1use std::sync::Arc;
4
5use std::ffi::c_void;
6
7use cudarc::driver::LaunchConfig;
8use xlog_core::{Result, XlogError};
9use xlog_cuda::memory::TrackedCudaSlice;
10use xlog_cuda::provider::sat_kernels;
11use xlog_cuda::provider::SAT_MODULE;
12use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
13use xlog_solve::{GpuCdclConfig, GpuCdclSolver, GpuCnf};
14
15#[cfg(debug_assertions)]
16use crate::compilation::gpu_d4::validate_cnf_gpu;
17
18use crate::gpu::GpuXgcf;
19
20const MAX_GRID_X: u64 = 65_535;
21
22fn checked_launch_grid(elements: u32, block: u32, context: &str) -> Result<u32> {
23 if block == 0 {
24 return Err(XlogError::Kernel(format!(
25 "{context}: CUDA launch block size must be nonzero"
26 )));
27 }
28 let grid = if elements == 0 {
29 1
30 } else {
31 u64::from(elements).div_ceil(u64::from(block))
32 };
33 if grid > MAX_GRID_X {
34 return Err(XlogError::Kernel(format!(
35 "{context}: launch grid {grid} exceeds x-dimension limit {MAX_GRID_X} \
36 for {elements} elements with block size {block}"
37 )));
38 }
39 Ok(grid as u32)
40}
41
42fn checked_clause_offset_span(clause_cap: u32, context: &str) -> Result<u32> {
43 clause_cap
44 .checked_add(1)
45 .ok_or_else(|| XlogError::Kernel(format!("{context}: clause offset span overflow")))
46}
47
48#[derive(Debug, Clone, Copy, Default)]
55#[non_exhaustive]
56pub struct GpuEquivalenceConfig {
57 pub cdcl: GpuCdclConfig,
59 pub reuse_workspace: bool,
61}
62
63pub struct GpuEquivalenceQueries {
65 pub q1: GpuCnf,
66 pub q2: GpuCnf,
67 pub q2_unsat_var_base: TrackedCudaSlice<u32>,
69}
70
71struct CircuitCnf {
72 cnf: GpuCnf,
73 internal_prefix: TrackedCudaSlice<u32>,
76}
77
78fn build_circuit_cnf(
79 provider: &Arc<CudaKernelProvider>,
80 circuit: &GpuXgcf,
81 base_num_vars: &TrackedCudaSlice<u32>,
82 base_var_cap: u32,
83 compile_needed: &TrackedCudaSlice<u32>,
84) -> Result<CircuitCnf> {
85 if base_var_cap == 0 {
86 return Err(XlogError::Compilation(
87 "GPU equivalence verifier requires base_var_cap > 0".to_string(),
88 ));
89 }
90 if circuit.max_var() > base_var_cap {
91 return Err(XlogError::Compilation(format!(
92 "Circuit references var {} but base CNF has only {} vars",
93 circuit.max_var(),
94 base_var_cap
95 )));
96 }
97
98 let num_nodes = circuit.num_nodes();
99 if num_nodes == 0 {
100 return Err(XlogError::Compilation(
101 "GPU equivalence verifier requires circuit with num_nodes > 0".to_string(),
102 ));
103 }
104 if circuit.root() as usize >= num_nodes {
105 return Err(XlogError::Compilation(format!(
106 "GPU equivalence verifier: circuit root {} out of bounds (num_nodes={})",
107 circuit.root(),
108 num_nodes
109 )));
110 }
111
112 let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
113 XlogError::Compilation(format!(
114 "GPU equivalence verifier: circuit num_nodes {} exceeds u32::MAX",
115 num_nodes
116 ))
117 })?;
118
119 let num_edges = circuit.num_edges();
121 let n64 = num_nodes as u64;
122 let e64 = num_edges as u64;
123
124 let var_cap = u32::try_from((base_var_cap as u64).saturating_add(n64))
125 .map_err(|_| XlogError::Kernel("Circuit CNF var capacity exceeds u32::MAX".to_string()))?;
126 let clause_cap =
127 u32::try_from(e64.checked_add(4u64.saturating_mul(n64)).ok_or_else(|| {
128 XlogError::Kernel("Circuit CNF clause capacity overflow".to_string())
129 })?)
130 .map_err(|_| {
131 XlogError::Kernel("Circuit CNF clause capacity exceeds u32::MAX".to_string())
132 })?;
133 let lit_cap = u32::try_from(
134 (3u64.saturating_mul(e64))
135 .checked_add(12u64.saturating_mul(n64))
136 .ok_or_else(|| {
137 XlogError::Kernel("Circuit CNF literal capacity overflow".to_string())
138 })?,
139 )
140 .map_err(|_| XlogError::Kernel("Circuit CNF literal capacity exceeds u32::MAX".to_string()))?;
141
142 let memory = provider.memory();
143 let device = provider.device().inner();
144
145 let mut internal_prefix = memory.alloc::<u32>(num_nodes)?;
147 let mut clause_base = memory.alloc::<u32>(num_nodes)?;
148 let mut lit_base = memory.alloc::<u32>(num_nodes)?;
149
150 let counts_fn = device
151 .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_COUNTS)
152 .ok_or_else(|| XlogError::Kernel("sat_xgcf_cnf_counts kernel not found".to_string()))?;
153
154 let block = 256u32;
155 let grid = checked_launch_grid(num_nodes_u32, block, "sat_xgcf_cnf_counts")?;
156
157 unsafe {
159 counts_fn.clone().launch(
160 LaunchConfig {
161 grid_dim: (grid, 1, 1),
162 block_dim: (block, 1, 1),
163 shared_mem_bytes: 0,
164 },
165 (
166 compile_needed,
167 circuit.node_type(),
168 circuit.child_offsets(),
169 num_nodes_u32,
170 &mut internal_prefix,
171 &mut clause_base,
172 &mut lit_base,
173 ),
174 )
175 }
176 .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_counts failed: {}", e)))?;
177
178 let mut internal_last = memory.alloc::<u32>(1)?;
180 let mut clause_last = memory.alloc::<u32>(1)?;
181 let mut lit_last = memory.alloc::<u32>(1)?;
182
183 let capture_last_fn = device
184 .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_CAPTURE_LAST_COUNTS)
185 .ok_or_else(|| {
186 XlogError::Kernel("sat_xgcf_cnf_capture_last_counts kernel not found".to_string())
187 })?;
188 unsafe {
190 capture_last_fn.clone().launch(
191 LaunchConfig {
192 grid_dim: (1, 1, 1),
193 block_dim: (1, 1, 1),
194 shared_mem_bytes: 0,
195 },
196 (
197 &internal_prefix,
198 &clause_base,
199 &lit_base,
200 num_nodes_u32,
201 &mut internal_last,
202 &mut clause_last,
203 &mut lit_last,
204 ),
205 )
206 }
207 .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_capture_last_counts failed: {}", e)))?;
208
209 provider.exclusive_scan_u32_inplace(&mut internal_prefix, num_nodes_u32)?;
210 provider.exclusive_scan_u32_inplace(&mut clause_base, num_nodes_u32)?;
211 provider.exclusive_scan_u32_inplace(&mut lit_base, num_nodes_u32)?;
212 let d_num_vars = memory.alloc::<u32>(1)?;
216 let d_num_clauses = memory.alloc::<u32>(1)?;
217 let d_num_lits = memory.alloc::<u32>(1)?;
218 let mut d_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
219 let d_lits = memory.alloc::<i32>(lit_cap as usize)?;
220
221 let totals_fn = device
222 .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_COMPUTE_TOTALS)
223 .ok_or_else(|| {
224 XlogError::Kernel("sat_xgcf_cnf_compute_totals kernel not found".to_string())
225 })?;
226 let mut totals_params: Vec<*mut c_void> = vec![
228 (&internal_prefix).as_kernel_param(),
229 (&clause_base).as_kernel_param(),
230 (&lit_base).as_kernel_param(),
231 (&internal_last).as_kernel_param(),
232 (&clause_last).as_kernel_param(),
233 (&lit_last).as_kernel_param(),
234 num_nodes_u32.as_kernel_param(),
235 (base_num_vars).as_kernel_param(),
236 clause_cap.as_kernel_param(),
237 lit_cap.as_kernel_param(),
238 (&d_num_vars).as_kernel_param(),
239 (&d_num_clauses).as_kernel_param(),
240 (&d_num_lits).as_kernel_param(),
241 ];
242 unsafe {
244 totals_fn.clone().launch(
245 LaunchConfig {
246 grid_dim: (1, 1, 1),
247 block_dim: (1, 1, 1),
248 shared_mem_bytes: 0,
249 },
250 &mut totals_params,
251 )
252 }
253 .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_compute_totals failed: {}", e)))?;
254
255 let emit_fn = device
256 .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_CNF_EMIT)
257 .ok_or_else(|| XlogError::Kernel("sat_xgcf_cnf_emit kernel not found".to_string()))?;
258
259 let mut params: Vec<*mut c_void> = vec![
262 compile_needed.as_kernel_param(),
263 circuit.node_type().as_kernel_param(),
264 circuit.child_offsets().as_kernel_param(),
265 circuit.child_indices().as_kernel_param(),
266 circuit.lit().as_kernel_param(),
267 circuit.decision_var().as_kernel_param(),
268 circuit.decision_child_false().as_kernel_param(),
269 circuit.decision_child_true().as_kernel_param(),
270 (&internal_prefix).as_kernel_param(),
271 (&clause_base).as_kernel_param(),
272 (&lit_base).as_kernel_param(),
273 (base_num_vars).as_kernel_param(),
274 num_nodes_u32.as_kernel_param(),
275 (&d_offsets).as_kernel_param(),
276 (&d_lits).as_kernel_param(),
277 ];
278
279 unsafe {
281 emit_fn.clone().launch(
282 LaunchConfig {
283 grid_dim: (grid, 1, 1),
284 block_dim: (block, 1, 1),
285 shared_mem_bytes: 0,
286 },
287 &mut params,
288 )
289 }
290 .map_err(|e| XlogError::Kernel(format!("sat_xgcf_cnf_emit failed: {}", e)))?;
291
292 let term_fn = device
294 .get_func(SAT_MODULE, sat_kernels::SAT_CNF_WRITE_TERMINATOR)
295 .ok_or_else(|| {
296 XlogError::Kernel("sat_cnf_write_terminator kernel not found".to_string())
297 })?;
298 unsafe {
300 term_fn.clone().launch(
301 LaunchConfig {
302 grid_dim: (1, 1, 1),
303 block_dim: (1, 1, 1),
304 shared_mem_bytes: 0,
305 },
306 (&mut d_offsets, &d_num_clauses, &d_num_lits),
307 )
308 }
309 .map_err(|e| XlogError::Kernel(format!("sat_cnf_write_terminator failed: {}", e)))?;
310 Ok(CircuitCnf {
313 cnf: GpuCnf {
314 var_cap,
315 clause_cap,
316 lit_cap,
317 num_vars: d_num_vars,
318 num_clauses: d_num_clauses,
319 num_lits: d_num_lits,
320 clause_offsets: d_offsets,
321 literals: d_lits,
322 },
323 internal_prefix,
324 })
325}
326
327fn build_phi_and_not_c(
328 provider: &Arc<CudaKernelProvider>,
329 phi: &GpuCnf,
330 circuit: &GpuXgcf,
331 circuit_cnf: &CircuitCnf,
332 compile_needed: &TrackedCudaSlice<u32>,
333) -> Result<GpuCnf> {
334 let device = provider.device().inner();
335 let memory = provider.memory();
336
337 let phi_clause_cap = phi.clause_cap;
338 let phi_lit_cap = phi.lit_cap;
339
340 let clause_cap = u32::try_from(
341 (phi_clause_cap as u64)
342 .checked_add(circuit_cnf.cnf.clause_cap as u64)
343 .and_then(|v| v.checked_add(1))
344 .ok_or_else(|| XlogError::Kernel("phi ∧ ¬C clause capacity overflow".to_string()))?,
345 )
346 .map_err(|_| XlogError::Kernel("phi ∧ ¬C clause capacity exceeds u32::MAX".to_string()))?;
347 let lit_cap = u32::try_from(
348 (phi_lit_cap as u64)
349 .checked_add(circuit_cnf.cnf.lit_cap as u64)
350 .and_then(|v| v.checked_add(1))
351 .ok_or_else(|| XlogError::Kernel("phi ∧ ¬C literal capacity overflow".to_string()))?,
352 )
353 .map_err(|_| XlogError::Kernel("phi ∧ ¬C literal capacity exceeds u32::MAX".to_string()))?;
354
355 let var_cap = circuit_cnf.cnf.var_cap;
356
357 let out_num_vars = memory.alloc::<u32>(1)?;
358 let out_num_clauses = memory.alloc::<u32>(1)?;
359 let out_num_lits = memory.alloc::<u32>(1)?;
360 let d_unused0 = memory.alloc::<u32>(1)?;
361 let d_unused1 = memory.alloc::<u32>(1)?;
362 let d_unused2 = memory.alloc::<u32>(1)?;
363
364 let mut d_zero = memory.alloc::<u32>(1)?;
365 provider
366 .htod_launch_metadata_sync_copy_into(&[0u32], &mut d_zero)
367 .map_err(|e| XlogError::Kernel(format!("Failed to upload zero: {}", e)))?;
368
369 let mut out_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
370 let mut out_lits = memory.alloc::<i32>(lit_cap as usize)?;
371
372 let copy_fn = device
373 .get_func(SAT_MODULE, sat_kernels::SAT_CNF_COPY_INTO)
374 .ok_or_else(|| XlogError::Kernel("sat_cnf_copy_into kernel not found".to_string()))?;
375
376 let block = 256u32;
377 let phi_copy_elems =
378 checked_clause_offset_span(phi_clause_cap, "sat_cnf_copy_into(phi)")?.max(phi_lit_cap);
379 let grid = checked_launch_grid(phi_copy_elems, block, "sat_cnf_copy_into(phi)")?;
380
381 unsafe {
386 copy_fn.clone().launch(
387 LaunchConfig {
388 grid_dim: (grid, 1, 1),
389 block_dim: (block, 1, 1),
390 shared_mem_bytes: 0,
391 },
392 (
393 &phi.clause_offsets,
394 &phi.literals,
395 &phi.num_clauses,
396 &phi.num_lits,
397 phi.clause_cap,
398 phi.lit_cap,
399 &d_zero,
400 &d_zero,
401 clause_cap,
402 lit_cap,
403 &mut out_offsets,
404 &mut out_lits,
405 ),
406 )
407 }
408 .map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(phi) failed: {}", e)))?;
409
410 let circuit_copy_elems =
412 checked_clause_offset_span(circuit_cnf.cnf.clause_cap, "sat_cnf_copy_into(circuit)")?
413 .max(circuit_cnf.cnf.lit_cap);
414 let grid_c = checked_launch_grid(circuit_copy_elems, block, "sat_cnf_copy_into(circuit)")?;
415 unsafe {
417 copy_fn.clone().launch(
418 LaunchConfig {
419 grid_dim: (grid_c, 1, 1),
420 block_dim: (block, 1, 1),
421 shared_mem_bytes: 0,
422 },
423 (
424 &circuit_cnf.cnf.clause_offsets,
425 &circuit_cnf.cnf.literals,
426 &circuit_cnf.cnf.num_clauses,
427 &circuit_cnf.cnf.num_lits,
428 circuit_cnf.cnf.clause_cap,
429 circuit_cnf.cnf.lit_cap,
430 &phi.num_clauses,
431 &phi.num_lits,
432 clause_cap,
433 lit_cap,
434 &mut out_offsets,
435 &mut out_lits,
436 ),
437 )
438 }
439 .map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(C) failed: {}", e)))?;
440
441 let unit_fn = device
443 .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE)
444 .ok_or_else(|| {
445 XlogError::Kernel("sat_xgcf_write_root_unit_clause kernel not found".to_string())
446 })?;
447
448 let root = circuit.root();
452 let force_true: i32 = 0;
453 let out_var_cap = var_cap;
454 let out_clause_cap = clause_cap;
455 let out_lit_cap = lit_cap;
456
457 let mut params: Vec<*mut c_void> = vec![
458 compile_needed.as_kernel_param(),
459 circuit.node_type().as_kernel_param(),
460 circuit.lit().as_kernel_param(),
461 (&circuit_cnf.internal_prefix).as_kernel_param(),
462 (&phi.num_vars).as_kernel_param(),
463 root.as_kernel_param(),
464 force_true.as_kernel_param(), (&phi.num_clauses).as_kernel_param(),
466 (&phi.num_lits).as_kernel_param(),
467 (&circuit_cnf.cnf.num_vars).as_kernel_param(),
468 (&circuit_cnf.cnf.num_clauses).as_kernel_param(),
469 (&circuit_cnf.cnf.num_lits).as_kernel_param(),
470 (&d_zero).as_kernel_param(), (&d_zero).as_kernel_param(), (&d_zero).as_kernel_param(), out_var_cap.as_kernel_param(),
474 out_clause_cap.as_kernel_param(),
475 out_lit_cap.as_kernel_param(),
476 (&out_num_vars).as_kernel_param(),
477 (&out_num_clauses).as_kernel_param(),
478 (&out_num_lits).as_kernel_param(),
479 (&d_unused0).as_kernel_param(),
480 (&d_unused1).as_kernel_param(),
481 (&d_unused2).as_kernel_param(),
482 (&out_offsets).as_kernel_param(),
483 (&out_lits).as_kernel_param(),
484 ];
485
486 unsafe {
488 unit_fn.clone().launch(
489 LaunchConfig {
490 grid_dim: (1, 1, 1),
491 block_dim: (1, 1, 1),
492 shared_mem_bytes: 0,
493 },
494 &mut params,
495 )
496 }
497 .map_err(|e| XlogError::Kernel(format!("sat_xgcf_write_root_unit_clause failed: {}", e)))?;
498 Ok(GpuCnf {
501 var_cap,
502 clause_cap,
503 lit_cap,
504 num_vars: out_num_vars,
505 num_clauses: out_num_clauses,
506 num_lits: out_num_lits,
507 clause_offsets: out_offsets,
508 literals: out_lits,
509 })
510}
511
512fn build_c_and_not_phi(
513 provider: &Arc<CudaKernelProvider>,
514 phi: &GpuCnf,
515 circuit: &GpuXgcf,
516 circuit_cnf: &CircuitCnf,
517 compile_needed: &TrackedCudaSlice<u32>,
518) -> Result<(GpuCnf, TrackedCudaSlice<u32>)> {
519 let device = provider.device().inner();
520 let memory = provider.memory();
521
522 let phi_clause_cap = phi.clause_cap;
523 let phi_lit_cap = phi.lit_cap;
524
525 let notphi_clause_cap = u32::try_from(
529 (phi_lit_cap as u64)
530 .checked_add(phi_clause_cap as u64)
531 .and_then(|v| v.checked_add(1))
532 .ok_or_else(|| XlogError::Kernel("¬phi clause count overflow".to_string()))?,
533 )
534 .map_err(|_| XlogError::Kernel("¬phi clause count exceeds u32::MAX".to_string()))?;
535 let notphi_lit_cap = u32::try_from(
536 (phi_lit_cap as u64)
537 .checked_mul(3)
538 .and_then(|v| v.checked_add(2u64.saturating_mul(phi_clause_cap as u64)))
539 .ok_or_else(|| XlogError::Kernel("¬phi literal count overflow".to_string()))?,
540 )
541 .map_err(|_| XlogError::Kernel("¬phi literal count exceeds u32::MAX".to_string()))?;
542
543 let var_cap = circuit_cnf
544 .cnf
545 .var_cap
546 .checked_add(phi_clause_cap)
547 .ok_or_else(|| XlogError::Kernel("C ∧ ¬phi var capacity overflow".to_string()))?;
548 let clause_cap = u32::try_from(
549 (circuit_cnf.cnf.clause_cap as u64)
550 .checked_add(1)
551 .and_then(|v| v.checked_add(notphi_clause_cap as u64))
552 .ok_or_else(|| XlogError::Kernel("C ∧ ¬phi clause capacity overflow".to_string()))?,
553 )
554 .map_err(|_| XlogError::Kernel("C ∧ ¬phi clause capacity exceeds u32::MAX".to_string()))?;
555 let lit_cap = u32::try_from(
556 (circuit_cnf.cnf.lit_cap as u64)
557 .checked_add(1)
558 .and_then(|v| v.checked_add(notphi_lit_cap as u64))
559 .ok_or_else(|| XlogError::Kernel("C ∧ ¬phi literal capacity overflow".to_string()))?,
560 )
561 .map_err(|_| XlogError::Kernel("C ∧ ¬phi literal capacity exceeds u32::MAX".to_string()))?;
562
563 let out_num_vars = memory.alloc::<u32>(1)?;
564 let out_num_clauses = memory.alloc::<u32>(1)?;
565 let out_num_lits = memory.alloc::<u32>(1)?;
566
567 let mut d_zero = memory.alloc::<u32>(1)?;
568 provider
569 .htod_launch_metadata_sync_copy_into(&[0u32], &mut d_zero)
570 .map_err(|e| XlogError::Kernel(format!("Failed to upload zero: {}", e)))?;
571
572 let mut d_extra_num_vars = memory.alloc::<u32>(1)?;
574 let mut d_extra_num_clauses = memory.alloc::<u32>(1)?;
575 let mut d_extra_num_lits = memory.alloc::<u32>(1)?;
576
577 let d_unsat_var_base = memory.alloc::<u32>(1)?;
578 let d_notphi_clause_base = memory.alloc::<u32>(1)?;
579 let d_notphi_lit_base = memory.alloc::<u32>(1)?;
580
581 let mut out_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
582 let mut out_lits = memory.alloc::<i32>(lit_cap as usize)?;
583
584 let copy_fn = device
585 .get_func(SAT_MODULE, sat_kernels::SAT_CNF_COPY_INTO)
586 .ok_or_else(|| XlogError::Kernel("sat_cnf_copy_into kernel not found".to_string()))?;
587
588 let block = 256u32;
590 let circuit_copy_elems =
591 checked_clause_offset_span(circuit_cnf.cnf.clause_cap, "sat_cnf_copy_into(circuit)")?
592 .max(circuit_cnf.cnf.lit_cap);
593 let grid = checked_launch_grid(circuit_copy_elems, block, "sat_cnf_copy_into(circuit)")?;
594 unsafe {
597 copy_fn.clone().launch(
598 LaunchConfig {
599 grid_dim: (grid, 1, 1),
600 block_dim: (block, 1, 1),
601 shared_mem_bytes: 0,
602 },
603 (
604 &circuit_cnf.cnf.clause_offsets,
605 &circuit_cnf.cnf.literals,
606 &circuit_cnf.cnf.num_clauses,
607 &circuit_cnf.cnf.num_lits,
608 circuit_cnf.cnf.clause_cap,
609 circuit_cnf.cnf.lit_cap,
610 &d_zero,
611 &d_zero,
612 clause_cap,
613 lit_cap,
614 &mut out_offsets,
615 &mut out_lits,
616 ),
617 )
618 }
619 .map_err(|e| XlogError::Kernel(format!("sat_cnf_copy_into(C) failed: {}", e)))?;
620
621 let notphi_counts_fn = device
623 .get_func(SAT_MODULE, sat_kernels::SAT_NOT_PHI_COUNTS)
624 .ok_or_else(|| XlogError::Kernel("sat_not_phi_counts kernel not found".to_string()))?;
625 unsafe {
627 notphi_counts_fn.clone().launch(
628 LaunchConfig {
629 grid_dim: (1, 1, 1),
630 block_dim: (1, 1, 1),
631 shared_mem_bytes: 0,
632 },
633 (
634 compile_needed,
635 &phi.num_clauses,
636 &phi.num_lits,
637 &mut d_extra_num_vars,
638 &mut d_extra_num_clauses,
639 &mut d_extra_num_lits,
640 ),
641 )
642 }
643 .map_err(|e| XlogError::Kernel(format!("sat_not_phi_counts failed: {}", e)))?;
644
645 let unit_fn = device
647 .get_func(SAT_MODULE, sat_kernels::SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE)
648 .ok_or_else(|| {
649 XlogError::Kernel("sat_xgcf_write_root_unit_clause kernel not found".to_string())
650 })?;
651
652 let root = circuit.root();
654 let force_true: i32 = 1;
655 let out_var_cap = var_cap;
656 let out_clause_cap = clause_cap;
657 let out_lit_cap = lit_cap;
658
659 let mut params: Vec<*mut c_void> = vec![
660 compile_needed.as_kernel_param(),
661 circuit.node_type().as_kernel_param(),
662 circuit.lit().as_kernel_param(),
663 (&circuit_cnf.internal_prefix).as_kernel_param(),
664 (&phi.num_vars).as_kernel_param(),
665 root.as_kernel_param(),
666 force_true.as_kernel_param(), (&d_zero).as_kernel_param(), (&d_zero).as_kernel_param(), (&circuit_cnf.cnf.num_vars).as_kernel_param(),
670 (&circuit_cnf.cnf.num_clauses).as_kernel_param(),
671 (&circuit_cnf.cnf.num_lits).as_kernel_param(),
672 (&d_extra_num_vars).as_kernel_param(), (&d_extra_num_clauses).as_kernel_param(), (&d_extra_num_lits).as_kernel_param(), out_var_cap.as_kernel_param(),
676 out_clause_cap.as_kernel_param(),
677 out_lit_cap.as_kernel_param(),
678 (&out_num_vars).as_kernel_param(),
679 (&out_num_clauses).as_kernel_param(),
680 (&out_num_lits).as_kernel_param(),
681 (&d_unsat_var_base).as_kernel_param(),
682 (&d_notphi_clause_base).as_kernel_param(),
683 (&d_notphi_lit_base).as_kernel_param(),
684 (&out_offsets).as_kernel_param(),
685 (&out_lits).as_kernel_param(),
686 ];
687
688 unsafe {
690 unit_fn.clone().launch(
691 LaunchConfig {
692 grid_dim: (1, 1, 1),
693 block_dim: (1, 1, 1),
694 shared_mem_bytes: 0,
695 },
696 &mut params,
697 )
698 }
699 .map_err(|e| XlogError::Kernel(format!("sat_xgcf_write_root_unit_clause failed: {}", e)))?;
700
701 let not_phi_fn = device
703 .get_func(SAT_MODULE, sat_kernels::SAT_EMIT_NOT_PHI)
704 .ok_or_else(|| XlogError::Kernel("sat_emit_not_phi kernel not found".to_string()))?;
705
706 let block = 256u32;
707 let grid = checked_launch_grid(phi_clause_cap, block, "sat_emit_not_phi")?;
708
709 unsafe {
711 not_phi_fn.clone().launch(
712 LaunchConfig {
713 grid_dim: (grid, 1, 1),
714 block_dim: (block, 1, 1),
715 shared_mem_bytes: 0,
716 },
717 (
718 compile_needed,
719 &phi.clause_offsets,
720 &phi.literals,
721 &phi.num_clauses,
722 &d_unsat_var_base,
723 &d_notphi_clause_base,
724 &d_notphi_lit_base,
725 &mut out_offsets,
726 &mut out_lits,
727 ),
728 )
729 }
730 .map_err(|e| XlogError::Kernel(format!("sat_emit_not_phi failed: {}", e)))?;
731 Ok((
734 GpuCnf {
735 var_cap,
736 clause_cap,
737 lit_cap,
738 num_vars: out_num_vars,
739 num_clauses: out_num_clauses,
740 num_lits: out_num_lits,
741 clause_offsets: out_offsets,
742 literals: out_lits,
743 },
744 d_unsat_var_base,
745 ))
746}
747
748pub(crate) fn check_equivalence_gpu(
749 phi: &GpuCnf,
750 phi_decision_var_limit: &TrackedCudaSlice<u32>,
751 circuit: &GpuXgcf,
752 provider: &Arc<CudaKernelProvider>,
753 config: GpuEquivalenceConfig,
754) -> Result<()> {
755 let queries = build_equivalence_queries_gpu(phi, circuit, provider)?;
756
757 #[cfg(debug_assertions)]
758 {
759 validate_cnf_gpu(&queries.q1, provider.as_ref())?;
761 validate_cnf_gpu(&queries.q2, provider.as_ref())?;
762 }
763
764 let solver = GpuCdclSolver::new(provider.clone(), config.cdcl);
765 if config.reuse_workspace {
766 let max_var_cap = std::cmp::max(queries.q1.var_cap, queries.q2.var_cap);
767 let max_clause_cap = std::cmp::max(queries.q1.clause_cap, queries.q2.clause_cap);
768 let mut ws = solver.new_workspace(max_var_cap, max_clause_cap)?;
769 solver.solve_expect_unsat_with_branch_limit_ws(
771 &mut ws,
772 &queries.q1,
773 phi_decision_var_limit,
774 )?;
775 solver.solve_expect_unsat_with_decision_ranges_ws(
777 &mut ws,
778 &queries.q2,
779 phi_decision_var_limit,
780 &queries.q2_unsat_var_base,
781 &phi.num_clauses,
782 )?;
783 } else {
784 solver.solve_expect_unsat_with_branch_limit(&queries.q1, phi_decision_var_limit)?;
786 solver.solve_expect_unsat_with_decision_ranges(
788 &queries.q2,
789 phi_decision_var_limit,
790 &queries.q2_unsat_var_base,
791 &phi.num_clauses,
792 )?;
793 }
794 Ok(())
795}
796
797pub fn build_equivalence_queries_gpu(
804 phi: &GpuCnf,
805 circuit: &GpuXgcf,
806 provider: &Arc<CudaKernelProvider>,
807) -> Result<GpuEquivalenceQueries> {
808 let memory = provider.memory();
810 let mut compile_needed = memory.alloc::<u32>(1)?;
811 provider
812 .htod_launch_metadata_sync_copy_into(&[1u32], &mut compile_needed)
813 .map_err(|e| XlogError::Kernel(format!("Failed to upload compile_needed=1: {}", e)))?;
814
815 let circuit_cnf = build_circuit_cnf(
816 provider,
817 circuit,
818 &phi.num_vars,
819 phi.var_cap,
820 &compile_needed,
821 )?;
822 let q1 = build_phi_and_not_c(provider, phi, circuit, &circuit_cnf, &compile_needed)?;
823 let (q2, q2_unsat_var_base) =
824 build_c_and_not_phi(provider, phi, circuit, &circuit_cnf, &compile_needed)?;
825 Ok(GpuEquivalenceQueries {
826 q1,
827 q2,
828 q2_unsat_var_base,
829 })
830}
831
832pub(crate) fn check_equivalence_gpu_gated(
833 phi: &GpuCnf,
834 phi_decision_var_limit: &TrackedCudaSlice<u32>,
835 circuit: &GpuXgcf,
836 provider: &Arc<CudaKernelProvider>,
837 config: GpuEquivalenceConfig,
838 compile_needed: &TrackedCudaSlice<u32>,
839) -> Result<()> {
840 #[cfg(debug_assertions)]
841 eprintln!("[xlog-prob] equivalence: build_circuit_cnf");
842 let circuit_cnf = build_circuit_cnf(
843 provider,
844 circuit,
845 &phi.num_vars,
846 phi.var_cap,
847 compile_needed,
848 )?;
849 #[cfg(debug_assertions)]
850 {
851 provider.device().synchronize().map_err(|e| {
852 XlogError::Kernel(format!("sync after build_circuit_cnf failed: {}", e))
853 })?;
854 eprintln!("[xlog-prob] equivalence: build_phi_and_not_c");
855 }
856
857 let q1 = build_phi_and_not_c(provider, phi, circuit, &circuit_cnf, compile_needed)?;
858 #[cfg(debug_assertions)]
859 {
860 provider.device().synchronize().map_err(|e| {
861 XlogError::Kernel(format!("sync after build_phi_and_not_c failed: {}", e))
862 })?;
863 eprintln!("[xlog-prob] equivalence: build_c_and_not_phi");
864 }
865 let (q2, q2_unsat_var_base) =
866 build_c_and_not_phi(provider, phi, circuit, &circuit_cnf, compile_needed)?;
867 #[cfg(debug_assertions)]
868 {
869 provider.device().synchronize().map_err(|e| {
870 XlogError::Kernel(format!("sync after build_c_and_not_phi failed: {}", e))
871 })?;
872 eprintln!(
873 "[xlog-prob] equivalence: caps: phi(v={} c={} l={}) circuit_cnf(v={} c={} l={}) q1(v={} c={} l={}) q2(v={} c={} l={})",
874 phi.var_cap,
875 phi.clause_cap,
876 phi.lit_cap,
877 circuit_cnf.cnf.var_cap,
878 circuit_cnf.cnf.clause_cap,
879 circuit_cnf.cnf.lit_cap,
880 q1.var_cap,
881 q1.clause_cap,
882 q1.lit_cap,
883 q2.var_cap,
884 q2.clause_cap,
885 q2.lit_cap,
886 );
887 eprintln!("[xlog-prob] equivalence: solve_expect_unsat q1");
888 }
889
890 #[cfg(debug_assertions)]
891 {
892 validate_cnf_gpu(&q1, provider.as_ref())?;
893 validate_cnf_gpu(&q2, provider.as_ref())?;
894 }
895
896 let solver = GpuCdclSolver::new(provider.clone(), config.cdcl);
897 if config.reuse_workspace {
898 let max_var_cap = std::cmp::max(q1.var_cap, q2.var_cap);
899 let max_clause_cap = std::cmp::max(q1.clause_cap, q2.clause_cap);
900 let mut ws = solver.new_workspace(max_var_cap, max_clause_cap)?;
901 solver.solve_expect_unsat_with_branch_limit_gated_ws(
902 &mut ws,
903 &q1,
904 compile_needed,
905 phi_decision_var_limit,
906 )?;
907 #[cfg(debug_assertions)]
908 {
909 provider.device().synchronize().map_err(|e| {
910 XlogError::Kernel(format!("sync after solve_expect_unsat(q1) failed: {}", e))
911 })?;
912 eprintln!("[xlog-prob] equivalence: solve_expect_unsat q2");
913 }
914 solver.solve_expect_unsat_with_decision_ranges_gated_ws(
915 &mut ws,
916 &q2,
917 compile_needed,
918 phi_decision_var_limit,
919 &q2_unsat_var_base,
920 &phi.num_clauses,
921 )?;
922 } else {
923 solver.solve_expect_unsat_with_branch_limit_gated(
924 &q1,
925 compile_needed,
926 phi_decision_var_limit,
927 )?;
928 #[cfg(debug_assertions)]
929 {
930 provider.device().synchronize().map_err(|e| {
931 XlogError::Kernel(format!("sync after solve_expect_unsat(q1) failed: {}", e))
932 })?;
933 eprintln!("[xlog-prob] equivalence: solve_expect_unsat q2");
934 }
935 solver.solve_expect_unsat_with_decision_ranges_gated(
936 &q2,
937 compile_needed,
938 phi_decision_var_limit,
939 &q2_unsat_var_base,
940 &phi.num_clauses,
941 )?;
942 }
943 #[cfg(debug_assertions)]
944 {
945 provider.device().synchronize().map_err(|e| {
946 XlogError::Kernel(format!("sync after solve_expect_unsat(q2) failed: {}", e))
947 })?;
948 eprintln!("[xlog-prob] equivalence: done");
949 }
950 Ok(())
951}
952
953fn verify_size_budget() -> (u32, u32) {
967 fn env_u32(key: &str) -> u32 {
968 std::env::var(key)
969 .ok()
970 .and_then(|v| v.trim().parse::<u32>().ok())
971 .unwrap_or(u32::MAX)
972 }
973 static BUDGET: std::sync::OnceLock<(u32, u32)> = std::sync::OnceLock::new();
974 *BUDGET.get_or_init(|| {
975 (
976 env_u32("XLOG_D4_VERIFY_MAX_VARS"),
977 env_u32("XLOG_D4_VERIFY_MAX_CLAUSES"),
978 )
979 })
980}
981
982fn enforce_verify_size_bound(
986 var_cap: u32,
987 clause_cap: u32,
988 var_budget: u32,
989 clause_budget: u32,
990 context: &str,
991) -> Result<()> {
992 if var_cap > var_budget || clause_cap > clause_budget {
993 return Err(XlogError::CompileCapacityExceeded {
996 context: context.to_string(),
997 detail: format!(
998 "CNF {var_cap} vars / {clause_cap} clauses exceeds size bound \
999 ({var_budget} vars / {clause_budget} clauses)"
1000 ),
1001 });
1002 }
1003 Ok(())
1004}
1005
1006pub(crate) fn check_verify_size_bound(phi: &GpuCnf, context: &str) -> Result<()> {
1014 if std::env::var("XLOG_DEBUG_VERIFY_SIZE").as_deref() == Ok("1") {
1019 eprintln!(
1020 "[xlog-prob] verify-size {context}: var_cap={} clause_cap={} lit_cap={}",
1021 phi.var_cap, phi.clause_cap, phi.lit_cap
1022 );
1023 }
1024 let (var_budget, clause_budget) = verify_size_budget();
1025 enforce_verify_size_bound(
1026 phi.var_cap,
1027 phi.clause_cap,
1028 var_budget,
1029 clause_budget,
1030 context,
1031 )
1032}
1033
1034pub fn validate_equivalence_gpu(
1035 phi: &GpuCnf,
1036 phi_decision_var_limit: &TrackedCudaSlice<u32>,
1037 circuit: &GpuXgcf,
1038 provider: &Arc<CudaKernelProvider>,
1039 config: GpuEquivalenceConfig,
1040) -> Result<()> {
1041 check_verify_size_bound(phi, "validate_equivalence_gpu")?;
1042 check_equivalence_gpu(phi, phi_decision_var_limit, circuit, provider, config)
1043}
1044
1045pub fn validate_equivalence_gpu_gated(
1046 phi: &GpuCnf,
1047 phi_decision_var_limit: &TrackedCudaSlice<u32>,
1048 circuit: &GpuXgcf,
1049 provider: &Arc<CudaKernelProvider>,
1050 config: GpuEquivalenceConfig,
1051 compile_needed: &TrackedCudaSlice<u32>,
1052) -> Result<()> {
1053 check_verify_size_bound(phi, "validate_equivalence_gpu_gated")?;
1054 check_equivalence_gpu_gated(
1055 phi,
1056 phi_decision_var_limit,
1057 circuit,
1058 provider,
1059 config,
1060 compile_needed,
1061 )
1062}
1063
1064#[cfg(test)]
1065mod verify_size_bound_tests {
1066 use super::enforce_verify_size_bound;
1067 use xlog_core::XlogError;
1068
1069 #[test]
1073 fn over_var_budget_declines_typed() {
1074 let err = enforce_verify_size_bound(5000, 100, 4096, u32::MAX, "ctx")
1075 .expect_err("must decline over var budget");
1076 match err {
1077 XlogError::CompileCapacityExceeded { context, detail } => {
1079 assert_eq!(context, "ctx");
1080 assert!(detail.contains("5000 vars"), "detail: {detail}");
1082 assert!(detail.contains("size bound"), "detail: {detail}");
1083 }
1084 other => panic!("wrong error variant: {other:?}"),
1085 }
1086 }
1087
1088 #[test]
1089 fn over_clause_budget_declines_typed() {
1090 let err = enforce_verify_size_bound(10, 20_000, u32::MAX, 16_384, "ctx")
1091 .expect_err("must decline over clause budget");
1092 assert!(matches!(err, XlogError::CompileCapacityExceeded { .. }));
1093 }
1094
1095 #[test]
1096 fn within_budget_proceeds() {
1097 enforce_verify_size_bound(100, 200, 4096, 16_384, "ctx")
1098 .expect("within budget must pass the bound");
1099 }
1100
1101 #[test]
1104 fn unbounded_default_never_declines() {
1105 enforce_verify_size_bound(u32::MAX, u32::MAX, u32::MAX, u32::MAX, "ctx")
1106 .expect("unbounded default must not decline");
1107 }
1108}