Skip to main content

xlog_prob/
exact.rs

1//! Exact probabilistic inference via GPU-native Decision-DNNF knowledge compilation
2//! and weighted model counting.
3
4use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, Mutex};
6
7use cudarc::driver::LaunchConfig;
8use xlog_core::{MemoryBudget, Result, ScalarType, XlogError};
9use xlog_cuda::LaunchAsync;
10use xlog_logic::ast::Program;
11
12use crate::compilation::gpu_cache::{
13    GpuCircuitCache, GpuCircuitCacheConfig, GpuCircuitCacheHandle,
14};
15use crate::compilation::gpu_cnf::GpuCnfVarTables;
16#[cfg(feature = "host-io")]
17use crate::compilation::gpu_weights::map_nodes_to_vars_gpu;
18use crate::compilation::gpu_weights::{build_evidence_by_var_gpu, build_weights_gpu};
19use crate::compilation::{
20    compile_gpu_d4_and_verify_cached, encode_cnf_gpu, CircuitCompileProfile, DeviceRandomVarList,
21    GpuCompileConfig, GpuPirGraph, GpuPirRoots,
22};
23use crate::neural_fast_path::{GpuWeightSlots, NeuralFastPathConfig};
24use crate::provenance::{
25    extract_from_program, extract_from_source, AggregateLiftStatus, GroundAtom, Provenance, Value,
26};
27use xlog_cuda::memory::TrackedCudaSlice;
28use xlog_cuda::provider::{
29    arith_kernels, filter_kernels, neural_kernels, weights_kernels, ARITH_MODULE, FILTER_MODULE,
30    NEURAL_MODULE, WEIGHTS_MODULE,
31};
32use xlog_cuda::{CudaBuffer, CudaDevice, CudaKernelProvider, GpuMemoryManager};
33
34#[derive(Debug, Clone)]
35pub struct QueryProbability {
36    pub atom: GroundAtom,
37    pub log_prob: f64,
38    pub prob: f64,
39}
40
41#[derive(Debug, Clone)]
42pub struct ExactResult {
43    pub log_z_e: f64,
44    pub query_probs: Vec<QueryProbability>,
45}
46
47#[derive(Debug, Clone)]
48pub struct QueryGradients {
49    pub atom: GroundAtom,
50    pub log_prob: f64,
51    pub prob: f64,
52    pub grad_true: Vec<f64>,
53    pub grad_false: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57pub struct ExactResultWithGrads {
58    pub log_z_e: f64,
59    pub query_grads: Vec<QueryGradients>,
60}
61
62#[derive(Debug, Clone)]
63struct QuerySpec {
64    #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
65    atom: GroundAtom,
66    var: Option<u32>,
67}
68
69fn neural_slot_count_u32(slot_count: usize) -> Result<u32> {
70    u32::try_from(slot_count).map_err(|_| {
71        XlogError::Compilation(
72            "Neural fast-path group slot count exceeds GPU u32 index space".to_string(),
73        )
74    })
75}
76
77fn checked_launch_grid_u32(context: &str, item_count: u32, block_size: u32) -> Result<u32> {
78    if block_size == 0 {
79        return Err(XlogError::Kernel(format!(
80            "{context} launch block size must be non-zero"
81        )));
82    }
83    if item_count == 0 {
84        return Ok(0);
85    }
86    item_count
87        .checked_add(block_size - 1)
88        .map(|rounded| rounded / block_size)
89        .ok_or_else(|| XlogError::Kernel(format!("{context} launch grid overflow")))
90}
91
92struct GpuExactState {
93    provider: Arc<CudaKernelProvider>,
94    cache: Mutex<GpuCircuitCache>,
95    handle: GpuCircuitCacheHandle,
96    /// Device-resident batched query-var metadata, keyed by the host vector of
97    /// CNF query vars. Lets a warm training loop reuse a single upload instead
98    /// of re-uploading (a tracked htod) on every batched force call.
99    query_var_batch_cache: Mutex<HashMap<Vec<u32>, Arc<TrackedCudaSlice<u32>>>>,
100}
101
102/// GPU device selection and memory budget for probabilistic inference.
103///
104/// Use [`GpuConfig::default()`] and override individual fields as needed.
105#[derive(Debug, Clone, Copy)]
106#[non_exhaustive]
107pub struct GpuConfig {
108    /// CUDA device ordinal (0-based).
109    pub device_ordinal: usize,
110    /// Device memory budget in bytes (clamped to available memory at runtime).
111    pub memory_bytes: u64,
112    /// Host-side Decision-DNNF compiler decision-order hint: renumber leaf/choice
113    /// variables by descending structural fanout in the provenance DAG before CNF
114    /// encoding, steering the deterministic variable-id tie-breaks of the
115    /// (unchanged) GPU-native Decision-DNNF branching heuristic. Query probabilities
116    /// are unaffected; only compile-time search shape can differ.
117    pub decision_order_hint: bool,
118}
119
120impl Default for GpuConfig {
121    fn default() -> Self {
122        Self {
123            device_ordinal: 0,
124            memory_bytes: 32 * 1024 * 1024 * 1024, // 32 GB — clamped to available device memory by GpuMemoryManager at runtime.
125            decision_order_hint: false,
126        }
127    }
128}
129
130impl GpuExactState {
131    fn new(
132        provider: Arc<CudaKernelProvider>,
133        cache: GpuCircuitCache,
134        handle: GpuCircuitCacheHandle,
135    ) -> Result<Self> {
136        Ok(Self {
137            provider,
138            cache: Mutex::new(cache),
139            handle,
140            query_var_batch_cache: Mutex::new(HashMap::new()),
141        })
142    }
143
144    fn provider(&self) -> &Arc<CudaKernelProvider> {
145        &self.provider
146    }
147
148    fn handle(&self) -> &GpuCircuitCacheHandle {
149        &self.handle
150    }
151
152    /// Device-resident batched query vars for `query_vars_host`, uploading once
153    /// and reusing the cached slice on repeat calls with the same vars. The
154    /// upload is a tracked htod; caching it keeps a warm training loop free of
155    /// per-step host transfers.
156    fn cached_query_var_batch(
157        &self,
158        query_vars_host: Vec<u32>,
159    ) -> Result<Arc<TrackedCudaSlice<u32>>> {
160        let mut cache = self
161            .query_var_batch_cache
162            .lock()
163            .unwrap_or_else(|poisoned| poisoned.into_inner());
164        if let Some(cached) = cache.get(&query_vars_host) {
165            return Ok(Arc::clone(cached));
166        }
167        let mut query_vars = self.provider.memory().alloc::<u32>(query_vars_host.len())?;
168        self.provider
169            .htod_sync_copy_into_tracked(&query_vars_host, &mut query_vars)
170            .map_err(|e| {
171                XlogError::Kernel(format!("Failed to upload batched query vars: {}", e))
172            })?;
173        let query_vars = Arc::new(query_vars);
174        cache.insert(query_vars_host, Arc::clone(&query_vars));
175        Ok(query_vars)
176    }
177}
178
179#[cfg_attr(not(feature = "host-io"), allow(dead_code))]
180struct GpuCountLiftQuery {
181    atom: GroundAtom,
182    target_count: u32,
183    leaf_count: u32,
184    leaf_probs: TrackedCudaSlice<f64>,
185}
186
187#[cfg_attr(not(feature = "host-io"), allow(dead_code))]
188struct GpuCountLiftState {
189    provider: Arc<CudaKernelProvider>,
190    queries: Vec<GpuCountLiftQuery>,
191}
192
193impl GpuCountLiftState {
194    fn new(provider: Arc<CudaKernelProvider>, queries: Vec<GpuCountLiftQuery>) -> Self {
195        Self { provider, queries }
196    }
197
198    #[cfg(feature = "host-io")]
199    fn evaluate(&self) -> Result<ExactResult> {
200        let func = self
201            .provider
202            .device()
203            .inner()
204            .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_COUNT_LIFT_EXACT)
205            .ok_or_else(|| {
206                XlogError::Kernel("weights_count_lift_exact kernel not found".to_string())
207            })?;
208        let mut query_probs = Vec::with_capacity(self.queries.len());
209        for query in &self.queries {
210            let scratch_len = query
211                .target_count
212                .checked_add(1)
213                .ok_or_else(|| XlogError::Compilation("count-lift target overflow".to_string()))?;
214            let mut scratch = self.provider.memory().alloc::<f64>(scratch_len as usize)?;
215            let mut out = self.provider.memory().alloc::<f64>(1)?;
216            unsafe {
217                func.clone().launch(
218                    LaunchConfig {
219                        grid_dim: (1, 1, 1),
220                        block_dim: (1, 1, 1),
221                        shared_mem_bytes: 0,
222                    },
223                    (
224                        &query.leaf_probs,
225                        query.leaf_count,
226                        query.target_count,
227                        &mut scratch,
228                        &mut out,
229                    ),
230                )
231            }
232            .map_err(|e| XlogError::Kernel(format!("weights_count_lift_exact failed: {}", e)))?;
233            let mut host = vec![0.0f64; 1];
234            self.provider
235                .device()
236                .inner()
237                .dtoh_sync_copy_into(&out, &mut host)
238                .map_err(|e| XlogError::Kernel(format!("count-lift result dtoh failed: {}", e)))?;
239            let mut prob = host[0];
240            if (-1e-12..0.0).contains(&prob) || prob == -1e-12 {
241                prob = 0.0;
242            } else if prob > 1.0 && (1.0..=1.0 + 1e-12).contains(&prob) {
243                prob = 1.0;
244            }
245            if !prob.is_finite() || !(0.0..=1.0).contains(&prob) {
246                return Err(XlogError::Kernel(format!(
247                    "count-lift GPU evaluator returned invalid probability {}",
248                    prob
249                )));
250            }
251            let log_prob = if prob == 0.0 {
252                f64::NEG_INFINITY
253            } else {
254                prob.ln()
255            };
256            query_probs.push(QueryProbability {
257                atom: query.atom.clone(),
258                log_prob,
259                prob,
260            });
261        }
262        Ok(ExactResult {
263            log_z_e: 0.0,
264            query_probs,
265        })
266    }
267}
268
269#[derive(Clone)]
270pub struct ExactDdnnfProgram {
271    gpu: Option<Arc<GpuExactState>>,
272    #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
273    count_lift_gpu: Option<Arc<GpuCountLiftState>>,
274    queries: Vec<QuerySpec>,
275    #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
276    random_vars: Option<Arc<DeviceRandomVarList>>,
277    max_var: u32,
278    #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
279    origin: ExactProgramOrigin,
280    #[allow(dead_code)] // retained: config is stored for future re-compilation paths
281    gpu_config: GpuConfig,
282    /// Latest circuit compilation profile (populated on cache miss when profiling).
283    last_compile_profile: Option<CircuitCompileProfile>,
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub(crate) enum ExactProgramOrigin {
288    Source,
289    Program,
290}
291
292impl ExactDdnnfProgram {
293    pub fn compile_source(source: &str) -> Result<Self> {
294        let provenance = extract_from_source(source)?;
295        Self::compile_provenance_with_gpu(
296            provenance,
297            GpuConfig::default(),
298            ExactProgramOrigin::Source,
299        )
300    }
301
302    pub fn compile_source_with_gpu(source: &str, config: GpuConfig) -> Result<Self> {
303        let provenance = extract_from_source(source)?;
304        Self::compile_provenance_with_gpu(provenance, config, ExactProgramOrigin::Source)
305    }
306
307    pub fn compile_from_program(program: &Program, config: GpuConfig) -> Result<Self> {
308        let provenance = extract_from_program(program)?;
309        Self::compile_provenance_with_gpu(provenance, config, ExactProgramOrigin::Program)
310    }
311
312    #[allow(dead_code)] // retained: accessor for future re-compilation paths
313    pub(crate) fn gpu_config(&self) -> GpuConfig {
314        self.gpu_config
315    }
316
317    #[cfg(feature = "host-io")]
318    pub(crate) fn origin(&self) -> ExactProgramOrigin {
319        self.origin
320    }
321
322    pub fn uses_gpu_production_backend(&self) -> bool {
323        self.gpu.is_some()
324    }
325
326    /// Get the latest circuit compilation profile (populated when XLOG_WARMUP_PROFILE=1).
327    pub fn last_compile_profile(&self) -> Option<&CircuitCompileProfile> {
328        self.last_compile_profile.as_ref()
329    }
330
331    #[doc(hidden)]
332    #[cfg(feature = "host-io")]
333    pub fn uses_gpu_native_count_lift(&self) -> bool {
334        self.count_lift_gpu.is_some()
335    }
336
337    #[cfg(feature = "host-io")]
338    pub fn evaluate(&self) -> Result<ExactResult> {
339        if let Some(count_lift_gpu) = &self.count_lift_gpu {
340            return count_lift_gpu.evaluate();
341        }
342
343        // `gpu` is `None` only when compilation found an empty PIR root set
344        // (no probabilistic leaves and no derivations reach any query), so
345        // every query atom is unprovable and P = 0 is the correct semantics —
346        // this is NOT a missing-GPU fallback; a real circuit with an
347        // unavailable GPU fails at compile time instead.
348        if self.gpu.is_none() {
349            let mut query_probs: Vec<QueryProbability> = Vec::with_capacity(self.queries.len());
350            for query in &self.queries {
351                query_probs.push(QueryProbability {
352                    atom: query.atom.clone(),
353                    log_prob: f64::NEG_INFINITY,
354                    prob: 0.0,
355                });
356            }
357            return Ok(ExactResult {
358                log_z_e: 0.0,
359                query_probs,
360            });
361        }
362
363        let log_z_e = self.eval_log_z_gpu(None)?;
364        if log_z_e.is_infinite() && log_z_e.is_sign_negative() {
365            return Err(XlogError::Execution(
366                "Exact inference error: evidence is inconsistent (P(E)=0)".to_string(),
367            ));
368        }
369
370        let mut query_probs: Vec<QueryProbability> = Vec::with_capacity(self.queries.len());
371        for query in &self.queries {
372            let (log_prob, prob) = match query.var {
373                None => (f64::NEG_INFINITY, 0.0),
374                Some(var) => {
375                    let log_z_eq = self.eval_log_z_gpu(Some(var))?;
376                    let log_prob = log_z_eq - log_z_e;
377                    let mut prob = if log_prob.is_infinite() && log_prob.is_sign_negative() {
378                        0.0
379                    } else {
380                        log_prob.exp()
381                    };
382                    if prob.is_nan() {
383                        return Err(XlogError::Execution(
384                            "Exact inference error: NaN probability encountered".to_string(),
385                        ));
386                    }
387                    prob = prob.clamp(0.0, 1.0);
388                    (log_prob, prob)
389                }
390            };
391
392            query_probs.push(QueryProbability {
393                atom: query.atom.clone(),
394                log_prob,
395                prob,
396            });
397        }
398
399        Ok(ExactResult {
400            log_z_e,
401            query_probs,
402        })
403    }
404
405    pub fn num_vars(&self) -> usize {
406        if self.max_var == 0 {
407            0
408        } else {
409            (self.max_var as usize) + 1
410        }
411    }
412
413    /// Returns the indices of random (probabilistic) variables in order.
414    ///
415    /// Random variables are those with non-trivial weights (not (0.0, 0.0)).
416    /// These correspond to annotated disjunctions in the source program.
417    /// The order matches the order variables were assigned during CNF encoding.
418    #[cfg(feature = "host-io")]
419    pub fn random_var_indices(&self) -> Vec<u32> {
420        let Some(state) = self.gpu.as_ref() else {
421            return Vec::new();
422        };
423        let Some(random_vars) = self.random_vars.as_ref() else {
424            return Vec::new();
425        };
426        if random_vars.is_empty() {
427            return Vec::new();
428        }
429        let count = random_vars.count() as usize;
430        let mut host = vec![0u32; count];
431        let view = random_vars.list().slice(0..count);
432        if let Err(e) = state
433            .provider()
434            .device()
435            .inner()
436            .dtoh_sync_copy_into(&view, &mut host)
437        {
438            eprintln!("Failed to read random var list: {}", e);
439            return Vec::new();
440        }
441        host
442    }
443
444    /// CNF variable id for the `idx`-th query formula (DIMACS, 1-based), if present.
445    pub(crate) fn query_var(&self, idx: usize) -> Option<u32> {
446        self.queries.get(idx).and_then(|q| q.var)
447    }
448
449    /// GPU neural fast-path: compute NLL gradients w.r.t. probability tensors (no host reads).
450    ///
451    /// This implements the design in `docs/design/2026-01-22-gpu-native-compilation-design.md` §5.3:
452    /// - Fill AD conditional-chain log-weights from device-resident `p[label]`.
453    /// - Run XGCF forward+backward on GPU.
454    /// - Scatter gradients back into probability-space via the correct chain rule (uses both grad_true + grad_false).
455    ///
456    /// The output gradient buffers are updated in-place:
457    /// - Base run: `out = dlogZ_base/dp`
458    /// - Query-forced run: `out -= dlogZ_query/dp`
459    ///   Result: `out = dL/dp` for `L = -log P(query | evidence)` (NLL).
460    pub fn neural_backward_nll_buffers(
461        &self,
462        slots: &GpuWeightSlots,
463        query_idx: usize,
464        probs: &[CudaBuffer],
465        out_grads: &mut [CudaBuffer],
466        cfg: NeuralFastPathConfig,
467    ) -> Result<()> {
468        self.neural_backward_nll_buffers_inner(slots, query_idx, probs, out_grads, cfg, None, true)
469    }
470
471    /// Same as [`Self::neural_backward_nll_buffers`], but also returns the device-resident scalar NLL loss:
472    /// `L = -log P(query | evidence)`.
473    ///
474    /// The returned slice has length 1 and is written on GPU (no device->host reads).
475    pub fn neural_backward_nll_buffers_with_device_loss(
476        &self,
477        slots: &GpuWeightSlots,
478        query_idx: usize,
479        probs: &[CudaBuffer],
480        out_grads: &mut [CudaBuffer],
481        cfg: NeuralFastPathConfig,
482        expected_true: bool,
483    ) -> Result<TrackedCudaSlice<f64>> {
484        let state = self.gpu_state()?;
485        let mut loss = state.provider.memory().alloc::<f64>(1)?;
486        self.neural_backward_nll_buffers_inner(
487            slots,
488            query_idx,
489            probs,
490            out_grads,
491            cfg,
492            Some(&mut loss),
493            expected_true,
494        )?;
495        Ok(loss)
496    }
497
498    /// Batched variant of [`Self::neural_backward_nll_buffers_with_device_loss`].
499    ///
500    /// Computes NLL gradients for `batch` queries that share one compiled circuit
501    /// template and returns a device-resident vector of `batch` scalar losses.
502    ///
503    /// On circuits that require free-variable correction, this falls back to the
504    /// existing per-query path for correctness.
505    pub fn neural_backward_nll_buffers_batch_with_device_loss(
506        &self,
507        slots: &GpuWeightSlots,
508        query_indices: &[usize],
509        probs_batch: &[Vec<CudaBuffer>],
510        out_grads_batch: &mut [Vec<CudaBuffer>],
511        cfg: NeuralFastPathConfig,
512        expected_true: bool,
513    ) -> Result<TrackedCudaSlice<f64>> {
514        let batch = query_indices.len();
515        if batch == 0 {
516            return Err(XlogError::Execution(
517                "Neural fast-path batch error: empty query batch".to_string(),
518            ));
519        }
520        if probs_batch.len() != batch || out_grads_batch.len() != batch {
521            return Err(XlogError::Compilation(format!(
522                "Neural fast-path batch error: query/prob/grad batch mismatch ({}/{}/{})",
523                batch,
524                probs_batch.len(),
525                out_grads_batch.len()
526            )));
527        }
528
529        let state = self.gpu_state()?;
530        let batch_u32 = u32::try_from(batch).map_err(|_| {
531            XlogError::Compilation("Neural fast-path batch size exceeds u32".to_string())
532        })?;
533        let device = state.provider.device().inner();
534
535        // Fallback for circuits that currently require per-query free-var correction.
536        {
537            let cache = state
538                .cache
539                .lock()
540                .unwrap_or_else(|poisoned| poisoned.into_inner());
541            if cache.has_any_free_var_mask() {
542                drop(cache);
543                let mut losses = state.provider.memory().alloc::<f64>(batch)?;
544                for q in 0..batch {
545                    let loss_q = self.neural_backward_nll_buffers_with_device_loss(
546                        slots,
547                        query_indices[q],
548                        &probs_batch[q],
549                        &mut out_grads_batch[q],
550                        cfg,
551                        expected_true,
552                    )?;
553                    let mut dst = losses.slice_mut(q..(q + 1));
554                    device.dtod_copy(&loss_q, &mut dst).map_err(|e| {
555                        XlogError::Kernel(format!(
556                            "Failed to copy fallback batch loss to output: {}",
557                            e
558                        ))
559                    })?;
560                }
561                return Ok(losses);
562            }
563        }
564
565        let fill = device
566            .get_func(NEURAL_MODULE, neural_kernels::NEURAL_FILL_AD_CHAIN_F32)
567            .ok_or_else(|| {
568                XlogError::Kernel("neural_fill_ad_chain_f32 kernel not found".to_string())
569            })?;
570        let scatter = device
571            .get_func(
572                NEURAL_MODULE,
573                neural_kernels::NEURAL_SCATTER_AD_CHAIN_GRADS_F32,
574            )
575            .ok_or_else(|| {
576                XlogError::Kernel("neural_scatter_ad_chain_grads_f32 kernel not found".to_string())
577            })?;
578        let binary_f64 = device
579            .get_func(ARITH_MODULE, arith_kernels::ARITH_BINARY_F64)
580            .ok_or_else(|| XlogError::Kernel("arith_binary_f64 kernel not found".to_string()))?;
581        let apply_query_false_batched = device
582            .get_func(
583                WEIGHTS_MODULE,
584                weights_kernels::WEIGHTS_APPLY_QUERY_VARS_FALSE_BATCHED,
585            )
586            .ok_or_else(|| {
587                XlogError::Kernel(
588                    "weights_apply_query_vars_false_batched kernel not found".to_string(),
589                )
590            })?;
591        let apply_query_true_batched = device
592            .get_func(
593                WEIGHTS_MODULE,
594                weights_kernels::WEIGHTS_APPLY_QUERY_VARS_TRUE_BATCHED,
595            )
596            .ok_or_else(|| {
597                XlogError::Kernel(
598                    "weights_apply_query_vars_true_batched kernel not found".to_string(),
599                )
600            })?;
601
602        let mut cache = state
603            .cache
604            .lock()
605            .unwrap_or_else(|poisoned| poisoned.into_inner());
606        let var_stride = cache.var_stride()?;
607        let var_stride_usize = var_stride as usize;
608        let node_stride = cache.node_stride();
609        let node_stride_usize = node_stride as usize;
610
611        let mut var_log_true_batch = state
612            .provider
613            .memory()
614            .alloc::<f64>(batch * var_stride_usize)?;
615        let mut var_log_false_batch = state
616            .provider
617            .memory()
618            .alloc::<f64>(batch * var_stride_usize)?;
619        cache.copy_slot_weights_to_batch(
620            state.handle(),
621            &mut var_log_true_batch,
622            &mut var_log_false_batch,
623            batch_u32,
624        )?;
625
626        let mut values_batch = state
627            .provider
628            .memory()
629            .alloc::<f64>(batch * node_stride_usize)?;
630        let mut adj_batch = state
631            .provider
632            .memory()
633            .alloc::<f64>(batch * node_stride_usize)?;
634        let mut grad_true_batch = state
635            .provider
636            .memory()
637            .alloc::<f64>(batch * var_stride_usize)?;
638        let mut grad_false_batch = state
639            .provider
640            .memory()
641            .alloc::<f64>(batch * var_stride_usize)?;
642        let mut base_roots = state.provider.memory().alloc::<f64>(batch)?;
643        let mut query_roots = state.provider.memory().alloc::<f64>(batch)?;
644        let mut losses = state.provider.memory().alloc::<f64>(batch)?;
645        let mut force_saved = state.provider.memory().alloc::<f64>(batch)?;
646
647        let mut query_vars_host: Vec<u32> = Vec::with_capacity(batch);
648
649        // Fill per-query var weight rows from device-resident probability tensors.
650        for q in 0..batch {
651            if probs_batch[q].len() != out_grads_batch[q].len() {
652                return Err(XlogError::Compilation(format!(
653                    "Neural fast-path batch error: probs len {} != out_grads len {} for query {}",
654                    probs_batch[q].len(),
655                    out_grads_batch[q].len(),
656                    q
657                )));
658            }
659            if probs_batch[q].len() != slots.num_groups_usize() {
660                return Err(XlogError::Compilation(format!(
661                    "Neural fast-path batch error: expected {} groups, got {} for query {}",
662                    slots.num_groups_usize(),
663                    probs_batch[q].len(),
664                    q
665                )));
666            }
667
668            let query_var = self.query_var(query_indices[q]).ok_or_else(|| {
669                XlogError::Execution(format!(
670                    "Neural fast-path batch error: query {} has no CNF var",
671                    query_indices[q]
672                ))
673            })?;
674            if query_var == 0 || query_var > self.max_var {
675                return Err(XlogError::Compilation(format!(
676                    "Neural fast-path batch error: query var {} out of bounds (max_var={})",
677                    query_var, self.max_var
678                )));
679            }
680            query_vars_host.push(query_var);
681
682            let row_start = q
683                .checked_mul(var_stride_usize)
684                .ok_or_else(|| XlogError::Compilation("Neural batch row overflow".to_string()))?;
685            let row_end = row_start + var_stride_usize;
686
687            for (g, prob_buf) in probs_batch[q].iter().enumerate() {
688                if prob_buf.arity() != 1 {
689                    return Err(XlogError::Compilation(
690                        "Neural fast-path expects 1-column prob buffers".to_string(),
691                    ));
692                }
693                let ty = prob_buf.schema().column_type(0).ok_or_else(|| {
694                    XlogError::Compilation("Missing prob buffer schema".to_string())
695                })?;
696                if ty != ScalarType::F32 {
697                    return Err(XlogError::Compilation(format!(
698                        "Neural fast-path expects prob dtype F32, got {:?}",
699                        ty
700                    )));
701                }
702
703                let slot_vars = slots.group_slot_cnf_var(g)?;
704                let labels = neural_slot_count_u32(slot_vars.len())?;
705                if prob_buf.num_rows() != labels as u64 {
706                    return Err(XlogError::Compilation(format!(
707                        "Neural fast-path prob rows {} != labels {}",
708                        prob_buf.num_rows(),
709                        labels
710                    )));
711                }
712                if out_grads_batch[q][g].num_rows() != labels as u64 {
713                    return Err(XlogError::Compilation(format!(
714                        "Neural fast-path grad rows {} != labels {}",
715                        out_grads_batch[q][g].num_rows(),
716                        labels
717                    )));
718                }
719
720                let prob_col = prob_buf.column(0).ok_or_else(|| {
721                    XlogError::Compilation("Neural fast-path missing prob column".to_string())
722                })?;
723                let mut q_true = var_log_true_batch.slice_mut(row_start..row_end);
724                let mut q_false = var_log_false_batch.slice_mut(row_start..row_end);
725
726                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
727                unsafe {
728                    fill.clone().launch(
729                        LaunchConfig {
730                            grid_dim: (1, 1, 1),
731                            block_dim: (1, 1, 1),
732                            shared_mem_bytes: 0,
733                        },
734                        (
735                            prob_col,
736                            labels,
737                            &slot_vars,
738                            cfg.eps,
739                            cfg.min_p,
740                            &mut q_true,
741                            &mut q_false,
742                        ),
743                    )
744                }
745                .map_err(|e| {
746                    XlogError::Kernel(format!("neural_fill_ad_chain_f32 failed: {}", e))
747                })?;
748            }
749        }
750
751        // Base pass (all queries): grads = dlogZ_base/dp, roots = logZ_base.
752        cache.eval_grads_inplace_fused_batched(
753            state.handle(),
754            &var_log_true_batch,
755            &var_log_false_batch,
756            &mut values_batch,
757            &mut adj_batch,
758            &mut grad_true_batch,
759            &mut grad_false_batch,
760            batch_u32,
761        )?;
762        cache.copy_root_batched_from_values(
763            state.handle(),
764            &values_batch,
765            &mut base_roots,
766            batch_u32,
767        )?;
768
769        // Scatter base gradients into output buffers.
770        for q in 0..batch {
771            let row_start = q
772                .checked_mul(var_stride_usize)
773                .ok_or_else(|| XlogError::Compilation("Neural batch row overflow".to_string()))?;
774            let row_end = row_start + var_stride_usize;
775            let q_grad_true = grad_true_batch.slice(row_start..row_end);
776            let q_grad_false = grad_false_batch.slice(row_start..row_end);
777
778            for (g, prob_buf) in probs_batch[q].iter().enumerate() {
779                let slot_vars = slots.group_slot_cnf_var(g)?;
780                let labels = neural_slot_count_u32(slot_vars.len())?;
781                let prob_col = prob_buf.column(0).ok_or_else(|| {
782                    XlogError::Compilation("Neural fast-path missing prob column".to_string())
783                })?;
784                let out_col = out_grads_batch[q][g]
785                    .columns
786                    .get_mut(0)
787                    .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
788
789                let shared_bytes: u32 = 3u64
790                    .checked_mul(labels as u64)
791                    .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
792                    .and_then(|n| u32::try_from(n).ok())
793                    .ok_or_else(|| {
794                        XlogError::Kernel("Neural scatter shared memory overflow".to_string())
795                    })?;
796
797                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
798                unsafe {
799                    scatter.clone().launch(
800                        LaunchConfig {
801                            grid_dim: (1, 1, 1),
802                            block_dim: (1, 1, 1),
803                            shared_mem_bytes: shared_bytes,
804                        },
805                        (
806                            prob_col,
807                            labels,
808                            &slot_vars,
809                            cfg.eps,
810                            cfg.min_p,
811                            &q_grad_true,
812                            &q_grad_false,
813                            0u8,
814                            out_col,
815                        ),
816                    )
817                }
818                .map_err(|e| XlogError::Kernel(format!("neural_scatter (base) failed: {}", e)))?;
819            }
820        }
821
822        // Reuse the device-resident query-var batch (uploaded once and cached),
823        // so a warm training loop performs no per-step tracked host transfer here.
824        let query_vars = state.cached_query_var_batch(query_vars_host)?;
825        let force_grid = checked_launch_grid_u32("gpu exact batched query force", batch_u32, 256)?;
826        if force_grid != 0 {
827            if expected_true {
828                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
829                unsafe {
830                    apply_query_false_batched.clone().launch(
831                        LaunchConfig {
832                            grid_dim: (force_grid, 1, 1),
833                            block_dim: (256, 1, 1),
834                            shared_mem_bytes: 0,
835                        },
836                        (
837                            query_vars.as_ref(),
838                            batch_u32,
839                            self.max_var,
840                            var_stride,
841                            &mut var_log_false_batch,
842                            &mut force_saved,
843                        ),
844                    )
845                }
846                .map_err(|e| {
847                    XlogError::Kernel(format!(
848                        "weights_apply_query_vars_false_batched failed: {}",
849                        e
850                    ))
851                })?;
852            } else {
853                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
854                unsafe {
855                    apply_query_true_batched.clone().launch(
856                        LaunchConfig {
857                            grid_dim: (force_grid, 1, 1),
858                            block_dim: (256, 1, 1),
859                            shared_mem_bytes: 0,
860                        },
861                        (
862                            query_vars.as_ref(),
863                            batch_u32,
864                            self.max_var,
865                            var_stride,
866                            &mut var_log_true_batch,
867                            &mut force_saved,
868                        ),
869                    )
870                }
871                .map_err(|e| {
872                    XlogError::Kernel(format!(
873                        "weights_apply_query_vars_true_batched failed: {}",
874                        e
875                    ))
876                })?;
877            }
878        }
879
880        // Query-forced pass (all queries): grads = dlogZ_query/dp, roots = logZ_query.
881        cache.eval_grads_inplace_fused_batched(
882            state.handle(),
883            &var_log_true_batch,
884            &var_log_false_batch,
885            &mut values_batch,
886            &mut adj_batch,
887            &mut grad_true_batch,
888            &mut grad_false_batch,
889            batch_u32,
890        )?;
891        cache.copy_root_batched_from_values(
892            state.handle(),
893            &values_batch,
894            &mut query_roots,
895            batch_u32,
896        )?;
897
898        let loss_grid = checked_launch_grid_u32("gpu exact batched query loss", batch_u32, 256)?;
899        if loss_grid != 0 {
900            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
901            unsafe {
902                binary_f64.clone().launch(
903                    LaunchConfig {
904                        grid_dim: (loss_grid, 1, 1),
905                        block_dim: (256, 1, 1),
906                        shared_mem_bytes: 0,
907                    },
908                    (&base_roots, &query_roots, batch_u32, 1u8, &mut losses),
909                )
910            }
911            .map_err(|e| XlogError::Kernel(format!("Failed to compute batched NLL loss: {}", e)))?;
912        }
913
914        // Scatter query gradients with subtract mode.
915        for q in 0..batch {
916            let row_start = q
917                .checked_mul(var_stride_usize)
918                .ok_or_else(|| XlogError::Compilation("Neural batch row overflow".to_string()))?;
919            let row_end = row_start + var_stride_usize;
920            let q_grad_true = grad_true_batch.slice(row_start..row_end);
921            let q_grad_false = grad_false_batch.slice(row_start..row_end);
922
923            for (g, prob_buf) in probs_batch[q].iter().enumerate() {
924                let slot_vars = slots.group_slot_cnf_var(g)?;
925                let labels = neural_slot_count_u32(slot_vars.len())?;
926                let prob_col = prob_buf.column(0).ok_or_else(|| {
927                    XlogError::Compilation("Neural fast-path missing prob column".to_string())
928                })?;
929                let out_col = out_grads_batch[q][g]
930                    .columns
931                    .get_mut(0)
932                    .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
933
934                let shared_bytes: u32 = 3u64
935                    .checked_mul(labels as u64)
936                    .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
937                    .and_then(|n| u32::try_from(n).ok())
938                    .ok_or_else(|| {
939                        XlogError::Kernel("Neural scatter shared memory overflow".to_string())
940                    })?;
941
942                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
943                unsafe {
944                    scatter.clone().launch(
945                        LaunchConfig {
946                            grid_dim: (1, 1, 1),
947                            block_dim: (1, 1, 1),
948                            shared_mem_bytes: shared_bytes,
949                        },
950                        (
951                            prob_col,
952                            labels,
953                            &slot_vars,
954                            cfg.eps,
955                            cfg.min_p,
956                            &q_grad_true,
957                            &q_grad_false,
958                            1u8,
959                            out_col,
960                        ),
961                    )
962                }
963                .map_err(|e| XlogError::Kernel(format!("neural_scatter (query) failed: {}", e)))?;
964            }
965        }
966
967        Ok(losses)
968    }
969
970    #[allow(clippy::too_many_arguments)]
971    fn neural_backward_nll_buffers_inner(
972        &self,
973        slots: &GpuWeightSlots,
974        query_idx: usize,
975        probs: &[CudaBuffer],
976        out_grads: &mut [CudaBuffer],
977        cfg: NeuralFastPathConfig,
978        out_loss: Option<&mut TrackedCudaSlice<f64>>,
979        expected_true: bool,
980    ) -> Result<()> {
981        if self.gpu.is_none() {
982            return Err(XlogError::Execution(
983                "Neural fast-path error: program has no compiled circuit".to_string(),
984            ));
985        }
986
987        let query_var = self.query_var(query_idx).ok_or_else(|| {
988            XlogError::Execution(format!(
989                "Neural fast-path error: query {} has no CNF var",
990                query_idx
991            ))
992        })?;
993
994        if probs.len() != out_grads.len() {
995            return Err(XlogError::Compilation(format!(
996                "Neural fast-path error: probs len {} != out_grads len {}",
997                probs.len(),
998                out_grads.len()
999            )));
1000        }
1001        if probs.len() != slots.num_groups_usize() {
1002            return Err(XlogError::Compilation(format!(
1003                "Neural fast-path error: expected {} groups, got {}",
1004                slots.num_groups_usize(),
1005                probs.len()
1006            )));
1007        }
1008
1009        let state = self.gpu_state()?;
1010        let device = state.provider.device().inner();
1011
1012        let fill = device
1013            .get_func(NEURAL_MODULE, neural_kernels::NEURAL_FILL_AD_CHAIN_F32)
1014            .ok_or_else(|| {
1015                XlogError::Kernel("neural_fill_ad_chain_f32 kernel not found".to_string())
1016            })?;
1017        let scatter = device
1018            .get_func(
1019                NEURAL_MODULE,
1020                neural_kernels::NEURAL_SCATTER_AD_CHAIN_GRADS_F32,
1021            )
1022            .ok_or_else(|| {
1023                XlogError::Kernel("neural_scatter_ad_chain_grads_f32 kernel not found".to_string())
1024            })?;
1025        let binary_f64 = device
1026            .get_func(ARITH_MODULE, arith_kernels::ARITH_BINARY_F64)
1027            .ok_or_else(|| XlogError::Kernel("arith_binary_f64 kernel not found".to_string()))?;
1028
1029        let mut cache = state
1030            .cache
1031            .lock()
1032            .unwrap_or_else(|poisoned| poisoned.into_inner());
1033
1034        let root_idx = state.handle().root() as usize;
1035
1036        // If the caller requested the scalar loss, keep the base logZ on device so we can compute
1037        // loss = logZ_base - logZ_query without any host reads.
1038        let mut base_log_z: Option<TrackedCudaSlice<f64>> = if out_loss.is_some() {
1039            Some(state.provider.memory().alloc::<f64>(1)?)
1040        } else {
1041            None
1042        };
1043
1044        // 1) Update AD chain weights from device-resident p[label].
1045        for (g, prob_buf) in probs.iter().enumerate() {
1046            if prob_buf.arity() != 1 {
1047                return Err(XlogError::Compilation(
1048                    "Neural fast-path expects 1-column prob buffers".to_string(),
1049                ));
1050            }
1051            let ty = prob_buf
1052                .schema()
1053                .column_type(0)
1054                .ok_or_else(|| XlogError::Compilation("Missing prob buffer schema".to_string()))?;
1055            if ty != ScalarType::F32 {
1056                return Err(XlogError::Compilation(format!(
1057                    "Neural fast-path expects prob dtype F32, got {:?}",
1058                    ty
1059                )));
1060            }
1061
1062            let slot_vars = slots.group_slot_cnf_var(g)?;
1063            let labels = neural_slot_count_u32(slot_vars.len())?;
1064
1065            if prob_buf.num_rows() != labels as u64 {
1066                return Err(XlogError::Compilation(format!(
1067                    "Neural fast-path prob rows {} != labels {}",
1068                    prob_buf.num_rows(),
1069                    labels
1070                )));
1071            }
1072
1073            let prob_col = prob_buf.column(0).ok_or_else(|| {
1074                XlogError::Compilation("Neural fast-path missing prob column".to_string())
1075            })?;
1076
1077            let (var_log_true, var_log_false) = cache.var_log_weights_mut();
1078
1079            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1080            unsafe {
1081                fill.clone().launch(
1082                    LaunchConfig {
1083                        grid_dim: (1, 1, 1),
1084                        block_dim: (1, 1, 1),
1085                        shared_mem_bytes: 0,
1086                    },
1087                    (
1088                        prob_col,
1089                        labels,
1090                        &slot_vars,
1091                        cfg.eps,
1092                        cfg.min_p,
1093                        var_log_true,
1094                        var_log_false,
1095                    ),
1096                )
1097            }
1098            .map_err(|e| XlogError::Kernel(format!("neural_fill_ad_chain_f32 failed: {}", e)))?;
1099        }
1100
1101        // 2) Base run: out = dlogZ_base/dp
1102        cache.eval_grads_inplace_fused(state.handle())?;
1103        if let Some(base) = base_log_z.as_mut() {
1104            let root_view = cache.values().slice(root_idx..(root_idx + 1));
1105            device.dtod_copy(&root_view, base).map_err(|e| {
1106                XlogError::Kernel(format!("Failed to copy base logZ on GPU: {}", e))
1107            })?;
1108        }
1109        for (g, prob_buf) in probs.iter().enumerate() {
1110            let slot_vars = slots.group_slot_cnf_var(g)?;
1111            let labels = neural_slot_count_u32(slot_vars.len())?;
1112
1113            let out_buf = out_grads.get_mut(g).ok_or_else(|| {
1114                XlogError::Compilation("Neural fast-path missing output grad buffer".to_string())
1115            })?;
1116            if out_buf.arity() != 1 {
1117                return Err(XlogError::Compilation(
1118                    "Neural fast-path expects 1-column grad buffers".to_string(),
1119                ));
1120            }
1121            let out_ty = out_buf
1122                .schema()
1123                .column_type(0)
1124                .ok_or_else(|| XlogError::Compilation("Missing grad buffer schema".to_string()))?;
1125            if out_ty != ScalarType::F32 {
1126                return Err(XlogError::Compilation(format!(
1127                    "Neural fast-path expects grad dtype F32, got {:?}",
1128                    out_ty
1129                )));
1130            }
1131            if out_buf.num_rows() != labels as u64 {
1132                return Err(XlogError::Compilation(format!(
1133                    "Neural fast-path grad rows {} != labels {}",
1134                    out_buf.num_rows(),
1135                    labels
1136                )));
1137            }
1138
1139            let prob_col = prob_buf.column(0).ok_or_else(|| {
1140                XlogError::Compilation("Neural fast-path missing prob column".to_string())
1141            })?;
1142            let out_col = out_buf
1143                .columns
1144                .get_mut(0)
1145                .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
1146
1147            let shared_bytes: u32 = 3u64
1148                .checked_mul(labels as u64)
1149                .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
1150                .and_then(|n| u32::try_from(n).ok())
1151                .ok_or_else(|| {
1152                    XlogError::Kernel("Neural scatter shared memory overflow".to_string())
1153                })?;
1154
1155            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1156            unsafe {
1157                scatter.clone().launch(
1158                    LaunchConfig {
1159                        grid_dim: (1, 1, 1),
1160                        block_dim: (1, 1, 1),
1161                        shared_mem_bytes: shared_bytes,
1162                    },
1163                    (
1164                        prob_col,
1165                        labels,
1166                        &slot_vars,
1167                        cfg.eps,
1168                        cfg.min_p,
1169                        cache.grad_true(),
1170                        cache.grad_false(),
1171                        0u8,
1172                        out_col,
1173                    ),
1174                )
1175            }
1176            .map_err(|e| XlogError::Kernel(format!("neural_scatter (base) failed: {}", e)))?;
1177        }
1178
1179        // 3) Query run: out -= dlogZ_query/dp
1180        if query_var == 0 || query_var > self.max_var {
1181            return Err(XlogError::Compilation(format!(
1182                "Neural fast-path error: query var {} out of bounds (max_var={})",
1183                query_var, self.max_var
1184            )));
1185        }
1186
1187        let mut restore = state.provider.memory().alloc::<f64>(1)?;
1188        if expected_true {
1189            {
1190                let (_, var_log_false) = cache.var_log_weights_mut();
1191                force_query_var_false(state.provider(), var_log_false, query_var, &mut restore)?;
1192            }
1193        } else {
1194            {
1195                let (var_log_true, _) = cache.var_log_weights_mut();
1196                force_query_var_true(state.provider(), var_log_true, query_var, &mut restore)?;
1197            }
1198        }
1199
1200        cache.eval_grads_inplace_fused(state.handle())?;
1201        if let Some(out) = out_loss {
1202            let base = base_log_z
1203                .as_ref()
1204                .expect("base_log_z allocated when out_loss requested");
1205            let root_view = cache.values().slice(root_idx..(root_idx + 1));
1206            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1207            unsafe {
1208                binary_f64.clone().launch(
1209                    LaunchConfig {
1210                        grid_dim: (1, 1, 1),
1211                        block_dim: (1, 1, 1),
1212                        shared_mem_bytes: 0,
1213                    },
1214                    (base, &root_view, 1u32, 1u8, out),
1215                )
1216            }
1217            .map_err(|e| XlogError::Kernel(format!("Failed to compute NLL loss on GPU: {}", e)))?;
1218        }
1219        for (g, prob_buf) in probs.iter().enumerate() {
1220            let slot_vars = slots.group_slot_cnf_var(g)?;
1221            let labels = neural_slot_count_u32(slot_vars.len())?;
1222
1223            let prob_col = prob_buf.column(0).ok_or_else(|| {
1224                XlogError::Compilation("Neural fast-path missing prob column".to_string())
1225            })?;
1226            let out_col = out_grads[g]
1227                .columns
1228                .get_mut(0)
1229                .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
1230
1231            let shared_bytes: u32 = 3u64
1232                .checked_mul(labels as u64)
1233                .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
1234                .and_then(|n| u32::try_from(n).ok())
1235                .ok_or_else(|| {
1236                    XlogError::Kernel("Neural scatter shared memory overflow".to_string())
1237                })?;
1238
1239            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1240            unsafe {
1241                scatter.clone().launch(
1242                    LaunchConfig {
1243                        grid_dim: (1, 1, 1),
1244                        block_dim: (1, 1, 1),
1245                        shared_mem_bytes: shared_bytes,
1246                    },
1247                    (
1248                        prob_col,
1249                        labels,
1250                        &slot_vars,
1251                        cfg.eps,
1252                        cfg.min_p,
1253                        cache.grad_true(),
1254                        cache.grad_false(),
1255                        1u8,
1256                        out_col,
1257                    ),
1258                )
1259            }
1260            .map_err(|e| XlogError::Kernel(format!("neural_scatter (query) failed: {}", e)))?;
1261        }
1262        if expected_true {
1263            {
1264                let (_, var_log_false) = cache.var_log_weights_mut();
1265                restore_query_var_false(state.provider(), var_log_false, query_var, &restore)?;
1266            }
1267        } else {
1268            {
1269                let (var_log_true, _) = cache.var_log_weights_mut();
1270                restore_query_var_true(state.provider(), var_log_true, query_var, &restore)?;
1271            }
1272        }
1273
1274        Ok(())
1275    }
1276
1277    #[cfg(feature = "host-io")]
1278    pub fn evaluate_gpu_with_grads(&self) -> Result<ExactResultWithGrads> {
1279        if self.gpu.is_none() {
1280            if self.count_lift_gpu.is_some() {
1281                return Err(XlogError::UnsupportedEpistemicConstruct {
1282                    construct: "GPU exact gradient evaluation".to_string(),
1283                    context: "GPU count-lift exact backend does not expose gradient evaluation; \
1284                              gradient production paths require a compiled GPU-native Decision-DNNF exact backend"
1285                        .to_string(),
1286                });
1287            }
1288            return Ok(ExactResultWithGrads {
1289                log_z_e: 0.0,
1290                query_grads: Vec::new(),
1291            });
1292        }
1293
1294        let weights_len = if self.max_var == 0 {
1295            0
1296        } else {
1297            (self.max_var as usize) + 1
1298        };
1299
1300        let (log_z_e, grad_true_e, grad_false_e) = self.eval_log_z_and_grads_gpu_cached(None)?;
1301
1302        if log_z_e.is_infinite() && log_z_e.is_sign_negative() {
1303            return Err(XlogError::Execution(
1304                "Exact inference error: evidence is inconsistent (P(E)=0)".to_string(),
1305            ));
1306        }
1307
1308        let mut query_grads: Vec<QueryGradients> = Vec::with_capacity(self.queries.len());
1309
1310        for query in &self.queries {
1311            let Some(var) = query.var else {
1312                query_grads.push(QueryGradients {
1313                    atom: query.atom.clone(),
1314                    log_prob: f64::NEG_INFINITY,
1315                    prob: 0.0,
1316                    grad_true: vec![0.0; weights_len],
1317                    grad_false: vec![0.0; weights_len],
1318                });
1319                continue;
1320            };
1321
1322            let idx = var as usize;
1323            if idx >= weights_len {
1324                return Err(XlogError::Compilation(format!(
1325                    "Exact inference error: query var {} out of bounds (len={})",
1326                    var, weights_len
1327                )));
1328            }
1329
1330            let (log_z_eq, grad_true_eq, grad_false_eq) =
1331                self.eval_log_z_and_grads_gpu_cached(Some(var))?;
1332
1333            let log_prob = log_z_eq - log_z_e;
1334            let mut prob = if log_prob.is_infinite() && log_prob.is_sign_negative() {
1335                0.0
1336            } else {
1337                log_prob.exp()
1338            };
1339            if prob.is_nan() {
1340                return Err(XlogError::Execution(
1341                    "Exact inference error: NaN probability encountered".to_string(),
1342                ));
1343            }
1344            prob = prob.clamp(0.0, 1.0);
1345
1346            if grad_true_eq.len() != grad_true_e.len() || grad_false_eq.len() != grad_false_e.len()
1347            {
1348                return Err(XlogError::Execution(
1349                    "Exact inference error: gradient length mismatch".to_string(),
1350                ));
1351            }
1352
1353            let mut grad_true: Vec<f64> = grad_true_eq;
1354            let mut grad_false: Vec<f64> = grad_false_eq;
1355            for i in 0..grad_true.len() {
1356                grad_true[i] -= grad_true_e[i];
1357                grad_false[i] -= grad_false_e[i];
1358            }
1359
1360            query_grads.push(QueryGradients {
1361                atom: query.atom.clone(),
1362                log_prob,
1363                prob,
1364                grad_true,
1365                grad_false,
1366            });
1367        }
1368
1369        Ok(ExactResultWithGrads {
1370            log_z_e,
1371            query_grads,
1372        })
1373    }
1374
1375    fn compile_provenance_with_gpu(
1376        provenance: Provenance,
1377        config: GpuConfig,
1378        origin: ExactProgramOrigin,
1379    ) -> Result<Self> {
1380        if config.memory_bytes == 0 {
1381            return Err(XlogError::Kernel(
1382                "GPU memory budget must be non-zero".to_string(),
1383            ));
1384        }
1385
1386        let provenance = if config.decision_order_hint {
1387            crate::decision_order::apply_decision_order_hint(provenance)
1388        } else {
1389            provenance
1390        };
1391
1392        let mut roots_set: HashSet<crate::pir::PirNodeId> = HashSet::new();
1393
1394        let mut evidence_formulas: Vec<(crate::pir::PirNodeId, bool, GroundAtom)> = Vec::new();
1395        let mut evidence_atoms: std::collections::HashMap<GroundAtom, bool> =
1396            std::collections::HashMap::new();
1397        for (atom, value) in &provenance.evidence {
1398            if let Some(prev) = evidence_atoms.insert(atom.clone(), *value) {
1399                if prev != *value {
1400                    return Err(XlogError::Execution(format!(
1401                        "Exact inference error: conflicting evidence for {}",
1402                        display_atom(atom)
1403                    )));
1404                }
1405            }
1406
1407            let formula = provenance.query_formula(&atom.predicate, &atom.args);
1408            match formula {
1409                Some(id) => {
1410                    roots_set.insert(id);
1411                    evidence_formulas.push((id, *value, atom.clone()));
1412                }
1413                None => {
1414                    if *value {
1415                        return Err(XlogError::Execution(format!(
1416                            "Exact inference error: evidence atom is never derivable: {}",
1417                            display_atom(atom)
1418                        )));
1419                    }
1420                }
1421            }
1422        }
1423
1424        let mut queries: Vec<QuerySpec> = Vec::new();
1425        #[cfg(feature = "host-io")]
1426        let mut query_nodes: Vec<(usize, crate::pir::PirNodeId)> = Vec::new();
1427        for atom in &provenance.queries {
1428            let formula = provenance.query_formula(&atom.predicate, &atom.args);
1429            if let Some(id) = formula {
1430                roots_set.insert(id);
1431                #[cfg(feature = "host-io")]
1432                {
1433                    query_nodes.push((queries.len(), id));
1434                }
1435            }
1436            queries.push(QuerySpec {
1437                atom: atom.clone(),
1438                var: None,
1439            });
1440        }
1441
1442        // Ensure ALL probabilistic variable nodes (Decision, Lit, NegLit) are reachable
1443        // so they get CNF variables. This is required for the template/neural fast-path
1444        // where GpuWeightSlots expects one CNF variable per ChoiceVarId/LeafId.
1445        for (idx, node) in provenance.pir.nodes().iter().enumerate() {
1446            match node {
1447                crate::pir::PirNode::Decision { .. }
1448                | crate::pir::PirNode::Lit { .. }
1449                | crate::pir::PirNode::NegLit { .. } => {
1450                    roots_set.insert(crate::pir::PirNodeId::from_u32(idx as u32));
1451                }
1452                _ => {}
1453            }
1454        }
1455
1456        let mut roots: Vec<crate::pir::PirNodeId> = roots_set.into_iter().collect();
1457        roots.sort();
1458
1459        if roots.is_empty() {
1460            return Ok(Self {
1461                gpu: None,
1462                count_lift_gpu: None,
1463                queries,
1464                random_vars: None,
1465                max_var: 0,
1466                origin,
1467                gpu_config: config,
1468                last_compile_profile: None,
1469            });
1470        }
1471
1472        let count_lift_gpu = try_build_count_lift_gpu_state(&provenance, &queries, config)?;
1473        if let Some(count_lift_gpu) = count_lift_gpu {
1474            return Ok(Self {
1475                gpu: None,
1476                count_lift_gpu: Some(count_lift_gpu),
1477                queries,
1478                random_vars: None,
1479                max_var: 0,
1480                origin,
1481                gpu_config: config,
1482                last_compile_profile: None,
1483            });
1484        }
1485
1486        let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
1487        let memory = Arc::new(GpuMemoryManager::new(
1488            device.clone(),
1489            MemoryBudget::with_limit(config.memory_bytes),
1490        ));
1491        let provider = Arc::new(CudaKernelProvider::new(device, memory)?);
1492
1493        let canonical_cnf_hash = crate::cnf::canonical_pir_hash(&provenance.pir, &roots)?;
1494        let gpu_pir = GpuPirGraph::from_host(&provenance.pir, &provider)?;
1495        let gpu_roots = GpuPirRoots::from_host(&roots, &provider)?;
1496        let encoding = encode_cnf_gpu(&gpu_pir, &gpu_roots, &provider)?;
1497        if encoding.vars.max_var != encoding.cnf.var_cap {
1498            return Err(XlogError::Compilation(format!(
1499                "Exact inference error: CNF var_cap {} != vars.max_var {}",
1500                encoding.cnf.var_cap, encoding.vars.max_var
1501            )));
1502        }
1503
1504        let (leaf_probs_host, choice_true_host, choice_false_host) =
1505            build_weight_sources(&provenance)?;
1506
1507        let leaf_probs = upload_f64(&provider, &leaf_probs_host)?;
1508        let choice_true = upload_f64(&provider, &choice_true_host)?;
1509        let choice_false = upload_f64(&provider, &choice_false_host)?;
1510
1511        let evidence_by_var = if evidence_formulas.is_empty() {
1512            let mut evidence = provider
1513                .memory()
1514                .alloc::<u8>((encoding.vars.max_var as usize) + 1)?;
1515            provider
1516                .device()
1517                .inner()
1518                .memset_zeros(&mut evidence)
1519                .map_err(|e| XlogError::Kernel(format!("Failed to zero evidence buffer: {}", e)))?;
1520            evidence
1521        } else {
1522            let mut nodes: Vec<u32> = Vec::with_capacity(evidence_formulas.len());
1523            let mut vals: Vec<u8> = Vec::with_capacity(evidence_formulas.len());
1524            for (node, value, _atom) in &evidence_formulas {
1525                nodes.push(node.as_u32());
1526                vals.push(if *value { 1u8 } else { 2u8 });
1527            }
1528            let evidence_nodes = upload_u32(&provider, &nodes)?;
1529            let evidence_vals = upload_u8(&provider, &vals)?;
1530            build_evidence_by_var_gpu(
1531                &encoding.vars.node_var,
1532                &evidence_nodes,
1533                &evidence_vals,
1534                encoding.vars.max_var,
1535                &provider,
1536            )?
1537        };
1538
1539        let weights = build_weights_gpu(
1540            &encoding.vars,
1541            &leaf_probs,
1542            &choice_true,
1543            &choice_false,
1544            &evidence_by_var,
1545            &provider,
1546        )?;
1547        let random_var_count = leaf_probs_host
1548            .len()
1549            .checked_add(choice_true_host.len())
1550            .ok_or_else(|| XlogError::Compilation("random var count overflow".to_string()))?;
1551        let random_var_count = u32::try_from(random_var_count)
1552            .map_err(|_| XlogError::Compilation("random var count exceeds u32".to_string()))?;
1553        let num_leaf_probs = u32::try_from(leaf_probs_host.len())
1554            .map_err(|_| XlogError::Compilation("leaf_probs count exceeds u32".to_string()))?;
1555        let num_choice_probs = u32::try_from(choice_true_host.len())
1556            .map_err(|_| XlogError::Compilation("choice_probs count exceeds u32".to_string()))?;
1557        let (random_var_list, actual_random_var_count) = collect_random_vars_device(
1558            &provider,
1559            &encoding.vars,
1560            num_leaf_probs,
1561            num_choice_probs,
1562            random_var_count,
1563        )?;
1564        let random_vars =
1565            DeviceRandomVarList::from_device(random_var_list, actual_random_var_count)?;
1566
1567        let compile_config = default_compile_config(&encoding.cnf, config.memory_bytes)?;
1568        let cache_config = default_cache_config(&encoding.cnf, &compile_config)?;
1569
1570        let mut cache = GpuCircuitCache::new(&provider, cache_config)?;
1571        let (handle, compile_profile) = compile_gpu_d4_and_verify_cached(
1572            &encoding.cnf,
1573            &encoding.decision_var_limit,
1574            &provider,
1575            &compile_config,
1576            &mut cache,
1577            &random_vars,
1578            Some(canonical_cnf_hash),
1579        )?;
1580        cache.store_weights(&handle, &weights.log_true, &weights.log_false)?;
1581
1582        #[cfg(feature = "host-io")]
1583        if !query_nodes.is_empty() {
1584            let mut node_ids: Vec<u32> = Vec::with_capacity(query_nodes.len());
1585            for (_idx, node) in &query_nodes {
1586                node_ids.push(node.as_u32());
1587            }
1588            let node_ids_device = upload_u32(&provider, &node_ids)?;
1589            let vars_device = map_nodes_to_vars_gpu(
1590                &encoding.vars.node_var,
1591                &node_ids_device,
1592                encoding.vars.max_var,
1593                &provider,
1594            )?;
1595
1596            let mut vars_host = vec![0u32; vars_device.len()];
1597            provider
1598                .device()
1599                .inner()
1600                .dtoh_sync_copy_into(&vars_device, &mut vars_host)
1601                .map_err(|e| XlogError::Kernel(format!("Failed to read query vars: {}", e)))?;
1602
1603            for (i, (query_idx, _)) in query_nodes.iter().enumerate() {
1604                let var = vars_host[i];
1605                queries[*query_idx].var = Some(var);
1606            }
1607        }
1608
1609        let state = GpuExactState::new(provider, cache, handle)?;
1610
1611        Ok(Self {
1612            gpu: Some(Arc::new(state)),
1613            count_lift_gpu: None,
1614            queries,
1615            random_vars: Some(Arc::new(random_vars)),
1616            max_var: encoding.vars.max_var,
1617            origin,
1618            gpu_config: config,
1619            last_compile_profile: compile_profile,
1620        })
1621    }
1622
1623    #[cfg(feature = "host-io")]
1624    fn eval_log_z_gpu(&self, query_true: Option<u32>) -> Result<f64> {
1625        let state = self.gpu_state()?;
1626        let mut cache = state
1627            .cache
1628            .lock()
1629            .unwrap_or_else(|poisoned| poisoned.into_inner());
1630
1631        if let Some(var) = query_true {
1632            if var == 0 || var > self.max_var {
1633                return Err(XlogError::Compilation(format!(
1634                    "Exact inference error: query var {} out of bounds (max_var={})",
1635                    var, self.max_var
1636                )));
1637            }
1638        }
1639
1640        let mut restore = None;
1641        if let Some(var) = query_true {
1642            let mut buf = state.provider.memory().alloc::<f64>(1)?;
1643            {
1644                let (_, var_log_false) = cache.var_log_weights_mut();
1645                force_query_var_false(state.provider(), var_log_false, var, &mut buf)?;
1646            }
1647            restore = Some((var, buf));
1648        }
1649
1650        let mut out_log_z = state.provider.memory().alloc::<f64>(1)?;
1651        let eval_result = cache.eval_log_wmc_device_inplace(state.handle(), &mut out_log_z);
1652
1653        if let Some((var, buf)) = restore {
1654            let (_, var_log_false) = cache.var_log_weights_mut();
1655            let restore_result =
1656                restore_query_var_false(state.provider(), var_log_false, var, &buf);
1657            if let Err(err) = eval_result {
1658                restore_result?;
1659                return Err(err);
1660            }
1661            restore_result?;
1662        } else {
1663            eval_result?;
1664        }
1665
1666        let mut host = [0.0f64];
1667        state
1668            .provider
1669            .device()
1670            .inner()
1671            .dtoh_sync_copy_into(&out_log_z, &mut host)
1672            .map_err(|e| XlogError::Kernel(format!("Failed to read logZ: {}", e)))?;
1673        Ok(host[0])
1674    }
1675
1676    fn gpu_state(&self) -> Result<Arc<GpuExactState>> {
1677        self.gpu.clone().ok_or_else(|| {
1678            XlogError::Execution(
1679                "Exact inference GPU error: program has no compiled circuit".to_string(),
1680            )
1681        })
1682    }
1683
1684    #[cfg(feature = "host-io")]
1685    fn eval_log_z_and_grads_gpu_cached(
1686        &self,
1687        query_true: Option<u32>,
1688    ) -> Result<(f64, Vec<f64>, Vec<f64>)> {
1689        let state = self.gpu_state()?;
1690        let mut cache = state
1691            .cache
1692            .lock()
1693            .unwrap_or_else(|poisoned| poisoned.into_inner());
1694
1695        if let Some(var) = query_true {
1696            if var == 0 || var > self.max_var {
1697                return Err(XlogError::Compilation(format!(
1698                    "Exact inference error: query var {} out of bounds (max_var={})",
1699                    var, self.max_var
1700                )));
1701            }
1702        }
1703
1704        let mut restore = None;
1705        if let Some(var) = query_true {
1706            let mut buf = state.provider.memory().alloc::<f64>(1)?;
1707            {
1708                let (_, var_log_false) = cache.var_log_weights_mut();
1709                force_query_var_false(state.provider(), var_log_false, var, &mut buf)?;
1710            }
1711            restore = Some((var, buf));
1712        }
1713
1714        let eval_result = cache.eval_grads_inplace(state.handle());
1715
1716        if let Some((var, buf)) = restore {
1717            let (_, var_log_false) = cache.var_log_weights_mut();
1718            let restore_result =
1719                restore_query_var_false(state.provider(), var_log_false, var, &buf);
1720            if let Err(err) = eval_result {
1721                restore_result?;
1722                return Err(err);
1723            }
1724            restore_result?;
1725        } else {
1726            eval_result?;
1727        }
1728
1729        let weights_len = if self.max_var == 0 {
1730            0
1731        } else {
1732            (self.max_var as usize) + 1
1733        };
1734
1735        let device = state.provider.device().inner();
1736        let mut host_grad_true: Vec<f64> = vec![0.0; weights_len];
1737        let mut host_grad_false: Vec<f64> = vec![0.0; weights_len];
1738
1739        let root_idx = state.handle().root() as usize;
1740        let root_view = cache.values().slice(root_idx..(root_idx + 1));
1741        let mut log_z = [0.0_f64];
1742        device
1743            .dtoh_sync_copy_into(&root_view, &mut log_z)
1744            .map_err(|e| XlogError::Kernel(format!("Failed to read logZ: {}", e)))?;
1745
1746        // Gradient buffers are multi-slot: [slot0_var0..slot0_varN, slot1_var0..].
1747        // Slice into the correct slot to download only this circuit's gradients.
1748        let var_stride = cache.var_stride()? as usize;
1749        let slot = state.handle().slot_index() as usize;
1750        let grad_start = slot * var_stride;
1751        let grad_end = grad_start + weights_len;
1752        let grad_true_slot = cache.grad_true().slice(grad_start..grad_end);
1753        let grad_false_slot = cache.grad_false().slice(grad_start..grad_end);
1754        device
1755            .dtoh_sync_copy_into(&grad_true_slot, &mut host_grad_true)
1756            .map_err(|e| XlogError::Kernel(format!("Failed to download grad_true: {}", e)))?;
1757        device
1758            .dtoh_sync_copy_into(&grad_false_slot, &mut host_grad_false)
1759            .map_err(|e| XlogError::Kernel(format!("Failed to download grad_false: {}", e)))?;
1760
1761        Ok((log_z[0], host_grad_true, host_grad_false))
1762    }
1763}
1764
1765fn try_build_count_lift_gpu_state(
1766    provenance: &Provenance,
1767    queries: &[QuerySpec],
1768    config: GpuConfig,
1769) -> Result<Option<Arc<GpuCountLiftState>>> {
1770    if queries.is_empty() || !provenance.evidence.is_empty() || !provenance.choice_probs.is_empty()
1771    {
1772        return Ok(None);
1773    }
1774
1775    let fired_count_predicates: HashSet<&str> = provenance
1776        .aggregate_lifting
1777        .iter()
1778        .filter(|entry| {
1779            entry.status == AggregateLiftStatus::Fired
1780                && entry.operator.as_str() == "count"
1781                && entry.deterministic_rows == 0
1782        })
1783        .map(|entry| entry.predicate.as_str())
1784        .collect();
1785    if fired_count_predicates.is_empty() {
1786        return Ok(None);
1787    }
1788    if queries
1789        .iter()
1790        .any(|query| !fired_count_predicates.contains(query.atom.predicate.as_str()))
1791    {
1792        return Ok(None);
1793    }
1794
1795    let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
1796    let memory = Arc::new(GpuMemoryManager::new(
1797        device.clone(),
1798        MemoryBudget::with_limit(config.memory_bytes),
1799    ));
1800    let provider = Arc::new(CudaKernelProvider::new(device, memory)?);
1801    let mut gpu_queries = Vec::with_capacity(queries.len());
1802    for query in queries {
1803        let target_count = match count_lift_query_target(query)? {
1804            Some(target) => target,
1805            None => return Ok(None),
1806        };
1807        let root = match provenance.query_formula(&query.atom.predicate, &query.atom.args) {
1808            Some(root) => root,
1809            None => return Ok(None),
1810        };
1811        let mut leaves = HashSet::new();
1812        collect_count_lift_leaves(provenance, root, &mut leaves)?;
1813        if leaves.is_empty() || leaves.len() > 64 {
1814            return Ok(None);
1815        }
1816        if target_count > leaves.len() as u32 {
1817            return Ok(None);
1818        }
1819        let mut leaves: Vec<_> = leaves.into_iter().collect();
1820        leaves.sort_by_key(|leaf| leaf.as_u32());
1821        let mut leaf_probs_host = Vec::with_capacity(leaves.len());
1822        for leaf in leaves {
1823            let p = *provenance.leaf_probs.get(&leaf).ok_or_else(|| {
1824                XlogError::Compilation(format!(
1825                    "Count-lift GPU evaluator missing probability for leaf {}",
1826                    leaf.as_u32()
1827                ))
1828            })?;
1829            leaf_probs_host.push(p);
1830        }
1831        let leaf_count = u32::try_from(leaf_probs_host.len())
1832            .map_err(|_| XlogError::Compilation("count-lift leaf count exceeds u32".to_string()))?;
1833        let leaf_probs = upload_f64(&provider, &leaf_probs_host)?;
1834        gpu_queries.push(GpuCountLiftQuery {
1835            atom: query.atom.clone(),
1836            target_count,
1837            leaf_count,
1838            leaf_probs,
1839        });
1840    }
1841    Ok(Some(Arc::new(GpuCountLiftState::new(
1842        provider,
1843        gpu_queries,
1844    ))))
1845}
1846
1847fn count_lift_query_target(query: &QuerySpec) -> Result<Option<u32>> {
1848    match query.atom.args.last() {
1849        Some(Value::I64(value)) if *value >= 0 => u32::try_from(*value)
1850            .map(Some)
1851            .map_err(|_| XlogError::Compilation("count-lift target exceeds u32".to_string())),
1852        _ => Ok(None),
1853    }
1854}
1855
1856fn collect_count_lift_leaves(
1857    provenance: &Provenance,
1858    node: crate::pir::PirNodeId,
1859    leaves: &mut HashSet<crate::pir::LeafId>,
1860) -> Result<()> {
1861    let pir_node = provenance.pir.node(node).ok_or_else(|| {
1862        XlogError::Compilation(format!(
1863            "Count-lift GPU evaluator saw invalid PIR node {}",
1864            node.as_u32()
1865        ))
1866    })?;
1867    match pir_node {
1868        crate::pir::PirNode::Const(_) => Ok(()),
1869        crate::pir::PirNode::Lit { leaf } | crate::pir::PirNode::NegLit { leaf } => {
1870            leaves.insert(*leaf);
1871            Ok(())
1872        }
1873        crate::pir::PirNode::And { children } | crate::pir::PirNode::Or { children } => {
1874            for child in children {
1875                collect_count_lift_leaves(provenance, *child, leaves)?;
1876            }
1877            Ok(())
1878        }
1879        crate::pir::PirNode::Decision { .. } => Err(XlogError::Compilation(
1880            "Count-lift GPU evaluator does not support annotated-disjunction choices".to_string(),
1881        )),
1882    }
1883}
1884
1885fn force_query_var_false(
1886    provider: &Arc<CudaKernelProvider>,
1887    log_false: &mut TrackedCudaSlice<f64>,
1888    var: u32,
1889    restore: &mut TrackedCudaSlice<f64>,
1890) -> Result<()> {
1891    let device = provider.device().inner();
1892    let func = device
1893        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FORCE_VAR_FALSE)
1894        .ok_or_else(|| XlogError::Kernel("weights_force_var_false kernel not found".to_string()))?;
1895    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1896    unsafe {
1897        func.clone().launch(
1898            LaunchConfig {
1899                grid_dim: (1, 1, 1),
1900                block_dim: (1, 1, 1),
1901                shared_mem_bytes: 0,
1902            },
1903            (var, log_false, restore),
1904        )
1905    }
1906    .map_err(|e| XlogError::Kernel(format!("weights_force_var_false failed: {}", e)))?;
1907    Ok(())
1908}
1909
1910fn restore_query_var_false(
1911    provider: &Arc<CudaKernelProvider>,
1912    log_false: &mut TrackedCudaSlice<f64>,
1913    var: u32,
1914    restore: &TrackedCudaSlice<f64>,
1915) -> Result<()> {
1916    let device = provider.device().inner();
1917    let func = device
1918        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_RESTORE_VAR_FALSE)
1919        .ok_or_else(|| {
1920            XlogError::Kernel("weights_restore_var_false kernel not found".to_string())
1921        })?;
1922    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1923    unsafe {
1924        func.clone().launch(
1925            LaunchConfig {
1926                grid_dim: (1, 1, 1),
1927                block_dim: (1, 1, 1),
1928                shared_mem_bytes: 0,
1929            },
1930            (var, log_false, restore),
1931        )
1932    }
1933    .map_err(|e| XlogError::Kernel(format!("weights_restore_var_false failed: {}", e)))?;
1934    Ok(())
1935}
1936
1937fn force_query_var_true(
1938    provider: &Arc<CudaKernelProvider>,
1939    log_true: &mut TrackedCudaSlice<f64>,
1940    var: u32,
1941    restore: &mut TrackedCudaSlice<f64>,
1942) -> Result<()> {
1943    let device = provider.device().inner();
1944    let func = device
1945        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FORCE_VAR_TRUE)
1946        .ok_or_else(|| XlogError::Kernel("weights_force_var_true kernel not found".to_string()))?;
1947    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1948    unsafe {
1949        func.clone().launch(
1950            LaunchConfig {
1951                grid_dim: (1, 1, 1),
1952                block_dim: (1, 1, 1),
1953                shared_mem_bytes: 0,
1954            },
1955            (var, log_true, restore),
1956        )
1957    }
1958    .map_err(|e| XlogError::Kernel(format!("weights_force_var_true failed: {}", e)))?;
1959    Ok(())
1960}
1961
1962fn restore_query_var_true(
1963    provider: &Arc<CudaKernelProvider>,
1964    log_true: &mut TrackedCudaSlice<f64>,
1965    var: u32,
1966    restore: &TrackedCudaSlice<f64>,
1967) -> Result<()> {
1968    let device = provider.device().inner();
1969    let func = device
1970        .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_RESTORE_VAR_TRUE)
1971        .ok_or_else(|| {
1972            XlogError::Kernel("weights_restore_var_true kernel not found".to_string())
1973        })?;
1974    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1975    unsafe {
1976        func.clone().launch(
1977            LaunchConfig {
1978                grid_dim: (1, 1, 1),
1979                block_dim: (1, 1, 1),
1980                shared_mem_bytes: 0,
1981            },
1982            (var, log_true, restore),
1983        )
1984    }
1985    .map_err(|e| XlogError::Kernel(format!("weights_restore_var_true failed: {}", e)))?;
1986    Ok(())
1987}
1988
1989pub(crate) fn default_compile_config(
1990    cnf: &xlog_solve::GpuCnf,
1991    memory_bytes: u64,
1992) -> Result<GpuCompileConfig> {
1993    // Must match the default GPU-native Decision-DNNF compiler configuration expected
1994    // by the Python training paths.
1995    // Sizing is conservative and strictly bounded by `GpuCompileConfig::{smooth_node_cap,smooth_edge_cap}`.
1996    let frontier_depth: u16 = 6;
1997
1998    let var_cap = cnf.var_cap.max(1);
1999    let trail_bytes_per_item = (var_cap as u64)
2000        .checked_add(1)
2001        .and_then(|v| v.checked_mul(std::mem::size_of::<i32>() as u64))
2002        .ok_or_else(|| XlogError::Compilation("trail size overflow".to_string()))?;
2003    let denom = trail_bytes_per_item
2004        .checked_mul(8)
2005        .ok_or_else(|| XlogError::Compilation("trail memory denominator overflow".to_string()))?;
2006    if memory_bytes
2007        < denom.checked_mul(8).ok_or_else(|| {
2008            XlogError::Compilation("minimum frontier memory requirement overflow".to_string())
2009        })?
2010    {
2011        return Err(XlogError::Compilation(format!(
2012            "memory budget {} cannot hold the minimum GPU-native Decision-DNNF frontier allocation",
2013            memory_bytes
2014        )));
2015    }
2016    let max_items_by_trail = memory_bytes / denom;
2017    let max_frontier_items = max_items_by_trail.min(4096).min(u64::from(u32::MAX)) as u32;
2018
2019    // The GPU-native Decision-DNNF compiler emits one leaf circuit per frontier item;
2020    // caps must scale with the maximum frontier size (up to 2^frontier_depth,
2021    // bounded by max_frontier_items).
2022    let frontier_cap_factor = (1u64
2023        .checked_shl(frontier_depth as u32)
2024        .unwrap_or(u64::from(u32::MAX)))
2025    .min(u64::from(max_frontier_items)) as u32;
2026
2027    let per_item_nodes = cnf
2028        .var_cap
2029        .checked_mul(5)
2030        .ok_or_else(|| XlogError::Compilation("smooth_node_cap overflow".to_string()))?
2031        .max(1024);
2032    let smooth_node_cap = per_item_nodes
2033        .checked_mul(frontier_cap_factor)
2034        .ok_or_else(|| XlogError::Compilation("smooth_node_cap overflow".to_string()))?;
2035
2036    // Edge capacity scales with node capacity; AND/OR fanout grows edges but stays within a small
2037    // multiple of nodes for the compiler's frontier emission patterns.
2038    let mut smooth_edge_cap = smooth_node_cap
2039        .checked_mul(2)
2040        .ok_or_else(|| XlogError::Compilation("smooth_edge_cap overflow".to_string()))?;
2041    if smooth_edge_cap < max_frontier_items {
2042        smooth_edge_cap = max_frontier_items;
2043    }
2044
2045    // The verifier's UNSAT certificate (resolution trace) can be large even when the source CNF
2046    // is small, because equivalence checking builds CNF(C) with many Tseitin variables/clauses.
2047    // Allocate a larger share of the budget to the GPU CDCL arenas to avoid deterministic
2048    // overflow errors in production verifier paths.
2049    let mut cdcl_learned_bytes = memory_bytes / 8;
2050    if cdcl_learned_bytes < 4 * 1024 * 1024 {
2051        cdcl_learned_bytes = 4 * 1024 * 1024;
2052    }
2053
2054    let config = GpuCompileConfig {
2055        frontier_depth,
2056        max_frontier_items,
2057        max_depth: 128,
2058        smooth_node_cap,
2059        smooth_edge_cap,
2060        cdcl_restart_interval: 64,
2061        cdcl_learned_bytes,
2062        cdcl_conflict_budget: None,
2063        incremental_verify: false,
2064    };
2065    Ok(config)
2066}
2067
2068pub(crate) fn default_cache_config(
2069    cnf: &xlog_solve::GpuCnf,
2070    compile: &GpuCompileConfig,
2071) -> Result<GpuCircuitCacheConfig> {
2072    if compile.smooth_node_cap == 0 || compile.smooth_edge_cap == 0 {
2073        return Err(XlogError::Compilation(
2074            "GPU cache config requires non-zero smoothing caps".to_string(),
2075        ));
2076    }
2077    Ok(GpuCircuitCacheConfig {
2078        num_slots: 4, // Hold 4 circuit templates; power-of-2 hash table.
2079        table_size: 8,
2080        node_cap: compile.smooth_node_cap,
2081        edge_cap: compile.smooth_edge_cap,
2082        level_cap: compile.smooth_node_cap,
2083        var_cap: cnf.var_cap,
2084    })
2085}
2086
2087pub(crate) fn build_weight_sources(
2088    provenance: &Provenance,
2089) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>)> {
2090    let max_leaf = provenance.leaf_probs.keys().map(|leaf| leaf.as_u32()).max();
2091    let leaf_len = max_leaf.map(|v| v as usize + 1).unwrap_or(0);
2092    let mut leaf_probs = vec![0.0f64; leaf_len];
2093    let mut leaf_seen = vec![false; leaf_len];
2094    for (leaf, p) in &provenance.leaf_probs {
2095        let idx = leaf.as_u32() as usize;
2096        if idx >= leaf_len {
2097            return Err(XlogError::Compilation(
2098                "leaf probability index out of bounds".to_string(),
2099            ));
2100        }
2101        leaf_probs[idx] = *p;
2102        leaf_seen[idx] = true;
2103    }
2104    if let Some((idx, _)) = leaf_seen.iter().enumerate().find(|(_, seen)| !**seen) {
2105        return Err(XlogError::Compilation(format!(
2106            "missing probability for leaf {}",
2107            idx
2108        )));
2109    }
2110
2111    let max_choice = provenance
2112        .choice_probs
2113        .keys()
2114        .map(|choice| choice.as_u32())
2115        .max();
2116    let choice_len = max_choice.map(|v| v as usize + 1).unwrap_or(0);
2117    let mut choice_true = vec![0.0f64; choice_len];
2118    let mut choice_false = vec![0.0f64; choice_len];
2119    let mut choice_seen = vec![false; choice_len];
2120    for (choice, (pt, pf)) in &provenance.choice_probs {
2121        let idx = choice.as_u32() as usize;
2122        if idx >= choice_len {
2123            return Err(XlogError::Compilation(
2124                "choice probability index out of bounds".to_string(),
2125            ));
2126        }
2127        choice_true[idx] = *pt;
2128        choice_false[idx] = *pf;
2129        choice_seen[idx] = true;
2130    }
2131    if let Some((idx, _)) = choice_seen.iter().enumerate().find(|(_, seen)| !**seen) {
2132        return Err(XlogError::Compilation(format!(
2133            "missing probability for choice {}",
2134            idx
2135        )));
2136    }
2137
2138    Ok((leaf_probs, choice_true, choice_false))
2139}
2140
2141pub(crate) fn upload_u32(
2142    provider: &Arc<CudaKernelProvider>,
2143    host: &[u32],
2144) -> Result<TrackedCudaSlice<u32>> {
2145    let memory = provider.memory();
2146    let mut buf = memory.alloc::<u32>(host.len())?;
2147    provider
2148        .htod_sync_copy_into_tracked(host, &mut buf)
2149        .map_err(|e| XlogError::Kernel(format!("Failed to upload u32 buffer: {}", e)))?;
2150    Ok(buf)
2151}
2152
2153pub(crate) fn upload_u8(
2154    provider: &Arc<CudaKernelProvider>,
2155    host: &[u8],
2156) -> Result<TrackedCudaSlice<u8>> {
2157    let memory = provider.memory();
2158    let mut buf = memory.alloc::<u8>(host.len())?;
2159    provider
2160        .htod_sync_copy_into_tracked(host, &mut buf)
2161        .map_err(|e| XlogError::Kernel(format!("Failed to upload u8 buffer: {}", e)))?;
2162    Ok(buf)
2163}
2164
2165pub(crate) fn upload_f64(
2166    provider: &Arc<CudaKernelProvider>,
2167    host: &[f64],
2168) -> Result<TrackedCudaSlice<f64>> {
2169    let memory = provider.memory();
2170    let mut buf = memory.alloc::<f64>(host.len())?;
2171    provider
2172        .htod_sync_copy_into_tracked(host, &mut buf)
2173        .map_err(|e| XlogError::Kernel(format!("Failed to upload f64 buffer: {}", e)))?;
2174    Ok(buf)
2175}
2176
2177fn capture_compact_count_device(
2178    provider: &Arc<CudaKernelProvider>,
2179    prefix_sum: &TrackedCudaSlice<u32>,
2180    mask: &TrackedCudaSlice<u8>,
2181    n: u32,
2182) -> Result<TrackedCudaSlice<u32>> {
2183    let mut out = provider.memory().alloc::<u32>(1)?;
2184    let device = provider.device().inner();
2185    let capture_fn = device
2186        .get_func(FILTER_MODULE, filter_kernels::CAPTURE_COMPACT_COUNT)
2187        .ok_or_else(|| XlogError::Kernel("capture_compact_count kernel not found".to_string()))?;
2188    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2189    unsafe {
2190        capture_fn.clone().launch(
2191            LaunchConfig {
2192                grid_dim: (1, 1, 1),
2193                block_dim: (1, 1, 1),
2194                shared_mem_bytes: 0,
2195            },
2196            (prefix_sum, mask, n, &mut out),
2197        )
2198    }
2199    .map_err(|e| XlogError::Kernel(format!("capture_compact_count failed: {}", e)))?;
2200    Ok(out)
2201}
2202
2203pub(crate) fn collect_random_vars_device(
2204    provider: &Arc<CudaKernelProvider>,
2205    vars: &GpuCnfVarTables,
2206    num_leaf_probs: u32,
2207    num_choice_probs: u32,
2208    _expected_count: u32,
2209) -> Result<(TrackedCudaSlice<u32>, u32)> {
2210    let device = provider.device().inner();
2211    let memory = provider.memory();
2212
2213    let mask_len = vars
2214        .max_var
2215        .checked_add(1)
2216        .ok_or_else(|| XlogError::Compilation("random var mask_len overflow".to_string()))?;
2217    let mask_len_usize = usize::try_from(mask_len)
2218        .map_err(|_| XlogError::Compilation("random var mask_len exceeds usize".to_string()))?;
2219
2220    let mut mask = memory.alloc::<u8>(mask_len_usize)?;
2221    device
2222        .memset_zeros(&mut mask)
2223        .map_err(|e| XlogError::Kernel(format!("Failed to zero random var mask: {}", e)))?;
2224
2225    let mut iota = memory.alloc::<u32>(mask_len_usize)?;
2226    let fill_iota = device
2227        .get_func(FILTER_MODULE, filter_kernels::FILL_U32_IOTA)
2228        .ok_or_else(|| XlogError::Kernel("fill_u32_iota kernel not found".to_string()))?;
2229    let block_size = 256u32;
2230    let grid = checked_launch_grid_u32("fill random-var iota", mask_len, block_size)?;
2231    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2232    unsafe {
2233        fill_iota.clone().launch(
2234            LaunchConfig {
2235                grid_dim: (grid, 1, 1),
2236                block_dim: (block_size, 1, 1),
2237                shared_mem_bytes: 0,
2238            },
2239            (&mut iota, mask_len, 0u32),
2240        )
2241    }
2242    .map_err(|e| XlogError::Kernel(format!("fill_u32_iota failed: {}", e)))?;
2243
2244    // Only iterate over the probabilistic entries — leaf_var and choice_var are allocated
2245    // to num_nodes but only the first num_leaf_probs / num_choice_probs entries correspond
2246    // to variables with actual probabilities. Non-probabilistic PIR leaf nodes also get
2247    // CNF variables but must NOT be marked as random.
2248    let leaf_len = num_leaf_probs;
2249    let choice_len = num_choice_probs;
2250
2251    let mark_kernel = device
2252        .get_func(FILTER_MODULE, filter_kernels::MARK_RANDOM_VARS)
2253        .ok_or_else(|| XlogError::Kernel("mark_random_vars kernel not found".to_string()))?;
2254    let mark_n = leaf_len.max(choice_len);
2255    if mark_n > 0 {
2256        let grid = checked_launch_grid_u32("mark random vars", mark_n, block_size)?;
2257        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2258        unsafe {
2259            mark_kernel.clone().launch(
2260                LaunchConfig {
2261                    grid_dim: (grid, 1, 1),
2262                    block_dim: (block_size, 1, 1),
2263                    shared_mem_bytes: 0,
2264                },
2265                (
2266                    &vars.leaf_var,
2267                    &vars.choice_var,
2268                    leaf_len,
2269                    choice_len,
2270                    &mut mask,
2271                    mask_len,
2272                ),
2273            )
2274        }
2275        .map_err(|e| XlogError::Kernel(format!("mark_random_vars failed: {}", e)))?;
2276    }
2277
2278    let prefix_sum = provider.scan_u8_mask_device(&mask, mask_len)?;
2279    let count_device = capture_compact_count_device(provider, &prefix_sum, &mask, mask_len)?;
2280
2281    // Read the actual random var count from device (the GPU scan result is authoritative).
2282    // The host-side expected_count can be wrong when some ChoiceVarIds are unreachable
2283    // from query/evidence roots and don't get assigned CNF variables.
2284    let actual_count = {
2285        let mut buf = vec![0u32; 1];
2286        device
2287            .dtoh_sync_copy_into(&count_device, &mut buf)
2288            .map_err(|e| XlogError::Kernel(format!("dtoh count_device failed: {}", e)))?;
2289        buf[0]
2290    };
2291
2292    if actual_count == 0 {
2293        // No random variables in the circuit — return empty list.
2294        let out = provider.memory().alloc::<u32>(0)?;
2295        return Ok((out, 0));
2296    }
2297
2298    let mut out = memory.alloc::<u32>(mask_len_usize)?;
2299    let compact_fn = device
2300        .get_func(FILTER_MODULE, filter_kernels::COMPACT_U32_BY_MASK)
2301        .ok_or_else(|| XlogError::Kernel("compact_u32_by_mask kernel not found".to_string()))?;
2302    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
2303    unsafe {
2304        compact_fn.clone().launch(
2305            LaunchConfig {
2306                grid_dim: (grid, 1, 1),
2307                block_dim: (block_size, 1, 1),
2308                shared_mem_bytes: 0,
2309            },
2310            (&iota, &mask, &prefix_sum, mask_len, &mut out),
2311        )
2312    }
2313    .map_err(|e| XlogError::Kernel(format!("compact_u32_by_mask failed: {}", e)))?;
2314
2315    Ok((out, actual_count))
2316}
2317
2318fn display_atom(atom: &GroundAtom) -> String {
2319    if atom.args.is_empty() {
2320        format!("{}()", atom.predicate)
2321    } else {
2322        format!("{}({} args)", atom.predicate, atom.args.len())
2323    }
2324}
2325
2326#[cfg(all(test, feature = "host-io"))]
2327mod tests {
2328    use super::*;
2329    use xlog_cuda::CudaDevice;
2330
2331    #[test]
2332    fn test_exact_negation_probability() {
2333        let _gpu_guard = crate::test_gpu_lock::lock();
2334        if CudaDevice::new(0).is_err() {
2335            eprintln!("Skipping test: CUDA runtime unavailable");
2336            return;
2337        }
2338        // 0.3::rain(). dry() :- not rain().
2339        // P(dry) = P(not rain) = 1 - 0.3 = 0.7
2340        let source = r#"
23410.3::rain().
2342dry() :- not rain().
2343query(dry()).
2344"#;
2345
2346        let program = ExactDdnnfProgram::compile_source(source).unwrap();
2347        let result = program.evaluate().unwrap();
2348
2349        assert_eq!(result.query_probs.len(), 1);
2350        let dry_prob = result.query_probs[0].prob;
2351        assert!(
2352            (dry_prob - 0.7).abs() < 1e-6,
2353            "P(dry) should be 0.7, got {}",
2354            dry_prob
2355        );
2356    }
2357
2358    #[test]
2359    fn test_exact_multi_layer_negation() {
2360        let _gpu_guard = crate::test_gpu_lock::lock();
2361        if CudaDevice::new(0).is_err() {
2362            eprintln!("Skipping test: CUDA runtime unavailable");
2363            return;
2364        }
2365        // 0.4::c(). b() :- not c(). a() :- not b().
2366        // P(b) = P(not c) = 0.6
2367        // P(a) = P(not b) = 0.4
2368        let source = r#"
23690.4::c().
2370b() :- not c().
2371a() :- not b().
2372query(a()).
2373"#;
2374
2375        let program = ExactDdnnfProgram::compile_source(source).unwrap();
2376        let result = program.evaluate().unwrap();
2377
2378        assert_eq!(result.query_probs.len(), 1);
2379        let a_prob = result.query_probs[0].prob;
2380        assert!(
2381            (a_prob - 0.4).abs() < 1e-6,
2382            "P(a) should be 0.4, got {}",
2383            a_prob
2384        );
2385    }
2386
2387    #[test]
2388    fn test_eval_log_z_changes_for_sprinkler_given_wet() {
2389        let _gpu_guard = crate::test_gpu_lock::lock();
2390        if CudaDevice::new(0).is_err() {
2391            eprintln!("Skipping test: CUDA runtime unavailable");
2392            return;
2393        }
2394
2395        let source = r#"
23960.7::rain().
23970.2::sprinkler().
2398wet() :- rain().
2399wet() :- sprinkler().
2400evidence(wet(), true).
2401query(rain()).
2402query(sprinkler()).
2403"#;
2404
2405        let program = ExactDdnnfProgram::compile_source(source).unwrap();
2406        let log_z_e = program.eval_log_z_gpu(None).unwrap();
2407        let sprinkler_var = program.query_var(1).unwrap();
2408        let log_z_eq = program.eval_log_z_gpu(Some(sprinkler_var)).unwrap();
2409
2410        let state = program.gpu_state().unwrap();
2411        let mut cache = state
2412            .cache
2413            .lock()
2414            .unwrap_or_else(|poisoned| poisoned.into_inner());
2415        let (_, var_log_false) = cache.var_log_weights_mut();
2416
2417        let mut before = [0.0f64];
2418        let view = var_log_false.slice(sprinkler_var as usize..(sprinkler_var as usize + 1));
2419        state
2420            .provider
2421            .device()
2422            .inner()
2423            .dtoh_sync_copy_into(&view, &mut before)
2424            .unwrap();
2425
2426        let mut restore = state.provider.memory().alloc::<f64>(1).unwrap();
2427        force_query_var_false(state.provider(), var_log_false, sprinkler_var, &mut restore)
2428            .unwrap();
2429
2430        let mut after = [0.0f64];
2431        let view_after = var_log_false.slice(sprinkler_var as usize..(sprinkler_var as usize + 1));
2432        state
2433            .provider
2434            .device()
2435            .inner()
2436            .dtoh_sync_copy_into(&view_after, &mut after)
2437            .unwrap();
2438
2439        restore_query_var_false(state.provider(), var_log_false, sprinkler_var, &restore).unwrap();
2440
2441        assert!(
2442            before[0].is_finite(),
2443            "expected finite log_false before forcing"
2444        );
2445        assert!(
2446            after[0].is_infinite() && after[0].is_sign_negative(),
2447            "expected -inf log_false after forcing, got {}",
2448            after[0]
2449        );
2450        assert!(
2451            log_z_eq < log_z_e,
2452            "conditioning on sprinkler should reduce logZ (log_z_e={}, log_z_eq={})",
2453            log_z_e,
2454            log_z_eq
2455        );
2456    }
2457}