This repository has no description
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 2.8 kB View raw
1"""Gemini embeddings: gemini-embedding-001, 1536-dim, L2-normalized for cosine.""" 2 3from __future__ import annotations 4 5import math 6import os 7 8import httpx 9 10DEFAULT_MODEL = "gemini-embedding-001" 11DEFAULT_DIM = 1536 12MAX_CHARS = 24_000 13GEMINI_BATCH_URL = ( 14 "https://generativelanguage.googleapis.com/v1beta/" 15 "models/gemini-embedding-001:batchEmbedContents" 16) 17 18 19def embedding_model() -> str: 20 return os.getenv("TANGLED_EMBEDDING_MODEL", DEFAULT_MODEL).strip() or DEFAULT_MODEL 21 22 23def batch_size() -> int: 24 raw = os.getenv("TANGLED_EMBED_BATCH_SIZE", "16").strip() 25 return max(1, min(100, int(raw))) 26 27 28def gemini_api_key() -> str: 29 key = ( 30 os.getenv("GEMINI_API_KEY", "").strip() 31 or os.getenv("GOOGLE_API_KEY", "").strip() 32 ) 33 if not key: 34 raise RuntimeError( 35 "GEMINI_API_KEY (or GOOGLE_API_KEY) is not set. " 36 "Add it to .env to compute embeddings." 37 ) 38 return key 39 40 41def truncate(text: str) -> str: 42 text = text.strip() 43 return text[:MAX_CHARS] if len(text) > MAX_CHARS else text 44 45 46def l2_normalize(vec: list[float]) -> list[float]: 47 norm = math.sqrt(sum(x * x for x in vec)) 48 if norm == 0: 49 return vec 50 return [x / norm for x in vec] 51 52 53def embed_texts( 54 client: httpx.Client, 55 *, 56 api_key: str, 57 texts: list[str], 58 task_type: str = "RETRIEVAL_DOCUMENT", 59) -> list[list[float]]: 60 """Embed texts via Gemini batchEmbedContents; returns L2-normalized 1536-dim vectors.""" 61 if not texts: 62 return [] 63 64 requests = [ 65 { 66 "model": f"models/{DEFAULT_MODEL}", 67 "content": {"parts": [{"text": text}]}, 68 "taskType": task_type, 69 "outputDimensionality": DEFAULT_DIM, 70 } 71 for text in texts 72 ] 73 74 resp = client.post( 75 GEMINI_BATCH_URL, 76 headers={ 77 "x-goog-api-key": api_key, 78 "Content-Type": "application/json", 79 }, 80 json={"requests": requests}, 81 timeout=120.0, 82 ) 83 if resp.status_code != 200: 84 raise RuntimeError( 85 f"Gemini embeddings HTTP {resp.status_code}: {resp.text[:500]}" 86 ) 87 88 embeddings = resp.json().get("embeddings") or [] 89 if len(embeddings) != len(texts): 90 raise RuntimeError(f"Expected {len(texts)} embeddings, got {len(embeddings)}") 91 92 vectors: list[list[float]] = [] 93 for row in embeddings: 94 values = row.get("values") 95 if not isinstance(values, list): 96 raise RuntimeError("Gemini response missing embedding values") 97 if len(values) != DEFAULT_DIM: 98 raise RuntimeError( 99 f"Expected dim {DEFAULT_DIM}, got {len(values)}. " 100 "Check outputDimensionality support for your API key." 101 ) 102 vectors.append(l2_normalize(values)) 103 return vectors