This repository has no description
1#!/usr/bin/env python3
2"""Compute and store one embedding vector per README in tangled_readmes."""
3
4from __future__ import annotations
5
6import os
7import sys
8from pathlib import Path
9
10import httpx
11from dotenv import load_dotenv
12
13from db import connect, init_schema, register_pgvector, set_crawl_state
14from embeddings import (
15 DEFAULT_DIM,
16 DEFAULT_MODEL,
17 batch_size,
18 embed_texts,
19 embedding_model,
20 gemini_api_key,
21 truncate,
22)
23from progress import banner, log, phase, step, summary_block
24
25REPO_ROOT = Path(__file__).resolve().parent.parent
26CRAWL_KEY = "readmes:embed"
27
28
29def _repo_limit() -> int | None:
30 raw = os.getenv("TANGLED_EMBED_README_LIMIT", "").strip()
31 if not raw:
32 return None
33 return max(1, int(raw))
34
35
36def _force_reembed() -> bool:
37 return os.getenv("TANGLED_EMBED_FORCE", "").strip().lower() in ("1", "true", "yes")
38
39
40def _select_query(*, force: bool, limit: int | None) -> str:
41 where = "status = 'found' and content is not null"
42 if not force:
43 where += " and embedding is null"
44 query = f"""
45 select repo_did, owner_handle, repo_name, content
46 from tangled_readmes
47 where {where}
48 order by fetched_at desc
49 """
50 if limit:
51 query += f" limit {limit}"
52 return query
53
54
55def run_embed_readmes(dsn: str) -> dict[str, int]:
56 api_key = gemini_api_key()
57 model = embedding_model()
58 bs = batch_size()
59 repo_limit = _repo_limit()
60 force = _force_reembed()
61
62 banner("README EMBED — Gemini → tangled_readmes.embedding")
63 log("embed", f"Model: {model} dim={DEFAULT_DIM} L2-normalized batch={bs}")
64 if repo_limit:
65 log("embed", f"Limit: {repo_limit}")
66 if force:
67 log("embed", "Force re-embed all matching rows")
68
69 with connect(dsn) as conn:
70 register_pgvector(conn)
71 rows = conn.execute(_select_query(force=force, limit=repo_limit)).fetchall()
72
73 if not rows:
74 log("embed", "Nothing to embed (run check-readmes first, or set TANGLED_EMBED_FORCE=1).")
75 return {"embedded": 0, "batches": 0, "errors": 0}
76
77 log("embed", f"Embedding {len(rows)} READMEs …")
78 stats = {"embedded": 0, "batches": 0, "errors": 0}
79
80 phase(1, "Gemini batchEmbedContents → tangled_readmes.embedding")
81
82 with httpx.Client() as client, connect(dsn) as conn:
83 register_pgvector(conn)
84 set_crawl_state(
85 conn,
86 key=CRAWL_KEY,
87 status="running",
88 meta={"count": len(rows), "model": model, "dim": DEFAULT_DIM},
89 )
90 conn.commit()
91
92 for start in range(0, len(rows), bs):
93 batch = rows[start : start + bs]
94 texts = [truncate(r["content"]) for r in batch]
95 labels = [
96 f"{r.get('owner_handle') or '?'}/{r.get('repo_name') or r['repo_did'][:16]}"
97 for r in batch
98 ]
99
100 try:
101 vectors = embed_texts(client, api_key=api_key, texts=texts)
102 except Exception as exc:
103 stats["errors"] += len(batch)
104 step(
105 "embed",
106 min(start + len(batch), len(rows)),
107 len(rows),
108 f"ERROR batch @ {start}: {exc}",
109 )
110 continue
111
112 for row, vec in zip(batch, vectors, strict=True):
113 conn.execute(
114 """
115 update tangled_readmes
116 set embedding = %s,
117 embedding_model = %s,
118 embedded_at = now()
119 where repo_did = %s
120 """,
121 (vec, model, row["repo_did"]),
122 )
123
124 stats["embedded"] += len(batch)
125 stats["batches"] += 1
126 conn.commit()
127
128 n = stats["embedded"]
129 if n <= 10 or n % bs == 0 or n == len(rows):
130 step("embed", n, len(rows), f"OK {labels[-1]}")
131
132 set_crawl_state(conn, key=CRAWL_KEY, status="complete", meta=stats)
133 conn.commit()
134
135 summary_block(
136 "README embed complete",
137 [
138 f"Embedded: {stats['embedded']}",
139 f"Batches: {stats['batches']}",
140 f"Errors: {stats['errors']}",
141 "",
142 "Cosine search (L2-normalized vectors):",
143 " order by embedding <=> query_vec",
144 ],
145 )
146 return stats
147
148
149def main() -> None:
150 for candidate in (REPO_ROOT / ".env", Path(__file__).parent / ".env"):
151 if candidate.exists():
152 load_dotenv(candidate)
153 break
154 else:
155 load_dotenv()
156
157 dsn = os.getenv("DB_CONNECTION_STRING", "").strip()
158 if not dsn:
159 print("ERROR: DB_CONNECTION_STRING not set", file=sys.stderr)
160 raise SystemExit(1)
161
162 init_schema(dsn)
163 run_embed_readmes(dsn)
164
165
166if __name__ == "__main__":
167 try:
168 main()
169 except KeyboardInterrupt:
170 print("\nInterrupted.", file=sys.stderr)
171 raise SystemExit(130) from None