Skip to main content

xlog_prob/
exact_gpu.rs

1//! GPU-only exact compilation helpers (no host reads in this module).
2
3use std::collections::{HashMap, HashSet};
4use std::sync::{Arc, Mutex};
5
6use cudarc::driver::DeviceSlice;
7use xlog_core::{MemoryBudget, Result, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
10
11use crate::compilation::gpu_cache::{GpuCircuitCache, GpuCircuitCacheHandle};
12use crate::compilation::gpu_weights::{
13    apply_query_vars_device, build_evidence_by_var_gpu, build_weights_gpu, map_nodes_to_vars_gpu,
14    restore_query_vars_device, GpuWeights,
15};
16use crate::compilation::{
17    compile_gpu_d4_and_verify_cached, encode_cnf_gpu, DeviceRandomVarList, GpuPirGraph, GpuPirRoots,
18};
19use crate::exact::{
20    build_weight_sources, collect_random_vars_device, default_cache_config, default_compile_config,
21    upload_f64, upload_u32, upload_u8, GpuConfig,
22};
23use crate::provenance::{GroundAtom, Provenance};
24
25pub(crate) struct ExactGpuState {
26    provider: Option<Arc<CudaKernelProvider>>,
27    cache: Option<Mutex<GpuCircuitCache>>,
28    handle: Option<GpuCircuitCacheHandle>,
29    weights: Option<GpuWeights>,
30    max_var: u32,
31    query_vars_device: Option<TrackedCudaSlice<u32>>,
32    query_indices: Vec<usize>,
33    queries: Vec<GroundAtom>,
34}
35
36impl ExactGpuState {
37    fn empty(queries: Vec<GroundAtom>) -> Self {
38        Self {
39            provider: None,
40            cache: None,
41            handle: None,
42            weights: None,
43            max_var: 0,
44            query_vars_device: None,
45            query_indices: Vec::new(),
46            queries,
47        }
48    }
49
50    pub(crate) fn provider(&self) -> Option<&Arc<CudaKernelProvider>> {
51        self.provider.as_ref()
52    }
53
54    pub(crate) fn lock_cache(&self) -> Option<std::sync::MutexGuard<'_, GpuCircuitCache>> {
55        self.cache.as_ref().map(|cache| {
56            cache
57                .lock()
58                .unwrap_or_else(|poisoned| poisoned.into_inner())
59        })
60    }
61
62    pub(crate) fn handle(&self) -> Option<&GpuCircuitCacheHandle> {
63        self.handle.as_ref()
64    }
65
66    pub(crate) fn weights(&self) -> Option<&GpuWeights> {
67        self.weights.as_ref()
68    }
69
70    pub(crate) fn max_var(&self) -> u32 {
71        self.max_var
72    }
73
74    pub(crate) fn query_vars_device(&self) -> Option<&TrackedCudaSlice<u32>> {
75        self.query_vars_device.as_ref()
76    }
77
78    pub(crate) fn query_indices(&self) -> &[usize] {
79        &self.query_indices
80    }
81
82    pub fn queries(&self) -> &[GroundAtom] {
83        &self.queries
84    }
85
86    pub(crate) fn allocate_query_restore(&self) -> Result<Option<TrackedCudaSlice<f64>>> {
87        let Some(provider) = self.provider.as_ref() else {
88            return Ok(None);
89        };
90        let Some(query_vars) = self.query_vars_device.as_ref() else {
91            return Ok(None);
92        };
93        let buf = provider.memory().alloc::<f64>(query_vars.len())?;
94        Ok(Some(buf))
95    }
96
97    pub(crate) fn apply_query_vars(
98        &self,
99        cache: &mut GpuCircuitCache,
100        saved: &mut TrackedCudaSlice<f64>,
101    ) -> Result<()> {
102        let Some(provider) = self.provider.as_ref() else {
103            return Ok(());
104        };
105        let Some(query_vars) = self.query_vars_device.as_ref() else {
106            return Ok(());
107        };
108        let (_, log_false) = cache.var_log_weights_mut();
109        apply_query_vars_device(provider, query_vars, self.max_var, log_false, saved)
110    }
111
112    pub(crate) fn restore_query_vars(
113        &self,
114        cache: &mut GpuCircuitCache,
115        saved: &TrackedCudaSlice<f64>,
116    ) -> Result<()> {
117        let Some(provider) = self.provider.as_ref() else {
118            return Ok(());
119        };
120        let Some(query_vars) = self.query_vars_device.as_ref() else {
121            return Ok(());
122        };
123        let (_, log_false) = cache.var_log_weights_mut();
124        restore_query_vars_device(provider, query_vars, self.max_var, log_false, saved)
125    }
126}
127
128pub(crate) fn compile_provenance_gpu_only(
129    provenance: &Provenance,
130    config: GpuConfig,
131) -> Result<ExactGpuState> {
132    if config.memory_bytes == 0 {
133        return Err(XlogError::Kernel(
134            "GPU memory budget must be non-zero".to_string(),
135        ));
136    }
137
138    let mut roots_set: HashSet<crate::pir::PirNodeId> = HashSet::new();
139    let mut evidence_formulas: Vec<(crate::pir::PirNodeId, bool, GroundAtom)> = Vec::new();
140    let mut evidence_atoms: HashMap<GroundAtom, bool> = HashMap::new();
141    for (atom, value) in &provenance.evidence {
142        if let Some(prev) = evidence_atoms.insert(atom.clone(), *value) {
143            if prev != *value {
144                return Err(XlogError::Execution(format!(
145                    "Exact inference error: conflicting evidence for {}",
146                    display_atom(atom)
147                )));
148            }
149        }
150
151        let formula = provenance.query_formula(&atom.predicate, &atom.args);
152        match formula {
153            Some(id) => {
154                roots_set.insert(id);
155                evidence_formulas.push((id, *value, atom.clone()));
156            }
157            None => {
158                if *value {
159                    return Err(XlogError::Execution(format!(
160                        "Exact inference error: evidence atom is never derivable: {}",
161                        display_atom(atom)
162                    )));
163                }
164            }
165        }
166    }
167
168    let mut queries: Vec<GroundAtom> = Vec::new();
169    let mut query_nodes: Vec<(usize, crate::pir::PirNodeId)> = Vec::new();
170    for atom in &provenance.queries {
171        let formula = provenance.query_formula(&atom.predicate, &atom.args);
172        if let Some(id) = formula {
173            roots_set.insert(id);
174            query_nodes.push((queries.len(), id));
175        }
176        queries.push(atom.clone());
177    }
178
179    // Ensure ALL probabilistic variable nodes (Decision, Lit, NegLit) are reachable
180    // so they get CNF variables. Required for template/neural fast-path slot mapping.
181    for (idx, node) in provenance.pir.nodes().iter().enumerate() {
182        match node {
183            crate::pir::PirNode::Decision { .. }
184            | crate::pir::PirNode::Lit { .. }
185            | crate::pir::PirNode::NegLit { .. } => {
186                roots_set.insert(crate::pir::PirNodeId::from_u32(idx as u32));
187            }
188            _ => {}
189        }
190    }
191
192    let mut roots: Vec<crate::pir::PirNodeId> = roots_set.into_iter().collect();
193    roots.sort();
194
195    if roots.is_empty() {
196        return Ok(ExactGpuState::empty(queries));
197    }
198
199    let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
200    let memory = Arc::new(GpuMemoryManager::new(
201        device.clone(),
202        MemoryBudget::with_limit(config.memory_bytes),
203    ));
204    let provider = Arc::new(CudaKernelProvider::new(device, memory)?);
205
206    let canonical_cnf_hash = crate::cnf::canonical_pir_hash(&provenance.pir, &roots)?;
207    let gpu_pir = GpuPirGraph::from_host(&provenance.pir, &provider)?;
208    let gpu_roots = GpuPirRoots::from_host(&roots, &provider)?;
209    let encoding = encode_cnf_gpu(&gpu_pir, &gpu_roots, &provider)?;
210    if encoding.vars.max_var != encoding.cnf.var_cap {
211        return Err(XlogError::Compilation(format!(
212            "Exact inference error: CNF var_cap {} != vars.max_var {}",
213            encoding.cnf.var_cap, encoding.vars.max_var
214        )));
215    }
216
217    let (leaf_probs_host, choice_true_host, choice_false_host) = build_weight_sources(provenance)?;
218    let leaf_probs = upload_f64(&provider, &leaf_probs_host)?;
219    let choice_true = upload_f64(&provider, &choice_true_host)?;
220    let choice_false = upload_f64(&provider, &choice_false_host)?;
221
222    let evidence_by_var = if evidence_formulas.is_empty() {
223        let mut evidence = provider
224            .memory()
225            .alloc::<u8>((encoding.vars.max_var as usize) + 1)?;
226        provider
227            .device()
228            .inner()
229            .memset_zeros(&mut evidence)
230            .map_err(|e| XlogError::Kernel(format!("Failed to zero evidence buffer: {}", e)))?;
231        evidence
232    } else {
233        let mut nodes: Vec<u32> = Vec::with_capacity(evidence_formulas.len());
234        let mut vals: Vec<u8> = Vec::with_capacity(evidence_formulas.len());
235        for (node, value, _atom) in &evidence_formulas {
236            nodes.push(node.as_u32());
237            vals.push(if *value { 1u8 } else { 2u8 });
238        }
239        let evidence_nodes = upload_u32(&provider, &nodes)?;
240        let evidence_vals = upload_u8(&provider, &vals)?;
241        build_evidence_by_var_gpu(
242            &encoding.vars.node_var,
243            &evidence_nodes,
244            &evidence_vals,
245            encoding.vars.max_var,
246            &provider,
247        )?
248    };
249
250    let weights = build_weights_gpu(
251        &encoding.vars,
252        &leaf_probs,
253        &choice_true,
254        &choice_false,
255        &evidence_by_var,
256        &provider,
257    )?;
258
259    let random_var_count = leaf_probs_host
260        .len()
261        .checked_add(choice_true_host.len())
262        .ok_or_else(|| XlogError::Compilation("random var count overflow".to_string()))?;
263    let random_var_count = u32::try_from(random_var_count)
264        .map_err(|_| XlogError::Compilation("random var count exceeds u32".to_string()))?;
265    let num_leaf_probs = u32::try_from(leaf_probs_host.len())
266        .map_err(|_| XlogError::Compilation("leaf_probs count exceeds u32".to_string()))?;
267    let num_choice_probs = u32::try_from(choice_true_host.len())
268        .map_err(|_| XlogError::Compilation("choice_probs count exceeds u32".to_string()))?;
269    let (random_var_list, actual_random_var_count) = collect_random_vars_device(
270        &provider,
271        &encoding.vars,
272        num_leaf_probs,
273        num_choice_probs,
274        random_var_count,
275    )?;
276    let random_vars = DeviceRandomVarList::from_device(random_var_list, actual_random_var_count)?;
277    let compile_config = default_compile_config(&encoding.cnf, config.memory_bytes)?;
278    let cache_config = default_cache_config(&encoding.cnf, &compile_config)?;
279
280    let mut cache = GpuCircuitCache::new(&provider, cache_config)?;
281    let (handle, _compile_profile) = compile_gpu_d4_and_verify_cached(
282        &encoding.cnf,
283        &encoding.decision_var_limit,
284        &provider,
285        &compile_config,
286        &mut cache,
287        &random_vars,
288        Some(canonical_cnf_hash),
289    )?;
290    cache.store_weights(&handle, &weights.log_true, &weights.log_false)?;
291
292    let (query_indices, query_vars_device) = if query_nodes.is_empty() {
293        (Vec::new(), None)
294    } else {
295        let mut node_ids: Vec<u32> = Vec::with_capacity(query_nodes.len());
296        let mut indices: Vec<usize> = Vec::with_capacity(query_nodes.len());
297        for (idx, node) in &query_nodes {
298            indices.push(*idx);
299            node_ids.push(node.as_u32());
300        }
301        let node_ids_device = upload_u32(&provider, &node_ids)?;
302        let vars_device = map_nodes_to_vars_gpu(
303            &encoding.vars.node_var,
304            &node_ids_device,
305            encoding.vars.max_var,
306            &provider,
307        )?;
308        (indices, Some(vars_device))
309    };
310
311    Ok(ExactGpuState {
312        provider: Some(provider),
313        cache: Some(Mutex::new(cache)),
314        handle: Some(handle),
315        weights: Some(weights),
316        max_var: encoding.vars.max_var,
317        query_vars_device,
318        query_indices,
319        queries,
320    })
321}
322
323fn display_atom(atom: &GroundAtom) -> String {
324    if atom.args.is_empty() {
325        format!("{}()", atom.predicate)
326    } else {
327        format!("{}({} args)", atom.predicate, atom.args.len())
328    }
329}