Skip to main content

xlog_solve/
gpu_cnf.rs

1use std::sync::Arc;
2
3use cudarc::driver::DeviceSlice;
4use xlog_core::{Result, XlogError};
5use xlog_cuda::memory::TrackedCudaSlice;
6use xlog_cuda::CudaKernelProvider;
7
8use crate::instance::SolveInstance;
9
10/// GPU-resident CNF in CSR form (DIMACS literals, 1-based variable ids).
11///
12/// This is the solver-facing CNF representation used by the GPU CDCL verifier.
13pub struct GpuCnf {
14    /// Variable capacity (>= num_vars).
15    pub var_cap: u32,
16    /// Clause capacity (>= num_clauses).
17    pub clause_cap: u32,
18    /// Literal capacity (>= num_lits).
19    pub lit_cap: u32,
20    /// Device-resident num_vars (len = 1).
21    pub num_vars: TrackedCudaSlice<u32>,
22    /// Device-resident num_clauses (len = 1).
23    pub num_clauses: TrackedCudaSlice<u32>,
24    /// Device-resident num_lits (len = 1).
25    pub num_lits: TrackedCudaSlice<u32>,
26    /// CSR offsets (len = clause_cap + 1).
27    pub clause_offsets: TrackedCudaSlice<u32>,
28    /// Flattened CSR literal array (len = lit_cap).
29    pub literals: TrackedCudaSlice<i32>,
30}
31
32impl GpuCnf {
33    pub(crate) fn require_provider_memory(
34        &self,
35        provider: &CudaKernelProvider,
36        context: &'static str,
37    ) -> Result<()> {
38        let expected_memory = Arc::as_ptr(provider.memory()) as usize;
39        self.require_slice_provider_memory(
40            context,
41            "num_vars",
42            self.num_vars.memory_manager_ptr_value(),
43            expected_memory,
44        )?;
45        self.require_slice_provider_memory(
46            context,
47            "num_clauses",
48            self.num_clauses.memory_manager_ptr_value(),
49            expected_memory,
50        )?;
51        self.require_slice_provider_memory(
52            context,
53            "num_lits",
54            self.num_lits.memory_manager_ptr_value(),
55            expected_memory,
56        )?;
57        self.require_slice_provider_memory(
58            context,
59            "clause_offsets",
60            self.clause_offsets.memory_manager_ptr_value(),
61            expected_memory,
62        )?;
63        self.require_slice_provider_memory(
64            context,
65            "literals",
66            self.literals.memory_manager_ptr_value(),
67            expected_memory,
68        )?;
69
70        if self.num_vars.len() != 1 || self.num_clauses.len() != 1 || self.num_lits.len() != 1 {
71            return Err(XlogError::UnsupportedEpistemicConstruct {
72                construct: context.to_string(),
73                context: format!(
74                    "GPU CNF scalar buffers must have len=1, got num_vars={} num_clauses={} num_lits={}",
75                    self.num_vars.len(),
76                    self.num_clauses.len(),
77                    self.num_lits.len()
78                ),
79            });
80        }
81        let expected_offsets = (self.clause_cap as usize).checked_add(1).ok_or_else(|| {
82            XlogError::UnsupportedEpistemicConstruct {
83                construct: context.to_string(),
84                context: "GPU CNF clause offset length overflowed".to_string(),
85            }
86        })?;
87        if self.clause_offsets.len() != expected_offsets
88            || self.literals.len() != self.lit_cap as usize
89        {
90            return Err(XlogError::UnsupportedEpistemicConstruct {
91                construct: context.to_string(),
92                context: format!(
93                    "GPU CNF buffer lengths must match capacities, got offsets={}/{} literals={}/{}",
94                    self.clause_offsets.len(),
95                    expected_offsets,
96                    self.literals.len(),
97                    self.lit_cap
98                ),
99            });
100        }
101        Ok(())
102    }
103
104    fn require_slice_provider_memory(
105        &self,
106        context: &'static str,
107        name: &'static str,
108        actual_memory: usize,
109        expected_memory: usize,
110    ) -> Result<()> {
111        if actual_memory != expected_memory {
112            return Err(XlogError::UnsupportedEpistemicConstruct {
113                construct: context.to_string(),
114                context: format!(
115                    "GPU CNF buffer {name} belongs to memory manager {actual_memory}, expected {expected_memory}"
116                ),
117            });
118        }
119        Ok(())
120    }
121
122    #[inline]
123    #[allow(dead_code)] // diagnostic accessor, retained for debugging
124    pub(crate) fn offsets_len(&self) -> usize {
125        self.clause_offsets.len()
126    }
127
128    #[inline]
129    pub fn num_literals_cap(&self) -> usize {
130        self.lit_cap as usize
131    }
132
133    /// Host -> device upload helper for tests and tooling.
134    ///
135    /// Production GPU-native paths should build `GpuCnf` directly on device.
136    pub fn from_host(instance: &SolveInstance, provider: &Arc<CudaKernelProvider>) -> Result<Self> {
137        if instance.objective != crate::Objective::Satisfaction {
138            return Err(XlogError::Compilation(format!(
139                "GpuCnf::from_host only supports Objective::Satisfaction, got {:?}",
140                instance.objective
141            )));
142        }
143        if instance.num_vars == 0 {
144            return Err(XlogError::Compilation(
145                "GpuCnf::from_host requires num_vars > 0".to_string(),
146            ));
147        }
148        if instance.num_vars > i32::MAX as u32 {
149            return Err(XlogError::Compilation(
150                "GpuCnf::from_host requires DIMACS variables to fit i32".to_string(),
151            ));
152        }
153        if !instance.validate() {
154            return Err(XlogError::Compilation(
155                "GpuCnf::from_host saw a literal variable outside num_vars".to_string(),
156            ));
157        }
158
159        let num_vars = instance.num_vars;
160        let num_clauses = u32::try_from(instance.clauses.len()).map_err(|_| {
161            XlogError::Compilation("GpuCnf::from_host clause count exceeds u32".to_string())
162        })?;
163        let offsets_len = instance.clauses.len().checked_add(1).ok_or_else(|| {
164            XlogError::Compilation("GpuCnf::from_host clause offset count overflow".to_string())
165        })?;
166        let total_literals = instance.clauses.iter().try_fold(0usize, |acc, clause| {
167            acc.checked_add(clause.literals.len()).ok_or_else(|| {
168                XlogError::Compilation("GpuCnf::from_host literal count overflow".to_string())
169            })
170        })?;
171        let lit_cap = u32::try_from(total_literals).map_err(|_| {
172            XlogError::Compilation("GpuCnf::from_host literal count exceeds u32".to_string())
173        })?;
174
175        // Build CSR on host.
176        let mut clause_offsets: Vec<u32> = Vec::with_capacity(offsets_len);
177        clause_offsets.push(0);
178
179        let mut literals: Vec<i32> = Vec::with_capacity(total_literals);
180        for clause in &instance.clauses {
181            let start = clause_offsets.last().copied().ok_or_else(|| {
182                XlogError::Kernel(
183                    "GpuCnf::from_host internal error: missing initial clause offset".to_string(),
184                )
185            })?;
186            let len = u32::try_from(clause.literals.len()).map_err(|_| {
187                XlogError::Compilation(
188                    "GpuCnf::from_host clause literal count exceeds u32".to_string(),
189                )
190            })?;
191            let end = start
192                .checked_add(len)
193                .ok_or_else(|| XlogError::Compilation("CNF literal count overflow".to_string()))?;
194            clause_offsets.push(end);
195
196            for &lit in &clause.literals {
197                let dimacs = lit.to_dimacs();
198                if dimacs == 0 {
199                    return Err(XlogError::Compilation(
200                        "CNF contains DIMACS 0 literal".to_string(),
201                    ));
202                }
203                literals.push(dimacs);
204            }
205        }
206
207        if clause_offsets.len() != offsets_len {
208            return Err(XlogError::Kernel(
209                "GpuCnf::from_host internal error: offsets length mismatch".to_string(),
210            ));
211        }
212        if literals.len() != total_literals {
213            return Err(XlogError::Kernel(
214                "GpuCnf::from_host internal error: literal length mismatch".to_string(),
215            ));
216        }
217
218        let memory = provider.memory();
219
220        // Device scalars (len=1 each).
221        let mut d_num_vars = memory.alloc::<u32>(1)?;
222        let mut d_num_clauses = memory.alloc::<u32>(1)?;
223        let mut d_num_lits = memory.alloc::<u32>(1)?;
224
225        provider
226            .htod_launch_metadata_sync_copy_into(&[num_vars], &mut d_num_vars)
227            .map_err(|e| XlogError::Kernel(format!("Failed to upload CNF num_vars: {}", e)))?;
228        provider
229            .htod_launch_metadata_sync_copy_into(&[num_clauses], &mut d_num_clauses)
230            .map_err(|e| XlogError::Kernel(format!("Failed to upload CNF num_clauses: {}", e)))?;
231
232        let mut d_offsets = memory.alloc::<u32>(clause_offsets.len())?;
233        let mut d_lits = memory.alloc::<i32>(literals.len())?;
234
235        provider
236            .htod_sync_copy_into_tracked(&clause_offsets, &mut d_offsets)
237            .map_err(|e| XlogError::Kernel(format!("Failed to upload CNF offsets: {}", e)))?;
238        provider
239            .htod_sync_copy_into_tracked(&literals, &mut d_lits)
240            .map_err(|e| XlogError::Kernel(format!("Failed to upload CNF lits: {}", e)))?;
241
242        provider
243            .htod_launch_metadata_sync_copy_into(&[lit_cap], &mut d_num_lits)
244            .map_err(|e| XlogError::Kernel(format!("Failed to upload CNF num_lits: {}", e)))?;
245
246        Ok(Self {
247            var_cap: num_vars,
248            clause_cap: num_clauses,
249            lit_cap,
250            num_vars: d_num_vars,
251            num_clauses: d_num_clauses,
252            num_lits: d_num_lits,
253            clause_offsets: d_offsets,
254            literals: d_lits,
255        })
256    }
257}