This repository has no description
1#!/usr/bin/env python3
2"""
3Sync trained .tflite models from Colab output into the sample app's assets dir.
4
5The Colab notebook writes trained models to ``collab/output/<run-id>/*.tflite``
6(via Drive sync). This script copies those models into the sample app's
7Android assets directory so ``DiscoverModels.android.kt`` will pick them up
8on the next build.
9
10Usage::
11
12 tools/sync_models.py # sync the latest run (default)
13 tools/sync_models.py --list # list available runs newest-first
14 tools/sync_models.py --run <run-id> # sync a specific run
15 tools/sync_models.py --clean # remove previously-synced models
16
17Synced files are renamed ``<run-id>__<model-base>.tflite`` (double underscore).
18The double underscore is deliberate: ``DiscoverModels.android.kt`` replaces
19single underscores with spaces in the picker, so ``__`` becomes a clear visual
20break between the run-id and the model name in the dropdown.
21
22``--clean`` preserves the three baseline models bundled with the repo.
23"""
24
25from __future__ import annotations
26
27import argparse
28import re
29import shutil
30import sys
31from pathlib import Path
32
33REPO_ROOT = Path(__file__).resolve().parent.parent
34COLAB_OUTPUT_DIR = REPO_ROOT / "collab" / "output"
35ASSETS_DIR = REPO_ROOT / "sample" / "composeApp" / "src" / "androidMain" / "assets"
36
37# Baseline models that ship with the repo. --clean preserves these; everything
38# else under ASSETS_DIR with a .tflite extension is treated as a synced file
39# and removed by --clean.
40BASELINE_MODELS = frozenset(
41 {
42 "yolo11n_dataset_dataset.tflite",
43 "yolo11n_su_416.tflite",
44 "yolov10n_float16.tflite",
45 }
46)
47
48SAFE_NAME = re.compile(r"[^A-Za-z0-9._-]")
49
50
51def sanitize(name: str) -> str:
52 """Replace anything that isn't safe for an Android asset filename."""
53 return SAFE_NAME.sub("_", name).strip("_") or "unnamed"
54
55
56def list_runs() -> list[Path]:
57 """Return run-id directories under collab/output, newest first by mtime."""
58 if not COLAB_OUTPUT_DIR.exists():
59 return []
60 runs = [p for p in COLAB_OUTPUT_DIR.iterdir() if p.is_dir()]
61 runs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
62 return runs
63
64
65def cmd_list() -> int:
66 runs = list_runs()
67 if not runs:
68 print(f"(no runs found under {COLAB_OUTPUT_DIR.relative_to(REPO_ROOT)})")
69 return 0
70 print(f"Runs under {COLAB_OUTPUT_DIR.relative_to(REPO_ROOT)} (newest first):")
71 for run in runs:
72 tflites = sorted(run.glob("*.tflite"))
73 if tflites:
74 tflite_summary = ", ".join(p.name for p in tflites)
75 else:
76 tflite_summary = "(no .tflite files)"
77 print(f" {run.name} — {tflite_summary}")
78 return 0
79
80
81def cmd_sync(run_id: str) -> int:
82 run_dir = COLAB_OUTPUT_DIR / run_id
83 if not run_dir.is_dir():
84 print(f"error: no such run directory: {run_dir}", file=sys.stderr)
85 available = [r.name for r in list_runs()]
86 if available:
87 print(f" available: {', '.join(available)}", file=sys.stderr)
88 return 2
89
90 tflites = sorted(run_dir.glob("*.tflite"))
91 if not tflites:
92 print(f"error: no .tflite files in {run_dir}", file=sys.stderr)
93 return 2
94
95 if not ASSETS_DIR.exists():
96 print(f"error: assets dir does not exist: {ASSETS_DIR}", file=sys.stderr)
97 return 2
98
99 safe_run = sanitize(run_id)
100 print(f"Syncing run '{run_id}' ({len(tflites)} model(s)) -> {ASSETS_DIR.relative_to(REPO_ROOT)}")
101 for src in tflites:
102 base = sanitize(src.stem)
103 dest_name = f"{safe_run}__{base}.tflite"
104 if dest_name in BASELINE_MODELS:
105 print(f" skip {src.name} (would shadow baseline {dest_name})")
106 continue
107 dest = ASSETS_DIR / dest_name
108 shutil.copy2(src, dest)
109 size_mb = dest.stat().st_size / (1024 * 1024)
110 print(f" copied {src.name} -> {dest.name} ({size_mb:.1f} MB)")
111 return 0
112
113
114def cmd_clean() -> int:
115 if not ASSETS_DIR.exists():
116 print(f"error: assets dir does not exist: {ASSETS_DIR}", file=sys.stderr)
117 return 2
118 removed = 0
119 for f in sorted(ASSETS_DIR.glob("*.tflite")):
120 if f.name in BASELINE_MODELS:
121 continue
122 f.unlink()
123 print(f" removed {f.name}")
124 removed += 1
125 if removed == 0:
126 print("(nothing to clean — only baseline models present)")
127 else:
128 print(f"removed {removed} synced model(s); baselines preserved")
129 return 0
130
131
132def cmd_latest() -> int:
133 runs = list_runs()
134 if not runs:
135 print(f"error: no runs under {COLAB_OUTPUT_DIR}", file=sys.stderr)
136 return 2
137 return cmd_sync(runs[0].name)
138
139
140def main(argv: list[str]) -> int:
141 parser = argparse.ArgumentParser(
142 description="Sync Colab-trained .tflite models into the sample app's assets dir.",
143 )
144 group = parser.add_mutually_exclusive_group()
145 group.add_argument("--list", action="store_true", help="list available runs newest-first")
146 group.add_argument("--run", metavar="RUN_ID", help="sync the named run from collab/output/")
147 group.add_argument("--clean", action="store_true", help="remove previously-synced models (preserve baselines)")
148 args = parser.parse_args(argv)
149
150 if args.list:
151 return cmd_list()
152 if args.clean:
153 return cmd_clean()
154 if args.run:
155 return cmd_sync(args.run)
156 # Default: sync the latest run.
157 return cmd_latest()
158
159
160if __name__ == "__main__":
161 sys.exit(main(sys.argv[1:]))