Sunstead trust scoring project
1"""M6 GraphSAGE (stretch). Inductive node classification on the vouch graph.
2
3PRD M6 / 6.5: GraphSAGE, 2 layers, hidden 64, out 1; nodes are contributors with
4the per-DID feature vector as node features; edges are positive vouches + co-
5contribution edges (denounce-count rides as a node feature, no signed-edge GNN).
6Trained OFFLINE, served in-process.
7
8Guardrail (PRD section 8, repeated): SHIP THE GNN ONLY IF IT BEATS THE CALIBRATED
9LightGBM BASELINE AND IS STABLE. So `train_and_compare` writes a verdict, and
10`load_if_winner` (used by fusion) returns a scorer ONLY when the GNN actually beat
11M5 on the time-split holdout. On a small, sparsely-vouched graph it won't, and the
12system correctly keeps serving M5 — "always have M4/M5 working first."
13
14Optional: needs `uv pip install -e '.[gnn]'` (torch + torch-geometric, multi-GB).
15"""
16
17from __future__ import annotations
18
19import json
20from types import SimpleNamespace
21# OpenMP dual-libomp guard (lightgbm + torch) is set in trust/__init__.py — it must
22# run before either library imports, which package init guarantees.
23
24from .config import MODEL_DIR
25from .db import connection
26from . import eigentrust, learned
27
28CKPT = MODEL_DIR / "gnn.pt"
29VERDICT = MODEL_DIR / "gnn_verdict.json"
30HIDDEN = 64
31
32
33def _sage(in_dim: int):
34 import torch
35 from torch_geometric.nn import SAGEConv
36
37 class SAGE(torch.nn.Module):
38 def __init__(self):
39 super().__init__()
40 self.c1 = SAGEConv(in_dim, HIDDEN) # inductive: generalizes to unseen nodes
41 self.c2 = SAGEConv(HIDDEN, 1)
42
43 def forward(self, x, ei):
44 import torch.nn.functional as F
45
46 return self.c2(F.relu(self.c1(x, ei)), ei).squeeze(-1)
47
48 return SAGE()
49
50
51def _build_graph(con, mean=None, std=None):
52 import torch
53
54 er = eigentrust.compute(con)
55 dids = [r[0] for r in con.execute("SELECT did FROM contributors ORDER BY did").fetchall()]
56 didx = {d: i for i, d in enumerate(dids)}
57 fcols = [c[0] for c in con.execute("DESCRIBE features").fetchall()]
58 feats = {r[0]: dict(zip(fcols, r)) for r in con.execute("SELECT * FROM features").fetchall()}
59
60 raw = torch.tensor([learned._vec(d, feats.get(d, {}), er) for d in dids], dtype=torch.float)
61 if mean is None:
62 mean, std = raw.mean(0, keepdim=True), raw.std(0, keepdim=True).clamp_min(1e-6)
63 x = (raw - mean) / std
64
65 src, dst = [], []
66 for v, s in con.execute("SELECT voucher_did, subject_did FROM vouches WHERE polarity > 0").fetchall():
67 if v in didx and s in didx: # undirected edges for SAGE mean-aggregation
68 src += [didx[v], didx[s]]; dst += [didx[s], didx[v]]
69 for a, b in con.execute( # co-contribution: authored PRs to the same repo (PRD 6.5)
70 "SELECT DISTINCT a.author_did, b.author_did FROM pull_requests a JOIN pull_requests b "
71 "ON a.repo = b.repo AND a.author_did < b.author_did"
72 ).fetchall():
73 if a in didx and b in didx:
74 src += [didx[a], didx[b]]; dst += [didx[b], didx[a]]
75 edge_index = (torch.tensor([src, dst], dtype=torch.long) if src
76 else torch.empty((2, 0), dtype=torch.long))
77
78 # node labels: soft target = clean_merge_rate; temporal split by latest labelled PR
79 lab = con.execute(
80 "SELECT author_did, AVG(clean_merge), MAX(opened_at) FROM pr_labels "
81 "WHERE clean_merge IS NOT NULL GROUP BY author_did ORDER BY MAX(opened_at)"
82 ).fetchall()
83 label = {d: float(r) for d, r, _ in lab}
84 ordered = [d for d, _, _ in lab]
85 k = max(1, int(len(ordered) * 0.7))
86 train_dids, val_dids = ordered[:k], ordered[k:]
87 y = torch.zeros(len(dids))
88 train_mask = torch.zeros(len(dids), dtype=torch.bool)
89 val_mask = torch.zeros(len(dids), dtype=torch.bool)
90 for d in ordered:
91 y[didx[d]] = label[d]
92 for d in train_dids:
93 train_mask[didx[d]] = True
94 for d in val_dids:
95 val_mask[didx[d]] = True
96
97 return SimpleNamespace(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask,
98 dids=dids, didx=didx, val_dids=val_dids, label=label, feats=feats, er=er,
99 mean=mean, std=std)
100
101
102def _acc(prob, y) -> float:
103 return float(((prob >= 0.5) == (y >= 0.5)).float().mean()) if len(y) else float("nan")
104
105
106def _m5_val_acc(g) -> float | None:
107 try:
108 s = learned.load()
109 except ImportError:
110 return None
111 if s is None or not g.val_dids:
112 return None
113 hits = sum(int((s.prob(d, g.feats.get(d, {}), g.er) >= 0.5) == (g.label[d] >= 0.5))
114 for d in g.val_dids)
115 return hits / len(g.val_dids)
116
117
118def train_and_compare(epochs: int = 300) -> dict:
119 import torch
120
121 with connection(read_only=True) as con:
122 g = _build_graph(con)
123 if not g.val_dids or int(g.train_mask.sum()) == 0:
124 raise SystemExit("not enough labelled nodes for a temporal split; seed/ingest more history")
125
126 model = _sage(g.x.size(1))
127 opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
128 lossfn = torch.nn.BCEWithLogitsLoss()
129 for _ in range(epochs):
130 model.train(); opt.zero_grad()
131 loss = lossfn(model(g.x, g.edge_index)[g.train_mask], g.y[g.train_mask])
132 loss.backward(); opt.step()
133
134 model.eval()
135 with torch.no_grad():
136 prob = torch.sigmoid(model(g.x, g.edge_index))
137 stable = bool(torch.isfinite(prob).all())
138 gnn_acc = _acc(prob[g.val_mask], g.y[g.val_mask])
139 m5_acc = _m5_val_acc(g)
140 # Beat the baseline strictly, and only when a baseline exists (PRD guardrail).
141 gnn_wins = bool(stable and m5_acc is not None and gnn_acc > m5_acc)
142
143 MODEL_DIR.mkdir(parents=True, exist_ok=True)
144 torch.save({"state": model.state_dict(), "mean": g.mean, "std": g.std, "in_dim": g.x.size(1)}, CKPT)
145 verdict = {"gnn_val_acc": round(gnn_acc, 3), "m5_val_acc": m5_acc,
146 "val_nodes": len(g.val_dids), "stable": stable, "gnn_wins": gnn_wins}
147 VERDICT.write_text(json.dumps(verdict, indent=2))
148 return verdict
149
150
151class GNNScorer:
152 """In-process inductive inference: rebuild the current graph, forward, read the node."""
153
154 def __init__(self, ckpt):
155 self.ckpt = ckpt
156
157 def prob(self, did: str) -> float:
158 import torch
159
160 with connection(read_only=True) as con:
161 g = _build_graph(con, mean=self.ckpt["mean"], std=self.ckpt["std"])
162 model = _sage(self.ckpt["in_dim"])
163 model.load_state_dict(self.ckpt["state"])
164 model.eval()
165 with torch.no_grad():
166 p = torch.sigmoid(model(g.x, g.edge_index))
167 i = g.didx.get(did)
168 return float(p[i]) if i is not None else 0.0
169
170
171def load_if_winner() -> GNNScorer | None:
172 """Serving hook used by fusion: a GNN scorer ONLY if it beat M5 (else None)."""
173 if not (VERDICT.exists() and CKPT.exists()):
174 return None
175 if not json.loads(VERDICT.read_text()).get("gnn_wins"):
176 return None
177 try:
178 import torch
179 except ImportError:
180 return None
181 return GNNScorer(torch.load(CKPT, weights_only=False))
182
183
184def main() -> None:
185 v = train_and_compare()
186 print(f"[gnn] val nodes={v['val_nodes']} GNN acc={v['gnn_val_acc']} "
187 f"M5 acc={v['m5_val_acc']} stable={v['stable']}")
188 print(f"[gnn] gnn_wins={v['gnn_wins']} -> "
189 + ("SERVED (beats calibrated M5)" if v["gnn_wins"]
190 else "NOT served; system keeps M5 (PRD guardrail: ship only if it beats the baseline)"))
191
192
193def demo() -> None:
194 """Self-check: trains, produces finite probs, writes a verdict — stability, not winning."""
195 from .db import connection as conn, init_db
196 from .seed import seed as load_seed
197
198 with conn(read_only=False) as con:
199 init_db(con)
200 load_seed(con)
201 try:
202 learned.train() # so there's an M5 baseline to compare against
203 except Exception:
204 pass
205 v = train_and_compare(epochs=200)
206 assert v["stable"], "GNN produced non-finite outputs"
207 assert isinstance(v["gnn_wins"], bool)
208 print(f"gnn_val_acc={v['gnn_val_acc']} m5_val_acc={v['m5_val_acc']} gnn_wins={v['gnn_wins']} ok")
209
210
211if __name__ == "__main__":
212 demo()