1use std::collections::{HashMap, HashSet};
5use std::sync::{Arc, Mutex};
6
7use cudarc::driver::LaunchConfig;
8use xlog_core::{MemoryBudget, Result, ScalarType, XlogError};
9use xlog_cuda::LaunchAsync;
10use xlog_logic::ast::Program;
11
12use crate::compilation::gpu_cache::{
13 GpuCircuitCache, GpuCircuitCacheConfig, GpuCircuitCacheHandle,
14};
15use crate::compilation::gpu_cnf::GpuCnfVarTables;
16#[cfg(feature = "host-io")]
17use crate::compilation::gpu_weights::map_nodes_to_vars_gpu;
18use crate::compilation::gpu_weights::{build_evidence_by_var_gpu, build_weights_gpu};
19use crate::compilation::{
20 compile_gpu_d4_and_verify_cached, encode_cnf_gpu, CircuitCompileProfile, DeviceRandomVarList,
21 GpuCompileConfig, GpuPirGraph, GpuPirRoots,
22};
23use crate::neural_fast_path::{GpuWeightSlots, NeuralFastPathConfig};
24use crate::provenance::{
25 extract_from_program, extract_from_source, AggregateLiftStatus, GroundAtom, Provenance, Value,
26};
27use xlog_cuda::memory::TrackedCudaSlice;
28use xlog_cuda::provider::{
29 arith_kernels, filter_kernels, neural_kernels, weights_kernels, ARITH_MODULE, FILTER_MODULE,
30 NEURAL_MODULE, WEIGHTS_MODULE,
31};
32use xlog_cuda::{CudaBuffer, CudaDevice, CudaKernelProvider, GpuMemoryManager};
33
34#[derive(Debug, Clone)]
35pub struct QueryProbability {
36 pub atom: GroundAtom,
37 pub log_prob: f64,
38 pub prob: f64,
39}
40
41#[derive(Debug, Clone)]
42pub struct ExactResult {
43 pub log_z_e: f64,
44 pub query_probs: Vec<QueryProbability>,
45}
46
47#[derive(Debug, Clone)]
48pub struct QueryGradients {
49 pub atom: GroundAtom,
50 pub log_prob: f64,
51 pub prob: f64,
52 pub grad_true: Vec<f64>,
53 pub grad_false: Vec<f64>,
54}
55
56#[derive(Debug, Clone)]
57pub struct ExactResultWithGrads {
58 pub log_z_e: f64,
59 pub query_grads: Vec<QueryGradients>,
60}
61
62#[derive(Debug, Clone)]
63struct QuerySpec {
64 #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
65 atom: GroundAtom,
66 var: Option<u32>,
67}
68
69fn neural_slot_count_u32(slot_count: usize) -> Result<u32> {
70 u32::try_from(slot_count).map_err(|_| {
71 XlogError::Compilation(
72 "Neural fast-path group slot count exceeds GPU u32 index space".to_string(),
73 )
74 })
75}
76
77fn checked_launch_grid_u32(context: &str, item_count: u32, block_size: u32) -> Result<u32> {
78 if block_size == 0 {
79 return Err(XlogError::Kernel(format!(
80 "{context} launch block size must be non-zero"
81 )));
82 }
83 if item_count == 0 {
84 return Ok(0);
85 }
86 item_count
87 .checked_add(block_size - 1)
88 .map(|rounded| rounded / block_size)
89 .ok_or_else(|| XlogError::Kernel(format!("{context} launch grid overflow")))
90}
91
92struct GpuExactState {
93 provider: Arc<CudaKernelProvider>,
94 cache: Mutex<GpuCircuitCache>,
95 handle: GpuCircuitCacheHandle,
96 query_var_batch_cache: Mutex<HashMap<Vec<u32>, Arc<TrackedCudaSlice<u32>>>>,
100}
101
102#[derive(Debug, Clone, Copy)]
106#[non_exhaustive]
107pub struct GpuConfig {
108 pub device_ordinal: usize,
110 pub memory_bytes: u64,
112 pub decision_order_hint: bool,
118}
119
120impl Default for GpuConfig {
121 fn default() -> Self {
122 Self {
123 device_ordinal: 0,
124 memory_bytes: 32 * 1024 * 1024 * 1024, decision_order_hint: false,
126 }
127 }
128}
129
130impl GpuExactState {
131 fn new(
132 provider: Arc<CudaKernelProvider>,
133 cache: GpuCircuitCache,
134 handle: GpuCircuitCacheHandle,
135 ) -> Result<Self> {
136 Ok(Self {
137 provider,
138 cache: Mutex::new(cache),
139 handle,
140 query_var_batch_cache: Mutex::new(HashMap::new()),
141 })
142 }
143
144 fn provider(&self) -> &Arc<CudaKernelProvider> {
145 &self.provider
146 }
147
148 fn handle(&self) -> &GpuCircuitCacheHandle {
149 &self.handle
150 }
151
152 fn cached_query_var_batch(
157 &self,
158 query_vars_host: Vec<u32>,
159 ) -> Result<Arc<TrackedCudaSlice<u32>>> {
160 let mut cache = self
161 .query_var_batch_cache
162 .lock()
163 .unwrap_or_else(|poisoned| poisoned.into_inner());
164 if let Some(cached) = cache.get(&query_vars_host) {
165 return Ok(Arc::clone(cached));
166 }
167 let mut query_vars = self.provider.memory().alloc::<u32>(query_vars_host.len())?;
168 self.provider
169 .htod_sync_copy_into_tracked(&query_vars_host, &mut query_vars)
170 .map_err(|e| {
171 XlogError::Kernel(format!("Failed to upload batched query vars: {}", e))
172 })?;
173 let query_vars = Arc::new(query_vars);
174 cache.insert(query_vars_host, Arc::clone(&query_vars));
175 Ok(query_vars)
176 }
177}
178
179#[cfg_attr(not(feature = "host-io"), allow(dead_code))]
180struct GpuCountLiftQuery {
181 atom: GroundAtom,
182 target_count: u32,
183 leaf_count: u32,
184 leaf_probs: TrackedCudaSlice<f64>,
185}
186
187#[cfg_attr(not(feature = "host-io"), allow(dead_code))]
188struct GpuCountLiftState {
189 provider: Arc<CudaKernelProvider>,
190 queries: Vec<GpuCountLiftQuery>,
191}
192
193impl GpuCountLiftState {
194 fn new(provider: Arc<CudaKernelProvider>, queries: Vec<GpuCountLiftQuery>) -> Self {
195 Self { provider, queries }
196 }
197
198 #[cfg(feature = "host-io")]
199 fn evaluate(&self) -> Result<ExactResult> {
200 let func = self
201 .provider
202 .device()
203 .inner()
204 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_COUNT_LIFT_EXACT)
205 .ok_or_else(|| {
206 XlogError::Kernel("weights_count_lift_exact kernel not found".to_string())
207 })?;
208 let mut query_probs = Vec::with_capacity(self.queries.len());
209 for query in &self.queries {
210 let scratch_len = query
211 .target_count
212 .checked_add(1)
213 .ok_or_else(|| XlogError::Compilation("count-lift target overflow".to_string()))?;
214 let mut scratch = self.provider.memory().alloc::<f64>(scratch_len as usize)?;
215 let mut out = self.provider.memory().alloc::<f64>(1)?;
216 unsafe {
217 func.clone().launch(
218 LaunchConfig {
219 grid_dim: (1, 1, 1),
220 block_dim: (1, 1, 1),
221 shared_mem_bytes: 0,
222 },
223 (
224 &query.leaf_probs,
225 query.leaf_count,
226 query.target_count,
227 &mut scratch,
228 &mut out,
229 ),
230 )
231 }
232 .map_err(|e| XlogError::Kernel(format!("weights_count_lift_exact failed: {}", e)))?;
233 let mut host = vec![0.0f64; 1];
234 self.provider
235 .device()
236 .inner()
237 .dtoh_sync_copy_into(&out, &mut host)
238 .map_err(|e| XlogError::Kernel(format!("count-lift result dtoh failed: {}", e)))?;
239 let mut prob = host[0];
240 if (-1e-12..0.0).contains(&prob) || prob == -1e-12 {
241 prob = 0.0;
242 } else if prob > 1.0 && (1.0..=1.0 + 1e-12).contains(&prob) {
243 prob = 1.0;
244 }
245 if !prob.is_finite() || !(0.0..=1.0).contains(&prob) {
246 return Err(XlogError::Kernel(format!(
247 "count-lift GPU evaluator returned invalid probability {}",
248 prob
249 )));
250 }
251 let log_prob = if prob == 0.0 {
252 f64::NEG_INFINITY
253 } else {
254 prob.ln()
255 };
256 query_probs.push(QueryProbability {
257 atom: query.atom.clone(),
258 log_prob,
259 prob,
260 });
261 }
262 Ok(ExactResult {
263 log_z_e: 0.0,
264 query_probs,
265 })
266 }
267}
268
269#[derive(Clone)]
270pub struct ExactDdnnfProgram {
271 gpu: Option<Arc<GpuExactState>>,
272 #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
273 count_lift_gpu: Option<Arc<GpuCountLiftState>>,
274 queries: Vec<QuerySpec>,
275 #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
276 random_vars: Option<Arc<DeviceRandomVarList>>,
277 max_var: u32,
278 #[cfg_attr(not(feature = "host-io"), allow(dead_code))]
279 origin: ExactProgramOrigin,
280 #[allow(dead_code)] gpu_config: GpuConfig,
282 last_compile_profile: Option<CircuitCompileProfile>,
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
287pub(crate) enum ExactProgramOrigin {
288 Source,
289 Program,
290}
291
292impl ExactDdnnfProgram {
293 pub fn compile_source(source: &str) -> Result<Self> {
294 let provenance = extract_from_source(source)?;
295 Self::compile_provenance_with_gpu(
296 provenance,
297 GpuConfig::default(),
298 ExactProgramOrigin::Source,
299 )
300 }
301
302 pub fn compile_source_with_gpu(source: &str, config: GpuConfig) -> Result<Self> {
303 let provenance = extract_from_source(source)?;
304 Self::compile_provenance_with_gpu(provenance, config, ExactProgramOrigin::Source)
305 }
306
307 pub fn compile_from_program(program: &Program, config: GpuConfig) -> Result<Self> {
308 let provenance = extract_from_program(program)?;
309 Self::compile_provenance_with_gpu(provenance, config, ExactProgramOrigin::Program)
310 }
311
312 #[allow(dead_code)] pub(crate) fn gpu_config(&self) -> GpuConfig {
314 self.gpu_config
315 }
316
317 #[cfg(feature = "host-io")]
318 pub(crate) fn origin(&self) -> ExactProgramOrigin {
319 self.origin
320 }
321
322 pub fn uses_gpu_production_backend(&self) -> bool {
323 self.gpu.is_some()
324 }
325
326 pub fn last_compile_profile(&self) -> Option<&CircuitCompileProfile> {
328 self.last_compile_profile.as_ref()
329 }
330
331 #[doc(hidden)]
332 #[cfg(feature = "host-io")]
333 pub fn uses_gpu_native_count_lift(&self) -> bool {
334 self.count_lift_gpu.is_some()
335 }
336
337 #[cfg(feature = "host-io")]
338 pub fn evaluate(&self) -> Result<ExactResult> {
339 if let Some(count_lift_gpu) = &self.count_lift_gpu {
340 return count_lift_gpu.evaluate();
341 }
342
343 if self.gpu.is_none() {
349 let mut query_probs: Vec<QueryProbability> = Vec::with_capacity(self.queries.len());
350 for query in &self.queries {
351 query_probs.push(QueryProbability {
352 atom: query.atom.clone(),
353 log_prob: f64::NEG_INFINITY,
354 prob: 0.0,
355 });
356 }
357 return Ok(ExactResult {
358 log_z_e: 0.0,
359 query_probs,
360 });
361 }
362
363 let log_z_e = self.eval_log_z_gpu(None)?;
364 if log_z_e.is_infinite() && log_z_e.is_sign_negative() {
365 return Err(XlogError::Execution(
366 "Exact inference error: evidence is inconsistent (P(E)=0)".to_string(),
367 ));
368 }
369
370 let mut query_probs: Vec<QueryProbability> = Vec::with_capacity(self.queries.len());
371 for query in &self.queries {
372 let (log_prob, prob) = match query.var {
373 None => (f64::NEG_INFINITY, 0.0),
374 Some(var) => {
375 let log_z_eq = self.eval_log_z_gpu(Some(var))?;
376 let log_prob = log_z_eq - log_z_e;
377 let mut prob = if log_prob.is_infinite() && log_prob.is_sign_negative() {
378 0.0
379 } else {
380 log_prob.exp()
381 };
382 if prob.is_nan() {
383 return Err(XlogError::Execution(
384 "Exact inference error: NaN probability encountered".to_string(),
385 ));
386 }
387 prob = prob.clamp(0.0, 1.0);
388 (log_prob, prob)
389 }
390 };
391
392 query_probs.push(QueryProbability {
393 atom: query.atom.clone(),
394 log_prob,
395 prob,
396 });
397 }
398
399 Ok(ExactResult {
400 log_z_e,
401 query_probs,
402 })
403 }
404
405 pub fn num_vars(&self) -> usize {
406 if self.max_var == 0 {
407 0
408 } else {
409 (self.max_var as usize) + 1
410 }
411 }
412
413 #[cfg(feature = "host-io")]
419 pub fn random_var_indices(&self) -> Vec<u32> {
420 let Some(state) = self.gpu.as_ref() else {
421 return Vec::new();
422 };
423 let Some(random_vars) = self.random_vars.as_ref() else {
424 return Vec::new();
425 };
426 if random_vars.is_empty() {
427 return Vec::new();
428 }
429 let count = random_vars.count() as usize;
430 let mut host = vec![0u32; count];
431 let view = random_vars.list().slice(0..count);
432 if let Err(e) = state
433 .provider()
434 .device()
435 .inner()
436 .dtoh_sync_copy_into(&view, &mut host)
437 {
438 eprintln!("Failed to read random var list: {}", e);
439 return Vec::new();
440 }
441 host
442 }
443
444 pub(crate) fn query_var(&self, idx: usize) -> Option<u32> {
446 self.queries.get(idx).and_then(|q| q.var)
447 }
448
449 pub fn neural_backward_nll_buffers(
461 &self,
462 slots: &GpuWeightSlots,
463 query_idx: usize,
464 probs: &[CudaBuffer],
465 out_grads: &mut [CudaBuffer],
466 cfg: NeuralFastPathConfig,
467 ) -> Result<()> {
468 self.neural_backward_nll_buffers_inner(slots, query_idx, probs, out_grads, cfg, None, true)
469 }
470
471 pub fn neural_backward_nll_buffers_with_device_loss(
476 &self,
477 slots: &GpuWeightSlots,
478 query_idx: usize,
479 probs: &[CudaBuffer],
480 out_grads: &mut [CudaBuffer],
481 cfg: NeuralFastPathConfig,
482 expected_true: bool,
483 ) -> Result<TrackedCudaSlice<f64>> {
484 let state = self.gpu_state()?;
485 let mut loss = state.provider.memory().alloc::<f64>(1)?;
486 self.neural_backward_nll_buffers_inner(
487 slots,
488 query_idx,
489 probs,
490 out_grads,
491 cfg,
492 Some(&mut loss),
493 expected_true,
494 )?;
495 Ok(loss)
496 }
497
498 pub fn neural_backward_nll_buffers_batch_with_device_loss(
506 &self,
507 slots: &GpuWeightSlots,
508 query_indices: &[usize],
509 probs_batch: &[Vec<CudaBuffer>],
510 out_grads_batch: &mut [Vec<CudaBuffer>],
511 cfg: NeuralFastPathConfig,
512 expected_true: bool,
513 ) -> Result<TrackedCudaSlice<f64>> {
514 let batch = query_indices.len();
515 if batch == 0 {
516 return Err(XlogError::Execution(
517 "Neural fast-path batch error: empty query batch".to_string(),
518 ));
519 }
520 if probs_batch.len() != batch || out_grads_batch.len() != batch {
521 return Err(XlogError::Compilation(format!(
522 "Neural fast-path batch error: query/prob/grad batch mismatch ({}/{}/{})",
523 batch,
524 probs_batch.len(),
525 out_grads_batch.len()
526 )));
527 }
528
529 let state = self.gpu_state()?;
530 let batch_u32 = u32::try_from(batch).map_err(|_| {
531 XlogError::Compilation("Neural fast-path batch size exceeds u32".to_string())
532 })?;
533 let device = state.provider.device().inner();
534
535 {
537 let cache = state
538 .cache
539 .lock()
540 .unwrap_or_else(|poisoned| poisoned.into_inner());
541 if cache.has_any_free_var_mask() {
542 drop(cache);
543 let mut losses = state.provider.memory().alloc::<f64>(batch)?;
544 for q in 0..batch {
545 let loss_q = self.neural_backward_nll_buffers_with_device_loss(
546 slots,
547 query_indices[q],
548 &probs_batch[q],
549 &mut out_grads_batch[q],
550 cfg,
551 expected_true,
552 )?;
553 let mut dst = losses.slice_mut(q..(q + 1));
554 device.dtod_copy(&loss_q, &mut dst).map_err(|e| {
555 XlogError::Kernel(format!(
556 "Failed to copy fallback batch loss to output: {}",
557 e
558 ))
559 })?;
560 }
561 return Ok(losses);
562 }
563 }
564
565 let fill = device
566 .get_func(NEURAL_MODULE, neural_kernels::NEURAL_FILL_AD_CHAIN_F32)
567 .ok_or_else(|| {
568 XlogError::Kernel("neural_fill_ad_chain_f32 kernel not found".to_string())
569 })?;
570 let scatter = device
571 .get_func(
572 NEURAL_MODULE,
573 neural_kernels::NEURAL_SCATTER_AD_CHAIN_GRADS_F32,
574 )
575 .ok_or_else(|| {
576 XlogError::Kernel("neural_scatter_ad_chain_grads_f32 kernel not found".to_string())
577 })?;
578 let binary_f64 = device
579 .get_func(ARITH_MODULE, arith_kernels::ARITH_BINARY_F64)
580 .ok_or_else(|| XlogError::Kernel("arith_binary_f64 kernel not found".to_string()))?;
581 let apply_query_false_batched = device
582 .get_func(
583 WEIGHTS_MODULE,
584 weights_kernels::WEIGHTS_APPLY_QUERY_VARS_FALSE_BATCHED,
585 )
586 .ok_or_else(|| {
587 XlogError::Kernel(
588 "weights_apply_query_vars_false_batched kernel not found".to_string(),
589 )
590 })?;
591 let apply_query_true_batched = device
592 .get_func(
593 WEIGHTS_MODULE,
594 weights_kernels::WEIGHTS_APPLY_QUERY_VARS_TRUE_BATCHED,
595 )
596 .ok_or_else(|| {
597 XlogError::Kernel(
598 "weights_apply_query_vars_true_batched kernel not found".to_string(),
599 )
600 })?;
601
602 let mut cache = state
603 .cache
604 .lock()
605 .unwrap_or_else(|poisoned| poisoned.into_inner());
606 let var_stride = cache.var_stride()?;
607 let var_stride_usize = var_stride as usize;
608 let node_stride = cache.node_stride();
609 let node_stride_usize = node_stride as usize;
610
611 let mut var_log_true_batch = state
612 .provider
613 .memory()
614 .alloc::<f64>(batch * var_stride_usize)?;
615 let mut var_log_false_batch = state
616 .provider
617 .memory()
618 .alloc::<f64>(batch * var_stride_usize)?;
619 cache.copy_slot_weights_to_batch(
620 state.handle(),
621 &mut var_log_true_batch,
622 &mut var_log_false_batch,
623 batch_u32,
624 )?;
625
626 let mut values_batch = state
627 .provider
628 .memory()
629 .alloc::<f64>(batch * node_stride_usize)?;
630 let mut adj_batch = state
631 .provider
632 .memory()
633 .alloc::<f64>(batch * node_stride_usize)?;
634 let mut grad_true_batch = state
635 .provider
636 .memory()
637 .alloc::<f64>(batch * var_stride_usize)?;
638 let mut grad_false_batch = state
639 .provider
640 .memory()
641 .alloc::<f64>(batch * var_stride_usize)?;
642 let mut base_roots = state.provider.memory().alloc::<f64>(batch)?;
643 let mut query_roots = state.provider.memory().alloc::<f64>(batch)?;
644 let mut losses = state.provider.memory().alloc::<f64>(batch)?;
645 let mut force_saved = state.provider.memory().alloc::<f64>(batch)?;
646
647 let mut query_vars_host: Vec<u32> = Vec::with_capacity(batch);
648
649 for q in 0..batch {
651 if probs_batch[q].len() != out_grads_batch[q].len() {
652 return Err(XlogError::Compilation(format!(
653 "Neural fast-path batch error: probs len {} != out_grads len {} for query {}",
654 probs_batch[q].len(),
655 out_grads_batch[q].len(),
656 q
657 )));
658 }
659 if probs_batch[q].len() != slots.num_groups_usize() {
660 return Err(XlogError::Compilation(format!(
661 "Neural fast-path batch error: expected {} groups, got {} for query {}",
662 slots.num_groups_usize(),
663 probs_batch[q].len(),
664 q
665 )));
666 }
667
668 let query_var = self.query_var(query_indices[q]).ok_or_else(|| {
669 XlogError::Execution(format!(
670 "Neural fast-path batch error: query {} has no CNF var",
671 query_indices[q]
672 ))
673 })?;
674 if query_var == 0 || query_var > self.max_var {
675 return Err(XlogError::Compilation(format!(
676 "Neural fast-path batch error: query var {} out of bounds (max_var={})",
677 query_var, self.max_var
678 )));
679 }
680 query_vars_host.push(query_var);
681
682 let row_start = q
683 .checked_mul(var_stride_usize)
684 .ok_or_else(|| XlogError::Compilation("Neural batch row overflow".to_string()))?;
685 let row_end = row_start + var_stride_usize;
686
687 for (g, prob_buf) in probs_batch[q].iter().enumerate() {
688 if prob_buf.arity() != 1 {
689 return Err(XlogError::Compilation(
690 "Neural fast-path expects 1-column prob buffers".to_string(),
691 ));
692 }
693 let ty = prob_buf.schema().column_type(0).ok_or_else(|| {
694 XlogError::Compilation("Missing prob buffer schema".to_string())
695 })?;
696 if ty != ScalarType::F32 {
697 return Err(XlogError::Compilation(format!(
698 "Neural fast-path expects prob dtype F32, got {:?}",
699 ty
700 )));
701 }
702
703 let slot_vars = slots.group_slot_cnf_var(g)?;
704 let labels = neural_slot_count_u32(slot_vars.len())?;
705 if prob_buf.num_rows() != labels as u64 {
706 return Err(XlogError::Compilation(format!(
707 "Neural fast-path prob rows {} != labels {}",
708 prob_buf.num_rows(),
709 labels
710 )));
711 }
712 if out_grads_batch[q][g].num_rows() != labels as u64 {
713 return Err(XlogError::Compilation(format!(
714 "Neural fast-path grad rows {} != labels {}",
715 out_grads_batch[q][g].num_rows(),
716 labels
717 )));
718 }
719
720 let prob_col = prob_buf.column(0).ok_or_else(|| {
721 XlogError::Compilation("Neural fast-path missing prob column".to_string())
722 })?;
723 let mut q_true = var_log_true_batch.slice_mut(row_start..row_end);
724 let mut q_false = var_log_false_batch.slice_mut(row_start..row_end);
725
726 unsafe {
728 fill.clone().launch(
729 LaunchConfig {
730 grid_dim: (1, 1, 1),
731 block_dim: (1, 1, 1),
732 shared_mem_bytes: 0,
733 },
734 (
735 prob_col,
736 labels,
737 &slot_vars,
738 cfg.eps,
739 cfg.min_p,
740 &mut q_true,
741 &mut q_false,
742 ),
743 )
744 }
745 .map_err(|e| {
746 XlogError::Kernel(format!("neural_fill_ad_chain_f32 failed: {}", e))
747 })?;
748 }
749 }
750
751 cache.eval_grads_inplace_fused_batched(
753 state.handle(),
754 &var_log_true_batch,
755 &var_log_false_batch,
756 &mut values_batch,
757 &mut adj_batch,
758 &mut grad_true_batch,
759 &mut grad_false_batch,
760 batch_u32,
761 )?;
762 cache.copy_root_batched_from_values(
763 state.handle(),
764 &values_batch,
765 &mut base_roots,
766 batch_u32,
767 )?;
768
769 for q in 0..batch {
771 let row_start = q
772 .checked_mul(var_stride_usize)
773 .ok_or_else(|| XlogError::Compilation("Neural batch row overflow".to_string()))?;
774 let row_end = row_start + var_stride_usize;
775 let q_grad_true = grad_true_batch.slice(row_start..row_end);
776 let q_grad_false = grad_false_batch.slice(row_start..row_end);
777
778 for (g, prob_buf) in probs_batch[q].iter().enumerate() {
779 let slot_vars = slots.group_slot_cnf_var(g)?;
780 let labels = neural_slot_count_u32(slot_vars.len())?;
781 let prob_col = prob_buf.column(0).ok_or_else(|| {
782 XlogError::Compilation("Neural fast-path missing prob column".to_string())
783 })?;
784 let out_col = out_grads_batch[q][g]
785 .columns
786 .get_mut(0)
787 .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
788
789 let shared_bytes: u32 = 3u64
790 .checked_mul(labels as u64)
791 .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
792 .and_then(|n| u32::try_from(n).ok())
793 .ok_or_else(|| {
794 XlogError::Kernel("Neural scatter shared memory overflow".to_string())
795 })?;
796
797 unsafe {
799 scatter.clone().launch(
800 LaunchConfig {
801 grid_dim: (1, 1, 1),
802 block_dim: (1, 1, 1),
803 shared_mem_bytes: shared_bytes,
804 },
805 (
806 prob_col,
807 labels,
808 &slot_vars,
809 cfg.eps,
810 cfg.min_p,
811 &q_grad_true,
812 &q_grad_false,
813 0u8,
814 out_col,
815 ),
816 )
817 }
818 .map_err(|e| XlogError::Kernel(format!("neural_scatter (base) failed: {}", e)))?;
819 }
820 }
821
822 let query_vars = state.cached_query_var_batch(query_vars_host)?;
825 let force_grid = checked_launch_grid_u32("gpu exact batched query force", batch_u32, 256)?;
826 if force_grid != 0 {
827 if expected_true {
828 unsafe {
830 apply_query_false_batched.clone().launch(
831 LaunchConfig {
832 grid_dim: (force_grid, 1, 1),
833 block_dim: (256, 1, 1),
834 shared_mem_bytes: 0,
835 },
836 (
837 query_vars.as_ref(),
838 batch_u32,
839 self.max_var,
840 var_stride,
841 &mut var_log_false_batch,
842 &mut force_saved,
843 ),
844 )
845 }
846 .map_err(|e| {
847 XlogError::Kernel(format!(
848 "weights_apply_query_vars_false_batched failed: {}",
849 e
850 ))
851 })?;
852 } else {
853 unsafe {
855 apply_query_true_batched.clone().launch(
856 LaunchConfig {
857 grid_dim: (force_grid, 1, 1),
858 block_dim: (256, 1, 1),
859 shared_mem_bytes: 0,
860 },
861 (
862 query_vars.as_ref(),
863 batch_u32,
864 self.max_var,
865 var_stride,
866 &mut var_log_true_batch,
867 &mut force_saved,
868 ),
869 )
870 }
871 .map_err(|e| {
872 XlogError::Kernel(format!(
873 "weights_apply_query_vars_true_batched failed: {}",
874 e
875 ))
876 })?;
877 }
878 }
879
880 cache.eval_grads_inplace_fused_batched(
882 state.handle(),
883 &var_log_true_batch,
884 &var_log_false_batch,
885 &mut values_batch,
886 &mut adj_batch,
887 &mut grad_true_batch,
888 &mut grad_false_batch,
889 batch_u32,
890 )?;
891 cache.copy_root_batched_from_values(
892 state.handle(),
893 &values_batch,
894 &mut query_roots,
895 batch_u32,
896 )?;
897
898 let loss_grid = checked_launch_grid_u32("gpu exact batched query loss", batch_u32, 256)?;
899 if loss_grid != 0 {
900 unsafe {
902 binary_f64.clone().launch(
903 LaunchConfig {
904 grid_dim: (loss_grid, 1, 1),
905 block_dim: (256, 1, 1),
906 shared_mem_bytes: 0,
907 },
908 (&base_roots, &query_roots, batch_u32, 1u8, &mut losses),
909 )
910 }
911 .map_err(|e| XlogError::Kernel(format!("Failed to compute batched NLL loss: {}", e)))?;
912 }
913
914 for q in 0..batch {
916 let row_start = q
917 .checked_mul(var_stride_usize)
918 .ok_or_else(|| XlogError::Compilation("Neural batch row overflow".to_string()))?;
919 let row_end = row_start + var_stride_usize;
920 let q_grad_true = grad_true_batch.slice(row_start..row_end);
921 let q_grad_false = grad_false_batch.slice(row_start..row_end);
922
923 for (g, prob_buf) in probs_batch[q].iter().enumerate() {
924 let slot_vars = slots.group_slot_cnf_var(g)?;
925 let labels = neural_slot_count_u32(slot_vars.len())?;
926 let prob_col = prob_buf.column(0).ok_or_else(|| {
927 XlogError::Compilation("Neural fast-path missing prob column".to_string())
928 })?;
929 let out_col = out_grads_batch[q][g]
930 .columns
931 .get_mut(0)
932 .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
933
934 let shared_bytes: u32 = 3u64
935 .checked_mul(labels as u64)
936 .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
937 .and_then(|n| u32::try_from(n).ok())
938 .ok_or_else(|| {
939 XlogError::Kernel("Neural scatter shared memory overflow".to_string())
940 })?;
941
942 unsafe {
944 scatter.clone().launch(
945 LaunchConfig {
946 grid_dim: (1, 1, 1),
947 block_dim: (1, 1, 1),
948 shared_mem_bytes: shared_bytes,
949 },
950 (
951 prob_col,
952 labels,
953 &slot_vars,
954 cfg.eps,
955 cfg.min_p,
956 &q_grad_true,
957 &q_grad_false,
958 1u8,
959 out_col,
960 ),
961 )
962 }
963 .map_err(|e| XlogError::Kernel(format!("neural_scatter (query) failed: {}", e)))?;
964 }
965 }
966
967 Ok(losses)
968 }
969
970 #[allow(clippy::too_many_arguments)]
971 fn neural_backward_nll_buffers_inner(
972 &self,
973 slots: &GpuWeightSlots,
974 query_idx: usize,
975 probs: &[CudaBuffer],
976 out_grads: &mut [CudaBuffer],
977 cfg: NeuralFastPathConfig,
978 out_loss: Option<&mut TrackedCudaSlice<f64>>,
979 expected_true: bool,
980 ) -> Result<()> {
981 if self.gpu.is_none() {
982 return Err(XlogError::Execution(
983 "Neural fast-path error: program has no compiled circuit".to_string(),
984 ));
985 }
986
987 let query_var = self.query_var(query_idx).ok_or_else(|| {
988 XlogError::Execution(format!(
989 "Neural fast-path error: query {} has no CNF var",
990 query_idx
991 ))
992 })?;
993
994 if probs.len() != out_grads.len() {
995 return Err(XlogError::Compilation(format!(
996 "Neural fast-path error: probs len {} != out_grads len {}",
997 probs.len(),
998 out_grads.len()
999 )));
1000 }
1001 if probs.len() != slots.num_groups_usize() {
1002 return Err(XlogError::Compilation(format!(
1003 "Neural fast-path error: expected {} groups, got {}",
1004 slots.num_groups_usize(),
1005 probs.len()
1006 )));
1007 }
1008
1009 let state = self.gpu_state()?;
1010 let device = state.provider.device().inner();
1011
1012 let fill = device
1013 .get_func(NEURAL_MODULE, neural_kernels::NEURAL_FILL_AD_CHAIN_F32)
1014 .ok_or_else(|| {
1015 XlogError::Kernel("neural_fill_ad_chain_f32 kernel not found".to_string())
1016 })?;
1017 let scatter = device
1018 .get_func(
1019 NEURAL_MODULE,
1020 neural_kernels::NEURAL_SCATTER_AD_CHAIN_GRADS_F32,
1021 )
1022 .ok_or_else(|| {
1023 XlogError::Kernel("neural_scatter_ad_chain_grads_f32 kernel not found".to_string())
1024 })?;
1025 let binary_f64 = device
1026 .get_func(ARITH_MODULE, arith_kernels::ARITH_BINARY_F64)
1027 .ok_or_else(|| XlogError::Kernel("arith_binary_f64 kernel not found".to_string()))?;
1028
1029 let mut cache = state
1030 .cache
1031 .lock()
1032 .unwrap_or_else(|poisoned| poisoned.into_inner());
1033
1034 let root_idx = state.handle().root() as usize;
1035
1036 let mut base_log_z: Option<TrackedCudaSlice<f64>> = if out_loss.is_some() {
1039 Some(state.provider.memory().alloc::<f64>(1)?)
1040 } else {
1041 None
1042 };
1043
1044 for (g, prob_buf) in probs.iter().enumerate() {
1046 if prob_buf.arity() != 1 {
1047 return Err(XlogError::Compilation(
1048 "Neural fast-path expects 1-column prob buffers".to_string(),
1049 ));
1050 }
1051 let ty = prob_buf
1052 .schema()
1053 .column_type(0)
1054 .ok_or_else(|| XlogError::Compilation("Missing prob buffer schema".to_string()))?;
1055 if ty != ScalarType::F32 {
1056 return Err(XlogError::Compilation(format!(
1057 "Neural fast-path expects prob dtype F32, got {:?}",
1058 ty
1059 )));
1060 }
1061
1062 let slot_vars = slots.group_slot_cnf_var(g)?;
1063 let labels = neural_slot_count_u32(slot_vars.len())?;
1064
1065 if prob_buf.num_rows() != labels as u64 {
1066 return Err(XlogError::Compilation(format!(
1067 "Neural fast-path prob rows {} != labels {}",
1068 prob_buf.num_rows(),
1069 labels
1070 )));
1071 }
1072
1073 let prob_col = prob_buf.column(0).ok_or_else(|| {
1074 XlogError::Compilation("Neural fast-path missing prob column".to_string())
1075 })?;
1076
1077 let (var_log_true, var_log_false) = cache.var_log_weights_mut();
1078
1079 unsafe {
1081 fill.clone().launch(
1082 LaunchConfig {
1083 grid_dim: (1, 1, 1),
1084 block_dim: (1, 1, 1),
1085 shared_mem_bytes: 0,
1086 },
1087 (
1088 prob_col,
1089 labels,
1090 &slot_vars,
1091 cfg.eps,
1092 cfg.min_p,
1093 var_log_true,
1094 var_log_false,
1095 ),
1096 )
1097 }
1098 .map_err(|e| XlogError::Kernel(format!("neural_fill_ad_chain_f32 failed: {}", e)))?;
1099 }
1100
1101 cache.eval_grads_inplace_fused(state.handle())?;
1103 if let Some(base) = base_log_z.as_mut() {
1104 let root_view = cache.values().slice(root_idx..(root_idx + 1));
1105 device.dtod_copy(&root_view, base).map_err(|e| {
1106 XlogError::Kernel(format!("Failed to copy base logZ on GPU: {}", e))
1107 })?;
1108 }
1109 for (g, prob_buf) in probs.iter().enumerate() {
1110 let slot_vars = slots.group_slot_cnf_var(g)?;
1111 let labels = neural_slot_count_u32(slot_vars.len())?;
1112
1113 let out_buf = out_grads.get_mut(g).ok_or_else(|| {
1114 XlogError::Compilation("Neural fast-path missing output grad buffer".to_string())
1115 })?;
1116 if out_buf.arity() != 1 {
1117 return Err(XlogError::Compilation(
1118 "Neural fast-path expects 1-column grad buffers".to_string(),
1119 ));
1120 }
1121 let out_ty = out_buf
1122 .schema()
1123 .column_type(0)
1124 .ok_or_else(|| XlogError::Compilation("Missing grad buffer schema".to_string()))?;
1125 if out_ty != ScalarType::F32 {
1126 return Err(XlogError::Compilation(format!(
1127 "Neural fast-path expects grad dtype F32, got {:?}",
1128 out_ty
1129 )));
1130 }
1131 if out_buf.num_rows() != labels as u64 {
1132 return Err(XlogError::Compilation(format!(
1133 "Neural fast-path grad rows {} != labels {}",
1134 out_buf.num_rows(),
1135 labels
1136 )));
1137 }
1138
1139 let prob_col = prob_buf.column(0).ok_or_else(|| {
1140 XlogError::Compilation("Neural fast-path missing prob column".to_string())
1141 })?;
1142 let out_col = out_buf
1143 .columns
1144 .get_mut(0)
1145 .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
1146
1147 let shared_bytes: u32 = 3u64
1148 .checked_mul(labels as u64)
1149 .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
1150 .and_then(|n| u32::try_from(n).ok())
1151 .ok_or_else(|| {
1152 XlogError::Kernel("Neural scatter shared memory overflow".to_string())
1153 })?;
1154
1155 unsafe {
1157 scatter.clone().launch(
1158 LaunchConfig {
1159 grid_dim: (1, 1, 1),
1160 block_dim: (1, 1, 1),
1161 shared_mem_bytes: shared_bytes,
1162 },
1163 (
1164 prob_col,
1165 labels,
1166 &slot_vars,
1167 cfg.eps,
1168 cfg.min_p,
1169 cache.grad_true(),
1170 cache.grad_false(),
1171 0u8,
1172 out_col,
1173 ),
1174 )
1175 }
1176 .map_err(|e| XlogError::Kernel(format!("neural_scatter (base) failed: {}", e)))?;
1177 }
1178
1179 if query_var == 0 || query_var > self.max_var {
1181 return Err(XlogError::Compilation(format!(
1182 "Neural fast-path error: query var {} out of bounds (max_var={})",
1183 query_var, self.max_var
1184 )));
1185 }
1186
1187 let mut restore = state.provider.memory().alloc::<f64>(1)?;
1188 if expected_true {
1189 {
1190 let (_, var_log_false) = cache.var_log_weights_mut();
1191 force_query_var_false(state.provider(), var_log_false, query_var, &mut restore)?;
1192 }
1193 } else {
1194 {
1195 let (var_log_true, _) = cache.var_log_weights_mut();
1196 force_query_var_true(state.provider(), var_log_true, query_var, &mut restore)?;
1197 }
1198 }
1199
1200 cache.eval_grads_inplace_fused(state.handle())?;
1201 if let Some(out) = out_loss {
1202 let base = base_log_z
1203 .as_ref()
1204 .expect("base_log_z allocated when out_loss requested");
1205 let root_view = cache.values().slice(root_idx..(root_idx + 1));
1206 unsafe {
1208 binary_f64.clone().launch(
1209 LaunchConfig {
1210 grid_dim: (1, 1, 1),
1211 block_dim: (1, 1, 1),
1212 shared_mem_bytes: 0,
1213 },
1214 (base, &root_view, 1u32, 1u8, out),
1215 )
1216 }
1217 .map_err(|e| XlogError::Kernel(format!("Failed to compute NLL loss on GPU: {}", e)))?;
1218 }
1219 for (g, prob_buf) in probs.iter().enumerate() {
1220 let slot_vars = slots.group_slot_cnf_var(g)?;
1221 let labels = neural_slot_count_u32(slot_vars.len())?;
1222
1223 let prob_col = prob_buf.column(0).ok_or_else(|| {
1224 XlogError::Compilation("Neural fast-path missing prob column".to_string())
1225 })?;
1226 let out_col = out_grads[g]
1227 .columns
1228 .get_mut(0)
1229 .ok_or_else(|| XlogError::Compilation("Missing grad column".to_string()))?;
1230
1231 let shared_bytes: u32 = 3u64
1232 .checked_mul(labels as u64)
1233 .and_then(|n| n.checked_mul(std::mem::size_of::<f64>() as u64))
1234 .and_then(|n| u32::try_from(n).ok())
1235 .ok_or_else(|| {
1236 XlogError::Kernel("Neural scatter shared memory overflow".to_string())
1237 })?;
1238
1239 unsafe {
1241 scatter.clone().launch(
1242 LaunchConfig {
1243 grid_dim: (1, 1, 1),
1244 block_dim: (1, 1, 1),
1245 shared_mem_bytes: shared_bytes,
1246 },
1247 (
1248 prob_col,
1249 labels,
1250 &slot_vars,
1251 cfg.eps,
1252 cfg.min_p,
1253 cache.grad_true(),
1254 cache.grad_false(),
1255 1u8,
1256 out_col,
1257 ),
1258 )
1259 }
1260 .map_err(|e| XlogError::Kernel(format!("neural_scatter (query) failed: {}", e)))?;
1261 }
1262 if expected_true {
1263 {
1264 let (_, var_log_false) = cache.var_log_weights_mut();
1265 restore_query_var_false(state.provider(), var_log_false, query_var, &restore)?;
1266 }
1267 } else {
1268 {
1269 let (var_log_true, _) = cache.var_log_weights_mut();
1270 restore_query_var_true(state.provider(), var_log_true, query_var, &restore)?;
1271 }
1272 }
1273
1274 Ok(())
1275 }
1276
1277 #[cfg(feature = "host-io")]
1278 pub fn evaluate_gpu_with_grads(&self) -> Result<ExactResultWithGrads> {
1279 if self.gpu.is_none() {
1280 if self.count_lift_gpu.is_some() {
1281 return Err(XlogError::UnsupportedEpistemicConstruct {
1282 construct: "GPU exact gradient evaluation".to_string(),
1283 context: "GPU count-lift exact backend does not expose gradient evaluation; \
1284 gradient production paths require a compiled GPU-native Decision-DNNF exact backend"
1285 .to_string(),
1286 });
1287 }
1288 return Ok(ExactResultWithGrads {
1289 log_z_e: 0.0,
1290 query_grads: Vec::new(),
1291 });
1292 }
1293
1294 let weights_len = if self.max_var == 0 {
1295 0
1296 } else {
1297 (self.max_var as usize) + 1
1298 };
1299
1300 let (log_z_e, grad_true_e, grad_false_e) = self.eval_log_z_and_grads_gpu_cached(None)?;
1301
1302 if log_z_e.is_infinite() && log_z_e.is_sign_negative() {
1303 return Err(XlogError::Execution(
1304 "Exact inference error: evidence is inconsistent (P(E)=0)".to_string(),
1305 ));
1306 }
1307
1308 let mut query_grads: Vec<QueryGradients> = Vec::with_capacity(self.queries.len());
1309
1310 for query in &self.queries {
1311 let Some(var) = query.var else {
1312 query_grads.push(QueryGradients {
1313 atom: query.atom.clone(),
1314 log_prob: f64::NEG_INFINITY,
1315 prob: 0.0,
1316 grad_true: vec![0.0; weights_len],
1317 grad_false: vec![0.0; weights_len],
1318 });
1319 continue;
1320 };
1321
1322 let idx = var as usize;
1323 if idx >= weights_len {
1324 return Err(XlogError::Compilation(format!(
1325 "Exact inference error: query var {} out of bounds (len={})",
1326 var, weights_len
1327 )));
1328 }
1329
1330 let (log_z_eq, grad_true_eq, grad_false_eq) =
1331 self.eval_log_z_and_grads_gpu_cached(Some(var))?;
1332
1333 let log_prob = log_z_eq - log_z_e;
1334 let mut prob = if log_prob.is_infinite() && log_prob.is_sign_negative() {
1335 0.0
1336 } else {
1337 log_prob.exp()
1338 };
1339 if prob.is_nan() {
1340 return Err(XlogError::Execution(
1341 "Exact inference error: NaN probability encountered".to_string(),
1342 ));
1343 }
1344 prob = prob.clamp(0.0, 1.0);
1345
1346 if grad_true_eq.len() != grad_true_e.len() || grad_false_eq.len() != grad_false_e.len()
1347 {
1348 return Err(XlogError::Execution(
1349 "Exact inference error: gradient length mismatch".to_string(),
1350 ));
1351 }
1352
1353 let mut grad_true: Vec<f64> = grad_true_eq;
1354 let mut grad_false: Vec<f64> = grad_false_eq;
1355 for i in 0..grad_true.len() {
1356 grad_true[i] -= grad_true_e[i];
1357 grad_false[i] -= grad_false_e[i];
1358 }
1359
1360 query_grads.push(QueryGradients {
1361 atom: query.atom.clone(),
1362 log_prob,
1363 prob,
1364 grad_true,
1365 grad_false,
1366 });
1367 }
1368
1369 Ok(ExactResultWithGrads {
1370 log_z_e,
1371 query_grads,
1372 })
1373 }
1374
1375 fn compile_provenance_with_gpu(
1376 provenance: Provenance,
1377 config: GpuConfig,
1378 origin: ExactProgramOrigin,
1379 ) -> Result<Self> {
1380 if config.memory_bytes == 0 {
1381 return Err(XlogError::Kernel(
1382 "GPU memory budget must be non-zero".to_string(),
1383 ));
1384 }
1385
1386 let provenance = if config.decision_order_hint {
1387 crate::decision_order::apply_decision_order_hint(provenance)
1388 } else {
1389 provenance
1390 };
1391
1392 let mut roots_set: HashSet<crate::pir::PirNodeId> = HashSet::new();
1393
1394 let mut evidence_formulas: Vec<(crate::pir::PirNodeId, bool, GroundAtom)> = Vec::new();
1395 let mut evidence_atoms: std::collections::HashMap<GroundAtom, bool> =
1396 std::collections::HashMap::new();
1397 for (atom, value) in &provenance.evidence {
1398 if let Some(prev) = evidence_atoms.insert(atom.clone(), *value) {
1399 if prev != *value {
1400 return Err(XlogError::Execution(format!(
1401 "Exact inference error: conflicting evidence for {}",
1402 display_atom(atom)
1403 )));
1404 }
1405 }
1406
1407 let formula = provenance.query_formula(&atom.predicate, &atom.args);
1408 match formula {
1409 Some(id) => {
1410 roots_set.insert(id);
1411 evidence_formulas.push((id, *value, atom.clone()));
1412 }
1413 None => {
1414 if *value {
1415 return Err(XlogError::Execution(format!(
1416 "Exact inference error: evidence atom is never derivable: {}",
1417 display_atom(atom)
1418 )));
1419 }
1420 }
1421 }
1422 }
1423
1424 let mut queries: Vec<QuerySpec> = Vec::new();
1425 #[cfg(feature = "host-io")]
1426 let mut query_nodes: Vec<(usize, crate::pir::PirNodeId)> = Vec::new();
1427 for atom in &provenance.queries {
1428 let formula = provenance.query_formula(&atom.predicate, &atom.args);
1429 if let Some(id) = formula {
1430 roots_set.insert(id);
1431 #[cfg(feature = "host-io")]
1432 {
1433 query_nodes.push((queries.len(), id));
1434 }
1435 }
1436 queries.push(QuerySpec {
1437 atom: atom.clone(),
1438 var: None,
1439 });
1440 }
1441
1442 for (idx, node) in provenance.pir.nodes().iter().enumerate() {
1446 match node {
1447 crate::pir::PirNode::Decision { .. }
1448 | crate::pir::PirNode::Lit { .. }
1449 | crate::pir::PirNode::NegLit { .. } => {
1450 roots_set.insert(crate::pir::PirNodeId::from_u32(idx as u32));
1451 }
1452 _ => {}
1453 }
1454 }
1455
1456 let mut roots: Vec<crate::pir::PirNodeId> = roots_set.into_iter().collect();
1457 roots.sort();
1458
1459 if roots.is_empty() {
1460 return Ok(Self {
1461 gpu: None,
1462 count_lift_gpu: None,
1463 queries,
1464 random_vars: None,
1465 max_var: 0,
1466 origin,
1467 gpu_config: config,
1468 last_compile_profile: None,
1469 });
1470 }
1471
1472 let count_lift_gpu = try_build_count_lift_gpu_state(&provenance, &queries, config)?;
1473 if let Some(count_lift_gpu) = count_lift_gpu {
1474 return Ok(Self {
1475 gpu: None,
1476 count_lift_gpu: Some(count_lift_gpu),
1477 queries,
1478 random_vars: None,
1479 max_var: 0,
1480 origin,
1481 gpu_config: config,
1482 last_compile_profile: None,
1483 });
1484 }
1485
1486 let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
1487 let memory = Arc::new(GpuMemoryManager::new(
1488 device.clone(),
1489 MemoryBudget::with_limit(config.memory_bytes),
1490 ));
1491 let provider = Arc::new(CudaKernelProvider::new(device, memory)?);
1492
1493 let canonical_cnf_hash = crate::cnf::canonical_pir_hash(&provenance.pir, &roots)?;
1494 let gpu_pir = GpuPirGraph::from_host(&provenance.pir, &provider)?;
1495 let gpu_roots = GpuPirRoots::from_host(&roots, &provider)?;
1496 let encoding = encode_cnf_gpu(&gpu_pir, &gpu_roots, &provider)?;
1497 if encoding.vars.max_var != encoding.cnf.var_cap {
1498 return Err(XlogError::Compilation(format!(
1499 "Exact inference error: CNF var_cap {} != vars.max_var {}",
1500 encoding.cnf.var_cap, encoding.vars.max_var
1501 )));
1502 }
1503
1504 let (leaf_probs_host, choice_true_host, choice_false_host) =
1505 build_weight_sources(&provenance)?;
1506
1507 let leaf_probs = upload_f64(&provider, &leaf_probs_host)?;
1508 let choice_true = upload_f64(&provider, &choice_true_host)?;
1509 let choice_false = upload_f64(&provider, &choice_false_host)?;
1510
1511 let evidence_by_var = if evidence_formulas.is_empty() {
1512 let mut evidence = provider
1513 .memory()
1514 .alloc::<u8>((encoding.vars.max_var as usize) + 1)?;
1515 provider
1516 .device()
1517 .inner()
1518 .memset_zeros(&mut evidence)
1519 .map_err(|e| XlogError::Kernel(format!("Failed to zero evidence buffer: {}", e)))?;
1520 evidence
1521 } else {
1522 let mut nodes: Vec<u32> = Vec::with_capacity(evidence_formulas.len());
1523 let mut vals: Vec<u8> = Vec::with_capacity(evidence_formulas.len());
1524 for (node, value, _atom) in &evidence_formulas {
1525 nodes.push(node.as_u32());
1526 vals.push(if *value { 1u8 } else { 2u8 });
1527 }
1528 let evidence_nodes = upload_u32(&provider, &nodes)?;
1529 let evidence_vals = upload_u8(&provider, &vals)?;
1530 build_evidence_by_var_gpu(
1531 &encoding.vars.node_var,
1532 &evidence_nodes,
1533 &evidence_vals,
1534 encoding.vars.max_var,
1535 &provider,
1536 )?
1537 };
1538
1539 let weights = build_weights_gpu(
1540 &encoding.vars,
1541 &leaf_probs,
1542 &choice_true,
1543 &choice_false,
1544 &evidence_by_var,
1545 &provider,
1546 )?;
1547 let random_var_count = leaf_probs_host
1548 .len()
1549 .checked_add(choice_true_host.len())
1550 .ok_or_else(|| XlogError::Compilation("random var count overflow".to_string()))?;
1551 let random_var_count = u32::try_from(random_var_count)
1552 .map_err(|_| XlogError::Compilation("random var count exceeds u32".to_string()))?;
1553 let num_leaf_probs = u32::try_from(leaf_probs_host.len())
1554 .map_err(|_| XlogError::Compilation("leaf_probs count exceeds u32".to_string()))?;
1555 let num_choice_probs = u32::try_from(choice_true_host.len())
1556 .map_err(|_| XlogError::Compilation("choice_probs count exceeds u32".to_string()))?;
1557 let (random_var_list, actual_random_var_count) = collect_random_vars_device(
1558 &provider,
1559 &encoding.vars,
1560 num_leaf_probs,
1561 num_choice_probs,
1562 random_var_count,
1563 )?;
1564 let random_vars =
1565 DeviceRandomVarList::from_device(random_var_list, actual_random_var_count)?;
1566
1567 let compile_config = default_compile_config(&encoding.cnf, config.memory_bytes)?;
1568 let cache_config = default_cache_config(&encoding.cnf, &compile_config)?;
1569
1570 let mut cache = GpuCircuitCache::new(&provider, cache_config)?;
1571 let (handle, compile_profile) = compile_gpu_d4_and_verify_cached(
1572 &encoding.cnf,
1573 &encoding.decision_var_limit,
1574 &provider,
1575 &compile_config,
1576 &mut cache,
1577 &random_vars,
1578 Some(canonical_cnf_hash),
1579 )?;
1580 cache.store_weights(&handle, &weights.log_true, &weights.log_false)?;
1581
1582 #[cfg(feature = "host-io")]
1583 if !query_nodes.is_empty() {
1584 let mut node_ids: Vec<u32> = Vec::with_capacity(query_nodes.len());
1585 for (_idx, node) in &query_nodes {
1586 node_ids.push(node.as_u32());
1587 }
1588 let node_ids_device = upload_u32(&provider, &node_ids)?;
1589 let vars_device = map_nodes_to_vars_gpu(
1590 &encoding.vars.node_var,
1591 &node_ids_device,
1592 encoding.vars.max_var,
1593 &provider,
1594 )?;
1595
1596 let mut vars_host = vec![0u32; vars_device.len()];
1597 provider
1598 .device()
1599 .inner()
1600 .dtoh_sync_copy_into(&vars_device, &mut vars_host)
1601 .map_err(|e| XlogError::Kernel(format!("Failed to read query vars: {}", e)))?;
1602
1603 for (i, (query_idx, _)) in query_nodes.iter().enumerate() {
1604 let var = vars_host[i];
1605 queries[*query_idx].var = Some(var);
1606 }
1607 }
1608
1609 let state = GpuExactState::new(provider, cache, handle)?;
1610
1611 Ok(Self {
1612 gpu: Some(Arc::new(state)),
1613 count_lift_gpu: None,
1614 queries,
1615 random_vars: Some(Arc::new(random_vars)),
1616 max_var: encoding.vars.max_var,
1617 origin,
1618 gpu_config: config,
1619 last_compile_profile: compile_profile,
1620 })
1621 }
1622
1623 #[cfg(feature = "host-io")]
1624 fn eval_log_z_gpu(&self, query_true: Option<u32>) -> Result<f64> {
1625 let state = self.gpu_state()?;
1626 let mut cache = state
1627 .cache
1628 .lock()
1629 .unwrap_or_else(|poisoned| poisoned.into_inner());
1630
1631 if let Some(var) = query_true {
1632 if var == 0 || var > self.max_var {
1633 return Err(XlogError::Compilation(format!(
1634 "Exact inference error: query var {} out of bounds (max_var={})",
1635 var, self.max_var
1636 )));
1637 }
1638 }
1639
1640 let mut restore = None;
1641 if let Some(var) = query_true {
1642 let mut buf = state.provider.memory().alloc::<f64>(1)?;
1643 {
1644 let (_, var_log_false) = cache.var_log_weights_mut();
1645 force_query_var_false(state.provider(), var_log_false, var, &mut buf)?;
1646 }
1647 restore = Some((var, buf));
1648 }
1649
1650 let mut out_log_z = state.provider.memory().alloc::<f64>(1)?;
1651 let eval_result = cache.eval_log_wmc_device_inplace(state.handle(), &mut out_log_z);
1652
1653 if let Some((var, buf)) = restore {
1654 let (_, var_log_false) = cache.var_log_weights_mut();
1655 let restore_result =
1656 restore_query_var_false(state.provider(), var_log_false, var, &buf);
1657 if let Err(err) = eval_result {
1658 restore_result?;
1659 return Err(err);
1660 }
1661 restore_result?;
1662 } else {
1663 eval_result?;
1664 }
1665
1666 let mut host = [0.0f64];
1667 state
1668 .provider
1669 .device()
1670 .inner()
1671 .dtoh_sync_copy_into(&out_log_z, &mut host)
1672 .map_err(|e| XlogError::Kernel(format!("Failed to read logZ: {}", e)))?;
1673 Ok(host[0])
1674 }
1675
1676 fn gpu_state(&self) -> Result<Arc<GpuExactState>> {
1677 self.gpu.clone().ok_or_else(|| {
1678 XlogError::Execution(
1679 "Exact inference GPU error: program has no compiled circuit".to_string(),
1680 )
1681 })
1682 }
1683
1684 #[cfg(feature = "host-io")]
1685 fn eval_log_z_and_grads_gpu_cached(
1686 &self,
1687 query_true: Option<u32>,
1688 ) -> Result<(f64, Vec<f64>, Vec<f64>)> {
1689 let state = self.gpu_state()?;
1690 let mut cache = state
1691 .cache
1692 .lock()
1693 .unwrap_or_else(|poisoned| poisoned.into_inner());
1694
1695 if let Some(var) = query_true {
1696 if var == 0 || var > self.max_var {
1697 return Err(XlogError::Compilation(format!(
1698 "Exact inference error: query var {} out of bounds (max_var={})",
1699 var, self.max_var
1700 )));
1701 }
1702 }
1703
1704 let mut restore = None;
1705 if let Some(var) = query_true {
1706 let mut buf = state.provider.memory().alloc::<f64>(1)?;
1707 {
1708 let (_, var_log_false) = cache.var_log_weights_mut();
1709 force_query_var_false(state.provider(), var_log_false, var, &mut buf)?;
1710 }
1711 restore = Some((var, buf));
1712 }
1713
1714 let eval_result = cache.eval_grads_inplace(state.handle());
1715
1716 if let Some((var, buf)) = restore {
1717 let (_, var_log_false) = cache.var_log_weights_mut();
1718 let restore_result =
1719 restore_query_var_false(state.provider(), var_log_false, var, &buf);
1720 if let Err(err) = eval_result {
1721 restore_result?;
1722 return Err(err);
1723 }
1724 restore_result?;
1725 } else {
1726 eval_result?;
1727 }
1728
1729 let weights_len = if self.max_var == 0 {
1730 0
1731 } else {
1732 (self.max_var as usize) + 1
1733 };
1734
1735 let device = state.provider.device().inner();
1736 let mut host_grad_true: Vec<f64> = vec![0.0; weights_len];
1737 let mut host_grad_false: Vec<f64> = vec![0.0; weights_len];
1738
1739 let root_idx = state.handle().root() as usize;
1740 let root_view = cache.values().slice(root_idx..(root_idx + 1));
1741 let mut log_z = [0.0_f64];
1742 device
1743 .dtoh_sync_copy_into(&root_view, &mut log_z)
1744 .map_err(|e| XlogError::Kernel(format!("Failed to read logZ: {}", e)))?;
1745
1746 let var_stride = cache.var_stride()? as usize;
1749 let slot = state.handle().slot_index() as usize;
1750 let grad_start = slot * var_stride;
1751 let grad_end = grad_start + weights_len;
1752 let grad_true_slot = cache.grad_true().slice(grad_start..grad_end);
1753 let grad_false_slot = cache.grad_false().slice(grad_start..grad_end);
1754 device
1755 .dtoh_sync_copy_into(&grad_true_slot, &mut host_grad_true)
1756 .map_err(|e| XlogError::Kernel(format!("Failed to download grad_true: {}", e)))?;
1757 device
1758 .dtoh_sync_copy_into(&grad_false_slot, &mut host_grad_false)
1759 .map_err(|e| XlogError::Kernel(format!("Failed to download grad_false: {}", e)))?;
1760
1761 Ok((log_z[0], host_grad_true, host_grad_false))
1762 }
1763}
1764
1765fn try_build_count_lift_gpu_state(
1766 provenance: &Provenance,
1767 queries: &[QuerySpec],
1768 config: GpuConfig,
1769) -> Result<Option<Arc<GpuCountLiftState>>> {
1770 if queries.is_empty() || !provenance.evidence.is_empty() || !provenance.choice_probs.is_empty()
1771 {
1772 return Ok(None);
1773 }
1774
1775 let fired_count_predicates: HashSet<&str> = provenance
1776 .aggregate_lifting
1777 .iter()
1778 .filter(|entry| {
1779 entry.status == AggregateLiftStatus::Fired
1780 && entry.operator.as_str() == "count"
1781 && entry.deterministic_rows == 0
1782 })
1783 .map(|entry| entry.predicate.as_str())
1784 .collect();
1785 if fired_count_predicates.is_empty() {
1786 return Ok(None);
1787 }
1788 if queries
1789 .iter()
1790 .any(|query| !fired_count_predicates.contains(query.atom.predicate.as_str()))
1791 {
1792 return Ok(None);
1793 }
1794
1795 let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
1796 let memory = Arc::new(GpuMemoryManager::new(
1797 device.clone(),
1798 MemoryBudget::with_limit(config.memory_bytes),
1799 ));
1800 let provider = Arc::new(CudaKernelProvider::new(device, memory)?);
1801 let mut gpu_queries = Vec::with_capacity(queries.len());
1802 for query in queries {
1803 let target_count = match count_lift_query_target(query)? {
1804 Some(target) => target,
1805 None => return Ok(None),
1806 };
1807 let root = match provenance.query_formula(&query.atom.predicate, &query.atom.args) {
1808 Some(root) => root,
1809 None => return Ok(None),
1810 };
1811 let mut leaves = HashSet::new();
1812 collect_count_lift_leaves(provenance, root, &mut leaves)?;
1813 if leaves.is_empty() || leaves.len() > 64 {
1814 return Ok(None);
1815 }
1816 if target_count > leaves.len() as u32 {
1817 return Ok(None);
1818 }
1819 let mut leaves: Vec<_> = leaves.into_iter().collect();
1820 leaves.sort_by_key(|leaf| leaf.as_u32());
1821 let mut leaf_probs_host = Vec::with_capacity(leaves.len());
1822 for leaf in leaves {
1823 let p = *provenance.leaf_probs.get(&leaf).ok_or_else(|| {
1824 XlogError::Compilation(format!(
1825 "Count-lift GPU evaluator missing probability for leaf {}",
1826 leaf.as_u32()
1827 ))
1828 })?;
1829 leaf_probs_host.push(p);
1830 }
1831 let leaf_count = u32::try_from(leaf_probs_host.len())
1832 .map_err(|_| XlogError::Compilation("count-lift leaf count exceeds u32".to_string()))?;
1833 let leaf_probs = upload_f64(&provider, &leaf_probs_host)?;
1834 gpu_queries.push(GpuCountLiftQuery {
1835 atom: query.atom.clone(),
1836 target_count,
1837 leaf_count,
1838 leaf_probs,
1839 });
1840 }
1841 Ok(Some(Arc::new(GpuCountLiftState::new(
1842 provider,
1843 gpu_queries,
1844 ))))
1845}
1846
1847fn count_lift_query_target(query: &QuerySpec) -> Result<Option<u32>> {
1848 match query.atom.args.last() {
1849 Some(Value::I64(value)) if *value >= 0 => u32::try_from(*value)
1850 .map(Some)
1851 .map_err(|_| XlogError::Compilation("count-lift target exceeds u32".to_string())),
1852 _ => Ok(None),
1853 }
1854}
1855
1856fn collect_count_lift_leaves(
1857 provenance: &Provenance,
1858 node: crate::pir::PirNodeId,
1859 leaves: &mut HashSet<crate::pir::LeafId>,
1860) -> Result<()> {
1861 let pir_node = provenance.pir.node(node).ok_or_else(|| {
1862 XlogError::Compilation(format!(
1863 "Count-lift GPU evaluator saw invalid PIR node {}",
1864 node.as_u32()
1865 ))
1866 })?;
1867 match pir_node {
1868 crate::pir::PirNode::Const(_) => Ok(()),
1869 crate::pir::PirNode::Lit { leaf } | crate::pir::PirNode::NegLit { leaf } => {
1870 leaves.insert(*leaf);
1871 Ok(())
1872 }
1873 crate::pir::PirNode::And { children } | crate::pir::PirNode::Or { children } => {
1874 for child in children {
1875 collect_count_lift_leaves(provenance, *child, leaves)?;
1876 }
1877 Ok(())
1878 }
1879 crate::pir::PirNode::Decision { .. } => Err(XlogError::Compilation(
1880 "Count-lift GPU evaluator does not support annotated-disjunction choices".to_string(),
1881 )),
1882 }
1883}
1884
1885fn force_query_var_false(
1886 provider: &Arc<CudaKernelProvider>,
1887 log_false: &mut TrackedCudaSlice<f64>,
1888 var: u32,
1889 restore: &mut TrackedCudaSlice<f64>,
1890) -> Result<()> {
1891 let device = provider.device().inner();
1892 let func = device
1893 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FORCE_VAR_FALSE)
1894 .ok_or_else(|| XlogError::Kernel("weights_force_var_false kernel not found".to_string()))?;
1895 unsafe {
1897 func.clone().launch(
1898 LaunchConfig {
1899 grid_dim: (1, 1, 1),
1900 block_dim: (1, 1, 1),
1901 shared_mem_bytes: 0,
1902 },
1903 (var, log_false, restore),
1904 )
1905 }
1906 .map_err(|e| XlogError::Kernel(format!("weights_force_var_false failed: {}", e)))?;
1907 Ok(())
1908}
1909
1910fn restore_query_var_false(
1911 provider: &Arc<CudaKernelProvider>,
1912 log_false: &mut TrackedCudaSlice<f64>,
1913 var: u32,
1914 restore: &TrackedCudaSlice<f64>,
1915) -> Result<()> {
1916 let device = provider.device().inner();
1917 let func = device
1918 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_RESTORE_VAR_FALSE)
1919 .ok_or_else(|| {
1920 XlogError::Kernel("weights_restore_var_false kernel not found".to_string())
1921 })?;
1922 unsafe {
1924 func.clone().launch(
1925 LaunchConfig {
1926 grid_dim: (1, 1, 1),
1927 block_dim: (1, 1, 1),
1928 shared_mem_bytes: 0,
1929 },
1930 (var, log_false, restore),
1931 )
1932 }
1933 .map_err(|e| XlogError::Kernel(format!("weights_restore_var_false failed: {}", e)))?;
1934 Ok(())
1935}
1936
1937fn force_query_var_true(
1938 provider: &Arc<CudaKernelProvider>,
1939 log_true: &mut TrackedCudaSlice<f64>,
1940 var: u32,
1941 restore: &mut TrackedCudaSlice<f64>,
1942) -> Result<()> {
1943 let device = provider.device().inner();
1944 let func = device
1945 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FORCE_VAR_TRUE)
1946 .ok_or_else(|| XlogError::Kernel("weights_force_var_true kernel not found".to_string()))?;
1947 unsafe {
1949 func.clone().launch(
1950 LaunchConfig {
1951 grid_dim: (1, 1, 1),
1952 block_dim: (1, 1, 1),
1953 shared_mem_bytes: 0,
1954 },
1955 (var, log_true, restore),
1956 )
1957 }
1958 .map_err(|e| XlogError::Kernel(format!("weights_force_var_true failed: {}", e)))?;
1959 Ok(())
1960}
1961
1962fn restore_query_var_true(
1963 provider: &Arc<CudaKernelProvider>,
1964 log_true: &mut TrackedCudaSlice<f64>,
1965 var: u32,
1966 restore: &TrackedCudaSlice<f64>,
1967) -> Result<()> {
1968 let device = provider.device().inner();
1969 let func = device
1970 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_RESTORE_VAR_TRUE)
1971 .ok_or_else(|| {
1972 XlogError::Kernel("weights_restore_var_true kernel not found".to_string())
1973 })?;
1974 unsafe {
1976 func.clone().launch(
1977 LaunchConfig {
1978 grid_dim: (1, 1, 1),
1979 block_dim: (1, 1, 1),
1980 shared_mem_bytes: 0,
1981 },
1982 (var, log_true, restore),
1983 )
1984 }
1985 .map_err(|e| XlogError::Kernel(format!("weights_restore_var_true failed: {}", e)))?;
1986 Ok(())
1987}
1988
1989pub(crate) fn default_compile_config(
1990 cnf: &xlog_solve::GpuCnf,
1991 memory_bytes: u64,
1992) -> Result<GpuCompileConfig> {
1993 let frontier_depth: u16 = 6;
1997
1998 let var_cap = cnf.var_cap.max(1);
1999 let trail_bytes_per_item = (var_cap as u64)
2000 .checked_add(1)
2001 .and_then(|v| v.checked_mul(std::mem::size_of::<i32>() as u64))
2002 .ok_or_else(|| XlogError::Compilation("trail size overflow".to_string()))?;
2003 let denom = trail_bytes_per_item
2004 .checked_mul(8)
2005 .ok_or_else(|| XlogError::Compilation("trail memory denominator overflow".to_string()))?;
2006 if memory_bytes
2007 < denom.checked_mul(8).ok_or_else(|| {
2008 XlogError::Compilation("minimum frontier memory requirement overflow".to_string())
2009 })?
2010 {
2011 return Err(XlogError::Compilation(format!(
2012 "memory budget {} cannot hold the minimum GPU-native Decision-DNNF frontier allocation",
2013 memory_bytes
2014 )));
2015 }
2016 let max_items_by_trail = memory_bytes / denom;
2017 let max_frontier_items = max_items_by_trail.min(4096).min(u64::from(u32::MAX)) as u32;
2018
2019 let frontier_cap_factor = (1u64
2023 .checked_shl(frontier_depth as u32)
2024 .unwrap_or(u64::from(u32::MAX)))
2025 .min(u64::from(max_frontier_items)) as u32;
2026
2027 let per_item_nodes = cnf
2028 .var_cap
2029 .checked_mul(5)
2030 .ok_or_else(|| XlogError::Compilation("smooth_node_cap overflow".to_string()))?
2031 .max(1024);
2032 let smooth_node_cap = per_item_nodes
2033 .checked_mul(frontier_cap_factor)
2034 .ok_or_else(|| XlogError::Compilation("smooth_node_cap overflow".to_string()))?;
2035
2036 let mut smooth_edge_cap = smooth_node_cap
2039 .checked_mul(2)
2040 .ok_or_else(|| XlogError::Compilation("smooth_edge_cap overflow".to_string()))?;
2041 if smooth_edge_cap < max_frontier_items {
2042 smooth_edge_cap = max_frontier_items;
2043 }
2044
2045 let mut cdcl_learned_bytes = memory_bytes / 8;
2050 if cdcl_learned_bytes < 4 * 1024 * 1024 {
2051 cdcl_learned_bytes = 4 * 1024 * 1024;
2052 }
2053
2054 let config = GpuCompileConfig {
2055 frontier_depth,
2056 max_frontier_items,
2057 max_depth: 128,
2058 smooth_node_cap,
2059 smooth_edge_cap,
2060 cdcl_restart_interval: 64,
2061 cdcl_learned_bytes,
2062 cdcl_conflict_budget: None,
2063 incremental_verify: false,
2064 };
2065 Ok(config)
2066}
2067
2068pub(crate) fn default_cache_config(
2069 cnf: &xlog_solve::GpuCnf,
2070 compile: &GpuCompileConfig,
2071) -> Result<GpuCircuitCacheConfig> {
2072 if compile.smooth_node_cap == 0 || compile.smooth_edge_cap == 0 {
2073 return Err(XlogError::Compilation(
2074 "GPU cache config requires non-zero smoothing caps".to_string(),
2075 ));
2076 }
2077 Ok(GpuCircuitCacheConfig {
2078 num_slots: 4, table_size: 8,
2080 node_cap: compile.smooth_node_cap,
2081 edge_cap: compile.smooth_edge_cap,
2082 level_cap: compile.smooth_node_cap,
2083 var_cap: cnf.var_cap,
2084 })
2085}
2086
2087pub(crate) fn build_weight_sources(
2088 provenance: &Provenance,
2089) -> Result<(Vec<f64>, Vec<f64>, Vec<f64>)> {
2090 let max_leaf = provenance.leaf_probs.keys().map(|leaf| leaf.as_u32()).max();
2091 let leaf_len = max_leaf.map(|v| v as usize + 1).unwrap_or(0);
2092 let mut leaf_probs = vec![0.0f64; leaf_len];
2093 let mut leaf_seen = vec![false; leaf_len];
2094 for (leaf, p) in &provenance.leaf_probs {
2095 let idx = leaf.as_u32() as usize;
2096 if idx >= leaf_len {
2097 return Err(XlogError::Compilation(
2098 "leaf probability index out of bounds".to_string(),
2099 ));
2100 }
2101 leaf_probs[idx] = *p;
2102 leaf_seen[idx] = true;
2103 }
2104 if let Some((idx, _)) = leaf_seen.iter().enumerate().find(|(_, seen)| !**seen) {
2105 return Err(XlogError::Compilation(format!(
2106 "missing probability for leaf {}",
2107 idx
2108 )));
2109 }
2110
2111 let max_choice = provenance
2112 .choice_probs
2113 .keys()
2114 .map(|choice| choice.as_u32())
2115 .max();
2116 let choice_len = max_choice.map(|v| v as usize + 1).unwrap_or(0);
2117 let mut choice_true = vec![0.0f64; choice_len];
2118 let mut choice_false = vec![0.0f64; choice_len];
2119 let mut choice_seen = vec![false; choice_len];
2120 for (choice, (pt, pf)) in &provenance.choice_probs {
2121 let idx = choice.as_u32() as usize;
2122 if idx >= choice_len {
2123 return Err(XlogError::Compilation(
2124 "choice probability index out of bounds".to_string(),
2125 ));
2126 }
2127 choice_true[idx] = *pt;
2128 choice_false[idx] = *pf;
2129 choice_seen[idx] = true;
2130 }
2131 if let Some((idx, _)) = choice_seen.iter().enumerate().find(|(_, seen)| !**seen) {
2132 return Err(XlogError::Compilation(format!(
2133 "missing probability for choice {}",
2134 idx
2135 )));
2136 }
2137
2138 Ok((leaf_probs, choice_true, choice_false))
2139}
2140
2141pub(crate) fn upload_u32(
2142 provider: &Arc<CudaKernelProvider>,
2143 host: &[u32],
2144) -> Result<TrackedCudaSlice<u32>> {
2145 let memory = provider.memory();
2146 let mut buf = memory.alloc::<u32>(host.len())?;
2147 provider
2148 .htod_sync_copy_into_tracked(host, &mut buf)
2149 .map_err(|e| XlogError::Kernel(format!("Failed to upload u32 buffer: {}", e)))?;
2150 Ok(buf)
2151}
2152
2153pub(crate) fn upload_u8(
2154 provider: &Arc<CudaKernelProvider>,
2155 host: &[u8],
2156) -> Result<TrackedCudaSlice<u8>> {
2157 let memory = provider.memory();
2158 let mut buf = memory.alloc::<u8>(host.len())?;
2159 provider
2160 .htod_sync_copy_into_tracked(host, &mut buf)
2161 .map_err(|e| XlogError::Kernel(format!("Failed to upload u8 buffer: {}", e)))?;
2162 Ok(buf)
2163}
2164
2165pub(crate) fn upload_f64(
2166 provider: &Arc<CudaKernelProvider>,
2167 host: &[f64],
2168) -> Result<TrackedCudaSlice<f64>> {
2169 let memory = provider.memory();
2170 let mut buf = memory.alloc::<f64>(host.len())?;
2171 provider
2172 .htod_sync_copy_into_tracked(host, &mut buf)
2173 .map_err(|e| XlogError::Kernel(format!("Failed to upload f64 buffer: {}", e)))?;
2174 Ok(buf)
2175}
2176
2177fn capture_compact_count_device(
2178 provider: &Arc<CudaKernelProvider>,
2179 prefix_sum: &TrackedCudaSlice<u32>,
2180 mask: &TrackedCudaSlice<u8>,
2181 n: u32,
2182) -> Result<TrackedCudaSlice<u32>> {
2183 let mut out = provider.memory().alloc::<u32>(1)?;
2184 let device = provider.device().inner();
2185 let capture_fn = device
2186 .get_func(FILTER_MODULE, filter_kernels::CAPTURE_COMPACT_COUNT)
2187 .ok_or_else(|| XlogError::Kernel("capture_compact_count kernel not found".to_string()))?;
2188 unsafe {
2190 capture_fn.clone().launch(
2191 LaunchConfig {
2192 grid_dim: (1, 1, 1),
2193 block_dim: (1, 1, 1),
2194 shared_mem_bytes: 0,
2195 },
2196 (prefix_sum, mask, n, &mut out),
2197 )
2198 }
2199 .map_err(|e| XlogError::Kernel(format!("capture_compact_count failed: {}", e)))?;
2200 Ok(out)
2201}
2202
2203pub(crate) fn collect_random_vars_device(
2204 provider: &Arc<CudaKernelProvider>,
2205 vars: &GpuCnfVarTables,
2206 num_leaf_probs: u32,
2207 num_choice_probs: u32,
2208 _expected_count: u32,
2209) -> Result<(TrackedCudaSlice<u32>, u32)> {
2210 let device = provider.device().inner();
2211 let memory = provider.memory();
2212
2213 let mask_len = vars
2214 .max_var
2215 .checked_add(1)
2216 .ok_or_else(|| XlogError::Compilation("random var mask_len overflow".to_string()))?;
2217 let mask_len_usize = usize::try_from(mask_len)
2218 .map_err(|_| XlogError::Compilation("random var mask_len exceeds usize".to_string()))?;
2219
2220 let mut mask = memory.alloc::<u8>(mask_len_usize)?;
2221 device
2222 .memset_zeros(&mut mask)
2223 .map_err(|e| XlogError::Kernel(format!("Failed to zero random var mask: {}", e)))?;
2224
2225 let mut iota = memory.alloc::<u32>(mask_len_usize)?;
2226 let fill_iota = device
2227 .get_func(FILTER_MODULE, filter_kernels::FILL_U32_IOTA)
2228 .ok_or_else(|| XlogError::Kernel("fill_u32_iota kernel not found".to_string()))?;
2229 let block_size = 256u32;
2230 let grid = checked_launch_grid_u32("fill random-var iota", mask_len, block_size)?;
2231 unsafe {
2233 fill_iota.clone().launch(
2234 LaunchConfig {
2235 grid_dim: (grid, 1, 1),
2236 block_dim: (block_size, 1, 1),
2237 shared_mem_bytes: 0,
2238 },
2239 (&mut iota, mask_len, 0u32),
2240 )
2241 }
2242 .map_err(|e| XlogError::Kernel(format!("fill_u32_iota failed: {}", e)))?;
2243
2244 let leaf_len = num_leaf_probs;
2249 let choice_len = num_choice_probs;
2250
2251 let mark_kernel = device
2252 .get_func(FILTER_MODULE, filter_kernels::MARK_RANDOM_VARS)
2253 .ok_or_else(|| XlogError::Kernel("mark_random_vars kernel not found".to_string()))?;
2254 let mark_n = leaf_len.max(choice_len);
2255 if mark_n > 0 {
2256 let grid = checked_launch_grid_u32("mark random vars", mark_n, block_size)?;
2257 unsafe {
2259 mark_kernel.clone().launch(
2260 LaunchConfig {
2261 grid_dim: (grid, 1, 1),
2262 block_dim: (block_size, 1, 1),
2263 shared_mem_bytes: 0,
2264 },
2265 (
2266 &vars.leaf_var,
2267 &vars.choice_var,
2268 leaf_len,
2269 choice_len,
2270 &mut mask,
2271 mask_len,
2272 ),
2273 )
2274 }
2275 .map_err(|e| XlogError::Kernel(format!("mark_random_vars failed: {}", e)))?;
2276 }
2277
2278 let prefix_sum = provider.scan_u8_mask_device(&mask, mask_len)?;
2279 let count_device = capture_compact_count_device(provider, &prefix_sum, &mask, mask_len)?;
2280
2281 let actual_count = {
2285 let mut buf = vec![0u32; 1];
2286 device
2287 .dtoh_sync_copy_into(&count_device, &mut buf)
2288 .map_err(|e| XlogError::Kernel(format!("dtoh count_device failed: {}", e)))?;
2289 buf[0]
2290 };
2291
2292 if actual_count == 0 {
2293 let out = provider.memory().alloc::<u32>(0)?;
2295 return Ok((out, 0));
2296 }
2297
2298 let mut out = memory.alloc::<u32>(mask_len_usize)?;
2299 let compact_fn = device
2300 .get_func(FILTER_MODULE, filter_kernels::COMPACT_U32_BY_MASK)
2301 .ok_or_else(|| XlogError::Kernel("compact_u32_by_mask kernel not found".to_string()))?;
2302 unsafe {
2304 compact_fn.clone().launch(
2305 LaunchConfig {
2306 grid_dim: (grid, 1, 1),
2307 block_dim: (block_size, 1, 1),
2308 shared_mem_bytes: 0,
2309 },
2310 (&iota, &mask, &prefix_sum, mask_len, &mut out),
2311 )
2312 }
2313 .map_err(|e| XlogError::Kernel(format!("compact_u32_by_mask failed: {}", e)))?;
2314
2315 Ok((out, actual_count))
2316}
2317
2318fn display_atom(atom: &GroundAtom) -> String {
2319 if atom.args.is_empty() {
2320 format!("{}()", atom.predicate)
2321 } else {
2322 format!("{}({} args)", atom.predicate, atom.args.len())
2323 }
2324}
2325
2326#[cfg(all(test, feature = "host-io"))]
2327mod tests {
2328 use super::*;
2329 use xlog_cuda::CudaDevice;
2330
2331 #[test]
2332 fn test_exact_negation_probability() {
2333 let _gpu_guard = crate::test_gpu_lock::lock();
2334 if CudaDevice::new(0).is_err() {
2335 eprintln!("Skipping test: CUDA runtime unavailable");
2336 return;
2337 }
2338 let source = r#"
23410.3::rain().
2342dry() :- not rain().
2343query(dry()).
2344"#;
2345
2346 let program = ExactDdnnfProgram::compile_source(source).unwrap();
2347 let result = program.evaluate().unwrap();
2348
2349 assert_eq!(result.query_probs.len(), 1);
2350 let dry_prob = result.query_probs[0].prob;
2351 assert!(
2352 (dry_prob - 0.7).abs() < 1e-6,
2353 "P(dry) should be 0.7, got {}",
2354 dry_prob
2355 );
2356 }
2357
2358 #[test]
2359 fn test_exact_multi_layer_negation() {
2360 let _gpu_guard = crate::test_gpu_lock::lock();
2361 if CudaDevice::new(0).is_err() {
2362 eprintln!("Skipping test: CUDA runtime unavailable");
2363 return;
2364 }
2365 let source = r#"
23690.4::c().
2370b() :- not c().
2371a() :- not b().
2372query(a()).
2373"#;
2374
2375 let program = ExactDdnnfProgram::compile_source(source).unwrap();
2376 let result = program.evaluate().unwrap();
2377
2378 assert_eq!(result.query_probs.len(), 1);
2379 let a_prob = result.query_probs[0].prob;
2380 assert!(
2381 (a_prob - 0.4).abs() < 1e-6,
2382 "P(a) should be 0.4, got {}",
2383 a_prob
2384 );
2385 }
2386
2387 #[test]
2388 fn test_eval_log_z_changes_for_sprinkler_given_wet() {
2389 let _gpu_guard = crate::test_gpu_lock::lock();
2390 if CudaDevice::new(0).is_err() {
2391 eprintln!("Skipping test: CUDA runtime unavailable");
2392 return;
2393 }
2394
2395 let source = r#"
23960.7::rain().
23970.2::sprinkler().
2398wet() :- rain().
2399wet() :- sprinkler().
2400evidence(wet(), true).
2401query(rain()).
2402query(sprinkler()).
2403"#;
2404
2405 let program = ExactDdnnfProgram::compile_source(source).unwrap();
2406 let log_z_e = program.eval_log_z_gpu(None).unwrap();
2407 let sprinkler_var = program.query_var(1).unwrap();
2408 let log_z_eq = program.eval_log_z_gpu(Some(sprinkler_var)).unwrap();
2409
2410 let state = program.gpu_state().unwrap();
2411 let mut cache = state
2412 .cache
2413 .lock()
2414 .unwrap_or_else(|poisoned| poisoned.into_inner());
2415 let (_, var_log_false) = cache.var_log_weights_mut();
2416
2417 let mut before = [0.0f64];
2418 let view = var_log_false.slice(sprinkler_var as usize..(sprinkler_var as usize + 1));
2419 state
2420 .provider
2421 .device()
2422 .inner()
2423 .dtoh_sync_copy_into(&view, &mut before)
2424 .unwrap();
2425
2426 let mut restore = state.provider.memory().alloc::<f64>(1).unwrap();
2427 force_query_var_false(state.provider(), var_log_false, sprinkler_var, &mut restore)
2428 .unwrap();
2429
2430 let mut after = [0.0f64];
2431 let view_after = var_log_false.slice(sprinkler_var as usize..(sprinkler_var as usize + 1));
2432 state
2433 .provider
2434 .device()
2435 .inner()
2436 .dtoh_sync_copy_into(&view_after, &mut after)
2437 .unwrap();
2438
2439 restore_query_var_false(state.provider(), var_log_false, sprinkler_var, &restore).unwrap();
2440
2441 assert!(
2442 before[0].is_finite(),
2443 "expected finite log_false before forcing"
2444 );
2445 assert!(
2446 after[0].is_infinite() && after[0].is_sign_negative(),
2447 "expected -inf log_false after forcing, got {}",
2448 after[0]
2449 );
2450 assert!(
2451 log_z_eq < log_z_e,
2452 "conditioning on sprinkler should reduce logZ (log_z_e={}, log_z_eq={})",
2453 log_z_e,
2454 log_z_eq
2455 );
2456 }
2457}