Skip to main content

xlog_prob/compilation/
gpu_cnf.rs

1//! GPU-native Tseitin CNF encoding for PIR graphs.
2
3use std::ffi::c_void;
4use std::sync::Arc;
5
6use cudarc::driver::{DeviceSlice, LaunchConfig};
7use xlog_core::{Result, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::provider::{cnf_kernels, CNF_MODULE};
10use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
11use xlog_solve::GpuCnf;
12
13use crate::compilation::gpu_pir::GpuPirGraph;
14use crate::compilation::gpu_pir::GpuPirRoots;
15
16/// GPU-resident CNF variable tables for PIR ids.
17pub struct GpuCnfVarTables {
18    pub node_var: TrackedCudaSlice<u32>,
19    pub leaf_var: TrackedCudaSlice<u32>,
20    pub choice_var: TrackedCudaSlice<u32>,
21    pub max_var: u32,
22}
23
24/// GPU-resident CNF encoding bundle (CNF + var tables).
25pub struct GpuCnfEncoding {
26    pub cnf: GpuCnf,
27    pub vars: GpuCnfVarTables,
28    /// Largest variable id that is semantically meaningful and should be eligible for branching
29    /// in the GPU CDCL verifier (len = 1, device-resident).
30    ///
31    /// For PIR Tseitin encodings this is the end of the leaf+choice var range, excluding internal
32    /// node Tseitin vars which are propagation-only.
33    pub decision_var_limit: TrackedCudaSlice<u32>,
34}
35
36const MAX_GRID_X: u64 = 65_535;
37
38fn checked_grid_dim(n: u32, block: u32, context: &str) -> Result<u32> {
39    if block == 0 {
40        return Err(XlogError::Kernel(format!(
41            "{context}: CUDA launch block size must be nonzero"
42        )));
43    }
44    let grid = if n == 0 {
45        1
46    } else {
47        u64::from(n).div_ceil(u64::from(block))
48    };
49    if grid > MAX_GRID_X {
50        return Err(XlogError::Kernel(format!(
51            "{context}: launch grid {grid} exceeds x-dimension limit {MAX_GRID_X} \
52             for {n} elements with block size {block}"
53        )));
54    }
55    Ok(grid as u32)
56}
57
58/// Encode a GPU PIR graph into GPU-resident Tseitin CNF.
59pub fn encode_cnf_gpu(
60    pir: &GpuPirGraph,
61    roots: &GpuPirRoots,
62    provider: &Arc<CudaKernelProvider>,
63) -> Result<GpuCnfEncoding> {
64    if roots.roots.is_empty() {
65        return Err(XlogError::Compilation(
66            "Cannot encode CNF for empty PIR root set".to_string(),
67        ));
68    }
69    let num_nodes = pir.node_type.len();
70    if num_nodes == 0 {
71        return Err(XlogError::Compilation(
72            "Cannot encode CNF for empty PIR graph".to_string(),
73        ));
74    }
75
76    let num_nodes_u32 = u32::try_from(num_nodes)
77        .map_err(|_| XlogError::Compilation("PIR node count overflow".to_string()))?;
78    let num_roots_u32 = u32::try_from(roots.roots.len())
79        .map_err(|_| XlogError::Compilation("PIR root count exceeds u32::MAX".to_string()))?;
80
81    let num_edges = pir.children.len();
82    let n64 = num_nodes as u64;
83    let e64 = num_edges as u64;
84
85    let var_cap = u32::try_from(
86        n64.checked_mul(3)
87            .ok_or_else(|| XlogError::Kernel("CNF var capacity overflow".to_string()))?,
88    )
89    .map_err(|_| XlogError::Kernel("CNF var capacity exceeds u32::MAX".to_string()))?;
90    let clause_cap = u32::try_from(
91        e64.checked_add(
92            n64.checked_mul(4)
93                .ok_or_else(|| XlogError::Kernel("CNF clause capacity overflow".to_string()))?,
94        )
95        .ok_or_else(|| XlogError::Kernel("CNF clause capacity overflow".to_string()))?,
96    )
97    .map_err(|_| XlogError::Kernel("CNF clause capacity exceeds u32::MAX".to_string()))?;
98    let lit_cap =
99        u32::try_from(
100            e64.checked_mul(3)
101                .ok_or_else(|| XlogError::Kernel("CNF literal capacity overflow".to_string()))?
102                .checked_add(n64.checked_mul(12).ok_or_else(|| {
103                    XlogError::Kernel("CNF literal capacity overflow".to_string())
104                })?)
105                .ok_or_else(|| XlogError::Kernel("CNF literal capacity overflow".to_string()))?,
106        )
107        .map_err(|_| XlogError::Kernel("CNF literal capacity exceeds u32::MAX".to_string()))?;
108
109    let leaf_cap = num_nodes_u32;
110    let choice_cap = num_nodes_u32;
111
112    let memory = provider.memory();
113    let device = provider.device().inner();
114
115    let mut reachable = memory.alloc::<u32>(num_nodes)?;
116    let mut queue = memory.alloc::<u32>(num_nodes)?;
117    let mut queue_ready = memory.alloc::<u32>(num_nodes)?;
118    let mut head = memory.alloc::<u32>(1)?;
119    let mut tail = memory.alloc::<u32>(1)?;
120    let mut in_flight = memory.alloc::<u32>(1)?;
121
122    let mut leaf_used = memory.alloc::<u32>(leaf_cap as usize)?;
123    let mut choice_used = memory.alloc::<u32>(choice_cap as usize)?;
124    let mut leaf_var = memory.alloc::<u32>(leaf_cap as usize)?;
125    let mut choice_var = memory.alloc::<u32>(choice_cap as usize)?;
126
127    let mut node_needs_var = memory.alloc::<u32>(num_nodes)?;
128    let mut node_var = memory.alloc::<u32>(num_nodes)?;
129
130    let mut clause_counts = memory.alloc::<u32>(num_nodes)?;
131    let mut lit_counts = memory.alloc::<u32>(num_nodes)?;
132
133    let mut leaf_prefix = memory.alloc::<u32>(leaf_cap as usize)?;
134    let mut choice_prefix = memory.alloc::<u32>(choice_cap as usize)?;
135
136    let mut node_last = memory.alloc::<u32>(1)?;
137    let mut clause_last = memory.alloc::<u32>(1)?;
138    let mut lit_last = memory.alloc::<u32>(1)?;
139
140    let mut num_leaf = memory.alloc::<u32>(1)?;
141    let mut num_choice = memory.alloc::<u32>(1)?;
142    let mut base_choice = memory.alloc::<u32>(1)?;
143    let mut base_node = memory.alloc::<u32>(1)?;
144    let mut decision_var_limit = memory.alloc::<u32>(1)?;
145
146    let d_num_vars = memory.alloc::<u32>(1)?;
147    let d_num_clauses = memory.alloc::<u32>(1)?;
148    let d_num_lits = memory.alloc::<u32>(1)?;
149
150    let mut d_offsets = memory.alloc::<u32>((clause_cap as usize) + 1)?;
151    let d_lits = memory.alloc::<i32>(lit_cap as usize)?;
152
153    device
154        .memset_zeros(&mut reachable)
155        .map_err(|e| XlogError::Kernel(format!("zero reachable: {}", e)))?;
156    device
157        .memset_zeros(&mut queue)
158        .map_err(|e| XlogError::Kernel(format!("zero queue: {}", e)))?;
159    device
160        .memset_zeros(&mut queue_ready)
161        .map_err(|e| XlogError::Kernel(format!("zero queue_ready: {}", e)))?;
162    device
163        .memset_zeros(&mut head)
164        .map_err(|e| XlogError::Kernel(format!("zero head: {}", e)))?;
165    device
166        .memset_zeros(&mut tail)
167        .map_err(|e| XlogError::Kernel(format!("zero tail: {}", e)))?;
168    device
169        .memset_zeros(&mut in_flight)
170        .map_err(|e| XlogError::Kernel(format!("zero in_flight: {}", e)))?;
171    device
172        .memset_zeros(&mut leaf_used)
173        .map_err(|e| XlogError::Kernel(format!("zero leaf_used: {}", e)))?;
174    device
175        .memset_zeros(&mut choice_used)
176        .map_err(|e| XlogError::Kernel(format!("zero choice_used: {}", e)))?;
177    device
178        .memset_zeros(&mut leaf_var)
179        .map_err(|e| XlogError::Kernel(format!("zero leaf_var: {}", e)))?;
180    device
181        .memset_zeros(&mut choice_var)
182        .map_err(|e| XlogError::Kernel(format!("zero choice_var: {}", e)))?;
183    device
184        .memset_zeros(&mut node_needs_var)
185        .map_err(|e| XlogError::Kernel(format!("zero node_needs_var: {}", e)))?;
186    device
187        .memset_zeros(&mut node_var)
188        .map_err(|e| XlogError::Kernel(format!("zero node_var: {}", e)))?;
189    device
190        .memset_zeros(&mut clause_counts)
191        .map_err(|e| XlogError::Kernel(format!("zero clause_counts: {}", e)))?;
192    device
193        .memset_zeros(&mut lit_counts)
194        .map_err(|e| XlogError::Kernel(format!("zero lit_counts: {}", e)))?;
195
196    let reach_init_fn = device
197        .get_func(CNF_MODULE, cnf_kernels::CNF_REACHABILITY_INIT)
198        .ok_or_else(|| XlogError::Kernel("cnf_reachability_init kernel not found".to_string()))?;
199    let reach_bfs_fn = device
200        .get_func(CNF_MODULE, cnf_kernels::CNF_REACHABILITY_BFS)
201        .ok_or_else(|| XlogError::Kernel("cnf_reachability_bfs kernel not found".to_string()))?;
202    let mark_leaf_choice_fn = device
203        .get_func(CNF_MODULE, cnf_kernels::CNF_MARK_LEAF_CHOICE)
204        .ok_or_else(|| XlogError::Kernel("cnf_mark_leaf_choice kernel not found".to_string()))?;
205    let assign_leaf_var_fn = device
206        .get_func(CNF_MODULE, cnf_kernels::CNF_ASSIGN_LEAF_VAR)
207        .ok_or_else(|| XlogError::Kernel("cnf_assign_leaf_var kernel not found".to_string()))?;
208    let assign_choice_var_fn = device
209        .get_func(CNF_MODULE, cnf_kernels::CNF_ASSIGN_CHOICE_VAR)
210        .ok_or_else(|| XlogError::Kernel("cnf_assign_choice_var kernel not found".to_string()))?;
211    let mark_node_vars_fn = device
212        .get_func(CNF_MODULE, cnf_kernels::CNF_MARK_NODE_VARS)
213        .ok_or_else(|| XlogError::Kernel("cnf_mark_node_vars kernel not found".to_string()))?;
214    let count_clauses_fn = device
215        .get_func(CNF_MODULE, cnf_kernels::CNF_COUNT_CLAUSES)
216        .ok_or_else(|| XlogError::Kernel("cnf_count_clauses kernel not found".to_string()))?;
217    let capture_last_fn = device
218        .get_func(CNF_MODULE, cnf_kernels::CNF_CAPTURE_LAST_COUNTS)
219        .ok_or_else(|| XlogError::Kernel("cnf_capture_last_counts kernel not found".to_string()))?;
220    let leaf_choice_totals_fn = device
221        .get_func(CNF_MODULE, cnf_kernels::CNF_COMPUTE_LEAF_CHOICE_TOTALS)
222        .ok_or_else(|| {
223            XlogError::Kernel("cnf_compute_leaf_choice_totals kernel not found".to_string())
224        })?;
225    let compute_totals_fn = device
226        .get_func(CNF_MODULE, cnf_kernels::CNF_COMPUTE_TOTALS)
227        .ok_or_else(|| XlogError::Kernel("cnf_compute_totals kernel not found".to_string()))?;
228    let assign_node_var_fn = device
229        .get_func(CNF_MODULE, cnf_kernels::CNF_ASSIGN_NODE_VAR)
230        .ok_or_else(|| XlogError::Kernel("cnf_assign_node_var kernel not found".to_string()))?;
231    let emit_clauses_fn = device
232        .get_func(CNF_MODULE, cnf_kernels::CNF_EMIT_CLAUSES)
233        .ok_or_else(|| XlogError::Kernel("cnf_emit_clauses kernel not found".to_string()))?;
234    let set_clause_end_fn = device
235        .get_func(CNF_MODULE, cnf_kernels::CNF_SET_CLAUSE_END)
236        .ok_or_else(|| XlogError::Kernel("cnf_set_clause_end kernel not found".to_string()))?;
237
238    let block = 256u32;
239
240    let grid_roots = checked_grid_dim(num_roots_u32, block, "cnf_reachability_init")?;
241    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
242    unsafe {
243        reach_init_fn.clone().launch(
244            LaunchConfig {
245                grid_dim: (grid_roots, 1, 1),
246                block_dim: (block, 1, 1),
247                shared_mem_bytes: 0,
248            },
249            (
250                &roots.roots,
251                num_roots_u32,
252                num_nodes_u32,
253                &mut reachable,
254                &mut queue,
255                &mut queue_ready,
256                &mut head,
257                &mut tail,
258                &mut in_flight,
259            ),
260        )
261    }
262    .map_err(|e| XlogError::Kernel(format!("cnf_reachability_init failed: {}", e)))?;
263
264    let grid_nodes = checked_grid_dim(num_nodes_u32, block, "cnf node kernels")?;
265    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
266    unsafe {
267        reach_bfs_fn.clone().launch(
268            LaunchConfig {
269                grid_dim: (grid_nodes, 1, 1),
270                block_dim: (block, 1, 1),
271                shared_mem_bytes: 0,
272            },
273            (
274                &pir.node_type,
275                &pir.child_offsets,
276                &pir.children,
277                &pir.decision_child_false,
278                &pir.decision_child_true,
279                num_nodes_u32,
280                &mut reachable,
281                &mut queue,
282                &mut queue_ready,
283                &mut head,
284                &mut tail,
285                &mut in_flight,
286            ),
287        )
288    }
289    .map_err(|e| XlogError::Kernel(format!("cnf_reachability_bfs failed: {}", e)))?;
290
291    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
292    unsafe {
293        mark_leaf_choice_fn.clone().launch(
294            LaunchConfig {
295                grid_dim: (grid_nodes, 1, 1),
296                block_dim: (block, 1, 1),
297                shared_mem_bytes: 0,
298            },
299            (
300                &pir.node_type,
301                &pir.leaf_id,
302                &pir.decision_var,
303                &reachable,
304                num_nodes_u32,
305                leaf_cap,
306                choice_cap,
307                &mut leaf_used,
308                &mut choice_used,
309            ),
310        )
311    }
312    .map_err(|e| XlogError::Kernel(format!("cnf_mark_leaf_choice failed: {}", e)))?;
313
314    if leaf_cap > 0 {
315        device
316            .dtod_copy(&leaf_used, &mut leaf_prefix)
317            .map_err(|e| XlogError::Kernel(format!("copy leaf_used: {}", e)))?;
318        provider.exclusive_scan_u32_inplace(&mut leaf_prefix, leaf_cap)?;
319    }
320    if choice_cap > 0 {
321        device
322            .dtod_copy(&choice_used, &mut choice_prefix)
323            .map_err(|e| XlogError::Kernel(format!("copy choice_used: {}", e)))?;
324        provider.exclusive_scan_u32_inplace(&mut choice_prefix, choice_cap)?;
325    }
326
327    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
328    unsafe {
329        leaf_choice_totals_fn.clone().launch(
330            LaunchConfig {
331                grid_dim: (1, 1, 1),
332                block_dim: (1, 1, 1),
333                shared_mem_bytes: 0,
334            },
335            (
336                &leaf_prefix,
337                &leaf_used,
338                leaf_cap,
339                &choice_prefix,
340                &choice_used,
341                choice_cap,
342                &mut num_leaf,
343                &mut num_choice,
344                &mut base_choice,
345                &mut base_node,
346                &mut decision_var_limit,
347            ),
348        )
349    }
350    .map_err(|e| XlogError::Kernel(format!("cnf_compute_leaf_choice_totals failed: {}", e)))?;
351
352    if leaf_cap > 0 {
353        let grid_leaf = checked_grid_dim(leaf_cap, block, "cnf_assign_leaf_var")?;
354        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
355        unsafe {
356            assign_leaf_var_fn.clone().launch(
357                LaunchConfig {
358                    grid_dim: (grid_leaf, 1, 1),
359                    block_dim: (block, 1, 1),
360                    shared_mem_bytes: 0,
361                },
362                (&leaf_used, &leaf_prefix, leaf_cap, &mut leaf_var),
363            )
364        }
365        .map_err(|e| XlogError::Kernel(format!("cnf_assign_leaf_var failed: {}", e)))?;
366    }
367    if choice_cap > 0 {
368        let grid_choice = checked_grid_dim(choice_cap, block, "cnf_assign_choice_var")?;
369        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
370        unsafe {
371            assign_choice_var_fn.clone().launch(
372                LaunchConfig {
373                    grid_dim: (grid_choice, 1, 1),
374                    block_dim: (block, 1, 1),
375                    shared_mem_bytes: 0,
376                },
377                (
378                    &choice_used,
379                    &choice_prefix,
380                    choice_cap,
381                    &base_choice,
382                    &mut choice_var,
383                ),
384            )
385        }
386        .map_err(|e| XlogError::Kernel(format!("cnf_assign_choice_var failed: {}", e)))?;
387    }
388
389    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
390    unsafe {
391        mark_node_vars_fn.clone().launch(
392            LaunchConfig {
393                grid_dim: (grid_nodes, 1, 1),
394                block_dim: (block, 1, 1),
395                shared_mem_bytes: 0,
396            },
397            (
398                &pir.node_type,
399                &reachable,
400                num_nodes_u32,
401                &mut node_needs_var,
402            ),
403        )
404    }
405    .map_err(|e| XlogError::Kernel(format!("cnf_mark_node_vars failed: {}", e)))?;
406
407    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
408    unsafe {
409        count_clauses_fn.clone().launch(
410            LaunchConfig {
411                grid_dim: (grid_nodes, 1, 1),
412                block_dim: (block, 1, 1),
413                shared_mem_bytes: 0,
414            },
415            (
416                &pir.node_type,
417                &pir.child_offsets,
418                &reachable,
419                num_nodes_u32,
420                &mut clause_counts,
421                &mut lit_counts,
422            ),
423        )
424    }
425    .map_err(|e| XlogError::Kernel(format!("cnf_count_clauses failed: {}", e)))?;
426
427    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
428    unsafe {
429        capture_last_fn.clone().launch(
430            LaunchConfig {
431                grid_dim: (1, 1, 1),
432                block_dim: (1, 1, 1),
433                shared_mem_bytes: 0,
434            },
435            (
436                &node_needs_var,
437                &clause_counts,
438                &lit_counts,
439                num_nodes_u32,
440                &mut node_last,
441                &mut clause_last,
442                &mut lit_last,
443            ),
444        )
445    }
446    .map_err(|e| XlogError::Kernel(format!("cnf_capture_last_counts failed: {}", e)))?;
447
448    provider.exclusive_scan_u32_inplace(&mut node_needs_var, num_nodes_u32)?;
449    provider.exclusive_scan_u32_inplace(&mut clause_counts, num_nodes_u32)?;
450    provider.exclusive_scan_u32_inplace(&mut lit_counts, num_nodes_u32)?;
451
452    let mut totals_params: Vec<*mut c_void> = vec![
453        (&node_needs_var).as_kernel_param(),
454        (&clause_counts).as_kernel_param(),
455        (&lit_counts).as_kernel_param(),
456        (&node_last).as_kernel_param(),
457        (&clause_last).as_kernel_param(),
458        (&lit_last).as_kernel_param(),
459        num_nodes_u32.as_kernel_param(),
460        (&base_node).as_kernel_param(),
461        var_cap.as_kernel_param(),
462        clause_cap.as_kernel_param(),
463        lit_cap.as_kernel_param(),
464        (&d_num_vars).as_kernel_param(),
465        (&d_num_clauses).as_kernel_param(),
466        (&d_num_lits).as_kernel_param(),
467    ];
468    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
469    unsafe {
470        compute_totals_fn.clone().launch(
471            LaunchConfig {
472                grid_dim: (1, 1, 1),
473                block_dim: (1, 1, 1),
474                shared_mem_bytes: 0,
475            },
476            &mut totals_params,
477        )
478    }
479    .map_err(|e| XlogError::Kernel(format!("cnf_compute_totals failed: {}", e)))?;
480
481    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
482    unsafe {
483        assign_node_var_fn.clone().launch(
484            LaunchConfig {
485                grid_dim: (grid_nodes, 1, 1),
486                block_dim: (block, 1, 1),
487                shared_mem_bytes: 0,
488            },
489            (
490                &pir.node_type,
491                &pir.leaf_id,
492                &reachable,
493                &node_needs_var,
494                &base_node,
495                num_nodes_u32,
496                leaf_cap,
497                &leaf_var,
498                &mut node_var,
499            ),
500        )
501    }
502    .map_err(|e| XlogError::Kernel(format!("cnf_assign_node_var failed: {}", e)))?;
503
504    let mut emit_params: Vec<*mut c_void> = vec![
505        (&pir.node_type).as_kernel_param(),
506        (&pir.child_offsets).as_kernel_param(),
507        (&pir.children).as_kernel_param(),
508        (&pir.leaf_id).as_kernel_param(),
509        (&pir.decision_var).as_kernel_param(),
510        (&pir.decision_child_false).as_kernel_param(),
511        (&pir.decision_child_true).as_kernel_param(),
512        (&reachable).as_kernel_param(),
513        (&node_var).as_kernel_param(),
514        (&leaf_var).as_kernel_param(),
515        (&choice_var).as_kernel_param(),
516        (&clause_counts).as_kernel_param(),
517        (&lit_counts).as_kernel_param(),
518        num_nodes_u32.as_kernel_param(),
519        leaf_cap.as_kernel_param(),
520        choice_cap.as_kernel_param(),
521        (&d_offsets).as_kernel_param(),
522        (&d_lits).as_kernel_param(),
523    ];
524
525    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
526    unsafe {
527        emit_clauses_fn.clone().launch(
528            LaunchConfig {
529                grid_dim: (grid_nodes, 1, 1),
530                block_dim: (block, 1, 1),
531                shared_mem_bytes: 0,
532            },
533            &mut emit_params,
534        )
535    }
536    .map_err(|e| XlogError::Kernel(format!("cnf_emit_clauses failed: {}", e)))?;
537
538    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
539    unsafe {
540        set_clause_end_fn.clone().launch(
541            LaunchConfig {
542                grid_dim: (1, 1, 1),
543                block_dim: (1, 1, 1),
544                shared_mem_bytes: 0,
545            },
546            (&mut d_offsets, &d_num_clauses, &d_num_lits),
547        )
548    }
549    .map_err(|e| XlogError::Kernel(format!("cnf_set_clause_end failed: {}", e)))?;
550    // No device synchronize: returns device-resident CNF; same-stream ordering suffices.
551    Ok(GpuCnfEncoding {
552        cnf: GpuCnf {
553            var_cap,
554            clause_cap,
555            lit_cap,
556            num_vars: d_num_vars,
557            num_clauses: d_num_clauses,
558            num_lits: d_num_lits,
559            clause_offsets: d_offsets,
560            literals: d_lits,
561        },
562        vars: GpuCnfVarTables {
563            node_var,
564            leaf_var,
565            choice_var,
566            max_var: var_cap,
567        },
568        decision_var_limit,
569    })
570}