Skip to main content

xlog_solve/
solver.rs

1//! Continuous Local Search (CLS) solver implementation.
2//!
3//! This module implements a CLS-based SAT/MaxSAT solver that treats SAT as continuous
4//! optimization. Variables are relaxed from {0,1} to [0,1] and gradient descent with
5//! momentum is used to minimize the unsatisfied clause penalty.
6//!
7//! # Algorithm Overview
8//!
9//! The CLS algorithm works as follows:
10//! 1. Initialize variables with random values in [0,1]
11//! 2. For each iteration:
12//!    - Compute gradients of the unsatisfied clause penalty
13//!    - Update assignments using momentum-based gradient descent
14//!    - Clamp values to [0,1]
15//!    - Check if the discretized assignment satisfies all clauses
16//! 3. Return SAT if satisfied, or best effort approximation otherwise
17//!
18//! # Example
19//!
20//! ```
21//! use xlog_solve::{SolveInstance, Clause, Literal, Solver};
22//!
23//! // Create a SAT instance: (x0) AND (NOT x0 OR x1)
24//! let instance = SolveInstance::new(2, vec![
25//!     Clause::new(vec![Literal::positive(0)]),
26//!     Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
27//! ]);
28//!
29//! // Solve
30//! let solver = Solver::new_cpu();
31//! let result = solver.solve(instance);
32//! ```
33
34use crate::instance::SolveInstance;
35use crate::proof::{SolveProof, SolveResult, SolveStats, SolveStatus};
36
37// =============================================================================
38// SolverConfig - Configuration parameters for the CLS solver
39// =============================================================================
40
41/// Configuration parameters for the CLS solver.
42///
43/// These parameters control the behavior of the continuous local search algorithm:
44/// - `max_iterations`: Maximum number of gradient descent iterations
45/// - `learning_rate`: Step size for gradient updates
46/// - `momentum`: Momentum coefficient for velocity accumulation
47/// - `discretize_threshold`: Threshold for converting continuous values to boolean
48///
49/// # Default Values
50///
51/// ```
52/// use xlog_solve::SolverConfig;
53///
54/// let config = SolverConfig::default();
55/// assert_eq!(config.max_iterations, 10000);
56/// assert_eq!(config.learning_rate, 0.1);
57/// assert_eq!(config.momentum, 0.9);
58/// assert_eq!(config.discretize_threshold, 0.5);
59/// ```
60///
61/// # Example
62///
63/// ```
64/// use xlog_solve::SolverConfig;
65///
66/// let mut config = SolverConfig::default();
67/// config.max_iterations = 5000;
68/// config.learning_rate = 0.05;
69/// config.momentum = 0.95;
70/// config.discretize_threshold = 0.5;
71/// ```
72#[derive(Debug, Clone, Copy, PartialEq)]
73#[non_exhaustive]
74pub struct SolverConfig {
75    /// Maximum number of iterations before giving up.
76    ///
77    /// The solver will terminate and return the best-effort result if this
78    /// limit is reached without finding a satisfying assignment.
79    pub max_iterations: u32,
80
81    /// Learning rate for gradient descent updates.
82    ///
83    /// Controls the step size when updating variable assignments.
84    /// Larger values lead to faster convergence but may overshoot.
85    /// Typical values are in the range [0.01, 0.5].
86    pub learning_rate: f32,
87
88    /// Momentum coefficient for velocity accumulation.
89    ///
90    /// Controls how much of the previous velocity is retained.
91    /// Helps escape local minima and smooths the optimization trajectory.
92    /// Values close to 1.0 give more weight to history.
93    /// Typical values are in the range [0.8, 0.99].
94    pub momentum: f32,
95
96    /// Threshold for discretizing continuous values to boolean.
97    ///
98    /// Values >= threshold become `true`, values < threshold become `false`.
99    /// The standard value is 0.5 for symmetric treatment.
100    pub discretize_threshold: f32,
101}
102
103impl Default for SolverConfig {
104    fn default() -> Self {
105        Self {
106            max_iterations: 10000,
107            learning_rate: 0.1,
108            momentum: 0.9,
109            discretize_threshold: 0.5,
110        }
111    }
112}
113
114impl SolverConfig {
115    /// Creates a new configuration with specified parameters.
116    ///
117    /// # Arguments
118    ///
119    /// * `max_iterations` - Maximum number of iterations
120    /// * `learning_rate` - Step size for gradient updates
121    /// * `momentum` - Momentum coefficient
122    /// * `discretize_threshold` - Threshold for boolean conversion
123    #[inline]
124    pub const fn new(
125        max_iterations: u32,
126        learning_rate: f32,
127        momentum: f32,
128        discretize_threshold: f32,
129    ) -> Self {
130        Self {
131            max_iterations,
132            learning_rate,
133            momentum,
134            discretize_threshold,
135        }
136    }
137
138    /// Creates a configuration optimized for fast convergence on small instances.
139    ///
140    /// Uses higher learning rate and fewer iterations.
141    #[inline]
142    pub const fn fast() -> Self {
143        Self {
144            max_iterations: 1000,
145            learning_rate: 0.2,
146            momentum: 0.9,
147            discretize_threshold: 0.5,
148        }
149    }
150
151    /// Creates a configuration optimized for thorough search on hard instances.
152    ///
153    /// Uses more iterations and lower learning rate for better exploration.
154    #[inline]
155    pub const fn thorough() -> Self {
156        Self {
157            max_iterations: 50000,
158            learning_rate: 0.05,
159            momentum: 0.95,
160            discretize_threshold: 0.5,
161        }
162    }
163
164    /// Sets the maximum iterations, consuming and returning self.
165    #[inline]
166    pub const fn with_max_iterations(mut self, max_iterations: u32) -> Self {
167        self.max_iterations = max_iterations;
168        self
169    }
170
171    /// Sets the learning rate, consuming and returning self.
172    #[inline]
173    pub const fn with_learning_rate(mut self, learning_rate: f32) -> Self {
174        self.learning_rate = learning_rate;
175        self
176    }
177
178    /// Sets the momentum, consuming and returning self.
179    #[inline]
180    pub const fn with_momentum(mut self, momentum: f32) -> Self {
181        self.momentum = momentum;
182        self
183    }
184
185    /// Sets the discretize threshold, consuming and returning self.
186    #[inline]
187    pub const fn with_discretize_threshold(mut self, threshold: f32) -> Self {
188        self.discretize_threshold = threshold;
189        self
190    }
191}
192
193// =============================================================================
194// SolverState - Internal state during CLS optimization
195// =============================================================================
196
197/// Internal state during CLS optimization.
198///
199/// Tracks the continuous variable assignments, momentum velocities, and
200/// computed gradients throughout the optimization process.
201///
202/// # Memory Layout
203///
204/// Each vector has length equal to the number of variables:
205/// - `assignments`: Current continuous values in [0,1]
206/// - `velocities`: Momentum velocities (accumulated gradient history)
207/// - `gradients`: Computed gradients for the current iteration
208#[derive(Debug, Clone)]
209pub struct SolverState {
210    /// Continuous variable assignments in [0,1].
211    ///
212    /// Index i corresponds to variable i. Values close to 1.0 indicate
213    /// the variable should be true, values close to 0.0 indicate false.
214    pub assignments: Vec<f32>,
215
216    /// Momentum velocities for each variable.
217    ///
218    /// Accumulates gradient history to help escape local minima
219    /// and smooth the optimization trajectory.
220    pub velocities: Vec<f32>,
221
222    /// Computed gradients for the current iteration.
223    ///
224    /// The gradient of the unsatisfied clause penalty with respect
225    /// to each variable. Negative gradients indicate the variable
226    /// should increase to reduce unsatisfaction.
227    pub gradients: Vec<f32>,
228}
229
230impl SolverState {
231    /// Creates a new solver state for the given number of variables.
232    ///
233    /// Variables are initialized with values near 0.5 using a deterministic
234    /// pseudo-random pattern based on index to break symmetry while
235    /// maintaining reproducibility.
236    ///
237    /// # Arguments
238    ///
239    /// * `num_vars` - Number of variables in the SAT instance
240    ///
241    /// # Example
242    ///
243    /// ```
244    /// use xlog_solve::SolverState;
245    ///
246    /// let state = SolverState::new(10);
247    /// assert_eq!(state.assignments.len(), 10);
248    /// assert_eq!(state.velocities.len(), 10);
249    /// assert_eq!(state.gradients.len(), 10);
250    /// ```
251    pub fn new(num_vars: u32) -> Self {
252        let n = num_vars as usize;
253
254        // Initialize assignments with pseudo-random values near 0.5
255        // This breaks symmetry while being deterministic
256        let assignments: Vec<f32> = (0..n)
257            .map(|i| {
258                // Simple deterministic pseudo-random initialization
259                // Uses golden ratio for good distribution
260                let phi = 1.618033988749895_f64;
261                let val = ((i as f64 + 1.0) * phi).fract() as f32;
262                // Keep values in [0.3, 0.7] to avoid starting at extremes
263                0.3 + val * 0.4
264            })
265            .collect();
266
267        Self {
268            assignments,
269            velocities: vec![0.0; n],
270            gradients: vec![0.0; n],
271        }
272    }
273
274    /// Creates a solver state with specific initial assignments.
275    ///
276    /// Useful for warm-starting from a known assignment or for testing.
277    ///
278    /// # Arguments
279    ///
280    /// * `assignments` - Initial continuous assignments in [0,1]
281    pub fn with_assignments(assignments: Vec<f32>) -> Self {
282        let n = assignments.len();
283        Self {
284            assignments,
285            velocities: vec![0.0; n],
286            gradients: vec![0.0; n],
287        }
288    }
289
290    /// Discretizes the continuous assignments to boolean values.
291    ///
292    /// Values >= threshold become `true`, values < threshold become `false`.
293    ///
294    /// # Arguments
295    ///
296    /// * `threshold` - The threshold for boolean conversion (typically 0.5)
297    ///
298    /// # Returns
299    ///
300    /// A vector of boolean values, one per variable.
301    ///
302    /// # Example
303    ///
304    /// ```
305    /// use xlog_solve::SolverState;
306    ///
307    /// let mut state = SolverState::new(3);
308    /// state.assignments = vec![0.3, 0.7, 0.5];
309    /// let discrete = state.discretize(0.5);
310    /// assert_eq!(discrete, vec![false, true, true]);
311    /// ```
312    #[inline]
313    pub fn discretize(&self, threshold: f32) -> Vec<bool> {
314        self.assignments
315            .iter()
316            .map(|&val| val >= threshold)
317            .collect()
318    }
319
320    /// Returns the number of variables in this state.
321    #[inline]
322    pub fn num_vars(&self) -> usize {
323        self.assignments.len()
324    }
325
326    /// Resets all velocities to zero.
327    ///
328    /// Useful for restarting the optimization with a fresh momentum state.
329    #[inline]
330    pub fn reset_velocities(&mut self) {
331        self.velocities.fill(0.0);
332    }
333
334    /// Clears the gradients buffer.
335    #[inline]
336    pub fn clear_gradients(&mut self) {
337        self.gradients.fill(0.0);
338    }
339}
340
341// =============================================================================
342// Solver - The main CLS solver
343// =============================================================================
344
345/// The CLS-based SAT/MaxSAT solver.
346///
347/// Implements Continuous Local Search, treating SAT as continuous optimization.
348/// Variables are relaxed from {0,1} to [0,1] and gradient descent with momentum
349/// is used to minimize the unsatisfied clause penalty.
350///
351/// # Algorithm
352///
353/// The solver minimizes the "unsatisfied-ness" of clauses:
354/// - For each clause, compute the product of (1 - lit_value) for all literals
355/// - This product is 0 when any literal is satisfied, approaching 1 when all fail
356/// - Gradient descent moves variables to minimize this penalty
357///
358/// # Completeness
359///
360/// CLS is an incomplete solver - it can find satisfying assignments efficiently
361/// but cannot prove unsatisfiability. For unsatisfiable instances, it will
362/// return `Unknown` status with the best-effort assignment found.
363///
364/// # Example
365///
366/// ```
367/// use xlog_solve::{SolveInstance, Clause, Literal, Solver, SolverConfig};
368///
369/// // Create instance: (x0 OR x1) AND (NOT x0 OR x1)
370/// let instance = SolveInstance::new(2, vec![
371///     Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
372///     Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
373/// ]);
374///
375/// // Solve with default config
376/// let solver = Solver::new_cpu();
377/// let result = solver.solve(instance.clone());
378///
379/// // Or with custom config
380/// let config = SolverConfig::default().with_max_iterations(5000);
381/// let solver = Solver::with_config_cpu(config);
382/// let result = solver.solve(instance);
383/// ```
384#[derive(Debug, Clone)]
385pub struct Solver {
386    /// Configuration parameters for the solver.
387    config: SolverConfig,
388}
389
390impl Solver {
391    /// Creates a new CPU-based CLS solver with default configuration.
392    ///
393    /// # Example
394    ///
395    /// ```
396    /// use xlog_solve::Solver;
397    ///
398    /// let solver = Solver::new_cpu();
399    /// ```
400    #[inline]
401    pub fn new_cpu() -> Self {
402        Self {
403            config: SolverConfig::default(),
404        }
405    }
406
407    /// Creates a new CPU-based CLS solver with custom configuration.
408    ///
409    /// # Arguments
410    ///
411    /// * `config` - The solver configuration
412    ///
413    /// # Example
414    ///
415    /// ```
416    /// use xlog_solve::{Solver, SolverConfig};
417    ///
418    /// let config = SolverConfig::default()
419    ///     .with_max_iterations(5000)
420    ///     .with_learning_rate(0.05);
421    /// let solver = Solver::with_config_cpu(config);
422    /// ```
423    #[inline]
424    pub fn with_config_cpu(config: SolverConfig) -> Self {
425        Self { config }
426    }
427
428    /// Returns a reference to the solver configuration.
429    #[inline]
430    pub fn config(&self) -> &SolverConfig {
431        &self.config
432    }
433
434    /// Solves the given SAT instance.
435    ///
436    /// # Arguments
437    ///
438    /// * `instance` - The SAT/MaxSAT instance to solve
439    ///
440    /// # Returns
441    ///
442    /// A `SolveResult` containing:
443    /// - `Sat` status with satisfying assignment if found
444    /// - `Unknown` status with best-effort assignment if not found
445    ///
446    /// Note: CLS cannot prove unsatisfiability, so `Unsat` is never returned.
447    ///
448    /// # Example
449    ///
450    /// ```
451    /// use xlog_solve::{SolveInstance, Clause, Literal, Solver, SolveStatus};
452    ///
453    /// let instance = SolveInstance::new(1, vec![
454    ///     Clause::new(vec![Literal::positive(0)]),
455    /// ]);
456    ///
457    /// let solver = Solver::new_cpu();
458    /// let result = solver.solve(instance);
459    ///
460    /// assert!(matches!(result.status, SolveStatus::Sat));
461    /// ```
462    pub fn solve(&self, instance: SolveInstance) -> SolveResult {
463        let start = std::time::Instant::now();
464
465        // Handle edge cases
466        if instance.num_vars == 0 {
467            // Empty instance is trivially satisfiable if no clauses
468            // or trivially unsatisfiable if there are empty clauses
469            let has_empty_clause = instance.clauses.iter().any(|c| c.is_empty());
470            if has_empty_clause {
471                // Empty clause can never be satisfied
472                return SolveResult {
473                    status: SolveStatus::Unknown,
474                    proof: SolveProof::approximate(vec![], 0, instance.clauses.len() as u32, 0),
475                    stats: SolveStats::new(0, start.elapsed().as_micros() as u64, 0),
476                };
477            }
478            // No variables and no empty clauses - trivially SAT
479            return SolveResult::satisfiable(vec![]).with_stats(SolveStats::new(
480                0,
481                start.elapsed().as_micros() as u64,
482                0,
483            ));
484        }
485
486        if instance.clauses.is_empty() {
487            // No clauses - any assignment works
488            let assignment = vec![false; instance.num_vars as usize];
489            return SolveResult::satisfiable(assignment).with_stats(SolveStats::new(
490                0,
491                start.elapsed().as_micros() as u64,
492                0,
493            ));
494        }
495
496        // Check for empty clauses (impossible to satisfy)
497        if instance.clauses.iter().any(|c| c.is_empty()) {
498            return SolveResult {
499                status: SolveStatus::Unknown,
500                proof: SolveProof::approximate(
501                    vec![false; instance.num_vars as usize],
502                    instance.count_satisfied(&vec![false; instance.num_vars as usize]) as u32,
503                    instance.clauses.len() as u32,
504                    0,
505                ),
506                stats: SolveStats::new(0, start.elapsed().as_micros() as u64, 0),
507            };
508        }
509
510        let mut state = SolverState::new(instance.num_vars);
511
512        // Track best solution found
513        let mut best_assignment: Option<Vec<bool>> = None;
514        let mut best_satisfied: u32 = 0;
515
516        for iter in 0..self.config.max_iterations {
517            // Compute gradients
518            self.compute_gradients(&instance, &mut state);
519
520            // Update with momentum
521            self.update_assignments(&mut state);
522
523            // Check if solved
524            let discrete = state.discretize(self.config.discretize_threshold);
525            let satisfied = instance.count_satisfied(&discrete) as u32;
526
527            // Track best solution
528            if satisfied > best_satisfied {
529                best_satisfied = satisfied;
530                best_assignment = Some(discrete.clone());
531            }
532
533            if instance.is_satisfied(&discrete) {
534                return SolveResult::satisfiable(discrete).with_stats(SolveStats {
535                    iterations: iter + 1,
536                    duration_us: start.elapsed().as_micros() as u64,
537                    peak_memory: 0,
538                });
539            }
540        }
541
542        // Return best effort
543        let final_discrete =
544            best_assignment.unwrap_or_else(|| state.discretize(self.config.discretize_threshold));
545        let final_satisfied = instance.count_satisfied(&final_discrete) as u32;
546
547        SolveResult {
548            status: SolveStatus::Unknown,
549            proof: SolveProof::approximate(
550                final_discrete,
551                final_satisfied,
552                instance.clauses.len() as u32,
553                self.config.max_iterations,
554            ),
555            stats: SolveStats {
556                iterations: self.config.max_iterations,
557                duration_us: start.elapsed().as_micros() as u64,
558                peak_memory: 0,
559            },
560        }
561    }
562
563    /// Computes gradients of the unsatisfied clause penalty.
564    ///
565    /// For each clause, the "unsatisfied-ness" is the product of (1 - lit_value)
566    /// for all literals. The gradient is computed using the product rule:
567    ///
568    /// d(clause_unsat)/d(var) = d/d(var)[ prod_i (1 - lit_i) ]
569    ///                        = sign(lit) * prod_{j != i} (1 - lit_j)
570    ///
571    /// where sign(lit) = -1 for positive literals, +1 for negative literals.
572    ///
573    /// # Arguments
574    ///
575    /// * `instance` - The SAT instance
576    /// * `state` - The solver state to update with computed gradients
577    fn compute_gradients(&self, instance: &SolveInstance, state: &mut SolverState) {
578        state.gradients.fill(0.0);
579
580        for clause in &instance.clauses {
581            // Compute clause unsatisfaction: prod_i (1 - lit_val_i)
582            // This is 0 when any literal is satisfied (lit_val = 1)
583            // and approaches 1 when all literals are unsatisfied (lit_val = 0)
584            let mut clause_unsat = 1.0f32;
585            for lit in &clause.literals {
586                let val = state.assignments[lit.var as usize];
587                // lit_val is the "truth value" of the literal:
588                // - For positive literal: lit_val = val
589                // - For negative literal: lit_val = 1 - val
590                let lit_val = if lit.negated { 1.0 - val } else { val };
591                clause_unsat *= 1.0 - lit_val;
592            }
593
594            // Skip nearly satisfied clauses (gradient contribution negligible)
595            // This is an important optimization: when a clause is satisfied,
596            // its gradient contribution is essentially zero, so we skip it
597            if clause_unsat < 0.001 {
598                continue;
599            }
600
601            // Compute gradient contribution for each literal in this clause
602            // Using the product rule: d/dx[f(x)*g(y)] = g(y) * df/dx
603            // The gradient of (1 - lit_val) with respect to var is:
604            // - For positive literal: d/dvar[1 - var] = -1
605            // - For negative literal: d/dvar[1 - (1-var)] = d/dvar[var] = +1
606            for lit in &clause.literals {
607                let var = lit.var as usize;
608
609                // Compute product of other terms (excluding this literal)
610                // This gives us the coefficient when applying product rule
611                let mut other_product = 1.0f32;
612                for other_lit in &clause.literals {
613                    if other_lit.var != lit.var {
614                        let other_val = state.assignments[other_lit.var as usize];
615                        let lit_val = if other_lit.negated {
616                            1.0 - other_val
617                        } else {
618                            other_val
619                        };
620                        other_product *= 1.0 - lit_val;
621                    }
622                }
623
624                // The derivative of (1 - lit_val) with respect to var:
625                // - Positive literal: lit_val = var, so d(1-var)/dvar = -1
626                // - Negative literal: lit_val = 1-var, so d(1-(1-var))/dvar = d(var)/dvar = +1
627                // But we want to minimize unsatisfaction, so gradient points to increase satisfaction
628                let sign = if lit.negated { 1.0 } else { -1.0 };
629                state.gradients[var] += sign * other_product;
630            }
631        }
632    }
633
634    /// Updates variable assignments using momentum-based gradient descent.
635    ///
636    /// The update rule is:
637    /// ```text
638    /// velocity[i] = momentum * velocity[i] - learning_rate * gradient[i]
639    /// assignment[i] = clamp(assignment[i] + velocity[i], 0.0, 1.0)
640    /// ```
641    ///
642    /// # Arguments
643    ///
644    /// * `state` - The solver state to update
645    fn update_assignments(&self, state: &mut SolverState) {
646        for i in 0..state.assignments.len() {
647            // Momentum update: accumulate velocity
648            state.velocities[i] = self.config.momentum * state.velocities[i]
649                - self.config.learning_rate * state.gradients[i];
650
651            // Update assignment
652            state.assignments[i] += state.velocities[i];
653
654            // Clamp to valid range [0, 1]
655            state.assignments[i] = state.assignments[i].clamp(0.0, 1.0);
656        }
657    }
658}
659
660impl Default for Solver {
661    fn default() -> Self {
662        Self::new_cpu()
663    }
664}
665
666// =============================================================================
667// Tests
668// =============================================================================
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673    use crate::instance::{Clause, Literal};
674
675    // ==========================================================================
676    // SolverConfig Tests
677    // ==========================================================================
678
679    #[test]
680    fn test_solver_config_default() {
681        let config = SolverConfig::default();
682        assert_eq!(config.max_iterations, 10000);
683        assert_eq!(config.learning_rate, 0.1);
684        assert_eq!(config.momentum, 0.9);
685        assert_eq!(config.discretize_threshold, 0.5);
686    }
687
688    #[test]
689    fn test_solver_config_new() {
690        let config = SolverConfig::new(5000, 0.05, 0.95, 0.4);
691        assert_eq!(config.max_iterations, 5000);
692        assert_eq!(config.learning_rate, 0.05);
693        assert_eq!(config.momentum, 0.95);
694        assert_eq!(config.discretize_threshold, 0.4);
695    }
696
697    #[test]
698    fn test_solver_config_fast() {
699        let config = SolverConfig::fast();
700        assert_eq!(config.max_iterations, 1000);
701        assert_eq!(config.learning_rate, 0.2);
702    }
703
704    #[test]
705    fn test_solver_config_thorough() {
706        let config = SolverConfig::thorough();
707        assert_eq!(config.max_iterations, 50000);
708        assert_eq!(config.learning_rate, 0.05);
709    }
710
711    #[test]
712    fn test_solver_config_builders() {
713        let config = SolverConfig::default()
714            .with_max_iterations(2000)
715            .with_learning_rate(0.2)
716            .with_momentum(0.8)
717            .with_discretize_threshold(0.6);
718
719        assert_eq!(config.max_iterations, 2000);
720        assert_eq!(config.learning_rate, 0.2);
721        assert_eq!(config.momentum, 0.8);
722        assert_eq!(config.discretize_threshold, 0.6);
723    }
724
725    // ==========================================================================
726    // SolverState Tests
727    // ==========================================================================
728
729    #[test]
730    fn test_solver_state_new() {
731        let state = SolverState::new(5);
732        assert_eq!(state.assignments.len(), 5);
733        assert_eq!(state.velocities.len(), 5);
734        assert_eq!(state.gradients.len(), 5);
735
736        // All velocities and gradients should be zero
737        assert!(state.velocities.iter().all(|&v| v == 0.0));
738        assert!(state.gradients.iter().all(|&g| g == 0.0));
739
740        // Assignments should be in [0.3, 0.7]
741        for &val in &state.assignments {
742            assert!((0.3..=0.7).contains(&val));
743        }
744    }
745
746    #[test]
747    fn test_solver_state_with_assignments() {
748        let assignments = vec![0.1, 0.5, 0.9];
749        let state = SolverState::with_assignments(assignments.clone());
750        assert_eq!(state.assignments, assignments);
751        assert!(state.velocities.iter().all(|&v| v == 0.0));
752    }
753
754    #[test]
755    fn test_solver_state_discretize() {
756        let mut state = SolverState::new(4);
757        state.assignments = vec![0.2, 0.5, 0.6, 0.9];
758
759        let discrete = state.discretize(0.5);
760        assert_eq!(discrete, vec![false, true, true, true]);
761
762        let discrete_high = state.discretize(0.7);
763        assert_eq!(discrete_high, vec![false, false, false, true]);
764    }
765
766    #[test]
767    fn test_solver_state_num_vars() {
768        let state = SolverState::new(10);
769        assert_eq!(state.num_vars(), 10);
770    }
771
772    #[test]
773    fn test_solver_state_reset_velocities() {
774        let mut state = SolverState::new(3);
775        state.velocities = vec![1.0, 2.0, 3.0];
776        state.reset_velocities();
777        assert!(state.velocities.iter().all(|&v| v == 0.0));
778    }
779
780    #[test]
781    fn test_solver_state_clear_gradients() {
782        let mut state = SolverState::new(3);
783        state.gradients = vec![1.0, 2.0, 3.0];
784        state.clear_gradients();
785        assert!(state.gradients.iter().all(|&g| g == 0.0));
786    }
787
788    // ==========================================================================
789    // Solver Construction Tests
790    // ==========================================================================
791
792    #[test]
793    fn test_solver_new_cpu() {
794        let solver = Solver::new_cpu();
795        assert_eq!(solver.config().max_iterations, 10000);
796    }
797
798    #[test]
799    fn test_solver_with_config_cpu() {
800        let config = SolverConfig::fast();
801        let solver = Solver::with_config_cpu(config);
802        assert_eq!(solver.config().max_iterations, 1000);
803    }
804
805    #[test]
806    fn test_solver_default() {
807        let solver = Solver::default();
808        assert_eq!(solver.config().max_iterations, 10000);
809    }
810
811    // ==========================================================================
812    // Core Algorithm Tests (from task spec)
813    // ==========================================================================
814
815    #[test]
816    fn test_solver_simple_sat() {
817        // (x0) - trivially satisfiable
818        let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
819        let solver = Solver::new_cpu();
820        let result = solver.solve(instance);
821        assert!(matches!(result.status, SolveStatus::Sat));
822        if let Some(assignment) = result.assignment() {
823            assert!(assignment[0]); // x0 must be true
824        }
825    }
826
827    #[test]
828    fn test_solver_two_clause() {
829        // (x0 OR x1) AND (NOT x0 OR x1) - x1 must be true
830        let instance = SolveInstance::new(
831            2,
832            vec![
833                Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
834                Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
835            ],
836        );
837        let solver = Solver::new_cpu();
838        let result = solver.solve(instance);
839        assert!(matches!(result.status, SolveStatus::Sat));
840        if let Some(assignment) = result.assignment() {
841            assert!(assignment[1]); // x1 must be true
842        }
843    }
844
845    #[test]
846    fn test_solver_unsat() {
847        // (x0) AND (NOT x0) - unsatisfiable
848        let instance = SolveInstance::new(
849            1,
850            vec![
851                Clause::new(vec![Literal::positive(0)]),
852                Clause::new(vec![Literal::negative(0)]),
853            ],
854        );
855        let solver = Solver::new_cpu();
856        let result = solver.solve(instance);
857        // CLS is incomplete for UNSAT, so Unknown is acceptable
858        assert!(matches!(
859            result.status,
860            SolveStatus::Unsat | SolveStatus::Unknown
861        ));
862    }
863
864    // ==========================================================================
865    // Additional Algorithm Tests
866    // ==========================================================================
867
868    #[test]
869    fn test_solver_empty_instance() {
870        // No clauses - any assignment works
871        let instance = SolveInstance::new(3, vec![]);
872        let solver = Solver::new_cpu();
873        let result = solver.solve(instance);
874        assert!(matches!(result.status, SolveStatus::Sat));
875    }
876
877    #[test]
878    fn test_solver_no_variables() {
879        // No variables and no clauses
880        let instance = SolveInstance::new(0, vec![]);
881        let solver = Solver::new_cpu();
882        let result = solver.solve(instance);
883        assert!(matches!(result.status, SolveStatus::Sat));
884    }
885
886    #[test]
887    fn test_solver_unit_propagation() {
888        // Unit clause forces x0=true, which makes second clause satisfied via x1
889        // (x0) AND (NOT x0 OR x1)
890        let instance = SolveInstance::new(
891            2,
892            vec![
893                Clause::new(vec![Literal::positive(0)]),
894                Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
895            ],
896        );
897        let solver = Solver::new_cpu();
898        let result = solver.solve(instance);
899        assert!(matches!(result.status, SolveStatus::Sat));
900        if let Some(assignment) = result.assignment() {
901            assert!(assignment[0]); // x0 must be true
902        }
903    }
904
905    #[test]
906    fn test_solver_negative_unit() {
907        // (NOT x0) - must set x0=false
908        let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::negative(0)])]);
909        let solver = Solver::new_cpu();
910        let result = solver.solve(instance);
911        assert!(matches!(result.status, SolveStatus::Sat));
912        if let Some(assignment) = result.assignment() {
913            assert!(!assignment[0]); // x0 must be false
914        }
915    }
916
917    #[test]
918    fn test_solver_three_vars() {
919        // (x0 OR x1) AND (NOT x1 OR x2) AND (NOT x2 OR x0)
920        // Satisfiable: x0=true, x1=true, x2=true
921        let instance = SolveInstance::new(
922            3,
923            vec![
924                Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
925                Clause::new(vec![Literal::negative(1), Literal::positive(2)]),
926                Clause::new(vec![Literal::negative(2), Literal::positive(0)]),
927            ],
928        );
929        let solver = Solver::new_cpu();
930        let result = solver.solve(instance.clone());
931        assert!(matches!(result.status, SolveStatus::Sat));
932        if let Some(assignment) = result.assignment() {
933            assert!(instance.is_satisfied(assignment));
934        }
935    }
936
937    #[test]
938    fn test_solver_all_positive() {
939        // (x0) AND (x1) AND (x2) - all must be true
940        let instance = SolveInstance::new(
941            3,
942            vec![
943                Clause::new(vec![Literal::positive(0)]),
944                Clause::new(vec![Literal::positive(1)]),
945                Clause::new(vec![Literal::positive(2)]),
946            ],
947        );
948        let solver = Solver::new_cpu();
949        let result = solver.solve(instance);
950        assert!(matches!(result.status, SolveStatus::Sat));
951        if let Some(assignment) = result.assignment() {
952            assert!(assignment.iter().all(|&v| v));
953        }
954    }
955
956    #[test]
957    fn test_solver_all_negative() {
958        // (NOT x0) AND (NOT x1) AND (NOT x2) - all must be false
959        let instance = SolveInstance::new(
960            3,
961            vec![
962                Clause::new(vec![Literal::negative(0)]),
963                Clause::new(vec![Literal::negative(1)]),
964                Clause::new(vec![Literal::negative(2)]),
965            ],
966        );
967        let solver = Solver::new_cpu();
968        let result = solver.solve(instance);
969        assert!(matches!(result.status, SolveStatus::Sat));
970        if let Some(assignment) = result.assignment() {
971            assert!(assignment.iter().all(|&v| !v));
972        }
973    }
974
975    #[test]
976    fn test_solver_xor_like() {
977        // (x0 OR x1) AND (NOT x0 OR NOT x1) - XOR: exactly one must be true
978        let instance = SolveInstance::new(
979            2,
980            vec![
981                Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
982                Clause::new(vec![Literal::negative(0), Literal::negative(1)]),
983            ],
984        );
985        let solver = Solver::new_cpu();
986        let result = solver.solve(instance);
987        assert!(matches!(result.status, SolveStatus::Sat));
988        if let Some(assignment) = result.assignment() {
989            // Exactly one should be true
990            assert!(assignment[0] != assignment[1]);
991        }
992    }
993
994    #[test]
995    fn test_solver_binary_clause() {
996        // (x0 OR x1) - at least one true
997        let instance = SolveInstance::new(
998            2,
999            vec![Clause::new(vec![
1000                Literal::positive(0),
1001                Literal::positive(1),
1002            ])],
1003        );
1004        let solver = Solver::new_cpu();
1005        let result = solver.solve(instance);
1006        assert!(matches!(result.status, SolveStatus::Sat));
1007        if let Some(assignment) = result.assignment() {
1008            assert!(assignment[0] || assignment[1]);
1009        }
1010    }
1011
1012    #[test]
1013    fn test_solver_ternary_clause() {
1014        // (x0 OR x1 OR x2) - at least one true
1015        let instance = SolveInstance::new(
1016            3,
1017            vec![Clause::new(vec![
1018                Literal::positive(0),
1019                Literal::positive(1),
1020                Literal::positive(2),
1021            ])],
1022        );
1023        let solver = Solver::new_cpu();
1024        let result = solver.solve(instance);
1025        assert!(matches!(result.status, SolveStatus::Sat));
1026        if let Some(assignment) = result.assignment() {
1027            assert!(assignment[0] || assignment[1] || assignment[2]);
1028        }
1029    }
1030
1031    // ==========================================================================
1032    // Statistics Tests
1033    // ==========================================================================
1034
1035    #[test]
1036    fn test_solver_stats() {
1037        let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
1038        let solver = Solver::new_cpu();
1039        let result = solver.solve(instance);
1040
1041        // Should have iteration count > 0
1042        assert!(result.stats.iterations > 0);
1043        // Duration should be recorded
1044        // Note: Very fast executions might have 0 microseconds
1045        assert!(result.stats.iterations <= solver.config().max_iterations);
1046    }
1047
1048    #[test]
1049    fn test_solver_stats_iterations_limited() {
1050        // Test that iterations are limited by config
1051        let config = SolverConfig::default().with_max_iterations(10);
1052        let solver = Solver::with_config_cpu(config);
1053
1054        // Unsatisfiable instance will run to max iterations
1055        let instance = SolveInstance::new(
1056            1,
1057            vec![
1058                Clause::new(vec![Literal::positive(0)]),
1059                Clause::new(vec![Literal::negative(0)]),
1060            ],
1061        );
1062        let result = solver.solve(instance);
1063
1064        assert!(result.stats.iterations <= 10);
1065    }
1066
1067    // ==========================================================================
1068    // Gradient Computation Tests
1069    // ==========================================================================
1070
1071    #[test]
1072    fn test_compute_gradients_single_positive() {
1073        // (x0) - gradient should push x0 towards 1
1074        let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
1075        let solver = Solver::new_cpu();
1076        let mut state = SolverState::with_assignments(vec![0.5]);
1077
1078        solver.compute_gradients(&instance, &mut state);
1079
1080        // Gradient should be negative (to increase x0 when subtracted)
1081        assert!(state.gradients[0] < 0.0);
1082    }
1083
1084    #[test]
1085    fn test_compute_gradients_single_negative() {
1086        // (NOT x0) - gradient should push x0 towards 0
1087        let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::negative(0)])]);
1088        let solver = Solver::new_cpu();
1089        let mut state = SolverState::with_assignments(vec![0.5]);
1090
1091        solver.compute_gradients(&instance, &mut state);
1092
1093        // Gradient should be positive (to decrease x0 when subtracted)
1094        assert!(state.gradients[0] > 0.0);
1095    }
1096
1097    #[test]
1098    fn test_compute_gradients_satisfied_clause() {
1099        // (x0) with x0=1.0 - clause is satisfied, gradient should be near zero
1100        let instance = SolveInstance::new(1, vec![Clause::new(vec![Literal::positive(0)])]);
1101        let solver = Solver::new_cpu();
1102        let mut state = SolverState::with_assignments(vec![1.0]);
1103
1104        solver.compute_gradients(&instance, &mut state);
1105
1106        // Gradient should be very small (clause is satisfied)
1107        assert!(state.gradients[0].abs() < 0.01);
1108    }
1109
1110    // ==========================================================================
1111    // Update Assignments Tests
1112    // ==========================================================================
1113
1114    #[test]
1115    fn test_update_assignments_clamps() {
1116        let solver = Solver::with_config_cpu(SolverConfig::default().with_learning_rate(10.0));
1117        let mut state = SolverState::with_assignments(vec![0.5]);
1118        state.gradients = vec![-1.0]; // Large negative gradient
1119
1120        solver.update_assignments(&mut state);
1121
1122        // Should be clamped to [0, 1]
1123        assert!(state.assignments[0] >= 0.0);
1124        assert!(state.assignments[0] <= 1.0);
1125    }
1126
1127    #[test]
1128    fn test_update_assignments_momentum() {
1129        let solver = Solver::with_config_cpu(
1130            SolverConfig::default()
1131                .with_learning_rate(0.1)
1132                .with_momentum(0.5),
1133        );
1134        let mut state = SolverState::with_assignments(vec![0.5]);
1135        state.velocities = vec![0.1]; // Previous velocity
1136        state.gradients = vec![-0.1]; // Gradient
1137
1138        solver.update_assignments(&mut state);
1139
1140        // Velocity should incorporate both momentum and gradient
1141        // v = 0.5 * 0.1 - 0.1 * (-0.1) = 0.05 + 0.01 = 0.06
1142        let expected_velocity = 0.5 * 0.1 - 0.1 * (-0.1);
1143        assert!((state.velocities[0] - expected_velocity).abs() < 1e-6);
1144    }
1145
1146    // ==========================================================================
1147    // Edge Case Tests
1148    // ==========================================================================
1149
1150    #[test]
1151    fn test_solver_empty_clause() {
1152        // Empty clause can never be satisfied
1153        let instance = SolveInstance::new(1, vec![Clause::new(vec![])]);
1154        let solver = Solver::new_cpu();
1155        let result = solver.solve(instance);
1156
1157        // Should return Unknown (CLS cannot prove UNSAT)
1158        assert!(matches!(result.status, SolveStatus::Unknown));
1159    }
1160
1161    #[test]
1162    fn test_solver_large_clause() {
1163        // Large clause (10 literals)
1164        let literals: Vec<Literal> = (0..10).map(Literal::positive).collect();
1165        let instance = SolveInstance::new(10, vec![Clause::new(literals)]);
1166        let solver = Solver::new_cpu();
1167        let result = solver.solve(instance);
1168
1169        assert!(matches!(result.status, SolveStatus::Sat));
1170    }
1171
1172    #[test]
1173    fn test_solver_many_clauses() {
1174        // Many unit clauses (20 variables, each must be true)
1175        let clauses: Vec<Clause> = (0..20)
1176            .map(|i| Clause::new(vec![Literal::positive(i)]))
1177            .collect();
1178        let instance = SolveInstance::new(20, clauses);
1179        let solver = Solver::new_cpu();
1180        let result = solver.solve(instance);
1181
1182        assert!(matches!(result.status, SolveStatus::Sat));
1183        if let Some(assignment) = result.assignment() {
1184            assert!(assignment.iter().all(|&v| v));
1185        }
1186    }
1187
1188    #[test]
1189    fn test_solver_pigeon_hole_small() {
1190        // 2 pigeons, 1 hole - unsatisfiable
1191        // p1_h1: pigeon 1 in hole 1
1192        // p2_h1: pigeon 2 in hole 1
1193        // Clauses:
1194        // (p1_h1) - pigeon 1 must be somewhere
1195        // (p2_h1) - pigeon 2 must be somewhere
1196        // (NOT p1_h1 OR NOT p2_h1) - hole 1 can have at most one pigeon
1197        let instance = SolveInstance::new(
1198            2,
1199            vec![
1200                Clause::new(vec![Literal::positive(0)]), // pigeon 1 in hole 1
1201                Clause::new(vec![Literal::positive(1)]), // pigeon 2 in hole 1
1202                Clause::new(vec![Literal::negative(0), Literal::negative(1)]), // at most one
1203            ],
1204        );
1205        let solver = Solver::new_cpu();
1206        let result = solver.solve(instance);
1207
1208        // CLS cannot prove UNSAT, so should return Unknown
1209        assert!(matches!(result.status, SolveStatus::Unknown));
1210    }
1211
1212    #[test]
1213    fn test_solver_deterministic() {
1214        // Same instance should give same result (solver is deterministic)
1215        let instance = SolveInstance::new(
1216            2,
1217            vec![
1218                Clause::new(vec![Literal::positive(0), Literal::positive(1)]),
1219                Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
1220            ],
1221        );
1222
1223        let solver = Solver::new_cpu();
1224        let result1 = solver.solve(instance.clone());
1225        let result2 = solver.solve(instance);
1226
1227        assert_eq!(result1.status, result2.status);
1228        assert_eq!(result1.assignment(), result2.assignment());
1229    }
1230
1231    #[test]
1232    fn test_solver_with_fast_config() {
1233        // Fast config should still solve simple instances
1234        let instance = SolveInstance::new(
1235            2,
1236            vec![Clause::new(vec![
1237                Literal::positive(0),
1238                Literal::positive(1),
1239            ])],
1240        );
1241        let solver = Solver::with_config_cpu(SolverConfig::fast());
1242        let result = solver.solve(instance);
1243
1244        assert!(matches!(result.status, SolveStatus::Sat));
1245    }
1246
1247    #[test]
1248    fn test_solver_implication_chain() {
1249        // x0 -> x1 -> x2 -> x3, with x0=true
1250        // (x0) AND (NOT x0 OR x1) AND (NOT x1 OR x2) AND (NOT x2 OR x3)
1251        // Should result in all variables true
1252        let instance = SolveInstance::new(
1253            4,
1254            vec![
1255                Clause::new(vec![Literal::positive(0)]),
1256                Clause::new(vec![Literal::negative(0), Literal::positive(1)]),
1257                Clause::new(vec![Literal::negative(1), Literal::positive(2)]),
1258                Clause::new(vec![Literal::negative(2), Literal::positive(3)]),
1259            ],
1260        );
1261        let solver = Solver::new_cpu();
1262        let result = solver.solve(instance);
1263
1264        assert!(matches!(result.status, SolveStatus::Sat));
1265        if let Some(assignment) = result.assignment() {
1266            // All should be true due to implications
1267            assert!(assignment.iter().all(|&v| v));
1268        }
1269    }
1270}