Another project
1use std::collections::{BTreeMap, BTreeSet};
2
3use bone_types::{Parameter, ParameterIndex, ParentIndex, ResidualIndex};
4
5use crate::graph::Component;
6use crate::residual::{CurveRadius, LineHandle, PointHandle, Residual};
7
8#[derive(Clone, Debug, PartialEq)]
9pub struct ConstraintSystem {
10 parameters: Vec<Parameter>,
11 residuals: Vec<Residual>,
12}
13
14impl ConstraintSystem {
15 #[must_use]
16 pub fn new(parameters: Vec<Parameter>, residuals: Vec<Residual>) -> Self {
17 Self {
18 parameters,
19 residuals,
20 }
21 }
22
23 #[must_use]
24 pub fn parameters(&self) -> &[Parameter] {
25 &self.parameters
26 }
27
28 #[must_use]
29 pub fn residuals(&self) -> &[Residual] {
30 &self.residuals
31 }
32
33 #[must_use]
34 pub fn parameter_count(&self) -> usize {
35 self.parameters.len()
36 }
37
38 #[must_use]
39 pub fn row_count(&self) -> usize {
40 self.residuals.iter().map(Residual::rows).sum()
41 }
42
43 #[must_use]
44 pub fn row_offsets(&self) -> Vec<ResidualIndex> {
45 self.residuals
46 .iter()
47 .scan(0u32, |acc, r| {
48 let start = ResidualIndex::new(*acc);
49 let Ok(n) = u32::try_from(r.rows()) else {
50 unreachable!("Residual::rows is bounded to a small constant")
51 };
52 let Some(next) = acc.checked_add(n) else {
53 unreachable!("row count fits in u32 by construction of ResidualIndex")
54 };
55 *acc = next;
56 Some(start)
57 })
58 .collect()
59 }
60
61 #[must_use]
62 pub fn parameter_at(&self, idx: ParameterIndex) -> Parameter {
63 self.parameters[idx.as_usize()]
64 }
65
66 #[must_use]
67 pub fn residual_at_parent(&self, parent: ParentIndex) -> &Residual {
68 &self.residuals[parent.as_usize()]
69 }
70
71 #[must_use]
72 pub fn with_extra_residuals(&self, extra: Vec<Residual>) -> Self {
73 let mut residuals = self.residuals.clone();
74 residuals.extend(extra);
75 Self {
76 parameters: self.parameters.clone(),
77 residuals,
78 }
79 }
80
81 #[must_use]
82 pub fn parent_of_row(&self, row: ResidualIndex) -> ParentIndex {
83 let target = row.as_usize();
84 let found = self
85 .residuals
86 .iter()
87 .scan(0usize, |acc, r| {
88 let start = *acc;
89 *acc += r.rows();
90 Some((start, *acc))
91 })
92 .position(|(begin, end)| target >= begin && target < end);
93 let Some(idx) = found else {
94 unreachable!("ResidualIndex {row} is past the residual row count")
95 };
96 let Ok(iv) = u32::try_from(idx) else {
97 unreachable!("parent count fits in u32 via ResidualIndex")
98 };
99 ParentIndex::new(iv)
100 }
101
102 #[must_use]
103 pub fn rows_of_parent(&self, parent: ParentIndex) -> Vec<ResidualIndex> {
104 let idx = parent.as_usize();
105 let Some(residual) = self.residuals.get(idx) else {
106 unreachable!("ParentIndex {parent} out of range for this ConstraintSystem")
107 };
108 let start_usize: usize = self.residuals[..idx].iter().map(Residual::rows).sum();
109 let Ok(start) = u32::try_from(start_usize) else {
110 unreachable!("row count fits in u32 via ResidualIndex")
111 };
112 (0..residual.rows())
113 .map(|o| {
114 let Ok(ov) = u32::try_from(o) else {
115 unreachable!("residual row count is a small constant")
116 };
117 ResidualIndex::new(start + ov)
118 })
119 .collect()
120 }
121
122 #[must_use]
123 pub fn without_parents(&self, exclude: &BTreeSet<ParentIndex>) -> Self {
124 let mask: BTreeSet<usize> = exclude.iter().copied().map(ParentIndex::as_usize).collect();
125 let residuals = self
126 .residuals
127 .iter()
128 .enumerate()
129 .filter(|(i, _)| !mask.contains(i))
130 .map(|(_, r)| r.clone())
131 .collect();
132 Self {
133 parameters: self.parameters.clone(),
134 residuals,
135 }
136 }
137
138 #[must_use]
139 pub fn subsystem(&self, component: &Component) -> Subsystem {
140 let orig_params: &[ParameterIndex] = component.parameters();
141 let remap: BTreeMap<ParameterIndex, ParameterIndex> = orig_params
142 .iter()
143 .copied()
144 .enumerate()
145 .map(|(i, orig)| {
146 let Ok(iv) = u32::try_from(i) else {
147 unreachable!("component parameter count fits in u32 by construction")
148 };
149 (orig, ParameterIndex::new(iv))
150 })
151 .collect();
152 let parameters: Vec<Parameter> = orig_params
153 .iter()
154 .map(|p| self.parameters[p.as_usize()])
155 .collect();
156 let residuals: Vec<Residual> = component
157 .residual_parents()
158 .iter()
159 .map(|&parent| remap_residual(self.residual_at_parent(parent), &remap))
160 .collect();
161 Subsystem {
162 system: ConstraintSystem::new(parameters, residuals),
163 param_map: orig_params.to_vec(),
164 }
165 }
166}
167
168#[derive(Clone, Debug)]
169pub struct Subsystem {
170 system: ConstraintSystem,
171 param_map: Vec<ParameterIndex>,
172}
173
174impl Subsystem {
175 #[must_use]
176 pub fn system(&self) -> &ConstraintSystem {
177 &self.system
178 }
179
180 #[must_use]
181 pub fn param_map(&self) -> &[ParameterIndex] {
182 &self.param_map
183 }
184}
185
186fn remap_residual(r: &Residual, m: &BTreeMap<ParameterIndex, ParameterIndex>) -> Residual {
187 let idx = |p: ParameterIndex| -> ParameterIndex {
188 let Some(&mapped) = m.get(&p) else {
189 unreachable!("residual parameter missing from component remap (graph bug)")
190 };
191 mapped
192 };
193 let point = |p: PointHandle| PointHandle {
194 x: idx(p.x),
195 y: idx(p.y),
196 };
197 let line = |l: LineHandle| LineHandle {
198 a: point(l.a),
199 b: point(l.b),
200 };
201 let curve = |c: CurveRadius| match c {
202 CurveRadius::Explicit { center, radius } => CurveRadius::Explicit {
203 center: point(center),
204 radius: idx(radius),
205 },
206 CurveRadius::FromSpoke { center, spoke } => CurveRadius::FromSpoke {
207 center: point(center),
208 spoke: point(spoke),
209 },
210 };
211 match *r {
212 Residual::Pin { param, target } => Residual::Pin {
213 param: idx(param),
214 target,
215 },
216 Residual::Horizontal(l) => Residual::Horizontal(line(l)),
217 Residual::Vertical(l) => Residual::Vertical(line(l)),
218 Residual::Parallel(a, b) => Residual::Parallel(line(a), line(b)),
219 Residual::Perpendicular(a, b) => Residual::Perpendicular(line(a), line(b)),
220 Residual::TangentLineCurve { line: l, curve: c } => Residual::TangentLineCurve {
221 line: line(l),
222 curve: curve(c),
223 },
224 Residual::TangentCurveCurve { a, b } => Residual::TangentCurveCurve {
225 a: curve(a),
226 b: curve(b),
227 },
228 Residual::CoincidentPointPoint(p, q) => Residual::CoincidentPointPoint(point(p), point(q)),
229 Residual::CoincidentPointLine { point: p, line: l } => Residual::CoincidentPointLine {
230 point: point(p),
231 line: line(l),
232 },
233 Residual::CoincidentPointCurve { point: p, curve: c } => Residual::CoincidentPointCurve {
234 point: point(p),
235 curve: curve(c),
236 },
237 Residual::MidpointPointLine { point: p, line: l } => Residual::MidpointPointLine {
238 point: point(p),
239 line: line(l),
240 },
241 Residual::EqualLength(a, b) => Residual::EqualLength(line(a), line(b)),
242 Residual::EqualRadius(a, b) => Residual::EqualRadius(curve(a), curve(b)),
243 Residual::LinearDistance { a, b, value_mm } => Residual::LinearDistance {
244 a: point(a),
245 b: point(b),
246 value_mm,
247 },
248 Residual::AngularBetweenLines { a, b, angle_rad } => Residual::AngularBetweenLines {
249 a: line(a),
250 b: line(b),
251 angle_rad,
252 },
253 Residual::RadiusCurve { curve: c, value_mm } => Residual::RadiusCurve {
254 curve: curve(c),
255 value_mm,
256 },
257 Residual::Symmetric { a, b, axis } => Residual::Symmetric {
258 a: point(a),
259 b: point(b),
260 axis: line(axis),
261 },
262 }
263}