Skip to main content

xlog_prob/
neural_fast_path.rs

1//! GPU neural fast-path helpers (device slot mapping + AD-chain glue).
2//!
3//! This module contains GPU-resident tables used to map neural predicate outputs
4//! (probability vectors) to CNF variable ids in the compiled circuit.
5
6use cudarc::driver::{CudaView, DeviceSlice};
7use xlog_core::{Result, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::CudaKernelProvider;
10
11/// Configuration for GPU neural fast-path weight injection.
12///
13/// Controls numerical stability parameters for mapping neural-network
14/// output probabilities to CNF variable log-weights.
15#[derive(Debug, Clone, Copy)]
16#[non_exhaustive]
17pub struct NeuralFastPathConfig {
18    /// Probability mass reserved for the implicit "none" outcome.
19    pub eps: f64,
20    /// Minimum probability clamp used for numerical stability.
21    pub min_p: f64,
22}
23
24impl Default for NeuralFastPathConfig {
25    fn default() -> Self {
26        Self {
27            eps: 1e-7,
28            min_p: 1e-12,
29        }
30    }
31}
32
33/// Device-resident mapping from neural output slots to CNF variable ids.
34///
35/// Slots are grouped (one group per neural predicate instance). Each slot is a
36/// CNF var id (DIMACS, 1-based) whose log-weights should be updated from the
37/// group’s probability vector.
38pub struct GpuWeightSlots {
39    group_offsets_host: Vec<u32>,
40    group_offsets: TrackedCudaSlice<u32>, // len = num_groups + 1
41    slot_cnf_var: TrackedCudaSlice<u32>,  // len = total_slots
42}
43
44impl GpuWeightSlots {
45    /// Upload a slot mapping from host vectors.
46    ///
47    /// `groups[g][i]` is the CNF variable id corresponding to label/slot `i` of group `g`.
48    pub fn upload(provider: &CudaKernelProvider, groups: &[Vec<u32>]) -> Result<Self> {
49        if groups.len() > u32::MAX as usize {
50            return Err(XlogError::Compilation(
51                "Neural fast-path group count exceeds GPU u32 index space".to_string(),
52            ));
53        }
54        let offset_count = groups.len().checked_add(1).ok_or_else(|| {
55            XlogError::Compilation("Neural fast-path group offset count overflow".to_string())
56        })?;
57        let total_slots = groups.iter().try_fold(0usize, |acc, group| {
58            acc.checked_add(group.len()).ok_or_else(|| {
59                XlogError::Compilation("Neural fast-path slot count overflow".to_string())
60            })
61        })?;
62        if total_slots > u32::MAX as usize {
63            return Err(XlogError::Compilation(
64                "Neural fast-path slot count exceeds GPU u32 index space".to_string(),
65            ));
66        }
67
68        let mut offsets: Vec<u32> = Vec::with_capacity(offset_count);
69        offsets.push(0);
70
71        let mut flat: Vec<u32> = Vec::with_capacity(total_slots);
72        for g in groups {
73            flat.extend_from_slice(g);
74            let next_offset = u32::try_from(flat.len()).map_err(|_| {
75                XlogError::Compilation(
76                    "Neural fast-path slot offset exceeds GPU u32 index space".to_string(),
77                )
78            })?;
79            offsets.push(next_offset);
80        }
81
82        let memory = provider.memory().clone();
83
84        let mut d_offsets = memory.alloc::<u32>(offsets.len())?;
85        provider
86            .htod_sync_copy_into_tracked(&offsets, &mut d_offsets)
87            .map_err(|e| {
88                XlogError::Kernel(format!("Failed to upload weight slot offsets: {}", e))
89            })?;
90
91        let mut d_vars = memory.alloc::<u32>(flat.len())?;
92        provider
93            .htod_sync_copy_into_tracked(&flat, &mut d_vars)
94            .map_err(|e| XlogError::Kernel(format!("Failed to upload weight slot vars: {}", e)))?;
95
96        Ok(Self {
97            group_offsets_host: offsets,
98            group_offsets: d_offsets,
99            slot_cnf_var: d_vars,
100        })
101    }
102
103    pub fn num_groups(&self) -> u32 {
104        debug_assert!(!self.group_offsets_host.is_empty());
105        (self.group_offsets_host.len() - 1) as u32
106    }
107
108    pub fn num_groups_usize(&self) -> usize {
109        debug_assert!(!self.group_offsets_host.is_empty());
110        self.group_offsets_host.len() - 1
111    }
112
113    pub fn total_slots(&self) -> u32 {
114        debug_assert!(!self.group_offsets_host.is_empty());
115        self.group_offsets_host[self.group_offsets_host.len() - 1]
116    }
117
118    pub fn group_offsets(&self) -> &TrackedCudaSlice<u32> {
119        &self.group_offsets
120    }
121
122    pub fn slot_cnf_var(&self) -> &TrackedCudaSlice<u32> {
123        &self.slot_cnf_var
124    }
125
126    /// Device view over `slot_cnf_var` for a single group.
127    pub fn group_slot_cnf_var(&self, group_idx: usize) -> Result<CudaView<'_, u32>> {
128        let start = *self
129            .group_offsets_host
130            .get(group_idx)
131            .ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
132            as usize;
133        let end = *self
134            .group_offsets_host
135            .get(group_idx + 1)
136            .ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
137            as usize;
138        if end < start || end > self.slot_cnf_var.len() {
139            return Err(XlogError::Compilation(
140                "Invalid group slot range in GpuWeightSlots".to_string(),
141            ));
142        }
143        Ok(self.slot_cnf_var.slice(start..end))
144    }
145}