Skip to main content

xlog_prob/compilation/gpu_d4/
mod.rs

1//! GPU-native Decision-DNNF knowledge compilation.
2//!
3//! This module provides the configuration and kernel-facing utilities needed
4//! to compile a device-resident CNF into a device-resident XGCF circuit.
5//!
6//! Primary spec: docs/design/2026-01-22-gpu-native-compilation-design.md (Section 5.2.4).
7
8use std::sync::Arc;
9
10use cudarc::driver::{DeviceSlice, LaunchConfig};
11use xlog_core::{Result, XlogError};
12use xlog_cuda::memory::TrackedCudaSlice;
13use xlog_cuda::provider::{d4_kernels, scan_kernels, D4_MODULE, SCAN_MODULE};
14use xlog_cuda::{CudaKernelProvider, LaunchAsync};
15use xlog_solve::GpuCnf;
16
17use crate::gpu::GpuXgcf;
18
19pub(crate) mod build;
20pub(crate) mod frontier;
21
22// Re-export used by test code in build.rs (via super::super::).
23#[cfg(test)]
24pub(crate) use frontier::build_frontier_bitset;
25
26pub(super) fn alloc_compile_gate(
27    provider: &CudaKernelProvider,
28    value: u32,
29) -> Result<TrackedCudaSlice<u32>> {
30    let memory = provider.memory();
31    let mut gate = memory.alloc::<u32>(1)?;
32    provider
33        .htod_launch_metadata_sync_copy_into(&[value], &mut gate)
34        .map_err(|e| XlogError::Kernel(format!("compile gate upload failed: {}", e)))?;
35    Ok(gate)
36}
37
38pub(super) fn memset_u8_sync(dst: &mut TrackedCudaSlice<u8>, value: u8) -> Result<()> {
39    dst.stream()
40        .context()
41        .bind_to_thread()
42        .map_err(|e| XlogError::Kernel(format!("bind_to_thread failed: {}", e)))?;
43    let dptr = dst.device_ptr_value();
44    unsafe { cudarc::driver::result::memset_d8_sync(dptr, value, dst.len()) }
45        .map_err(|e| XlogError::Kernel(format!("memset_d8_sync failed: {}", e)))?;
46    Ok(())
47}
48
49/// Configuration for GPU-native Decision-DNNF knowledge compilation plus GPU CDCL verification.
50///
51/// This is the public control-plane contract for the GPU-native compilation pipeline.
52/// Use [`GpuCompileConfig::default()`] for conservative static defaults, or
53/// [`crate::exact::default_compile_config`] for dynamic sizing from a CNF.
54#[derive(Debug, Clone, Copy)]
55pub struct GpuCompileConfig {
56    /// BFS expansion depth before handing each frontier item to a per-block DFS worker.
57    pub frontier_depth: u16,
58    /// Hard cap on the number of frontier work items (overflow is a hard error).
59    pub max_frontier_items: u32,
60    /// Absolute depth cap (defensive); exceeding this is a hard error (no UNKNOWN).
61    pub max_depth: u16,
62    /// Hard cap on nodes emitted by the GPU smoothing pass.
63    pub smooth_node_cap: u32,
64    /// Hard cap on edges emitted by the GPU smoothing pass.
65    pub smooth_edge_cap: u32,
66
67    /// CDCL restart cadence (deterministic).
68    pub cdcl_restart_interval: u32,
69    /// Learned clause arena size (bytes) for the verifier instance.
70    pub cdcl_learned_bytes: u64,
71    /// Optional conflict budget for debug/profiling only; production must be unbounded.
72    pub cdcl_conflict_budget: Option<u64>,
73
74    /// Enable workspace reuse in the equivalence verifier (amortizes arena allocation).
75    pub incremental_verify: bool,
76}
77
78impl Default for GpuCompileConfig {
79    /// Conservative defaults suitable for small-to-medium CNFs.
80    ///
81    /// Production callers should use [`crate::exact::default_compile_config`] which
82    /// sizes arenas dynamically from the CNF and memory budget.
83    fn default() -> Self {
84        Self {
85            frontier_depth: 6,
86            max_frontier_items: 64,
87            max_depth: 128,
88            smooth_node_cap: 65_536,
89            smooth_edge_cap: 131_072,
90            cdcl_restart_interval: 64,
91            cdcl_learned_bytes: 4 * 1024 * 1024,
92            cdcl_conflict_budget: None,
93            incremental_verify: false,
94        }
95    }
96}
97
98/// Validate `GpuCnf` CSR invariants on the GPU (fail-fast trap on invalid input).
99///
100/// This is used as a mandatory invariant check for GPU-native compilation paths where
101/// the host cannot safely "peek" into device-resident CNF buffers.
102pub fn validate_cnf_gpu(cnf: &GpuCnf, provider: &CudaKernelProvider) -> Result<()> {
103    let device = provider.device().inner();
104    let validate = device
105        .get_func(D4_MODULE, d4_kernels::D4_VALIDATE_CNF)
106        .ok_or_else(|| XlogError::Kernel("d4_validate_cnf kernel not found".to_string()))?;
107
108    let var_cap = cnf.var_cap;
109    let clause_cap = cnf.clause_cap;
110    let lit_cap = cnf.lit_cap;
111
112    // SAFETY: d4_validate_cnf(var_cap, clause_cap, lit_cap, num_vars*, num_clauses*, num_lits*, offsets*, lits*)
113    unsafe {
114        validate.clone().launch(
115            LaunchConfig {
116                grid_dim: (1, 1, 1),
117                block_dim: (1, 1, 1),
118                shared_mem_bytes: 0,
119            },
120            (
121                var_cap,
122                clause_cap,
123                lit_cap,
124                &cnf.num_vars,
125                &cnf.num_clauses,
126                &cnf.num_lits,
127                &cnf.clause_offsets,
128                &cnf.literals,
129            ),
130        )
131    }
132    .map_err(|e| XlogError::Kernel(format!("d4_validate_cnf failed: {}", e)))?;
133
134    // `d4_validate_cnf` uses a device-side trap for fail-fast invariants. If it trips, we must
135    // surface that error here; otherwise it will show up later as a panicking `CudaSlice` drop.
136    provider
137        .device()
138        .synchronize()
139        .map_err(|e| XlogError::Kernel(format!("sync after d4_validate_cnf failed: {}", e)))?;
140
141    Ok(())
142}
143
144/// Compute free-variable mask on device (vars absent from CNF and circuit).
145///
146/// Returns a device-resident u8 mask of length var_cap+1 where mask[v]=1 means
147/// var v is free. Traps on device if a variable appears in CNF clauses but is
148/// missing from the circuit.
149pub fn compute_free_var_mask_gpu(
150    cnf: &GpuCnf,
151    circuit: &GpuXgcf,
152    provider: &CudaKernelProvider,
153) -> Result<TrackedCudaSlice<u8>> {
154    let compile_needed = alloc_compile_gate(provider, 1)?;
155    compute_free_var_mask_gpu_gated(cnf, circuit, provider, &compile_needed)
156}
157
158pub(crate) fn compute_free_var_mask_gpu_gated(
159    cnf: &GpuCnf,
160    circuit: &GpuXgcf,
161    provider: &CudaKernelProvider,
162    compile_needed: &TrackedCudaSlice<u32>,
163) -> Result<TrackedCudaSlice<u8>> {
164    if cnf.var_cap == 0 {
165        return Err(XlogError::Compilation(
166            "compute_free_var_mask_gpu requires var_cap > 0".to_string(),
167        ));
168    }
169    if circuit.max_var() > cnf.var_cap {
170        return Err(XlogError::Compilation(format!(
171            "compute_free_var_mask_gpu: circuit max_var {} exceeds CNF var_cap {}",
172            circuit.max_var(),
173            cnf.var_cap
174        )));
175    }
176
177    let num_nodes = u32::try_from(circuit.num_nodes()).map_err(|_| {
178        XlogError::Compilation("compute_free_var_mask_gpu: num_nodes overflow".to_string())
179    })?;
180
181    let mask_len = (cnf.var_cap as u64)
182        .checked_add(1)
183        .and_then(|v| usize::try_from(v).ok())
184        .ok_or_else(|| {
185            XlogError::Compilation("compute_free_var_mask_gpu: mask length overflow".to_string())
186        })?;
187
188    let memory = provider.memory();
189    let device = provider.device().inner();
190
191    let mut vars_in_clauses = memory.alloc::<u32>(mask_len)?;
192    let mut vars_in_circuit = memory.alloc::<u32>(mask_len)?;
193    let mut free_var_mask = memory.alloc::<u8>(mask_len)?;
194
195    device.memset_zeros(&mut vars_in_clauses).map_err(|e| {
196        XlogError::Kernel(format!(
197            "compute_free_var_mask_gpu: zero vars_in_clauses: {}",
198            e
199        ))
200    })?;
201    device.memset_zeros(&mut vars_in_circuit).map_err(|e| {
202        XlogError::Kernel(format!(
203            "compute_free_var_mask_gpu: zero vars_in_circuit: {}",
204            e
205        ))
206    })?;
207    device.memset_zeros(&mut free_var_mask).map_err(|e| {
208        XlogError::Kernel(format!(
209            "compute_free_var_mask_gpu: zero free_var_mask: {}",
210            e
211        ))
212    })?;
213
214    let block_dim = 256u32;
215
216    if cnf.lit_cap > 0 {
217        let grid_dim = cnf.lit_cap.div_ceil(block_dim);
218        let mark_clauses = device
219            .get_func(D4_MODULE, d4_kernels::D4_MARK_VARS_IN_CLAUSES)
220            .ok_or_else(|| {
221                XlogError::Kernel("d4_mark_vars_in_clauses kernel not found".to_string())
222            })?;
223        unsafe {
224            mark_clauses.clone().launch(
225                LaunchConfig {
226                    grid_dim: (grid_dim, 1, 1),
227                    block_dim: (block_dim, 1, 1),
228                    shared_mem_bytes: 0,
229                },
230                (
231                    compile_needed,
232                    cnf.var_cap,
233                    cnf.lit_cap,
234                    &cnf.num_vars,
235                    &cnf.num_lits,
236                    &cnf.literals,
237                    &mut vars_in_clauses,
238                ),
239            )
240        }
241        .map_err(|e| XlogError::Kernel(format!("d4_mark_vars_in_clauses failed: {}", e)))?;
242    }
243
244    if num_nodes > 0 {
245        let grid_dim = num_nodes.div_ceil(block_dim);
246        let mark_circuit = device
247            .get_func(D4_MODULE, d4_kernels::D4_MARK_VARS_IN_CIRCUIT)
248            .ok_or_else(|| {
249                XlogError::Kernel("d4_mark_vars_in_circuit kernel not found".to_string())
250            })?;
251        unsafe {
252            mark_circuit.clone().launch(
253                LaunchConfig {
254                    grid_dim: (grid_dim, 1, 1),
255                    block_dim: (block_dim, 1, 1),
256                    shared_mem_bytes: 0,
257                },
258                (
259                    compile_needed,
260                    circuit.node_type(),
261                    circuit.lit(),
262                    circuit.decision_var(),
263                    circuit.num_nodes_device(),
264                    &cnf.num_vars,
265                    cnf.var_cap,
266                    &mut vars_in_circuit,
267                ),
268            )
269        }
270        .map_err(|e| XlogError::Kernel(format!("d4_mark_vars_in_circuit failed: {}", e)))?;
271    }
272
273    let mask_len_u32 = cnf.var_cap.checked_add(1).ok_or_else(|| {
274        XlogError::Compilation("compute_free_var_mask_gpu: mask length overflow".to_string())
275    })?;
276    let grid_dim = mask_len_u32.div_ceil(block_dim);
277    let build_mask = device
278        .get_func(D4_MODULE, d4_kernels::D4_BUILD_FREE_VAR_MASK)
279        .ok_or_else(|| XlogError::Kernel("d4_build_free_var_mask kernel not found".to_string()))?;
280    unsafe {
281        build_mask.clone().launch(
282            LaunchConfig {
283                grid_dim: (grid_dim, 1, 1),
284                block_dim: (block_dim, 1, 1),
285                shared_mem_bytes: 0,
286            },
287            (
288                compile_needed,
289                &cnf.num_vars,
290                cnf.var_cap,
291                &vars_in_clauses,
292                &vars_in_circuit,
293                &mut free_var_mask,
294            ),
295        )
296    }
297    .map_err(|e| XlogError::Kernel(format!("d4_build_free_var_mask failed: {}", e)))?;
298
299    Ok(free_var_mask)
300}
301
302pub(super) fn bitset_words_per_item(var_cap: u32) -> Result<u32> {
303    // Bit 0 is unused (DIMACS vars are 1-based), so we allocate var_cap+1 bits.
304    let bits = var_cap
305        .checked_add(1)
306        .ok_or_else(|| XlogError::Kernel("bitset var_cap+1 overflow".to_string()))?;
307    Ok(bits.div_ceil(32))
308}
309
310pub(super) fn checked_pool_len_u32(max_items: u32, stride: u32, context: &str) -> Result<u32> {
311    let len = (max_items as u64)
312        .checked_mul(stride as u64)
313        .ok_or_else(|| XlogError::Kernel(format!("{} pool length overflow", context)))?;
314    if len > (u32::MAX as u64) {
315        return Err(XlogError::Kernel(format!(
316            "{} pool length {} exceeds u32::MAX",
317            context, len
318        )));
319    }
320    Ok(len as u32)
321}
322
323pub(super) fn checked_pool_len_usize(max_items: u32, stride: u32, context: &str) -> Result<usize> {
324    let len_u32 = checked_pool_len_u32(max_items, stride, context)?;
325    Ok(len_u32 as usize)
326}
327
328pub(crate) fn exclusive_scan_u32_inplace(
329    provider: &CudaKernelProvider,
330    data: &mut TrackedCudaSlice<u32>,
331    n: u32,
332) -> Result<()> {
333    if n == 0 {
334        return Ok(());
335    }
336
337    let device = provider.device().inner();
338    let block_size = 256u32;
339
340    if n <= block_size {
341        let phase2 = device
342            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
343            .ok_or_else(|| {
344                XlogError::Kernel("multiblock_scan_phase2 kernel not found".to_string())
345            })?;
346        unsafe {
347            phase2.clone().launch(
348                LaunchConfig {
349                    grid_dim: (1, 1, 1),
350                    block_dim: (block_size, 1, 1),
351                    shared_mem_bytes: 0,
352                },
353                (&mut *data, n),
354            )
355        }
356        .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase2 failed: {}", e)))?;
357        return Ok(());
358    }
359
360    let num_blocks = n.div_ceil(block_size);
361    let memory = provider.memory();
362    let mut block_sums = memory.alloc::<u32>(num_blocks as usize)?;
363
364    let phase1 = device
365        .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
366        .ok_or_else(|| {
367            XlogError::Kernel("multiblock_scan_u32_phase1 kernel not found".to_string())
368        })?;
369    unsafe {
370        phase1.clone().launch(
371            LaunchConfig {
372                grid_dim: (num_blocks, 1, 1),
373                block_dim: (block_size, 1, 1),
374                shared_mem_bytes: 0,
375            },
376            (&mut *data, &mut block_sums, n),
377        )
378    }
379    .map_err(|e| XlogError::Kernel(format!("multiblock_scan_u32_phase1 failed: {}", e)))?;
380
381    if num_blocks > 1 {
382        exclusive_scan_u32_inplace(provider, &mut block_sums, num_blocks)?;
383    }
384
385    let phase3 = device
386        .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
387        .ok_or_else(|| XlogError::Kernel("multiblock_scan_phase3 kernel not found".to_string()))?;
388    unsafe {
389        phase3.clone().launch(
390            LaunchConfig {
391                grid_dim: (num_blocks, 1, 1),
392                block_dim: (block_size, 1, 1),
393                shared_mem_bytes: 0,
394            },
395            (&mut *data, &block_sums, n),
396        )
397    }
398    .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
399
400    Ok(())
401}
402
403/// Compile a device-resident CNF into a device-resident XGCF circuit.
404pub(crate) fn compile_gpu_d4(
405    cnf: &GpuCnf,
406    provider: &Arc<CudaKernelProvider>,
407    config: &GpuCompileConfig,
408) -> Result<GpuXgcf> {
409    let compile_needed = alloc_compile_gate(provider, 1)?;
410    Ok(build::compile_gpu_d4_with_gate(cnf, provider, config, &compile_needed)?.0)
411}
412
413/// Compile a device-resident CNF into a device-resident XGCF circuit,
414/// skipping work on the device when `compile_needed` is 0.
415pub fn compile_gpu_d4_gated(
416    cnf: &GpuCnf,
417    provider: &Arc<CudaKernelProvider>,
418    config: &GpuCompileConfig,
419    compile_needed: &TrackedCudaSlice<u32>,
420) -> Result<GpuXgcf> {
421    Ok(build::compile_gpu_d4_with_gate(cnf, provider, config, compile_needed)?.0)
422}
423
424/// Gated compile that also returns the BFS frontier item count for profiling
425/// (0 unless `XLOG_WARMUP_PROFILE=1`).
426pub(crate) fn compile_gpu_d4_gated_with_stats(
427    cnf: &GpuCnf,
428    provider: &Arc<CudaKernelProvider>,
429    config: &GpuCompileConfig,
430    compile_needed: &TrackedCudaSlice<u32>,
431) -> Result<(GpuXgcf, u32)> {
432    build::compile_gpu_d4_with_gate(cnf, provider, config, compile_needed)
433}
434
435#[cfg(test)]
436mod tests {
437    #[test]
438    fn gpu_d4_compile_config_requires_smoothing_caps() {
439        let config = super::GpuCompileConfig {
440            frontier_depth: 0,
441            max_frontier_items: 1,
442            max_depth: 8,
443            cdcl_restart_interval: 128,
444            cdcl_learned_bytes: 1 << 20,
445            cdcl_conflict_budget: None,
446            smooth_node_cap: 256,
447            smooth_edge_cap: 512,
448            incremental_verify: false,
449        };
450        assert!(config.smooth_node_cap > 0);
451        assert!(config.smooth_edge_cap > 0);
452    }
453}