This repository has no description
0

Configure Feed

Select the types of activity you want to include in your feed.

at master 5.5 kB View raw
1#!/usr/bin/env python3 2""" 3Sync trained .tflite models from Colab output into the sample app's assets dir. 4 5The Colab notebook writes trained models to ``collab/output/<run-id>/*.tflite`` 6(via Drive sync). This script copies those models into the sample app's 7Android assets directory so ``DiscoverModels.android.kt`` will pick them up 8on the next build. 9 10Usage:: 11 12 tools/sync_models.py # sync the latest run (default) 13 tools/sync_models.py --list # list available runs newest-first 14 tools/sync_models.py --run <run-id> # sync a specific run 15 tools/sync_models.py --clean # remove previously-synced models 16 17Synced files are renamed ``<run-id>__<model-base>.tflite`` (double underscore). 18The double underscore is deliberate: ``DiscoverModels.android.kt`` replaces 19single underscores with spaces in the picker, so ``__`` becomes a clear visual 20break between the run-id and the model name in the dropdown. 21 22``--clean`` preserves the three baseline models bundled with the repo. 23""" 24 25from __future__ import annotations 26 27import argparse 28import re 29import shutil 30import sys 31from pathlib import Path 32 33REPO_ROOT = Path(__file__).resolve().parent.parent 34COLAB_OUTPUT_DIR = REPO_ROOT / "collab" / "output" 35ASSETS_DIR = REPO_ROOT / "sample" / "composeApp" / "src" / "androidMain" / "assets" 36 37# Baseline models that ship with the repo. --clean preserves these; everything 38# else under ASSETS_DIR with a .tflite extension is treated as a synced file 39# and removed by --clean. 40BASELINE_MODELS = frozenset( 41 { 42 "yolo11n_dataset_dataset.tflite", 43 "yolo11n_su_416.tflite", 44 "yolov10n_float16.tflite", 45 } 46) 47 48SAFE_NAME = re.compile(r"[^A-Za-z0-9._-]") 49 50 51def sanitize(name: str) -> str: 52 """Replace anything that isn't safe for an Android asset filename.""" 53 return SAFE_NAME.sub("_", name).strip("_") or "unnamed" 54 55 56def list_runs() -> list[Path]: 57 """Return run-id directories under collab/output, newest first by mtime.""" 58 if not COLAB_OUTPUT_DIR.exists(): 59 return [] 60 runs = [p for p in COLAB_OUTPUT_DIR.iterdir() if p.is_dir()] 61 runs.sort(key=lambda p: p.stat().st_mtime, reverse=True) 62 return runs 63 64 65def cmd_list() -> int: 66 runs = list_runs() 67 if not runs: 68 print(f"(no runs found under {COLAB_OUTPUT_DIR.relative_to(REPO_ROOT)})") 69 return 0 70 print(f"Runs under {COLAB_OUTPUT_DIR.relative_to(REPO_ROOT)} (newest first):") 71 for run in runs: 72 tflites = sorted(run.glob("*.tflite")) 73 if tflites: 74 tflite_summary = ", ".join(p.name for p in tflites) 75 else: 76 tflite_summary = "(no .tflite files)" 77 print(f" {run.name}{tflite_summary}") 78 return 0 79 80 81def cmd_sync(run_id: str) -> int: 82 run_dir = COLAB_OUTPUT_DIR / run_id 83 if not run_dir.is_dir(): 84 print(f"error: no such run directory: {run_dir}", file=sys.stderr) 85 available = [r.name for r in list_runs()] 86 if available: 87 print(f" available: {', '.join(available)}", file=sys.stderr) 88 return 2 89 90 tflites = sorted(run_dir.glob("*.tflite")) 91 if not tflites: 92 print(f"error: no .tflite files in {run_dir}", file=sys.stderr) 93 return 2 94 95 if not ASSETS_DIR.exists(): 96 print(f"error: assets dir does not exist: {ASSETS_DIR}", file=sys.stderr) 97 return 2 98 99 safe_run = sanitize(run_id) 100 print(f"Syncing run '{run_id}' ({len(tflites)} model(s)) -> {ASSETS_DIR.relative_to(REPO_ROOT)}") 101 for src in tflites: 102 base = sanitize(src.stem) 103 dest_name = f"{safe_run}__{base}.tflite" 104 if dest_name in BASELINE_MODELS: 105 print(f" skip {src.name} (would shadow baseline {dest_name})") 106 continue 107 dest = ASSETS_DIR / dest_name 108 shutil.copy2(src, dest) 109 size_mb = dest.stat().st_size / (1024 * 1024) 110 print(f" copied {src.name} -> {dest.name} ({size_mb:.1f} MB)") 111 return 0 112 113 114def cmd_clean() -> int: 115 if not ASSETS_DIR.exists(): 116 print(f"error: assets dir does not exist: {ASSETS_DIR}", file=sys.stderr) 117 return 2 118 removed = 0 119 for f in sorted(ASSETS_DIR.glob("*.tflite")): 120 if f.name in BASELINE_MODELS: 121 continue 122 f.unlink() 123 print(f" removed {f.name}") 124 removed += 1 125 if removed == 0: 126 print("(nothing to clean — only baseline models present)") 127 else: 128 print(f"removed {removed} synced model(s); baselines preserved") 129 return 0 130 131 132def cmd_latest() -> int: 133 runs = list_runs() 134 if not runs: 135 print(f"error: no runs under {COLAB_OUTPUT_DIR}", file=sys.stderr) 136 return 2 137 return cmd_sync(runs[0].name) 138 139 140def main(argv: list[str]) -> int: 141 parser = argparse.ArgumentParser( 142 description="Sync Colab-trained .tflite models into the sample app's assets dir.", 143 ) 144 group = parser.add_mutually_exclusive_group() 145 group.add_argument("--list", action="store_true", help="list available runs newest-first") 146 group.add_argument("--run", metavar="RUN_ID", help="sync the named run from collab/output/") 147 group.add_argument("--clean", action="store_true", help="remove previously-synced models (preserve baselines)") 148 args = parser.parse_args(argv) 149 150 if args.list: 151 return cmd_list() 152 if args.clean: 153 return cmd_clean() 154 if args.run: 155 return cmd_sync(args.run) 156 # Default: sync the latest run. 157 return cmd_latest() 158 159 160if __name__ == "__main__": 161 sys.exit(main(sys.argv[1:]))