Skip to main content

xlog_prob/compilation/
gpu_cache.rs

1//! GPU-resident circuit cache helpers.
2
3use std::sync::Arc;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::provider::{cache_kernels, CACHE_MODULE};
9use xlog_cuda::{AsKernelParam, CudaKernelProvider, LaunchAsync};
10use xlog_solve::GpuCnf;
11
12use super::disk_cache;
13use crate::gpu::GpuXgcf;
14
15/// Configuration for the GPU-resident circuit cache.
16///
17/// Controls the number of cached XGCF circuit slots and the per-slot capacity
18/// limits for nodes, edges, levels, and variables. Production callers should
19/// use [`crate::exact::default_cache_config`] which sizes caps from the CNF
20/// and compile config.
21#[derive(Debug, Clone, Copy)]
22#[non_exhaustive]
23pub struct GpuCircuitCacheConfig {
24    /// Number of circuit slots kept resident on the GPU.
25    pub num_slots: u32,
26    /// Hash table size for the circuit lookup (should be >= 2 * num_slots).
27    pub table_size: u32,
28    /// Maximum nodes per cached circuit.
29    pub node_cap: u32,
30    /// Maximum edges per cached circuit.
31    pub edge_cap: u32,
32    /// Maximum levels (BFS depth) per cached circuit.
33    pub level_cap: u32,
34    /// Maximum CNF variable id (1-based, DIMACS) across all cached circuits.
35    pub var_cap: u32,
36}
37
38impl Default for GpuCircuitCacheConfig {
39    /// Conservative defaults for small CNFs (< 64 variables).
40    ///
41    /// Production callers should derive caps from the actual CNF dimensions.
42    fn default() -> Self {
43        Self {
44            num_slots: 4,
45            table_size: 8,
46            node_cap: 65_536,
47            edge_cap: 131_072,
48            level_cap: 65_536,
49            var_cap: 128,
50        }
51    }
52}
53
54fn cache_grid_dim_for_u32_count(context: &str, count: u32, block_dim: u32) -> Result<u32> {
55    if count == 0 {
56        return Ok(0);
57    }
58    if block_dim == 0 {
59        return Err(XlogError::Compilation(format!(
60            "{context}: GPU cache block size must be nonzero"
61        )));
62    }
63    let padded = count
64        .checked_add(block_dim - 1)
65        .ok_or_else(|| XlogError::Compilation(format!("{context}: GPU cache grid overflow")))?;
66    Ok(padded / block_dim)
67}
68
69fn cache_grid_dim_for_u64_count(context: &str, count: u64, block_dim: u32) -> Result<u32> {
70    if count == 0 {
71        return Ok(0);
72    }
73    if block_dim == 0 {
74        return Err(XlogError::Compilation(format!(
75            "{context}: GPU cache block size must be nonzero"
76        )));
77    }
78    let block = block_dim as u64;
79    let grid = count
80        .checked_add(block - 1)
81        .map(|padded| padded / block)
82        .ok_or_else(|| XlogError::Compilation(format!("{context}: GPU cache grid overflow")))?;
83    u32::try_from(grid)
84        .map_err(|_| XlogError::Compilation(format!("{context}: GPU cache grid exceeds u32")))
85}
86
87pub struct GpuCircuitCache {
88    provider: Arc<CudaKernelProvider>,
89    table_size: u32,
90    num_slots: u32,
91    node_cap: u32,
92    edge_cap: u32,
93    level_cap: u32,
94    var_cap: u32,
95    keys: TrackedCudaSlice<u64>,
96    slots: TrackedCudaSlice<u32>,
97    state: TrackedCudaSlice<u32>,
98    last_used: TrackedCudaSlice<u64>,
99    slot_states: TrackedCudaSlice<u32>,
100    clock: TrackedCudaSlice<u64>,
101    node_type: TrackedCudaSlice<u8>,
102    child_offsets: TrackedCudaSlice<u32>,
103    child_indices: TrackedCudaSlice<u32>,
104    lit: TrackedCudaSlice<i32>,
105    decision_var: TrackedCudaSlice<u32>,
106    decision_child_false: TrackedCudaSlice<u32>,
107    decision_child_true: TrackedCudaSlice<u32>,
108    level_nodes: TrackedCudaSlice<u32>,
109    level_offsets: TrackedCudaSlice<u32>,
110    var_log_true: TrackedCudaSlice<f64>,
111    var_log_false: TrackedCudaSlice<f64>,
112    values: TrackedCudaSlice<f64>,
113    adj: TrackedCudaSlice<f64>,
114    grad_true: TrackedCudaSlice<f64>,
115    grad_false: TrackedCudaSlice<f64>,
116    meta_num_nodes: TrackedCudaSlice<u32>,
117    meta_num_levels: TrackedCudaSlice<u32>,
118    meta_root: TrackedCudaSlice<u32>,
119    meta_max_var: TrackedCudaSlice<u32>,
120    always_on: TrackedCudaSlice<u32>,
121    zero_f64: TrackedCudaSlice<f64>,
122    one_f64: TrackedCudaSlice<f64>,
123    free_var_mask: TrackedCudaSlice<u8>,
124    has_free_var_mask: Vec<bool>,
125}
126
127pub struct GpuCacheLookup {
128    provider: Arc<CudaKernelProvider>,
129    slot: TrackedCudaSlice<u32>,
130    compile_needed: TrackedCudaSlice<u32>,
131}
132
133impl GpuCacheLookup {
134    pub fn slot_device(&self) -> &TrackedCudaSlice<u32> {
135        &self.slot
136    }
137
138    pub fn compile_needed_device(&self) -> &TrackedCudaSlice<u32> {
139        &self.compile_needed
140    }
141
142    pub fn provider(&self) -> &Arc<CudaKernelProvider> {
143        &self.provider
144    }
145
146    pub fn into_handle(self) -> Result<GpuCircuitCacheHandle> {
147        let slot_host_vec: Vec<u32> = self
148            .provider
149            .device()
150            .inner()
151            .dtoh_sync_copy(&self.slot)
152            .map_err(|e| XlogError::Kernel(format!("dtoh slot index: {}", e)))?;
153        Ok(GpuCircuitCacheHandle {
154            provider: self.provider,
155            slot: self.slot,
156            compile_needed: self.compile_needed,
157            slot_host: slot_host_vec[0],
158            num_nodes: 0,
159            num_levels: 0,
160            root: 0,
161            max_var: 0,
162        })
163    }
164}
165
166pub struct GpuCircuitCacheHandle {
167    provider: Arc<CudaKernelProvider>,
168    slot: TrackedCudaSlice<u32>,
169    compile_needed: TrackedCudaSlice<u32>,
170    slot_host: u32,
171    num_nodes: u32,
172    num_levels: u32,
173    root: u32,
174    max_var: u32,
175}
176
177impl GpuCircuitCacheHandle {
178    pub fn slot_device(&self) -> &TrackedCudaSlice<u32> {
179        &self.slot
180    }
181
182    pub fn compile_needed_device(&self) -> &TrackedCudaSlice<u32> {
183        &self.compile_needed
184    }
185
186    pub fn provider(&self) -> &Arc<CudaKernelProvider> {
187        &self.provider
188    }
189
190    pub fn num_nodes(&self) -> u32 {
191        self.num_nodes
192    }
193
194    pub fn num_levels(&self) -> u32 {
195        self.num_levels
196    }
197
198    pub fn root(&self) -> u32 {
199        self.root
200    }
201
202    pub fn max_var(&self) -> u32 {
203        self.max_var
204    }
205
206    pub(crate) fn slot_index(&self) -> u32 {
207        self.slot_host
208    }
209
210    #[allow(dead_code)] // reserved API: used by future cache-warming path
211    pub(crate) fn set_meta(&mut self, num_nodes: u32, num_levels: u32, root: u32, max_var: u32) {
212        self.num_nodes = num_nodes;
213        self.num_levels = num_levels;
214        self.root = root;
215        self.max_var = max_var;
216    }
217}
218
219/// Compute a deterministic CNF hash on the GPU.
220///
221/// Hash input order matches the cache kernel: num_vars, num_clauses, num_lits,
222/// clause_offsets[0..num_clauses], literals[0..num_lits-1].
223pub fn hash_cnf_gpu(
224    cnf: &GpuCnf,
225    provider: &Arc<CudaKernelProvider>,
226) -> Result<TrackedCudaSlice<u64>> {
227    let memory = provider.memory();
228    let mut out_hash = memory.alloc::<u64>(1)?;
229
230    let func = provider
231        .device()
232        .inner()
233        .get_func(CACHE_MODULE, cache_kernels::CACHE_CNF_HASH)
234        .ok_or_else(|| XlogError::Kernel("cache_cnf_hash kernel not found".to_string()))?;
235
236    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
237    unsafe {
238        func.clone().launch(
239            LaunchConfig {
240                grid_dim: (1, 1, 1),
241                block_dim: (1, 1, 1),
242                shared_mem_bytes: 0,
243            },
244            (
245                &cnf.num_vars,
246                &cnf.num_clauses,
247                &cnf.num_lits,
248                &cnf.clause_offsets,
249                &cnf.literals,
250                &mut out_hash,
251            ),
252        )
253    }
254    .map_err(|e| XlogError::Kernel(format!("cache_cnf_hash launch failed: {}", e)))?;
255    // No device synchronize: hash stays device-resident for lookup kernel; same-stream ordering suffices.
256    Ok(out_hash)
257}
258
259impl GpuCircuitCache {
260    pub fn provider(&self) -> &Arc<CudaKernelProvider> {
261        &self.provider
262    }
263
264    pub fn var_log_weights_mut(
265        &mut self,
266    ) -> (&mut TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) {
267        (&mut self.var_log_true, &mut self.var_log_false)
268    }
269
270    pub fn grad_true(&self) -> &TrackedCudaSlice<f64> {
271        &self.grad_true
272    }
273
274    pub fn grad_false(&self) -> &TrackedCudaSlice<f64> {
275        &self.grad_false
276    }
277
278    pub fn values(&self) -> &TrackedCudaSlice<f64> {
279        &self.values
280    }
281
282    pub fn meta_num_nodes_device(&self) -> &TrackedCudaSlice<u32> {
283        &self.meta_num_nodes
284    }
285
286    pub fn meta_num_levels_device(&self) -> &TrackedCudaSlice<u32> {
287        &self.meta_num_levels
288    }
289
290    pub fn meta_root_device(&self) -> &TrackedCudaSlice<u32> {
291        &self.meta_root
292    }
293
294    pub fn meta_max_var_device(&self) -> &TrackedCudaSlice<u32> {
295        &self.meta_max_var
296    }
297
298    pub fn num_slots(&self) -> u32 {
299        self.num_slots
300    }
301
302    pub(crate) fn has_any_free_var_mask(&self) -> bool {
303        self.has_free_var_mask.iter().any(|&v| v)
304    }
305
306    pub(crate) fn has_free_var_mask_for_slot(&self, slot: u32) -> bool {
307        self.has_free_var_mask
308            .get(slot as usize)
309            .copied()
310            .unwrap_or(false)
311    }
312
313    pub(crate) fn var_stride(&self) -> Result<u32> {
314        self.var_cap
315            .checked_add(1)
316            .ok_or_else(|| XlogError::Compilation("GpuCircuitCache var_cap overflow".to_string()))
317    }
318
319    pub(crate) fn node_stride(&self) -> u32 {
320        self.node_cap
321    }
322
323    pub(crate) fn copy_slot_weights_to_batch(
324        &mut self,
325        handle: &GpuCircuitCacheHandle,
326        out_true_batch: &mut TrackedCudaSlice<f64>,
327        out_false_batch: &mut TrackedCudaSlice<f64>,
328        batch_size: u32,
329    ) -> Result<()> {
330        if batch_size == 0 {
331            return Ok(());
332        }
333        let var_stride = self.var_stride()?;
334        let expected = (batch_size as usize)
335            .checked_mul(var_stride as usize)
336            .ok_or_else(|| {
337                XlogError::Compilation("GpuCircuitCache batch weight size overflow".to_string())
338            })?;
339        if out_true_batch.len() != expected || out_false_batch.len() != expected {
340            return Err(XlogError::Compilation(format!(
341                "GpuCircuitCache batched weight buffers must both have len {}, got {} and {}",
342                expected,
343                out_true_batch.len(),
344                out_false_batch.len()
345            )));
346        }
347
348        let device = self.provider.device().inner();
349        let func = device
350            .get_func(
351                xlog_cuda::provider::WEIGHTS_MODULE,
352                xlog_cuda::provider::weights_kernels::WEIGHTS_COPY_SLOT_TO_BATCH,
353            )
354            .ok_or_else(|| {
355                XlogError::Kernel("weights_copy_slot_to_batch kernel not found".to_string())
356            })?;
357
358        let block_dim = 256u32;
359        let total = (batch_size as u64)
360            .checked_mul(var_stride as u64)
361            .ok_or_else(|| {
362                XlogError::Compilation("GpuCircuitCache batch copy overflow".to_string())
363            })?;
364        let grid_dim =
365            cache_grid_dim_for_u64_count("GpuCircuitCache batch weight copy", total, block_dim)?;
366        if grid_dim == 0 {
367            return Ok(());
368        }
369
370        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
371        unsafe {
372            func.clone().launch(
373                LaunchConfig {
374                    grid_dim: (grid_dim, 1, 1),
375                    block_dim: (block_dim, 1, 1),
376                    shared_mem_bytes: 0,
377                },
378                (
379                    handle.slot_device(),
380                    self.var_cap,
381                    &self.var_log_true,
382                    &self.var_log_false,
383                    out_true_batch,
384                    out_false_batch,
385                    var_stride,
386                    batch_size,
387                ),
388            )
389        }
390        .map_err(|e| XlogError::Kernel(format!("weights_copy_slot_to_batch failed: {}", e)))?;
391
392        Ok(())
393    }
394
395    #[allow(clippy::too_many_arguments)]
396    pub(crate) fn eval_grads_inplace_fused_batched(
397        &mut self,
398        handle: &GpuCircuitCacheHandle,
399        var_log_true_batch: &TrackedCudaSlice<f64>,
400        var_log_false_batch: &TrackedCudaSlice<f64>,
401        values_batch: &mut TrackedCudaSlice<f64>,
402        adj_batch: &mut TrackedCudaSlice<f64>,
403        grad_true_batch: &mut TrackedCudaSlice<f64>,
404        grad_false_batch: &mut TrackedCudaSlice<f64>,
405        batch_size: u32,
406    ) -> Result<()> {
407        if batch_size == 0 {
408            return Ok(());
409        }
410        if self.has_free_var_mask_for_slot(handle.slot_index()) {
411            return Err(XlogError::Execution(
412                "Batched fused eval currently does not support free-var correction".to_string(),
413            ));
414        }
415
416        let var_stride = self.var_stride()?;
417        let node_stride = self.node_stride();
418        let expected_var = (batch_size as usize)
419            .checked_mul(var_stride as usize)
420            .ok_or_else(|| {
421                XlogError::Compilation("GpuCircuitCache batched var buffer overflow".to_string())
422            })?;
423        let expected_node = (batch_size as usize)
424            .checked_mul(node_stride as usize)
425            .ok_or_else(|| {
426                XlogError::Compilation("GpuCircuitCache batched node buffer overflow".to_string())
427            })?;
428
429        if var_log_true_batch.len() != expected_var
430            || var_log_false_batch.len() != expected_var
431            || grad_true_batch.len() != expected_var
432            || grad_false_batch.len() != expected_var
433        {
434            return Err(XlogError::Compilation(format!(
435                "GpuCircuitCache batched var buffers must have len {}",
436                expected_var
437            )));
438        }
439        if values_batch.len() != expected_node || adj_batch.len() != expected_node {
440            return Err(XlogError::Compilation(format!(
441                "GpuCircuitCache batched node buffers must have len {}",
442                expected_node
443            )));
444        }
445
446        let device = self.provider.device().inner();
447        device
448            .memset_zeros(adj_batch)
449            .map_err(|e| XlogError::Kernel(format!("Failed to zero batched adj: {}", e)))?;
450        device
451            .memset_zeros(grad_true_batch)
452            .map_err(|e| XlogError::Kernel(format!("Failed to zero batched grad_true: {}", e)))?;
453        device
454            .memset_zeros(grad_false_batch)
455            .map_err(|e| XlogError::Kernel(format!("Failed to zero batched grad_false: {}", e)))?;
456
457        let eval_all = device
458            .get_func(
459                xlog_cuda::CIRCUIT_MODULE,
460                xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED_BATCHED,
461            )
462            .ok_or_else(|| {
463                XlogError::Kernel("xgcf_eval_all_levels_cached_batched not found".to_string())
464            })?;
465        let set_root_adj = device
466            .get_func(
467                xlog_cuda::CIRCUIT_MODULE,
468                xlog_cuda::circuit_kernels::XGCF_SET_ROOT_ADJ_CACHED_BATCHED,
469            )
470            .ok_or_else(|| {
471                XlogError::Kernel("xgcf_set_root_adj_cached_batched not found".to_string())
472            })?;
473        let backward_all = device
474            .get_func(
475                xlog_cuda::CIRCUIT_MODULE,
476                xlog_cuda::circuit_kernels::XGCF_BACKWARD_ALL_LEVELS_CACHED_BATCHED,
477            )
478            .ok_or_else(|| {
479                XlogError::Kernel("xgcf_backward_all_levels_cached_batched not found".to_string())
480            })?;
481
482        let block_size = 256u32;
483        let mut eval_params: Vec<*mut std::ffi::c_void> = vec![
484            handle.slot_device().as_kernel_param(),
485            self.node_cap.as_kernel_param(),
486            self.edge_cap.as_kernel_param(),
487            self.level_cap.as_kernel_param(),
488            self.var_cap.as_kernel_param(),
489            (&self.node_type).as_kernel_param(),
490            (&self.child_offsets).as_kernel_param(),
491            (&self.child_indices).as_kernel_param(),
492            (&self.lit).as_kernel_param(),
493            (&self.decision_var).as_kernel_param(),
494            (&self.decision_child_false).as_kernel_param(),
495            (&self.decision_child_true).as_kernel_param(),
496            (&self.level_nodes).as_kernel_param(),
497            (&self.level_offsets).as_kernel_param(),
498            (&self.meta_num_levels).as_kernel_param(),
499            var_log_true_batch.as_kernel_param(),
500            var_log_false_batch.as_kernel_param(),
501            var_stride.as_kernel_param(),
502            values_batch.as_kernel_param(),
503            node_stride.as_kernel_param(),
504            batch_size.as_kernel_param(),
505        ];
506        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
507        unsafe {
508            eval_all.clone().launch(
509                LaunchConfig {
510                    grid_dim: (batch_size, 1, 1),
511                    block_dim: (block_size, 1, 1),
512                    shared_mem_bytes: 0,
513                },
514                &mut eval_params,
515            )
516        }
517        .map_err(|e| {
518            XlogError::Kernel(format!("xgcf_eval_all_levels_cached_batched failed: {}", e))
519        })?;
520
521        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
522        unsafe {
523            set_root_adj.clone().launch(
524                LaunchConfig {
525                    grid_dim: (batch_size, 1, 1),
526                    block_dim: (1, 1, 1),
527                    shared_mem_bytes: 0,
528                },
529                (
530                    handle.slot_device(),
531                    self.node_cap,
532                    &self.meta_root,
533                    &mut *adj_batch,
534                    node_stride,
535                    batch_size,
536                ),
537            )
538        }
539        .map_err(|e| {
540            XlogError::Kernel(format!("xgcf_set_root_adj_cached_batched failed: {}", e))
541        })?;
542
543        let mut backward_params: Vec<*mut std::ffi::c_void> = vec![
544            handle.slot_device().as_kernel_param(),
545            self.node_cap.as_kernel_param(),
546            self.edge_cap.as_kernel_param(),
547            self.level_cap.as_kernel_param(),
548            self.var_cap.as_kernel_param(),
549            (&self.node_type).as_kernel_param(),
550            (&self.child_offsets).as_kernel_param(),
551            (&self.child_indices).as_kernel_param(),
552            (&self.decision_var).as_kernel_param(),
553            (&self.decision_child_false).as_kernel_param(),
554            (&self.decision_child_true).as_kernel_param(),
555            (&self.lit).as_kernel_param(),
556            (&self.level_nodes).as_kernel_param(),
557            (&self.level_offsets).as_kernel_param(),
558            (&self.meta_num_levels).as_kernel_param(),
559            var_log_true_batch.as_kernel_param(),
560            var_log_false_batch.as_kernel_param(),
561            var_stride.as_kernel_param(),
562            values_batch.as_kernel_param(),
563            node_stride.as_kernel_param(),
564            adj_batch.as_kernel_param(),
565            node_stride.as_kernel_param(),
566            grad_true_batch.as_kernel_param(),
567            grad_false_batch.as_kernel_param(),
568            var_stride.as_kernel_param(),
569            batch_size.as_kernel_param(),
570        ];
571        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
572        unsafe {
573            backward_all.clone().launch(
574                LaunchConfig {
575                    grid_dim: (batch_size, 1, 1),
576                    block_dim: (block_size, 1, 1),
577                    shared_mem_bytes: 0,
578                },
579                &mut backward_params,
580            )
581        }
582        .map_err(|e| {
583            XlogError::Kernel(format!(
584                "xgcf_backward_all_levels_cached_batched failed: {}",
585                e
586            ))
587        })?;
588
589        Ok(())
590    }
591
592    pub(crate) fn copy_root_batched_from_values(
593        &self,
594        handle: &GpuCircuitCacheHandle,
595        values_batch: &TrackedCudaSlice<f64>,
596        out_roots: &mut TrackedCudaSlice<f64>,
597        batch_size: u32,
598    ) -> Result<()> {
599        if batch_size == 0 {
600            return Ok(());
601        }
602        let node_stride = self.node_stride();
603        let expected_values = (batch_size as usize)
604            .checked_mul(node_stride as usize)
605            .ok_or_else(|| {
606                XlogError::Compilation("GpuCircuitCache batched values overflow".to_string())
607            })?;
608        if values_batch.len() != expected_values || out_roots.len() != batch_size as usize {
609            return Err(XlogError::Compilation(format!(
610                "GpuCircuitCache root copy expects values len {} and roots len {}, got {} and {}",
611                expected_values,
612                batch_size,
613                values_batch.len(),
614                out_roots.len()
615            )));
616        }
617
618        let device = self.provider.device().inner();
619        let copy_root = device
620            .get_func(
621                xlog_cuda::CIRCUIT_MODULE,
622                xlog_cuda::circuit_kernels::XGCF_COPY_ROOT_CACHED_META_BATCHED,
623            )
624            .ok_or_else(|| {
625                XlogError::Kernel("xgcf_copy_root_cached_meta_batched not found".to_string())
626            })?;
627        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
628        unsafe {
629            copy_root.clone().launch(
630                LaunchConfig {
631                    grid_dim: (batch_size, 1, 1),
632                    block_dim: (1, 1, 1),
633                    shared_mem_bytes: 0,
634                },
635                (
636                    handle.slot_device(),
637                    self.node_cap,
638                    &self.meta_root,
639                    values_batch,
640                    node_stride,
641                    out_roots,
642                    batch_size,
643                ),
644            )
645        }
646        .map_err(|e| {
647            XlogError::Kernel(format!("xgcf_copy_root_cached_meta_batched failed: {}", e))
648        })?;
649        Ok(())
650    }
651
652    pub fn new(provider: &Arc<CudaKernelProvider>, config: GpuCircuitCacheConfig) -> Result<Self> {
653        if config.num_slots == 0 {
654            return Err(XlogError::Compilation(
655                "GpuCircuitCache requires num_slots > 0".to_string(),
656            ));
657        }
658        if config.table_size == 0 {
659            return Err(XlogError::Compilation(
660                "GpuCircuitCache requires table_size > 0".to_string(),
661            ));
662        }
663        if config.table_size < config.num_slots {
664            return Err(XlogError::Compilation(format!(
665                "GpuCircuitCache table_size {} < num_slots {}",
666                config.table_size, config.num_slots
667            )));
668        }
669        if config.node_cap == 0
670            || config.edge_cap == 0
671            || config.level_cap == 0
672            || config.var_cap == 0
673        {
674            return Err(XlogError::Compilation(
675                "GpuCircuitCache requires non-zero caps".to_string(),
676            ));
677        }
678
679        let memory = provider.memory();
680        let device = provider.device().inner();
681
682        let table_len = usize::try_from(config.table_size).map_err(|_| {
683            XlogError::Compilation("GpuCircuitCache table_size overflow".to_string())
684        })?;
685        let slot_len = usize::try_from(config.num_slots).map_err(|_| {
686            XlogError::Compilation("GpuCircuitCache num_slots overflow".to_string())
687        })?;
688
689        let node_cap = usize::try_from(config.node_cap)
690            .map_err(|_| XlogError::Compilation("GpuCircuitCache node_cap overflow".to_string()))?;
691        let edge_cap = usize::try_from(config.edge_cap)
692            .map_err(|_| XlogError::Compilation("GpuCircuitCache edge_cap overflow".to_string()))?;
693        let level_cap = usize::try_from(config.level_cap).map_err(|_| {
694            XlogError::Compilation("GpuCircuitCache level_cap overflow".to_string())
695        })?;
696        let var_cap = usize::try_from(config.var_cap)
697            .map_err(|_| XlogError::Compilation("GpuCircuitCache var_cap overflow".to_string()))?;
698
699        let node_slots = slot_len.checked_mul(node_cap).ok_or_else(|| {
700            XlogError::Compilation("GpuCircuitCache node slots overflow".to_string())
701        })?;
702        let edge_slots = slot_len.checked_mul(edge_cap).ok_or_else(|| {
703            XlogError::Compilation("GpuCircuitCache edge slots overflow".to_string())
704        })?;
705        let var_slots = slot_len.checked_mul(var_cap + 1).ok_or_else(|| {
706            XlogError::Compilation("GpuCircuitCache var slots overflow".to_string())
707        })?;
708        let node_offsets = slot_len.checked_mul(node_cap + 1).ok_or_else(|| {
709            XlogError::Compilation("GpuCircuitCache offset slots overflow".to_string())
710        })?;
711        let level_offsets = slot_len.checked_mul(level_cap + 1).ok_or_else(|| {
712            XlogError::Compilation("GpuCircuitCache level offsets overflow".to_string())
713        })?;
714
715        let mut keys = memory.alloc::<u64>(table_len)?;
716        let mut slots = memory.alloc::<u32>(table_len)?;
717        let mut state = memory.alloc::<u32>(table_len)?;
718        let mut last_used = memory.alloc::<u64>(table_len)?;
719        let mut slot_states = memory.alloc::<u32>(slot_len)?;
720        let mut clock = memory.alloc::<u64>(1)?;
721
722        let mut node_type = memory.alloc::<u8>(node_slots)?;
723        let mut child_offsets = memory.alloc::<u32>(node_offsets)?;
724        let mut child_indices = memory.alloc::<u32>(edge_slots)?;
725        let mut lit = memory.alloc::<i32>(node_slots)?;
726        let mut decision_var = memory.alloc::<u32>(node_slots)?;
727        let mut decision_child_false = memory.alloc::<u32>(node_slots)?;
728        let mut decision_child_true = memory.alloc::<u32>(node_slots)?;
729        let mut level_nodes = memory.alloc::<u32>(node_slots)?;
730        let mut level_offsets = memory.alloc::<u32>(level_offsets)?;
731
732        let mut var_log_true = memory.alloc::<f64>(var_slots)?;
733        let mut var_log_false = memory.alloc::<f64>(var_slots)?;
734        let mut values = memory.alloc::<f64>(node_slots)?;
735        let mut adj = memory.alloc::<f64>(node_slots)?;
736        let mut grad_true = memory.alloc::<f64>(var_slots)?;
737        let mut grad_false = memory.alloc::<f64>(var_slots)?;
738        let mut free_var_mask = memory.alloc::<u8>(var_slots)?;
739        let mut meta_num_nodes = memory.alloc::<u32>(slot_len)?;
740        let mut meta_num_levels = memory.alloc::<u32>(slot_len)?;
741        let mut meta_root = memory.alloc::<u32>(slot_len)?;
742        let mut meta_max_var = memory.alloc::<u32>(slot_len)?;
743        let mut always_on = memory.alloc::<u32>(1)?;
744        let zero_len = node_cap.max(var_cap + 1);
745        let mut zero_f64 = memory.alloc::<f64>(zero_len)?;
746        let mut one_f64 = memory.alloc::<f64>(1)?;
747
748        device
749            .memset_zeros(&mut keys)
750            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero keys failed: {}", e)))?;
751        device
752            .memset_zeros(&mut slots)
753            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero slots failed: {}", e)))?;
754        device
755            .memset_zeros(&mut state)
756            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero state failed: {}", e)))?;
757        device.memset_zeros(&mut last_used).map_err(|e| {
758            XlogError::Kernel(format!("GpuCircuitCache zero last_used failed: {}", e))
759        })?;
760        device.memset_zeros(&mut slot_states).map_err(|e| {
761            XlogError::Kernel(format!("GpuCircuitCache zero slot_states failed: {}", e))
762        })?;
763        device
764            .memset_zeros(&mut clock)
765            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero clock failed: {}", e)))?;
766
767        device.memset_zeros(&mut node_type).map_err(|e| {
768            XlogError::Kernel(format!("GpuCircuitCache zero node_type failed: {}", e))
769        })?;
770        device.memset_zeros(&mut child_offsets).map_err(|e| {
771            XlogError::Kernel(format!("GpuCircuitCache zero child_offsets failed: {}", e))
772        })?;
773        device.memset_zeros(&mut child_indices).map_err(|e| {
774            XlogError::Kernel(format!("GpuCircuitCache zero child_indices failed: {}", e))
775        })?;
776        device
777            .memset_zeros(&mut lit)
778            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero lit failed: {}", e)))?;
779        device.memset_zeros(&mut decision_var).map_err(|e| {
780            XlogError::Kernel(format!("GpuCircuitCache zero decision_var failed: {}", e))
781        })?;
782        device
783            .memset_zeros(&mut decision_child_false)
784            .map_err(|e| {
785                XlogError::Kernel(format!(
786                    "GpuCircuitCache zero decision_child_false failed: {}",
787                    e
788                ))
789            })?;
790        device.memset_zeros(&mut decision_child_true).map_err(|e| {
791            XlogError::Kernel(format!(
792                "GpuCircuitCache zero decision_child_true failed: {}",
793                e
794            ))
795        })?;
796        device.memset_zeros(&mut level_nodes).map_err(|e| {
797            XlogError::Kernel(format!("GpuCircuitCache zero level_nodes failed: {}", e))
798        })?;
799        device.memset_zeros(&mut level_offsets).map_err(|e| {
800            XlogError::Kernel(format!("GpuCircuitCache zero level_offsets failed: {}", e))
801        })?;
802        device.memset_zeros(&mut var_log_true).map_err(|e| {
803            XlogError::Kernel(format!("GpuCircuitCache zero var_log_true failed: {}", e))
804        })?;
805        device.memset_zeros(&mut var_log_false).map_err(|e| {
806            XlogError::Kernel(format!("GpuCircuitCache zero var_log_false failed: {}", e))
807        })?;
808        device
809            .memset_zeros(&mut values)
810            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero values failed: {}", e)))?;
811        device
812            .memset_zeros(&mut adj)
813            .map_err(|e| XlogError::Kernel(format!("GpuCircuitCache zero adj failed: {}", e)))?;
814        device.memset_zeros(&mut grad_true).map_err(|e| {
815            XlogError::Kernel(format!("GpuCircuitCache zero grad_true failed: {}", e))
816        })?;
817        device.memset_zeros(&mut grad_false).map_err(|e| {
818            XlogError::Kernel(format!("GpuCircuitCache zero grad_false failed: {}", e))
819        })?;
820        device.memset_zeros(&mut free_var_mask).map_err(|e| {
821            XlogError::Kernel(format!("GpuCircuitCache zero free_var_mask failed: {}", e))
822        })?;
823        device.memset_zeros(&mut meta_num_nodes).map_err(|e| {
824            XlogError::Kernel(format!("GpuCircuitCache zero meta_num_nodes failed: {}", e))
825        })?;
826        device.memset_zeros(&mut meta_num_levels).map_err(|e| {
827            XlogError::Kernel(format!(
828                "GpuCircuitCache zero meta_num_levels failed: {}",
829                e
830            ))
831        })?;
832        device.memset_zeros(&mut meta_root).map_err(|e| {
833            XlogError::Kernel(format!("GpuCircuitCache zero meta_root failed: {}", e))
834        })?;
835        device.memset_zeros(&mut meta_max_var).map_err(|e| {
836            XlogError::Kernel(format!("GpuCircuitCache zero meta_max_var failed: {}", e))
837        })?;
838        device.memset_zeros(&mut zero_f64).map_err(|e| {
839            XlogError::Kernel(format!("GpuCircuitCache zero zero_f64 failed: {}", e))
840        })?;
841        provider
842            .htod_launch_metadata_sync_copy_into(&[1u32], &mut always_on)
843            .map_err(|e| {
844                XlogError::Kernel(format!("GpuCircuitCache init always_on failed: {}", e))
845            })?;
846        provider
847            .htod_launch_metadata_sync_copy_into(&[1.0f64], &mut one_f64)
848            .map_err(|e| {
849                XlogError::Kernel(format!("GpuCircuitCache init one_f64 failed: {}", e))
850            })?;
851
852        Ok(Self {
853            provider: provider.clone(),
854            table_size: config.table_size,
855            num_slots: config.num_slots,
856            node_cap: config.node_cap,
857            edge_cap: config.edge_cap,
858            level_cap: config.level_cap,
859            var_cap: config.var_cap,
860            keys,
861            slots,
862            state,
863            last_used,
864            slot_states,
865            clock,
866            node_type,
867            child_offsets,
868            child_indices,
869            lit,
870            decision_var,
871            decision_child_false,
872            decision_child_true,
873            level_nodes,
874            level_offsets,
875            var_log_true,
876            var_log_false,
877            values,
878            adj,
879            grad_true,
880            grad_false,
881            meta_num_nodes,
882            meta_num_levels,
883            meta_root,
884            meta_max_var,
885            always_on,
886            zero_f64,
887            one_f64,
888            free_var_mask,
889            has_free_var_mask: vec![false; config.num_slots as usize],
890        })
891    }
892
893    pub fn lookup_or_insert(&mut self, key: u64) -> Result<GpuCacheLookup> {
894        let memory = self.provider.memory();
895        let mut key_device = memory.alloc::<u64>(1)?;
896        self.provider
897            .htod_launch_metadata_sync_copy_into(&[key], &mut key_device)
898            .map_err(|e| XlogError::Kernel(format!("cache upload key failed: {}", e)))?;
899        self.lookup_or_insert_device(&key_device)
900    }
901
902    pub(crate) fn lookup_or_insert_device(
903        &mut self,
904        key_device: &TrackedCudaSlice<u64>,
905    ) -> Result<GpuCacheLookup> {
906        let memory = self.provider.memory();
907        let mut out_slot = memory.alloc::<u32>(1)?;
908        let mut out_compile_needed = memory.alloc::<u32>(1)?;
909
910        let func = self
911            .provider
912            .device()
913            .inner()
914            .get_func(CACHE_MODULE, cache_kernels::CACHE_LOOKUP_OR_INSERT)
915            .ok_or_else(|| {
916                XlogError::Kernel("cache_lookup_or_insert kernel not found".to_string())
917            })?;
918
919        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
920        unsafe {
921            func.clone().launch(
922                LaunchConfig {
923                    grid_dim: (1, 1, 1),
924                    block_dim: (1, 1, 1),
925                    shared_mem_bytes: 0,
926                },
927                (
928                    key_device,
929                    self.table_size,
930                    self.num_slots,
931                    &mut self.keys,
932                    &mut self.slots,
933                    &mut self.state,
934                    &mut self.last_used,
935                    &mut self.slot_states,
936                    &mut self.clock,
937                    &mut out_slot,
938                    &mut out_compile_needed,
939                ),
940            )
941        }
942        .map_err(|e| XlogError::Kernel(format!("cache_lookup_or_insert failed: {}", e)))?;
943        // No device synchronize: slot and compile_needed stay device-resident; same-stream ordering suffices.
944        Ok(GpuCacheLookup {
945            provider: self.provider.clone(),
946            slot: out_slot,
947            compile_needed: out_compile_needed,
948        })
949    }
950
951    pub fn claim_slot(&mut self, key: u64) -> Result<GpuCircuitCacheHandle> {
952        let lookup = self.lookup_or_insert(key)?;
953        lookup.into_handle()
954    }
955
956    pub fn store_from_xgcf(
957        &mut self,
958        handle: &mut GpuCircuitCacheHandle,
959        xgcf: &GpuXgcf,
960    ) -> Result<()> {
961        // Download the actual node/edge counts from device-resident metadata.
962        // xgcf.num_nodes() / num_edges() return the CAPACITY (node_cap / edge_cap),
963        // not the actual count produced by d4. Using the capacity would store garbage
964        // data beyond the actual circuit, corrupting disk cache artifacts.
965        let device = self.provider.device().inner();
966        let num_nodes_host: Vec<u32> = device
967            .dtoh_sync_copy(xgcf.num_nodes_device())
968            .map_err(|e| XlogError::Kernel(format!("dtoh meta_num_nodes: {}", e)))?;
969        let num_nodes = num_nodes_host[0];
970        if num_nodes == 0 {
971            return Err(XlogError::Compilation(
972                "GpuCircuitCache store: num_nodes must be > 0".to_string(),
973            ));
974        }
975        if num_nodes > self.node_cap {
976            return Err(XlogError::Compilation(format!(
977                "GpuCircuitCache store: num_nodes {} exceeds node_cap {}",
978                num_nodes, self.node_cap
979            )));
980        }
981
982        let num_edges_host: Vec<u32> = device
983            .dtoh_sync_copy(xgcf.num_edges_device())
984            .map_err(|e| XlogError::Kernel(format!("dtoh meta_num_edges: {}", e)))?;
985        let num_edges = num_edges_host[0];
986        if num_edges > self.edge_cap {
987            return Err(XlogError::Compilation(format!(
988                "GpuCircuitCache store: num_edges {} exceeds edge_cap {}",
989                num_edges, self.edge_cap
990            )));
991        }
992
993        let num_levels = xgcf.num_levels();
994        if num_levels == 0 {
995            return Err(XlogError::Compilation(
996                "GpuCircuitCache store: num_levels must be > 0".to_string(),
997            ));
998        }
999        if num_levels > self.level_cap {
1000            return Err(XlogError::Compilation(format!(
1001                "GpuCircuitCache store: num_levels {} exceeds level_cap {}",
1002                num_levels, self.level_cap
1003            )));
1004        }
1005
1006        let root = xgcf.root();
1007        if root >= num_nodes {
1008            return Err(XlogError::Compilation(format!(
1009                "GpuCircuitCache store: root {} out of bounds (num_nodes={})",
1010                root, num_nodes
1011            )));
1012        }
1013
1014        let max_var = xgcf.max_var();
1015        if max_var > self.var_cap {
1016            return Err(XlogError::Compilation(format!(
1017                "GpuCircuitCache store: max_var {} exceeds var_cap {}",
1018                max_var, self.var_cap
1019            )));
1020        }
1021
1022        let expected_child_offsets = (num_nodes as usize) + 1;
1023        if xgcf.child_offsets().len() < expected_child_offsets {
1024            return Err(XlogError::Compilation(format!(
1025                "GpuCircuitCache store: child_offsets len {} < num_nodes+1 {}",
1026                xgcf.child_offsets().len(),
1027                expected_child_offsets
1028            )));
1029        }
1030        if xgcf.level_nodes().len() < num_nodes as usize {
1031            return Err(XlogError::Compilation(format!(
1032                "GpuCircuitCache store: level_nodes len {} < num_nodes {}",
1033                xgcf.level_nodes().len(),
1034                num_nodes
1035            )));
1036        }
1037        let expected_level_offsets = (num_levels as usize) + 1;
1038        if xgcf.level_offsets().len() != expected_level_offsets {
1039            return Err(XlogError::Compilation(format!(
1040                "GpuCircuitCache store: level_offsets len {} != num_levels+1 {}",
1041                xgcf.level_offsets().len(),
1042                expected_level_offsets
1043            )));
1044        }
1045
1046        handle.num_nodes = num_nodes;
1047        handle.num_levels = num_levels;
1048        handle.root = root;
1049        handle.max_var = max_var;
1050
1051        let store_u8 = device
1052            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U8)
1053            .ok_or_else(|| XlogError::Kernel("cache_store_u8 kernel not found".to_string()))?;
1054        let store_u32 = device
1055            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U32)
1056            .ok_or_else(|| XlogError::Kernel("cache_store_u32 kernel not found".to_string()))?;
1057        let store_i32 = device
1058            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_I32)
1059            .ok_or_else(|| XlogError::Kernel("cache_store_i32 kernel not found".to_string()))?;
1060        let store_f64 = device
1061            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
1062            .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
1063        let store_meta = device
1064            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_META)
1065            .ok_or_else(|| XlogError::Kernel("cache_store_meta kernel not found".to_string()))?;
1066
1067        let block_dim = 256u32;
1068
1069        let node_stride = self.node_cap;
1070        let offset_stride = self.node_cap.checked_add(1).ok_or_else(|| {
1071            XlogError::Compilation("GpuCircuitCache store: node_cap overflow".to_string())
1072        })?;
1073        let level_offset_stride = self.level_cap.checked_add(1).ok_or_else(|| {
1074            XlogError::Compilation("GpuCircuitCache store: level_cap overflow".to_string())
1075        })?;
1076        let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1077            XlogError::Compilation("GpuCircuitCache store: var_cap overflow".to_string())
1078        })?;
1079
1080        let num_nodes_plus1 = num_nodes.checked_add(1).ok_or_else(|| {
1081            XlogError::Compilation("GpuCircuitCache store: num_nodes overflow".to_string())
1082        })?;
1083        let num_levels_plus1 = num_levels.checked_add(1).ok_or_else(|| {
1084            XlogError::Compilation("GpuCircuitCache store: num_levels overflow".to_string())
1085        })?;
1086        let weights_len = max_var.checked_add(1).ok_or_else(|| {
1087            XlogError::Compilation("GpuCircuitCache store: max_var overflow".to_string())
1088        })?;
1089
1090        let grid_nodes =
1091            cache_grid_dim_for_u32_count("GpuCircuitCache store node_type", num_nodes, block_dim)?;
1092        if grid_nodes != 0 {
1093            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1094            unsafe {
1095                store_u8.clone().launch(
1096                    LaunchConfig {
1097                        grid_dim: (grid_nodes, 1, 1),
1098                        block_dim: (block_dim, 1, 1),
1099                        shared_mem_bytes: 0,
1100                    },
1101                    (
1102                        handle.slot_device(),
1103                        handle.compile_needed_device(),
1104                        node_stride,
1105                        xgcf.node_type(),
1106                        &mut self.node_type,
1107                        num_nodes,
1108                    ),
1109                )
1110            }
1111            .map_err(|e| XlogError::Kernel(format!("cache_store_u8 failed: {}", e)))?;
1112        }
1113
1114        let grid_offsets = cache_grid_dim_for_u32_count(
1115            "GpuCircuitCache store child_offsets",
1116            num_nodes_plus1,
1117            block_dim,
1118        )?;
1119        if grid_offsets != 0 {
1120            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1121            unsafe {
1122                store_u32.clone().launch(
1123                    LaunchConfig {
1124                        grid_dim: (grid_offsets, 1, 1),
1125                        block_dim: (block_dim, 1, 1),
1126                        shared_mem_bytes: 0,
1127                    },
1128                    (
1129                        handle.slot_device(),
1130                        handle.compile_needed_device(),
1131                        offset_stride,
1132                        xgcf.child_offsets(),
1133                        &mut self.child_offsets,
1134                        num_nodes_plus1,
1135                    ),
1136                )
1137            }
1138            .map_err(|e| XlogError::Kernel(format!("cache_store_child_offsets failed: {}", e)))?;
1139        }
1140
1141        let grid_edges = cache_grid_dim_for_u32_count(
1142            "GpuCircuitCache store child_indices",
1143            num_edges,
1144            block_dim,
1145        )?;
1146        if grid_edges != 0 {
1147            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1148            unsafe {
1149                store_u32.clone().launch(
1150                    LaunchConfig {
1151                        grid_dim: (grid_edges, 1, 1),
1152                        block_dim: (block_dim, 1, 1),
1153                        shared_mem_bytes: 0,
1154                    },
1155                    (
1156                        handle.slot_device(),
1157                        handle.compile_needed_device(),
1158                        self.edge_cap,
1159                        xgcf.child_indices(),
1160                        &mut self.child_indices,
1161                        num_edges,
1162                    ),
1163                )
1164            }
1165            .map_err(|e| XlogError::Kernel(format!("cache_store_child_indices failed: {}", e)))?;
1166        }
1167
1168        if grid_nodes != 0 {
1169            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1170            unsafe {
1171                store_i32.clone().launch(
1172                    LaunchConfig {
1173                        grid_dim: (grid_nodes, 1, 1),
1174                        block_dim: (block_dim, 1, 1),
1175                        shared_mem_bytes: 0,
1176                    },
1177                    (
1178                        handle.slot_device(),
1179                        handle.compile_needed_device(),
1180                        node_stride,
1181                        xgcf.lit(),
1182                        &mut self.lit,
1183                        num_nodes,
1184                    ),
1185                )
1186            }
1187            .map_err(|e| XlogError::Kernel(format!("cache_store_lit failed: {}", e)))?;
1188
1189            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1190            unsafe {
1191                store_u32.clone().launch(
1192                    LaunchConfig {
1193                        grid_dim: (grid_nodes, 1, 1),
1194                        block_dim: (block_dim, 1, 1),
1195                        shared_mem_bytes: 0,
1196                    },
1197                    (
1198                        handle.slot_device(),
1199                        handle.compile_needed_device(),
1200                        node_stride,
1201                        xgcf.decision_var(),
1202                        &mut self.decision_var,
1203                        num_nodes,
1204                    ),
1205                )
1206            }
1207            .map_err(|e| XlogError::Kernel(format!("cache_store_decision_var failed: {}", e)))?;
1208
1209            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1210            unsafe {
1211                store_u32.clone().launch(
1212                    LaunchConfig {
1213                        grid_dim: (grid_nodes, 1, 1),
1214                        block_dim: (block_dim, 1, 1),
1215                        shared_mem_bytes: 0,
1216                    },
1217                    (
1218                        handle.slot_device(),
1219                        handle.compile_needed_device(),
1220                        node_stride,
1221                        xgcf.decision_child_false(),
1222                        &mut self.decision_child_false,
1223                        num_nodes,
1224                    ),
1225                )
1226            }
1227            .map_err(|e| {
1228                XlogError::Kernel(format!("cache_store_decision_child_false failed: {}", e))
1229            })?;
1230
1231            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1232            unsafe {
1233                store_u32.clone().launch(
1234                    LaunchConfig {
1235                        grid_dim: (grid_nodes, 1, 1),
1236                        block_dim: (block_dim, 1, 1),
1237                        shared_mem_bytes: 0,
1238                    },
1239                    (
1240                        handle.slot_device(),
1241                        handle.compile_needed_device(),
1242                        node_stride,
1243                        xgcf.decision_child_true(),
1244                        &mut self.decision_child_true,
1245                        num_nodes,
1246                    ),
1247                )
1248            }
1249            .map_err(|e| {
1250                XlogError::Kernel(format!("cache_store_decision_child_true failed: {}", e))
1251            })?;
1252
1253            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1254            unsafe {
1255                store_u32.clone().launch(
1256                    LaunchConfig {
1257                        grid_dim: (grid_nodes, 1, 1),
1258                        block_dim: (block_dim, 1, 1),
1259                        shared_mem_bytes: 0,
1260                    },
1261                    (
1262                        handle.slot_device(),
1263                        handle.compile_needed_device(),
1264                        node_stride,
1265                        xgcf.level_nodes(),
1266                        &mut self.level_nodes,
1267                        num_nodes,
1268                    ),
1269                )
1270            }
1271            .map_err(|e| XlogError::Kernel(format!("cache_store_level_nodes failed: {}", e)))?;
1272        }
1273
1274        let grid_levels = cache_grid_dim_for_u32_count(
1275            "GpuCircuitCache store level_offsets",
1276            num_levels_plus1,
1277            block_dim,
1278        )?;
1279        if grid_levels != 0 {
1280            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1281            unsafe {
1282                store_u32.clone().launch(
1283                    LaunchConfig {
1284                        grid_dim: (grid_levels, 1, 1),
1285                        block_dim: (block_dim, 1, 1),
1286                        shared_mem_bytes: 0,
1287                    },
1288                    (
1289                        handle.slot_device(),
1290                        handle.compile_needed_device(),
1291                        level_offset_stride,
1292                        xgcf.level_offsets(),
1293                        &mut self.level_offsets,
1294                        num_levels_plus1,
1295                    ),
1296                )
1297            }
1298            .map_err(|e| XlogError::Kernel(format!("cache_store_level_offsets failed: {}", e)))?;
1299        }
1300
1301        let grid_weights = cache_grid_dim_for_u32_count(
1302            "GpuCircuitCache store free_var_mask",
1303            weights_len,
1304            block_dim,
1305        )?;
1306        if grid_weights != 0 {
1307            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1308            unsafe {
1309                store_f64.clone().launch(
1310                    LaunchConfig {
1311                        grid_dim: (grid_weights, 1, 1),
1312                        block_dim: (block_dim, 1, 1),
1313                        shared_mem_bytes: 0,
1314                    },
1315                    (
1316                        handle.slot_device(),
1317                        handle.compile_needed_device(),
1318                        var_stride,
1319                        xgcf.var_log_true(),
1320                        &mut self.var_log_true,
1321                        weights_len,
1322                    ),
1323                )
1324            }
1325            .map_err(|e| XlogError::Kernel(format!("cache_store_var_log_true failed: {}", e)))?;
1326
1327            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1328            unsafe {
1329                store_f64.clone().launch(
1330                    LaunchConfig {
1331                        grid_dim: (grid_weights, 1, 1),
1332                        block_dim: (block_dim, 1, 1),
1333                        shared_mem_bytes: 0,
1334                    },
1335                    (
1336                        handle.slot_device(),
1337                        handle.compile_needed_device(),
1338                        var_stride,
1339                        xgcf.var_log_false(),
1340                        &mut self.var_log_false,
1341                        weights_len,
1342                    ),
1343                )
1344            }
1345            .map_err(|e| XlogError::Kernel(format!("cache_store_var_log_false failed: {}", e)))?;
1346        }
1347
1348        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1349        unsafe {
1350            store_meta.clone().launch(
1351                LaunchConfig {
1352                    grid_dim: (1, 1, 1),
1353                    block_dim: (1, 1, 1),
1354                    shared_mem_bytes: 0,
1355                },
1356                (
1357                    handle.slot_device(),
1358                    handle.compile_needed_device(),
1359                    self.num_slots,
1360                    num_nodes,
1361                    num_levels,
1362                    root,
1363                    max_var,
1364                    &mut self.meta_num_nodes,
1365                    &mut self.meta_num_levels,
1366                    &mut self.meta_root,
1367                    &mut self.meta_max_var,
1368                ),
1369            )
1370        }
1371        .map_err(|e| XlogError::Kernel(format!("cache_store_meta failed: {}", e)))?;
1372
1373        // No device synchronize needed: all stores are GPU-to-GPU on the same stream.
1374        // Same-stream ordering guarantees subsequent kernels see the stored data.
1375        Ok(())
1376    }
1377
1378    pub fn store_weights(
1379        &mut self,
1380        handle: &GpuCircuitCacheHandle,
1381        weights_true: &TrackedCudaSlice<f64>,
1382        weights_false: &TrackedCudaSlice<f64>,
1383    ) -> Result<()> {
1384        let weights_len = handle.max_var.checked_add(1).ok_or_else(|| {
1385            XlogError::Compilation("GpuCircuitCache store_weights max_var overflow".to_string())
1386        })?;
1387        let weights_len_usize = usize::try_from(weights_len).map_err(|_| {
1388            XlogError::Compilation("GpuCircuitCache store_weights len overflow".to_string())
1389        })?;
1390        if weights_true.len() < weights_len_usize || weights_false.len() < weights_len_usize {
1391            return Err(XlogError::Compilation(format!(
1392                "GpuCircuitCache store_weights requires weights len >= {}, got true={} false={}",
1393                weights_len,
1394                weights_true.len(),
1395                weights_false.len()
1396            )));
1397        }
1398
1399        let device = self.provider.device().inner();
1400        let store_f64 = device
1401            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
1402            .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
1403
1404        let block_dim = 256u32;
1405        let grid_dim = if weights_len == 0 {
1406            0
1407        } else {
1408            weights_len.div_ceil(block_dim)
1409        };
1410        if grid_dim != 0 {
1411            let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1412                XlogError::Compilation("GpuCircuitCache store_weights var_cap overflow".to_string())
1413            })?;
1414            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1415            unsafe {
1416                store_f64.clone().launch(
1417                    LaunchConfig {
1418                        grid_dim: (grid_dim, 1, 1),
1419                        block_dim: (block_dim, 1, 1),
1420                        shared_mem_bytes: 0,
1421                    },
1422                    (
1423                        handle.slot_device(),
1424                        handle.compile_needed_device(),
1425                        var_stride,
1426                        weights_true,
1427                        &mut self.var_log_true,
1428                        weights_len,
1429                    ),
1430                )
1431            }
1432            .map_err(|e| XlogError::Kernel(format!("cache_store_weights_true failed: {}", e)))?;
1433
1434            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1435            unsafe {
1436                store_f64.clone().launch(
1437                    LaunchConfig {
1438                        grid_dim: (grid_dim, 1, 1),
1439                        block_dim: (block_dim, 1, 1),
1440                        shared_mem_bytes: 0,
1441                    },
1442                    (
1443                        handle.slot_device(),
1444                        handle.compile_needed_device(),
1445                        var_stride,
1446                        weights_false,
1447                        &mut self.var_log_false,
1448                        weights_len,
1449                    ),
1450                )
1451            }
1452            .map_err(|e| XlogError::Kernel(format!("cache_store_weights_false failed: {}", e)))?;
1453        }
1454
1455        // No device synchronize: same-stream ordering guarantees visibility.
1456        Ok(())
1457    }
1458
1459    pub fn overwrite_weights(
1460        &mut self,
1461        handle: &GpuCircuitCacheHandle,
1462        weights_true: &TrackedCudaSlice<f64>,
1463        weights_false: &TrackedCudaSlice<f64>,
1464    ) -> Result<()> {
1465        let weights_len = handle.max_var.checked_add(1).ok_or_else(|| {
1466            XlogError::Compilation("GpuCircuitCache overwrite_weights max_var overflow".to_string())
1467        })?;
1468        let weights_len_usize = usize::try_from(weights_len).map_err(|_| {
1469            XlogError::Compilation("GpuCircuitCache overwrite_weights len overflow".to_string())
1470        })?;
1471        if weights_true.len() < weights_len_usize || weights_false.len() < weights_len_usize {
1472            return Err(XlogError::Compilation(format!(
1473                "GpuCircuitCache overwrite_weights requires weights len >= {}, got true={} false={}",
1474                weights_len,
1475                weights_true.len(),
1476                weights_false.len()
1477            )));
1478        }
1479
1480        let device = self.provider.device().inner();
1481        let store_f64 = device
1482            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
1483            .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
1484
1485        let block_dim = 256u32;
1486        let grid_dim = if weights_len == 0 {
1487            0
1488        } else {
1489            weights_len.div_ceil(block_dim)
1490        };
1491        if grid_dim != 0 {
1492            let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1493                XlogError::Compilation(
1494                    "GpuCircuitCache overwrite_weights var_cap overflow".to_string(),
1495                )
1496            })?;
1497            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1498            unsafe {
1499                store_f64.clone().launch(
1500                    LaunchConfig {
1501                        grid_dim: (grid_dim, 1, 1),
1502                        block_dim: (block_dim, 1, 1),
1503                        shared_mem_bytes: 0,
1504                    },
1505                    (
1506                        handle.slot_device(),
1507                        &self.always_on,
1508                        var_stride,
1509                        weights_true,
1510                        &mut self.var_log_true,
1511                        weights_len,
1512                    ),
1513                )
1514            }
1515            .map_err(|e| {
1516                XlogError::Kernel(format!("cache_overwrite_weights_true failed: {}", e))
1517            })?;
1518
1519            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1520            unsafe {
1521                store_f64.clone().launch(
1522                    LaunchConfig {
1523                        grid_dim: (grid_dim, 1, 1),
1524                        block_dim: (block_dim, 1, 1),
1525                        shared_mem_bytes: 0,
1526                    },
1527                    (
1528                        handle.slot_device(),
1529                        &self.always_on,
1530                        var_stride,
1531                        weights_false,
1532                        &mut self.var_log_false,
1533                        weights_len,
1534                    ),
1535                )
1536            }
1537            .map_err(|e| {
1538                XlogError::Kernel(format!("cache_overwrite_weights_false failed: {}", e))
1539            })?;
1540        }
1541
1542        // No device synchronize: same-stream ordering guarantees visibility.
1543        Ok(())
1544    }
1545
1546    pub fn store_free_var_mask(
1547        &mut self,
1548        handle: &GpuCircuitCacheHandle,
1549        mask: &TrackedCudaSlice<u8>,
1550    ) -> Result<()> {
1551        let mask_len = u32::try_from(mask.len()).map_err(|_| {
1552            XlogError::Compilation("GpuCircuitCache free_var_mask len overflow".to_string())
1553        })?;
1554        let expected_len = handle.max_var.checked_add(1).ok_or_else(|| {
1555            XlogError::Compilation("GpuCircuitCache free_var_mask max_var overflow".to_string())
1556        })?;
1557        if mask_len != expected_len {
1558            return Err(XlogError::Compilation(format!(
1559                "GpuCircuitCache free_var_mask len {} != expected {}",
1560                mask_len, expected_len
1561            )));
1562        }
1563        let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1564            XlogError::Compilation("GpuCircuitCache free_var_mask var_cap overflow".to_string())
1565        })?;
1566        if expected_len > var_stride {
1567            return Err(XlogError::Compilation(format!(
1568                "GpuCircuitCache free_var_mask len {} exceeds var_cap+1 {}",
1569                expected_len, var_stride
1570            )));
1571        }
1572
1573        let device = self.provider.device().inner();
1574        let store_u8 = device
1575            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U8)
1576            .ok_or_else(|| XlogError::Kernel("cache_store_u8 kernel not found".to_string()))?;
1577
1578        let block_dim = 256u32;
1579        let grid_dim = mask_len.div_ceil(block_dim);
1580        if grid_dim == 0 {
1581            return Ok(());
1582        }
1583
1584        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1585        unsafe {
1586            store_u8.clone().launch(
1587                LaunchConfig {
1588                    grid_dim: (grid_dim, 1, 1),
1589                    block_dim: (block_dim, 1, 1),
1590                    shared_mem_bytes: 0,
1591                },
1592                (
1593                    handle.slot_device(),
1594                    handle.compile_needed_device(),
1595                    var_stride,
1596                    mask,
1597                    &mut self.free_var_mask,
1598                    mask_len,
1599                ),
1600            )
1601        }
1602        .map_err(|e| XlogError::Kernel(format!("cache_store_free_var_mask failed: {}", e)))?;
1603
1604        // No device synchronize: same-stream ordering guarantees visibility.
1605        let slot_idx = handle.slot_index() as usize;
1606        debug_assert!(
1607            slot_idx < self.has_free_var_mask.len(),
1608            "slot_index {} exceeds num_slots {}",
1609            slot_idx,
1610            self.has_free_var_mask.len()
1611        );
1612        if slot_idx < self.has_free_var_mask.len() {
1613            self.has_free_var_mask[slot_idx] = true;
1614        }
1615        Ok(())
1616    }
1617
1618    /// Populate a cache slot from host-resident arrays loaded from the disk cache.
1619    ///
1620    /// This mirrors [`store_from_xgcf`] but takes a [`disk_cache::CircuitArtifact`]
1621    /// (host `Vec`s) instead of a device-resident `GpuXgcf`. Each host array is
1622    /// uploaded to a temporary device buffer and then stored into the slot via the
1623    /// same `cache_store_*` kernels.
1624    pub(crate) fn restore_from_host_arrays(
1625        &mut self,
1626        handle: &mut GpuCircuitCacheHandle,
1627        artifact: &disk_cache::CircuitArtifact,
1628    ) -> Result<()> {
1629        // -- Validate sizes against cache caps --
1630        let num_nodes = artifact.num_nodes;
1631        if num_nodes == 0 {
1632            return Err(XlogError::Compilation(
1633                "GpuCircuitCache restore: num_nodes must be > 0".to_string(),
1634            ));
1635        }
1636        if num_nodes > self.node_cap {
1637            return Err(XlogError::Compilation(format!(
1638                "GpuCircuitCache restore: num_nodes {} exceeds node_cap {}",
1639                num_nodes, self.node_cap
1640            )));
1641        }
1642
1643        let num_edges = artifact.num_edges;
1644        if num_edges > self.edge_cap {
1645            return Err(XlogError::Compilation(format!(
1646                "GpuCircuitCache restore: num_edges {} exceeds edge_cap {}",
1647                num_edges, self.edge_cap
1648            )));
1649        }
1650
1651        let num_levels = artifact.num_levels;
1652        if num_levels == 0 {
1653            return Err(XlogError::Compilation(
1654                "GpuCircuitCache restore: num_levels must be > 0".to_string(),
1655            ));
1656        }
1657        if num_levels > self.level_cap {
1658            return Err(XlogError::Compilation(format!(
1659                "GpuCircuitCache restore: num_levels {} exceeds level_cap {}",
1660                num_levels, self.level_cap
1661            )));
1662        }
1663
1664        let root = artifact.root;
1665        if root >= num_nodes {
1666            return Err(XlogError::Compilation(format!(
1667                "GpuCircuitCache restore: root {} out of bounds (num_nodes={})",
1668                root, num_nodes
1669            )));
1670        }
1671
1672        let max_var = artifact.max_var;
1673        if max_var > self.var_cap {
1674            return Err(XlogError::Compilation(format!(
1675                "GpuCircuitCache restore: max_var {} exceeds var_cap {}",
1676                max_var, self.var_cap
1677            )));
1678        }
1679
1680        let expected_child_offsets = (num_nodes as usize) + 1;
1681        if artifact.child_offsets.len() < expected_child_offsets {
1682            return Err(XlogError::Compilation(format!(
1683                "GpuCircuitCache restore: child_offsets len {} < num_nodes+1 {}",
1684                artifact.child_offsets.len(),
1685                expected_child_offsets
1686            )));
1687        }
1688        if artifact.level_nodes.len() < num_nodes as usize {
1689            return Err(XlogError::Compilation(format!(
1690                "GpuCircuitCache restore: level_nodes len {} < num_nodes {}",
1691                artifact.level_nodes.len(),
1692                num_nodes
1693            )));
1694        }
1695        let expected_level_offsets = (num_levels as usize) + 1;
1696        if artifact.level_offsets.len() != expected_level_offsets {
1697            return Err(XlogError::Compilation(format!(
1698                "GpuCircuitCache restore: level_offsets len {} != num_levels+1 {}",
1699                artifact.level_offsets.len(),
1700                expected_level_offsets
1701            )));
1702        }
1703
1704        // -- Set handle metadata --
1705        handle.num_nodes = num_nodes;
1706        handle.num_levels = num_levels;
1707        handle.root = root;
1708        handle.max_var = max_var;
1709
1710        // -- Load kernels --
1711        let device = self.provider.device().inner();
1712        let memory = self.provider.memory();
1713
1714        let store_u8 = device
1715            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U8)
1716            .ok_or_else(|| XlogError::Kernel("cache_store_u8 kernel not found".to_string()))?;
1717        let store_u32 = device
1718            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_U32)
1719            .ok_or_else(|| XlogError::Kernel("cache_store_u32 kernel not found".to_string()))?;
1720        let store_i32 = device
1721            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_I32)
1722            .ok_or_else(|| XlogError::Kernel("cache_store_i32 kernel not found".to_string()))?;
1723        let store_meta = device
1724            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_META)
1725            .ok_or_else(|| XlogError::Kernel("cache_store_meta kernel not found".to_string()))?;
1726
1727        let block_dim = 256u32;
1728
1729        let node_stride = self.node_cap;
1730        let offset_stride = self.node_cap.checked_add(1).ok_or_else(|| {
1731            XlogError::Compilation("GpuCircuitCache restore: node_cap overflow".to_string())
1732        })?;
1733        let level_offset_stride = self.level_cap.checked_add(1).ok_or_else(|| {
1734            XlogError::Compilation("GpuCircuitCache restore: level_cap overflow".to_string())
1735        })?;
1736        let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
1737            XlogError::Compilation("GpuCircuitCache restore: var_cap overflow".to_string())
1738        })?;
1739
1740        let num_nodes_plus1 = num_nodes.checked_add(1).ok_or_else(|| {
1741            XlogError::Compilation("GpuCircuitCache restore: num_nodes overflow".to_string())
1742        })?;
1743        let num_levels_plus1 = num_levels.checked_add(1).ok_or_else(|| {
1744            XlogError::Compilation("GpuCircuitCache restore: num_levels overflow".to_string())
1745        })?;
1746
1747        // -- Upload node_type (u8, num_nodes elements) --
1748        let grid_nodes = cache_grid_dim_for_u32_count(
1749            "GpuCircuitCache restore node_type",
1750            num_nodes,
1751            block_dim,
1752        )?;
1753        if grid_nodes != 0 {
1754            let mut d_node_type = memory.alloc::<u8>(num_nodes as usize)?;
1755            self.provider
1756                .htod_sync_copy_into_tracked(
1757                    &artifact.node_type[..num_nodes as usize],
1758                    &mut d_node_type,
1759                )
1760                .map_err(|e| XlogError::Kernel(format!("restore htod node_type failed: {}", e)))?;
1761            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1762            unsafe {
1763                store_u8.clone().launch(
1764                    LaunchConfig {
1765                        grid_dim: (grid_nodes, 1, 1),
1766                        block_dim: (block_dim, 1, 1),
1767                        shared_mem_bytes: 0,
1768                    },
1769                    (
1770                        handle.slot_device(),
1771                        handle.compile_needed_device(),
1772                        node_stride,
1773                        &d_node_type,
1774                        &mut self.node_type,
1775                        num_nodes,
1776                    ),
1777                )
1778            }
1779            .map_err(|e| {
1780                XlogError::Kernel(format!("restore cache_store node_type failed: {}", e))
1781            })?;
1782        }
1783
1784        // -- Upload child_offsets (u32, num_nodes+1 elements) --
1785        let grid_offsets = cache_grid_dim_for_u32_count(
1786            "GpuCircuitCache restore child_offsets",
1787            num_nodes_plus1,
1788            block_dim,
1789        )?;
1790        if grid_offsets != 0 {
1791            let mut d_child_offsets = memory.alloc::<u32>(num_nodes_plus1 as usize)?;
1792            self.provider
1793                .htod_sync_copy_into_tracked(
1794                    &artifact.child_offsets[..num_nodes_plus1 as usize],
1795                    &mut d_child_offsets,
1796                )
1797                .map_err(|e| {
1798                    XlogError::Kernel(format!("restore htod child_offsets failed: {}", e))
1799                })?;
1800            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1801            unsafe {
1802                store_u32.clone().launch(
1803                    LaunchConfig {
1804                        grid_dim: (grid_offsets, 1, 1),
1805                        block_dim: (block_dim, 1, 1),
1806                        shared_mem_bytes: 0,
1807                    },
1808                    (
1809                        handle.slot_device(),
1810                        handle.compile_needed_device(),
1811                        offset_stride,
1812                        &d_child_offsets,
1813                        &mut self.child_offsets,
1814                        num_nodes_plus1,
1815                    ),
1816                )
1817            }
1818            .map_err(|e| {
1819                XlogError::Kernel(format!("restore cache_store child_offsets failed: {}", e))
1820            })?;
1821        }
1822
1823        // -- Upload child_indices (u32, num_edges elements) --
1824        let grid_edges = cache_grid_dim_for_u32_count(
1825            "GpuCircuitCache restore child_indices",
1826            num_edges,
1827            block_dim,
1828        )?;
1829        if grid_edges != 0 {
1830            let mut d_child_indices = memory.alloc::<u32>(num_edges as usize)?;
1831            self.provider
1832                .htod_sync_copy_into_tracked(
1833                    &artifact.child_indices[..num_edges as usize],
1834                    &mut d_child_indices,
1835                )
1836                .map_err(|e| {
1837                    XlogError::Kernel(format!("restore htod child_indices failed: {}", e))
1838                })?;
1839            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1840            unsafe {
1841                store_u32.clone().launch(
1842                    LaunchConfig {
1843                        grid_dim: (grid_edges, 1, 1),
1844                        block_dim: (block_dim, 1, 1),
1845                        shared_mem_bytes: 0,
1846                    },
1847                    (
1848                        handle.slot_device(),
1849                        handle.compile_needed_device(),
1850                        self.edge_cap,
1851                        &d_child_indices,
1852                        &mut self.child_indices,
1853                        num_edges,
1854                    ),
1855                )
1856            }
1857            .map_err(|e| {
1858                XlogError::Kernel(format!("restore cache_store child_indices failed: {}", e))
1859            })?;
1860        }
1861
1862        // -- Upload lit (i32, num_nodes elements) --
1863        if grid_nodes != 0 {
1864            let mut d_lit = memory.alloc::<i32>(num_nodes as usize)?;
1865            self.provider
1866                .htod_sync_copy_into_tracked(&artifact.lit[..num_nodes as usize], &mut d_lit)
1867                .map_err(|e| XlogError::Kernel(format!("restore htod lit failed: {}", e)))?;
1868            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1869            unsafe {
1870                store_i32.clone().launch(
1871                    LaunchConfig {
1872                        grid_dim: (grid_nodes, 1, 1),
1873                        block_dim: (block_dim, 1, 1),
1874                        shared_mem_bytes: 0,
1875                    },
1876                    (
1877                        handle.slot_device(),
1878                        handle.compile_needed_device(),
1879                        node_stride,
1880                        &d_lit,
1881                        &mut self.lit,
1882                        num_nodes,
1883                    ),
1884                )
1885            }
1886            .map_err(|e| XlogError::Kernel(format!("restore cache_store lit failed: {}", e)))?;
1887
1888            // -- Upload decision_var (u32, num_nodes elements) --
1889            let mut d_decision_var = memory.alloc::<u32>(num_nodes as usize)?;
1890            self.provider
1891                .htod_sync_copy_into_tracked(
1892                    &artifact.decision_var[..num_nodes as usize],
1893                    &mut d_decision_var,
1894                )
1895                .map_err(|e| {
1896                    XlogError::Kernel(format!("restore htod decision_var failed: {}", e))
1897                })?;
1898            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1899            unsafe {
1900                store_u32.clone().launch(
1901                    LaunchConfig {
1902                        grid_dim: (grid_nodes, 1, 1),
1903                        block_dim: (block_dim, 1, 1),
1904                        shared_mem_bytes: 0,
1905                    },
1906                    (
1907                        handle.slot_device(),
1908                        handle.compile_needed_device(),
1909                        node_stride,
1910                        &d_decision_var,
1911                        &mut self.decision_var,
1912                        num_nodes,
1913                    ),
1914                )
1915            }
1916            .map_err(|e| {
1917                XlogError::Kernel(format!("restore cache_store decision_var failed: {}", e))
1918            })?;
1919
1920            // -- Upload decision_child_false (u32, num_nodes elements) --
1921            let mut d_decision_child_false = memory.alloc::<u32>(num_nodes as usize)?;
1922            self.provider
1923                .htod_sync_copy_into_tracked(
1924                    &artifact.decision_child_false[..num_nodes as usize],
1925                    &mut d_decision_child_false,
1926                )
1927                .map_err(|e| {
1928                    XlogError::Kernel(format!("restore htod decision_child_false failed: {}", e))
1929                })?;
1930            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1931            unsafe {
1932                store_u32.clone().launch(
1933                    LaunchConfig {
1934                        grid_dim: (grid_nodes, 1, 1),
1935                        block_dim: (block_dim, 1, 1),
1936                        shared_mem_bytes: 0,
1937                    },
1938                    (
1939                        handle.slot_device(),
1940                        handle.compile_needed_device(),
1941                        node_stride,
1942                        &d_decision_child_false,
1943                        &mut self.decision_child_false,
1944                        num_nodes,
1945                    ),
1946                )
1947            }
1948            .map_err(|e| {
1949                XlogError::Kernel(format!(
1950                    "restore cache_store decision_child_false failed: {}",
1951                    e
1952                ))
1953            })?;
1954
1955            // -- Upload decision_child_true (u32, num_nodes elements) --
1956            let mut d_decision_child_true = memory.alloc::<u32>(num_nodes as usize)?;
1957            self.provider
1958                .htod_sync_copy_into_tracked(
1959                    &artifact.decision_child_true[..num_nodes as usize],
1960                    &mut d_decision_child_true,
1961                )
1962                .map_err(|e| {
1963                    XlogError::Kernel(format!("restore htod decision_child_true failed: {}", e))
1964                })?;
1965            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1966            unsafe {
1967                store_u32.clone().launch(
1968                    LaunchConfig {
1969                        grid_dim: (grid_nodes, 1, 1),
1970                        block_dim: (block_dim, 1, 1),
1971                        shared_mem_bytes: 0,
1972                    },
1973                    (
1974                        handle.slot_device(),
1975                        handle.compile_needed_device(),
1976                        node_stride,
1977                        &d_decision_child_true,
1978                        &mut self.decision_child_true,
1979                        num_nodes,
1980                    ),
1981                )
1982            }
1983            .map_err(|e| {
1984                XlogError::Kernel(format!(
1985                    "restore cache_store decision_child_true failed: {}",
1986                    e
1987                ))
1988            })?;
1989
1990            // -- Upload level_nodes (u32, num_nodes elements) --
1991            let mut d_level_nodes = memory.alloc::<u32>(num_nodes as usize)?;
1992            self.provider
1993                .htod_sync_copy_into_tracked(
1994                    &artifact.level_nodes[..num_nodes as usize],
1995                    &mut d_level_nodes,
1996                )
1997                .map_err(|e| {
1998                    XlogError::Kernel(format!("restore htod level_nodes failed: {}", e))
1999                })?;
2000            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2001            unsafe {
2002                store_u32.clone().launch(
2003                    LaunchConfig {
2004                        grid_dim: (grid_nodes, 1, 1),
2005                        block_dim: (block_dim, 1, 1),
2006                        shared_mem_bytes: 0,
2007                    },
2008                    (
2009                        handle.slot_device(),
2010                        handle.compile_needed_device(),
2011                        node_stride,
2012                        &d_level_nodes,
2013                        &mut self.level_nodes,
2014                        num_nodes,
2015                    ),
2016                )
2017            }
2018            .map_err(|e| {
2019                XlogError::Kernel(format!("restore cache_store level_nodes failed: {}", e))
2020            })?;
2021        }
2022
2023        // -- Upload level_offsets (u32, num_levels+1 elements) --
2024        let grid_levels = cache_grid_dim_for_u32_count(
2025            "GpuCircuitCache restore level_offsets",
2026            num_levels_plus1,
2027            block_dim,
2028        )?;
2029        if grid_levels != 0 {
2030            let mut d_level_offsets = memory.alloc::<u32>(num_levels_plus1 as usize)?;
2031            self.provider
2032                .htod_sync_copy_into_tracked(
2033                    &artifact.level_offsets[..num_levels_plus1 as usize],
2034                    &mut d_level_offsets,
2035                )
2036                .map_err(|e| {
2037                    XlogError::Kernel(format!("restore htod level_offsets failed: {}", e))
2038                })?;
2039            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2040            unsafe {
2041                store_u32.clone().launch(
2042                    LaunchConfig {
2043                        grid_dim: (grid_levels, 1, 1),
2044                        block_dim: (block_dim, 1, 1),
2045                        shared_mem_bytes: 0,
2046                    },
2047                    (
2048                        handle.slot_device(),
2049                        handle.compile_needed_device(),
2050                        level_offset_stride,
2051                        &d_level_offsets,
2052                        &mut self.level_offsets,
2053                        num_levels_plus1,
2054                    ),
2055                )
2056            }
2057            .map_err(|e| {
2058                XlogError::Kernel(format!("restore cache_store level_offsets failed: {}", e))
2059            })?;
2060        }
2061
2062        // -- Store metadata via cache_store_meta kernel --
2063        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2064        unsafe {
2065            store_meta.clone().launch(
2066                LaunchConfig {
2067                    grid_dim: (1, 1, 1),
2068                    block_dim: (1, 1, 1),
2069                    shared_mem_bytes: 0,
2070                },
2071                (
2072                    handle.slot_device(),
2073                    handle.compile_needed_device(),
2074                    self.num_slots,
2075                    num_nodes,
2076                    num_levels,
2077                    root,
2078                    max_var,
2079                    &mut self.meta_num_nodes,
2080                    &mut self.meta_num_levels,
2081                    &mut self.meta_root,
2082                    &mut self.meta_max_var,
2083                ),
2084            )
2085        }
2086        .map_err(|e| XlogError::Kernel(format!("restore cache_store_meta failed: {}", e)))?;
2087
2088        // -- Zero the free_var_mask region for this slot, then conditionally write --
2089        let slot_idx = handle.slot_index() as usize;
2090
2091        // Zero the slot's mask region by uploading a zero buffer and storing it.
2092        // We always zero to ensure stale mask data from a previous occupant is cleared.
2093        let mask_cap = var_stride; // max_var+1 capacity per slot
2094        let grid_mask_zero = cache_grid_dim_for_u32_count(
2095            "GpuCircuitCache restore zero free_var_mask",
2096            mask_cap,
2097            block_dim,
2098        )?;
2099        if grid_mask_zero != 0 {
2100            let mut d_zeros = memory.alloc::<u8>(mask_cap as usize)?;
2101            device.memset_zeros(&mut d_zeros).map_err(|e| {
2102                XlogError::Kernel(format!("restore memset_zeros free_var_mask failed: {}", e))
2103            })?;
2104            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2105            unsafe {
2106                store_u8.clone().launch(
2107                    LaunchConfig {
2108                        grid_dim: (grid_mask_zero, 1, 1),
2109                        block_dim: (block_dim, 1, 1),
2110                        shared_mem_bytes: 0,
2111                    },
2112                    (
2113                        handle.slot_device(),
2114                        handle.compile_needed_device(),
2115                        var_stride,
2116                        &d_zeros,
2117                        &mut self.free_var_mask,
2118                        mask_cap,
2119                    ),
2120                )
2121            }
2122            .map_err(|e| {
2123                XlogError::Kernel(format!(
2124                    "restore cache_store zero free_var_mask failed: {}",
2125                    e
2126                ))
2127            })?;
2128        }
2129
2130        // Write the actual free_var_mask if the artifact has one.
2131        let has_mask = artifact.has_free_var_mask && !artifact.free_var_mask.is_empty();
2132        if has_mask {
2133            let mask_len = max_var.checked_add(1).ok_or_else(|| {
2134                XlogError::Compilation(
2135                    "GpuCircuitCache restore: free_var_mask max_var overflow".to_string(),
2136                )
2137            })?;
2138            let actual_len = std::cmp::min(mask_len as usize, artifact.free_var_mask.len());
2139            if actual_len > 0 {
2140                let actual_len_u32 = u32::try_from(actual_len).map_err(|_| {
2141                    XlogError::Compilation(
2142                        "GpuCircuitCache restore free_var_mask len exceeds u32".to_string(),
2143                    )
2144                })?;
2145                let grid_mask = cache_grid_dim_for_u32_count(
2146                    "GpuCircuitCache restore free_var_mask",
2147                    actual_len_u32,
2148                    block_dim,
2149                )?;
2150                if grid_mask != 0 {
2151                    let mut d_mask = memory.alloc::<u8>(actual_len)?;
2152                    self.provider
2153                        .htod_sync_copy_into_tracked(
2154                            &artifact.free_var_mask[..actual_len],
2155                            &mut d_mask,
2156                        )
2157                        .map_err(|e| {
2158                            XlogError::Kernel(format!("restore htod free_var_mask failed: {}", e))
2159                        })?;
2160                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2161                    unsafe {
2162                        store_u8.clone().launch(
2163                            LaunchConfig {
2164                                grid_dim: (grid_mask, 1, 1),
2165                                block_dim: (block_dim, 1, 1),
2166                                shared_mem_bytes: 0,
2167                            },
2168                            (
2169                                handle.slot_device(),
2170                                handle.compile_needed_device(),
2171                                var_stride,
2172                                &d_mask,
2173                                &mut self.free_var_mask,
2174                                actual_len_u32,
2175                            ),
2176                        )
2177                    }
2178                    .map_err(|e| {
2179                        XlogError::Kernel(format!(
2180                            "restore cache_store free_var_mask failed: {}",
2181                            e
2182                        ))
2183                    })?;
2184                }
2185            }
2186        }
2187
2188        // Set per-slot has_free_var_mask flag.
2189        debug_assert!(
2190            slot_idx < self.has_free_var_mask.len(),
2191            "slot_index {} exceeds num_slots {}",
2192            slot_idx,
2193            self.has_free_var_mask.len()
2194        );
2195        if slot_idx < self.has_free_var_mask.len() {
2196            self.has_free_var_mask[slot_idx] = has_mask;
2197        }
2198
2199        // No device synchronize needed: all stores are H→D copies followed by
2200        // same-stream kernel launches, so ordering is guaranteed.
2201        Ok(())
2202    }
2203
2204    /// Extract a [`disk_cache::CircuitArtifact`] from a populated GPU cache slot.
2205    ///
2206    /// This is the inverse of [`restore_from_host_arrays`]: it reads device-resident
2207    /// topology arrays from the cache slot and builds host vectors suitable for disk
2208    /// serialization. The caller must ensure the slot has been populated (i.e. after
2209    /// `store_from_xgcf` + `store_free_var_mask`).
2210    pub(crate) fn build_artifact_from_device(
2211        &self,
2212        handle: &GpuCircuitCacheHandle,
2213        provider: &Arc<CudaKernelProvider>,
2214    ) -> Result<disk_cache::CircuitArtifact> {
2215        let device = provider.device().inner();
2216        let slot = handle.slot_index() as usize;
2217        let num_nodes = handle.num_nodes();
2218        let num_levels = handle.num_levels();
2219        let root = handle.root();
2220        let max_var = handle.max_var();
2221
2222        if num_nodes == 0 {
2223            return Err(XlogError::Compilation(
2224                "build_artifact_from_device: num_nodes is 0".to_string(),
2225            ));
2226        }
2227
2228        let node_stride = self.node_cap as usize;
2229        let offset_stride = (self.node_cap as usize) + 1;
2230        let edge_stride = self.edge_cap as usize;
2231        let level_offset_stride = (self.level_cap as usize) + 1;
2232        let var_stride = (self.var_cap as usize) + 1;
2233
2234        let slot_node_start = slot * node_stride;
2235        let slot_offset_start = slot * offset_stride;
2236        let slot_level_offset_start = slot * level_offset_stride;
2237        let slot_var_start = slot * var_stride;
2238
2239        let nn = num_nodes as usize;
2240        let nn1 = nn + 1;
2241        let nl1 = (num_levels as usize) + 1;
2242
2243        // Determine num_edges from child_offsets[num_nodes] - child_offsets[0].
2244        // We read child_offsets first, then derive num_edges from it.
2245        let child_offsets_view = self
2246            .child_offsets
2247            .slice(slot_offset_start..(slot_offset_start + nn1));
2248        let child_offsets: Vec<u32> = device
2249            .dtoh_sync_copy(&child_offsets_view)
2250            .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh child_offsets: {}", e)))?;
2251        let num_edges = if nn1 > 0 {
2252            child_offsets[nn]
2253                .checked_sub(child_offsets[0])
2254                .ok_or_else(|| {
2255                    XlogError::Compilation(
2256                        "build_artifact_from_device: child_offsets[num_nodes] < child_offsets[0]"
2257                            .to_string(),
2258                    )
2259                })?
2260        } else {
2261            0
2262        };
2263
2264        // Read child_indices from the edge region.
2265        let slot_edge_start = slot * edge_stride;
2266        let ne = num_edges as usize;
2267        let child_indices: Vec<u32> = if ne > 0 {
2268            let view = self
2269                .child_indices
2270                .slice(slot_edge_start..(slot_edge_start + ne));
2271            device.dtoh_sync_copy(&view).map_err(|e| {
2272                XlogError::Kernel(format!("build_artifact dtoh child_indices: {}", e))
2273            })?
2274        } else {
2275            Vec::new()
2276        };
2277
2278        // node_type (u8)
2279        let node_type_view = self
2280            .node_type
2281            .slice(slot_node_start..(slot_node_start + nn));
2282        let node_type: Vec<u8> = device
2283            .dtoh_sync_copy(&node_type_view)
2284            .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh node_type: {}", e)))?;
2285
2286        // lit (i32)
2287        let lit_view = self.lit.slice(slot_node_start..(slot_node_start + nn));
2288        let lit: Vec<i32> = device
2289            .dtoh_sync_copy(&lit_view)
2290            .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh lit: {}", e)))?;
2291
2292        // decision_var (u32)
2293        let dv_view = self
2294            .decision_var
2295            .slice(slot_node_start..(slot_node_start + nn));
2296        let decision_var: Vec<u32> = device
2297            .dtoh_sync_copy(&dv_view)
2298            .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh decision_var: {}", e)))?;
2299
2300        // decision_child_false (u32)
2301        let dcf_view = self
2302            .decision_child_false
2303            .slice(slot_node_start..(slot_node_start + nn));
2304        let decision_child_false: Vec<u32> = device.dtoh_sync_copy(&dcf_view).map_err(|e| {
2305            XlogError::Kernel(format!("build_artifact dtoh decision_child_false: {}", e))
2306        })?;
2307
2308        // decision_child_true (u32)
2309        let dct_view = self
2310            .decision_child_true
2311            .slice(slot_node_start..(slot_node_start + nn));
2312        let decision_child_true: Vec<u32> = device.dtoh_sync_copy(&dct_view).map_err(|e| {
2313            XlogError::Kernel(format!("build_artifact dtoh decision_child_true: {}", e))
2314        })?;
2315
2316        // level_nodes (u32)
2317        let ln_view = self
2318            .level_nodes
2319            .slice(slot_node_start..(slot_node_start + nn));
2320        let level_nodes: Vec<u32> = device
2321            .dtoh_sync_copy(&ln_view)
2322            .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh level_nodes: {}", e)))?;
2323
2324        // level_offsets (u32)
2325        let lo_view = self
2326            .level_offsets
2327            .slice(slot_level_offset_start..(slot_level_offset_start + nl1));
2328        let level_offsets: Vec<u32> = device
2329            .dtoh_sync_copy(&lo_view)
2330            .map_err(|e| XlogError::Kernel(format!("build_artifact dtoh level_offsets: {}", e)))?;
2331
2332        // free_var_mask (u8)
2333        let has_free_var_mask = self.has_free_var_mask_for_slot(slot as u32);
2334        let mask_len = (max_var as usize) + 1;
2335        let free_var_mask: Vec<u8> = if mask_len > 0 {
2336            let fvm_view = self
2337                .free_var_mask
2338                .slice(slot_var_start..(slot_var_start + mask_len));
2339            device.dtoh_sync_copy(&fvm_view).map_err(|e| {
2340                XlogError::Kernel(format!("build_artifact dtoh free_var_mask: {}", e))
2341            })?
2342        } else {
2343            Vec::new()
2344        };
2345
2346        Ok(disk_cache::CircuitArtifact {
2347            num_nodes,
2348            num_edges,
2349            num_levels,
2350            root,
2351            max_var,
2352            has_free_var_mask,
2353            node_type,
2354            child_offsets,
2355            child_indices,
2356            lit,
2357            decision_var,
2358            decision_child_false,
2359            decision_child_true,
2360            level_nodes,
2361            level_offsets,
2362            free_var_mask,
2363        })
2364    }
2365
2366    pub fn eval_log_wmc_device_inplace(
2367        &mut self,
2368        handle: &GpuCircuitCacheHandle,
2369        out_log_z: &mut TrackedCudaSlice<f64>,
2370    ) -> Result<()> {
2371        self.eval_log_wmc_device_only(handle, out_log_z)
2372    }
2373
2374    pub fn eval_log_wmc_device_only(
2375        &mut self,
2376        handle: &GpuCircuitCacheHandle,
2377        out_log_z: &mut TrackedCudaSlice<f64>,
2378    ) -> Result<()> {
2379        if out_log_z.len() != 1 {
2380            return Err(XlogError::Compilation(format!(
2381                "GPU cache logZ output len {} != 1",
2382                out_log_z.len()
2383            )));
2384        }
2385
2386        {
2387            let device = self.provider.device().inner();
2388            let eval_all = device
2389                .get_func(
2390                    xlog_cuda::CIRCUIT_MODULE,
2391                    xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED,
2392                )
2393                .ok_or_else(|| {
2394                    XlogError::Kernel("xgcf_eval_all_levels_cached kernel not found".to_string())
2395                })?;
2396
2397            let block_size: u32 = 256;
2398            let mut params: Vec<*mut std::ffi::c_void> = vec![
2399                handle.slot_device().as_kernel_param(),
2400                self.node_cap.as_kernel_param(),
2401                self.edge_cap.as_kernel_param(),
2402                self.level_cap.as_kernel_param(),
2403                self.var_cap.as_kernel_param(),
2404                (&self.node_type).as_kernel_param(),
2405                (&self.child_offsets).as_kernel_param(),
2406                (&self.child_indices).as_kernel_param(),
2407                (&self.lit).as_kernel_param(),
2408                (&self.decision_var).as_kernel_param(),
2409                (&self.decision_child_false).as_kernel_param(),
2410                (&self.decision_child_true).as_kernel_param(),
2411                (&self.level_nodes).as_kernel_param(),
2412                (&self.level_offsets).as_kernel_param(),
2413                (&self.var_log_true).as_kernel_param(),
2414                (&self.var_log_false).as_kernel_param(),
2415                (&self.values).as_kernel_param(),
2416                (&self.meta_num_levels).as_kernel_param(),
2417            ];
2418            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2419            unsafe {
2420                eval_all.clone().launch(
2421                    LaunchConfig {
2422                        grid_dim: (1, 1, 1),
2423                        block_dim: (block_size, 1, 1),
2424                        shared_mem_bytes: 0,
2425                    },
2426                    &mut params,
2427                )
2428            }
2429            .map_err(|e| XlogError::Kernel(format!("xgcf_eval_all_levels_cached failed: {}", e)))?;
2430        }
2431
2432        self.apply_free_var_correction_cached(handle, true, false)?;
2433
2434        let device = self.provider.device().inner();
2435        let copy_root = device
2436            .get_func(
2437                xlog_cuda::CIRCUIT_MODULE,
2438                xlog_cuda::circuit_kernels::XGCF_COPY_ROOT_CACHED_META,
2439            )
2440            .ok_or_else(|| {
2441                XlogError::Kernel("xgcf_copy_root_cached_meta kernel not found".to_string())
2442            })?;
2443        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2444        unsafe {
2445            copy_root.clone().launch(
2446                LaunchConfig {
2447                    grid_dim: (1, 1, 1),
2448                    block_dim: (1, 1, 1),
2449                    shared_mem_bytes: 0,
2450                },
2451                (
2452                    handle.slot_device(),
2453                    self.node_cap,
2454                    &self.values,
2455                    &self.meta_root,
2456                    out_log_z,
2457                ),
2458            )
2459        }
2460        .map_err(|e| XlogError::Kernel(format!("xgcf_copy_root_cached_meta failed: {}", e)))?;
2461
2462        // No device synchronize: callers read back with a synchronous host copy
2463        // or pass the result to subsequent GPU operations (same-stream ordering).
2464        Ok(())
2465    }
2466
2467    pub fn eval_grads_inplace(&mut self, handle: &GpuCircuitCacheHandle) -> Result<()> {
2468        let device = self.provider.device().inner();
2469        let eval_all = device
2470            .get_func(
2471                xlog_cuda::CIRCUIT_MODULE,
2472                xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED,
2473            )
2474            .ok_or_else(|| {
2475                XlogError::Kernel("xgcf_eval_all_levels_cached kernel not found".to_string())
2476            })?;
2477        let block_size: u32 = 256;
2478        let mut params: Vec<*mut std::ffi::c_void> = vec![
2479            handle.slot_device().as_kernel_param(),
2480            self.node_cap.as_kernel_param(),
2481            self.edge_cap.as_kernel_param(),
2482            self.level_cap.as_kernel_param(),
2483            self.var_cap.as_kernel_param(),
2484            (&self.node_type).as_kernel_param(),
2485            (&self.child_offsets).as_kernel_param(),
2486            (&self.child_indices).as_kernel_param(),
2487            (&self.lit).as_kernel_param(),
2488            (&self.decision_var).as_kernel_param(),
2489            (&self.decision_child_false).as_kernel_param(),
2490            (&self.decision_child_true).as_kernel_param(),
2491            (&self.level_nodes).as_kernel_param(),
2492            (&self.level_offsets).as_kernel_param(),
2493            (&self.var_log_true).as_kernel_param(),
2494            (&self.var_log_false).as_kernel_param(),
2495            (&self.values).as_kernel_param(),
2496            (&self.meta_num_levels).as_kernel_param(),
2497        ];
2498        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2499        unsafe {
2500            eval_all.clone().launch(
2501                LaunchConfig {
2502                    grid_dim: (1, 1, 1),
2503                    block_dim: (block_size, 1, 1),
2504                    shared_mem_bytes: 0,
2505                },
2506                &mut params,
2507            )
2508        }
2509        .map_err(|e| XlogError::Kernel(format!("xgcf_eval_all_levels_cached failed: {}", e)))?;
2510
2511        let device = self.provider.device().inner();
2512        let store_f64 = device
2513            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
2514            .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
2515
2516        let node_stride = self.node_cap;
2517        let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
2518            XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2519        })?;
2520        let weights_len = self.var_cap.checked_add(1).ok_or_else(|| {
2521            XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2522        })?;
2523
2524        let grid_nodes = cache_grid_dim_for_u32_count(
2525            "GpuCircuitCache eval_grads zero adj",
2526            self.node_cap,
2527            block_size,
2528        )?;
2529        if grid_nodes != 0 {
2530            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2531            unsafe {
2532                store_f64.clone().launch(
2533                    LaunchConfig {
2534                        grid_dim: (grid_nodes, 1, 1),
2535                        block_dim: (block_size, 1, 1),
2536                        shared_mem_bytes: 0,
2537                    },
2538                    (
2539                        handle.slot_device(),
2540                        &self.always_on,
2541                        node_stride,
2542                        &self.zero_f64,
2543                        &mut self.adj,
2544                        self.node_cap,
2545                    ),
2546                )
2547            }
2548            .map_err(|e| XlogError::Kernel(format!("cache zero adj failed: {}", e)))?;
2549        }
2550
2551        let grid_weights = cache_grid_dim_for_u32_count(
2552            "GpuCircuitCache eval_grads zero weights",
2553            weights_len,
2554            block_size,
2555        )?;
2556        if grid_weights != 0 {
2557            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2558            unsafe {
2559                store_f64.clone().launch(
2560                    LaunchConfig {
2561                        grid_dim: (grid_weights, 1, 1),
2562                        block_dim: (block_size, 1, 1),
2563                        shared_mem_bytes: 0,
2564                    },
2565                    (
2566                        handle.slot_device(),
2567                        &self.always_on,
2568                        var_stride,
2569                        &self.zero_f64,
2570                        &mut self.grad_true,
2571                        weights_len,
2572                    ),
2573                )
2574            }
2575            .map_err(|e| XlogError::Kernel(format!("cache zero grad_true failed: {}", e)))?;
2576
2577            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2578            unsafe {
2579                store_f64.clone().launch(
2580                    LaunchConfig {
2581                        grid_dim: (grid_weights, 1, 1),
2582                        block_dim: (block_size, 1, 1),
2583                        shared_mem_bytes: 0,
2584                    },
2585                    (
2586                        handle.slot_device(),
2587                        &self.always_on,
2588                        var_stride,
2589                        &self.zero_f64,
2590                        &mut self.grad_false,
2591                        weights_len,
2592                    ),
2593                )
2594            }
2595            .map_err(|e| XlogError::Kernel(format!("cache zero grad_false failed: {}", e)))?;
2596        }
2597
2598        let add_scalar = device
2599            .get_func(
2600                xlog_cuda::CIRCUIT_MODULE,
2601                xlog_cuda::circuit_kernels::XGCF_ADD_SCALAR_CACHED,
2602            )
2603            .ok_or_else(|| {
2604                XlogError::Kernel("xgcf_add_scalar_cached kernel not found".to_string())
2605            })?;
2606        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2607        unsafe {
2608            add_scalar.clone().launch(
2609                LaunchConfig {
2610                    grid_dim: (1, 1, 1),
2611                    block_dim: (1, 1, 1),
2612                    shared_mem_bytes: 0,
2613                },
2614                (
2615                    handle.slot_device(),
2616                    self.node_cap,
2617                    &mut self.adj,
2618                    &self.meta_root,
2619                    &self.one_f64,
2620                ),
2621            )
2622        }
2623        .map_err(|e| XlogError::Kernel(format!("xgcf_add_scalar_cached (adj) failed: {}", e)))?;
2624
2625        let propagate = device
2626            .get_func(
2627                xlog_cuda::CIRCUIT_MODULE,
2628                xlog_cuda::circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE_CACHED,
2629            )
2630            .ok_or_else(|| {
2631                XlogError::Kernel(
2632                    "xgcf_backward_level_propagate_cached kernel not found".to_string(),
2633                )
2634            })?;
2635        let decision_grad = device
2636            .get_func(
2637                xlog_cuda::CIRCUIT_MODULE,
2638                xlog_cuda::circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD_CACHED,
2639            )
2640            .ok_or_else(|| {
2641                XlogError::Kernel(
2642                    "xgcf_backward_level_decision_grad_cached kernel not found".to_string(),
2643                )
2644            })?;
2645        let lit_grad = device
2646            .get_func(
2647                xlog_cuda::CIRCUIT_MODULE,
2648                xlog_cuda::circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD_CACHED,
2649            )
2650            .ok_or_else(|| {
2651                XlogError::Kernel(
2652                    "xgcf_backward_level_lit_grad_cached kernel not found".to_string(),
2653                )
2654            })?;
2655
2656        let num_blocks = self.node_cap.div_ceil(block_size);
2657        let num_levels = self.level_cap;
2658        for level in (0..num_levels).rev() {
2659            if num_blocks == 0 {
2660                continue;
2661            }
2662            let level_u32: u32 = level;
2663            let mut params: Vec<*mut std::ffi::c_void> = vec![
2664                handle.slot_device().as_kernel_param(),
2665                self.node_cap.as_kernel_param(),
2666                self.edge_cap.as_kernel_param(),
2667                self.level_cap.as_kernel_param(),
2668                self.var_cap.as_kernel_param(),
2669                (&self.node_type).as_kernel_param(),
2670                (&self.child_offsets).as_kernel_param(),
2671                (&self.child_indices).as_kernel_param(),
2672                (&self.decision_var).as_kernel_param(),
2673                (&self.decision_child_false).as_kernel_param(),
2674                (&self.decision_child_true).as_kernel_param(),
2675                (&self.level_nodes).as_kernel_param(),
2676                (&self.level_offsets).as_kernel_param(),
2677                level_u32.as_kernel_param(),
2678                (&self.var_log_true).as_kernel_param(),
2679                (&self.var_log_false).as_kernel_param(),
2680                (&self.values).as_kernel_param(),
2681                (&self.adj).as_kernel_param(),
2682                (&self.meta_num_levels).as_kernel_param(),
2683            ];
2684
2685            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2686            unsafe {
2687                propagate.clone().launch(
2688                    LaunchConfig {
2689                        grid_dim: (num_blocks, 1, 1),
2690                        block_dim: (block_size, 1, 1),
2691                        shared_mem_bytes: 0,
2692                    },
2693                    &mut params,
2694                )
2695            }
2696            .map_err(|e| {
2697                XlogError::Kernel(format!(
2698                    "xgcf_backward_level_propagate_cached failed: {}",
2699                    e
2700                ))
2701            })?;
2702
2703            let mut params: Vec<*mut std::ffi::c_void> = vec![
2704                handle.slot_device().as_kernel_param(),
2705                self.node_cap.as_kernel_param(),
2706                self.edge_cap.as_kernel_param(),
2707                self.level_cap.as_kernel_param(),
2708                self.var_cap.as_kernel_param(),
2709                (&self.node_type).as_kernel_param(),
2710                (&self.decision_var).as_kernel_param(),
2711                (&self.decision_child_false).as_kernel_param(),
2712                (&self.decision_child_true).as_kernel_param(),
2713                (&self.level_nodes).as_kernel_param(),
2714                (&self.level_offsets).as_kernel_param(),
2715                level_u32.as_kernel_param(),
2716                (&self.var_log_true).as_kernel_param(),
2717                (&self.var_log_false).as_kernel_param(),
2718                (&self.values).as_kernel_param(),
2719                (&self.adj).as_kernel_param(),
2720                (&self.grad_true).as_kernel_param(),
2721                (&self.grad_false).as_kernel_param(),
2722                (&self.meta_num_levels).as_kernel_param(),
2723            ];
2724
2725            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2726            unsafe {
2727                decision_grad.clone().launch(
2728                    LaunchConfig {
2729                        grid_dim: (num_blocks, 1, 1),
2730                        block_dim: (block_size, 1, 1),
2731                        shared_mem_bytes: 0,
2732                    },
2733                    &mut params,
2734                )
2735            }
2736            .map_err(|e| {
2737                XlogError::Kernel(format!(
2738                    "xgcf_backward_level_decision_grad_cached failed: {}",
2739                    e
2740                ))
2741            })?;
2742
2743            let mut params: Vec<*mut std::ffi::c_void> = vec![
2744                handle.slot_device().as_kernel_param(),
2745                self.node_cap.as_kernel_param(),
2746                self.edge_cap.as_kernel_param(),
2747                self.level_cap.as_kernel_param(),
2748                self.var_cap.as_kernel_param(),
2749                (&self.node_type).as_kernel_param(),
2750                (&self.lit).as_kernel_param(),
2751                (&self.level_nodes).as_kernel_param(),
2752                (&self.level_offsets).as_kernel_param(),
2753                level_u32.as_kernel_param(),
2754                (&self.adj).as_kernel_param(),
2755                (&self.grad_true).as_kernel_param(),
2756                (&self.grad_false).as_kernel_param(),
2757                (&self.meta_num_levels).as_kernel_param(),
2758            ];
2759
2760            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2761            unsafe {
2762                lit_grad.clone().launch(
2763                    LaunchConfig {
2764                        grid_dim: (num_blocks, 1, 1),
2765                        block_dim: (block_size, 1, 1),
2766                        shared_mem_bytes: 0,
2767                    },
2768                    &mut params,
2769                )
2770            }
2771            .map_err(|e| {
2772                XlogError::Kernel(format!("xgcf_backward_level_lit_grad_cached failed: {}", e))
2773            })?;
2774        }
2775
2776        self.apply_free_var_correction_cached(handle, true, true)?;
2777        // No device synchronize: callers batch multiple eval/backward calls
2778        // before syncing at the query boundary.
2779        Ok(())
2780    }
2781
2782    /// Like [`eval_grads_inplace`] but replaces the per-level backward loop
2783    /// with a single launch of `xgcf_backward_all_levels_cached`, and omits the
2784    /// trailing `device().synchronize()` so that the caller can batch multiple
2785    /// queries before syncing.
2786    pub fn eval_grads_inplace_fused(&mut self, handle: &GpuCircuitCacheHandle) -> Result<()> {
2787        let device = self.provider.device().inner();
2788        let eval_all = device
2789            .get_func(
2790                xlog_cuda::CIRCUIT_MODULE,
2791                xlog_cuda::circuit_kernels::XGCF_EVAL_ALL_LEVELS_CACHED,
2792            )
2793            .ok_or_else(|| {
2794                XlogError::Kernel("xgcf_eval_all_levels_cached kernel not found".to_string())
2795            })?;
2796        let block_size: u32 = 256;
2797        let mut params: Vec<*mut std::ffi::c_void> = vec![
2798            handle.slot_device().as_kernel_param(),
2799            self.node_cap.as_kernel_param(),
2800            self.edge_cap.as_kernel_param(),
2801            self.level_cap.as_kernel_param(),
2802            self.var_cap.as_kernel_param(),
2803            (&self.node_type).as_kernel_param(),
2804            (&self.child_offsets).as_kernel_param(),
2805            (&self.child_indices).as_kernel_param(),
2806            (&self.lit).as_kernel_param(),
2807            (&self.decision_var).as_kernel_param(),
2808            (&self.decision_child_false).as_kernel_param(),
2809            (&self.decision_child_true).as_kernel_param(),
2810            (&self.level_nodes).as_kernel_param(),
2811            (&self.level_offsets).as_kernel_param(),
2812            (&self.var_log_true).as_kernel_param(),
2813            (&self.var_log_false).as_kernel_param(),
2814            (&self.values).as_kernel_param(),
2815            (&self.meta_num_levels).as_kernel_param(),
2816        ];
2817        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2818        unsafe {
2819            eval_all.clone().launch(
2820                LaunchConfig {
2821                    grid_dim: (1, 1, 1),
2822                    block_dim: (block_size, 1, 1),
2823                    shared_mem_bytes: 0,
2824                },
2825                &mut params,
2826            )
2827        }
2828        .map_err(|e| XlogError::Kernel(format!("xgcf_eval_all_levels_cached failed: {}", e)))?;
2829
2830        let device = self.provider.device().inner();
2831        let store_f64 = device
2832            .get_func(CACHE_MODULE, cache_kernels::CACHE_STORE_F64)
2833            .ok_or_else(|| XlogError::Kernel("cache_store_f64 kernel not found".to_string()))?;
2834
2835        let node_stride = self.node_cap;
2836        let var_stride = self.var_cap.checked_add(1).ok_or_else(|| {
2837            XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2838        })?;
2839        let weights_len = self.var_cap.checked_add(1).ok_or_else(|| {
2840            XlogError::Compilation("GPU cache eval_grads var_cap overflow".to_string())
2841        })?;
2842
2843        let grid_nodes = cache_grid_dim_for_u32_count(
2844            "GpuCircuitCache batched eval_grads zero adj",
2845            self.node_cap,
2846            block_size,
2847        )?;
2848        if grid_nodes != 0 {
2849            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2850            unsafe {
2851                store_f64.clone().launch(
2852                    LaunchConfig {
2853                        grid_dim: (grid_nodes, 1, 1),
2854                        block_dim: (block_size, 1, 1),
2855                        shared_mem_bytes: 0,
2856                    },
2857                    (
2858                        handle.slot_device(),
2859                        &self.always_on,
2860                        node_stride,
2861                        &self.zero_f64,
2862                        &mut self.adj,
2863                        self.node_cap,
2864                    ),
2865                )
2866            }
2867            .map_err(|e| XlogError::Kernel(format!("cache zero adj failed: {}", e)))?;
2868        }
2869
2870        let grid_weights = cache_grid_dim_for_u32_count(
2871            "GpuCircuitCache batched eval_grads zero weights",
2872            weights_len,
2873            block_size,
2874        )?;
2875        if grid_weights != 0 {
2876            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2877            unsafe {
2878                store_f64.clone().launch(
2879                    LaunchConfig {
2880                        grid_dim: (grid_weights, 1, 1),
2881                        block_dim: (block_size, 1, 1),
2882                        shared_mem_bytes: 0,
2883                    },
2884                    (
2885                        handle.slot_device(),
2886                        &self.always_on,
2887                        var_stride,
2888                        &self.zero_f64,
2889                        &mut self.grad_true,
2890                        weights_len,
2891                    ),
2892                )
2893            }
2894            .map_err(|e| XlogError::Kernel(format!("cache zero grad_true failed: {}", e)))?;
2895
2896            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2897            unsafe {
2898                store_f64.clone().launch(
2899                    LaunchConfig {
2900                        grid_dim: (grid_weights, 1, 1),
2901                        block_dim: (block_size, 1, 1),
2902                        shared_mem_bytes: 0,
2903                    },
2904                    (
2905                        handle.slot_device(),
2906                        &self.always_on,
2907                        var_stride,
2908                        &self.zero_f64,
2909                        &mut self.grad_false,
2910                        weights_len,
2911                    ),
2912                )
2913            }
2914            .map_err(|e| XlogError::Kernel(format!("cache zero grad_false failed: {}", e)))?;
2915        }
2916
2917        let add_scalar = device
2918            .get_func(
2919                xlog_cuda::CIRCUIT_MODULE,
2920                xlog_cuda::circuit_kernels::XGCF_ADD_SCALAR_CACHED,
2921            )
2922            .ok_or_else(|| {
2923                XlogError::Kernel("xgcf_add_scalar_cached kernel not found".to_string())
2924            })?;
2925        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2926        unsafe {
2927            add_scalar.clone().launch(
2928                LaunchConfig {
2929                    grid_dim: (1, 1, 1),
2930                    block_dim: (1, 1, 1),
2931                    shared_mem_bytes: 0,
2932                },
2933                (
2934                    handle.slot_device(),
2935                    self.node_cap,
2936                    &mut self.adj,
2937                    &self.meta_root,
2938                    &self.one_f64,
2939                ),
2940            )
2941        }
2942        .map_err(|e| XlogError::Kernel(format!("xgcf_add_scalar_cached (adj) failed: {}", e)))?;
2943
2944        // Fused backward: single kernel replaces the per-level loop.
2945        let backward_all = device
2946            .get_func(
2947                xlog_cuda::CIRCUIT_MODULE,
2948                xlog_cuda::circuit_kernels::XGCF_BACKWARD_ALL_LEVELS_CACHED,
2949            )
2950            .ok_or_else(|| XlogError::Kernel("xgcf_backward_all_levels_cached not found".into()))?;
2951
2952        let mut params: Vec<*mut std::ffi::c_void> = vec![
2953            handle.slot_device().as_kernel_param(),
2954            self.node_cap.as_kernel_param(),
2955            self.edge_cap.as_kernel_param(),
2956            self.level_cap.as_kernel_param(),
2957            self.var_cap.as_kernel_param(),
2958            (&self.node_type).as_kernel_param(),
2959            (&self.child_offsets).as_kernel_param(),
2960            (&self.child_indices).as_kernel_param(),
2961            (&self.decision_var).as_kernel_param(),
2962            (&self.decision_child_false).as_kernel_param(),
2963            (&self.decision_child_true).as_kernel_param(),
2964            (&self.lit).as_kernel_param(),
2965            (&self.level_nodes).as_kernel_param(),
2966            (&self.level_offsets).as_kernel_param(),
2967            (&self.var_log_true).as_kernel_param(),
2968            (&self.var_log_false).as_kernel_param(),
2969            (&self.values).as_kernel_param(),
2970            (&self.adj).as_kernel_param(),
2971            (&self.grad_true).as_kernel_param(),
2972            (&self.grad_false).as_kernel_param(),
2973            (&self.meta_num_levels).as_kernel_param(),
2974        ];
2975
2976        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2977        unsafe {
2978            backward_all.clone().launch(
2979                LaunchConfig {
2980                    grid_dim: (1, 1, 1),
2981                    block_dim: (block_size, 1, 1),
2982                    shared_mem_bytes: 0,
2983                },
2984                &mut params,
2985            )
2986        }
2987        .map_err(|e| XlogError::Kernel(format!("xgcf_backward_all_levels_cached failed: {}", e)))?;
2988
2989        self.apply_free_var_correction_cached(handle, true, true)?;
2990        Ok(())
2991    }
2992
2993    fn apply_free_var_correction_cached(
2994        &mut self,
2995        handle: &GpuCircuitCacheHandle,
2996        apply_log_z: bool,
2997        apply_grads: bool,
2998    ) -> Result<()> {
2999        if !self.has_free_var_mask_for_slot(handle.slot_index()) {
3000            return Ok(());
3001        }
3002        let n = self
3003            .var_cap
3004            .checked_add(1)
3005            .ok_or_else(|| XlogError::Compilation("GPU cache free-var overflow".to_string()))?;
3006        if n == 0 {
3007            return Ok(());
3008        }
3009
3010        let device = self.provider.device().inner();
3011        let block_dim = 256u32;
3012        let grid_dim = n.div_ceil(block_dim);
3013
3014        if apply_grads {
3015            let apply_grad = device
3016                .get_func(
3017                    xlog_cuda::CIRCUIT_MODULE,
3018                    xlog_cuda::circuit_kernels::XGCF_FREE_VAR_APPLY_GRAD_CACHED,
3019                )
3020                .ok_or_else(|| {
3021                    XlogError::Kernel(
3022                        "xgcf_free_var_apply_grad_cached kernel not found".to_string(),
3023                    )
3024                })?;
3025            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
3026            unsafe {
3027                apply_grad.clone().launch(
3028                    LaunchConfig {
3029                        grid_dim: (grid_dim, 1, 1),
3030                        block_dim: (block_dim, 1, 1),
3031                        shared_mem_bytes: 0,
3032                    },
3033                    (
3034                        handle.slot_device(),
3035                        self.var_cap,
3036                        &self.free_var_mask,
3037                        &self.var_log_true,
3038                        &self.var_log_false,
3039                        n,
3040                        &mut self.grad_true,
3041                        &mut self.grad_false,
3042                    ),
3043                )
3044            }
3045            .map_err(|e| {
3046                XlogError::Kernel(format!("xgcf_free_var_apply_grad_cached failed: {}", e))
3047            })?;
3048        }
3049
3050        if apply_log_z {
3051            let reduce_stage = device
3052                .get_func(
3053                    xlog_cuda::CIRCUIT_MODULE,
3054                    xlog_cuda::circuit_kernels::XGCF_FREE_VAR_REDUCE_STAGE_CACHED,
3055                )
3056                .ok_or_else(|| {
3057                    XlogError::Kernel(
3058                        "xgcf_free_var_reduce_stage_cached kernel not found".to_string(),
3059                    )
3060                })?;
3061            let add_scalar = device
3062                .get_func(
3063                    xlog_cuda::CIRCUIT_MODULE,
3064                    xlog_cuda::circuit_kernels::XGCF_ADD_SCALAR_CACHED,
3065                )
3066                .ok_or_else(|| {
3067                    XlogError::Kernel("xgcf_add_scalar_cached kernel not found".to_string())
3068                })?;
3069
3070            let memory = self.provider.memory();
3071            let mut buf_a = memory.alloc::<f64>(n as usize)?;
3072            let mut buf_b = memory.alloc::<f64>(n as usize)?;
3073
3074            let mut stage_n = n;
3075            let mut stage0 = true;
3076            let mut output_is_a = true;
3077            loop {
3078                let out_len = stage_n.div_ceil(2);
3079                let stage_grid = out_len.div_ceil(block_dim);
3080
3081                let (in_buf, out_buf): (&TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) =
3082                    if output_is_a {
3083                        (&buf_b, &mut buf_a)
3084                    } else {
3085                        (&buf_a, &mut buf_b)
3086                    };
3087                let mode = if stage0 { 0u32 } else { 1u32 };
3088
3089                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
3090                unsafe {
3091                    reduce_stage.clone().launch(
3092                        LaunchConfig {
3093                            grid_dim: (stage_grid, 1, 1),
3094                            block_dim: (block_dim, 1, 1),
3095                            shared_mem_bytes: 0,
3096                        },
3097                        (
3098                            handle.slot_device(),
3099                            self.var_cap,
3100                            &self.free_var_mask,
3101                            &self.var_log_true,
3102                            &self.var_log_false,
3103                            in_buf,
3104                            stage_n,
3105                            mode,
3106                            out_buf,
3107                        ),
3108                    )
3109                }
3110                .map_err(|e| {
3111                    XlogError::Kernel(format!("xgcf_free_var_reduce_stage_cached failed: {}", e))
3112                })?;
3113
3114                if out_len == 1 {
3115                    let result_buf = if output_is_a { &buf_a } else { &buf_b };
3116                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
3117                    unsafe {
3118                        add_scalar.clone().launch(
3119                            LaunchConfig {
3120                                grid_dim: (1, 1, 1),
3121                                block_dim: (1, 1, 1),
3122                                shared_mem_bytes: 0,
3123                            },
3124                            (
3125                                handle.slot_device(),
3126                                self.node_cap,
3127                                &mut self.values,
3128                                &self.meta_root,
3129                                result_buf,
3130                            ),
3131                        )
3132                    }
3133                    .map_err(|e| {
3134                        XlogError::Kernel(format!("xgcf_add_scalar_cached failed: {}", e))
3135                    })?;
3136                    break;
3137                }
3138
3139                stage_n = out_len;
3140                stage0 = false;
3141                output_is_a = !output_is_a;
3142            }
3143        }
3144
3145        Ok(())
3146    }
3147}