Another project
1use std::collections::BTreeMap;
2use std::time::{Duration, Instant};
3
4use bone_types::{
5 BudgetCeiling, NewtonDamping, NewtonStepTolerance, Parameter, ParameterIndex, ResidualIndex,
6 SolverResidual,
7};
8use faer::Mat;
9use faer::linalg::solvers::Solve;
10use faer::sparse::{SparseColMat, Triplet as FaerTriplet};
11use nalgebra::DVector;
12
13use crate::error::{Result, SolverError};
14use crate::graph::Decomposition;
15use crate::jacobian::{evaluate_residuals, jacobian_triplets};
16use crate::residual::Triplet;
17use crate::system::ConstraintSystem;
18
19#[derive(Copy, Clone, Debug, PartialEq)]
20pub struct NewtonConfig {
21 pub max_iterations: u32,
22 pub residual_tolerance: SolverResidual,
23 pub step_tolerance: NewtonStepTolerance,
24 pub line_search_shrinks: u32,
25 pub damping: NewtonDamping,
26 pub budget: Option<BudgetCeiling>,
27}
28
29impl NewtonConfig {
30 pub const DEFAULT: Self = Self {
31 max_iterations: 64,
32 residual_tolerance: SolverResidual::new(1e-10),
33 step_tolerance: NewtonStepTolerance::DEFAULT,
34 line_search_shrinks: 20,
35 damping: NewtonDamping::DEFAULT,
36 budget: None,
37 };
38}
39
40#[must_use]
41pub fn residual_norm(residuals: &[f64]) -> SolverResidual {
42 let sum: f64 = residuals.iter().map(|r| r * r).sum();
43 SolverResidual::new(sum.sqrt())
44}
45
46#[derive(Copy, Clone, Debug)]
47struct Stopwatch {
48 started_at: Instant,
49 budget: Option<BudgetCeiling>,
50}
51
52impl Stopwatch {
53 fn start(budget: Option<BudgetCeiling>) -> Self {
54 Self {
55 started_at: Instant::now(),
56 budget,
57 }
58 }
59
60 fn check(self) -> Result<(), SolverError> {
61 let Some(budget) = self.budget else {
62 return Ok(());
63 };
64 let elapsed = self.started_at.elapsed();
65 if elapsed >= budget.duration() {
66 Err(SolverError::Budget { elapsed })
67 } else {
68 Ok(())
69 }
70 }
71}
72
73#[must_use]
74fn remaining_budget(outer: BudgetCeiling, elapsed: Duration) -> BudgetCeiling {
75 BudgetCeiling::new(outer.duration().saturating_sub(elapsed))
76}
77
78struct Iterate {
79 params: Vec<f64>,
80 residuals: Vec<f64>,
81 norm: SolverResidual,
82}
83
84impl Iterate {
85 fn evaluate(system: &ConstraintSystem, params: Vec<f64>) -> Self {
86 let residuals = evaluate_residuals(system, ¶ms);
87 let norm = residual_norm(&residuals);
88 Self {
89 params,
90 residuals,
91 norm,
92 }
93 }
94}
95
96pub fn solve_newton(system: &ConstraintSystem, cfg: NewtonConfig) -> Result<Vec<Parameter>> {
97 let stopwatch = Stopwatch::start(cfg.budget);
98 let seed: Vec<f64> = system.parameters().iter().map(|p| p.value()).collect();
99 converge(
100 system,
101 cfg,
102 &Iterate::evaluate(system, seed),
103 cfg.max_iterations,
104 stopwatch,
105 )
106}
107
108pub fn solve_newton_decomposed(
109 system: &ConstraintSystem,
110 decomposition: &Decomposition,
111 cfg: NewtonConfig,
112) -> Result<Vec<Parameter>> {
113 let stopwatch = Stopwatch::start(cfg.budget);
114 let seed: Vec<Parameter> = system.parameters().to_vec();
115 decomposition
116 .components()
117 .iter()
118 .try_fold(seed, |acc, component| {
119 stopwatch.check()?;
120 let sub = system.subsystem(component);
121 if sub.system().row_count() == 0 {
122 return Ok(acc);
123 }
124 let sub_cfg = NewtonConfig {
125 budget: cfg
126 .budget
127 .map(|b| remaining_budget(b, stopwatch.started_at.elapsed())),
128 ..cfg
129 };
130 let solved = solve_newton(sub.system(), sub_cfg)
131 .map_err(|err| lift_subsystem_error(err, sub.param_map()))?;
132 Ok(splice(acc, sub.param_map(), &solved))
133 })
134}
135
136fn lift_subsystem_error(err: SolverError, param_map: &[ParameterIndex]) -> SolverError {
137 match err {
138 SolverError::InvalidSolutionFound { at } => SolverError::InvalidSolutionFound {
139 at: param_map[at.as_usize()],
140 },
141 other => other,
142 }
143}
144
145fn splice(
146 mut whole: Vec<Parameter>,
147 param_map: &[bone_types::ParameterIndex],
148 sub: &[Parameter],
149) -> Vec<Parameter> {
150 param_map
151 .iter()
152 .zip(sub.iter().copied())
153 .for_each(|(orig, value)| {
154 whole[orig.as_usize()] = value;
155 });
156 whole
157}
158
159fn converge(
160 system: &ConstraintSystem,
161 cfg: NewtonConfig,
162 state: &Iterate,
163 remaining: u32,
164 stopwatch: Stopwatch,
165) -> Result<Vec<Parameter>> {
166 if state.norm.value() < cfg.residual_tolerance.value() {
167 return Ok(to_parameters(&state.params));
168 }
169 if remaining == 0 {
170 return Err(SolverError::NoSolutionFound { last: state.norm });
171 }
172 stopwatch.check()?;
173 let triplets = jacobian_triplets(system, &state.params);
174 let step = least_squares_step(
175 system.parameter_count(),
176 &triplets,
177 &state.residuals,
178 cfg.damping,
179 )?;
180 if step.norm() < cfg.step_tolerance.value() {
181 return Ok(to_parameters(&state.params));
182 }
183 match line_search(
184 system,
185 &state.params,
186 &step,
187 state.norm,
188 cfg.line_search_shrinks,
189 ) {
190 Some(next) => converge(system, cfg, &next, remaining - 1, stopwatch),
191 None => Err(SolverError::NoSolutionFound { last: state.norm }),
192 }
193}
194
195fn to_parameters(values: &[f64]) -> Vec<Parameter> {
196 values.iter().copied().map(Parameter::new).collect()
197}
198
199fn least_squares_step(
200 params_len: usize,
201 triplets: &[Triplet],
202 residuals: &[f64],
203 damping: NewtonDamping,
204) -> Result<DVector<f64>> {
205 let (jtj_triplets, jtr) =
206 assemble_normal_equations(params_len, triplets, residuals, damping.value());
207 let Ok(jtj) =
208 SparseColMat::<usize, f64>::try_new_from_triplets(params_len, params_len, &jtj_triplets)
209 else {
210 unreachable!("normal-equations triplets are built with in-bounds indices")
211 };
212 let lu = jtj.sp_lu().map_err(|err| match err {
213 faer::sparse::linalg::LuError::SymbolicSingular { index } => {
214 let Ok(at) = u32::try_from(index) else {
215 unreachable!("J^T J column count fits in u32 via ParameterIndex")
216 };
217 SolverError::InvalidSolutionFound {
218 at: ParameterIndex::new(at),
219 }
220 }
221 faer::sparse::linalg::LuError::Generic(_) => {
222 unreachable!("faer sp_lu returns Generic only on malformed input")
223 }
224 })?;
225 let mut rhs = Mat::<f64>::from_fn(params_len, 1, |i, _| -jtr[i]);
226 lu.solve_in_place(rhs.as_mut());
227 Ok(DVector::from_iterator(
228 params_len,
229 (0..params_len).map(|i| rhs[(i, 0)]),
230 ))
231}
232
233fn assemble_normal_equations(
234 params_len: usize,
235 triplets: &[Triplet],
236 residuals: &[f64],
237 damping: f64,
238) -> (Vec<FaerTriplet<usize, usize, f64>>, Vec<f64>) {
239 let rows: BTreeMap<ResidualIndex, Vec<(usize, f64)>> =
240 triplets
241 .iter()
242 .fold(BTreeMap::new(), |mut acc, (row, col, val)| {
243 acc.entry(*row).or_default().push((col.as_usize(), *val));
244 acc
245 });
246 let jtj: BTreeMap<(usize, usize), f64> = rows
247 .values()
248 .flat_map(|entries| {
249 entries
250 .iter()
251 .flat_map(move |(i, vi)| entries.iter().map(move |(j, vj)| ((*i, *j), vi * vj)))
252 })
253 .fold(BTreeMap::new(), |mut acc, (key, val)| {
254 *acc.entry(key).or_insert(0.0) += val;
255 acc
256 });
257 let damped = (0..params_len).fold(jtj, |mut acc, i| {
258 *acc.entry((i, i)).or_insert(0.0) += damping;
259 acc
260 });
261 let jtj_triplets: Vec<FaerTriplet<usize, usize, f64>> = damped
262 .into_iter()
263 .map(|((i, j), v)| FaerTriplet::new(i, j, v))
264 .collect();
265 let jtr = triplets
266 .iter()
267 .fold(vec![0.0_f64; params_len], |mut acc, (row, col, val)| {
268 acc[col.as_usize()] += val * residuals[row.as_usize()];
269 acc
270 });
271 (jtj_triplets, jtr)
272}
273
274fn line_search(
275 system: &ConstraintSystem,
276 params: &[f64],
277 step: &DVector<f64>,
278 baseline: SolverResidual,
279 max_shrinks: u32,
280) -> Option<Iterate> {
281 let count = usize::try_from(max_shrinks)
282 .unwrap_or(usize::MAX)
283 .saturating_add(1);
284 std::iter::successors(Some(1.0_f64), |a| Some(a * 0.5))
285 .take(count)
286 .find_map(|alpha| {
287 let trial: Vec<f64> = params
288 .iter()
289 .zip(step.iter())
290 .map(|(p, s)| p + alpha * s)
291 .collect();
292 let candidate = Iterate::evaluate(system, trial);
293 (candidate.norm.value() < baseline.value()).then_some(candidate)
294 })
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn lift_subsystem_error_translates_singular_index() {
303 let param_map = vec![
304 ParameterIndex::new(4),
305 ParameterIndex::new(9),
306 ParameterIndex::new(12),
307 ];
308 let local = SolverError::InvalidSolutionFound {
309 at: ParameterIndex::new(1),
310 };
311 match lift_subsystem_error(local, ¶m_map) {
312 SolverError::InvalidSolutionFound { at } => assert_eq!(at, ParameterIndex::new(9)),
313 other => panic!("expected InvalidSolutionFound, got {other:?}"),
314 }
315 }
316
317 #[test]
318 fn lift_subsystem_error_passes_through_other_variants() {
319 let param_map = [ParameterIndex::new(2), ParameterIndex::new(5)];
320 let no_solution = SolverError::NoSolutionFound {
321 last: SolverResidual::new(3.5),
322 };
323 match lift_subsystem_error(no_solution, ¶m_map) {
324 SolverError::NoSolutionFound { last } => {
325 assert!((last.value() - 3.5).abs() < f64::EPSILON);
326 }
327 other => panic!("expected NoSolutionFound, got {other:?}"),
328 }
329 }
330
331 #[test]
332 fn remaining_budget_subtracts_elapsed() {
333 let outer = BudgetCeiling::new(Duration::from_millis(10));
334 let r = remaining_budget(outer, Duration::from_millis(3));
335 assert_eq!(r.duration(), Duration::from_millis(7));
336 }
337
338 #[test]
339 fn remaining_budget_saturates_at_zero_when_overshot() {
340 let outer = BudgetCeiling::new(Duration::from_millis(10));
341 let r = remaining_budget(outer, Duration::from_millis(25));
342 assert_eq!(r.duration(), Duration::ZERO);
343 }
344
345 #[test]
346 fn decomposed_zero_budget_fires_across_disjoint_components() {
347 use crate::graph::decompose;
348 use crate::residual::Residual;
349 use crate::system::ConstraintSystem;
350 let system = ConstraintSystem::new(
351 vec![Parameter::new(0.0), Parameter::new(0.0)],
352 vec![
353 Residual::Pin {
354 param: ParameterIndex::new(0),
355 target: 1.0,
356 },
357 Residual::Pin {
358 param: ParameterIndex::new(1),
359 target: 2.0,
360 },
361 ],
362 );
363 let decomp = decompose(&system);
364 assert!(
365 decomp.components().len() >= 2,
366 "fixture must decompose into at least two components, got {}",
367 decomp.components().len(),
368 );
369 let cfg = NewtonConfig {
370 budget: Some(BudgetCeiling::new(Duration::ZERO)),
371 ..NewtonConfig::DEFAULT
372 };
373 match solve_newton_decomposed(&system, &decomp, cfg) {
374 Err(SolverError::Budget { .. }) => {}
375 other => panic!("expected Budget, got {other:?}"),
376 }
377 }
378}