Skip to main content

xlog_solve/
gpu_cdcl.rs

1use std::ffi::c_void;
2use std::sync::Arc;
3
4use cudarc::driver::LaunchConfig;
5use xlog_core::{Result, XlogError};
6use xlog_cuda::memory::TrackedCudaSlice;
7use xlog_cuda::provider::{sat_kernels, SAT_MODULE};
8use xlog_cuda::{AsKernelParam, CudaKernelProvider, DeviceSlice, LaunchAsync};
9
10use crate::gpu_cnf::GpuCnf;
11
12// Must match kernels/sat.cu.
13const SAT_STATUS_UNSAT: i32 = 0;
14const SAT_STATUS_SAT: i32 = 1;
15/// Conflict budget reached before the search terminated (mirrors the kernel's
16/// `SAT_STATUS_BUDGET_EXHAUSTED`). INDETERMINATE — the verifier must treat this
17/// as a fail-closed decline, never as a proof.
18const SAT_STATUS_BUDGET_EXHAUSTED: i32 = 2;
19
20struct GpuCdclRun {
21    assignment: TrackedCudaSlice<i8>,
22    // Scratch buffers used only by sat_cdcl_solve, but must stay alive until the solver kernel completes.
23    #[allow(dead_code)]
24    decision_heap: TrackedCudaSlice<u32>,
25    #[allow(dead_code)]
26    decision_heap_pos: TrackedCudaSlice<u32>,
27
28    learned_offsets: TrackedCudaSlice<u32>,
29    learned_lits: TrackedCudaSlice<i32>,
30    proof_offsets: TrackedCudaSlice<u32>,
31    proof_data: TrackedCudaSlice<u32>,
32
33    out_status: TrackedCudaSlice<i32>,
34    out_error: TrackedCudaSlice<i32>,
35    out_learned_count: TrackedCudaSlice<u32>,
36}
37
38/// Configuration for the GPU CDCL (Conflict-Driven Clause Learning) solver.
39///
40/// Controls arena sizes for learned clauses, proof traces, and the
41/// deterministic restart/reduce schedule. The defaults are tuned for
42/// equivalence verification of circuits with up to ~10 000 variables;
43/// larger problems may need proportionally larger arenas.
44#[derive(Debug, Clone, Copy)]
45#[non_exhaustive]
46pub struct GpuCdclConfig {
47    /// Maximum number of learned clauses retained in the arena.
48    pub max_learned_clauses: u32,
49    /// Maximum total literals across all learned clauses.
50    pub max_learned_lits: u32,
51    /// Maximum size (in u32 words) of the resolution proof trace.
52    pub max_proof_u32: u32,
53    /// Base conflict count between deterministic restarts.
54    pub restart_base: u32,
55    /// Conflict count between clause-database reductions.
56    pub reduce_interval: u32,
57    /// Conflict budget: `0` = unlimited (default; no behavior change). When
58    /// nonzero, the solver bails with status `BUDGET_EXHAUSTED` after this many
59    /// conflicts without terminating, bounding wall-clock so a treewidth-hard
60    /// instance returns gracefully instead of running to a CUDA launch timeout
61    /// that poisons the primary context.
62    pub max_conflicts: u32,
63}
64
65impl Default for GpuCdclConfig {
66    fn default() -> Self {
67        Self {
68            max_learned_clauses: 32_768,
69            max_learned_lits: 262_144,
70            max_proof_u32: 1_048_576,
71            restart_base: 100,
72            reduce_interval: 2000,
73            max_conflicts: 0,
74        }
75    }
76}
77
78/// GPU-native CDCL SAT solver backed by CUDA kernels.
79pub struct GpuCdclSolver {
80    provider: Arc<CudaKernelProvider>,
81    config: GpuCdclConfig,
82}
83
84/// Pre-allocated solver arena for reuse across multiple CDCL solves.
85///
86/// Owns the 30 device buffers that `launch_cdcl_with_decision_ranges_gated` normally
87/// allocates per call. Does NOT own CNF storage (clause_offsets/literals stay on GpuCnf).
88///
89/// Created via [`GpuCdclSolver::new_workspace`]. Passed as `&mut` to `_ws` solver methods.
90///
91/// `reset_for_solve()` is intentionally a no-op: the `sat_cdcl_solve` kernel initializes
92/// mutable state at launch. The learned-import path preserves the learned/proof arenas
93/// below `out_learned_count[0]` and initializes the remaining workspace buffers.
94pub struct GpuCdclWorkspace {
95    // Capacity limits (used for overflow checks)
96    pub(crate) var_cap: usize,
97    pub(crate) clause_total_cap: usize,
98
99    // Variable state (var_cap + 1 each)
100    pub(crate) assign: TrackedCudaSlice<i8>,
101    pub(crate) level: TrackedCudaSlice<u32>,
102    pub(crate) reason: TrackedCudaSlice<i32>,
103    pub(crate) var_activity: TrackedCudaSlice<u32>,
104    pub(crate) var_phase: TrackedCudaSlice<i8>,
105    pub(crate) decision_heap: TrackedCudaSlice<u32>,
106    pub(crate) decision_heap_pos: TrackedCudaSlice<u32>,
107
108    // Trail (var_cap + 1 each)
109    pub(crate) trail: TrackedCudaSlice<i32>,
110    pub(crate) trail_lim: TrackedCudaSlice<u32>,
111
112    // Analysis scratch (var_cap + 1 each)
113    pub(crate) seen: TrackedCudaSlice<u8>,
114    pub(crate) learnt_tmp: TrackedCudaSlice<i32>,
115    pub(crate) proof_vars_tmp: TrackedCudaSlice<u32>,
116    pub(crate) proof_reason_tmp: TrackedCudaSlice<u32>,
117
118    // Watch lists
119    pub(crate) watch0_pos: TrackedCudaSlice<u32>, // clause_total_cap
120    pub(crate) watch1_pos: TrackedCudaSlice<u32>, // clause_total_cap
121    pub(crate) watch_head: TrackedCudaSlice<i32>, // 2 * var_cap
122    pub(crate) watch_next: TrackedCudaSlice<i32>, // 2 * clause_total_cap
123    pub(crate) watch_prev: TrackedCudaSlice<i32>, // 2 * clause_total_cap
124
125    // Learned clause arena
126    pub(crate) learned_offsets: TrackedCudaSlice<u32>, // max_learned_clauses + 1
127    pub(crate) learned_lits: TrackedCudaSlice<i32>,    // max_learned_lits
128    pub(crate) learned_deleted: TrackedCudaSlice<u8>,  // max_learned_clauses
129    pub(crate) learned_lbd: TrackedCudaSlice<u32>,     // max_learned_clauses
130    pub(crate) learned_activity: TrackedCudaSlice<u32>, // max_learned_clauses
131    pub(crate) learned_locked: TrackedCudaSlice<u8>,   // max_learned_clauses
132
133    // Proof trace
134    pub(crate) proof_offsets: TrackedCudaSlice<u32>, // max_learned_clauses + 1
135    pub(crate) proof_data: TrackedCudaSlice<u32>,    // max_proof_u32
136
137    // Scalar outputs
138    pub(crate) out_status: TrackedCudaSlice<i32>, // 1
139    pub(crate) out_error: TrackedCudaSlice<i32>,  // 1
140    pub(crate) out_learned_count: TrackedCudaSlice<u32>, // 1
141}
142
143impl GpuCdclWorkspace {
144    /// No-op: the sat_cdcl_solve kernel initializes mutable state at launch.
145    #[inline]
146    pub(crate) fn reset_for_solve(&mut self) {
147        // Intentionally empty; launch flags decide whether learned/proof arenas are imported.
148    }
149
150    /// Variable capacity this workspace was allocated for.
151    #[inline]
152    #[allow(dead_code)] // diagnostic accessor, retained for debugging
153    pub(crate) fn var_cap(&self) -> usize {
154        self.var_cap
155    }
156
157    /// Total clause capacity (input + learned) this workspace was allocated for.
158    #[inline]
159    #[allow(dead_code)] // diagnostic accessor, retained for debugging
160    pub(crate) fn clause_total_cap(&self) -> usize {
161        self.clause_total_cap
162    }
163
164    /// Device pointer of the assignment buffer (for diagnostics / reuse verification).
165    #[inline]
166    pub fn assign_device_ptr(&self) -> cudarc::driver::sys::CUdeviceptr {
167        self.assign.device_ptr_value()
168    }
169}
170
171/// Raw CDCL outputs (device-resident) for debugging and research.
172///
173/// Production verifier paths should prefer `solve_expect_sat*` / `solve_expect_unsat*`,
174/// which validate results on GPU and return typed errors for status mismatches.
175pub struct GpuCdclRawOutput {
176    pub assignment: TrackedCudaSlice<i8>,
177    pub out_status: TrackedCudaSlice<i32>,
178    pub out_error: TrackedCudaSlice<i32>,
179    pub out_learned_count: TrackedCudaSlice<u32>,
180}
181
182fn checked_solver_len_add_one(context: &str, value: usize) -> Result<usize> {
183    value
184        .checked_add(1)
185        .ok_or_else(|| XlogError::Kernel(format!("{context} length overflow")))
186}
187
188fn checked_solver_len_double(context: &str, value: usize) -> Result<usize> {
189    value
190        .checked_mul(2)
191        .ok_or_else(|| XlogError::Kernel(format!("{context} length overflow")))
192}
193
194impl GpuCdclSolver {
195    fn require_expected_status(
196        &self,
197        out_status: &TrackedCudaSlice<i32>,
198        out_error: &TrackedCudaSlice<i32>,
199        expected_status: i32,
200        context: &'static str,
201    ) -> Result<()> {
202        let actual_status = self
203            .provider
204            .dtoh_scalar_untracked(out_status, 0)
205            .map_err(|e| XlogError::Kernel(format!("Failed to read {context} status: {e}")))?;
206        let actual_error = self
207            .provider
208            .dtoh_scalar_untracked(out_error, 0)
209            .map_err(|e| XlogError::Kernel(format!("Failed to read {context} error: {e}")))?;
210        // A conflict-budget bail is not a solver error and not a proof: it is a
211        // fail-closed decline (the search was cut short on purpose to avoid a
212        // context-poisoning launch timeout). Surface it as the typed,
213        // catchable VerifyBudgetExceeded rather than a generic kernel error.
214        if actual_error == 0 && actual_status == SAT_STATUS_BUDGET_EXHAUSTED {
215            return Err(XlogError::VerifyBudgetExceeded {
216                context: context.to_string(),
217                detail: "CDCL conflict budget exhausted before the search \
218                         terminated (indeterminate; verify declined fail-closed)"
219                    .to_string(),
220            });
221        }
222        if actual_error != 0 || actual_status != expected_status {
223            return Err(XlogError::Kernel(format!(
224                "{context} expected status {expected_status}, got status {actual_status} error {actual_error}"
225            )));
226        }
227        Ok(())
228    }
229
230    fn provider_memory_ptr(&self) -> usize {
231        Arc::as_ptr(self.provider.memory()) as usize
232    }
233
234    fn require_slice_on_provider<T: cudarc::driver::DeviceRepr>(
235        &self,
236        name: &'static str,
237        slice: &TrackedCudaSlice<T>,
238    ) -> Result<()> {
239        let expected_memory = self.provider_memory_ptr();
240        let actual_memory = slice.memory_manager_ptr_value();
241        if actual_memory != expected_memory {
242            return Err(XlogError::UnsupportedEpistemicConstruct {
243                construct: "GPU CDCL solver provider boundary".to_string(),
244                context: format!(
245                    "{name} belongs to memory manager {actual_memory}, expected {expected_memory}"
246                ),
247            });
248        }
249        Ok(())
250    }
251
252    pub(crate) fn require_workspace_on_provider(&self, ws: &GpuCdclWorkspace) -> Result<()> {
253        macro_rules! require_workspace_slice {
254            ($field:ident) => {
255                self.require_slice_on_provider(
256                    concat!("workspace.", stringify!($field)),
257                    &ws.$field,
258                )?
259            };
260        }
261
262        require_workspace_slice!(assign);
263        require_workspace_slice!(level);
264        require_workspace_slice!(reason);
265        require_workspace_slice!(var_activity);
266        require_workspace_slice!(var_phase);
267        require_workspace_slice!(decision_heap);
268        require_workspace_slice!(decision_heap_pos);
269        require_workspace_slice!(trail);
270        require_workspace_slice!(trail_lim);
271        require_workspace_slice!(seen);
272        require_workspace_slice!(learnt_tmp);
273        require_workspace_slice!(proof_vars_tmp);
274        require_workspace_slice!(proof_reason_tmp);
275        require_workspace_slice!(watch0_pos);
276        require_workspace_slice!(watch1_pos);
277        require_workspace_slice!(watch_head);
278        require_workspace_slice!(watch_next);
279        require_workspace_slice!(watch_prev);
280        require_workspace_slice!(learned_offsets);
281        require_workspace_slice!(learned_lits);
282        require_workspace_slice!(learned_deleted);
283        require_workspace_slice!(learned_lbd);
284        require_workspace_slice!(learned_activity);
285        require_workspace_slice!(learned_locked);
286        require_workspace_slice!(proof_offsets);
287        require_workspace_slice!(proof_data);
288        require_workspace_slice!(out_status);
289        require_workspace_slice!(out_error);
290        require_workspace_slice!(out_learned_count);
291        Ok(())
292    }
293
294    pub(crate) fn require_workspace_capacity_for_cnf(
295        &self,
296        ws: &GpuCdclWorkspace,
297        var_cap: u32,
298        clause_cap: u32,
299    ) -> Result<()> {
300        let num_vars_cap = var_cap as usize;
301        if num_vars_cap > ws.var_cap {
302            return Err(XlogError::Kernel(format!(
303                "CNF var_cap {} exceeds workspace var_cap {}",
304                num_vars_cap, ws.var_cap
305            )));
306        }
307
308        let max_learned_clauses = self.config.max_learned_clauses as usize;
309        let max_total_clauses = (clause_cap as usize)
310            .checked_add(max_learned_clauses)
311            .ok_or_else(|| XlogError::Kernel("SAT clause capacity overflow".to_string()))?;
312        if max_total_clauses > ws.clause_total_cap {
313            return Err(XlogError::Kernel(format!(
314                "CNF clause_total {} exceeds workspace clause_total_cap {}",
315                max_total_clauses, ws.clause_total_cap
316            )));
317        }
318
319        Ok(())
320    }
321
322    pub fn new(provider: Arc<CudaKernelProvider>, config: GpuCdclConfig) -> Self {
323        Self { provider, config }
324    }
325
326    /// Pre-allocate a reusable solver arena.
327    ///
328    /// `max_var_cap` and `max_clause_cap` must be >= the `var_cap` / `clause_cap` of any
329    /// `GpuCnf` that will be solved with this workspace. If a solve call exceeds these
330    /// capacities, it returns `XlogError::Kernel`.
331    pub fn new_workspace(&self, max_var_cap: u32, max_clause_cap: u32) -> Result<GpuCdclWorkspace> {
332        let num_vars_cap = max_var_cap as usize;
333        let num_clauses_cap = max_clause_cap as usize;
334        let max_learned_clauses = self.config.max_learned_clauses as usize;
335        let max_learned_lits = self.config.max_learned_lits as usize;
336        let max_proof_u32 = self.config.max_proof_u32 as usize;
337
338        if max_var_cap == 0 {
339            return Err(XlogError::Compilation(
340                "GpuCdclSolver workspace requires max_var_cap > 0".to_string(),
341            ));
342        }
343        if self.config.max_learned_clauses == 0 {
344            return Err(XlogError::Compilation(
345                "GpuCdclSolver requires max_learned_clauses > 0".to_string(),
346            ));
347        }
348        if self.config.max_learned_lits == 0 {
349            return Err(XlogError::Compilation(
350                "GpuCdclSolver requires max_learned_lits > 0".to_string(),
351            ));
352        }
353        if self.config.max_proof_u32 < 2 {
354            return Err(XlogError::Compilation(
355                "GpuCdclSolver requires max_proof_u32 >= 2".to_string(),
356            ));
357        }
358
359        let max_total_clauses = num_clauses_cap
360            .checked_add(max_learned_clauses)
361            .ok_or_else(|| XlogError::Kernel("SAT clause capacity overflow".to_string()))?;
362        let vars_plus_one =
363            checked_solver_len_add_one("SAT workspace variable arena", num_vars_cap)?;
364        let learned_offsets_len =
365            checked_solver_len_add_one("SAT workspace learned offsets", max_learned_clauses)?;
366        let watch_head_len = checked_solver_len_double("SAT workspace watch head", num_vars_cap)?;
367        let watch_clause_len =
368            checked_solver_len_double("SAT workspace watch clauses", max_total_clauses)?;
369
370        let memory = self.provider.memory();
371
372        Ok(GpuCdclWorkspace {
373            var_cap: num_vars_cap,
374            clause_total_cap: max_total_clauses,
375
376            // Variable state
377            assign: memory.alloc::<i8>(vars_plus_one)?,
378            level: memory.alloc::<u32>(vars_plus_one)?,
379            reason: memory.alloc::<i32>(vars_plus_one)?,
380            var_activity: memory.alloc::<u32>(vars_plus_one)?,
381            var_phase: memory.alloc::<i8>(vars_plus_one)?,
382            decision_heap: memory.alloc::<u32>(vars_plus_one)?,
383            decision_heap_pos: memory.alloc::<u32>(vars_plus_one)?,
384
385            // Trail
386            trail: memory.alloc::<i32>(vars_plus_one)?,
387            trail_lim: memory.alloc::<u32>(vars_plus_one)?,
388
389            // Analysis scratch
390            seen: memory.alloc::<u8>(vars_plus_one)?,
391            learnt_tmp: memory.alloc::<i32>(vars_plus_one)?,
392            proof_vars_tmp: memory.alloc::<u32>(vars_plus_one)?,
393            proof_reason_tmp: memory.alloc::<u32>(vars_plus_one)?,
394
395            // Watch lists
396            watch0_pos: memory.alloc::<u32>(max_total_clauses)?,
397            watch1_pos: memory.alloc::<u32>(max_total_clauses)?,
398            watch_head: memory.alloc::<i32>(watch_head_len)?,
399            watch_next: memory.alloc::<i32>(watch_clause_len)?,
400            watch_prev: memory.alloc::<i32>(watch_clause_len)?,
401
402            // Learned
403            learned_offsets: memory.alloc::<u32>(learned_offsets_len)?,
404            learned_lits: memory.alloc::<i32>(max_learned_lits)?,
405            learned_deleted: memory.alloc::<u8>(max_learned_clauses)?,
406            learned_lbd: memory.alloc::<u32>(max_learned_clauses)?,
407            learned_activity: memory.alloc::<u32>(max_learned_clauses)?,
408            learned_locked: memory.alloc::<u8>(max_learned_clauses)?,
409
410            // Proof
411            proof_offsets: memory.alloc::<u32>(learned_offsets_len)?,
412            proof_data: memory.alloc::<u32>(max_proof_u32)?,
413
414            // Outputs
415            out_status: memory.alloc::<i32>(1)?,
416            out_error: memory.alloc::<i32>(1)?,
417            out_learned_count: memory.alloc::<u32>(1)?,
418        })
419    }
420
421    fn alloc_u32_scalar(&self, value: u32) -> Result<TrackedCudaSlice<u32>> {
422        let memory = self.provider.memory();
423        let mut gate = memory.alloc::<u32>(1)?;
424        self.provider
425            .htod_launch_metadata_sync_copy_into(&[value], &mut gate)
426            .map_err(|e| XlogError::Kernel(format!("GpuCdclSolver gate upload failed: {}", e)))?;
427        Ok(gate)
428    }
429
430    fn launch_cdcl_with_decision_ranges_gated(
431        &self,
432        cnf: &GpuCnf,
433        compile_needed: &TrackedCudaSlice<u32>,
434        decision_base_limit: &TrackedCudaSlice<u32>,
435        decision_extra_base: &TrackedCudaSlice<u32>,
436        decision_extra_count: &TrackedCudaSlice<u32>,
437    ) -> Result<GpuCdclRun> {
438        let num_vars_cap = cnf.var_cap as usize;
439        let num_clauses_cap = cnf.clause_cap as usize;
440
441        cnf.require_provider_memory(&self.provider, "GPU CDCL solver provider boundary")?;
442        self.require_slice_on_provider("compile_needed", compile_needed)?;
443        self.require_slice_on_provider("decision_base_limit", decision_base_limit)?;
444        self.require_slice_on_provider("decision_extra_base", decision_extra_base)?;
445        self.require_slice_on_provider("decision_extra_count", decision_extra_count)?;
446        if compile_needed.len() != 1 {
447            return Err(XlogError::Compilation(format!(
448                "GpuCdclSolver requires compile_needed len=1, got {}",
449                compile_needed.len()
450            )));
451        }
452
453        if cnf.var_cap == 0 {
454            return Err(XlogError::Compilation(
455                "GpuCdclSolver requires num_vars > 0".to_string(),
456            ));
457        }
458        if decision_base_limit.len() != 1 {
459            return Err(XlogError::Compilation(format!(
460                "GpuCdclSolver requires decision_base_limit len=1, got {}",
461                decision_base_limit.len()
462            )));
463        }
464        if decision_extra_base.len() != 1 {
465            return Err(XlogError::Compilation(format!(
466                "GpuCdclSolver requires decision_extra_base len=1, got {}",
467                decision_extra_base.len()
468            )));
469        }
470        if decision_extra_count.len() != 1 {
471            return Err(XlogError::Compilation(format!(
472                "GpuCdclSolver requires decision_extra_count len=1, got {}",
473                decision_extra_count.len()
474            )));
475        }
476        if self.config.max_learned_clauses == 0 {
477            return Err(XlogError::Compilation(
478                "GpuCdclSolver requires max_learned_clauses > 0".to_string(),
479            ));
480        }
481        if self.config.max_learned_lits == 0 {
482            return Err(XlogError::Compilation(
483                "GpuCdclSolver requires max_learned_lits > 0".to_string(),
484            ));
485        }
486        if self.config.max_proof_u32 < 2 {
487            return Err(XlogError::Compilation(
488                "GpuCdclSolver requires max_proof_u32 >= 2".to_string(),
489            ));
490        }
491
492        let max_learned_clauses = self.config.max_learned_clauses as usize;
493        let max_learned_lits = self.config.max_learned_lits as usize;
494        let max_proof_u32 = self.config.max_proof_u32 as usize;
495
496        let max_total_clauses = num_clauses_cap
497            .checked_add(max_learned_clauses)
498            .ok_or_else(|| XlogError::Kernel("SAT clause capacity overflow".to_string()))?;
499        let vars_plus_one = checked_solver_len_add_one("SAT variable arena", num_vars_cap)?;
500        let learned_offsets_len =
501            checked_solver_len_add_one("SAT learned offsets", max_learned_clauses)?;
502        let watch_head_len = checked_solver_len_double("SAT watch head", num_vars_cap)?;
503        let watch_clause_len = checked_solver_len_double("SAT watch clauses", max_total_clauses)?;
504
505        let memory = self.provider.memory();
506
507        // Variable state
508        let assign = memory.alloc::<i8>(vars_plus_one)?;
509        let level = memory.alloc::<u32>(vars_plus_one)?;
510        let reason = memory.alloc::<i32>(vars_plus_one)?;
511        let var_activity = memory.alloc::<u32>(vars_plus_one)?;
512        let var_phase = memory.alloc::<i8>(vars_plus_one)?;
513        let decision_heap = memory.alloc::<u32>(vars_plus_one)?;
514        let decision_heap_pos = memory.alloc::<u32>(vars_plus_one)?;
515
516        // Trail / levels
517        let trail = memory.alloc::<i32>(vars_plus_one)?;
518        let trail_lim = memory.alloc::<u32>(vars_plus_one)?;
519
520        // Analysis scratch
521        let seen = memory.alloc::<u8>(vars_plus_one)?;
522        let learnt_tmp = memory.alloc::<i32>(vars_plus_one)?;
523        let proof_vars_tmp = memory.alloc::<u32>(vars_plus_one)?;
524        let proof_reason_tmp = memory.alloc::<u32>(vars_plus_one)?;
525
526        // Watched literals
527        let watch0_pos = memory.alloc::<u32>(max_total_clauses)?;
528        let watch1_pos = memory.alloc::<u32>(max_total_clauses)?;
529        let watch_head = memory.alloc::<i32>(watch_head_len)?;
530        let watch_next = memory.alloc::<i32>(watch_clause_len)?;
531        let watch_prev = memory.alloc::<i32>(watch_clause_len)?;
532
533        // Learned clause arena
534        let learned_offsets = memory.alloc::<u32>(learned_offsets_len)?;
535        let learned_lits = memory.alloc::<i32>(max_learned_lits)?;
536        let learned_deleted = memory.alloc::<u8>(max_learned_clauses)?;
537        let learned_lbd = memory.alloc::<u32>(max_learned_clauses)?;
538        let learned_activity = memory.alloc::<u32>(max_learned_clauses)?;
539        let learned_locked = memory.alloc::<u8>(max_learned_clauses)?;
540
541        // Proof trace arena
542        let proof_offsets = memory.alloc::<u32>(learned_offsets_len)?;
543        let proof_data = memory.alloc::<u32>(max_proof_u32)?;
544
545        // Device-resident outputs
546        let out_status = memory.alloc::<i32>(1)?;
547        let out_error = memory.alloc::<i32>(1)?;
548        let mut out_learned_count = memory.alloc::<u32>(1)?;
549        self.provider
550            .htod_launch_metadata_sync_copy_into(&[0u32], &mut out_learned_count)
551            .map_err(|e| {
552                XlogError::Kernel(format!("Failed to init learned import count: {}", e))
553            })?;
554
555        let sat_fn = self
556            .provider
557            .device()
558            .inner()
559            .get_func(SAT_MODULE, sat_kernels::SAT_CDCL_SOLVE)
560            .ok_or_else(|| XlogError::Kernel("sat_cdcl_solve kernel not found".to_string()))?;
561
562        // IMPORTANT: When launching with an explicit `Vec<*mut c_void>` parameter list, scalar
563        // kernel arguments MUST be backed by stable host storage until `cuLaunchKernel` copies
564        // them. Do not pass temporaries like `self.config.restart_base.as_kernel_param()`.
565        let cnf_var_cap = cnf.var_cap;
566        let cnf_clause_cap = cnf.clause_cap;
567        let cfg_max_learned_clauses = self.config.max_learned_clauses;
568        let cfg_max_learned_lits = self.config.max_learned_lits;
569        let cfg_max_proof_u32 = self.config.max_proof_u32;
570        let cfg_restart_base = self.config.restart_base;
571        let cfg_reduce_interval = self.config.reduce_interval;
572        let cfg_max_conflicts = self.config.max_conflicts;
573        let learned_import_count_param = (&out_learned_count).as_kernel_param();
574
575        let mut params: Vec<*mut c_void> = vec![
576            compile_needed.as_kernel_param(),
577            (&cnf.clause_offsets).as_kernel_param(),
578            (&cnf.literals).as_kernel_param(),
579            (&cnf.num_vars).as_kernel_param(),
580            (&cnf.num_clauses).as_kernel_param(),
581            decision_base_limit.as_kernel_param(),
582            decision_extra_base.as_kernel_param(),
583            decision_extra_count.as_kernel_param(),
584            cnf_var_cap.as_kernel_param(),
585            cnf_clause_cap.as_kernel_param(),
586            cfg_max_learned_clauses.as_kernel_param(),
587            cfg_max_learned_lits.as_kernel_param(),
588            cfg_max_proof_u32.as_kernel_param(),
589            cfg_restart_base.as_kernel_param(),
590            cfg_reduce_interval.as_kernel_param(),
591            cfg_max_conflicts.as_kernel_param(),
592            learned_import_count_param,
593            (&assign).as_kernel_param(),
594            (&level).as_kernel_param(),
595            (&reason).as_kernel_param(),
596            (&var_activity).as_kernel_param(),
597            (&var_phase).as_kernel_param(),
598            (&decision_heap).as_kernel_param(),
599            (&decision_heap_pos).as_kernel_param(),
600            (&trail).as_kernel_param(),
601            (&trail_lim).as_kernel_param(),
602            (&seen).as_kernel_param(),
603            (&learnt_tmp).as_kernel_param(),
604            (&proof_vars_tmp).as_kernel_param(),
605            (&proof_reason_tmp).as_kernel_param(),
606            (&watch0_pos).as_kernel_param(),
607            (&watch1_pos).as_kernel_param(),
608            (&watch_head).as_kernel_param(),
609            (&watch_next).as_kernel_param(),
610            (&watch_prev).as_kernel_param(),
611            (&learned_offsets).as_kernel_param(),
612            (&learned_lits).as_kernel_param(),
613            (&learned_deleted).as_kernel_param(),
614            (&learned_lbd).as_kernel_param(),
615            (&learned_activity).as_kernel_param(),
616            (&learned_locked).as_kernel_param(),
617            (&proof_offsets).as_kernel_param(),
618            (&proof_data).as_kernel_param(),
619            (&out_status).as_kernel_param(),
620            (&out_error).as_kernel_param(),
621            (&out_learned_count).as_kernel_param(),
622        ];
623
624        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
625        unsafe {
626            sat_fn.clone().launch(
627                LaunchConfig {
628                    grid_dim: (1, 1, 1),
629                    // One block per SAT instance; use a full block so sat_cdcl_solve can do
630                    // block-parallel propagation and initialization.
631                    block_dim: (256, 1, 1),
632                    shared_mem_bytes: 0,
633                },
634                &mut params,
635            )
636        }
637        .map_err(|e| XlogError::Kernel(format!("Failed to launch SAT solver kernel: {}", e)))?;
638
639        Ok(GpuCdclRun {
640            assignment: assign,
641            decision_heap,
642            decision_heap_pos,
643            learned_offsets,
644            learned_lits,
645            proof_offsets,
646            proof_data,
647            out_status,
648            out_error,
649            out_learned_count,
650        })
651    }
652
653    /// Launch CDCL using pre-allocated workspace buffers.
654    ///
655    /// Like `launch_cdcl_with_decision_ranges_gated` but uses `ws` buffers instead of
656    /// allocating per call. Returns `Result<()>` — the caller reads `ws.out_*` directly.
657    #[allow(clippy::too_many_arguments)]
658    fn launch_cdcl_with_workspace_gated(
659        &self,
660        ws: &mut GpuCdclWorkspace,
661        cnf: &GpuCnf,
662        compile_needed: &TrackedCudaSlice<u32>,
663        decision_base_limit: &TrackedCudaSlice<u32>,
664        decision_extra_base: &TrackedCudaSlice<u32>,
665        decision_extra_count: &TrackedCudaSlice<u32>,
666        import_existing_learned: bool,
667    ) -> Result<()> {
668        cnf.require_provider_memory(&self.provider, "GPU CDCL solver provider boundary")?;
669        self.require_workspace_on_provider(ws)?;
670        self.require_slice_on_provider("compile_needed", compile_needed)?;
671        self.require_slice_on_provider("decision_base_limit", decision_base_limit)?;
672        self.require_slice_on_provider("decision_extra_base", decision_extra_base)?;
673        self.require_slice_on_provider("decision_extra_count", decision_extra_count)?;
674        if compile_needed.len() != 1 {
675            return Err(XlogError::Compilation(format!(
676                "GpuCdclSolver requires compile_needed len=1, got {}",
677                compile_needed.len()
678            )));
679        }
680
681        self.require_workspace_capacity_for_cnf(ws, cnf.var_cap, cnf.clause_cap)?;
682
683        // Replicate all validation checks from the existing launch method.
684        if cnf.var_cap == 0 {
685            return Err(XlogError::Compilation(
686                "GpuCdclSolver requires num_vars > 0".to_string(),
687            ));
688        }
689        if decision_base_limit.len() != 1 {
690            return Err(XlogError::Compilation(format!(
691                "GpuCdclSolver requires decision_base_limit len=1, got {}",
692                decision_base_limit.len()
693            )));
694        }
695        if decision_extra_base.len() != 1 {
696            return Err(XlogError::Compilation(format!(
697                "GpuCdclSolver requires decision_extra_base len=1, got {}",
698                decision_extra_base.len()
699            )));
700        }
701        if decision_extra_count.len() != 1 {
702            return Err(XlogError::Compilation(format!(
703                "GpuCdclSolver requires decision_extra_count len=1, got {}",
704                decision_extra_count.len()
705            )));
706        }
707        if self.config.max_learned_clauses == 0 {
708            return Err(XlogError::Compilation(
709                "GpuCdclSolver requires max_learned_clauses > 0".to_string(),
710            ));
711        }
712        if self.config.max_learned_lits == 0 {
713            return Err(XlogError::Compilation(
714                "GpuCdclSolver requires max_learned_lits > 0".to_string(),
715            ));
716        }
717        if self.config.max_proof_u32 < 2 {
718            return Err(XlogError::Compilation(
719                "GpuCdclSolver requires max_proof_u32 >= 2".to_string(),
720            ));
721        }
722
723        // No-op: the sat_cdcl_solve kernel initializes all mutable state at launch.
724        ws.reset_for_solve();
725        if !import_existing_learned {
726            self.provider
727                .htod_launch_metadata_sync_copy_into(&[0u32], &mut ws.out_learned_count)
728                .map_err(|e| {
729                    XlogError::Kernel(format!("Failed to init learned import count: {}", e))
730                })?;
731        }
732
733        let sat_fn = self
734            .provider
735            .device()
736            .inner()
737            .get_func(SAT_MODULE, sat_kernels::SAT_CDCL_SOLVE)
738            .ok_or_else(|| XlogError::Kernel("sat_cdcl_solve kernel not found".to_string()))?;
739
740        // Scalar kernel arguments must be backed by stable host storage until cuLaunchKernel
741        // copies them.
742        let cnf_var_cap = cnf.var_cap;
743        let cnf_clause_cap = cnf.clause_cap;
744        let cfg_max_learned_clauses = self.config.max_learned_clauses;
745        let cfg_max_learned_lits = self.config.max_learned_lits;
746        let cfg_max_proof_u32 = self.config.max_proof_u32;
747        let cfg_restart_base = self.config.restart_base;
748        let cfg_reduce_interval = self.config.reduce_interval;
749        let cfg_max_conflicts = self.config.max_conflicts;
750        let learned_import_count_param = (&ws.out_learned_count).as_kernel_param();
751
752        // Parameter order MUST match launch_cdcl_with_decision_ranges_gated exactly.
753        let mut params: Vec<*mut c_void> = vec![
754            compile_needed.as_kernel_param(),
755            (&cnf.clause_offsets).as_kernel_param(),
756            (&cnf.literals).as_kernel_param(),
757            (&cnf.num_vars).as_kernel_param(),
758            (&cnf.num_clauses).as_kernel_param(),
759            decision_base_limit.as_kernel_param(),
760            decision_extra_base.as_kernel_param(),
761            decision_extra_count.as_kernel_param(),
762            cnf_var_cap.as_kernel_param(),
763            cnf_clause_cap.as_kernel_param(),
764            cfg_max_learned_clauses.as_kernel_param(),
765            cfg_max_learned_lits.as_kernel_param(),
766            cfg_max_proof_u32.as_kernel_param(),
767            cfg_restart_base.as_kernel_param(),
768            cfg_reduce_interval.as_kernel_param(),
769            cfg_max_conflicts.as_kernel_param(),
770            learned_import_count_param,
771            (&ws.assign).as_kernel_param(),
772            (&ws.level).as_kernel_param(),
773            (&ws.reason).as_kernel_param(),
774            (&ws.var_activity).as_kernel_param(),
775            (&ws.var_phase).as_kernel_param(),
776            (&ws.decision_heap).as_kernel_param(),
777            (&ws.decision_heap_pos).as_kernel_param(),
778            (&ws.trail).as_kernel_param(),
779            (&ws.trail_lim).as_kernel_param(),
780            (&ws.seen).as_kernel_param(),
781            (&ws.learnt_tmp).as_kernel_param(),
782            (&ws.proof_vars_tmp).as_kernel_param(),
783            (&ws.proof_reason_tmp).as_kernel_param(),
784            (&ws.watch0_pos).as_kernel_param(),
785            (&ws.watch1_pos).as_kernel_param(),
786            (&ws.watch_head).as_kernel_param(),
787            (&ws.watch_next).as_kernel_param(),
788            (&ws.watch_prev).as_kernel_param(),
789            (&ws.learned_offsets).as_kernel_param(),
790            (&ws.learned_lits).as_kernel_param(),
791            (&ws.learned_deleted).as_kernel_param(),
792            (&ws.learned_lbd).as_kernel_param(),
793            (&ws.learned_activity).as_kernel_param(),
794            (&ws.learned_locked).as_kernel_param(),
795            (&ws.proof_offsets).as_kernel_param(),
796            (&ws.proof_data).as_kernel_param(),
797            (&ws.out_status).as_kernel_param(),
798            (&ws.out_error).as_kernel_param(),
799            (&ws.out_learned_count).as_kernel_param(),
800        ];
801
802        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
803        unsafe {
804            sat_fn.clone().launch(
805                LaunchConfig {
806                    grid_dim: (1, 1, 1),
807                    block_dim: (256, 1, 1),
808                    shared_mem_bytes: 0,
809                },
810                &mut params,
811            )
812        }
813        .map_err(|e| XlogError::Kernel(format!("Failed to launch SAT solver kernel: {}", e)))?;
814
815        Ok(())
816    }
817
818    /// Launch CDCL and return raw device outputs without enforcing SAT/UNSAT on device.
819    ///
820    /// This is intentionally **not** used in production verifier paths. It exists so tests and
821    /// debugging tools can inspect `out_status/out_error` without modifying kernel behavior.
822    pub fn solve_raw_with_branch_limit(
823        &self,
824        cnf: &GpuCnf,
825        branch_var_limit: &TrackedCudaSlice<u32>,
826    ) -> Result<GpuCdclRawOutput> {
827        let compile_needed = self.alloc_u32_scalar(1)?;
828        self.solve_raw_with_branch_limit_gated(cnf, &compile_needed, branch_var_limit)
829    }
830
831    /// Gated variant of `solve_raw_with_branch_limit`.
832    pub fn solve_raw_with_branch_limit_gated(
833        &self,
834        cnf: &GpuCnf,
835        compile_needed: &TrackedCudaSlice<u32>,
836        branch_var_limit: &TrackedCudaSlice<u32>,
837    ) -> Result<GpuCdclRawOutput> {
838        let zero = self.alloc_u32_scalar(0)?;
839        let run = self.launch_cdcl_with_decision_ranges_gated(
840            cnf,
841            compile_needed,
842            branch_var_limit,
843            &zero,
844            &zero,
845        )?;
846        // Ensure kernel completion so `out_*` are valid for inspection.
847        self.provider.device().synchronize()?;
848
849        let GpuCdclRun {
850            assignment,
851            out_status,
852            out_error,
853            out_learned_count,
854            ..
855        } = run;
856
857        Ok(GpuCdclRawOutput {
858            assignment,
859            out_status,
860            out_error,
861            out_learned_count,
862        })
863    }
864
865    /// Launch CDCL and return raw device outputs without enforcing SAT/UNSAT on device, using an
866    /// explicit decision variable set:
867    /// - decision vars include all `v` in `1..=decision_base_limit[0]` and
868    ///   `decision_extra_base[0]..(decision_extra_base[0] + decision_extra_count[0] - 1)`.
869    ///
870    /// Production verifier paths should prefer `solve_expect_*` methods which enforce results on
871    /// GPU without host reads.
872    pub fn solve_raw_with_decision_ranges(
873        &self,
874        cnf: &GpuCnf,
875        decision_base_limit: &TrackedCudaSlice<u32>,
876        decision_extra_base: &TrackedCudaSlice<u32>,
877        decision_extra_count: &TrackedCudaSlice<u32>,
878    ) -> Result<GpuCdclRawOutput> {
879        let compile_needed = self.alloc_u32_scalar(1)?;
880        self.solve_raw_with_decision_ranges_gated(
881            cnf,
882            &compile_needed,
883            decision_base_limit,
884            decision_extra_base,
885            decision_extra_count,
886        )
887    }
888
889    /// Gated variant of `solve_raw_with_decision_ranges`.
890    pub fn solve_raw_with_decision_ranges_gated(
891        &self,
892        cnf: &GpuCnf,
893        compile_needed: &TrackedCudaSlice<u32>,
894        decision_base_limit: &TrackedCudaSlice<u32>,
895        decision_extra_base: &TrackedCudaSlice<u32>,
896        decision_extra_count: &TrackedCudaSlice<u32>,
897    ) -> Result<GpuCdclRawOutput> {
898        let run = self.launch_cdcl_with_decision_ranges_gated(
899            cnf,
900            compile_needed,
901            decision_base_limit,
902            decision_extra_base,
903            decision_extra_count,
904        )?;
905        // Ensure kernel completion so `out_*` are valid for inspection.
906        self.provider.device().synchronize()?;
907
908        let GpuCdclRun {
909            assignment,
910            out_status,
911            out_error,
912            out_learned_count,
913            ..
914        } = run;
915
916        Ok(GpuCdclRawOutput {
917            assignment,
918            out_status,
919            out_error,
920            out_learned_count,
921        })
922    }
923
924    /// Solve and enforce SAT entirely on GPU (no device->host reads).
925    pub fn solve_expect_sat(&self, cnf: &GpuCnf) -> Result<TrackedCudaSlice<i8>> {
926        let compile_needed = self.alloc_u32_scalar(1)?;
927        self.solve_expect_sat_gated(cnf, &compile_needed)
928    }
929
930    /// Solve and enforce SAT entirely on GPU (no device->host reads),
931    /// skipping all GPU work if `compile_needed` is 0.
932    pub fn solve_expect_sat_gated(
933        &self,
934        cnf: &GpuCnf,
935        compile_needed: &TrackedCudaSlice<u32>,
936    ) -> Result<TrackedCudaSlice<i8>> {
937        self.solve_expect_sat_with_branch_limit_gated(cnf, compile_needed, &cnf.num_vars)
938    }
939
940    /// Solve and enforce SAT entirely on GPU (no device->host reads), using an explicit decision
941    /// variable set:
942    /// - decision vars include all `v` in `1..=decision_base_limit[0]` and
943    ///   `decision_extra_base[0]..(decision_extra_base[0] + decision_extra_count[0] - 1)`.
944    pub fn solve_expect_sat_with_decision_ranges(
945        &self,
946        cnf: &GpuCnf,
947        decision_base_limit: &TrackedCudaSlice<u32>,
948        decision_extra_base: &TrackedCudaSlice<u32>,
949        decision_extra_count: &TrackedCudaSlice<u32>,
950    ) -> Result<TrackedCudaSlice<i8>> {
951        let compile_needed = self.alloc_u32_scalar(1)?;
952        self.solve_expect_sat_with_decision_ranges_gated(
953            cnf,
954            &compile_needed,
955            decision_base_limit,
956            decision_extra_base,
957            decision_extra_count,
958        )
959    }
960
961    /// Gated variant of `solve_expect_sat_with_decision_ranges`.
962    pub fn solve_expect_sat_with_decision_ranges_gated(
963        &self,
964        cnf: &GpuCnf,
965        compile_needed: &TrackedCudaSlice<u32>,
966        decision_base_limit: &TrackedCudaSlice<u32>,
967        decision_extra_base: &TrackedCudaSlice<u32>,
968        decision_extra_count: &TrackedCudaSlice<u32>,
969    ) -> Result<TrackedCudaSlice<i8>> {
970        #[cfg(debug_assertions)]
971        let trace = std::env::var_os("XLOG_CDCL_TRACE").is_some();
972        #[cfg(debug_assertions)]
973        let t0 = std::time::Instant::now();
974
975        let run = self.launch_cdcl_with_decision_ranges_gated(
976            cnf,
977            compile_needed,
978            decision_base_limit,
979            decision_extra_base,
980            decision_extra_count,
981        )?;
982        self.require_expected_status(
983            &run.out_status,
984            &run.out_error,
985            SAT_STATUS_SAT,
986            "GPU CDCL SAT expectation",
987        )?;
988
989        let device = self.provider.device().inner();
990        let memory = self.provider.memory();
991
992        let assert_status_fn = device
993            .get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_STATUS)
994            .ok_or_else(|| XlogError::Kernel("sat_assert_status kernel not found".to_string()))?;
995        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
996        unsafe {
997            assert_status_fn
998                .clone()
999                .launch(
1000                    LaunchConfig {
1001                        grid_dim: (1, 1, 1),
1002                        block_dim: (1, 1, 1),
1003                        shared_mem_bytes: 0,
1004                    },
1005                    (
1006                        compile_needed,
1007                        &run.out_status,
1008                        &run.out_error,
1009                        SAT_STATUS_SAT,
1010                    ),
1011                )
1012                .map_err(|e| {
1013                    XlogError::Kernel(format!("Failed to launch sat_assert_status: {}", e))
1014                })?;
1015        }
1016        // Fail-fast if the solver did not produce SAT.
1017        self.provider.device().synchronize()?;
1018        #[cfg(debug_assertions)]
1019        if trace {
1020            eprintln!("[xlog-solve] cdcl(sat) time: {:?}", t0.elapsed());
1021        }
1022
1023        let mut out_ok = memory.alloc::<i32>(1)?;
1024        let check_fn = device
1025            .get_func(SAT_MODULE, sat_kernels::SAT_CHECK_MODEL)
1026            .ok_or_else(|| XlogError::Kernel("sat_check_model kernel not found".to_string()))?;
1027        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1028        unsafe {
1029            check_fn
1030                .clone()
1031                .launch(
1032                    LaunchConfig {
1033                        grid_dim: (1, 1, 1),
1034                        block_dim: (256, 1, 1),
1035                        shared_mem_bytes: 0,
1036                    },
1037                    (
1038                        compile_needed,
1039                        &cnf.clause_offsets,
1040                        &cnf.literals,
1041                        &cnf.num_clauses,
1042                        &run.assignment,
1043                        &mut out_ok,
1044                    ),
1045                )
1046                .map_err(|e| {
1047                    XlogError::Kernel(format!("Failed to launch SAT model check: {}", e))
1048                })?;
1049        }
1050
1051        let assert_ok_fn = device
1052            .get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_OK)
1053            .ok_or_else(|| XlogError::Kernel("sat_assert_ok kernel not found".to_string()))?;
1054        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1055        unsafe {
1056            assert_ok_fn
1057                .clone()
1058                .launch(
1059                    LaunchConfig {
1060                        grid_dim: (1, 1, 1),
1061                        block_dim: (1, 1, 1),
1062                        shared_mem_bytes: 0,
1063                    },
1064                    (compile_needed, &out_ok),
1065                )
1066                .map_err(|e| XlogError::Kernel(format!("Failed to launch sat_assert_ok: {}", e)))?;
1067        }
1068        self.provider.device().synchronize()?;
1069        #[cfg(debug_assertions)]
1070        if trace {
1071            eprintln!(
1072                "[xlog-solve] cdcl(sat)+model_check time: {:?}",
1073                t0.elapsed()
1074            );
1075        }
1076
1077        Ok(run.assignment)
1078    }
1079
1080    /// Solve and enforce SAT entirely on GPU (no device->host reads),
1081    /// restricting branching to variables in `1..=branch_var_limit[0]`.
1082    pub fn solve_expect_sat_with_branch_limit(
1083        &self,
1084        cnf: &GpuCnf,
1085        branch_var_limit: &TrackedCudaSlice<u32>,
1086    ) -> Result<TrackedCudaSlice<i8>> {
1087        let compile_needed = self.alloc_u32_scalar(1)?;
1088        self.solve_expect_sat_with_branch_limit_gated(cnf, &compile_needed, branch_var_limit)
1089    }
1090
1091    /// Solve and enforce SAT entirely on GPU (no device->host reads),
1092    /// skipping all GPU work if `compile_needed` is 0, and restricting branching to
1093    /// variables in `1..=branch_var_limit[0]`.
1094    pub fn solve_expect_sat_with_branch_limit_gated(
1095        &self,
1096        cnf: &GpuCnf,
1097        compile_needed: &TrackedCudaSlice<u32>,
1098        branch_var_limit: &TrackedCudaSlice<u32>,
1099    ) -> Result<TrackedCudaSlice<i8>> {
1100        let zero = self.alloc_u32_scalar(0)?;
1101        self.solve_expect_sat_with_decision_ranges_gated(
1102            cnf,
1103            compile_needed,
1104            branch_var_limit,
1105            &zero,
1106            &zero,
1107        )
1108    }
1109
1110    /// Solve and enforce UNSAT entirely on GPU (no device->host reads).
1111    pub fn solve_expect_unsat(&self, cnf: &GpuCnf) -> Result<()> {
1112        let compile_needed = self.alloc_u32_scalar(1)?;
1113        self.solve_expect_unsat_gated(cnf, &compile_needed)
1114    }
1115
1116    /// Solve and enforce UNSAT entirely on GPU (no device->host reads),
1117    /// skipping all GPU work if `compile_needed` is 0.
1118    pub fn solve_expect_unsat_gated(
1119        &self,
1120        cnf: &GpuCnf,
1121        compile_needed: &TrackedCudaSlice<u32>,
1122    ) -> Result<()> {
1123        self.solve_expect_unsat_with_branch_limit_gated(cnf, compile_needed, &cnf.num_vars)
1124    }
1125
1126    /// Solve and enforce UNSAT entirely on GPU (no device->host reads),
1127    /// restricting branching to variables in `1..=branch_var_limit[0]`.
1128    pub fn solve_expect_unsat_with_branch_limit(
1129        &self,
1130        cnf: &GpuCnf,
1131        branch_var_limit: &TrackedCudaSlice<u32>,
1132    ) -> Result<()> {
1133        let compile_needed = self.alloc_u32_scalar(1)?;
1134        self.solve_expect_unsat_with_branch_limit_gated(cnf, &compile_needed, branch_var_limit)
1135    }
1136
1137    /// Solve and enforce UNSAT with GPU proof verification, using an explicit decision variable set:
1138    /// - decision vars include all `v` in `1..=decision_base_limit[0]` and
1139    ///   `decision_extra_base[0]..(decision_extra_base[0] + decision_extra_count[0] - 1)`.
1140    pub fn solve_expect_unsat_with_decision_ranges(
1141        &self,
1142        cnf: &GpuCnf,
1143        decision_base_limit: &TrackedCudaSlice<u32>,
1144        decision_extra_base: &TrackedCudaSlice<u32>,
1145        decision_extra_count: &TrackedCudaSlice<u32>,
1146    ) -> Result<()> {
1147        let compile_needed = self.alloc_u32_scalar(1)?;
1148        self.solve_expect_unsat_with_decision_ranges_gated(
1149            cnf,
1150            &compile_needed,
1151            decision_base_limit,
1152            decision_extra_base,
1153            decision_extra_count,
1154        )
1155    }
1156
1157    /// Solve and enforce UNSAT entirely on GPU (no device->host reads),
1158    /// skipping all GPU work if `compile_needed` is 0, and restricting branching to
1159    /// variables in `1..=branch_var_limit[0]`.
1160    pub fn solve_expect_unsat_with_branch_limit_gated(
1161        &self,
1162        cnf: &GpuCnf,
1163        compile_needed: &TrackedCudaSlice<u32>,
1164        branch_var_limit: &TrackedCudaSlice<u32>,
1165    ) -> Result<()> {
1166        let zero = self.alloc_u32_scalar(0)?;
1167        self.solve_expect_unsat_with_decision_ranges_gated(
1168            cnf,
1169            compile_needed,
1170            branch_var_limit,
1171            &zero,
1172            &zero,
1173        )
1174    }
1175
1176    /// Solve and enforce UNSAT entirely on GPU (no device->host reads), using an explicit decision
1177    /// variable set:
1178    /// - decision vars include all `v` in `1..=decision_base_limit[0]` and
1179    ///   `decision_extra_base[0]..(decision_extra_base[0] + decision_extra_count[0] - 1)`.
1180    pub fn solve_expect_unsat_with_decision_ranges_gated(
1181        &self,
1182        cnf: &GpuCnf,
1183        compile_needed: &TrackedCudaSlice<u32>,
1184        decision_base_limit: &TrackedCudaSlice<u32>,
1185        decision_extra_base: &TrackedCudaSlice<u32>,
1186        decision_extra_count: &TrackedCudaSlice<u32>,
1187    ) -> Result<()> {
1188        #[cfg(debug_assertions)]
1189        let trace = std::env::var_os("XLOG_CDCL_TRACE").is_some();
1190        #[cfg(debug_assertions)]
1191        let t0 = std::time::Instant::now();
1192
1193        let run = self.launch_cdcl_with_decision_ranges_gated(
1194            cnf,
1195            compile_needed,
1196            decision_base_limit,
1197            decision_extra_base,
1198            decision_extra_count,
1199        )?;
1200        self.require_expected_status(
1201            &run.out_status,
1202            &run.out_error,
1203            SAT_STATUS_UNSAT,
1204            "GPU CDCL UNSAT expectation",
1205        )?;
1206
1207        let device = self.provider.device().inner();
1208        let memory = self.provider.memory();
1209
1210        let assert_status_fn = device
1211            .get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_STATUS)
1212            .ok_or_else(|| XlogError::Kernel("sat_assert_status kernel not found".to_string()))?;
1213        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1214        unsafe {
1215            assert_status_fn
1216                .clone()
1217                .launch(
1218                    LaunchConfig {
1219                        grid_dim: (1, 1, 1),
1220                        block_dim: (1, 1, 1),
1221                        shared_mem_bytes: 0,
1222                    },
1223                    (
1224                        compile_needed,
1225                        &run.out_status,
1226                        &run.out_error,
1227                        SAT_STATUS_UNSAT,
1228                    ),
1229                )
1230                .map_err(|e| {
1231                    XlogError::Kernel(format!("Failed to launch sat_assert_status: {}", e))
1232                })?;
1233        }
1234        // Fail-fast if the solver did not produce UNSAT.
1235        self.provider.device().synchronize()?;
1236        #[cfg(debug_assertions)]
1237        if trace {
1238            eprintln!("[xlog-solve] cdcl(unsat) time: {:?}", t0.elapsed());
1239        }
1240
1241        let mut out_ok = memory.alloc::<i32>(1)?;
1242        self.provider
1243            .htod_launch_metadata_sync_copy_into(&[1i32], &mut out_ok)
1244            .map_err(|e| XlogError::Kernel(format!("Failed to init proof out_ok: {}", e)))?;
1245
1246        // sat_proof_check uses scratch buffers sized to `scratch_cap` per verifier block. To keep
1247        // proof checking fast on large instances, allocate multiple scratch regions and verify
1248        // learned clauses in parallel across blocks.
1249        let scratch_cap_u32 = cnf
1250            .var_cap
1251            .checked_add(1)
1252            .ok_or_else(|| XlogError::Kernel("Proof scratch capacity overflow".to_string()))?;
1253        let scratch_cap = scratch_cap_u32 as usize;
1254
1255        let mut proof_blocks: usize = 1;
1256        let mut scratch_a = None;
1257        let mut scratch_b = None;
1258        let mut scratch_map = None;
1259        let mut last_alloc_err: Option<XlogError> = None;
1260        for blocks in [512usize, 256, 128, 64, 32, 16, 8, 4, 2, 1] {
1261            let len = match scratch_cap.checked_mul(blocks) {
1262                Some(v) => v,
1263                None => {
1264                    last_alloc_err = Some(XlogError::Kernel(
1265                        "Proof scratch allocation length overflow".to_string(),
1266                    ));
1267                    continue;
1268                }
1269            };
1270
1271            let a = match memory.alloc::<i32>(len) {
1272                Ok(buf) => buf,
1273                Err(e) => {
1274                    last_alloc_err = Some(e);
1275                    continue;
1276                }
1277            };
1278            let b = match memory.alloc::<i32>(len) {
1279                Ok(buf) => buf,
1280                Err(e) => {
1281                    last_alloc_err = Some(e);
1282                    // Drop `a` before retrying with a smaller configuration.
1283                    drop(a);
1284                    continue;
1285                }
1286            };
1287            let m = match memory.alloc::<u32>(len) {
1288                Ok(buf) => buf,
1289                Err(e) => {
1290                    last_alloc_err = Some(e);
1291                    drop(a);
1292                    drop(b);
1293                    continue;
1294                }
1295            };
1296
1297            proof_blocks = blocks;
1298            scratch_a = Some(a);
1299            scratch_b = Some(b);
1300            scratch_map = Some(m);
1301            break;
1302        }
1303        let scratch_a = scratch_a.ok_or_else(|| {
1304            last_alloc_err.unwrap_or_else(|| {
1305                XlogError::Kernel("Failed to allocate proof scratch buffers".to_string())
1306            })
1307        })?;
1308        let scratch_b = scratch_b
1309            .ok_or_else(|| XlogError::Kernel("Missing proof scratch buffer".to_string()))?;
1310        let mut scratch_map = scratch_map
1311            .ok_or_else(|| XlogError::Kernel("Missing proof scratch map buffer".to_string()))?;
1312        device
1313            .memset_zeros(&mut scratch_map)
1314            .map_err(|e| XlogError::Kernel(format!("Failed to zero proof scratch map: {}", e)))?;
1315        #[cfg(debug_assertions)]
1316        if trace {
1317            eprintln!("[xlog-solve] proof_check blocks: {}", proof_blocks);
1318        }
1319        #[cfg(debug_assertions)]
1320        let t_mark = std::time::Instant::now();
1321
1322        let needed_cap_u32 = self.config.max_learned_clauses;
1323        let needed_cap = needed_cap_u32 as usize;
1324        let mut needed = memory.alloc::<u8>(needed_cap)?;
1325        device
1326            .memset_zeros(&mut needed)
1327            .map_err(|e| XlogError::Kernel(format!("Failed to zero proof needed mask: {}", e)))?;
1328
1329        let mark_needed_fn = device
1330            .get_func(SAT_MODULE, sat_kernels::SAT_PROOF_MARK_NEEDED)
1331            .ok_or_else(|| {
1332                XlogError::Kernel("sat_proof_mark_needed kernel not found".to_string())
1333            })?;
1334        let mut mark_params: Vec<*mut c_void> = vec![
1335            compile_needed.as_kernel_param(),
1336            (&cnf.num_clauses).as_kernel_param(),
1337            (&run.out_learned_count).as_kernel_param(),
1338            (&run.proof_offsets).as_kernel_param(),
1339            (&run.proof_data).as_kernel_param(),
1340            needed_cap_u32.as_kernel_param(),
1341            (&needed).as_kernel_param(),
1342        ];
1343        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1344        unsafe {
1345            mark_needed_fn
1346                .clone()
1347                .launch(
1348                    LaunchConfig {
1349                        grid_dim: (1, 1, 1),
1350                        block_dim: (1, 1, 1),
1351                        shared_mem_bytes: 0,
1352                    },
1353                    &mut mark_params,
1354                )
1355                .map_err(|e| {
1356                    XlogError::Kernel(format!("Failed to launch sat_proof_mark_needed: {}", e))
1357                })?;
1358        }
1359        self.provider.device().synchronize()?;
1360        #[cfg(debug_assertions)]
1361        if trace {
1362            eprintln!(
1363                "[xlog-solve] proof_mark_needed time: {:?}",
1364                t_mark.elapsed()
1365            );
1366        }
1367
1368        let proof_fn = device
1369            .get_func(SAT_MODULE, sat_kernels::SAT_PROOF_CHECK)
1370            .ok_or_else(|| XlogError::Kernel("sat_proof_check kernel not found".to_string()))?;
1371        #[cfg(debug_assertions)]
1372        let t_proof = std::time::Instant::now();
1373        let proof_blocks_u32 = u32::try_from(proof_blocks)
1374            .map_err(|_| XlogError::Kernel("Proof check grid dim exceeds u32::MAX".to_string()))?;
1375        let mut proof_params: Vec<*mut c_void> = vec![
1376            compile_needed.as_kernel_param(),
1377            (&cnf.clause_offsets).as_kernel_param(),
1378            (&cnf.literals).as_kernel_param(),
1379            (&cnf.num_clauses).as_kernel_param(),
1380            (&run.learned_offsets).as_kernel_param(),
1381            (&run.learned_lits).as_kernel_param(),
1382            (&run.out_learned_count).as_kernel_param(),
1383            (&run.proof_offsets).as_kernel_param(),
1384            (&run.proof_data).as_kernel_param(),
1385            (&needed).as_kernel_param(),
1386            needed_cap_u32.as_kernel_param(),
1387            (&scratch_a).as_kernel_param(),
1388            (&scratch_b).as_kernel_param(),
1389            (&scratch_map).as_kernel_param(),
1390            scratch_cap_u32.as_kernel_param(),
1391            (&out_ok).as_kernel_param(),
1392        ];
1393        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1394        unsafe {
1395            proof_fn
1396                .clone()
1397                .launch(
1398                    LaunchConfig {
1399                        grid_dim: (proof_blocks_u32, 1, 1),
1400                        block_dim: (128, 1, 1),
1401                        shared_mem_bytes: 0,
1402                    },
1403                    &mut proof_params,
1404                )
1405                .map_err(|e| {
1406                    XlogError::Kernel(format!("Failed to launch SAT proof check: {}", e))
1407                })?;
1408        }
1409
1410        let assert_ok_fn = device
1411            .get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_OK)
1412            .ok_or_else(|| XlogError::Kernel("sat_assert_ok kernel not found".to_string()))?;
1413        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1414        unsafe {
1415            assert_ok_fn
1416                .clone()
1417                .launch(
1418                    LaunchConfig {
1419                        grid_dim: (1, 1, 1),
1420                        block_dim: (1, 1, 1),
1421                        shared_mem_bytes: 0,
1422                    },
1423                    (compile_needed, &out_ok),
1424                )
1425                .map_err(|e| XlogError::Kernel(format!("Failed to launch sat_assert_ok: {}", e)))?;
1426        }
1427        self.provider.device().synchronize()?;
1428        #[cfg(debug_assertions)]
1429        if trace {
1430            eprintln!("[xlog-solve] proof_check time: {:?}", t_proof.elapsed());
1431            eprintln!(
1432                "[xlog-solve] cdcl(unsat)+proof_check time: {:?}",
1433                t0.elapsed()
1434            );
1435        }
1436
1437        Ok(())
1438    }
1439
1440    // ── Workspace-reuse variants ──────────────────────────────────────────
1441
1442    /// Solve and enforce UNSAT entirely on GPU using a pre-allocated workspace,
1443    /// restricting branching to variables in `1..=branch_var_limit[0]`.
1444    pub fn solve_expect_unsat_with_branch_limit_ws(
1445        &self,
1446        ws: &mut GpuCdclWorkspace,
1447        cnf: &GpuCnf,
1448        branch_var_limit: &TrackedCudaSlice<u32>,
1449    ) -> Result<()> {
1450        let compile_needed = self.alloc_u32_scalar(1)?;
1451        self.solve_expect_unsat_with_branch_limit_gated_ws(
1452            ws,
1453            cnf,
1454            &compile_needed,
1455            branch_var_limit,
1456        )
1457    }
1458
1459    /// Gated workspace variant: solve and enforce UNSAT entirely on GPU,
1460    /// restricting branching to variables in `1..=branch_var_limit[0]`.
1461    /// Skips all GPU work if `compile_needed` is 0.
1462    pub fn solve_expect_unsat_with_branch_limit_gated_ws(
1463        &self,
1464        ws: &mut GpuCdclWorkspace,
1465        cnf: &GpuCnf,
1466        compile_needed: &TrackedCudaSlice<u32>,
1467        branch_var_limit: &TrackedCudaSlice<u32>,
1468    ) -> Result<()> {
1469        let zero = self.alloc_u32_scalar(0)?;
1470        self.solve_expect_unsat_with_decision_ranges_gated_ws(
1471            ws,
1472            cnf,
1473            compile_needed,
1474            branch_var_limit,
1475            &zero,
1476            &zero,
1477        )
1478    }
1479
1480    /// Solve and enforce UNSAT entirely on GPU using a pre-allocated workspace,
1481    /// with explicit decision variable ranges.
1482    pub fn solve_expect_unsat_with_decision_ranges_ws(
1483        &self,
1484        ws: &mut GpuCdclWorkspace,
1485        cnf: &GpuCnf,
1486        decision_base_limit: &TrackedCudaSlice<u32>,
1487        decision_extra_base: &TrackedCudaSlice<u32>,
1488        decision_extra_count: &TrackedCudaSlice<u32>,
1489    ) -> Result<()> {
1490        let compile_needed = self.alloc_u32_scalar(1)?;
1491        self.solve_expect_unsat_with_decision_ranges_gated_ws(
1492            ws,
1493            cnf,
1494            &compile_needed,
1495            decision_base_limit,
1496            decision_extra_base,
1497            decision_extra_count,
1498        )
1499    }
1500
1501    /// Gated workspace variant (LEAF): solve and enforce UNSAT entirely on GPU
1502    /// using a pre-allocated workspace, with explicit decision variable ranges.
1503    /// Skips all GPU work if `compile_needed` is 0.
1504    ///
1505    /// This is the leaf implementation for all `_ws` UNSAT methods. It:
1506    /// 1. Launches CDCL via `launch_cdcl_with_workspace_gated`
1507    /// 2. Asserts UNSAT status on GPU
1508    /// 3. Verifies the UNSAT proof on GPU (sat_proof_mark_needed + sat_proof_check + sat_assert_ok)
1509    pub fn solve_expect_unsat_with_decision_ranges_gated_ws(
1510        &self,
1511        ws: &mut GpuCdclWorkspace,
1512        cnf: &GpuCnf,
1513        compile_needed: &TrackedCudaSlice<u32>,
1514        decision_base_limit: &TrackedCudaSlice<u32>,
1515        decision_extra_base: &TrackedCudaSlice<u32>,
1516        decision_extra_count: &TrackedCudaSlice<u32>,
1517    ) -> Result<()> {
1518        self.solve_expect_unsat_with_decision_ranges_gated_ws_inner(
1519            ws,
1520            cnf,
1521            compile_needed,
1522            decision_base_limit,
1523            decision_extra_base,
1524            decision_extra_count,
1525            false,
1526        )
1527    }
1528
1529    /// Solve and enforce UNSAT on GPU while importing the current workspace learned arena.
1530    ///
1531    /// The imported learned count is read from `ws.out_learned_count` on device. Callers must only
1532    /// use this when the imported learned clauses are valid for `cnf`.
1533    pub fn solve_expect_unsat_with_branch_limit_ws_importing_learned(
1534        &self,
1535        ws: &mut GpuCdclWorkspace,
1536        cnf: &GpuCnf,
1537        branch_var_limit: &TrackedCudaSlice<u32>,
1538    ) -> Result<()> {
1539        let compile_needed = self.alloc_u32_scalar(1)?;
1540        let zero = self.alloc_u32_scalar(0)?;
1541        self.solve_expect_unsat_with_decision_ranges_gated_ws_inner(
1542            ws,
1543            cnf,
1544            &compile_needed,
1545            branch_var_limit,
1546            &zero,
1547            &zero,
1548            true,
1549        )
1550    }
1551
1552    #[allow(clippy::too_many_arguments)]
1553    fn solve_expect_unsat_with_decision_ranges_gated_ws_inner(
1554        &self,
1555        ws: &mut GpuCdclWorkspace,
1556        cnf: &GpuCnf,
1557        compile_needed: &TrackedCudaSlice<u32>,
1558        decision_base_limit: &TrackedCudaSlice<u32>,
1559        decision_extra_base: &TrackedCudaSlice<u32>,
1560        decision_extra_count: &TrackedCudaSlice<u32>,
1561        import_existing_learned: bool,
1562    ) -> Result<()> {
1563        #[cfg(debug_assertions)]
1564        let trace = std::env::var_os("XLOG_CDCL_TRACE").is_some();
1565        #[cfg(debug_assertions)]
1566        let t0 = std::time::Instant::now();
1567
1568        self.launch_cdcl_with_workspace_gated(
1569            ws,
1570            cnf,
1571            compile_needed,
1572            decision_base_limit,
1573            decision_extra_base,
1574            decision_extra_count,
1575            import_existing_learned,
1576        )?;
1577        self.require_expected_status(
1578            &ws.out_status,
1579            &ws.out_error,
1580            SAT_STATUS_UNSAT,
1581            "GPU CDCL workspace UNSAT expectation",
1582        )?;
1583
1584        let device = self.provider.device().inner();
1585        let memory = self.provider.memory();
1586
1587        let assert_status_fn = device
1588            .get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_STATUS)
1589            .ok_or_else(|| XlogError::Kernel("sat_assert_status kernel not found".to_string()))?;
1590        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1591        unsafe {
1592            assert_status_fn
1593                .clone()
1594                .launch(
1595                    LaunchConfig {
1596                        grid_dim: (1, 1, 1),
1597                        block_dim: (1, 1, 1),
1598                        shared_mem_bytes: 0,
1599                    },
1600                    (
1601                        compile_needed,
1602                        &ws.out_status,
1603                        &ws.out_error,
1604                        SAT_STATUS_UNSAT,
1605                    ),
1606                )
1607                .map_err(|e| {
1608                    XlogError::Kernel(format!("Failed to launch sat_assert_status: {}", e))
1609                })?;
1610        }
1611        // Fail-fast if the solver did not produce UNSAT.
1612        self.provider.device().synchronize()?;
1613        #[cfg(debug_assertions)]
1614        if trace {
1615            eprintln!("[xlog-solve] cdcl_ws(unsat) time: {:?}", t0.elapsed());
1616        }
1617
1618        let mut out_ok = memory.alloc::<i32>(1)?;
1619        self.provider
1620            .htod_launch_metadata_sync_copy_into(&[1i32], &mut out_ok)
1621            .map_err(|e| XlogError::Kernel(format!("Failed to init proof out_ok: {}", e)))?;
1622
1623        // sat_proof_check uses scratch buffers sized to `scratch_cap` per verifier block. To keep
1624        // proof checking fast on large instances, allocate multiple scratch regions and verify
1625        // learned clauses in parallel across blocks.
1626        let scratch_cap_u32 = cnf
1627            .var_cap
1628            .checked_add(1)
1629            .ok_or_else(|| XlogError::Kernel("Proof scratch capacity overflow".to_string()))?;
1630        let scratch_cap = scratch_cap_u32 as usize;
1631
1632        let mut proof_blocks: usize = 1;
1633        let mut scratch_a = None;
1634        let mut scratch_b = None;
1635        let mut scratch_map = None;
1636        let mut last_alloc_err: Option<XlogError> = None;
1637        for blocks in [512usize, 256, 128, 64, 32, 16, 8, 4, 2, 1] {
1638            let len = match scratch_cap.checked_mul(blocks) {
1639                Some(v) => v,
1640                None => {
1641                    last_alloc_err = Some(XlogError::Kernel(
1642                        "Proof scratch allocation length overflow".to_string(),
1643                    ));
1644                    continue;
1645                }
1646            };
1647
1648            let a = match memory.alloc::<i32>(len) {
1649                Ok(buf) => buf,
1650                Err(e) => {
1651                    last_alloc_err = Some(e);
1652                    continue;
1653                }
1654            };
1655            let b = match memory.alloc::<i32>(len) {
1656                Ok(buf) => buf,
1657                Err(e) => {
1658                    last_alloc_err = Some(e);
1659                    // Drop `a` before retrying with a smaller configuration.
1660                    drop(a);
1661                    continue;
1662                }
1663            };
1664            let m = match memory.alloc::<u32>(len) {
1665                Ok(buf) => buf,
1666                Err(e) => {
1667                    last_alloc_err = Some(e);
1668                    drop(a);
1669                    drop(b);
1670                    continue;
1671                }
1672            };
1673
1674            proof_blocks = blocks;
1675            scratch_a = Some(a);
1676            scratch_b = Some(b);
1677            scratch_map = Some(m);
1678            break;
1679        }
1680        let scratch_a = scratch_a.ok_or_else(|| {
1681            last_alloc_err.unwrap_or_else(|| {
1682                XlogError::Kernel("Failed to allocate proof scratch buffers".to_string())
1683            })
1684        })?;
1685        let scratch_b = scratch_b
1686            .ok_or_else(|| XlogError::Kernel("Missing proof scratch buffer".to_string()))?;
1687        let mut scratch_map = scratch_map
1688            .ok_or_else(|| XlogError::Kernel("Missing proof scratch map buffer".to_string()))?;
1689        device
1690            .memset_zeros(&mut scratch_map)
1691            .map_err(|e| XlogError::Kernel(format!("Failed to zero proof scratch map: {}", e)))?;
1692        #[cfg(debug_assertions)]
1693        if trace {
1694            eprintln!("[xlog-solve] proof_check_ws blocks: {}", proof_blocks);
1695        }
1696        #[cfg(debug_assertions)]
1697        let t_mark = std::time::Instant::now();
1698
1699        let needed_cap_u32 = self.config.max_learned_clauses;
1700        let needed_cap = needed_cap_u32 as usize;
1701        let mut needed = memory.alloc::<u8>(needed_cap)?;
1702        device
1703            .memset_zeros(&mut needed)
1704            .map_err(|e| XlogError::Kernel(format!("Failed to zero proof needed mask: {}", e)))?;
1705
1706        let mark_needed_fn = device
1707            .get_func(SAT_MODULE, sat_kernels::SAT_PROOF_MARK_NEEDED)
1708            .ok_or_else(|| {
1709                XlogError::Kernel("sat_proof_mark_needed kernel not found".to_string())
1710            })?;
1711        let mut mark_params: Vec<*mut c_void> = vec![
1712            compile_needed.as_kernel_param(),
1713            (&cnf.num_clauses).as_kernel_param(),
1714            (&ws.out_learned_count).as_kernel_param(),
1715            (&ws.proof_offsets).as_kernel_param(),
1716            (&ws.proof_data).as_kernel_param(),
1717            needed_cap_u32.as_kernel_param(),
1718            (&needed).as_kernel_param(),
1719        ];
1720        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1721        unsafe {
1722            mark_needed_fn
1723                .clone()
1724                .launch(
1725                    LaunchConfig {
1726                        grid_dim: (1, 1, 1),
1727                        block_dim: (1, 1, 1),
1728                        shared_mem_bytes: 0,
1729                    },
1730                    &mut mark_params,
1731                )
1732                .map_err(|e| {
1733                    XlogError::Kernel(format!("Failed to launch sat_proof_mark_needed: {}", e))
1734                })?;
1735        }
1736        self.provider.device().synchronize()?;
1737        #[cfg(debug_assertions)]
1738        if trace {
1739            eprintln!(
1740                "[xlog-solve] proof_mark_needed_ws time: {:?}",
1741                t_mark.elapsed()
1742            );
1743        }
1744
1745        let proof_fn = device
1746            .get_func(SAT_MODULE, sat_kernels::SAT_PROOF_CHECK)
1747            .ok_or_else(|| XlogError::Kernel("sat_proof_check kernel not found".to_string()))?;
1748        #[cfg(debug_assertions)]
1749        let t_proof = std::time::Instant::now();
1750        let proof_blocks_u32 = u32::try_from(proof_blocks)
1751            .map_err(|_| XlogError::Kernel("Proof check grid dim exceeds u32::MAX".to_string()))?;
1752        let mut proof_params: Vec<*mut c_void> = vec![
1753            compile_needed.as_kernel_param(),
1754            (&cnf.clause_offsets).as_kernel_param(),
1755            (&cnf.literals).as_kernel_param(),
1756            (&cnf.num_clauses).as_kernel_param(),
1757            (&ws.learned_offsets).as_kernel_param(),
1758            (&ws.learned_lits).as_kernel_param(),
1759            (&ws.out_learned_count).as_kernel_param(),
1760            (&ws.proof_offsets).as_kernel_param(),
1761            (&ws.proof_data).as_kernel_param(),
1762            (&needed).as_kernel_param(),
1763            needed_cap_u32.as_kernel_param(),
1764            (&scratch_a).as_kernel_param(),
1765            (&scratch_b).as_kernel_param(),
1766            (&scratch_map).as_kernel_param(),
1767            scratch_cap_u32.as_kernel_param(),
1768            (&out_ok).as_kernel_param(),
1769        ];
1770        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1771        unsafe {
1772            proof_fn
1773                .clone()
1774                .launch(
1775                    LaunchConfig {
1776                        grid_dim: (proof_blocks_u32, 1, 1),
1777                        block_dim: (128, 1, 1),
1778                        shared_mem_bytes: 0,
1779                    },
1780                    &mut proof_params,
1781                )
1782                .map_err(|e| {
1783                    XlogError::Kernel(format!("Failed to launch SAT proof check: {}", e))
1784                })?;
1785        }
1786
1787        let assert_ok_fn = device
1788            .get_func(SAT_MODULE, sat_kernels::SAT_ASSERT_OK)
1789            .ok_or_else(|| XlogError::Kernel("sat_assert_ok kernel not found".to_string()))?;
1790        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1791        unsafe {
1792            assert_ok_fn
1793                .clone()
1794                .launch(
1795                    LaunchConfig {
1796                        grid_dim: (1, 1, 1),
1797                        block_dim: (1, 1, 1),
1798                        shared_mem_bytes: 0,
1799                    },
1800                    (compile_needed, &out_ok),
1801                )
1802                .map_err(|e| XlogError::Kernel(format!("Failed to launch sat_assert_ok: {}", e)))?;
1803        }
1804        self.provider.device().synchronize()?;
1805        #[cfg(debug_assertions)]
1806        if trace {
1807            eprintln!("[xlog-solve] proof_check_ws time: {:?}", t_proof.elapsed());
1808            eprintln!(
1809                "[xlog-solve] cdcl_ws(unsat)+proof_check time: {:?}",
1810                t0.elapsed()
1811            );
1812        }
1813
1814        Ok(())
1815    }
1816}