This repository has no description
0

Configure Feed

Select the types of activity you want to include in your feed.

feat: tensor interpretation for yolo models

+181 -22
+3 -1
posedetection/build.gradle.kts
··· 90 90 implementation(libs.pose.detection) 91 91 implementation(libs.pose.detection.common) 92 92 implementation(libs.androidx.media3.common.ktx) 93 - implementation(libs.tensorflow.lite.task.vision) 93 + //implementation(libs.tensorflow.lite.task.vision) 94 + implementation("com.google.ai.edge.litert:litert:1.4.1") 95 + implementation("com.google.ai.edge.litert:litert-support:1.4.1") 94 96 } 95 97 96 98 }
+3 -2
posedetection/src/androidMain/kotlin/com.performancecoachlab/posedetection/camera/CameraView.android.kt
··· 54 54 import com.performancecoachlab.posedetection.recording.AnalysisResult 55 55 import kotlinx.coroutines.launch 56 56 import org.tensorflow.lite.support.image.TensorImage 57 - import org.tensorflow.lite.task.vision.detector.ObjectDetector 58 57 import android.hardware.camera2.CameraCharacteristics 59 58 import androidx.camera.camera2.interop.Camera2CameraInfo 59 + import androidx.camera.camera2.interop.ExperimentalCamera2Interop 60 60 import androidx.camera.core.CameraInfo 61 61 import androidx.camera.core.CameraSelector 62 62 import co.touchlab.kermit.Logger ··· 68 68 val outputPath: String 69 69 ) 70 70 71 + @OptIn(ExperimentalCamera2Interop::class) 71 72 private fun buildBackUltraWideSelectorOrNull(cameraProvider: ProcessCameraProvider): CameraSelector? { 72 73 return try { 73 74 val bestBackInfo: CameraInfo? = cameraProvider.availableCameraInfos ··· 104 105 } 105 106 } 106 107 107 - @OptIn(ExperimentalGetImage::class) 108 + @OptIn(ExperimentalCamera2Interop::class) 108 109 @Composable 109 110 actual fun CameraView( 110 111 skeletonRepository: SkeletonRepository,
+147 -10
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/camera/Utils.android.kt
··· 12 12 import com.performancecoachlab.posedetection.recording.AnalysisResult 13 13 import com.performancecoachlab.posedetection.recording.FrameSize 14 14 import com.performancecoachlab.posedetection.skeleton.Skeleton 15 + import org.tensorflow.lite.DataType 16 + import org.tensorflow.lite.Interpreter 17 + import org.tensorflow.lite.support.common.ops.CastOp 18 + import org.tensorflow.lite.support.common.ops.NormalizeOp 19 + import org.tensorflow.lite.support.image.ImageProcessor 15 20 import org.tensorflow.lite.support.image.TensorImage 21 + import org.tensorflow.lite.support.tensorbuffer.TensorBuffer 22 + import posedetection.posedetection.generated.resources.Res 16 23 import kotlin.math.absoluteValue 24 + import androidx.core.graphics.scale 25 + import co.touchlab.kermit.Logger 26 + import kotlin.collections.get 27 + import kotlin.math.max 28 + import kotlin.math.min 29 + import kotlin.run 17 30 18 31 actual enum class PlatformType { 19 32 ANDROID, IOS; ··· 25 38 26 39 @OptIn(ExperimentalGetImage::class) 27 40 fun ImageProxy.process( 28 - objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?, 41 + objectDetector: Interpreter?, 29 42 poseDetector: PoseDetector?, 30 43 timestamp: Long, 31 44 focusArea: Rect?, 32 45 onComplete: (AnalysisResult, Bitmap) -> Unit 33 46 ) { 34 - val bitmap = toBitmap().rotate(imageInfo.rotationDegrees.toFloat()) 35 - val tensorImage = TensorImage.fromBitmap(bitmap) 47 + val bitmap = objectDetector?.let { interperter -> 48 + val inputShape = interperter.getInputTensor(0)?.shape() 49 + var tensorWidth = 0 50 + var tensorHeight = 0 51 + if (inputShape != null) { 52 + tensorWidth = inputShape[1] 53 + tensorHeight = inputShape[2] 54 + 55 + // If in case input shape is in format of [1, 3, ..., ...] 56 + if (inputShape[1] == 3) { 57 + tensorWidth = inputShape[2] 58 + tensorHeight = inputShape[3] 59 + } 60 + } 61 + toBitmap().rotate(imageInfo.rotationDegrees.toFloat()) 62 + .scale(tensorWidth, tensorHeight, false) 63 + } 64 + if(bitmap == null){ 65 + return 66 + } 67 + 68 + 69 + //val tensorImage = TensorImage.fromBitmap(bitmap) 70 + val tensorImage = TensorImage(DataType.FLOAT32) 71 + tensorImage.load(bitmap) 72 + val processedImage = imageProcessor.process(tensorImage) 36 73 val mlKitImage = InputImage.fromBitmap(bitmap.applyFocusAreaMask(focusArea,imageInfo.rotationDegrees 37 74 ), imageInfo.rotationDegrees) 38 75 process( 39 - tensorImage = tensorImage, 76 + tensorImage = processedImage, 40 77 mlKitImage = mlKitImage, 41 78 objectDetector = objectDetector, 42 79 poseDetector = poseDetector, ··· 57 94 ) 58 95 }?: android.graphics.Rect(0, 0, width, height) 59 96 } 97 + private val imageProcessor = ImageProcessor.Builder() 98 + .add(NormalizeOp(0f, 255f)) 99 + .add(CastOp(DataType.FLOAT32)) 100 + .build() 60 101 61 102 fun Bitmap.process( 62 - objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?, 103 + objectDetector: Interpreter?, 63 104 poseDetector: PoseDetector, 64 105 timestamp: Long, 65 106 focusArea: Rect?, 66 107 onComplete: (AnalysisResult, Bitmap) -> Unit 67 108 ) { 68 - val tensorImage = TensorImage.fromBitmap(this) 109 + //val tensorImage = TensorImage.fromBitmap(this) 110 + val tensorImage = TensorImage(DataType.FLOAT32) 111 + tensorImage.load(this) 112 + val processedImage = imageProcessor.process(tensorImage) 69 113 val mlKitImage = InputImage.fromBitmap(this.applyFocusAreaMask(focusArea), 0) 70 114 process( 71 - tensorImage = tensorImage, 115 + tensorImage = processedImage, 72 116 mlKitImage = mlKitImage, 73 117 objectDetector = objectDetector, 74 118 poseDetector = poseDetector, ··· 170 214 private fun process( 171 215 tensorImage: TensorImage, 172 216 mlKitImage: InputImage?, 173 - objectDetector: org.tensorflow.lite.task.vision.detector.ObjectDetector?, 217 + objectDetector: Interpreter?, 174 218 poseDetector: PoseDetector?, 175 219 timestamp: Long, 176 220 width: Int, ··· 178 222 bitmap: Bitmap, 179 223 onComplete: (AnalysisResult, Bitmap) -> Unit 180 224 ) { 225 + Logger.d{"Processing image of size: ${tensorImage.width}x${tensorImage.height}" } 226 + val objectsDetected = objectDetector?.let { interpreter -> 227 + val outputTensor = interpreter.getOutputTensor(0) 228 + val outputShape = outputTensor.shape() // e.g. [1, 180, 6] or [1, 6, 180] 229 + val output = TensorBuffer.createFixedSize(outputShape, DataType.FLOAT32) 230 + 231 + interpreter.run(tensorImage.buffer, output.buffer) 232 + 233 + val array = output.floatArray 234 + if (outputShape.size != 3) return@let emptyList<AnalysisObject>() 235 + 236 + val dim1 = outputShape[1] 237 + val dim2 = outputShape[2] 238 + 239 + // We expect 6 values per detection: x1,y1,x2,y2,cnf,cls 240 + // So whichever dimension equals 6 is the "channels". 241 + val channels: Int 242 + val elements: Int 243 + val isElementsFirst: Boolean // true if shape is [1, elements, channels] 244 + 245 + when { 246 + dim2 == 6 -> { 247 + // [1, elements, 6] 248 + elements = dim1 249 + channels = dim2 250 + isElementsFirst = true 251 + } 252 + dim1 == 6 -> { 253 + // [1, 6, elements] 254 + channels = dim1 255 + elements = dim2 256 + isElementsFirst = false 257 + } 258 + else -> { 259 + // Unknown layout; bail out rather than silently producing 0 detections. 260 + return@let emptyList<AnalysisObject>() 261 + } 262 + } 263 + 264 + Logger.d{"Processing objects: ${elements}" } 265 + 266 + fun valueAt(elementIndex: Int, channelIndex: Int): Float { 267 + return if (isElementsFirst) { 268 + // base = elementIndex * channels + channelIndex 269 + array[elementIndex * channels + channelIndex] 270 + } else { 271 + // base = channelIndex * elements + elementIndex 272 + array[channelIndex * elements + elementIndex] 273 + } 274 + } 275 + 276 + (0 until elements).mapNotNull { i -> 277 + val cnf = valueAt(i, 4) 278 + if (cnf > 0.25f) { 279 + val x1 = valueAt(i, 0) 280 + val y1 = valueAt(i, 1) 281 + val x2 = valueAt(i, 2) 282 + val y2 = valueAt(i, 3) 283 + val cls = valueAt(i, 5).toInt() 284 + 285 + val leftN = min(x1, x2) 286 + val topN = min(y1, y2) 287 + val rightN = max(x1, x2) 288 + val bottomN = max(y1, y2) 289 + 290 + val leftPx = leftN * width.absoluteValue 291 + val topPx = topN * height.absoluteValue 292 + val rightPx = rightN * width.absoluteValue 293 + val bottomPx = bottomN * height.absoluteValue 294 + 295 + 296 + AnalysisObject( 297 + boundingBox = Rect(left = leftPx, top = topPx, right = rightPx, bottom = bottomPx), 298 + trackingId = 0, 299 + labels = listOf(com.performancecoachlab.posedetection.recording.Label("$cls", cnf)), 300 + frameSize = FrameSize(width = width.absoluteValue, height = height.absoluteValue) 301 + ) 302 + } else null 303 + } 304 + } ?: emptyList() 305 + 306 + Logger.d{"Processed objecs size: ${objectsDetected.size}" } 307 + /* 181 308 val objectsDetected = objectDetector?.detect(tensorImage)?.map { result -> 182 309 AnalysisObject( 183 310 boundingBox = result.boundingBox.let { ··· 200 327 height = height.absoluteValue 201 328 ) 202 329 ) 203 - } ?: emptyList() 330 + } ?: emptyList()*/ 204 331 var skeleton: Skeleton? = null 205 332 val poseDetectionTask = mlKitImage?.let { 206 333 val rotation = it.rotationDegrees ··· 221 348 bitmap 222 349 ) 223 350 } 224 - } 351 + } 352 + 353 + data class BoundingBox( 354 + val x1: Float, 355 + val y1: Float, 356 + val x2: Float, 357 + val y2: Float, 358 + val cnf: Float, 359 + val cls: Int, 360 + val clsName: String 361 + )
+27 -8
posedetection/src/androidMain/kotlin/com/performancecoachlab/posedetection/custom/CustomObjectModel.android.kt
··· 2 2 3 3 import androidx.compose.runtime.Composable 4 4 import androidx.compose.ui.platform.LocalContext 5 + import org.tensorflow.lite.Interpreter 6 + import org.tensorflow.lite.support.common.FileUtil 5 7 6 8 @Composable 7 9 actual fun initialiseObjectModel(modelPath: ModelPath): ObjectModel { 8 - val options = org.tensorflow.lite.task.vision.detector.ObjectDetector.ObjectDetectorOptions.builder().setMaxResults(5).setScoreThreshold(0f).build() 9 - val detector = org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromFileAndOptions( 10 - LocalContext.current, 11 - modelPath.androidModelPath, 12 - options 13 - ) 14 - return ObjectModel(detector) 10 + if(modelPath.androidModelPath == null){ 11 + throw IllegalArgumentException("Android model path cannot be null") 12 + } 13 + val options = Interpreter.Options().apply{ 14 + this.setNumThreads(4) 15 + } 16 + 17 + val model = FileUtil.loadMappedFile(LocalContext.current, modelPath.androidModelPath) 18 + val interpreter = Interpreter(model, options) 19 + 20 + val inputShape = interpreter.getInputTensor(0)?.shape() 21 + val outputShape = interpreter.getOutputTensor(0)?.shape() 22 + 23 + return ObjectModel(interpreter) 15 24 } 16 25 17 26 actual class ObjectModel{ 18 - private var detector: org.tensorflow.lite.task.vision.detector.ObjectDetector? = null 27 + /*private var detector: org.tensorflow.lite.task.vision.detector.ObjectDetector? = null 19 28 20 29 constructor(detector: org.tensorflow.lite.task.vision.detector.ObjectDetector){ 21 30 this.detector = detector 22 31 } 23 32 24 33 fun getDetector(): org.tensorflow.lite.task.vision.detector.ObjectDetector? { 34 + return detector 35 + }*/ 36 + 37 + private var detector: Interpreter? = null 38 + 39 + constructor(detector: Interpreter){ 40 + this.detector = detector 41 + } 42 + 43 + fun getDetector(): Interpreter? { 25 44 return detector 26 45 } 27 46 }
+1 -1
sample/composeApp/src/commonMain/kotlin/com/nate/posedetection/App.kt
··· 75 75 import kotlin.time.Clock 76 76 import kotlin.time.ExperimentalTime 77 77 78 - val androidPath = "lite-model_efficientdet_lite2_detection_metadata_1.tflite" 78 + val androidPath = "hoops.tflite" 79 79 val iosPath = "YOLOv3FP16" 80 80 @Composable 81 81 internal fun App() = AppTheme {