This repository has no description
0

Configure Feed

Select the types of activity you want to include in your feed.

at master 15 kB View raw
1package com.performancecoachlab.posedetection.custom 2 3import androidx.compose.runtime.Composable 4import androidx.compose.ui.platform.LocalContext 5import co.touchlab.kermit.Logger 6import org.json.JSONObject 7import android.os.Build 8import org.tensorflow.lite.Interpreter 9import org.tensorflow.lite.gpu.GpuDelegate 10import org.tensorflow.lite.support.common.FileUtil 11import org.tensorflow.lite.support.metadata.MetadataExtractor 12import java.io.ByteArrayInputStream 13import java.nio.MappedByteBuffer 14import java.nio.charset.StandardCharsets 15import java.util.zip.GZIPInputStream 16import java.util.zip.Inflater 17import java.util.zip.InflaterInputStream 18import java.util.zip.ZipInputStream 19 20@Composable 21actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel { 22 if (modelPath.androidModelPath == null) { 23 throw IllegalArgumentException("Android model path cannot be null") 24 } 25 // Prefer GPU, then NNAPI (API 27+), then CPU. `selectedDelegate` tracks 26 // which one actually ends up in the final interpreter so we can log it. 27 var selectedDelegate = "CPU" 28 val (options, gpuDelegate) = runCatching { 29 val delegate = GpuDelegate() 30 val opts = Interpreter.Options().apply { 31 addDelegate(delegate) 32 setNumThreads(2) 33 } 34 selectedDelegate = "GPU" 35 Logger.i { "TFLite: GPU delegate constructed" } 36 opts to delegate 37 }.onFailure { t -> 38 Logger.w(t) { "TFLite: GPU delegate not available; trying NNAPI" } 39 }.getOrElse { 40 if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O_MR1) { 41 runCatching { 42 val nnapiDelegate = org.tensorflow.lite.nnapi.NnApiDelegate() 43 val opts = Interpreter.Options().apply { 44 addDelegate(nnapiDelegate) 45 setNumThreads(2) 46 } 47 selectedDelegate = "NNAPI" 48 Logger.i { "TFLite: NNAPI delegate constructed" } 49 opts to null 50 }.onFailure { t -> 51 Logger.w(t) { "TFLite: NNAPI delegate not available; falling back to CPU" } 52 }.getOrElse { 53 selectedDelegate = "CPU" 54 Interpreter.Options().apply { setNumThreads(4) } to null 55 } 56 } else { 57 selectedDelegate = "CPU" 58 Interpreter.Options().apply { setNumThreads(4) } to null 59 } 60 } 61 62 val model = FileUtil.loadMappedFile(LocalContext.current, modelPath.androidModelPath) 63 val labels = labels(model) 64 65 val interpreter = runCatching { 66 Interpreter(model, options) 67 }.onFailure { t -> 68 // If the chosen delegate can't actually build an interpreter (common 69 // for GPU on models with unsupported ops), fall back to pure CPU. 70 Logger.w(t) { "TFLite: failed to create interpreter with $selectedDelegate delegate; retrying on CPU" } 71 gpuDelegate?.close() 72 selectedDelegate = "CPU" 73 }.getOrElse { 74 val cpuOptions = Interpreter.Options().apply { setNumThreads(4) } 75 Interpreter(model, cpuOptions) 76 } 77 78 val inputShape = interpreter.getInputTensor(0)?.shape() 79 val outputShape = interpreter.getOutputTensor(0)?.shape() 80 Logger.i { 81 "TFLite: model='${modelPath.androidModelPath}' delegate=$selectedDelegate " + 82 "inputShape=${inputShape?.toList()} outputShape=${outputShape?.toList()}" 83 } 84 val modelInfo = ModelInfo.fromShapes( 85 inputShape = inputShape 86 ?: throw IllegalArgumentException("Invalid model: input shape is null"), 87 outputShape = outputShape 88 ?: throw IllegalArgumentException("Invalid model: output shape is null"), 89 labels, 90 ) 91 val androidDetector = AndroidDetector( 92 interpreter = interpreter, modelInfo = modelInfo 93 ) 94 return ObjectModel(androidDetector) 95} 96 97fun labels(model: MappedByteBuffer): List<String> { 98 return runCatching { 99 val extractor = MetadataExtractor(model) 100 val files = extractor.associatedFileNames.orEmpty() 101 if (files.isEmpty()) return@runCatching emptyList() 102 103 // Try every associated file. Pick the first that decodes to JSON with a `names` object. 104 for (name in files) { 105 val rawBytes = runCatching { extractor.getAssociatedFile(name).readBytes() }.getOrNull() 106 ?: continue 107 108 val decoded = rawBytes.decodeUtf8PossiblyCompressed() 109 val trimmed = decoded.trimStart() 110 if (!trimmed.startsWith("{")) continue 111 112 val labels = parseUltralyticsNamesJson(decoded) 113 if (labels.isNotEmpty()) { 114 Logger.d { "Loaded ${labels.size} labels from associated file '$name'" } 115 return@runCatching labels 116 } 117 } 118 emptyList() 119 }.onFailure { t -> 120 Logger.w(t) { "Failed to load labels from TFLite metadata" } 121 }.getOrDefault(emptyList()) 122} 123 124private fun ByteArray.decodeUtf8PossiblyCompressed(): String { 125 if (isEmpty()) return "" 126 127 fun firstBytesHex(n: Int = 8): String = 128 take(minOf(size, n)).joinToString(" ") { b -> "%02x".format(b) } 129 130 // 0) Scan for embedded magic headers / common compressed stream signatures. 131 val zipOffset = 132 indexOfSubsequence(byteArrayOf('P'.code.toByte(), 'K'.code.toByte(), 0x03, 0x04)) 133 val gzipOffset = indexOfSubsequence(byteArrayOf(0x1F.toByte(), 0x8B.toByte())) 134 135 // Common zlib headers (CMF/FLG). Most common are 0x78 0x9C (default), 0x78 0xDA (best), 0x78 0x01 (no compression). 136 val zlibOffsets = listOf( 137 indexOfSubsequence(byteArrayOf(0x78.toByte(), 0x9C.toByte())), 138 indexOfSubsequence(byteArrayOf(0x78.toByte(), 0xDA.toByte())), 139 indexOfSubsequence(byteArrayOf(0x78.toByte(), 0x01.toByte())), 140 ).filter { it >= 0 }.distinct().sorted() 141 142 if (zipOffset > 0) { 143 runCatching { 144 val sliced = copyOfRange(zipOffset, size) 145 val text = sliced.decodeFromZipIfPossible(preferEntryName = "metadata.json") 146 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 147 } 148 } 149 150 if (gzipOffset > 0) { 151 runCatching { 152 val sliced = copyOfRange(gzipOffset, size) 153 val text = sliced.decodeFromGzipOrTarGz() 154 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 155 } 156 } 157 158 for (off in zlibOffsets) { 159 if (off <= 0) continue 160 runCatching { 161 val sliced = copyOfRange(off, size) 162 val text = sliced.decodeFromZlib() 163 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 164 } 165 } 166 167 runCatching { 168 val text = decodeFromGzipOrTarGz() 169 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 170 } 171 172 runCatching { 173 val text = decodeFromZipIfPossible(preferEntryName = "metadata.json") 174 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 175 } 176 177 runCatching { 178 val text = decodeFromZlib() 179 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 180 } 181 182 runCatching { 183 val text = decodeFromDeflateRaw() 184 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text 185 } 186 187 return toString(StandardCharsets.UTF_8) 188} 189 190private fun ByteArray.decodeFromDeflateRaw(): String { 191 return runCatching { 192 val inflater = Inflater(true) // nowrap=true => raw DEFLATE 193 InflaterInputStream(ByteArrayInputStream(this), inflater) 194 .bufferedReader(StandardCharsets.UTF_8) 195 .use { it.readText() } 196 }.getOrDefault("") 197} 198 199private fun ByteArray.decodeFromZlib(): String { 200 return runCatching { 201 InflaterInputStream(ByteArrayInputStream(this)) 202 .bufferedReader(StandardCharsets.UTF_8) 203 .use { it.readText() } 204 }.getOrDefault("") 205} 206 207private fun ByteArray.decodeFromGzipOrTarGz(): String { 208 // First try: plain gzipped UTF-8 directly 209 runCatching { 210 val text = GZIPInputStream(ByteArrayInputStream(this)) 211 .bufferedReader(StandardCharsets.UTF_8) 212 .use { it.readText() } 213 if (text.isNotBlank()) return text 214 } 215 216 // Second try: gzipped TAR that contains metadata.json (or first file) 217 return runCatching { 218 GZIPInputStream(ByteArrayInputStream(this)).use { gz -> 219 extractFirstTarFileUtf8(gz) 220 } 221 }.getOrElse { 222 "" 223 } 224} 225 226private fun extractFirstTarFileUtf8(input: java.io.InputStream): String { 227 // Minimal TAR reader: TAR headers are 512 bytes. 228 // We only need to extract the first regular file (or metadata.json if present) 229 // and decode it as UTF-8. 230 231 fun readExactly(buf: ByteArray): Boolean { 232 var off = 0 233 while (off < buf.size) { 234 val r = input.read(buf, off, buf.size - off) 235 if (r <= 0) return false 236 off += r 237 } 238 return true 239 } 240 241 val header = ByteArray(512) 242 var firstText: String? = null 243 244 while (true) { 245 if (!readExactly(header)) break 246 // End of archive: two consecutive 512-byte blocks of zero 247 if (header.all { it == 0.toByte() }) break 248 249 val name = header.copyOfRange(0, 100).toString(StandardCharsets.US_ASCII) 250 .trimEnd { it == '\u0000' } 251 val sizeOctal = header.copyOfRange(124, 136).toString(StandardCharsets.US_ASCII).trim() 252 .trimEnd { it == '\u0000' } 253 val typeFlag = header[156] 254 255 val fileSize = sizeOctal.toLongOrNull(8) ?: 0L 256 val isRegularFile = typeFlag == 0.toByte() || typeFlag == '0'.code.toByte() 257 258 val fileData = ByteArray(fileSize.toInt()) 259 if (fileSize > 0 && !readExactly(fileData)) break 260 261 // TAR pads file data to 512 byte blocks. 262 val pad = ((512 - (fileSize % 512)) % 512).toInt() 263 if (pad > 0) { 264 val skip = ByteArray(pad) 265 if (!readExactly(skip)) break 266 } 267 268 if (isRegularFile) { 269 val text = fileData.toString(StandardCharsets.UTF_8) 270 if (name.equals("metadata.json", ignoreCase = true)) return text 271 if (firstText == null && text.isNotBlank()) firstText = text 272 } 273 } 274 275 return firstText.orEmpty() 276} 277 278private fun ByteArray.indexOfSubsequence(needle: ByteArray): Int { 279 if (needle.isEmpty() || size < needle.size) return -1 280 outer@ for (i in 0..(size - needle.size)) { 281 for (j in needle.indices) { 282 if (this[i + j] != needle[j]) continue@outer 283 } 284 return i 285 } 286 return -1 287} 288 289private fun parseUltralyticsNamesJson(text: String): List<String> { 290 return runCatching { 291 val root = JSONObject(text) 292 val namesObj = root.optJSONObject("names") ?: return@runCatching emptyList() 293 294 val keys = namesObj.keys().asSequence() 295 .mapNotNull { k -> k.toIntOrNull()?.let { idx -> idx to k } } 296 .sortedBy { it.first } 297 .toList() 298 299 if (keys.isEmpty()) return@runCatching emptyList() 300 301 val maxIdx = keys.maxOf { it.first } 302 val out = MutableList(maxIdx + 1) { "" } 303 304 for ((idx, key) in keys) { 305 val label = namesObj.optString(key, "").trim() 306 if (label.isNotBlank() && idx in out.indices) { 307 out[idx] = label 308 } 309 } 310 311 out.filter { it.isNotBlank() } 312 }.getOrDefault(emptyList()) 313} 314 315private fun parseUltralyticsNames(text: String): List<String> { 316 // Grab the `names: { ... }` section (non-greedy) to avoid matching other maps. 317 val namesBlock = Regex("""['"]names['"]\s*:\s*\{([\s\S]*?)\}""") 318 .find(text) 319 ?.groupValues 320 ?.getOrNull(1) 321 ?: return emptyList() 322 323 // Match entries like: 0: 'person' OR 0: "person" 324 val entry = Regex("""(\d+)\s*:\s*['"]([^'"]+)['"]""") 325 val pairs = entry.findAll(namesBlock).mapNotNull { m -> 326 val idx = m.groupValues[1].toIntOrNull() ?: return@mapNotNull null 327 val name = m.groupValues[2].trim() 328 if (name.isBlank()) null else idx to name 329 }.toList() 330 331 if (pairs.isEmpty()) return emptyList() 332 333 val maxIdx = pairs.maxOf { it.first } 334 val out = MutableList(maxIdx + 1) { "" } 335 for ((i, label) in pairs) { 336 if (i in out.indices) out[i] = label 337 } 338 return out.filter { it.isNotBlank() } 339} 340 341data class ModelInfo( 342 val inputShape: IntArray, 343 val outputShape: IntArray, 344 val inputWidth: Int, 345 val inputHeight: Int, 346 val inputChannels: Int? = null, 347 val isNhwc: Boolean? = null, 348 val labels: List<String>, 349) { 350 companion object { 351 fun fromShapes( 352 inputShape: IntArray, 353 outputShape: IntArray, 354 labels: List<String> 355 ): ModelInfo { 356 // Common TFLite image shapes: 357 // NHWC: [1, H, W, C] 358 // NCHW: [1, C, H, W] (less common on Android) 359 // Some models might be [H, W, C] (no batch). 360 val (h, w, c, nhwc) = when (inputShape.size) { 361 4 -> { 362 val isNhwcGuess = inputShape[3] in 1..4 363 if (isNhwcGuess) { 364 Quad(inputShape[1], inputShape[2], inputShape[3], true) 365 } else { 366 Quad(inputShape[2], inputShape[3], inputShape[1], false) 367 } 368 } 369 370 3 -> Quad(inputShape[0], inputShape[1], inputShape[2], true) 371 else -> Quad(0, 0, null, null) 372 } 373 374 return ModelInfo( 375 inputShape = inputShape, 376 outputShape = outputShape, 377 inputWidth = w, 378 inputHeight = h, 379 inputChannels = c, 380 isNhwc = nhwc, 381 labels = labels, 382 ) 383 } 384 } 385} 386 387private data class Quad( 388 val h: Int, val w: Int, val c: Int?, val nhwc: Boolean? 389) 390 391data class AndroidDetector( 392 val interpreter: Interpreter, val modelInfo: ModelInfo 393) 394 395actual class ObjectModel { 396 397 private var detector: AndroidDetector? = null 398 399 constructor(detector: AndroidDetector) { 400 this.detector = detector 401 } 402 403 fun getDetector(): AndroidDetector? { 404 return detector 405 } 406} 407 408private fun ByteArray.decodeFromZipIfPossible(preferEntryName: String): String { 409 return runCatching { 410 ZipInputStream(ByteArrayInputStream(this)).use { zis -> 411 var firstNonEmpty: String? = null 412 val entries = mutableListOf<String>() 413 414 while (true) { 415 val entry = zis.nextEntry ?: break 416 entries += entry.name 417 if (entry.isDirectory) continue 418 419 val entryBytes = zis.readBytes() 420 val text = entryBytes.toString(StandardCharsets.UTF_8) 421 422 if (entry.name.equals(preferEntryName, ignoreCase = true)) { 423 Logger.d { "TFLite metadata zip entry=${entry.name} size=${entryBytes.size} entries=$entries" } 424 return@use text 425 } 426 427 if (firstNonEmpty == null && text.isNotBlank()) { 428 firstNonEmpty = text 429 } 430 } 431 432 firstNonEmpty.orEmpty() 433 } 434 }.getOrDefault("") 435} 436 437@Composable 438internal actual fun platformRememberObjectModel(modelPath: ModelPath): ObjectModel { 439 return ObjectModelProvider.get(modelPath) 440}