Sunstead trust scoring project
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 8.2 kB View raw
1"""M6 GraphSAGE (stretch). Inductive node classification on the vouch graph. 2 3PRD M6 / 6.5: GraphSAGE, 2 layers, hidden 64, out 1; nodes are contributors with 4the per-DID feature vector as node features; edges are positive vouches + co- 5contribution edges (denounce-count rides as a node feature, no signed-edge GNN). 6Trained OFFLINE, served in-process. 7 8Guardrail (PRD section 8, repeated): SHIP THE GNN ONLY IF IT BEATS THE CALIBRATED 9LightGBM BASELINE AND IS STABLE. So `train_and_compare` writes a verdict, and 10`load_if_winner` (used by fusion) returns a scorer ONLY when the GNN actually beat 11M5 on the time-split holdout. On a small, sparsely-vouched graph it won't, and the 12system correctly keeps serving M5 — "always have M4/M5 working first." 13 14Optional: needs `uv pip install -e '.[gnn]'` (torch + torch-geometric, multi-GB). 15""" 16 17from __future__ import annotations 18 19import json 20from types import SimpleNamespace 21# OpenMP dual-libomp guard (lightgbm + torch) is set in trust/__init__.py — it must 22# run before either library imports, which package init guarantees. 23 24from .config import MODEL_DIR 25from .db import connection 26from . import eigentrust, learned 27 28CKPT = MODEL_DIR / "gnn.pt" 29VERDICT = MODEL_DIR / "gnn_verdict.json" 30HIDDEN = 64 31 32 33def _sage(in_dim: int): 34 import torch 35 from torch_geometric.nn import SAGEConv 36 37 class SAGE(torch.nn.Module): 38 def __init__(self): 39 super().__init__() 40 self.c1 = SAGEConv(in_dim, HIDDEN) # inductive: generalizes to unseen nodes 41 self.c2 = SAGEConv(HIDDEN, 1) 42 43 def forward(self, x, ei): 44 import torch.nn.functional as F 45 46 return self.c2(F.relu(self.c1(x, ei)), ei).squeeze(-1) 47 48 return SAGE() 49 50 51def _build_graph(con, mean=None, std=None): 52 import torch 53 54 er = eigentrust.compute(con) 55 dids = [r[0] for r in con.execute("SELECT did FROM contributors ORDER BY did").fetchall()] 56 didx = {d: i for i, d in enumerate(dids)} 57 fcols = [c[0] for c in con.execute("DESCRIBE features").fetchall()] 58 feats = {r[0]: dict(zip(fcols, r)) for r in con.execute("SELECT * FROM features").fetchall()} 59 60 raw = torch.tensor([learned._vec(d, feats.get(d, {}), er) for d in dids], dtype=torch.float) 61 if mean is None: 62 mean, std = raw.mean(0, keepdim=True), raw.std(0, keepdim=True).clamp_min(1e-6) 63 x = (raw - mean) / std 64 65 src, dst = [], [] 66 for v, s in con.execute("SELECT voucher_did, subject_did FROM vouches WHERE polarity > 0").fetchall(): 67 if v in didx and s in didx: # undirected edges for SAGE mean-aggregation 68 src += [didx[v], didx[s]]; dst += [didx[s], didx[v]] 69 for a, b in con.execute( # co-contribution: authored PRs to the same repo (PRD 6.5) 70 "SELECT DISTINCT a.author_did, b.author_did FROM pull_requests a JOIN pull_requests b " 71 "ON a.repo = b.repo AND a.author_did < b.author_did" 72 ).fetchall(): 73 if a in didx and b in didx: 74 src += [didx[a], didx[b]]; dst += [didx[b], didx[a]] 75 edge_index = (torch.tensor([src, dst], dtype=torch.long) if src 76 else torch.empty((2, 0), dtype=torch.long)) 77 78 # node labels: soft target = clean_merge_rate; temporal split by latest labelled PR 79 lab = con.execute( 80 "SELECT author_did, AVG(clean_merge), MAX(opened_at) FROM pr_labels " 81 "WHERE clean_merge IS NOT NULL GROUP BY author_did ORDER BY MAX(opened_at)" 82 ).fetchall() 83 label = {d: float(r) for d, r, _ in lab} 84 ordered = [d for d, _, _ in lab] 85 k = max(1, int(len(ordered) * 0.7)) 86 train_dids, val_dids = ordered[:k], ordered[k:] 87 y = torch.zeros(len(dids)) 88 train_mask = torch.zeros(len(dids), dtype=torch.bool) 89 val_mask = torch.zeros(len(dids), dtype=torch.bool) 90 for d in ordered: 91 y[didx[d]] = label[d] 92 for d in train_dids: 93 train_mask[didx[d]] = True 94 for d in val_dids: 95 val_mask[didx[d]] = True 96 97 return SimpleNamespace(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, 98 dids=dids, didx=didx, val_dids=val_dids, label=label, feats=feats, er=er, 99 mean=mean, std=std) 100 101 102def _acc(prob, y) -> float: 103 return float(((prob >= 0.5) == (y >= 0.5)).float().mean()) if len(y) else float("nan") 104 105 106def _m5_val_acc(g) -> float | None: 107 try: 108 s = learned.load() 109 except ImportError: 110 return None 111 if s is None or not g.val_dids: 112 return None 113 hits = sum(int((s.prob(d, g.feats.get(d, {}), g.er) >= 0.5) == (g.label[d] >= 0.5)) 114 for d in g.val_dids) 115 return hits / len(g.val_dids) 116 117 118def train_and_compare(epochs: int = 300) -> dict: 119 import torch 120 121 with connection(read_only=True) as con: 122 g = _build_graph(con) 123 if not g.val_dids or int(g.train_mask.sum()) == 0: 124 raise SystemExit("not enough labelled nodes for a temporal split; seed/ingest more history") 125 126 model = _sage(g.x.size(1)) 127 opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) 128 lossfn = torch.nn.BCEWithLogitsLoss() 129 for _ in range(epochs): 130 model.train(); opt.zero_grad() 131 loss = lossfn(model(g.x, g.edge_index)[g.train_mask], g.y[g.train_mask]) 132 loss.backward(); opt.step() 133 134 model.eval() 135 with torch.no_grad(): 136 prob = torch.sigmoid(model(g.x, g.edge_index)) 137 stable = bool(torch.isfinite(prob).all()) 138 gnn_acc = _acc(prob[g.val_mask], g.y[g.val_mask]) 139 m5_acc = _m5_val_acc(g) 140 # Beat the baseline strictly, and only when a baseline exists (PRD guardrail). 141 gnn_wins = bool(stable and m5_acc is not None and gnn_acc > m5_acc) 142 143 MODEL_DIR.mkdir(parents=True, exist_ok=True) 144 torch.save({"state": model.state_dict(), "mean": g.mean, "std": g.std, "in_dim": g.x.size(1)}, CKPT) 145 verdict = {"gnn_val_acc": round(gnn_acc, 3), "m5_val_acc": m5_acc, 146 "val_nodes": len(g.val_dids), "stable": stable, "gnn_wins": gnn_wins} 147 VERDICT.write_text(json.dumps(verdict, indent=2)) 148 return verdict 149 150 151class GNNScorer: 152 """In-process inductive inference: rebuild the current graph, forward, read the node.""" 153 154 def __init__(self, ckpt): 155 self.ckpt = ckpt 156 157 def prob(self, did: str) -> float: 158 import torch 159 160 with connection(read_only=True) as con: 161 g = _build_graph(con, mean=self.ckpt["mean"], std=self.ckpt["std"]) 162 model = _sage(self.ckpt["in_dim"]) 163 model.load_state_dict(self.ckpt["state"]) 164 model.eval() 165 with torch.no_grad(): 166 p = torch.sigmoid(model(g.x, g.edge_index)) 167 i = g.didx.get(did) 168 return float(p[i]) if i is not None else 0.0 169 170 171def load_if_winner() -> GNNScorer | None: 172 """Serving hook used by fusion: a GNN scorer ONLY if it beat M5 (else None).""" 173 if not (VERDICT.exists() and CKPT.exists()): 174 return None 175 if not json.loads(VERDICT.read_text()).get("gnn_wins"): 176 return None 177 try: 178 import torch 179 except ImportError: 180 return None 181 return GNNScorer(torch.load(CKPT, weights_only=False)) 182 183 184def main() -> None: 185 v = train_and_compare() 186 print(f"[gnn] val nodes={v['val_nodes']} GNN acc={v['gnn_val_acc']} " 187 f"M5 acc={v['m5_val_acc']} stable={v['stable']}") 188 print(f"[gnn] gnn_wins={v['gnn_wins']} -> " 189 + ("SERVED (beats calibrated M5)" if v["gnn_wins"] 190 else "NOT served; system keeps M5 (PRD guardrail: ship only if it beats the baseline)")) 191 192 193def demo() -> None: 194 """Self-check: trains, produces finite probs, writes a verdict — stability, not winning.""" 195 from .db import connection as conn, init_db 196 from .seed import seed as load_seed 197 198 with conn(read_only=False) as con: 199 init_db(con) 200 load_seed(con) 201 try: 202 learned.train() # so there's an M5 baseline to compare against 203 except Exception: 204 pass 205 v = train_and_compare(epochs=200) 206 assert v["stable"], "GNN produced non-finite outputs" 207 assert isinstance(v["gnn_wins"], bool) 208 print(f"gnn_val_acc={v['gnn_val_acc']} m5_val_acc={v['m5_val_acc']} gnn_wins={v['gnn_wins']} ok") 209 210 211if __name__ == "__main__": 212 demo()