This repository has no description
0

Configure Feed

Select the types of activity you want to include in your feed.

feat: get labels from metadata

+155 -87
+1
posedetection/build.gradle.kts
··· 93 93 //implementation(libs.tensorflow.lite.task.vision) 94 94 implementation("com.google.ai.edge.litert:litert:1.4.1") 95 95 implementation("com.google.ai.edge.litert:litert-support:1.4.1") 96 + implementation("com.google.ai.edge.litert:litert-metadata:1.4.1") 96 97 } 97 98 98 99 }
+8 -1
posedetection/src/androidMain/kotlin/com.performancecoachlab/posedetection/camera/CameraView.android.kt
··· 66 66 import androidx.camera.camera2.interop.ExperimentalCamera2Interop 67 67 import androidx.camera.core.CameraInfo 68 68 import androidx.camera.core.CameraSelector 69 + import androidx.compose.ui.text.rememberTextMeasurer 69 70 import co.touchlab.kermit.Logger 70 71 import java.util.concurrent.atomic.AtomicBoolean 71 72 import java.util.concurrent.atomic.AtomicLong ··· 157 158 var focus by remember { mutableStateOf(focusArea) } 158 159 var objectDetector by remember { mutableStateOf(objectModel) } 159 160 var currentDetectMode by remember { mutableStateOf(detectMode) } 161 + 162 + val textMeasurer = rememberTextMeasurer() 160 163 LaunchedEffect(detectMode) { currentDetectMode = detectMode } 161 164 162 165 // Update focus when focusArea changes ··· 435 438 ) 436 439 437 440 DrawableShape.LABEL -> { 438 - // no-op 441 + drawLabelTextPlatform( 442 + drawableObject = d.copy(obj = d.obj.copy(boundingBox = Rect(offset = Offset(left, top), 443 + size = Size(w, h),))), 444 + textMeasurer = textMeasurer 445 + ) 439 446 } 440 447 } 441 448 }
+24 -53
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/camera/Utils.android.kt
··· 16 16 import com.performancecoachlab.posedetection.recording.FrameSize 17 17 import com.performancecoachlab.posedetection.skeleton.Skeleton 18 18 import org.tensorflow.lite.DataType 19 - import org.tensorflow.lite.Interpreter 20 19 import org.tensorflow.lite.support.common.ops.CastOp 21 20 import org.tensorflow.lite.support.common.ops.NormalizeOp 22 21 import org.tensorflow.lite.support.image.ImageProcessor 23 22 import org.tensorflow.lite.support.image.TensorImage 24 23 import org.tensorflow.lite.support.tensorbuffer.TensorBuffer 25 - import posedetection.posedetection.generated.resources.Res 26 24 import kotlin.math.absoluteValue 27 - import androidx.core.graphics.scale 28 25 import co.touchlab.kermit.Logger 29 - import kotlin.collections.get 30 26 import kotlin.math.max 31 27 import kotlin.math.min 32 - import kotlin.run 33 28 import androidx.core.graphics.createBitmap 29 + import com.performancecoachlab.posedetection.custom.AndroidDetector 30 + import com.performancecoachlab.posedetection.custom.ModelInfo 34 31 import java.nio.ByteBuffer 35 32 36 33 actual enum class PlatformType { ··· 166 163 // For OUTPUT_IMAGE_FORMAT_RGBA_8888 this should be 4. 167 164 if (pixelStride != 4) { 168 165 // Fallback: still attempt a direct copy into a fresh bitmap. 169 - return Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888).also { it.copyPixelsFromBuffer(buffer) } 166 + return createBitmap(width, height).also { it.copyPixelsFromBuffer(buffer) } 170 167 } 171 168 172 169 val bmp = AnalysisBitmapPool.obtain(width, height, Bitmap.Config.ARGB_8888) ··· 194 191 195 192 @OptIn(ExperimentalGetImage::class) 196 193 fun ImageProxy.process( 197 - objectDetector: Interpreter?, 194 + objectDetector: AndroidDetector?, 198 195 poseDetector: PoseDetector?, 199 196 timestamp: Long, 200 197 focusArea: Rect?, ··· 222 219 } 223 220 224 221 // 3) Tensor input: resize into pooled bitmap (avoid allocating each frame). 225 - val processedTensorImage: TensorImage? = objectDetector?.let { interpreter -> 226 - val inputShape = interpreter.getInputTensor(0)?.shape() 227 - var tensorWidth = 0 228 - var tensorHeight = 0 229 - if (inputShape != null) { 230 - tensorWidth = inputShape[1] 231 - tensorHeight = inputShape[2] 232 - if (inputShape[1] == 3) { 233 - tensorWidth = inputShape[2] 234 - tensorHeight = inputShape[3] 235 - } 236 - } 222 + val processedTensorImage: TensorImage? = objectDetector?.modelInfo?.let { interpreter -> 223 + val tensorWidth = interpreter.inputWidth 224 + val tensorHeight = interpreter.inputHeight 237 225 if (tensorWidth <= 0 || tensorHeight <= 0) return@let null 238 226 239 227 val tensorBitmap = BitmapPool.obtain(tensorWidth, tensorHeight, Bitmap.Config.ARGB_8888) ··· 263 251 fun process( 264 252 tensorImage: TensorImage?, 265 253 mlKitImage: InputImage?, 266 - objectDetector: Interpreter?, 254 + objectDetector: AndroidDetector?, 267 255 poseDetector: PoseDetector?, 268 256 timestamp: Long, 269 257 width: Int, ··· 272 260 onComplete: (AnalysisResult, Bitmap) -> Unit 273 261 ) { 274 262 val objectsDetected = if (objectDetector != null && tensorImage != null) { 275 - val outputTensor = objectDetector.getOutputTensor(0) 276 - val outputShape = outputTensor.shape() 263 + val outputShape = objectDetector.modelInfo.outputShape 277 264 val output = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32) 278 265 279 - objectDetector.run(tensorImage.buffer, output.buffer) 266 + objectDetector.interpreter.run(tensorImage.buffer, output.buffer) 280 267 281 268 val array = output.floatArray 282 269 if (outputShape.size != 3) emptyList() else { ··· 328 315 val rightPx = rightN * width 329 316 val bottomPx = bottomN * height 330 317 318 + val label = objectDetector.modelInfo.label(cls) 319 + 331 320 AnalysisObject( 332 321 boundingBox = Rect(left = leftPx, top = topPx, right = rightPx, bottom = bottomPx), 333 322 trackingId = 0, 334 - labels = listOf(com.performancecoachlab.posedetection.recording.Label("$cls", cnf)), 323 + labels = listOf(com.performancecoachlab.posedetection.recording.Label(label, cnf)), 335 324 frameSize = FrameSize(width = width.absoluteValue, height = height.absoluteValue) 336 325 ) 337 326 } else null ··· 373 362 .add(CastOp(DataType.FLOAT32)) 374 363 .build() 375 364 376 - fun Bitmap.process( 377 - objectDetector: Interpreter?, 378 - poseDetector: PoseDetector, 379 - timestamp: Long, 380 - focusArea: Rect?, 381 - onComplete: (AnalysisResult, Bitmap) -> Unit 382 - ) { 383 - //val tensorImage = TensorImage.fromBitmap(this) 384 - val tensorImage = TensorImage(DataType.FLOAT32) 385 - tensorImage.load(this) 386 - val processedImage = imageProcessor.process(tensorImage) 387 - val mlKitImage = InputImage.fromBitmap(this.applyFocusAreaMask(focusArea), 0) 388 - process( 389 - tensorImage = processedImage, 390 - mlKitImage = mlKitImage, 391 - objectDetector = objectDetector, 392 - poseDetector = poseDetector, 393 - timestamp = timestamp, 394 - width = width, 395 - height = height, 396 - bitmap = this, 397 - onComplete = onComplete 398 - ) 399 - } 400 - 401 365 fun Bitmap.cropToFocusArea(focusArea: Rect?): Bitmap { 402 366 return focusArea?.let { rect -> 403 367 val left = (rect.left * width.toFloat()).toInt().coerceIn(0, width) ··· 536 500 val clsName: String 537 501 ) 538 502 539 - 540 - 541 - 542 - 503 + fun ModelInfo.label(cls: Int): String{ 504 + Logger.d{ "ModelInfo.label: cls=$cls" } 505 + this.labels.let { labelsList -> 506 + if (cls in labelsList.indices) { 507 + Logger.d{ "ModelInfo.label: $cls = ${labelsList[cls]}" } 508 + return labelsList[cls] 509 + } 510 + } 511 + Logger.d{ "ModelInfo.label: $cls not in labels" } 512 + return "$cls" 513 + }
+112 -13
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/custom/CustomObjectModel.android.kt
··· 2 2 3 3 import androidx.compose.runtime.Composable 4 4 import androidx.compose.ui.platform.LocalContext 5 + import co.touchlab.kermit.Logger 5 6 import org.tensorflow.lite.Interpreter 6 7 import org.tensorflow.lite.support.common.FileUtil 8 + import org.tensorflow.lite.support.metadata.MetadataExtractor 9 + import java.nio.MappedByteBuffer 7 10 8 11 @Composable 9 12 actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel { 10 - if(modelPath.androidModelPath == null){ 13 + if (modelPath.androidModelPath == null) { 11 14 throw IllegalArgumentException("Android model path cannot be null") 12 15 } 13 - val options = Interpreter.Options().apply{ 16 + val options = Interpreter.Options().apply { 14 17 this.setNumThreads(4) 15 18 } 16 19 17 20 val model = FileUtil.loadMappedFile(LocalContext.current, modelPath.androidModelPath) 21 + val labels = labels(model) 18 22 val interpreter = Interpreter(model, options) 19 23 20 24 val inputShape = interpreter.getInputTensor(0)?.shape() 21 25 val outputShape = interpreter.getOutputTensor(0)?.shape() 26 + val modelInfo = ModelInfo.fromShapes( 27 + inputShape = inputShape 28 + ?: throw IllegalArgumentException("Invalid model: input shape is null"), 29 + outputShape = outputShape 30 + ?: throw IllegalArgumentException("Invalid model: output shape is null"), 31 + labels, 32 + ) 33 + val androidDetector = AndroidDetector( 34 + interpreter = interpreter, modelInfo = modelInfo 35 + ) 36 + return ObjectModel(androidDetector) 37 + } 22 38 23 - return ObjectModel(interpreter) 39 + fun labels(model: MappedByteBuffer): List<String> { 40 + return runCatching { 41 + val extractor = MetadataExtractor(model) 42 + 43 + // Ultralytics exports often store this as a .txt "metadata" file containing Python-dict text. 44 + val metaFile = extractor.associatedFileNames.orEmpty() 45 + .firstOrNull { it.endsWith(".txt", ignoreCase = true) || it.endsWith(".json", ignoreCase = true) } 46 + ?: return@runCatching emptyList() 47 + 48 + val text = extractor.getAssociatedFile(metaFile) 49 + .bufferedReader() 50 + .use { it.readText() } 51 + 52 + parseUltralyticsNames(text) 53 + }.onFailure { t -> 54 + Logger.w(t) { "Failed to load labels from TFLite metadata" } 55 + }.getOrDefault(emptyList()) 56 + } 57 + 58 + private fun parseUltralyticsNames(text: String): List<String> { 59 + // Grab the `names: { ... }` section (non-greedy) to avoid matching other maps. 60 + val namesBlock = Regex("""['"]names['"]\s*:\s*\{([\s\S]*?)\}""") 61 + .find(text) 62 + ?.groupValues 63 + ?.getOrNull(1) 64 + ?: return emptyList() 65 + 66 + // Match entries like: 0: 'person' OR 0: "person" 67 + val entry = Regex("""(\d+)\s*:\s*['"]([^'"]+)['"]""") 68 + val pairs = entry.findAll(namesBlock).mapNotNull { m -> 69 + val idx = m.groupValues[1].toIntOrNull() ?: return@mapNotNull null 70 + val name = m.groupValues[2].trim() 71 + if (name.isBlank()) null else idx to name 72 + }.toList() 73 + 74 + if (pairs.isEmpty()) return emptyList() 75 + 76 + val maxIdx = pairs.maxOf { it.first } 77 + val out = MutableList(maxIdx + 1) { "" } 78 + for ((i, label) in pairs) { 79 + if (i in out.indices) out[i] = label 80 + } 81 + return out.filter { it.isNotBlank() } 24 82 } 25 83 26 - actual class ObjectModel{ 27 - /*private var detector: org.tensorflow.lite.task.vision.detector.ObjectDetector? = null 84 + data class ModelInfo( 85 + val inputShape: IntArray, 86 + val outputShape: IntArray, 87 + val inputWidth: Int, 88 + val inputHeight: Int, 89 + val inputChannels: Int? = null, 90 + val isNhwc: Boolean? = null, 91 + val labels: List<String>, 92 + ) { 93 + companion object { 94 + fun fromShapes(inputShape: IntArray, outputShape: IntArray, labels: List<String>): ModelInfo { 95 + // Common TFLite image shapes: 96 + // NHWC: [1, H, W, C] 97 + // NCHW: [1, C, H, W] (less common on Android) 98 + // Some models might be [H, W, C] (no batch). 99 + val (h, w, c, nhwc) = when (inputShape.size) { 100 + 4 -> { 101 + val isNhwcGuess = inputShape[3] in 1..4 102 + if (isNhwcGuess) { 103 + Quad(inputShape[1], inputShape[2], inputShape[3], true) 104 + } else { 105 + Quad(inputShape[2], inputShape[3], inputShape[1], false) 106 + } 107 + } 28 108 29 - constructor(detector: org.tensorflow.lite.task.vision.detector.ObjectDetector){ 30 - this.detector = detector 109 + 3 -> Quad(inputShape[0], inputShape[1], inputShape[2], true) 110 + else -> Quad(0, 0, null, null) 111 + } 112 + 113 + return ModelInfo( 114 + inputShape = inputShape, 115 + outputShape = outputShape, 116 + inputWidth = w, 117 + inputHeight = h, 118 + inputChannels = c, 119 + isNhwc = nhwc, 120 + labels = labels, 121 + ) 122 + } 31 123 } 124 + } 32 125 33 - fun getDetector(): org.tensorflow.lite.task.vision.detector.ObjectDetector? { 34 - return detector 35 - }*/ 126 + private data class Quad( 127 + val h: Int, val w: Int, val c: Int?, val nhwc: Boolean? 128 + ) 129 + 130 + data class AndroidDetector( 131 + val interpreter: Interpreter, val modelInfo: ModelInfo 132 + ) 133 + 134 + actual class ObjectModel { 36 135 37 - private var detector: Interpreter? = null 136 + private var detector: AndroidDetector? = null 38 137 39 - constructor(detector: Interpreter){ 138 + constructor(detector: AndroidDetector) { 40 139 this.detector = detector 41 140 } 42 141 43 - fun getDetector(): Interpreter? { 142 + fun getDetector(): AndroidDetector? { 44 143 return detector 45 144 } 46 145 }
+5 -15
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/recording/InputFrame.android.kt
··· 24 24 import kotlin.math.max 25 25 import androidx.core.graphics.createBitmap 26 26 import co.touchlab.kermit.Logger 27 - import com.google.android.gms.tasks.Tasks 28 27 import com.performancecoachlab.posedetection.camera.process 29 28 import kotlinx.coroutines.Dispatchers 30 29 import kotlinx.coroutines.launch ··· 119 118 val mlKitImage = InputImage.fromBitmap(masked, 0) 120 119 121 120 // Object: resize to model input size, then normalize. 122 - val tensorImage: TensorImage? = interpreter?.let { tfl -> 123 - val shape = tfl.getInputTensor(0)?.shape() 124 - var w = 0 125 - var h = 0 126 - if (shape != null) { 127 - w = shape[1] 128 - h = shape[2] 129 - if (shape[1] == 3) { 130 - w = shape[2] 131 - h = shape[3] 132 - } 133 - } 121 + val tensorImage: TensorImage? = interpreter?.modelInfo?.let { tfl -> 122 + val w = tfl.inputWidth 123 + val h = tfl.inputHeight 124 + 134 125 if (w <= 0 || h <= 0) null else { 135 126 val dst = ensureTensorBitmap(w, h) 136 - // Draw scaled into dst (no new allocations). 137 - android.graphics.Canvas(dst).drawBitmap( 127 + Canvas(dst).drawBitmap( 138 128 inputFrame.bitmap, 139 129 null, 140 130 android.graphics.Rect(0, 0, w, h),
+5 -5
sample/composeApp/src/commonMain/kotlin/com/nate/posedetection/App.kt
··· 80 80 import kotlin.time.Clock 81 81 import kotlin.time.ExperimentalTime 82 82 83 - val androidPath = "basketballs_n1.tflite" 83 + val androidPath = "yolov10n_float16.tflite" 84 84 val iosPath = "basketballs_n1" 85 85 @Composable 86 86 internal fun App() = AppTheme { ··· 406 406 } 407 407 Button( 408 408 onClick = { 409 - frontCamera = !frontCamera 409 + recordingId = "${Clock.System.now().epochSeconds}" 410 410 }, 411 411 modifier = Modifier.imePadding().padding(16.dp).align(Alignment.TopStart) 412 412 ) { ··· 431 431 rightKnee = Pose.PoseRange(160.0, 180.0) 432 432 ) 433 433 skeleton?.let { 434 - if (upRightPose.matches(it)) Text("Standing", fontSize = 80.sp) 435 - else if (starPose.matches(it)) Text("Star", fontSize = 80.sp) 436 - else Text("No Pose Detected", fontSize = 80.sp) 434 + if (upRightPose.matches(it)) Text("Standing", fontSize = 15.sp) 435 + else if (starPose.matches(it)) Text("Star", fontSize = 15.sp) 436 + else Text("No Pose Detected", fontSize = 15.sp) 437 437 } 438 438 }