alpha
Login
or
Join now
nateholland.bsky.social
/
PoseDetection
Star
0
Fork
0
Atom
Configure Feed
Issues
Pull Requests
Commits
Tags
Feed URL
Select the types of activity you want to include in your feed.
This repository has no description
Star
0
Fork
0
Atom
Configure Feed
Issues
Pull Requests
Commits
Tags
Feed URL
Select the types of activity you want to include in your feed.
Overview
Issues
Pulls
Pipelines
feat: get labels from metadata
author
nate
date
4 months ago
(Feb 5, 2026, 12:57 PM +0200)
commit
28c66ab9
28c66ab92d49833634829a2e2a651da206d7b923
parent
9ff107e5
9ff107e5cd8362d0fc475d622ca8b44b6c271115
+155
-87
6 changed files
Expand all
Collapse all
Unified
Split
posedetection
build.gradle.kts
src
androidMain
kotlin
com
performancecoachlab
posedetection
camera
Utils.android.kt
custom
CustomObjectModel.android.kt
recording
InputFrame.android.kt
com.performancecoachlab
posedetection
camera
CameraView.android.kt
sample
composeApp
src
commonMain
kotlin
com
nate
posedetection
App.kt
+1
posedetection/build.gradle.kts
Reviewed
···
93
93
//implementation(libs.tensorflow.lite.task.vision)
94
94
implementation("com.google.ai.edge.litert:litert:1.4.1")
95
95
implementation("com.google.ai.edge.litert:litert-support:1.4.1")
96
96
+
implementation("com.google.ai.edge.litert:litert-metadata:1.4.1")
96
97
}
97
98
98
99
}
+8
-1
posedetection/src/androidMain/kotlin/com.performancecoachlab/posedetection/camera/CameraView.android.kt
Reviewed
···
66
66
import androidx.camera.camera2.interop.ExperimentalCamera2Interop
67
67
import androidx.camera.core.CameraInfo
68
68
import androidx.camera.core.CameraSelector
69
69
+
import androidx.compose.ui.text.rememberTextMeasurer
69
70
import co.touchlab.kermit.Logger
70
71
import java.util.concurrent.atomic.AtomicBoolean
71
72
import java.util.concurrent.atomic.AtomicLong
···
157
158
var focus by remember { mutableStateOf(focusArea) }
158
159
var objectDetector by remember { mutableStateOf(objectModel) }
159
160
var currentDetectMode by remember { mutableStateOf(detectMode) }
161
161
+
162
162
+
val textMeasurer = rememberTextMeasurer()
160
163
LaunchedEffect(detectMode) { currentDetectMode = detectMode }
161
164
162
165
// Update focus when focusArea changes
···
435
438
)
436
439
437
440
DrawableShape.LABEL -> {
438
438
-
// no-op
441
441
+
drawLabelTextPlatform(
442
442
+
drawableObject = d.copy(obj = d.obj.copy(boundingBox = Rect(offset = Offset(left, top),
443
443
+
size = Size(w, h),))),
444
444
+
textMeasurer = textMeasurer
445
445
+
)
439
446
}
440
447
}
441
448
}
+24
-53
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/camera/Utils.android.kt
Reviewed
···
16
16
import com.performancecoachlab.posedetection.recording.FrameSize
17
17
import com.performancecoachlab.posedetection.skeleton.Skeleton
18
18
import org.tensorflow.lite.DataType
19
19
-
import org.tensorflow.lite.Interpreter
20
19
import org.tensorflow.lite.support.common.ops.CastOp
21
20
import org.tensorflow.lite.support.common.ops.NormalizeOp
22
21
import org.tensorflow.lite.support.image.ImageProcessor
23
22
import org.tensorflow.lite.support.image.TensorImage
24
23
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
25
25
-
import posedetection.posedetection.generated.resources.Res
26
24
import kotlin.math.absoluteValue
27
27
-
import androidx.core.graphics.scale
28
25
import co.touchlab.kermit.Logger
29
29
-
import kotlin.collections.get
30
26
import kotlin.math.max
31
27
import kotlin.math.min
32
32
-
import kotlin.run
33
28
import androidx.core.graphics.createBitmap
29
29
+
import com.performancecoachlab.posedetection.custom.AndroidDetector
30
30
+
import com.performancecoachlab.posedetection.custom.ModelInfo
34
31
import java.nio.ByteBuffer
35
32
36
33
actual enum class PlatformType {
···
166
163
// For OUTPUT_IMAGE_FORMAT_RGBA_8888 this should be 4.
167
164
if (pixelStride != 4) {
168
165
// Fallback: still attempt a direct copy into a fresh bitmap.
169
169
-
return Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888).also { it.copyPixelsFromBuffer(buffer) }
166
166
+
return createBitmap(width, height).also { it.copyPixelsFromBuffer(buffer) }
170
167
}
171
168
172
169
val bmp = AnalysisBitmapPool.obtain(width, height, Bitmap.Config.ARGB_8888)
···
194
191
195
192
@OptIn(ExperimentalGetImage::class)
196
193
fun ImageProxy.process(
197
197
-
objectDetector: Interpreter?,
194
194
+
objectDetector: AndroidDetector?,
198
195
poseDetector: PoseDetector?,
199
196
timestamp: Long,
200
197
focusArea: Rect?,
···
222
219
}
223
220
224
221
// 3) Tensor input: resize into pooled bitmap (avoid allocating each frame).
225
225
-
val processedTensorImage: TensorImage? = objectDetector?.let { interpreter ->
226
226
-
val inputShape = interpreter.getInputTensor(0)?.shape()
227
227
-
var tensorWidth = 0
228
228
-
var tensorHeight = 0
229
229
-
if (inputShape != null) {
230
230
-
tensorWidth = inputShape[1]
231
231
-
tensorHeight = inputShape[2]
232
232
-
if (inputShape[1] == 3) {
233
233
-
tensorWidth = inputShape[2]
234
234
-
tensorHeight = inputShape[3]
235
235
-
}
236
236
-
}
222
222
+
val processedTensorImage: TensorImage? = objectDetector?.modelInfo?.let { interpreter ->
223
223
+
val tensorWidth = interpreter.inputWidth
224
224
+
val tensorHeight = interpreter.inputHeight
237
225
if (tensorWidth <= 0 || tensorHeight <= 0) return@let null
238
226
239
227
val tensorBitmap = BitmapPool.obtain(tensorWidth, tensorHeight, Bitmap.Config.ARGB_8888)
···
263
251
fun process(
264
252
tensorImage: TensorImage?,
265
253
mlKitImage: InputImage?,
266
266
-
objectDetector: Interpreter?,
254
254
+
objectDetector: AndroidDetector?,
267
255
poseDetector: PoseDetector?,
268
256
timestamp: Long,
269
257
width: Int,
···
272
260
onComplete: (AnalysisResult, Bitmap) -> Unit
273
261
) {
274
262
val objectsDetected = if (objectDetector != null && tensorImage != null) {
275
275
-
val outputTensor = objectDetector.getOutputTensor(0)
276
276
-
val outputShape = outputTensor.shape()
263
263
+
val outputShape = objectDetector.modelInfo.outputShape
277
264
val output = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32)
278
265
279
279
-
objectDetector.run(tensorImage.buffer, output.buffer)
266
266
+
objectDetector.interpreter.run(tensorImage.buffer, output.buffer)
280
267
281
268
val array = output.floatArray
282
269
if (outputShape.size != 3) emptyList() else {
···
328
315
val rightPx = rightN * width
329
316
val bottomPx = bottomN * height
330
317
318
318
+
val label = objectDetector.modelInfo.label(cls)
319
319
+
331
320
AnalysisObject(
332
321
boundingBox = Rect(left = leftPx, top = topPx, right = rightPx, bottom = bottomPx),
333
322
trackingId = 0,
334
334
-
labels = listOf(com.performancecoachlab.posedetection.recording.Label("$cls", cnf)),
323
323
+
labels = listOf(com.performancecoachlab.posedetection.recording.Label(label, cnf)),
335
324
frameSize = FrameSize(width = width.absoluteValue, height = height.absoluteValue)
336
325
)
337
326
} else null
···
373
362
.add(CastOp(DataType.FLOAT32))
374
363
.build()
375
364
376
376
-
fun Bitmap.process(
377
377
-
objectDetector: Interpreter?,
378
378
-
poseDetector: PoseDetector,
379
379
-
timestamp: Long,
380
380
-
focusArea: Rect?,
381
381
-
onComplete: (AnalysisResult, Bitmap) -> Unit
382
382
-
) {
383
383
-
//val tensorImage = TensorImage.fromBitmap(this)
384
384
-
val tensorImage = TensorImage(DataType.FLOAT32)
385
385
-
tensorImage.load(this)
386
386
-
val processedImage = imageProcessor.process(tensorImage)
387
387
-
val mlKitImage = InputImage.fromBitmap(this.applyFocusAreaMask(focusArea), 0)
388
388
-
process(
389
389
-
tensorImage = processedImage,
390
390
-
mlKitImage = mlKitImage,
391
391
-
objectDetector = objectDetector,
392
392
-
poseDetector = poseDetector,
393
393
-
timestamp = timestamp,
394
394
-
width = width,
395
395
-
height = height,
396
396
-
bitmap = this,
397
397
-
onComplete = onComplete
398
398
-
)
399
399
-
}
400
400
-
401
365
fun Bitmap.cropToFocusArea(focusArea: Rect?): Bitmap {
402
366
return focusArea?.let { rect ->
403
367
val left = (rect.left * width.toFloat()).toInt().coerceIn(0, width)
···
536
500
val clsName: String
537
501
)
538
502
539
539
-
540
540
-
541
541
-
542
542
-
503
503
+
fun ModelInfo.label(cls: Int): String{
504
504
+
Logger.d{ "ModelInfo.label: cls=$cls" }
505
505
+
this.labels.let { labelsList ->
506
506
+
if (cls in labelsList.indices) {
507
507
+
Logger.d{ "ModelInfo.label: $cls = ${labelsList[cls]}" }
508
508
+
return labelsList[cls]
509
509
+
}
510
510
+
}
511
511
+
Logger.d{ "ModelInfo.label: $cls not in labels" }
512
512
+
return "$cls"
513
513
+
}
+112
-13
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/custom/CustomObjectModel.android.kt
Reviewed
···
2
2
3
3
import androidx.compose.runtime.Composable
4
4
import androidx.compose.ui.platform.LocalContext
5
5
+
import co.touchlab.kermit.Logger
5
6
import org.tensorflow.lite.Interpreter
6
7
import org.tensorflow.lite.support.common.FileUtil
8
8
+
import org.tensorflow.lite.support.metadata.MetadataExtractor
9
9
+
import java.nio.MappedByteBuffer
7
10
8
11
@Composable
9
12
actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel {
10
10
-
if(modelPath.androidModelPath == null){
13
13
+
if (modelPath.androidModelPath == null) {
11
14
throw IllegalArgumentException("Android model path cannot be null")
12
15
}
13
13
-
val options = Interpreter.Options().apply{
16
16
+
val options = Interpreter.Options().apply {
14
17
this.setNumThreads(4)
15
18
}
16
19
17
20
val model = FileUtil.loadMappedFile(LocalContext.current, modelPath.androidModelPath)
21
21
+
val labels = labels(model)
18
22
val interpreter = Interpreter(model, options)
19
23
20
24
val inputShape = interpreter.getInputTensor(0)?.shape()
21
25
val outputShape = interpreter.getOutputTensor(0)?.shape()
26
26
+
val modelInfo = ModelInfo.fromShapes(
27
27
+
inputShape = inputShape
28
28
+
?: throw IllegalArgumentException("Invalid model: input shape is null"),
29
29
+
outputShape = outputShape
30
30
+
?: throw IllegalArgumentException("Invalid model: output shape is null"),
31
31
+
labels,
32
32
+
)
33
33
+
val androidDetector = AndroidDetector(
34
34
+
interpreter = interpreter, modelInfo = modelInfo
35
35
+
)
36
36
+
return ObjectModel(androidDetector)
37
37
+
}
22
38
23
23
-
return ObjectModel(interpreter)
39
39
+
fun labels(model: MappedByteBuffer): List<String> {
40
40
+
return runCatching {
41
41
+
val extractor = MetadataExtractor(model)
42
42
+
43
43
+
// Ultralytics exports often store this as a .txt "metadata" file containing Python-dict text.
44
44
+
val metaFile = extractor.associatedFileNames.orEmpty()
45
45
+
.firstOrNull { it.endsWith(".txt", ignoreCase = true) || it.endsWith(".json", ignoreCase = true) }
46
46
+
?: return@runCatching emptyList()
47
47
+
48
48
+
val text = extractor.getAssociatedFile(metaFile)
49
49
+
.bufferedReader()
50
50
+
.use { it.readText() }
51
51
+
52
52
+
parseUltralyticsNames(text)
53
53
+
}.onFailure { t ->
54
54
+
Logger.w(t) { "Failed to load labels from TFLite metadata" }
55
55
+
}.getOrDefault(emptyList())
56
56
+
}
57
57
+
58
58
+
private fun parseUltralyticsNames(text: String): List<String> {
59
59
+
// Grab the `names: { ... }` section (non-greedy) to avoid matching other maps.
60
60
+
val namesBlock = Regex("""['"]names['"]\s*:\s*\{([\s\S]*?)\}""")
61
61
+
.find(text)
62
62
+
?.groupValues
63
63
+
?.getOrNull(1)
64
64
+
?: return emptyList()
65
65
+
66
66
+
// Match entries like: 0: 'person' OR 0: "person"
67
67
+
val entry = Regex("""(\d+)\s*:\s*['"]([^'"]+)['"]""")
68
68
+
val pairs = entry.findAll(namesBlock).mapNotNull { m ->
69
69
+
val idx = m.groupValues[1].toIntOrNull() ?: return@mapNotNull null
70
70
+
val name = m.groupValues[2].trim()
71
71
+
if (name.isBlank()) null else idx to name
72
72
+
}.toList()
73
73
+
74
74
+
if (pairs.isEmpty()) return emptyList()
75
75
+
76
76
+
val maxIdx = pairs.maxOf { it.first }
77
77
+
val out = MutableList(maxIdx + 1) { "" }
78
78
+
for ((i, label) in pairs) {
79
79
+
if (i in out.indices) out[i] = label
80
80
+
}
81
81
+
return out.filter { it.isNotBlank() }
24
82
}
25
83
26
26
-
actual class ObjectModel{
27
27
-
/*private var detector: org.tensorflow.lite.task.vision.detector.ObjectDetector? = null
84
84
+
data class ModelInfo(
85
85
+
val inputShape: IntArray,
86
86
+
val outputShape: IntArray,
87
87
+
val inputWidth: Int,
88
88
+
val inputHeight: Int,
89
89
+
val inputChannels: Int? = null,
90
90
+
val isNhwc: Boolean? = null,
91
91
+
val labels: List<String>,
92
92
+
) {
93
93
+
companion object {
94
94
+
fun fromShapes(inputShape: IntArray, outputShape: IntArray, labels: List<String>): ModelInfo {
95
95
+
// Common TFLite image shapes:
96
96
+
// NHWC: [1, H, W, C]
97
97
+
// NCHW: [1, C, H, W] (less common on Android)
98
98
+
// Some models might be [H, W, C] (no batch).
99
99
+
val (h, w, c, nhwc) = when (inputShape.size) {
100
100
+
4 -> {
101
101
+
val isNhwcGuess = inputShape[3] in 1..4
102
102
+
if (isNhwcGuess) {
103
103
+
Quad(inputShape[1], inputShape[2], inputShape[3], true)
104
104
+
} else {
105
105
+
Quad(inputShape[2], inputShape[3], inputShape[1], false)
106
106
+
}
107
107
+
}
28
108
29
29
-
constructor(detector: org.tensorflow.lite.task.vision.detector.ObjectDetector){
30
30
-
this.detector = detector
109
109
+
3 -> Quad(inputShape[0], inputShape[1], inputShape[2], true)
110
110
+
else -> Quad(0, 0, null, null)
111
111
+
}
112
112
+
113
113
+
return ModelInfo(
114
114
+
inputShape = inputShape,
115
115
+
outputShape = outputShape,
116
116
+
inputWidth = w,
117
117
+
inputHeight = h,
118
118
+
inputChannels = c,
119
119
+
isNhwc = nhwc,
120
120
+
labels = labels,
121
121
+
)
122
122
+
}
31
123
}
124
124
+
}
32
125
33
33
-
fun getDetector(): org.tensorflow.lite.task.vision.detector.ObjectDetector? {
34
34
-
return detector
35
35
-
}*/
126
126
+
private data class Quad(
127
127
+
val h: Int, val w: Int, val c: Int?, val nhwc: Boolean?
128
128
+
)
129
129
+
130
130
+
data class AndroidDetector(
131
131
+
val interpreter: Interpreter, val modelInfo: ModelInfo
132
132
+
)
133
133
+
134
134
+
actual class ObjectModel {
36
135
37
37
-
private var detector: Interpreter? = null
136
136
+
private var detector: AndroidDetector? = null
38
137
39
39
-
constructor(detector: Interpreter){
138
138
+
constructor(detector: AndroidDetector) {
40
139
this.detector = detector
41
140
}
42
141
43
43
-
fun getDetector(): Interpreter? {
142
142
+
fun getDetector(): AndroidDetector? {
44
143
return detector
45
144
}
46
145
}
+5
-15
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/recording/InputFrame.android.kt
Reviewed
···
24
24
import kotlin.math.max
25
25
import androidx.core.graphics.createBitmap
26
26
import co.touchlab.kermit.Logger
27
27
-
import com.google.android.gms.tasks.Tasks
28
27
import com.performancecoachlab.posedetection.camera.process
29
28
import kotlinx.coroutines.Dispatchers
30
29
import kotlinx.coroutines.launch
···
119
118
val mlKitImage = InputImage.fromBitmap(masked, 0)
120
119
121
120
// Object: resize to model input size, then normalize.
122
122
-
val tensorImage: TensorImage? = interpreter?.let { tfl ->
123
123
-
val shape = tfl.getInputTensor(0)?.shape()
124
124
-
var w = 0
125
125
-
var h = 0
126
126
-
if (shape != null) {
127
127
-
w = shape[1]
128
128
-
h = shape[2]
129
129
-
if (shape[1] == 3) {
130
130
-
w = shape[2]
131
131
-
h = shape[3]
132
132
-
}
133
133
-
}
121
121
+
val tensorImage: TensorImage? = interpreter?.modelInfo?.let { tfl ->
122
122
+
val w = tfl.inputWidth
123
123
+
val h = tfl.inputHeight
124
124
+
134
125
if (w <= 0 || h <= 0) null else {
135
126
val dst = ensureTensorBitmap(w, h)
136
136
-
// Draw scaled into dst (no new allocations).
137
137
-
android.graphics.Canvas(dst).drawBitmap(
127
127
+
Canvas(dst).drawBitmap(
138
128
inputFrame.bitmap,
139
129
null,
140
130
android.graphics.Rect(0, 0, w, h),
+5
-5
sample/composeApp/src/commonMain/kotlin/com/nate/posedetection/App.kt
Reviewed
···
80
80
import kotlin.time.Clock
81
81
import kotlin.time.ExperimentalTime
82
82
83
83
-
val androidPath = "basketballs_n1.tflite"
83
83
+
val androidPath = "yolov10n_float16.tflite"
84
84
val iosPath = "basketballs_n1"
85
85
@Composable
86
86
internal fun App() = AppTheme {
···
406
406
}
407
407
Button(
408
408
onClick = {
409
409
-
frontCamera = !frontCamera
409
409
+
recordingId = "${Clock.System.now().epochSeconds}"
410
410
},
411
411
modifier = Modifier.imePadding().padding(16.dp).align(Alignment.TopStart)
412
412
) {
···
431
431
rightKnee = Pose.PoseRange(160.0, 180.0)
432
432
)
433
433
skeleton?.let {
434
434
-
if (upRightPose.matches(it)) Text("Standing", fontSize = 80.sp)
435
435
-
else if (starPose.matches(it)) Text("Star", fontSize = 80.sp)
436
436
-
else Text("No Pose Detected", fontSize = 80.sp)
434
434
+
if (upRightPose.matches(it)) Text("Standing", fontSize = 15.sp)
435
435
+
else if (starPose.matches(it)) Text("Star", fontSize = 15.sp)
436
436
+
else Text("No Pose Detected", fontSize = 15.sp)
437
437
}
438
438
}