This repository has no description
1package com.performancecoachlab.posedetection.custom
2
3import androidx.compose.runtime.Composable
4import androidx.compose.ui.platform.LocalContext
5import co.touchlab.kermit.Logger
6import org.json.JSONObject
7import android.os.Build
8import org.tensorflow.lite.Interpreter
9import org.tensorflow.lite.gpu.GpuDelegate
10import org.tensorflow.lite.support.common.FileUtil
11import org.tensorflow.lite.support.metadata.MetadataExtractor
12import java.io.ByteArrayInputStream
13import java.nio.MappedByteBuffer
14import java.nio.charset.StandardCharsets
15import java.util.zip.GZIPInputStream
16import java.util.zip.Inflater
17import java.util.zip.InflaterInputStream
18import java.util.zip.ZipInputStream
19
20@Composable
21actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel {
22 if (modelPath.androidModelPath == null) {
23 throw IllegalArgumentException("Android model path cannot be null")
24 }
25 // Prefer GPU, then NNAPI (API 27+), then CPU. `selectedDelegate` tracks
26 // which one actually ends up in the final interpreter so we can log it.
27 var selectedDelegate = "CPU"
28 val (options, gpuDelegate) = runCatching {
29 val delegate = GpuDelegate()
30 val opts = Interpreter.Options().apply {
31 addDelegate(delegate)
32 setNumThreads(2)
33 }
34 selectedDelegate = "GPU"
35 Logger.i { "TFLite: GPU delegate constructed" }
36 opts to delegate
37 }.onFailure { t ->
38 Logger.w(t) { "TFLite: GPU delegate not available; trying NNAPI" }
39 }.getOrElse {
40 if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O_MR1) {
41 runCatching {
42 val nnapiDelegate = org.tensorflow.lite.nnapi.NnApiDelegate()
43 val opts = Interpreter.Options().apply {
44 addDelegate(nnapiDelegate)
45 setNumThreads(2)
46 }
47 selectedDelegate = "NNAPI"
48 Logger.i { "TFLite: NNAPI delegate constructed" }
49 opts to null
50 }.onFailure { t ->
51 Logger.w(t) { "TFLite: NNAPI delegate not available; falling back to CPU" }
52 }.getOrElse {
53 selectedDelegate = "CPU"
54 Interpreter.Options().apply { setNumThreads(4) } to null
55 }
56 } else {
57 selectedDelegate = "CPU"
58 Interpreter.Options().apply { setNumThreads(4) } to null
59 }
60 }
61
62 val model = FileUtil.loadMappedFile(LocalContext.current, modelPath.androidModelPath)
63 val labels = labels(model)
64
65 val interpreter = runCatching {
66 Interpreter(model, options)
67 }.onFailure { t ->
68 // If the chosen delegate can't actually build an interpreter (common
69 // for GPU on models with unsupported ops), fall back to pure CPU.
70 Logger.w(t) { "TFLite: failed to create interpreter with $selectedDelegate delegate; retrying on CPU" }
71 gpuDelegate?.close()
72 selectedDelegate = "CPU"
73 }.getOrElse {
74 val cpuOptions = Interpreter.Options().apply { setNumThreads(4) }
75 Interpreter(model, cpuOptions)
76 }
77
78 val inputShape = interpreter.getInputTensor(0)?.shape()
79 val outputShape = interpreter.getOutputTensor(0)?.shape()
80 Logger.i {
81 "TFLite: model='${modelPath.androidModelPath}' delegate=$selectedDelegate " +
82 "inputShape=${inputShape?.toList()} outputShape=${outputShape?.toList()}"
83 }
84 val modelInfo = ModelInfo.fromShapes(
85 inputShape = inputShape
86 ?: throw IllegalArgumentException("Invalid model: input shape is null"),
87 outputShape = outputShape
88 ?: throw IllegalArgumentException("Invalid model: output shape is null"),
89 labels,
90 )
91 val androidDetector = AndroidDetector(
92 interpreter = interpreter, modelInfo = modelInfo
93 )
94 return ObjectModel(androidDetector)
95}
96
97fun labels(model: MappedByteBuffer): List<String> {
98 return runCatching {
99 val extractor = MetadataExtractor(model)
100 val files = extractor.associatedFileNames.orEmpty()
101 if (files.isEmpty()) return@runCatching emptyList()
102
103 // Try every associated file. Pick the first that decodes to JSON with a `names` object.
104 for (name in files) {
105 val rawBytes = runCatching { extractor.getAssociatedFile(name).readBytes() }.getOrNull()
106 ?: continue
107
108 val decoded = rawBytes.decodeUtf8PossiblyCompressed()
109 val trimmed = decoded.trimStart()
110 if (!trimmed.startsWith("{")) continue
111
112 val labels = parseUltralyticsNamesJson(decoded)
113 if (labels.isNotEmpty()) {
114 Logger.d { "Loaded ${labels.size} labels from associated file '$name'" }
115 return@runCatching labels
116 }
117 }
118 emptyList()
119 }.onFailure { t ->
120 Logger.w(t) { "Failed to load labels from TFLite metadata" }
121 }.getOrDefault(emptyList())
122}
123
124private fun ByteArray.decodeUtf8PossiblyCompressed(): String {
125 if (isEmpty()) return ""
126
127 fun firstBytesHex(n: Int = 8): String =
128 take(minOf(size, n)).joinToString(" ") { b -> "%02x".format(b) }
129
130 // 0) Scan for embedded magic headers / common compressed stream signatures.
131 val zipOffset =
132 indexOfSubsequence(byteArrayOf('P'.code.toByte(), 'K'.code.toByte(), 0x03, 0x04))
133 val gzipOffset = indexOfSubsequence(byteArrayOf(0x1F.toByte(), 0x8B.toByte()))
134
135 // Common zlib headers (CMF/FLG). Most common are 0x78 0x9C (default), 0x78 0xDA (best), 0x78 0x01 (no compression).
136 val zlibOffsets = listOf(
137 indexOfSubsequence(byteArrayOf(0x78.toByte(), 0x9C.toByte())),
138 indexOfSubsequence(byteArrayOf(0x78.toByte(), 0xDA.toByte())),
139 indexOfSubsequence(byteArrayOf(0x78.toByte(), 0x01.toByte())),
140 ).filter { it >= 0 }.distinct().sorted()
141
142 if (zipOffset > 0) {
143 runCatching {
144 val sliced = copyOfRange(zipOffset, size)
145 val text = sliced.decodeFromZipIfPossible(preferEntryName = "metadata.json")
146 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
147 }
148 }
149
150 if (gzipOffset > 0) {
151 runCatching {
152 val sliced = copyOfRange(gzipOffset, size)
153 val text = sliced.decodeFromGzipOrTarGz()
154 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
155 }
156 }
157
158 for (off in zlibOffsets) {
159 if (off <= 0) continue
160 runCatching {
161 val sliced = copyOfRange(off, size)
162 val text = sliced.decodeFromZlib()
163 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
164 }
165 }
166
167 runCatching {
168 val text = decodeFromGzipOrTarGz()
169 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
170 }
171
172 runCatching {
173 val text = decodeFromZipIfPossible(preferEntryName = "metadata.json")
174 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
175 }
176
177 runCatching {
178 val text = decodeFromZlib()
179 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
180 }
181
182 runCatching {
183 val text = decodeFromDeflateRaw()
184 if (text.isNotBlank() && text.trimStart().startsWith("{")) return text
185 }
186
187 return toString(StandardCharsets.UTF_8)
188}
189
190private fun ByteArray.decodeFromDeflateRaw(): String {
191 return runCatching {
192 val inflater = Inflater(true) // nowrap=true => raw DEFLATE
193 InflaterInputStream(ByteArrayInputStream(this), inflater)
194 .bufferedReader(StandardCharsets.UTF_8)
195 .use { it.readText() }
196 }.getOrDefault("")
197}
198
199private fun ByteArray.decodeFromZlib(): String {
200 return runCatching {
201 InflaterInputStream(ByteArrayInputStream(this))
202 .bufferedReader(StandardCharsets.UTF_8)
203 .use { it.readText() }
204 }.getOrDefault("")
205}
206
207private fun ByteArray.decodeFromGzipOrTarGz(): String {
208 // First try: plain gzipped UTF-8 directly
209 runCatching {
210 val text = GZIPInputStream(ByteArrayInputStream(this))
211 .bufferedReader(StandardCharsets.UTF_8)
212 .use { it.readText() }
213 if (text.isNotBlank()) return text
214 }
215
216 // Second try: gzipped TAR that contains metadata.json (or first file)
217 return runCatching {
218 GZIPInputStream(ByteArrayInputStream(this)).use { gz ->
219 extractFirstTarFileUtf8(gz)
220 }
221 }.getOrElse {
222 ""
223 }
224}
225
226private fun extractFirstTarFileUtf8(input: java.io.InputStream): String {
227 // Minimal TAR reader: TAR headers are 512 bytes.
228 // We only need to extract the first regular file (or metadata.json if present)
229 // and decode it as UTF-8.
230
231 fun readExactly(buf: ByteArray): Boolean {
232 var off = 0
233 while (off < buf.size) {
234 val r = input.read(buf, off, buf.size - off)
235 if (r <= 0) return false
236 off += r
237 }
238 return true
239 }
240
241 val header = ByteArray(512)
242 var firstText: String? = null
243
244 while (true) {
245 if (!readExactly(header)) break
246 // End of archive: two consecutive 512-byte blocks of zero
247 if (header.all { it == 0.toByte() }) break
248
249 val name = header.copyOfRange(0, 100).toString(StandardCharsets.US_ASCII)
250 .trimEnd { it == '\u0000' }
251 val sizeOctal = header.copyOfRange(124, 136).toString(StandardCharsets.US_ASCII).trim()
252 .trimEnd { it == '\u0000' }
253 val typeFlag = header[156]
254
255 val fileSize = sizeOctal.toLongOrNull(8) ?: 0L
256 val isRegularFile = typeFlag == 0.toByte() || typeFlag == '0'.code.toByte()
257
258 val fileData = ByteArray(fileSize.toInt())
259 if (fileSize > 0 && !readExactly(fileData)) break
260
261 // TAR pads file data to 512 byte blocks.
262 val pad = ((512 - (fileSize % 512)) % 512).toInt()
263 if (pad > 0) {
264 val skip = ByteArray(pad)
265 if (!readExactly(skip)) break
266 }
267
268 if (isRegularFile) {
269 val text = fileData.toString(StandardCharsets.UTF_8)
270 if (name.equals("metadata.json", ignoreCase = true)) return text
271 if (firstText == null && text.isNotBlank()) firstText = text
272 }
273 }
274
275 return firstText.orEmpty()
276}
277
278private fun ByteArray.indexOfSubsequence(needle: ByteArray): Int {
279 if (needle.isEmpty() || size < needle.size) return -1
280 outer@ for (i in 0..(size - needle.size)) {
281 for (j in needle.indices) {
282 if (this[i + j] != needle[j]) continue@outer
283 }
284 return i
285 }
286 return -1
287}
288
289private fun parseUltralyticsNamesJson(text: String): List<String> {
290 return runCatching {
291 val root = JSONObject(text)
292 val namesObj = root.optJSONObject("names") ?: return@runCatching emptyList()
293
294 val keys = namesObj.keys().asSequence()
295 .mapNotNull { k -> k.toIntOrNull()?.let { idx -> idx to k } }
296 .sortedBy { it.first }
297 .toList()
298
299 if (keys.isEmpty()) return@runCatching emptyList()
300
301 val maxIdx = keys.maxOf { it.first }
302 val out = MutableList(maxIdx + 1) { "" }
303
304 for ((idx, key) in keys) {
305 val label = namesObj.optString(key, "").trim()
306 if (label.isNotBlank() && idx in out.indices) {
307 out[idx] = label
308 }
309 }
310
311 out.filter { it.isNotBlank() }
312 }.getOrDefault(emptyList())
313}
314
315private fun parseUltralyticsNames(text: String): List<String> {
316 // Grab the `names: { ... }` section (non-greedy) to avoid matching other maps.
317 val namesBlock = Regex("""['"]names['"]\s*:\s*\{([\s\S]*?)\}""")
318 .find(text)
319 ?.groupValues
320 ?.getOrNull(1)
321 ?: return emptyList()
322
323 // Match entries like: 0: 'person' OR 0: "person"
324 val entry = Regex("""(\d+)\s*:\s*['"]([^'"]+)['"]""")
325 val pairs = entry.findAll(namesBlock).mapNotNull { m ->
326 val idx = m.groupValues[1].toIntOrNull() ?: return@mapNotNull null
327 val name = m.groupValues[2].trim()
328 if (name.isBlank()) null else idx to name
329 }.toList()
330
331 if (pairs.isEmpty()) return emptyList()
332
333 val maxIdx = pairs.maxOf { it.first }
334 val out = MutableList(maxIdx + 1) { "" }
335 for ((i, label) in pairs) {
336 if (i in out.indices) out[i] = label
337 }
338 return out.filter { it.isNotBlank() }
339}
340
341data class ModelInfo(
342 val inputShape: IntArray,
343 val outputShape: IntArray,
344 val inputWidth: Int,
345 val inputHeight: Int,
346 val inputChannels: Int? = null,
347 val isNhwc: Boolean? = null,
348 val labels: List<String>,
349) {
350 companion object {
351 fun fromShapes(
352 inputShape: IntArray,
353 outputShape: IntArray,
354 labels: List<String>
355 ): ModelInfo {
356 // Common TFLite image shapes:
357 // NHWC: [1, H, W, C]
358 // NCHW: [1, C, H, W] (less common on Android)
359 // Some models might be [H, W, C] (no batch).
360 val (h, w, c, nhwc) = when (inputShape.size) {
361 4 -> {
362 val isNhwcGuess = inputShape[3] in 1..4
363 if (isNhwcGuess) {
364 Quad(inputShape[1], inputShape[2], inputShape[3], true)
365 } else {
366 Quad(inputShape[2], inputShape[3], inputShape[1], false)
367 }
368 }
369
370 3 -> Quad(inputShape[0], inputShape[1], inputShape[2], true)
371 else -> Quad(0, 0, null, null)
372 }
373
374 return ModelInfo(
375 inputShape = inputShape,
376 outputShape = outputShape,
377 inputWidth = w,
378 inputHeight = h,
379 inputChannels = c,
380 isNhwc = nhwc,
381 labels = labels,
382 )
383 }
384 }
385}
386
387private data class Quad(
388 val h: Int, val w: Int, val c: Int?, val nhwc: Boolean?
389)
390
391data class AndroidDetector(
392 val interpreter: Interpreter, val modelInfo: ModelInfo
393)
394
395actual class ObjectModel {
396
397 private var detector: AndroidDetector? = null
398
399 constructor(detector: AndroidDetector) {
400 this.detector = detector
401 }
402
403 fun getDetector(): AndroidDetector? {
404 return detector
405 }
406}
407
408private fun ByteArray.decodeFromZipIfPossible(preferEntryName: String): String {
409 return runCatching {
410 ZipInputStream(ByteArrayInputStream(this)).use { zis ->
411 var firstNonEmpty: String? = null
412 val entries = mutableListOf<String>()
413
414 while (true) {
415 val entry = zis.nextEntry ?: break
416 entries += entry.name
417 if (entry.isDirectory) continue
418
419 val entryBytes = zis.readBytes()
420 val text = entryBytes.toString(StandardCharsets.UTF_8)
421
422 if (entry.name.equals(preferEntryName, ignoreCase = true)) {
423 Logger.d { "TFLite metadata zip entry=${entry.name} size=${entryBytes.size} entries=$entries" }
424 return@use text
425 }
426
427 if (firstNonEmpty == null && text.isNotBlank()) {
428 firstNonEmpty = text
429 }
430 }
431
432 firstNonEmpty.orEmpty()
433 }
434 }.getOrDefault("")
435}
436
437@Composable
438internal actual fun platformRememberObjectModel(modelPath: ModelPath): ObjectModel {
439 return ObjectModelProvider.get(modelPath)
440}