This repository has no description
1package com.performancecoachlab.posedetection.custom
2
3import android.graphics.Bitmap
4import android.graphics.Canvas
5import android.graphics.Color
6import android.graphics.Paint
7import android.graphics.RectF
8import androidx.compose.ui.geometry.Rect
9import androidx.compose.ui.graphics.ImageBitmap
10import androidx.compose.ui.graphics.asAndroidBitmap
11import co.touchlab.kermit.Logger
12import com.performancecoachlab.posedetection.camera.label
13import com.performancecoachlab.posedetection.recording.AnalysisObject
14import com.performancecoachlab.posedetection.recording.FrameSize
15import com.performancecoachlab.posedetection.recording.Label
16import org.tensorflow.lite.DataType
17import org.tensorflow.lite.support.common.ops.CastOp
18import org.tensorflow.lite.support.common.ops.NormalizeOp
19import org.tensorflow.lite.support.image.ImageProcessor
20import org.tensorflow.lite.support.image.TensorImage
21import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
22import kotlin.math.absoluteValue
23import kotlin.math.max
24import kotlin.math.min
25import kotlin.math.roundToInt
26
27actual class ImageDetector actual constructor(model: ObjectModel) {
28
29 private val detector: AndroidDetector? = model.getDetector()
30
31 private val imageProcessor = ImageProcessor.Builder()
32 .add(NormalizeOp(0f, 255f))
33 .add(CastOp(DataType.FLOAT32))
34 .build()
35
36 actual fun detect(image: ImageBitmap): List<AnalysisObject> {
37 val det = detector ?: return emptyList()
38 val info = det.modelInfo
39 val inputW = info.inputWidth
40 val inputH = info.inputHeight
41 if (inputW <= 0 || inputH <= 0) return emptyList()
42
43 val srcBitmap = image.asAndroidBitmap().let { bmp ->
44 if (bmp.config == Bitmap.Config.ARGB_8888) bmp
45 else bmp.copy(Bitmap.Config.ARGB_8888, false)
46 }
47 val imgW = srcBitmap.width
48 val imgH = srcBitmap.height
49
50 // Letterbox: scale the source to fit inside (inputW, inputH) while
51 // preserving aspect ratio, center it, and pad the remainder with
52 // gray 114 — matching ultralytics training-time preprocessing.
53 // (A naive stretch-resize destroys aspect ratio and hurts detection
54 // quality on non-square camera frames.)
55 val scale = min(
56 inputW.toFloat() / imgW.toFloat(),
57 inputH.toFloat() / imgH.toFloat()
58 )
59 val scaledW = (imgW * scale).roundToInt().coerceAtLeast(1)
60 val scaledH = (imgH * scale).roundToInt().coerceAtLeast(1)
61 val padX = (inputW - scaledW) / 2f
62 val padY = (inputH - scaledH) / 2f
63
64 val resized = Bitmap.createBitmap(inputW, inputH, Bitmap.Config.ARGB_8888)
65 val canvas = Canvas(resized)
66 canvas.drawColor(Color.rgb(114, 114, 114))
67 canvas.drawBitmap(
68 srcBitmap, null,
69 RectF(padX, padY, padX + scaledW, padY + scaledH),
70 Paint(Paint.FILTER_BITMAP_FLAG)
71 )
72
73 // Normalize and convert to tensor
74 val tensorImage = TensorImage(DataType.FLOAT32).also { it.load(resized) }
75 .let(imageProcessor::process)
76
77 // Run inference
78 val outputShape = info.outputShape
79 val output = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
80 try {
81 det.interpreter.run(tensorImage.buffer, output.buffer)
82 } catch (e: Exception) {
83 Logger.e(e) { "ImageDetector: interpreter.run failed" }
84 return emptyList()
85 }
86
87 // Parse output
88 val array = output.floatArray
89 if (outputShape.size != 3) return emptyList()
90
91 val dim1 = outputShape[1]
92 val dim2 = outputShape[2]
93
94 val elements: Int
95 val channels: Int
96 val isElementsFirst: Boolean
97
98 when {
99 dim2 == 6 -> { elements = dim1; channels = dim2; isElementsFirst = true }
100 dim1 == 6 -> { channels = dim1; elements = dim2; isElementsFirst = false }
101 else -> return emptyList()
102 }
103
104 fun valueAt(elementIndex: Int, channelIndex: Int): Float {
105 return if (isElementsFirst) array[elementIndex * channels + channelIndex]
106 else array[channelIndex * elements + elementIndex]
107 }
108
109 val imgWF = imgW.toFloat()
110 val imgHF = imgH.toFloat()
111
112 return (0 until elements).mapNotNull { i ->
113 val cnf = valueAt(i, 4)
114 if (cnf > 0.25f) {
115 val x1 = valueAt(i, 0)
116 val y1 = valueAt(i, 1)
117 val x2 = valueAt(i, 2)
118 val y2 = valueAt(i, 3)
119 val cls = valueAt(i, 5).toInt()
120
121 // Model outputs are normalized [0,1] over the letterboxed
122 // input (inputW × inputH). Un-letterbox: convert to
123 // letterboxed pixel coords, subtract padding, divide by
124 // the fit ratio to get original-image pixel coords.
125 val x1pLb = min(x1, x2) * inputW
126 val y1pLb = min(y1, y2) * inputH
127 val x2pLb = max(x1, x2) * inputW
128 val y2pLb = max(y1, y2) * inputH
129
130 val left = ((x1pLb - padX) / scale).coerceIn(0f, imgWF)
131 val top = ((y1pLb - padY) / scale).coerceIn(0f, imgHF)
132 val right = ((x2pLb - padX) / scale).coerceIn(0f, imgWF)
133 val bottom = ((y2pLb - padY) / scale).coerceIn(0f, imgHF)
134
135 AnalysisObject(
136 boundingBox = Rect(
137 left = left,
138 top = top,
139 right = right,
140 bottom = bottom
141 ),
142 trackingId = 0,
143 labels = listOf(Label(info.label(cls), cnf)),
144 frameSize = FrameSize(width = imgW.absoluteValue, height = imgH.absoluteValue),
145 timestamp = 0L
146 )
147 } else null
148 }
149 }
150}