This repository has no description
1"""Gemini embeddings: gemini-embedding-001, 1536-dim, L2-normalized for cosine."""
2
3from __future__ import annotations
4
5import math
6import os
7
8import httpx
9
10DEFAULT_MODEL = "gemini-embedding-001"
11DEFAULT_DIM = 1536
12MAX_CHARS = 24_000
13GEMINI_BATCH_URL = (
14 "https://generativelanguage.googleapis.com/v1beta/"
15 "models/gemini-embedding-001:batchEmbedContents"
16)
17
18
19def embedding_model() -> str:
20 return os.getenv("TANGLED_EMBEDDING_MODEL", DEFAULT_MODEL).strip() or DEFAULT_MODEL
21
22
23def batch_size() -> int:
24 raw = os.getenv("TANGLED_EMBED_BATCH_SIZE", "16").strip()
25 return max(1, min(100, int(raw)))
26
27
28def gemini_api_key() -> str:
29 key = (
30 os.getenv("GEMINI_API_KEY", "").strip()
31 or os.getenv("GOOGLE_API_KEY", "").strip()
32 )
33 if not key:
34 raise RuntimeError(
35 "GEMINI_API_KEY (or GOOGLE_API_KEY) is not set. "
36 "Add it to .env to compute embeddings."
37 )
38 return key
39
40
41def truncate(text: str) -> str:
42 text = text.strip()
43 return text[:MAX_CHARS] if len(text) > MAX_CHARS else text
44
45
46def l2_normalize(vec: list[float]) -> list[float]:
47 norm = math.sqrt(sum(x * x for x in vec))
48 if norm == 0:
49 return vec
50 return [x / norm for x in vec]
51
52
53def embed_texts(
54 client: httpx.Client,
55 *,
56 api_key: str,
57 texts: list[str],
58 task_type: str = "RETRIEVAL_DOCUMENT",
59) -> list[list[float]]:
60 """Embed texts via Gemini batchEmbedContents; returns L2-normalized 1536-dim vectors."""
61 if not texts:
62 return []
63
64 requests = [
65 {
66 "model": f"models/{DEFAULT_MODEL}",
67 "content": {"parts": [{"text": text}]},
68 "taskType": task_type,
69 "outputDimensionality": DEFAULT_DIM,
70 }
71 for text in texts
72 ]
73
74 resp = client.post(
75 GEMINI_BATCH_URL,
76 headers={
77 "x-goog-api-key": api_key,
78 "Content-Type": "application/json",
79 },
80 json={"requests": requests},
81 timeout=120.0,
82 )
83 if resp.status_code != 200:
84 raise RuntimeError(
85 f"Gemini embeddings HTTP {resp.status_code}: {resp.text[:500]}"
86 )
87
88 embeddings = resp.json().get("embeddings") or []
89 if len(embeddings) != len(texts):
90 raise RuntimeError(f"Expected {len(texts)} embeddings, got {len(embeddings)}")
91
92 vectors: list[list[float]] = []
93 for row in embeddings:
94 values = row.get("values")
95 if not isinstance(values, list):
96 raise RuntimeError("Gemini response missing embedding values")
97 if len(values) != DEFAULT_DIM:
98 raise RuntimeError(
99 f"Expected dim {DEFAULT_DIM}, got {len(values)}. "
100 "Check outputDimensionality support for your API key."
101 )
102 vectors.append(l2_normalize(values))
103 return vectors