Another project
0

Configure Feed

Select the types of activity you want to include in your feed.

at main 12 kB View raw
1use std::collections::BTreeMap; 2use std::time::{Duration, Instant}; 3 4use bone_types::{ 5 BudgetCeiling, NewtonDamping, NewtonStepTolerance, Parameter, ParameterIndex, ResidualIndex, 6 SolverResidual, 7}; 8use faer::Mat; 9use faer::linalg::solvers::Solve; 10use faer::sparse::{SparseColMat, Triplet as FaerTriplet}; 11use nalgebra::DVector; 12 13use crate::error::{Result, SolverError}; 14use crate::graph::Decomposition; 15use crate::jacobian::{evaluate_residuals, jacobian_triplets}; 16use crate::residual::Triplet; 17use crate::system::ConstraintSystem; 18 19#[derive(Copy, Clone, Debug, PartialEq)] 20pub struct NewtonConfig { 21 pub max_iterations: u32, 22 pub residual_tolerance: SolverResidual, 23 pub step_tolerance: NewtonStepTolerance, 24 pub line_search_shrinks: u32, 25 pub damping: NewtonDamping, 26 pub budget: Option<BudgetCeiling>, 27} 28 29impl NewtonConfig { 30 pub const DEFAULT: Self = Self { 31 max_iterations: 64, 32 residual_tolerance: SolverResidual::new(1e-10), 33 step_tolerance: NewtonStepTolerance::DEFAULT, 34 line_search_shrinks: 20, 35 damping: NewtonDamping::DEFAULT, 36 budget: None, 37 }; 38} 39 40#[must_use] 41pub fn residual_norm(residuals: &[f64]) -> SolverResidual { 42 let sum: f64 = residuals.iter().map(|r| r * r).sum(); 43 SolverResidual::new(sum.sqrt()) 44} 45 46#[derive(Copy, Clone, Debug)] 47struct Stopwatch { 48 started_at: Instant, 49 budget: Option<BudgetCeiling>, 50} 51 52impl Stopwatch { 53 fn start(budget: Option<BudgetCeiling>) -> Self { 54 Self { 55 started_at: Instant::now(), 56 budget, 57 } 58 } 59 60 fn check(self) -> Result<(), SolverError> { 61 let Some(budget) = self.budget else { 62 return Ok(()); 63 }; 64 let elapsed = self.started_at.elapsed(); 65 if elapsed >= budget.duration() { 66 Err(SolverError::Budget { elapsed }) 67 } else { 68 Ok(()) 69 } 70 } 71} 72 73#[must_use] 74fn remaining_budget(outer: BudgetCeiling, elapsed: Duration) -> BudgetCeiling { 75 BudgetCeiling::new(outer.duration().saturating_sub(elapsed)) 76} 77 78struct Iterate { 79 params: Vec<f64>, 80 residuals: Vec<f64>, 81 norm: SolverResidual, 82} 83 84impl Iterate { 85 fn evaluate(system: &ConstraintSystem, params: Vec<f64>) -> Self { 86 let residuals = evaluate_residuals(system, &params); 87 let norm = residual_norm(&residuals); 88 Self { 89 params, 90 residuals, 91 norm, 92 } 93 } 94} 95 96pub fn solve_newton(system: &ConstraintSystem, cfg: NewtonConfig) -> Result<Vec<Parameter>> { 97 let stopwatch = Stopwatch::start(cfg.budget); 98 let seed: Vec<f64> = system.parameters().iter().map(|p| p.value()).collect(); 99 converge( 100 system, 101 cfg, 102 &Iterate::evaluate(system, seed), 103 cfg.max_iterations, 104 stopwatch, 105 ) 106} 107 108pub fn solve_newton_decomposed( 109 system: &ConstraintSystem, 110 decomposition: &Decomposition, 111 cfg: NewtonConfig, 112) -> Result<Vec<Parameter>> { 113 let stopwatch = Stopwatch::start(cfg.budget); 114 let seed: Vec<Parameter> = system.parameters().to_vec(); 115 decomposition 116 .components() 117 .iter() 118 .try_fold(seed, |acc, component| { 119 stopwatch.check()?; 120 let sub = system.subsystem(component); 121 if sub.system().row_count() == 0 { 122 return Ok(acc); 123 } 124 let sub_cfg = NewtonConfig { 125 budget: cfg 126 .budget 127 .map(|b| remaining_budget(b, stopwatch.started_at.elapsed())), 128 ..cfg 129 }; 130 let solved = solve_newton(sub.system(), sub_cfg) 131 .map_err(|err| lift_subsystem_error(err, sub.param_map()))?; 132 Ok(splice(acc, sub.param_map(), &solved)) 133 }) 134} 135 136fn lift_subsystem_error(err: SolverError, param_map: &[ParameterIndex]) -> SolverError { 137 match err { 138 SolverError::InvalidSolutionFound { at } => SolverError::InvalidSolutionFound { 139 at: param_map[at.as_usize()], 140 }, 141 other => other, 142 } 143} 144 145fn splice( 146 mut whole: Vec<Parameter>, 147 param_map: &[bone_types::ParameterIndex], 148 sub: &[Parameter], 149) -> Vec<Parameter> { 150 param_map 151 .iter() 152 .zip(sub.iter().copied()) 153 .for_each(|(orig, value)| { 154 whole[orig.as_usize()] = value; 155 }); 156 whole 157} 158 159fn converge( 160 system: &ConstraintSystem, 161 cfg: NewtonConfig, 162 state: &Iterate, 163 remaining: u32, 164 stopwatch: Stopwatch, 165) -> Result<Vec<Parameter>> { 166 if state.norm.value() < cfg.residual_tolerance.value() { 167 return Ok(to_parameters(&state.params)); 168 } 169 if remaining == 0 { 170 return Err(SolverError::NoSolutionFound { last: state.norm }); 171 } 172 stopwatch.check()?; 173 let triplets = jacobian_triplets(system, &state.params); 174 let step = least_squares_step( 175 system.parameter_count(), 176 &triplets, 177 &state.residuals, 178 cfg.damping, 179 )?; 180 if step.norm() < cfg.step_tolerance.value() { 181 return Ok(to_parameters(&state.params)); 182 } 183 match line_search( 184 system, 185 &state.params, 186 &step, 187 state.norm, 188 cfg.line_search_shrinks, 189 ) { 190 Some(next) => converge(system, cfg, &next, remaining - 1, stopwatch), 191 None => Err(SolverError::NoSolutionFound { last: state.norm }), 192 } 193} 194 195fn to_parameters(values: &[f64]) -> Vec<Parameter> { 196 values.iter().copied().map(Parameter::new).collect() 197} 198 199fn least_squares_step( 200 params_len: usize, 201 triplets: &[Triplet], 202 residuals: &[f64], 203 damping: NewtonDamping, 204) -> Result<DVector<f64>> { 205 let (jtj_triplets, jtr) = 206 assemble_normal_equations(params_len, triplets, residuals, damping.value()); 207 let Ok(jtj) = 208 SparseColMat::<usize, f64>::try_new_from_triplets(params_len, params_len, &jtj_triplets) 209 else { 210 unreachable!("normal-equations triplets are built with in-bounds indices") 211 }; 212 let lu = jtj.sp_lu().map_err(|err| match err { 213 faer::sparse::linalg::LuError::SymbolicSingular { index } => { 214 let Ok(at) = u32::try_from(index) else { 215 unreachable!("J^T J column count fits in u32 via ParameterIndex") 216 }; 217 SolverError::InvalidSolutionFound { 218 at: ParameterIndex::new(at), 219 } 220 } 221 faer::sparse::linalg::LuError::Generic(_) => { 222 unreachable!("faer sp_lu returns Generic only on malformed input") 223 } 224 })?; 225 let mut rhs = Mat::<f64>::from_fn(params_len, 1, |i, _| -jtr[i]); 226 lu.solve_in_place(rhs.as_mut()); 227 Ok(DVector::from_iterator( 228 params_len, 229 (0..params_len).map(|i| rhs[(i, 0)]), 230 )) 231} 232 233fn assemble_normal_equations( 234 params_len: usize, 235 triplets: &[Triplet], 236 residuals: &[f64], 237 damping: f64, 238) -> (Vec<FaerTriplet<usize, usize, f64>>, Vec<f64>) { 239 let rows: BTreeMap<ResidualIndex, Vec<(usize, f64)>> = 240 triplets 241 .iter() 242 .fold(BTreeMap::new(), |mut acc, (row, col, val)| { 243 acc.entry(*row).or_default().push((col.as_usize(), *val)); 244 acc 245 }); 246 let jtj: BTreeMap<(usize, usize), f64> = rows 247 .values() 248 .flat_map(|entries| { 249 entries 250 .iter() 251 .flat_map(move |(i, vi)| entries.iter().map(move |(j, vj)| ((*i, *j), vi * vj))) 252 }) 253 .fold(BTreeMap::new(), |mut acc, (key, val)| { 254 *acc.entry(key).or_insert(0.0) += val; 255 acc 256 }); 257 let damped = (0..params_len).fold(jtj, |mut acc, i| { 258 *acc.entry((i, i)).or_insert(0.0) += damping; 259 acc 260 }); 261 let jtj_triplets: Vec<FaerTriplet<usize, usize, f64>> = damped 262 .into_iter() 263 .map(|((i, j), v)| FaerTriplet::new(i, j, v)) 264 .collect(); 265 let jtr = triplets 266 .iter() 267 .fold(vec![0.0_f64; params_len], |mut acc, (row, col, val)| { 268 acc[col.as_usize()] += val * residuals[row.as_usize()]; 269 acc 270 }); 271 (jtj_triplets, jtr) 272} 273 274fn line_search( 275 system: &ConstraintSystem, 276 params: &[f64], 277 step: &DVector<f64>, 278 baseline: SolverResidual, 279 max_shrinks: u32, 280) -> Option<Iterate> { 281 let count = usize::try_from(max_shrinks) 282 .unwrap_or(usize::MAX) 283 .saturating_add(1); 284 std::iter::successors(Some(1.0_f64), |a| Some(a * 0.5)) 285 .take(count) 286 .find_map(|alpha| { 287 let trial: Vec<f64> = params 288 .iter() 289 .zip(step.iter()) 290 .map(|(p, s)| p + alpha * s) 291 .collect(); 292 let candidate = Iterate::evaluate(system, trial); 293 (candidate.norm.value() < baseline.value()).then_some(candidate) 294 }) 295} 296 297#[cfg(test)] 298mod tests { 299 use super::*; 300 301 #[test] 302 fn lift_subsystem_error_translates_singular_index() { 303 let param_map = vec![ 304 ParameterIndex::new(4), 305 ParameterIndex::new(9), 306 ParameterIndex::new(12), 307 ]; 308 let local = SolverError::InvalidSolutionFound { 309 at: ParameterIndex::new(1), 310 }; 311 match lift_subsystem_error(local, &param_map) { 312 SolverError::InvalidSolutionFound { at } => assert_eq!(at, ParameterIndex::new(9)), 313 other => panic!("expected InvalidSolutionFound, got {other:?}"), 314 } 315 } 316 317 #[test] 318 fn lift_subsystem_error_passes_through_other_variants() { 319 let param_map = [ParameterIndex::new(2), ParameterIndex::new(5)]; 320 let no_solution = SolverError::NoSolutionFound { 321 last: SolverResidual::new(3.5), 322 }; 323 match lift_subsystem_error(no_solution, &param_map) { 324 SolverError::NoSolutionFound { last } => { 325 assert!((last.value() - 3.5).abs() < f64::EPSILON); 326 } 327 other => panic!("expected NoSolutionFound, got {other:?}"), 328 } 329 } 330 331 #[test] 332 fn remaining_budget_subtracts_elapsed() { 333 let outer = BudgetCeiling::new(Duration::from_millis(10)); 334 let r = remaining_budget(outer, Duration::from_millis(3)); 335 assert_eq!(r.duration(), Duration::from_millis(7)); 336 } 337 338 #[test] 339 fn remaining_budget_saturates_at_zero_when_overshot() { 340 let outer = BudgetCeiling::new(Duration::from_millis(10)); 341 let r = remaining_budget(outer, Duration::from_millis(25)); 342 assert_eq!(r.duration(), Duration::ZERO); 343 } 344 345 #[test] 346 fn decomposed_zero_budget_fires_across_disjoint_components() { 347 use crate::graph::decompose; 348 use crate::residual::Residual; 349 use crate::system::ConstraintSystem; 350 let system = ConstraintSystem::new( 351 vec![Parameter::new(0.0), Parameter::new(0.0)], 352 vec![ 353 Residual::Pin { 354 param: ParameterIndex::new(0), 355 target: 1.0, 356 }, 357 Residual::Pin { 358 param: ParameterIndex::new(1), 359 target: 2.0, 360 }, 361 ], 362 ); 363 let decomp = decompose(&system); 364 assert!( 365 decomp.components().len() >= 2, 366 "fixture must decompose into at least two components, got {}", 367 decomp.components().len(), 368 ); 369 let cfg = NewtonConfig { 370 budget: Some(BudgetCeiling::new(Duration::ZERO)), 371 ..NewtonConfig::DEFAULT 372 }; 373 match solve_newton_decomposed(&system, &decomp, cfg) { 374 Err(SolverError::Budget { .. }) => {} 375 other => panic!("expected Budget, got {other:?}"), 376 } 377 } 378}