Skip to main content

xlog_prob/
gpu.rs

1//! GPU evaluator for XGCF circuits (CUDA).
2
3use std::ffi::c_void;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::provider::{
9    arith_kernels, d4_kernels, filter_kernels, ARITH_MODULE, D4_MODULE, FILTER_MODULE,
10};
11use xlog_cuda::{circuit_kernels, AsKernelParam, CudaKernelProvider, LaunchAsync, CIRCUIT_MODULE};
12
13use crate::compilation::gpu_d4::exclusive_scan_u32_inplace;
14use crate::xgcf::{Xgcf, XgcfNodeType};
15
16/// Device-resident circuit buffers produced by the GPU compiler.
17///
18/// This matches the XGCF node layout used by `kernels/circuit.cu` and the SAT verifier CNF encoder
19/// in `kernels/sat.cu`.
20pub struct GpuCircuitBuilder {
21    pub node_type: TrackedCudaSlice<u8>,
22    pub child_offsets: TrackedCudaSlice<u32>,
23    pub child_indices: TrackedCudaSlice<u32>,
24    pub lit: TrackedCudaSlice<i32>,
25    pub decision_var: TrackedCudaSlice<u32>,
26    pub decision_child_false: TrackedCudaSlice<u32>,
27    pub decision_child_true: TrackedCudaSlice<u32>,
28}
29
30/// Device layout metadata for XGCF construction.
31pub struct GpuCircuitLayout {
32    pub num_nodes: u32,
33    pub num_edges: u32,
34    pub num_levels: u32,
35    pub level_offsets: TrackedCudaSlice<u32>,
36    pub level_nodes: TrackedCudaSlice<u32>,
37    pub root: u32,
38    pub max_var: u32,
39    pub num_nodes_device: Option<TrackedCudaSlice<u32>>,
40    pub num_edges_device: Option<TrackedCudaSlice<u32>>,
41}
42
43pub struct GpuXgcf {
44    node_type: TrackedCudaSlice<u8>,
45    child_offsets: TrackedCudaSlice<u32>,
46    child_indices: TrackedCudaSlice<u32>,
47    lit: TrackedCudaSlice<i32>,
48    decision_var: TrackedCudaSlice<u32>,
49    decision_child_false: TrackedCudaSlice<u32>,
50    decision_child_true: TrackedCudaSlice<u32>,
51    level_nodes: TrackedCudaSlice<u32>,
52    level_offsets: TrackedCudaSlice<u32>,
53    /// Optional host mirror for efficient per-level launch sizing when the circuit was uploaded
54    /// from host (`GpuXgcf::upload`). GPU-native compilation paths do not populate this.
55    level_offsets_host: Option<Vec<u32>>,
56    node_cap: u32,
57    edge_cap: u32,
58    num_levels: u32,
59    root: u32,
60    max_var: u32,
61    meta_num_nodes: TrackedCudaSlice<u32>,
62    meta_num_edges: TrackedCudaSlice<u32>,
63    var_log_true: TrackedCudaSlice<f64>,
64    var_log_false: TrackedCudaSlice<f64>,
65    values: TrackedCudaSlice<f64>,
66    adj: TrackedCudaSlice<f64>,
67    grad_true: TrackedCudaSlice<f64>,
68    grad_false: TrackedCudaSlice<f64>,
69    free_var_mask: Option<TrackedCudaSlice<u8>>,
70}
71
72fn checked_gpu_u32_len(context: &str, len: usize) -> Result<u32> {
73    u32::try_from(len)
74        .map_err(|_| XlogError::Compilation(format!("{context} exceeds u32::MAX: {len}")))
75}
76
77fn checked_gpu_len_add_one(context: &str, len: usize) -> Result<usize> {
78    len.checked_add(1)
79        .ok_or_else(|| XlogError::Compilation(format!("{context} length overflow")))
80}
81
82fn checked_gpu_launch_blocks(context: &str, item_count: usize, block_size: u32) -> Result<u32> {
83    let item_count = u32::try_from(item_count).map_err(|_| {
84        XlogError::Kernel(format!(
85            "{context} launch item count exceeds u32::MAX: {item_count}"
86        ))
87    })?;
88    item_count
89        .checked_add(block_size - 1)
90        .map(|rounded| rounded / block_size)
91        .ok_or_else(|| XlogError::Kernel(format!("{context} launch grid overflow")))
92}
93
94fn checked_host_level_width(level_offsets: &[u32], level: usize) -> Result<usize> {
95    let start = level_offsets[level];
96    let end = level_offsets[level + 1];
97    if end < start {
98        return Err(XlogError::Compilation(format!(
99            "XGCF invariant violation: level_offsets decrease at level {} ({} > {})",
100            level, start, end
101        )));
102    }
103    Ok((end - start) as usize)
104}
105
106fn validate_xgcf_for_gpu_upload(circuit: &Xgcf) -> Result<(u32, u32, u32)> {
107    let n = circuit.node_type.len();
108    if n == 0 {
109        return Err(XlogError::Compilation(
110            "GPU XGCF upload requires at least one node".to_string(),
111        ));
112    }
113    let node_count = checked_gpu_u32_len("GPU XGCF node count", n)?;
114    let child_offsets_len = checked_gpu_len_add_one("GPU XGCF child_offsets", n)?;
115    if circuit.child_offsets.len() != child_offsets_len {
116        return Err(XlogError::Compilation(format!(
117            "XGCF invariant violation: child_offsets len {} != num_nodes+1 ({})",
118            circuit.child_offsets.len(),
119            child_offsets_len
120        )));
121    }
122    if circuit.lit.len() != n
123        || circuit.decision_var.len() != n
124        || circuit.decision_child_false.len() != n
125        || circuit.decision_child_true.len() != n
126    {
127        return Err(XlogError::Compilation(
128            "XGCF invariant violation: per-node arrays length mismatch".to_string(),
129        ));
130    }
131
132    let edge_count = checked_gpu_u32_len("GPU XGCF edge count", circuit.child_indices.len())?;
133    let mut previous_offset = 0u32;
134    for (idx, &offset) in circuit.child_offsets.iter().enumerate() {
135        if offset < previous_offset {
136            return Err(XlogError::Compilation(format!(
137                "XGCF invariant violation: child_offsets decrease at index {} ({} > {})",
138                idx, previous_offset, offset
139            )));
140        }
141        if offset > edge_count {
142            return Err(XlogError::Compilation(format!(
143                "XGCF invariant violation: child_offsets[{}] {} exceeds child_indices len {}",
144                idx, offset, edge_count
145            )));
146        }
147        previous_offset = offset;
148    }
149    if previous_offset != edge_count {
150        return Err(XlogError::Compilation(format!(
151            "XGCF invariant violation: final child offset {} != child_indices len {}",
152            previous_offset, edge_count
153        )));
154    }
155    for (edge, &child) in circuit.child_indices.iter().enumerate() {
156        if child >= node_count {
157            return Err(XlogError::Compilation(format!(
158                "XGCF invariant violation: child_indices[{}] {} out of bounds (num_nodes={})",
159                edge, child, node_count
160            )));
161        }
162    }
163
164    for (idx, &ty) in circuit.node_type.iter().enumerate() {
165        match ty {
166            XgcfNodeType::Const0 | XgcfNodeType::Const1 => {}
167            XgcfNodeType::Lit => {
168                if circuit.lit[idx] == 0 {
169                    return Err(XlogError::Compilation(format!(
170                        "XGCF invariant violation: LIT node {} has lit=0",
171                        idx
172                    )));
173                }
174            }
175            XgcfNodeType::And | XgcfNodeType::Or => {
176                if circuit.child_offsets[idx] == circuit.child_offsets[idx + 1] {
177                    return Err(XlogError::Compilation(format!(
178                        "XGCF invariant violation: {:?} node {} has no children",
179                        ty, idx
180                    )));
181                }
182            }
183            XgcfNodeType::Decision => {
184                if circuit.decision_var[idx] == 0 {
185                    return Err(XlogError::Compilation(format!(
186                        "XGCF invariant violation: DECISION node {} has var=0",
187                        idx
188                    )));
189                }
190                if circuit.decision_child_false[idx] >= node_count {
191                    return Err(XlogError::Compilation(format!(
192                        "XGCF invariant violation: DECISION node {} false child {} out of bounds",
193                        idx, circuit.decision_child_false[idx]
194                    )));
195                }
196                if circuit.decision_child_true[idx] >= node_count {
197                    return Err(XlogError::Compilation(format!(
198                        "XGCF invariant violation: DECISION node {} true child {} out of bounds",
199                        idx, circuit.decision_child_true[idx]
200                    )));
201                }
202            }
203        }
204    }
205
206    if circuit.level_offsets.is_empty() || circuit.level_offsets[0] != 0 {
207        return Err(XlogError::Compilation(
208            "XGCF invariant violation: level_offsets must start at 0".to_string(),
209        ));
210    }
211    let level_nodes_len =
212        checked_gpu_u32_len("GPU XGCF level_nodes len", circuit.level_nodes.len())?;
213    let mut previous_level_offset = 0u32;
214    for (idx, &offset) in circuit.level_offsets.iter().enumerate() {
215        if offset < previous_level_offset {
216            return Err(XlogError::Compilation(format!(
217                "XGCF invariant violation: level_offsets decrease at index {} ({} > {})",
218                idx, previous_level_offset, offset
219            )));
220        }
221        if offset > level_nodes_len {
222            return Err(XlogError::Compilation(format!(
223                "XGCF invariant violation: level_offsets[{}] {} exceeds level_nodes len {}",
224                idx, offset, level_nodes_len
225            )));
226        }
227        previous_level_offset = offset;
228    }
229    if previous_level_offset != level_nodes_len {
230        return Err(XlogError::Compilation(format!(
231            "XGCF invariant violation: level_offsets last {} != level_nodes.len {}",
232            previous_level_offset, level_nodes_len
233        )));
234    }
235    for (idx, &node) in circuit.level_nodes.iter().enumerate() {
236        if node >= node_count {
237            return Err(XlogError::Compilation(format!(
238                "XGCF invariant violation: level_nodes[{}] {} out of bounds (num_nodes={})",
239                idx, node, node_count
240            )));
241        }
242    }
243    let num_levels_usize = circuit.level_offsets.len() - 1;
244    let num_levels = checked_gpu_u32_len("GPU XGCF level count", num_levels_usize)?;
245    if num_levels == 0 {
246        return Err(XlogError::Compilation(
247            "GPU XGCF upload requires at least one level".to_string(),
248        ));
249    }
250
251    if circuit.roots.len() != 1 {
252        return Err(XlogError::Compilation(format!(
253            "GPU XGCF eval expects exactly 1 root, got {}",
254            circuit.roots.len()
255        )));
256    }
257    if circuit.roots[0] >= node_count {
258        return Err(XlogError::Compilation(format!(
259            "XGCF invariant violation: root {} out of bounds (num_nodes={})",
260            circuit.roots[0], node_count
261        )));
262    }
263
264    Ok((node_count, edge_count, num_levels))
265}
266
267impl GpuXgcf {
268    pub fn from_device(
269        builder: GpuCircuitBuilder,
270        layout: GpuCircuitLayout,
271        provider: &CudaKernelProvider,
272    ) -> Result<GpuXgcf> {
273        if layout.num_nodes == 0 {
274            return Err(XlogError::Compilation(
275                "GpuXgcf::from_device requires num_nodes > 0".to_string(),
276            ));
277        }
278        if layout.root >= layout.num_nodes {
279            return Err(XlogError::Compilation(format!(
280                "GpuXgcf::from_device: root {} out of bounds (num_nodes={})",
281                layout.root, layout.num_nodes
282            )));
283        }
284        if layout.num_levels == 0 {
285            return Err(XlogError::Compilation(
286                "GpuXgcf::from_device requires num_levels > 0".to_string(),
287            ));
288        }
289
290        let num_nodes = layout.num_nodes as usize;
291        let num_edges = layout.num_edges as usize;
292        let node_cap = builder.node_type.len();
293        if num_nodes == 0 || num_nodes > node_cap {
294            return Err(XlogError::Compilation(
295                "GpuXgcf::from_device: num_nodes out of bounds".to_string(),
296            ));
297        }
298        let child_offsets_len =
299            checked_gpu_len_add_one("GpuXgcf::from_device child_offsets", node_cap)?;
300        if builder.child_offsets.len() != child_offsets_len
301            || builder.lit.len() != node_cap
302            || builder.decision_var.len() != node_cap
303            || builder.decision_child_false.len() != node_cap
304            || builder.decision_child_true.len() != node_cap
305        {
306            return Err(XlogError::Compilation(
307                "GpuXgcf::from_device: circuit buffer length mismatch".to_string(),
308            ));
309        }
310        if num_edges > builder.child_indices.len() {
311            return Err(XlogError::Compilation(
312                "GpuXgcf::from_device: num_edges out of bounds".to_string(),
313            ));
314        }
315
316        let num_levels = layout.num_levels as usize;
317        let level_offsets_len =
318            checked_gpu_len_add_one("GpuXgcf::from_device level_offsets", num_levels)?;
319        if layout.level_offsets.len() != level_offsets_len {
320            return Err(XlogError::Compilation(format!(
321                "GpuXgcf::from_device: level_offsets len {} != num_levels+1 ({})",
322                layout.level_offsets.len(),
323                level_offsets_len
324            )));
325        }
326        if layout.level_nodes.len() < num_nodes {
327            return Err(XlogError::Compilation(format!(
328                "GpuXgcf::from_device: level_nodes len {} < num_nodes ({})",
329                layout.level_nodes.len(),
330                num_nodes
331            )));
332        }
333
334        let memory = provider.memory();
335
336        let weights_len = (layout.max_var as usize) + 1;
337        let var_log_true = memory.alloc::<f64>(weights_len)?;
338        let var_log_false = memory.alloc::<f64>(weights_len)?;
339        let values = memory.alloc::<f64>(num_nodes)?;
340        let adj = memory.alloc::<f64>(num_nodes)?;
341        let grad_true = memory.alloc::<f64>(weights_len)?;
342        let grad_false = memory.alloc::<f64>(weights_len)?;
343
344        let meta_num_nodes = match layout.num_nodes_device {
345            Some(meta) => meta,
346            None => {
347                let mut meta = memory.alloc::<u32>(1)?;
348                provider
349                    .htod_launch_metadata_sync_copy_into(&[layout.num_nodes], &mut meta)
350                    .map_err(|e| {
351                        XlogError::Kernel(format!("Failed to upload num_nodes meta: {}", e))
352                    })?;
353                meta
354            }
355        };
356        let meta_num_edges = match layout.num_edges_device {
357            Some(meta) => meta,
358            None => {
359                let mut meta = memory.alloc::<u32>(1)?;
360                provider
361                    .htod_launch_metadata_sync_copy_into(&[layout.num_edges], &mut meta)
362                    .map_err(|e| {
363                        XlogError::Kernel(format!("Failed to upload num_edges meta: {}", e))
364                    })?;
365                meta
366            }
367        };
368
369        Ok(Self {
370            node_type: builder.node_type,
371            child_offsets: builder.child_offsets,
372            child_indices: builder.child_indices,
373            lit: builder.lit,
374            decision_var: builder.decision_var,
375            decision_child_false: builder.decision_child_false,
376            decision_child_true: builder.decision_child_true,
377            level_nodes: layout.level_nodes,
378            level_offsets: layout.level_offsets,
379            level_offsets_host: None,
380            node_cap: layout.num_nodes,
381            edge_cap: layout.num_edges,
382            num_levels: layout.num_levels,
383            root: layout.root,
384            max_var: layout.max_var,
385            meta_num_nodes,
386            meta_num_edges,
387            var_log_true,
388            var_log_false,
389            values,
390            adj,
391            grad_true,
392            grad_false,
393            free_var_mask: None,
394        })
395    }
396
397    /// GPU-native smoothing pass for random variables.
398    ///
399    /// Returns a new device-resident circuit that is smooth w.r.t. `random_var_list`.
400    /// This method performs no device->host data-plane transfers and traps on capacity overflow.
401    pub fn smooth_random_vars_device(
402        &self,
403        provider: &CudaKernelProvider,
404        random_var_list: &TrackedCudaSlice<u32>,
405        random_var_count: u32,
406        smooth_node_cap: u32,
407        smooth_edge_cap: u32,
408    ) -> Result<GpuXgcf> {
409        if smooth_node_cap == 0 || smooth_edge_cap == 0 {
410            return Err(XlogError::Compilation(
411                "GPU smoothing requires non-zero node/edge caps".to_string(),
412            ));
413        }
414
415        let num_nodes = self.node_cap;
416        if num_nodes == 0 {
417            return Err(XlogError::Compilation(
418                "GPU smoothing: num_nodes must be > 0".to_string(),
419            ));
420        }
421        if self.child_offsets.len() < (num_nodes as usize + 1) {
422            return Err(XlogError::Compilation(
423                "GPU smoothing: child_offsets len mismatch".to_string(),
424            ));
425        }
426        let num_edges = self.edge_cap;
427        if num_edges == 0 {
428            return Err(XlogError::Compilation(
429                "GPU smoothing: num_edges must be > 0".to_string(),
430            ));
431        }
432
433        let list_len = u32::try_from(random_var_list.len()).map_err(|_| {
434            XlogError::Compilation("GPU smoothing: random var list len exceeds u32".to_string())
435        })?;
436        let num_random_vars = random_var_count;
437        if num_random_vars > list_len {
438            return Err(XlogError::Compilation(format!(
439                "GPU smoothing: random var count {} exceeds list len {}",
440                num_random_vars, list_len
441            )));
442        }
443
444        let base_node = 2u32.checked_add(num_random_vars).ok_or_else(|| {
445            XlogError::Compilation("GPU smoothing: base node overflow".to_string())
446        })?;
447        let base_nodes = (base_node as u64)
448            .checked_add(num_nodes as u64)
449            .ok_or_else(|| {
450                XlogError::Compilation("GPU smoothing: base node overflow".to_string())
451            })?;
452        if base_nodes > smooth_node_cap as u64 {
453            return Err(XlogError::Compilation(format!(
454                "GPU smoothing: base nodes {} exceed smooth_node_cap {}",
455                base_nodes, smooth_node_cap
456            )));
457        }
458
459        let words_per_support = num_random_vars.div_ceil(32).max(1);
460
461        let support_len = (num_nodes as u64)
462            .checked_mul(words_per_support as u64)
463            .and_then(|v| usize::try_from(v).ok())
464            .ok_or_else(|| {
465                XlogError::Compilation("GPU smoothing: support size overflow".to_string())
466            })?;
467
468        let dec_entries = (num_nodes as u64)
469            .checked_mul(2)
470            .and_then(|v| usize::try_from(v).ok())
471            .ok_or_else(|| {
472                XlogError::Compilation("GPU smoothing: decision array overflow".to_string())
473            })?;
474        let dec_entries_u32 = u32::try_from(dec_entries).map_err(|_| {
475            XlogError::Compilation("GPU smoothing: decision entries exceed u32".to_string())
476        })?;
477
478        let device = provider.device().inner();
479        let memory = provider.memory();
480        let block_size: u32 = 256;
481
482        let map_len = (self.max_var as usize)
483            .checked_add(1)
484            .ok_or_else(|| XlogError::Compilation("GPU smoothing: max_var overflow".to_string()))?;
485        let map_len_u32 = u32::try_from(map_len).map_err(|_| {
486            XlogError::Compilation("GPU smoothing: random map len exceeds u32".to_string())
487        })?;
488        let mut d_random_map = memory.alloc::<u32>(map_len)?;
489        if map_len > 0 {
490            let fill_const = device
491                .get_func(FILTER_MODULE, filter_kernels::FILL_U32_CONST)
492                .ok_or_else(|| XlogError::Kernel("fill_u32_const kernel not found".to_string()))?;
493            let grid = map_len_u32.div_ceil(block_size);
494            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
495            unsafe {
496                fill_const.clone().launch(
497                    LaunchConfig {
498                        grid_dim: (grid, 1, 1),
499                        block_dim: (block_size, 1, 1),
500                        shared_mem_bytes: 0,
501                    },
502                    (&mut d_random_map, map_len_u32, u32::MAX),
503                )
504            }
505            .map_err(|e| XlogError::Kernel(format!("fill_u32_const failed: {}", e)))?;
506        }
507        if num_random_vars > 0 {
508            let map_kernel = device
509                .get_func(FILTER_MODULE, filter_kernels::RANDOM_VAR_TO_BIT_FROM_LIST)
510                .ok_or_else(|| {
511                    XlogError::Kernel("random_var_to_bit_from_list kernel not found".to_string())
512                })?;
513            let grid = num_random_vars.div_ceil(block_size);
514            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
515            unsafe {
516                map_kernel.clone().launch(
517                    LaunchConfig {
518                        grid_dim: (grid, 1, 1),
519                        block_dim: (block_size, 1, 1),
520                        shared_mem_bytes: 0,
521                    },
522                    (
523                        random_var_list,
524                        num_random_vars,
525                        map_len_u32,
526                        &mut d_random_map,
527                    ),
528                )
529            }
530            .map_err(|e| XlogError::Kernel(format!("random_var_to_bit_from_list failed: {}", e)))?;
531        }
532
533        let mut support = memory.alloc::<u32>(support_len)?;
534        device
535            .memset_zeros(&mut support)
536            .map_err(|e| XlogError::Kernel(format!("Failed to zero support: {}", e)))?;
537
538        let support_kernel = device
539            .get_func(D4_MODULE, d4_kernels::D4_SUPPORT_LEVEL)
540            .ok_or_else(|| XlogError::Kernel("d4_support_level kernel not found".to_string()))?;
541
542        let num_levels = self.num_levels as usize;
543        let random_map_len = map_len_u32;
544        for level in 0..num_levels {
545            let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
546                Some(off) => checked_host_level_width(off, level)?,
547                None => self.level_nodes.len(),
548            };
549            if num_level_nodes == 0 {
550                continue;
551            }
552            let num_blocks =
553                checked_gpu_launch_blocks("d4_support_level", num_level_nodes, block_size)?;
554            let config = LaunchConfig {
555                grid_dim: (num_blocks, 1, 1),
556                block_dim: (block_size, 1, 1),
557                shared_mem_bytes: 0,
558            };
559            let level_u32 = level as u32;
560            let mut params: Vec<*mut c_void> = vec![
561                (&self.node_type).as_kernel_param(),
562                (&self.child_offsets).as_kernel_param(),
563                (&self.child_indices).as_kernel_param(),
564                (&self.lit).as_kernel_param(),
565                (&self.decision_var).as_kernel_param(),
566                (&self.decision_child_false).as_kernel_param(),
567                (&self.decision_child_true).as_kernel_param(),
568                (&self.level_nodes).as_kernel_param(),
569                (&self.level_offsets).as_kernel_param(),
570                level_u32.as_kernel_param(),
571                (&d_random_map).as_kernel_param(),
572                random_map_len.as_kernel_param(),
573                words_per_support.as_kernel_param(),
574                (&support).as_kernel_param(),
575            ];
576            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
577            unsafe { support_kernel.clone().launch(config, &mut params) }
578                .map_err(|e| XlogError::Kernel(format!("d4_support_level failed: {}", e)))?;
579        }
580
581        if num_random_vars > 0 {
582            let root_kernel = device
583                .get_func(D4_MODULE, d4_kernels::D4_SUPPORT_SET_ROOT_BITS)
584                .ok_or_else(|| {
585                    XlogError::Kernel("d4_support_set_root_bits kernel not found".to_string())
586                })?;
587            let num_words = num_random_vars.div_ceil(32);
588            let grid = num_words.div_ceil(block_size);
589            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
590            unsafe {
591                root_kernel.clone().launch(
592                    LaunchConfig {
593                        grid_dim: (grid, 1, 1),
594                        block_dim: (block_size, 1, 1),
595                        shared_mem_bytes: 0,
596                    },
597                    (self.root, num_random_vars, words_per_support, &mut support),
598                )
599            }
600            .map_err(|e| XlogError::Kernel(format!("d4_support_set_root_bits failed: {}", e)))?;
601        }
602
603        let mut wrap_prefix_or = memory.alloc::<u32>(num_edges as usize)?;
604        let mut wrap_missing_or = memory.alloc::<u32>(num_edges as usize)?;
605        let mut wrap_prefix_dec = memory.alloc::<u32>(dec_entries)?;
606        let mut wrap_missing_dec = memory.alloc::<u32>(dec_entries)?;
607
608        device
609            .memset_zeros(&mut wrap_prefix_or)
610            .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_prefix_or: {}", e)))?;
611        device
612            .memset_zeros(&mut wrap_missing_or)
613            .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_missing_or: {}", e)))?;
614        device
615            .memset_zeros(&mut wrap_prefix_dec)
616            .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_prefix_dec: {}", e)))?;
617        device
618            .memset_zeros(&mut wrap_missing_dec)
619            .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_missing_dec: {}", e)))?;
620
621        let mut out_edge_counts = memory.alloc::<u32>(smooth_node_cap as usize)?;
622        device
623            .memset_zeros(&mut out_edge_counts)
624            .map_err(|e| XlogError::Kernel(format!("Failed to zero edge_counts: {}", e)))?;
625
626        let count_kernel = device
627            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_COUNT)
628            .ok_or_else(|| XlogError::Kernel("d4_smooth_count kernel not found".to_string()))?;
629        let num_blocks = num_nodes.div_ceil(block_size);
630        let mut params: Vec<*mut c_void> = vec![
631            (&self.node_type).as_kernel_param(),
632            (&self.child_offsets).as_kernel_param(),
633            (&self.child_indices).as_kernel_param(),
634            (&self.decision_var).as_kernel_param(),
635            (&self.decision_child_false).as_kernel_param(),
636            (&self.decision_child_true).as_kernel_param(),
637            (&self.meta_num_nodes).as_kernel_param(),
638            (&support).as_kernel_param(),
639            words_per_support.as_kernel_param(),
640            (&d_random_map).as_kernel_param(),
641            random_map_len.as_kernel_param(),
642            (&wrap_prefix_or).as_kernel_param(),
643            (&wrap_missing_or).as_kernel_param(),
644            (&wrap_prefix_dec).as_kernel_param(),
645            (&wrap_missing_dec).as_kernel_param(),
646            (&out_edge_counts).as_kernel_param(),
647            base_node.as_kernel_param(),
648            smooth_node_cap.as_kernel_param(),
649        ];
650        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
651        unsafe {
652            count_kernel.clone().launch(
653                LaunchConfig {
654                    grid_dim: (num_blocks, 1, 1),
655                    block_dim: (block_size, 1, 1),
656                    shared_mem_bytes: 0,
657                },
658                &mut params,
659            )
660        }
661        .map_err(|e| XlogError::Kernel(format!("d4_smooth_count failed: {}", e)))?;
662
663        exclusive_scan_u32_inplace(provider, &mut wrap_prefix_or, num_edges)?;
664        exclusive_scan_u32_inplace(provider, &mut wrap_prefix_dec, dec_entries_u32)?;
665
666        let mut wrap_counts = memory.alloc::<u32>(3)?;
667        device
668            .memset_zeros(&mut wrap_counts)
669            .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_counts: {}", e)))?;
670
671        let counts_kernel = device
672            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_WRAPPER_COUNTS)
673            .ok_or_else(|| {
674                XlogError::Kernel("d4_smooth_wrapper_counts kernel not found".to_string())
675            })?;
676        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
677        unsafe {
678            counts_kernel.clone().launch(
679                LaunchConfig {
680                    grid_dim: (1, 1, 1),
681                    block_dim: (1, 1, 1),
682                    shared_mem_bytes: 0,
683                },
684                (
685                    &wrap_prefix_or,
686                    &wrap_missing_or,
687                    num_edges,
688                    &wrap_prefix_dec,
689                    &wrap_missing_dec,
690                    dec_entries_u32,
691                    base_node,
692                    &self.meta_num_nodes,
693                    u32::MAX,
694                    &mut wrap_counts,
695                ),
696            )
697        }
698        .map_err(|e| XlogError::Kernel(format!("d4_smooth_wrapper_counts failed: {}", e)))?;
699
700        let wrap_or_kernel = device
701            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_WRAPPER_EDGE_COUNTS_OR)
702            .ok_or_else(|| {
703                XlogError::Kernel("d4_smooth_wrapper_edge_counts_or kernel not found".to_string())
704            })?;
705        if num_edges > 0 {
706            let num_blocks = num_edges.div_ceil(block_size);
707            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
708            unsafe {
709                wrap_or_kernel.clone().launch(
710                    LaunchConfig {
711                        grid_dim: (num_blocks, 1, 1),
712                        block_dim: (block_size, 1, 1),
713                        shared_mem_bytes: 0,
714                    },
715                    (
716                        &wrap_prefix_or,
717                        &wrap_missing_or,
718                        num_edges,
719                        base_node,
720                        &self.meta_num_nodes,
721                        smooth_node_cap,
722                        &mut out_edge_counts,
723                    ),
724                )
725            }
726            .map_err(|e| {
727                XlogError::Kernel(format!("d4_smooth_wrapper_edge_counts_or failed: {}", e))
728            })?;
729        }
730
731        let wrap_dec_kernel = device
732            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_WRAPPER_EDGE_COUNTS_DEC)
733            .ok_or_else(|| {
734                XlogError::Kernel("d4_smooth_wrapper_edge_counts_dec kernel not found".to_string())
735            })?;
736        if dec_entries > 0 {
737            let num_blocks = dec_entries_u32.div_ceil(block_size);
738            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
739            unsafe {
740                wrap_dec_kernel.clone().launch(
741                    LaunchConfig {
742                        grid_dim: (num_blocks, 1, 1),
743                        block_dim: (block_size, 1, 1),
744                        shared_mem_bytes: 0,
745                    },
746                    (
747                        &wrap_prefix_dec,
748                        &wrap_missing_dec,
749                        dec_entries_u32,
750                        base_node,
751                        &self.meta_num_nodes,
752                        &wrap_counts,
753                        smooth_node_cap,
754                        &mut out_edge_counts,
755                    ),
756                )
757            }
758            .map_err(|e| {
759                XlogError::Kernel(format!("d4_smooth_wrapper_edge_counts_dec failed: {}", e))
760            })?;
761        }
762
763        let mut out_child_offsets = memory.alloc::<u32>((smooth_node_cap as usize) + 1)?;
764        device
765            .memset_zeros(&mut out_child_offsets)
766            .map_err(|e| XlogError::Kernel(format!("Failed to zero child_offsets: {}", e)))?;
767        if smooth_node_cap > 0 {
768            device
769                .dtod_copy(
770                    &out_edge_counts,
771                    &mut out_child_offsets.slice_mut(0..smooth_node_cap as usize),
772                )
773                .map_err(|e| XlogError::Kernel(format!("Failed to copy edge_counts: {}", e)))?;
774        }
775        let child_scan_len = smooth_node_cap.checked_add(1).ok_or_else(|| {
776            XlogError::Compilation("GPU smoothing: child offset scan overflow".to_string())
777        })?;
778        exclusive_scan_u32_inplace(provider, &mut out_child_offsets, child_scan_len)?;
779
780        let edge_cap_check = device
781            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_CHECK_EDGE_CAP)
782            .ok_or_else(|| {
783                XlogError::Kernel("d4_smooth_check_edge_cap kernel not found".to_string())
784            })?;
785        let mut meta_num_nodes = memory.alloc::<u32>(1)?;
786        let mut meta_num_edges = memory.alloc::<u32>(1)?;
787        device
788            .memset_zeros(&mut meta_num_nodes)
789            .map_err(|e| XlogError::Kernel(format!("Failed to zero smooth num_nodes: {}", e)))?;
790        device
791            .memset_zeros(&mut meta_num_edges)
792            .map_err(|e| XlogError::Kernel(format!("Failed to zero smooth num_edges: {}", e)))?;
793        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
794        unsafe {
795            edge_cap_check.clone().launch(
796                LaunchConfig {
797                    grid_dim: (1, 1, 1),
798                    block_dim: (1, 1, 1),
799                    shared_mem_bytes: 0,
800                },
801                (
802                    &out_child_offsets,
803                    smooth_node_cap,
804                    smooth_edge_cap,
805                    &wrap_counts,
806                    &mut meta_num_nodes,
807                    &mut meta_num_edges,
808                ),
809            )
810        }
811        .map_err(|e| XlogError::Kernel(format!("d4_smooth_check_edge_cap failed: {}", e)))?;
812
813        let mut out_node_type = memory.alloc::<u8>(smooth_node_cap as usize)?;
814        let mut out_child_indices = memory.alloc::<u32>(smooth_edge_cap as usize)?;
815        let mut out_lit = memory.alloc::<i32>(smooth_node_cap as usize)?;
816        let mut out_decision_var = memory.alloc::<u32>(smooth_node_cap as usize)?;
817        let mut out_decision_child_false = memory.alloc::<u32>(smooth_node_cap as usize)?;
818        let mut out_decision_child_true = memory.alloc::<u32>(smooth_node_cap as usize)?;
819        let mut out_node_level = memory.alloc::<u32>(smooth_node_cap as usize)?;
820
821        device
822            .memset_zeros(&mut out_node_type)
823            .map_err(|e| XlogError::Kernel(format!("Failed to zero node_type: {}", e)))?;
824        device
825            .memset_zeros(&mut out_child_indices)
826            .map_err(|e| XlogError::Kernel(format!("Failed to zero child_indices: {}", e)))?;
827        device
828            .memset_zeros(&mut out_lit)
829            .map_err(|e| XlogError::Kernel(format!("Failed to zero lit: {}", e)))?;
830        device
831            .memset_zeros(&mut out_decision_var)
832            .map_err(|e| XlogError::Kernel(format!("Failed to zero decision_var: {}", e)))?;
833        device
834            .memset_zeros(&mut out_decision_child_false)
835            .map_err(|e| {
836                XlogError::Kernel(format!("Failed to zero decision_child_false: {}", e))
837            })?;
838        device
839            .memset_zeros(&mut out_decision_child_true)
840            .map_err(|e| XlogError::Kernel(format!("Failed to zero decision_child_true: {}", e)))?;
841        device
842            .memset_zeros(&mut out_node_level)
843            .map_err(|e| XlogError::Kernel(format!("Failed to zero node_level: {}", e)))?;
844
845        let init_kernel = device
846            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_INIT_NODES)
847            .ok_or_else(|| {
848                XlogError::Kernel("d4_smooth_init_nodes kernel not found".to_string())
849            })?;
850        let init_blocks = checked_gpu_launch_blocks(
851            "d4_smooth_init_nodes",
852            num_random_vars.max(1) as usize,
853            block_size,
854        )?;
855        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
856        unsafe {
857            init_kernel.clone().launch(
858                LaunchConfig {
859                    grid_dim: (init_blocks, 1, 1),
860                    block_dim: (block_size, 1, 1),
861                    shared_mem_bytes: 0,
862                },
863                (
864                    random_var_list,
865                    num_random_vars,
866                    smooth_node_cap,
867                    &mut out_node_type,
868                    &mut out_lit,
869                    &mut out_decision_var,
870                    &mut out_decision_child_false,
871                    &mut out_decision_child_true,
872                    &mut out_node_level,
873                ),
874            )
875        }
876        .map_err(|e| XlogError::Kernel(format!("d4_smooth_init_nodes failed: {}", e)))?;
877
878        let num_levels_out = self
879            .num_levels
880            .checked_mul(2)
881            .and_then(|levels| levels.checked_add(4))
882            .ok_or_else(|| {
883                XlogError::Compilation("GPU smoothing output level count overflow".to_string())
884            })?;
885        let num_levels_out_usize = num_levels_out as usize;
886        let level_offsets_len =
887            checked_gpu_len_add_one("GPU smoothing level offsets", num_levels_out_usize)?;
888
889        let emit_kernel = device
890            .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_EMIT_LEVEL)
891            .ok_or_else(|| {
892                XlogError::Kernel("d4_smooth_emit_level kernel not found".to_string())
893            })?;
894        for level in 0..num_levels {
895            let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
896                Some(off) => checked_host_level_width(off, level)?,
897                None => self.level_nodes.len(),
898            };
899            if num_level_nodes == 0 {
900                continue;
901            }
902            let num_blocks =
903                checked_gpu_launch_blocks("xgcf_smooth_forward", num_level_nodes, block_size)?;
904            let level_u32 = level as u32;
905            let mut params: Vec<*mut c_void> = vec![
906                (&self.node_type).as_kernel_param(),
907                (&self.child_offsets).as_kernel_param(),
908                (&self.child_indices).as_kernel_param(),
909                (&self.lit).as_kernel_param(),
910                (&self.decision_var).as_kernel_param(),
911                (&self.decision_child_false).as_kernel_param(),
912                (&self.decision_child_true).as_kernel_param(),
913                (&self.level_nodes).as_kernel_param(),
914                (&self.level_offsets).as_kernel_param(),
915                level_u32.as_kernel_param(),
916                (&support).as_kernel_param(),
917                words_per_support.as_kernel_param(),
918                (&wrap_prefix_or).as_kernel_param(),
919                (&wrap_missing_or).as_kernel_param(),
920                (&wrap_prefix_dec).as_kernel_param(),
921                (&wrap_missing_dec).as_kernel_param(),
922                base_node.as_kernel_param(),
923                (&self.meta_num_nodes).as_kernel_param(),
924                (&wrap_counts).as_kernel_param(),
925                num_random_vars.as_kernel_param(),
926                num_levels_out.as_kernel_param(),
927                (&out_node_type).as_kernel_param(),
928                (&out_child_offsets).as_kernel_param(),
929                (&out_child_indices).as_kernel_param(),
930                (&out_lit).as_kernel_param(),
931                (&out_decision_var).as_kernel_param(),
932                (&out_decision_child_false).as_kernel_param(),
933                (&out_decision_child_true).as_kernel_param(),
934                (&out_node_level).as_kernel_param(),
935            ];
936            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
937            unsafe {
938                emit_kernel.clone().launch(
939                    LaunchConfig {
940                        grid_dim: (num_blocks, 1, 1),
941                        block_dim: (block_size, 1, 1),
942                        shared_mem_bytes: 0,
943                    },
944                    &mut params,
945                )
946            }
947            .map_err(|e| XlogError::Kernel(format!("d4_smooth_emit_level failed: {}", e)))?;
948        }
949
950        let mut level_counts = memory.alloc::<u32>(num_levels_out_usize)?;
951        let mut level_offsets = memory.alloc::<u32>(level_offsets_len)?;
952        let mut level_cursors = memory.alloc::<u32>(num_levels_out_usize)?;
953        let mut level_nodes = memory.alloc::<u32>(smooth_node_cap as usize)?;
954
955        device
956            .memset_zeros(&mut level_counts)
957            .map_err(|e| XlogError::Kernel(format!("Failed to zero level_counts: {}", e)))?;
958        device
959            .memset_zeros(&mut level_offsets)
960            .map_err(|e| XlogError::Kernel(format!("Failed to zero level_offsets: {}", e)))?;
961        device
962            .memset_zeros(&mut level_cursors)
963            .map_err(|e| XlogError::Kernel(format!("Failed to zero level_cursors: {}", e)))?;
964        device
965            .memset_zeros(&mut level_nodes)
966            .map_err(|e| XlogError::Kernel(format!("Failed to zero level_nodes: {}", e)))?;
967
968        let mut compile_needed = memory.alloc::<u32>(1)?;
969        provider
970            .htod_launch_metadata_sync_copy_into(&[1u32], &mut compile_needed)
971            .map_err(|e| XlogError::Kernel(format!("Failed to upload compile_needed: {}", e)))?;
972
973        let levelize_counts = device
974            .get_func(D4_MODULE, d4_kernels::D4_LEVELIZE_COUNTS)
975            .ok_or_else(|| XlogError::Kernel("d4_levelize_counts kernel not found".to_string()))?;
976        let num_blocks =
977            checked_gpu_launch_blocks("d4_smooth_levelize", smooth_node_cap as usize, block_size)?;
978        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
979        unsafe {
980            levelize_counts.clone().launch(
981                LaunchConfig {
982                    grid_dim: (num_blocks, 1, 1),
983                    block_dim: (block_size, 1, 1),
984                    shared_mem_bytes: 0,
985                },
986                (
987                    &compile_needed,
988                    &out_node_level,
989                    &meta_num_nodes,
990                    num_levels_out,
991                    &mut level_counts,
992                ),
993            )
994        }
995        .map_err(|e| XlogError::Kernel(format!("d4_levelize_counts failed: {}", e)))?;
996
997        device
998            .dtod_copy(
999                &level_counts,
1000                &mut level_offsets.slice_mut(0..num_levels_out_usize),
1001            )
1002            .map_err(|e| XlogError::Kernel(format!("Failed to copy level_counts: {}", e)))?;
1003        let level_scan_len = num_levels_out.checked_add(1).ok_or_else(|| {
1004            XlogError::Compilation("GPU smoothing: level offset scan overflow".to_string())
1005        })?;
1006        exclusive_scan_u32_inplace(provider, &mut level_offsets, level_scan_len)?;
1007
1008        let levelize_emit = device
1009            .get_func(D4_MODULE, d4_kernels::D4_LEVELIZE_EMIT)
1010            .ok_or_else(|| XlogError::Kernel("d4_levelize_emit kernel not found".to_string()))?;
1011        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1012        unsafe {
1013            levelize_emit.clone().launch(
1014                LaunchConfig {
1015                    grid_dim: (num_blocks, 1, 1),
1016                    block_dim: (block_size, 1, 1),
1017                    shared_mem_bytes: 0,
1018                },
1019                (
1020                    &compile_needed,
1021                    &out_node_level,
1022                    &meta_num_nodes,
1023                    num_levels_out,
1024                    &level_offsets,
1025                    &mut level_cursors,
1026                    &mut level_nodes,
1027                ),
1028            )
1029        }
1030        .map_err(|e| XlogError::Kernel(format!("d4_levelize_emit failed: {}", e)))?;
1031
1032        // No device synchronize: result is device buffers used by subsequent GPU ops.
1033        let builder = GpuCircuitBuilder {
1034            node_type: out_node_type,
1035            child_offsets: out_child_offsets,
1036            child_indices: out_child_indices,
1037            lit: out_lit,
1038            decision_var: out_decision_var,
1039            decision_child_false: out_decision_child_false,
1040            decision_child_true: out_decision_child_true,
1041        };
1042        let layout = GpuCircuitLayout {
1043            num_nodes: smooth_node_cap,
1044            num_edges: smooth_edge_cap,
1045            num_levels: num_levels_out,
1046            level_offsets,
1047            level_nodes,
1048            root: base_node + self.root,
1049            max_var: self.max_var,
1050            num_nodes_device: Some(meta_num_nodes),
1051            num_edges_device: Some(meta_num_edges),
1052        };
1053
1054        GpuXgcf::from_device(builder, layout, provider)
1055    }
1056
1057    pub fn upload(provider: &CudaKernelProvider, circuit: &Xgcf) -> Result<Self> {
1058        let (node_cap, edge_cap, num_levels) = validate_xgcf_for_gpu_upload(circuit)?;
1059
1060        let memory = provider.memory().clone();
1061
1062        let n = circuit.node_type.len();
1063        let mut host_node_type: Vec<u8> = Vec::with_capacity(n);
1064        for &ty in &circuit.node_type {
1065            host_node_type.push(ty as u8);
1066        }
1067
1068        let mut max_var: u32 = 0;
1069        for (&ty, &lit) in circuit.node_type.iter().zip(circuit.lit.iter()) {
1070            if ty == XgcfNodeType::Lit && lit != 0 {
1071                max_var = max_var.max(lit.unsigned_abs());
1072            }
1073        }
1074        for &var in &circuit.decision_var {
1075            max_var = max_var.max(var);
1076        }
1077
1078        let mut d_node_type = memory.alloc::<u8>(n)?;
1079        provider
1080            .htod_sync_copy_into_tracked(&host_node_type, &mut d_node_type)
1081            .map_err(|e| XlogError::Kernel(format!("Failed to upload circuit node_type: {}", e)))?;
1082
1083        let mut d_child_offsets = memory.alloc::<u32>(circuit.child_offsets.len())?;
1084        provider
1085            .htod_sync_copy_into_tracked(&circuit.child_offsets, &mut d_child_offsets)
1086            .map_err(|e| {
1087                XlogError::Kernel(format!("Failed to upload circuit child_offsets: {}", e))
1088            })?;
1089
1090        let mut d_child_indices = memory.alloc::<u32>(circuit.child_indices.len())?;
1091        provider
1092            .htod_sync_copy_into_tracked(&circuit.child_indices, &mut d_child_indices)
1093            .map_err(|e| {
1094                XlogError::Kernel(format!("Failed to upload circuit child_indices: {}", e))
1095            })?;
1096
1097        let mut d_lit = memory.alloc::<i32>(circuit.lit.len())?;
1098        provider
1099            .htod_sync_copy_into_tracked(&circuit.lit, &mut d_lit)
1100            .map_err(|e| XlogError::Kernel(format!("Failed to upload circuit lit: {}", e)))?;
1101
1102        let mut d_decision_var = memory.alloc::<u32>(circuit.decision_var.len())?;
1103        provider
1104            .htod_sync_copy_into_tracked(&circuit.decision_var, &mut d_decision_var)
1105            .map_err(|e| {
1106                XlogError::Kernel(format!("Failed to upload circuit decision_var: {}", e))
1107            })?;
1108
1109        let mut d_decision_child_false = memory.alloc::<u32>(circuit.decision_child_false.len())?;
1110        provider
1111            .htod_sync_copy_into_tracked(&circuit.decision_child_false, &mut d_decision_child_false)
1112            .map_err(|e| {
1113                XlogError::Kernel(format!(
1114                    "Failed to upload circuit decision_child_false: {}",
1115                    e
1116                ))
1117            })?;
1118
1119        let mut d_decision_child_true = memory.alloc::<u32>(circuit.decision_child_true.len())?;
1120        provider
1121            .htod_sync_copy_into_tracked(&circuit.decision_child_true, &mut d_decision_child_true)
1122            .map_err(|e| {
1123                XlogError::Kernel(format!(
1124                    "Failed to upload circuit decision_child_true: {}",
1125                    e
1126                ))
1127            })?;
1128
1129        let mut d_level_nodes = memory.alloc::<u32>(circuit.level_nodes.len())?;
1130        provider
1131            .htod_sync_copy_into_tracked(&circuit.level_nodes, &mut d_level_nodes)
1132            .map_err(|e| {
1133                XlogError::Kernel(format!("Failed to upload circuit level_nodes: {}", e))
1134            })?;
1135
1136        let mut d_level_offsets = memory.alloc::<u32>(circuit.level_offsets.len())?;
1137        provider
1138            .htod_sync_copy_into_tracked(&circuit.level_offsets, &mut d_level_offsets)
1139            .map_err(|e| {
1140                XlogError::Kernel(format!("Failed to upload circuit level_offsets: {}", e))
1141            })?;
1142
1143        let weights_len = (max_var as usize) + 1;
1144        let var_log_true = memory.alloc::<f64>(weights_len)?;
1145        let var_log_false = memory.alloc::<f64>(weights_len)?;
1146        let values = memory.alloc::<f64>(n)?;
1147        let adj = memory.alloc::<f64>(n)?;
1148        let grad_true = memory.alloc::<f64>(weights_len)?;
1149        let grad_false = memory.alloc::<f64>(weights_len)?;
1150        let mut meta_num_nodes = memory.alloc::<u32>(1)?;
1151        provider
1152            .htod_launch_metadata_sync_copy_into(&[node_cap], &mut meta_num_nodes)
1153            .map_err(|e| XlogError::Kernel(format!("Failed to upload num_nodes meta: {}", e)))?;
1154        let mut meta_num_edges = memory.alloc::<u32>(1)?;
1155        provider
1156            .htod_launch_metadata_sync_copy_into(&[edge_cap], &mut meta_num_edges)
1157            .map_err(|e| XlogError::Kernel(format!("Failed to upload num_edges meta: {}", e)))?;
1158
1159        Ok(Self {
1160            node_type: d_node_type,
1161            child_offsets: d_child_offsets,
1162            child_indices: d_child_indices,
1163            lit: d_lit,
1164            decision_var: d_decision_var,
1165            decision_child_false: d_decision_child_false,
1166            decision_child_true: d_decision_child_true,
1167            level_nodes: d_level_nodes,
1168            level_offsets: d_level_offsets,
1169            level_offsets_host: Some(circuit.level_offsets.clone()),
1170            node_cap,
1171            edge_cap,
1172            num_levels,
1173            root: circuit.roots[0],
1174            max_var,
1175            meta_num_nodes,
1176            meta_num_edges,
1177            var_log_true,
1178            var_log_false,
1179            values,
1180            adj,
1181            grad_true,
1182            grad_false,
1183            free_var_mask: None,
1184        })
1185    }
1186
1187    pub fn max_var(&self) -> u32 {
1188        self.max_var
1189    }
1190
1191    /// Root node id of the circuit (XGCF requires exactly one root for evaluation/verification).
1192    pub fn root(&self) -> u32 {
1193        self.root
1194    }
1195
1196    /// Capacity (upper bound) for XGCF nodes in the circuit buffers.
1197    pub fn num_nodes(&self) -> usize {
1198        self.node_cap as usize
1199    }
1200
1201    /// Capacity (upper bound) for XGCF edges in the circuit buffers.
1202    pub fn num_edges(&self) -> usize {
1203        self.edge_cap as usize
1204    }
1205
1206    /// Number of topological levels in the circuit.
1207    pub fn num_levels(&self) -> u32 {
1208        self.num_levels
1209    }
1210
1211    /// Device-resident actual node count (len = 1).
1212    pub fn num_nodes_device(&self) -> &TrackedCudaSlice<u32> {
1213        &self.meta_num_nodes
1214    }
1215
1216    /// Device-resident actual edge count (len = 1).
1217    pub fn num_edges_device(&self) -> &TrackedCudaSlice<u32> {
1218        &self.meta_num_edges
1219    }
1220
1221    /// Device-resident level -> node index mapping (len = num_nodes).
1222    pub fn level_nodes(&self) -> &TrackedCudaSlice<u32> {
1223        &self.level_nodes
1224    }
1225
1226    /// Device-resident offsets for each level (len = num_levels + 1).
1227    pub fn level_offsets(&self) -> &TrackedCudaSlice<u32> {
1228        &self.level_offsets
1229    }
1230
1231    /// Device-resident node type tags (see `XgcfNodeType`).
1232    pub fn node_type(&self) -> &TrackedCudaSlice<u8> {
1233        &self.node_type
1234    }
1235
1236    /// Device-resident CSR child offsets for AND/OR nodes (len = num_nodes + 1).
1237    pub fn child_offsets(&self) -> &TrackedCudaSlice<u32> {
1238        &self.child_offsets
1239    }
1240
1241    /// Device-resident CSR child indices for AND/OR nodes.
1242    pub fn child_indices(&self) -> &TrackedCudaSlice<u32> {
1243        &self.child_indices
1244    }
1245
1246    /// Device-resident literals for LIT nodes (signed DIMACS, 1-based var ids).
1247    pub fn lit(&self) -> &TrackedCudaSlice<i32> {
1248        &self.lit
1249    }
1250
1251    /// Device-resident decision var ids for DECISION nodes (0 for non-decision).
1252    pub fn decision_var(&self) -> &TrackedCudaSlice<u32> {
1253        &self.decision_var
1254    }
1255
1256    pub fn decision_child_false(&self) -> &TrackedCudaSlice<u32> {
1257        &self.decision_child_false
1258    }
1259
1260    pub fn decision_child_true(&self) -> &TrackedCudaSlice<u32> {
1261        &self.decision_child_true
1262    }
1263
1264    /// Device-resident per-node values buffer (log-space). Written by forward pass.
1265    pub fn values(&self) -> &TrackedCudaSlice<f64> {
1266        &self.values
1267    }
1268
1269    /// Device-resident gradient buffer for ln(true-weight) per CNF variable.
1270    pub fn grad_true(&self) -> &TrackedCudaSlice<f64> {
1271        &self.grad_true
1272    }
1273
1274    /// Device-resident gradient buffer for ln(false-weight) per CNF variable.
1275    pub fn grad_false(&self) -> &TrackedCudaSlice<f64> {
1276        &self.grad_false
1277    }
1278
1279    /// Device-resident log(true-weight) table.
1280    pub fn var_log_true(&self) -> &TrackedCudaSlice<f64> {
1281        &self.var_log_true
1282    }
1283
1284    /// Device-resident log(false-weight) table.
1285    pub fn var_log_false(&self) -> &TrackedCudaSlice<f64> {
1286        &self.var_log_false
1287    }
1288
1289    /// Mutable access to device-resident log(true-weight) table.
1290    #[allow(dead_code)] // reserved: individual mutable weight access for exact_gpu path
1291    pub(crate) fn var_log_true_mut(&mut self) -> &mut TrackedCudaSlice<f64> {
1292        &mut self.var_log_true
1293    }
1294
1295    /// Mutable access to device-resident log(false-weight) table.
1296    #[allow(dead_code)] // reserved: individual mutable weight access for exact_gpu path
1297    pub(crate) fn var_log_false_mut(&mut self) -> &mut TrackedCudaSlice<f64> {
1298        &mut self.var_log_false
1299    }
1300
1301    /// Mutable access to both log-weight tables (true/false) at once.
1302    ///
1303    /// This is useful when passing both slices to a single CUDA kernel launch, avoiding
1304    /// overlapping mutable borrows of `self`.
1305    pub fn var_log_weights_mut(
1306        &mut self,
1307    ) -> (&mut TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) {
1308        (&mut self.var_log_true, &mut self.var_log_false)
1309    }
1310
1311    /// Attach a device-resident free-variable mask (length = max_var + 1).
1312    pub fn set_free_var_mask_device(&mut self, mask: TrackedCudaSlice<u8>) -> Result<()> {
1313        if mask.len() != self.var_log_true.len() {
1314            return Err(XlogError::Compilation(format!(
1315                "GPU free-var mask len {} != weights len {}",
1316                mask.len(),
1317                self.var_log_true.len()
1318            )));
1319        }
1320        self.free_var_mask = Some(mask);
1321        Ok(())
1322    }
1323
1324    /// Upload a host free-variable mask (length = max_var + 1).
1325    #[allow(dead_code)] // reserved: host-side mask upload for testing/diagnostics
1326    pub(crate) fn set_free_var_mask_from_host(
1327        &mut self,
1328        provider: &CudaKernelProvider,
1329        mask: &[u8],
1330    ) -> Result<()> {
1331        if mask.len() != self.var_log_true.len() {
1332            return Err(XlogError::Compilation(format!(
1333                "GPU free-var mask len {} != weights len {}",
1334                mask.len(),
1335                self.var_log_true.len()
1336            )));
1337        }
1338        let memory = provider.memory();
1339        let mut d_mask = memory.alloc::<u8>(mask.len())?;
1340        provider
1341            .htod_sync_copy_into_tracked(mask, &mut d_mask)
1342            .map_err(|e| XlogError::Kernel(format!("Failed to upload free_var_mask: {}", e)))?;
1343        self.free_var_mask = Some(d_mask);
1344        Ok(())
1345    }
1346
1347    /// Upload a host weight table into the device-resident `var_log_true/var_log_false` buffers.
1348    ///
1349    /// This is intended for one-time initialization of static weights (evidence + non-neural facts).
1350    /// Neural fast-path updates should overwrite only the relevant subset on GPU.
1351    pub fn set_base_weights(
1352        &mut self,
1353        provider: &CudaKernelProvider,
1354        var_log_weights: &[(f64, f64)],
1355    ) -> Result<()> {
1356        let weights_len = (self.max_var as usize) + 1;
1357        if var_log_weights.len() < weights_len {
1358            return Err(XlogError::Compilation(format!(
1359                "GPU XGCF weights init expects weight table len >= {}, got {}",
1360                weights_len,
1361                var_log_weights.len()
1362            )));
1363        }
1364
1365        let mut host_true: Vec<f64> = Vec::with_capacity(weights_len);
1366        let mut host_false: Vec<f64> = Vec::with_capacity(weights_len);
1367        for &(t, f) in &var_log_weights[..weights_len] {
1368            host_true.push(t);
1369            host_false.push(f);
1370        }
1371
1372        provider
1373            .htod_sync_copy_into_tracked(&host_true, &mut self.var_log_true)
1374            .map_err(|e| XlogError::Kernel(format!("Failed to upload log_true weights: {}", e)))?;
1375        provider
1376            .htod_sync_copy_into_tracked(&host_false, &mut self.var_log_false)
1377            .map_err(|e| XlogError::Kernel(format!("Failed to upload log_false weights: {}", e)))?;
1378
1379        Ok(())
1380    }
1381
1382    /// Evaluate logZ on the device using the currently loaded weights and write it into `out_log_z`.
1383    ///
1384    /// This method performs no device->host transfers.
1385    pub fn eval_log_wmc_device_inplace(
1386        &mut self,
1387        provider: &CudaKernelProvider,
1388        out_log_z: &mut TrackedCudaSlice<f64>,
1389    ) -> Result<()> {
1390        if out_log_z.len() != 1 {
1391            return Err(XlogError::Compilation(format!(
1392                "GPU device logZ output len {} != 1",
1393                out_log_z.len()
1394            )));
1395        }
1396
1397        let device = provider.device().inner();
1398        let func = device
1399            .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FORWARD_LEVEL)
1400            .ok_or_else(|| XlogError::Kernel("xgcf_forward_level kernel not found".to_string()))?;
1401
1402        let block_size: u32 = 256;
1403        let num_levels: usize = self.num_levels as usize;
1404        for level in 0..num_levels {
1405            let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
1406                Some(off) => checked_host_level_width(off, level)?,
1407                None => self.level_nodes.len(),
1408            };
1409            if num_level_nodes == 0 {
1410                continue;
1411            }
1412
1413            let num_blocks =
1414                checked_gpu_launch_blocks("xgcf_forward_level", num_level_nodes, block_size)?;
1415            let config = LaunchConfig {
1416                grid_dim: (num_blocks, 1, 1),
1417                block_dim: (block_size, 1, 1),
1418                shared_mem_bytes: 0,
1419            };
1420            let level_u32: u32 = level as u32;
1421
1422            let mut params: Vec<*mut c_void> = vec![
1423                (&self.node_type).as_kernel_param(),
1424                (&self.child_offsets).as_kernel_param(),
1425                (&self.child_indices).as_kernel_param(),
1426                (&self.lit).as_kernel_param(),
1427                (&self.decision_var).as_kernel_param(),
1428                (&self.decision_child_false).as_kernel_param(),
1429                (&self.decision_child_true).as_kernel_param(),
1430                (&self.level_nodes).as_kernel_param(),
1431                (&self.level_offsets).as_kernel_param(),
1432                level_u32.as_kernel_param(),
1433                (&self.var_log_true).as_kernel_param(),
1434                (&self.var_log_false).as_kernel_param(),
1435                (&self.values).as_kernel_param(),
1436            ];
1437
1438            // SAFETY: xgcf_forward_level(...) writes values for the provided level nodes.
1439            unsafe { func.clone().launch(config, &mut params) }
1440                .map_err(|e| XlogError::Kernel(format!("xgcf_forward_level failed: {}", e)))?;
1441        }
1442
1443        self.apply_free_var_correction(provider, true, false)?;
1444
1445        let root_idx = self.root as usize;
1446        let root_view = self.values.slice(root_idx..(root_idx + 1));
1447        device
1448            .dtod_copy(&root_view, out_log_z)
1449            .map_err(|e| XlogError::Kernel(format!("Failed to copy device logZ: {}", e)))?;
1450
1451        // No device synchronize: callers read back with a synchronous host copy
1452        // or pass the result to subsequent GPU operations (same-stream ordering).
1453        Ok(())
1454    }
1455
1456    /// Evaluate logZ on the device and write it into `out_log_z` (uploads weights from host).
1457    pub fn eval_log_wmc_device_into(
1458        &mut self,
1459        provider: &CudaKernelProvider,
1460        var_log_weights: &[(f64, f64)],
1461        out_log_z: &mut TrackedCudaSlice<f64>,
1462    ) -> Result<()> {
1463        self.set_base_weights(provider, var_log_weights)?;
1464        self.eval_log_wmc_device_inplace(provider, out_log_z)
1465    }
1466
1467    /// Evaluate logZ on the device and return a device-resident scalar (uploads weights from host).
1468    pub fn eval_log_wmc_device(
1469        &mut self,
1470        provider: &CudaKernelProvider,
1471        var_log_weights: &[(f64, f64)],
1472    ) -> Result<TrackedCudaSlice<f64>> {
1473        let memory = provider.memory();
1474        let mut out_log_z = memory.alloc::<f64>(1)?;
1475        self.eval_log_wmc_device_into(provider, var_log_weights, &mut out_log_z)?;
1476        Ok(out_log_z)
1477    }
1478
1479    fn apply_free_var_correction(
1480        &mut self,
1481        provider: &CudaKernelProvider,
1482        apply_log_z: bool,
1483        apply_grads: bool,
1484    ) -> Result<()> {
1485        let Some(mask) = self.free_var_mask.as_ref() else {
1486            return Ok(());
1487        };
1488
1489        if mask.len() != self.var_log_true.len() {
1490            return Err(XlogError::Compilation(format!(
1491                "GPU free-var mask len {} != weights len {}",
1492                mask.len(),
1493                self.var_log_true.len()
1494            )));
1495        }
1496
1497        let n = u32::try_from(mask.len())
1498            .map_err(|_| XlogError::Compilation("GPU free-var mask length overflow".to_string()))?;
1499        if n == 0 {
1500            return Ok(());
1501        }
1502
1503        let device = provider.device().inner();
1504        let block_dim = 256u32;
1505        let grid_dim = n.div_ceil(block_dim);
1506
1507        if apply_grads {
1508            let apply_grad = device
1509                .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FREE_VAR_APPLY_GRAD)
1510                .ok_or_else(|| {
1511                    XlogError::Kernel("xgcf_free_var_apply_grad kernel not found".to_string())
1512                })?;
1513            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1514            unsafe {
1515                apply_grad.clone().launch(
1516                    LaunchConfig {
1517                        grid_dim: (grid_dim, 1, 1),
1518                        block_dim: (block_dim, 1, 1),
1519                        shared_mem_bytes: 0,
1520                    },
1521                    (
1522                        mask,
1523                        &self.var_log_true,
1524                        &self.var_log_false,
1525                        n,
1526                        &mut self.grad_true,
1527                        &mut self.grad_false,
1528                    ),
1529                )
1530            }
1531            .map_err(|e| XlogError::Kernel(format!("xgcf_free_var_apply_grad failed: {}", e)))?;
1532        }
1533
1534        if apply_log_z {
1535            let reduce_stage = device
1536                .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FREE_VAR_REDUCE_STAGE)
1537                .ok_or_else(|| {
1538                    XlogError::Kernel("xgcf_free_var_reduce_stage kernel not found".to_string())
1539                })?;
1540            let add_scalar = device
1541                .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_ADD_SCALAR)
1542                .ok_or_else(|| XlogError::Kernel("xgcf_add_scalar kernel not found".to_string()))?;
1543
1544            let memory = provider.memory();
1545            let mut buf_a = memory.alloc::<f64>(mask.len())?;
1546            let mut buf_b = memory.alloc::<f64>(mask.len())?;
1547
1548            let mut stage_n = n;
1549            let mut stage0 = true;
1550            let mut output_is_a = true;
1551            loop {
1552                let out_len = stage_n.div_ceil(2);
1553                let stage_grid = out_len.div_ceil(block_dim);
1554
1555                let (in_buf, out_buf): (&TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) =
1556                    if output_is_a {
1557                        (&buf_b, &mut buf_a)
1558                    } else {
1559                        (&buf_a, &mut buf_b)
1560                    };
1561                let mode = if stage0 { 0u32 } else { 1u32 };
1562
1563                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1564                unsafe {
1565                    reduce_stage.clone().launch(
1566                        LaunchConfig {
1567                            grid_dim: (stage_grid, 1, 1),
1568                            block_dim: (block_dim, 1, 1),
1569                            shared_mem_bytes: 0,
1570                        },
1571                        (
1572                            mask,
1573                            &self.var_log_true,
1574                            &self.var_log_false,
1575                            in_buf,
1576                            stage_n,
1577                            mode,
1578                            out_buf,
1579                        ),
1580                    )
1581                }
1582                .map_err(|e| {
1583                    XlogError::Kernel(format!("xgcf_free_var_reduce_stage failed: {}", e))
1584                })?;
1585
1586                if out_len == 1 {
1587                    let result_buf = if output_is_a { &buf_a } else { &buf_b };
1588                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1589                    unsafe {
1590                        add_scalar.clone().launch(
1591                            LaunchConfig {
1592                                grid_dim: (1, 1, 1),
1593                                block_dim: (1, 1, 1),
1594                                shared_mem_bytes: 0,
1595                            },
1596                            (&mut self.values, self.root, result_buf),
1597                        )
1598                    }
1599                    .map_err(|e| XlogError::Kernel(format!("xgcf_add_scalar failed: {}", e)))?;
1600                    break;
1601                }
1602
1603                stage_n = out_len;
1604                stage0 = false;
1605                output_is_a = !output_is_a;
1606            }
1607        }
1608
1609        Ok(())
1610    }
1611
1612    /// Evaluate the circuit and populate `grad_true/grad_false` on the device (no host reads).
1613    ///
1614    /// Preconditions:
1615    /// - `var_log_true/var_log_false` contain the current weights on device.
1616    /// - Caller may read back results for testing/debugging, but this API performs no dtoh transfers.
1617    pub fn eval_grads_inplace(&mut self, provider: &CudaKernelProvider) -> Result<()> {
1618        let device = provider.device().inner();
1619
1620        // Forward pass (identical to eval_log_wmc, minus weight upload and root readback).
1621        let func = device
1622            .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FORWARD_LEVEL)
1623            .ok_or_else(|| XlogError::Kernel("xgcf_forward_level kernel not found".to_string()))?;
1624
1625        let block_size: u32 = 256;
1626        let num_levels: usize = self.num_levels as usize;
1627        for level in 0..num_levels {
1628            let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
1629                Some(off) => checked_host_level_width(off, level)?,
1630                None => self.level_nodes.len(),
1631            };
1632            if num_level_nodes == 0 {
1633                continue;
1634            }
1635
1636            let num_blocks =
1637                checked_gpu_launch_blocks("xgcf_forward_level", num_level_nodes, block_size)?;
1638            let config = LaunchConfig {
1639                grid_dim: (num_blocks, 1, 1),
1640                block_dim: (block_size, 1, 1),
1641                shared_mem_bytes: 0,
1642            };
1643            let level_u32: u32 = level as u32;
1644
1645            let mut params: Vec<*mut c_void> = vec![
1646                (&self.node_type).as_kernel_param(),
1647                (&self.child_offsets).as_kernel_param(),
1648                (&self.child_indices).as_kernel_param(),
1649                (&self.lit).as_kernel_param(),
1650                (&self.decision_var).as_kernel_param(),
1651                (&self.decision_child_false).as_kernel_param(),
1652                (&self.decision_child_true).as_kernel_param(),
1653                (&self.level_nodes).as_kernel_param(),
1654                (&self.level_offsets).as_kernel_param(),
1655                level_u32.as_kernel_param(),
1656                (&self.var_log_true).as_kernel_param(),
1657                (&self.var_log_false).as_kernel_param(),
1658                (&self.values).as_kernel_param(),
1659            ];
1660
1661            // SAFETY: xgcf_forward_level(...) writes values for the provided level nodes.
1662            unsafe { func.clone().launch(config, &mut params) }
1663                .map_err(|e| XlogError::Kernel(format!("xgcf_forward_level failed: {}", e)))?;
1664        }
1665
1666        // Backward pass buffers.
1667        device
1668            .memset_zeros(&mut self.adj)
1669            .map_err(|e| XlogError::Kernel(format!("Failed to zero adj buffer: {}", e)))?;
1670        device
1671            .memset_zeros(&mut self.grad_true)
1672            .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_true buffer: {}", e)))?;
1673        device
1674            .memset_zeros(&mut self.grad_false)
1675            .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_false buffer: {}", e)))?;
1676
1677        // Set root adjoint to 1.0 via GPU kernel (avoid host copy).
1678        let root_idx = self.root as usize;
1679        let mut root_adj_view = self.adj.slice_mut(root_idx..(root_idx + 1));
1680        let fill_const = device
1681            .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_F64)
1682            .ok_or_else(|| {
1683                XlogError::Kernel("arith_fill_const_f64 kernel not found".to_string())
1684            })?;
1685        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1686        unsafe {
1687            fill_const.clone().launch(
1688                LaunchConfig {
1689                    grid_dim: (1, 1, 1),
1690                    block_dim: (1, 1, 1),
1691                    shared_mem_bytes: 0,
1692                },
1693                (1.0_f64, 1u32, &mut root_adj_view),
1694            )
1695        }
1696        .map_err(|e| XlogError::Kernel(format!("arith_fill_const_f64 failed: {}", e)))?;
1697
1698        let propagate = device
1699            .get_func(
1700                CIRCUIT_MODULE,
1701                circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
1702            )
1703            .ok_or_else(|| {
1704                XlogError::Kernel("xgcf_backward_level_propagate kernel not found".to_string())
1705            })?;
1706        let decision_grad = device
1707            .get_func(
1708                CIRCUIT_MODULE,
1709                circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
1710            )
1711            .ok_or_else(|| {
1712                XlogError::Kernel("xgcf_backward_level_decision_grad kernel not found".to_string())
1713            })?;
1714        let lit_grad = device
1715            .get_func(
1716                CIRCUIT_MODULE,
1717                circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
1718            )
1719            .ok_or_else(|| {
1720                XlogError::Kernel("xgcf_backward_level_lit_grad kernel not found".to_string())
1721            })?;
1722
1723        let num_levels: usize = self.num_levels as usize;
1724        for level in (0..num_levels).rev() {
1725            let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
1726                Some(off) => checked_host_level_width(off, level)?,
1727                None => self.level_nodes.len(),
1728            };
1729            if num_level_nodes == 0 {
1730                continue;
1731            }
1732
1733            let num_blocks =
1734                checked_gpu_launch_blocks("xgcf_backward_level", num_level_nodes, block_size)?;
1735            let config = LaunchConfig {
1736                grid_dim: (num_blocks, 1, 1),
1737                block_dim: (block_size, 1, 1),
1738                shared_mem_bytes: 0,
1739            };
1740            let level_u32: u32 = level as u32;
1741
1742            let mut params: Vec<*mut c_void> = vec![
1743                (&self.node_type).as_kernel_param(),
1744                (&self.child_offsets).as_kernel_param(),
1745                (&self.child_indices).as_kernel_param(),
1746                (&self.decision_var).as_kernel_param(),
1747                (&self.decision_child_false).as_kernel_param(),
1748                (&self.decision_child_true).as_kernel_param(),
1749                (&self.level_nodes).as_kernel_param(),
1750                (&self.level_offsets).as_kernel_param(),
1751                level_u32.as_kernel_param(),
1752                (&self.var_log_true).as_kernel_param(),
1753                (&self.var_log_false).as_kernel_param(),
1754                (&self.values).as_kernel_param(),
1755                (&self.adj).as_kernel_param(),
1756            ];
1757
1758            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1759            unsafe { propagate.clone().launch(config, &mut params) }.map_err(|e| {
1760                XlogError::Kernel(format!("xgcf_backward_level_propagate failed: {}", e))
1761            })?;
1762
1763            let mut params: Vec<*mut c_void> = vec![
1764                (&self.node_type).as_kernel_param(),
1765                (&self.decision_var).as_kernel_param(),
1766                (&self.decision_child_false).as_kernel_param(),
1767                (&self.decision_child_true).as_kernel_param(),
1768                (&self.level_nodes).as_kernel_param(),
1769                (&self.level_offsets).as_kernel_param(),
1770                level_u32.as_kernel_param(),
1771                (&self.var_log_true).as_kernel_param(),
1772                (&self.var_log_false).as_kernel_param(),
1773                (&self.values).as_kernel_param(),
1774                (&self.adj).as_kernel_param(),
1775                (&self.grad_true).as_kernel_param(),
1776                (&self.grad_false).as_kernel_param(),
1777            ];
1778
1779            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1780            unsafe { decision_grad.clone().launch(config, &mut params) }.map_err(|e| {
1781                XlogError::Kernel(format!("xgcf_backward_level_decision_grad failed: {}", e))
1782            })?;
1783
1784            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1785            unsafe {
1786                lit_grad.clone().launch(
1787                    config,
1788                    (
1789                        &self.node_type,
1790                        &self.lit,
1791                        &self.level_nodes,
1792                        &self.level_offsets,
1793                        level_u32,
1794                        &self.adj,
1795                        &self.grad_true,
1796                        &self.grad_false,
1797                    ),
1798                )
1799            }
1800            .map_err(|e| {
1801                XlogError::Kernel(format!("xgcf_backward_level_lit_grad failed: {}", e))
1802            })?;
1803        }
1804
1805        self.apply_free_var_correction(provider, true, true)?;
1806        // No device synchronize: callers batch multiple eval/backward calls
1807        // before syncing at the query boundary.
1808        Ok(())
1809    }
1810
1811    #[cfg(feature = "host-io")]
1812    pub fn eval_log_wmc(
1813        &mut self,
1814        provider: &CudaKernelProvider,
1815        var_log_weights: &[(f64, f64)],
1816    ) -> Result<f64> {
1817        let device = provider.device().inner();
1818        let mut out_log_z = provider.memory().alloc::<f64>(1)?;
1819        self.eval_log_wmc_device_into(provider, var_log_weights, &mut out_log_z)?;
1820
1821        let mut host = [0.0_f64];
1822        device
1823            .dtoh_sync_copy_into(&out_log_z, &mut host)
1824            .map_err(|e| XlogError::Kernel(format!("Failed to read circuit root value: {}", e)))?;
1825        Ok(host[0])
1826    }
1827
1828    #[cfg(feature = "host-io")]
1829    pub fn eval_log_wmc_and_grads(
1830        &mut self,
1831        provider: &CudaKernelProvider,
1832        var_log_weights: &[(f64, f64)],
1833    ) -> Result<(f64, Vec<f64>, Vec<f64>)> {
1834        self.set_base_weights(provider, var_log_weights)?;
1835        self.eval_grads_inplace(provider)?;
1836
1837        let device = provider.device().inner();
1838
1839        let weights_len = (self.max_var as usize) + 1;
1840        let mut host_grad_true: Vec<f64> = vec![0.0; weights_len];
1841        let mut host_grad_false: Vec<f64> = vec![0.0; weights_len];
1842
1843        let root_idx = self.root as usize;
1844        let root_view = self.values.slice(root_idx..(root_idx + 1));
1845        let mut log_z = [0.0_f64];
1846        device
1847            .dtoh_sync_copy_into(&root_view, &mut log_z)
1848            .map_err(|e| XlogError::Kernel(format!("Failed to read circuit root value: {}", e)))?;
1849
1850        device
1851            .dtoh_sync_copy_into(&self.grad_true, &mut host_grad_true)
1852            .map_err(|e| XlogError::Kernel(format!("Failed to download grad_true: {}", e)))?;
1853        device
1854            .dtoh_sync_copy_into(&self.grad_false, &mut host_grad_false)
1855            .map_err(|e| XlogError::Kernel(format!("Failed to download grad_false: {}", e)))?;
1856
1857        Ok((log_z[0], host_grad_true, host_grad_false))
1858    }
1859}