This repository has no description
1#!/usr/bin/env python3
2"""Compute embeddings for tangled_issues (title + body)."""
3
4from __future__ import annotations
5
6import os
7import sys
8from pathlib import Path
9
10import httpx
11from dotenv import load_dotenv
12
13from db import connect, init_schema, register_pgvector, set_crawl_state
14from embeddings import (
15 DEFAULT_DIM,
16 DEFAULT_MODEL,
17 batch_size,
18 embed_texts,
19 embedding_model,
20 gemini_api_key,
21 truncate,
22)
23from progress import banner, log, phase, step, summary_block
24
25REPO_ROOT = Path(__file__).resolve().parent.parent
26CRAWL_KEY = "issues:embed"
27
28
29def _issue_limit() -> int | None:
30 raw = os.getenv("TANGLED_ISSUE_EMBED_LIMIT", "").strip()
31 if not raw:
32 return None
33 return max(1, int(raw))
34
35
36def _force_reembed() -> bool:
37 return os.getenv("TANGLED_ISSUE_EMBED_FORCE", "").strip().lower() in ("1", "true", "yes")
38
39
40def _issue_text(title: str | None, body: str | None) -> str:
41 parts = [p for p in (title, body) if p and p.strip()]
42 return truncate("\n\n".join(parts))
43
44
45def run_embed_issues(dsn: str) -> dict[str, int]:
46 api_key = gemini_api_key()
47 model = embedding_model()
48 bs = batch_size()
49 issue_limit = _issue_limit()
50 force = _force_reembed()
51
52 banner("ISSUE EMBED — Gemini → tangled_issues.embedding")
53 log("embed-issues", f"Model: {model} dim={DEFAULT_DIM} L2-normalized batch={bs}")
54 if issue_limit:
55 log("embed-issues", f"Limit: {issue_limit}")
56 if force:
57 log("embed-issues", "Force re-embed enabled")
58
59 where = "1=1"
60 if not force:
61 where += " and embedding is null"
62 query = f"""
63 select uri, author_handle, title, body
64 from tangled_issues
65 where {where}
66 and coalesce(nullif(trim(title), ''), nullif(trim(body), '')) is not null
67 order by fetched_at desc
68 """
69 if issue_limit:
70 query += f" limit {issue_limit}"
71
72 with connect(dsn) as conn:
73 rows = conn.execute(query).fetchall()
74
75 if not rows:
76 log("embed-issues", "Nothing to embed (run fetch-issues first).")
77 return {"embedded": 0, "batches": 0, "errors": 0}
78
79 log("embed-issues", f"Embedding {len(rows)} issues …")
80 stats = {"embedded": 0, "batches": 0, "errors": 0}
81
82 phase(1, "Gemini batchEmbedContents → tangled_issues.embedding")
83
84 with httpx.Client() as client, connect(dsn) as conn:
85 register_pgvector(conn)
86 set_crawl_state(
87 conn,
88 key=CRAWL_KEY,
89 status="running",
90 meta={"count": len(rows), "model": model, "dim": DEFAULT_DIM},
91 )
92 conn.commit()
93
94 for start in range(0, len(rows), bs):
95 batch = rows[start : start + bs]
96 texts = [_issue_text(r.get("title"), r.get("body")) for r in batch]
97 labels = [
98 f"{r.get('author_handle') or '?'}: {(r.get('title') or '')[:40]}"
99 for r in batch
100 ]
101
102 try:
103 vectors = embed_texts(client, api_key=api_key, texts=texts)
104 except Exception as exc:
105 stats["errors"] += len(batch)
106 step(
107 "embed-issues",
108 min(start + len(batch), len(rows)),
109 len(rows),
110 f"ERROR batch: {exc}",
111 )
112 continue
113
114 for row, vec in zip(batch, vectors, strict=True):
115 conn.execute(
116 """
117 update tangled_issues
118 set embedding = %s, embedding_model = %s, embedded_at = now()
119 where uri = %s
120 """,
121 (vec, model, row["uri"]),
122 )
123
124 stats["embedded"] += len(batch)
125 stats["batches"] += 1
126 conn.commit()
127 n = stats["embedded"]
128 if n <= 10 or n % bs == 0 or n == len(rows):
129 step("embed-issues", n, len(rows), f"OK {labels[-1]}")
130
131 set_crawl_state(conn, key=CRAWL_KEY, status="complete", meta=stats)
132 conn.commit()
133
134 summary_block(
135 "Issue embed complete",
136 [f"Embedded: {stats['embedded']}", f"Errors: {stats['errors']}"],
137 )
138 return stats
139
140
141def main() -> None:
142 for candidate in (REPO_ROOT / ".env", Path(__file__).parent / ".env"):
143 if candidate.exists():
144 load_dotenv(candidate)
145 break
146 else:
147 load_dotenv()
148
149 dsn = os.getenv("DB_CONNECTION_STRING", "").strip()
150 if not dsn:
151 print("ERROR: DB_CONNECTION_STRING not set", file=sys.stderr)
152 raise SystemExit(1)
153
154 init_schema(dsn)
155 run_embed_issues(dsn)
156
157
158if __name__ == "__main__":
159 try:
160 main()
161 except KeyboardInterrupt:
162 print("\nInterrupted.", file=sys.stderr)
163 raise SystemExit(130) from None