1use std::sync::Arc;
9
10use cudarc::driver::{DeviceSlice, LaunchConfig};
11use xlog_core::{Result, XlogError};
12use xlog_cuda::memory::TrackedCudaSlice;
13use xlog_cuda::provider::{d4_kernels, scan_kernels, D4_MODULE, SCAN_MODULE};
14use xlog_cuda::{CudaKernelProvider, LaunchAsync};
15use xlog_solve::GpuCnf;
16
17use crate::gpu::GpuXgcf;
18
19pub(crate) mod build;
20pub(crate) mod frontier;
21
22#[cfg(test)]
24pub(crate) use frontier::build_frontier_bitset;
25
26pub(super) fn alloc_compile_gate(
27 provider: &CudaKernelProvider,
28 value: u32,
29) -> Result<TrackedCudaSlice<u32>> {
30 let memory = provider.memory();
31 let mut gate = memory.alloc::<u32>(1)?;
32 provider
33 .htod_launch_metadata_sync_copy_into(&[value], &mut gate)
34 .map_err(|e| XlogError::Kernel(format!("compile gate upload failed: {}", e)))?;
35 Ok(gate)
36}
37
38pub(super) fn memset_u8_sync(dst: &mut TrackedCudaSlice<u8>, value: u8) -> Result<()> {
39 dst.stream()
40 .context()
41 .bind_to_thread()
42 .map_err(|e| XlogError::Kernel(format!("bind_to_thread failed: {}", e)))?;
43 let dptr = dst.device_ptr_value();
44 unsafe { cudarc::driver::result::memset_d8_sync(dptr, value, dst.len()) }
45 .map_err(|e| XlogError::Kernel(format!("memset_d8_sync failed: {}", e)))?;
46 Ok(())
47}
48
49#[derive(Debug, Clone, Copy)]
55pub struct GpuCompileConfig {
56 pub frontier_depth: u16,
58 pub max_frontier_items: u32,
60 pub max_depth: u16,
62 pub smooth_node_cap: u32,
64 pub smooth_edge_cap: u32,
66
67 pub cdcl_restart_interval: u32,
69 pub cdcl_learned_bytes: u64,
71 pub cdcl_conflict_budget: Option<u64>,
73
74 pub incremental_verify: bool,
76}
77
78impl Default for GpuCompileConfig {
79 fn default() -> Self {
84 Self {
85 frontier_depth: 6,
86 max_frontier_items: 64,
87 max_depth: 128,
88 smooth_node_cap: 65_536,
89 smooth_edge_cap: 131_072,
90 cdcl_restart_interval: 64,
91 cdcl_learned_bytes: 4 * 1024 * 1024,
92 cdcl_conflict_budget: None,
93 incremental_verify: false,
94 }
95 }
96}
97
98pub fn validate_cnf_gpu(cnf: &GpuCnf, provider: &CudaKernelProvider) -> Result<()> {
103 let device = provider.device().inner();
104 let validate = device
105 .get_func(D4_MODULE, d4_kernels::D4_VALIDATE_CNF)
106 .ok_or_else(|| XlogError::Kernel("d4_validate_cnf kernel not found".to_string()))?;
107
108 let var_cap = cnf.var_cap;
109 let clause_cap = cnf.clause_cap;
110 let lit_cap = cnf.lit_cap;
111
112 unsafe {
114 validate.clone().launch(
115 LaunchConfig {
116 grid_dim: (1, 1, 1),
117 block_dim: (1, 1, 1),
118 shared_mem_bytes: 0,
119 },
120 (
121 var_cap,
122 clause_cap,
123 lit_cap,
124 &cnf.num_vars,
125 &cnf.num_clauses,
126 &cnf.num_lits,
127 &cnf.clause_offsets,
128 &cnf.literals,
129 ),
130 )
131 }
132 .map_err(|e| XlogError::Kernel(format!("d4_validate_cnf failed: {}", e)))?;
133
134 provider
137 .device()
138 .synchronize()
139 .map_err(|e| XlogError::Kernel(format!("sync after d4_validate_cnf failed: {}", e)))?;
140
141 Ok(())
142}
143
144pub fn compute_free_var_mask_gpu(
150 cnf: &GpuCnf,
151 circuit: &GpuXgcf,
152 provider: &CudaKernelProvider,
153) -> Result<TrackedCudaSlice<u8>> {
154 let compile_needed = alloc_compile_gate(provider, 1)?;
155 compute_free_var_mask_gpu_gated(cnf, circuit, provider, &compile_needed)
156}
157
158pub(crate) fn compute_free_var_mask_gpu_gated(
159 cnf: &GpuCnf,
160 circuit: &GpuXgcf,
161 provider: &CudaKernelProvider,
162 compile_needed: &TrackedCudaSlice<u32>,
163) -> Result<TrackedCudaSlice<u8>> {
164 if cnf.var_cap == 0 {
165 return Err(XlogError::Compilation(
166 "compute_free_var_mask_gpu requires var_cap > 0".to_string(),
167 ));
168 }
169 if circuit.max_var() > cnf.var_cap {
170 return Err(XlogError::Compilation(format!(
171 "compute_free_var_mask_gpu: circuit max_var {} exceeds CNF var_cap {}",
172 circuit.max_var(),
173 cnf.var_cap
174 )));
175 }
176
177 let num_nodes = u32::try_from(circuit.num_nodes()).map_err(|_| {
178 XlogError::Compilation("compute_free_var_mask_gpu: num_nodes overflow".to_string())
179 })?;
180
181 let mask_len = (cnf.var_cap as u64)
182 .checked_add(1)
183 .and_then(|v| usize::try_from(v).ok())
184 .ok_or_else(|| {
185 XlogError::Compilation("compute_free_var_mask_gpu: mask length overflow".to_string())
186 })?;
187
188 let memory = provider.memory();
189 let device = provider.device().inner();
190
191 let mut vars_in_clauses = memory.alloc::<u32>(mask_len)?;
192 let mut vars_in_circuit = memory.alloc::<u32>(mask_len)?;
193 let mut free_var_mask = memory.alloc::<u8>(mask_len)?;
194
195 device.memset_zeros(&mut vars_in_clauses).map_err(|e| {
196 XlogError::Kernel(format!(
197 "compute_free_var_mask_gpu: zero vars_in_clauses: {}",
198 e
199 ))
200 })?;
201 device.memset_zeros(&mut vars_in_circuit).map_err(|e| {
202 XlogError::Kernel(format!(
203 "compute_free_var_mask_gpu: zero vars_in_circuit: {}",
204 e
205 ))
206 })?;
207 device.memset_zeros(&mut free_var_mask).map_err(|e| {
208 XlogError::Kernel(format!(
209 "compute_free_var_mask_gpu: zero free_var_mask: {}",
210 e
211 ))
212 })?;
213
214 let block_dim = 256u32;
215
216 if cnf.lit_cap > 0 {
217 let grid_dim = cnf.lit_cap.div_ceil(block_dim);
218 let mark_clauses = device
219 .get_func(D4_MODULE, d4_kernels::D4_MARK_VARS_IN_CLAUSES)
220 .ok_or_else(|| {
221 XlogError::Kernel("d4_mark_vars_in_clauses kernel not found".to_string())
222 })?;
223 unsafe {
224 mark_clauses.clone().launch(
225 LaunchConfig {
226 grid_dim: (grid_dim, 1, 1),
227 block_dim: (block_dim, 1, 1),
228 shared_mem_bytes: 0,
229 },
230 (
231 compile_needed,
232 cnf.var_cap,
233 cnf.lit_cap,
234 &cnf.num_vars,
235 &cnf.num_lits,
236 &cnf.literals,
237 &mut vars_in_clauses,
238 ),
239 )
240 }
241 .map_err(|e| XlogError::Kernel(format!("d4_mark_vars_in_clauses failed: {}", e)))?;
242 }
243
244 if num_nodes > 0 {
245 let grid_dim = num_nodes.div_ceil(block_dim);
246 let mark_circuit = device
247 .get_func(D4_MODULE, d4_kernels::D4_MARK_VARS_IN_CIRCUIT)
248 .ok_or_else(|| {
249 XlogError::Kernel("d4_mark_vars_in_circuit kernel not found".to_string())
250 })?;
251 unsafe {
252 mark_circuit.clone().launch(
253 LaunchConfig {
254 grid_dim: (grid_dim, 1, 1),
255 block_dim: (block_dim, 1, 1),
256 shared_mem_bytes: 0,
257 },
258 (
259 compile_needed,
260 circuit.node_type(),
261 circuit.lit(),
262 circuit.decision_var(),
263 circuit.num_nodes_device(),
264 &cnf.num_vars,
265 cnf.var_cap,
266 &mut vars_in_circuit,
267 ),
268 )
269 }
270 .map_err(|e| XlogError::Kernel(format!("d4_mark_vars_in_circuit failed: {}", e)))?;
271 }
272
273 let mask_len_u32 = cnf.var_cap.checked_add(1).ok_or_else(|| {
274 XlogError::Compilation("compute_free_var_mask_gpu: mask length overflow".to_string())
275 })?;
276 let grid_dim = mask_len_u32.div_ceil(block_dim);
277 let build_mask = device
278 .get_func(D4_MODULE, d4_kernels::D4_BUILD_FREE_VAR_MASK)
279 .ok_or_else(|| XlogError::Kernel("d4_build_free_var_mask kernel not found".to_string()))?;
280 unsafe {
281 build_mask.clone().launch(
282 LaunchConfig {
283 grid_dim: (grid_dim, 1, 1),
284 block_dim: (block_dim, 1, 1),
285 shared_mem_bytes: 0,
286 },
287 (
288 compile_needed,
289 &cnf.num_vars,
290 cnf.var_cap,
291 &vars_in_clauses,
292 &vars_in_circuit,
293 &mut free_var_mask,
294 ),
295 )
296 }
297 .map_err(|e| XlogError::Kernel(format!("d4_build_free_var_mask failed: {}", e)))?;
298
299 Ok(free_var_mask)
300}
301
302pub(super) fn bitset_words_per_item(var_cap: u32) -> Result<u32> {
303 let bits = var_cap
305 .checked_add(1)
306 .ok_or_else(|| XlogError::Kernel("bitset var_cap+1 overflow".to_string()))?;
307 Ok(bits.div_ceil(32))
308}
309
310pub(super) fn checked_pool_len_u32(max_items: u32, stride: u32, context: &str) -> Result<u32> {
311 let len = (max_items as u64)
312 .checked_mul(stride as u64)
313 .ok_or_else(|| XlogError::Kernel(format!("{} pool length overflow", context)))?;
314 if len > (u32::MAX as u64) {
315 return Err(XlogError::Kernel(format!(
316 "{} pool length {} exceeds u32::MAX",
317 context, len
318 )));
319 }
320 Ok(len as u32)
321}
322
323pub(super) fn checked_pool_len_usize(max_items: u32, stride: u32, context: &str) -> Result<usize> {
324 let len_u32 = checked_pool_len_u32(max_items, stride, context)?;
325 Ok(len_u32 as usize)
326}
327
328pub(crate) fn exclusive_scan_u32_inplace(
329 provider: &CudaKernelProvider,
330 data: &mut TrackedCudaSlice<u32>,
331 n: u32,
332) -> Result<()> {
333 if n == 0 {
334 return Ok(());
335 }
336
337 let device = provider.device().inner();
338 let block_size = 256u32;
339
340 if n <= block_size {
341 let phase2 = device
342 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
343 .ok_or_else(|| {
344 XlogError::Kernel("multiblock_scan_phase2 kernel not found".to_string())
345 })?;
346 unsafe {
347 phase2.clone().launch(
348 LaunchConfig {
349 grid_dim: (1, 1, 1),
350 block_dim: (block_size, 1, 1),
351 shared_mem_bytes: 0,
352 },
353 (&mut *data, n),
354 )
355 }
356 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase2 failed: {}", e)))?;
357 return Ok(());
358 }
359
360 let num_blocks = n.div_ceil(block_size);
361 let memory = provider.memory();
362 let mut block_sums = memory.alloc::<u32>(num_blocks as usize)?;
363
364 let phase1 = device
365 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
366 .ok_or_else(|| {
367 XlogError::Kernel("multiblock_scan_u32_phase1 kernel not found".to_string())
368 })?;
369 unsafe {
370 phase1.clone().launch(
371 LaunchConfig {
372 grid_dim: (num_blocks, 1, 1),
373 block_dim: (block_size, 1, 1),
374 shared_mem_bytes: 0,
375 },
376 (&mut *data, &mut block_sums, n),
377 )
378 }
379 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_u32_phase1 failed: {}", e)))?;
380
381 if num_blocks > 1 {
382 exclusive_scan_u32_inplace(provider, &mut block_sums, num_blocks)?;
383 }
384
385 let phase3 = device
386 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
387 .ok_or_else(|| XlogError::Kernel("multiblock_scan_phase3 kernel not found".to_string()))?;
388 unsafe {
389 phase3.clone().launch(
390 LaunchConfig {
391 grid_dim: (num_blocks, 1, 1),
392 block_dim: (block_size, 1, 1),
393 shared_mem_bytes: 0,
394 },
395 (&mut *data, &block_sums, n),
396 )
397 }
398 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
399
400 Ok(())
401}
402
403pub(crate) fn compile_gpu_d4(
405 cnf: &GpuCnf,
406 provider: &Arc<CudaKernelProvider>,
407 config: &GpuCompileConfig,
408) -> Result<GpuXgcf> {
409 let compile_needed = alloc_compile_gate(provider, 1)?;
410 Ok(build::compile_gpu_d4_with_gate(cnf, provider, config, &compile_needed)?.0)
411}
412
413pub fn compile_gpu_d4_gated(
416 cnf: &GpuCnf,
417 provider: &Arc<CudaKernelProvider>,
418 config: &GpuCompileConfig,
419 compile_needed: &TrackedCudaSlice<u32>,
420) -> Result<GpuXgcf> {
421 Ok(build::compile_gpu_d4_with_gate(cnf, provider, config, compile_needed)?.0)
422}
423
424pub(crate) fn compile_gpu_d4_gated_with_stats(
427 cnf: &GpuCnf,
428 provider: &Arc<CudaKernelProvider>,
429 config: &GpuCompileConfig,
430 compile_needed: &TrackedCudaSlice<u32>,
431) -> Result<(GpuXgcf, u32)> {
432 build::compile_gpu_d4_with_gate(cnf, provider, config, compile_needed)
433}
434
435#[cfg(test)]
436mod tests {
437 #[test]
438 fn gpu_d4_compile_config_requires_smoothing_caps() {
439 let config = super::GpuCompileConfig {
440 frontier_depth: 0,
441 max_frontier_items: 1,
442 max_depth: 8,
443 cdcl_restart_interval: 128,
444 cdcl_learned_bytes: 1 << 20,
445 cdcl_conflict_budget: None,
446 smooth_node_cap: 256,
447 smooth_edge_cap: 512,
448 incremental_verify: false,
449 };
450 assert!(config.smooth_node_cap > 0);
451 assert!(config.smooth_edge_cap > 0);
452 }
453}