···
1
1
package com.performancecoachlab.posedetection.camera
2
2
3
3
import android.graphics.Bitmap
4
4
+
import android.graphics.Canvas
5
5
+
import android.graphics.Matrix
6
6
+
import android.graphics.Paint
4
7
import androidx.annotation.OptIn
5
8
import androidx.camera.core.ExperimentalGetImage
6
9
import androidx.camera.core.ImageProxy
···
27
30
import kotlin.math.max
28
31
import kotlin.math.min
29
32
import kotlin.run
33
33
+
import androidx.core.graphics.createBitmap
30
34
31
35
actual enum class PlatformType {
32
36
ANDROID, IOS;
···
36
40
}
37
41
}
38
42
43
43
+
private object BitmapPool {
44
44
+
private var cached: Bitmap? = null
45
45
+
private var cachedW: Int = 0
46
46
+
private var cachedH: Int = 0
47
47
+
private var cachedConfig: Bitmap.Config = Bitmap.Config.ARGB_8888
48
48
+
49
49
+
fun obtain(width: Int, height: Int, config: Bitmap.Config = Bitmap.Config.ARGB_8888): Bitmap {
50
50
+
val bmp = cached
51
51
+
return if (
52
52
+
bmp != null &&
53
53
+
!bmp.isRecycled &&
54
54
+
cachedW == width &&
55
55
+
cachedH == height &&
56
56
+
cachedConfig == config
57
57
+
) {
58
58
+
bmp.eraseColor(android.graphics.Color.TRANSPARENT)
59
59
+
bmp
60
60
+
} else {
61
61
+
createBitmap(width, height, config).also { newBmp ->
62
62
+
cached = newBmp
63
63
+
cachedW = width
64
64
+
cachedH = height
65
65
+
cachedConfig = config
66
66
+
}
67
67
+
}
68
68
+
}
69
69
+
}
70
70
+
71
71
+
private fun Bitmap.rotateToNew(degrees: Int): Bitmap {
72
72
+
if (degrees % 360 == 0) return this
73
73
+
val m = Matrix().apply { postRotate(degrees.toFloat()) }
74
74
+
return Bitmap.createBitmap(this, 0, 0, width, height, m, true)
75
75
+
}
76
76
+
77
77
+
private fun resizeInto(src: Bitmap, dst: Bitmap) {
78
78
+
val c = Canvas(dst)
79
79
+
val paint = Paint(Paint.FILTER_BITMAP_FLAG)
80
80
+
c.drawBitmap(src, null, android.graphics.Rect(0, 0, dst.width, dst.height), paint)
81
81
+
}
82
82
+
39
83
@OptIn(ExperimentalGetImage::class)
40
84
fun ImageProxy.process(
41
85
objectDetector: Interpreter?,
···
44
88
focusArea: Rect?,
45
89
onComplete: (AnalysisResult, Bitmap) -> Unit
46
90
) {
47
47
-
val bitmap = objectDetector?.let { interperter ->
48
48
-
val inputShape = interperter.getInputTensor(0)?.shape()
91
91
+
if (objectDetector == null && poseDetector == null) return
92
92
+
93
93
+
val rotationDegrees = imageInfo.rotationDegrees
94
94
+
95
95
+
// 1) Create ONE rotated bitmap in "analysis space" (used by MLKit + drawing + coordinate mapping).
96
96
+
val analysisBitmap: Bitmap = toBitmap().rotateToNew(rotationDegrees)
97
97
+
98
98
+
// 2) MLKit image must match analysisBitmap coordinate space. Rotation is now 0.
99
99
+
val mlKitImage: InputImage? = poseDetector?.let {
100
100
+
val masked = analysisBitmap.applyFocusAreaMask(focusArea, rotationDegrees)
101
101
+
InputImage.fromBitmap(masked, 0)
102
102
+
}
103
103
+
104
104
+
// 3) Tensor input: resize into pooled bitmap (avoid allocating each frame).
105
105
+
val processedTensorImage: TensorImage? = objectDetector?.let { interpreter ->
106
106
+
val inputShape = interpreter.getInputTensor(0)?.shape()
49
107
var tensorWidth = 0
50
108
var tensorHeight = 0
51
109
if (inputShape != null) {
52
110
tensorWidth = inputShape[1]
53
111
tensorHeight = inputShape[2]
54
54
-
55
55
-
// If in case input shape is in format of [1, 3, ..., ...]
56
112
if (inputShape[1] == 3) {
57
113
tensorWidth = inputShape[2]
58
114
tensorHeight = inputShape[3]
59
115
}
60
116
}
61
61
-
toBitmap().rotate(imageInfo.rotationDegrees.toFloat())
62
62
-
.scale(tensorWidth, tensorHeight, false)
63
63
-
}
64
64
-
if(bitmap == null){
65
65
-
return
66
66
-
}
117
117
+
if (tensorWidth <= 0 || tensorHeight <= 0) return@let null
67
118
119
119
+
val tensorBitmap = BitmapPool.obtain(tensorWidth, tensorHeight, Bitmap.Config.ARGB_8888)
120
120
+
resizeInto(analysisBitmap, tensorBitmap)
68
121
69
69
-
//val tensorImage = TensorImage.fromBitmap(bitmap)
70
70
-
val tensorImage = TensorImage(DataType.FLOAT32)
71
71
-
tensorImage.load(bitmap)
72
72
-
val processedImage = imageProcessor.process(tensorImage)
73
73
-
val mlKitImage = InputImage.fromBitmap(bitmap.applyFocusAreaMask(focusArea,imageInfo.rotationDegrees
74
74
-
), imageInfo.rotationDegrees)
122
122
+
TensorImage(DataType.FLOAT32).also { ti ->
123
123
+
ti.load(tensorBitmap)
124
124
+
}.let { ti ->
125
125
+
imageProcessor.process(ti)
126
126
+
}
127
127
+
}
128
128
+
129
129
+
// If no objectDetector, we still want pose results; if no poseDetector, we still want objects.
75
130
process(
76
76
-
tensorImage = processedImage,
131
131
+
tensorImage = processedTensorImage,
77
132
mlKitImage = mlKitImage,
78
133
objectDetector = objectDetector,
79
134
poseDetector = poseDetector,
80
135
timestamp = timestamp,
81
81
-
width = width,
82
82
-
height = height,
83
83
-
bitmap = bitmap,
136
136
+
width = analysisBitmap.width,
137
137
+
height = analysisBitmap.height,
138
138
+
bitmap = analysisBitmap,
84
139
onComplete = onComplete
85
140
)
86
141
}
87
142
143
143
+
private fun process(
144
144
+
tensorImage: TensorImage?,
145
145
+
mlKitImage: InputImage?,
146
146
+
objectDetector: Interpreter?,
147
147
+
poseDetector: PoseDetector?,
148
148
+
timestamp: Long,
149
149
+
width: Int,
150
150
+
height: Int,
151
151
+
bitmap: Bitmap,
152
152
+
onComplete: (AnalysisResult, Bitmap) -> Unit
153
153
+
) {
154
154
+
val objectsDetected = if (objectDetector != null && tensorImage != null) {
155
155
+
val outputTensor = objectDetector.getOutputTensor(0)
156
156
+
val outputShape = outputTensor.shape()
157
157
+
val output = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
158
158
+
159
159
+
objectDetector.run(tensorImage.buffer, output.buffer)
160
160
+
161
161
+
val array = output.floatArray
162
162
+
if (outputShape.size != 3) emptyList() else {
163
163
+
val dim1 = outputShape[1]
164
164
+
val dim2 = outputShape[2]
165
165
+
166
166
+
val elements: Int
167
167
+
val channels: Int
168
168
+
val isElementsFirst: Boolean
169
169
+
170
170
+
when {
171
171
+
dim2 == 6 -> {
172
172
+
elements = dim1
173
173
+
channels = dim2
174
174
+
isElementsFirst = true
175
175
+
}
176
176
+
dim1 == 6 -> {
177
177
+
channels = dim1
178
178
+
elements = dim2
179
179
+
isElementsFirst = false
180
180
+
}
181
181
+
else -> return onComplete(AnalysisResult(skeleton = null, objects = emptyList()), bitmap)
182
182
+
}
183
183
+
184
184
+
fun valueAt(elementIndex: Int, channelIndex: Int): Float {
185
185
+
return if (isElementsFirst) {
186
186
+
array[elementIndex * channels + channelIndex]
187
187
+
} else {
188
188
+
array[channelIndex * elements + elementIndex]
189
189
+
}
190
190
+
}
191
191
+
192
192
+
(0 until elements).mapNotNull { i ->
193
193
+
val cnf = valueAt(i, 4)
194
194
+
if (cnf > 0.25f) {
195
195
+
val x1 = valueAt(i, 0)
196
196
+
val y1 = valueAt(i, 1)
197
197
+
val x2 = valueAt(i, 2)
198
198
+
val y2 = valueAt(i, 3)
199
199
+
val cls = valueAt(i, 5).toInt()
200
200
+
201
201
+
val leftN = min(x1, x2).coerceIn(0f, 1f)
202
202
+
val topN = min(y1, y2).coerceIn(0f, 1f)
203
203
+
val rightN = max(x1, x2).coerceIn(0f, 1f)
204
204
+
val bottomN = max(y1, y2).coerceIn(0f, 1f)
205
205
+
206
206
+
val leftPx = leftN * width
207
207
+
val topPx = topN * height
208
208
+
val rightPx = rightN * width
209
209
+
val bottomPx = bottomN * height
210
210
+
211
211
+
AnalysisObject(
212
212
+
boundingBox = Rect(left = leftPx, top = topPx, right = rightPx, bottom = bottomPx),
213
213
+
trackingId = 0,
214
214
+
labels = listOf(com.performancecoachlab.posedetection.recording.Label("$cls", cnf)),
215
215
+
frameSize = FrameSize(width = width.absoluteValue, height = height.absoluteValue)
216
216
+
)
217
217
+
} else null
218
218
+
}
219
219
+
}
220
220
+
} else emptyList()
221
221
+
222
222
+
var skeleton: Skeleton? = null
223
223
+
val poseDetectionTask = if (poseDetector != null && mlKitImage != null) {
224
224
+
poseDetector.process(mlKitImage)
225
225
+
.addOnSuccessListener { pose ->
226
226
+
skeleton = skeleton(pose, timestamp, width, height)
227
227
+
}
228
228
+
.addOnFailureListener { }
229
229
+
} else null
230
230
+
231
231
+
Tasks.whenAllComplete(listOfNotNull(poseDetectionTask)).addOnCompleteListener {
232
232
+
onComplete(
233
233
+
AnalysisResult(
234
234
+
skeleton = skeleton,
235
235
+
objects = objectsDetected
236
236
+
),
237
237
+
bitmap
238
238
+
)
239
239
+
}
240
240
+
}
241
241
+
88
242
private fun Rect?.toGraphicsRect(width: Int, height: Int):android.graphics.Rect {
89
243
return this?.let {
90
244
android.graphics.Rect((it.left*width).toInt(),
···
124
278
)
125
279
}
126
280
127
127
-
/**
128
128
-
* Crops the bitmap to the specified focus area rectangle
129
129
-
* @param focusArea The rectangle area to crop to (in normalized coordinates 0.0-1.0)
130
130
-
* @return A new bitmap cropped to the focus area, or the original bitmap if focusArea is null
131
131
-
*/
132
281
fun Bitmap.cropToFocusArea(focusArea: Rect?): Bitmap {
133
282
return focusArea?.let { rect ->
134
283
val left = (rect.left * width.toFloat()).toInt().coerceIn(0, width)
···
147
296
} ?: this
148
297
}
149
298
150
150
-
/**
151
151
-
* Creates a copy of the bitmap with everything outside the focus area blacked out
152
152
-
* @param focusArea The rectangle area to keep visible (in normalized coordinates 0.0-1.0)
153
153
-
* @param angle The rotation angle in degrees (must be a multiple of 90) to apply to the focus area rectangle
154
154
-
* @return A new bitmap with areas outside the focus area blacked out, or the original bitmap if focusArea is null
155
155
-
*/
156
156
-
fun Bitmap.applyFocusAreaMask(focusArea: Rect?, _angle: Int = 0): Bitmap {
299
299
+
300
300
+
fun Bitmap.applyFocusAreaMask(focusArea: Rect?, angle: Int = 0): Bitmap {
157
301
return focusArea?.let { rect ->
158
302
val result = this.copy(this.config ?: Bitmap.Config.ARGB_8888, true)
159
159
-
val canvas = android.graphics.Canvas(result)
160
160
-
val paint = android.graphics.Paint().apply {
303
303
+
val canvas = Canvas(result)
304
304
+
val paint = Paint().apply {
161
305
color = android.graphics.Color.BLACK
162
306
}
163
163
-
164
164
-
val angle = 0
165
307
// Transform the rectangle coordinates based on the angle
166
308
val transformedRect = when (angle % 360) {
167
309
90 -> Rect(
···
209
351
210
352
result
211
353
} ?: this
212
212
-
}
213
213
-
214
214
-
private fun process(
215
215
-
tensorImage: TensorImage,
216
216
-
mlKitImage: InputImage?,
217
217
-
objectDetector: Interpreter?,
218
218
-
poseDetector: PoseDetector?,
219
219
-
timestamp: Long,
220
220
-
width: Int,
221
221
-
height: Int,
222
222
-
bitmap: Bitmap,
223
223
-
onComplete: (AnalysisResult, Bitmap) -> Unit
224
224
-
) {
225
225
-
Logger.d{"Processing image of size: ${tensorImage.width}x${tensorImage.height}" }
226
226
-
val objectsDetected = objectDetector?.let { interpreter ->
227
227
-
val outputTensor = interpreter.getOutputTensor(0)
228
228
-
val outputShape = outputTensor.shape() // e.g. [1, 180, 6] or [1, 6, 180]
229
229
-
val output = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
230
230
-
231
231
-
interpreter.run(tensorImage.buffer, output.buffer)
232
232
-
233
233
-
val array = output.floatArray
234
234
-
if (outputShape.size != 3) return@let emptyList<AnalysisObject>()
235
235
-
236
236
-
val dim1 = outputShape[1]
237
237
-
val dim2 = outputShape[2]
238
238
-
239
239
-
// We expect 6 values per detection: x1,y1,x2,y2,cnf,cls
240
240
-
// So whichever dimension equals 6 is the "channels".
241
241
-
val channels: Int
242
242
-
val elements: Int
243
243
-
val isElementsFirst: Boolean // true if shape is [1, elements, channels]
244
244
-
245
245
-
when {
246
246
-
dim2 == 6 -> {
247
247
-
// [1, elements, 6]
248
248
-
elements = dim1
249
249
-
channels = dim2
250
250
-
isElementsFirst = true
251
251
-
}
252
252
-
dim1 == 6 -> {
253
253
-
// [1, 6, elements]
254
254
-
channels = dim1
255
255
-
elements = dim2
256
256
-
isElementsFirst = false
257
257
-
}
258
258
-
else -> {
259
259
-
// Unknown layout; bail out rather than silently producing 0 detections.
260
260
-
return@let emptyList<AnalysisObject>()
261
261
-
}
262
262
-
}
263
263
-
264
264
-
Logger.d{"Processing objects: ${elements}" }
265
265
-
266
266
-
fun valueAt(elementIndex: Int, channelIndex: Int): Float {
267
267
-
return if (isElementsFirst) {
268
268
-
// base = elementIndex * channels + channelIndex
269
269
-
array[elementIndex * channels + channelIndex]
270
270
-
} else {
271
271
-
// base = channelIndex * elements + elementIndex
272
272
-
array[channelIndex * elements + elementIndex]
273
273
-
}
274
274
-
}
275
275
-
276
276
-
(0 until elements).mapNotNull { i ->
277
277
-
val cnf = valueAt(i, 4)
278
278
-
if (cnf > 0.25f) {
279
279
-
val x1 = valueAt(i, 0)
280
280
-
val y1 = valueAt(i, 1)
281
281
-
val x2 = valueAt(i, 2)
282
282
-
val y2 = valueAt(i, 3)
283
283
-
val cls = valueAt(i, 5).toInt()
284
284
-
285
285
-
val leftN = min(x1, x2)
286
286
-
val topN = min(y1, y2)
287
287
-
val rightN = max(x1, x2)
288
288
-
val bottomN = max(y1, y2)
289
289
-
290
290
-
val leftPx = leftN * width.absoluteValue
291
291
-
val topPx = topN * height.absoluteValue
292
292
-
val rightPx = rightN * width.absoluteValue
293
293
-
val bottomPx = bottomN * height.absoluteValue
294
294
-
295
295
-
296
296
-
AnalysisObject(
297
297
-
boundingBox = Rect(left = leftPx, top = topPx, right = rightPx, bottom = bottomPx),
298
298
-
trackingId = 0,
299
299
-
labels = listOf(com.performancecoachlab.posedetection.recording.Label("$cls", cnf)),
300
300
-
frameSize = FrameSize(width = width.absoluteValue, height = height.absoluteValue)
301
301
-
)
302
302
-
} else null
303
303
-
}
304
304
-
} ?: emptyList()
305
305
-
306
306
-
Logger.d{"Processed objecs size: ${objectsDetected.size}" }
307
307
-
/*
308
308
-
val objectsDetected = objectDetector?.detect(tensorImage)?.map { result ->
309
309
-
AnalysisObject(
310
310
-
boundingBox = result.boundingBox.let {
311
311
-
Rect(
312
312
-
left = it.left,
313
313
-
top = it.top,
314
314
-
right = it.right,
315
315
-
bottom = it.bottom
316
316
-
)
317
317
-
},
318
318
-
trackingId = 0,
319
319
-
labels = result.categories.map { category ->
320
320
-
com.performancecoachlab.posedetection.recording.Label(
321
321
-
category.label,
322
322
-
category.score
323
323
-
)
324
324
-
},
325
325
-
frameSize = FrameSize(
326
326
-
width = width.absoluteValue,
327
327
-
height = height.absoluteValue
328
328
-
)
329
329
-
)
330
330
-
} ?: emptyList()*/
331
331
-
var skeleton: Skeleton? = null
332
332
-
val poseDetectionTask = mlKitImage?.let {
333
333
-
val rotation = it.rotationDegrees
334
334
-
poseDetector?.process(it)?.addOnSuccessListener { pose ->
335
335
-
skeleton = skeleton(pose, timestamp, width, height).let{
336
336
-
it.rotate(rotation)
337
337
-
}
338
338
-
}?.addOnFailureListener { e ->
339
339
-
//println(e)
340
340
-
}
341
341
-
}
342
342
-
Tasks.whenAllComplete(listOfNotNull(poseDetectionTask)).addOnCompleteListener {
343
343
-
onComplete(
344
344
-
AnalysisResult(
345
345
-
skeleton = skeleton,
346
346
-
objects = objectsDetected
347
347
-
),
348
348
-
bitmap
349
349
-
)
350
350
-
}
351
354
}
352
355
353
356
data class BoundingBox(