1use std::collections::{HashMap, HashSet};
4use std::sync::{Arc, Mutex};
5
6use cudarc::driver::DeviceSlice;
7use xlog_core::{MemoryBudget, Result, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
10
11use crate::compilation::gpu_cache::{GpuCircuitCache, GpuCircuitCacheHandle};
12use crate::compilation::gpu_weights::{
13 apply_query_vars_device, build_evidence_by_var_gpu, build_weights_gpu, map_nodes_to_vars_gpu,
14 restore_query_vars_device, GpuWeights,
15};
16use crate::compilation::{
17 compile_gpu_d4_and_verify_cached, encode_cnf_gpu, DeviceRandomVarList, GpuPirGraph, GpuPirRoots,
18};
19use crate::exact::{
20 build_weight_sources, collect_random_vars_device, default_cache_config, default_compile_config,
21 upload_f64, upload_u32, upload_u8, GpuConfig,
22};
23use crate::provenance::{GroundAtom, Provenance};
24
25pub(crate) struct ExactGpuState {
26 provider: Option<Arc<CudaKernelProvider>>,
27 cache: Option<Mutex<GpuCircuitCache>>,
28 handle: Option<GpuCircuitCacheHandle>,
29 weights: Option<GpuWeights>,
30 max_var: u32,
31 query_vars_device: Option<TrackedCudaSlice<u32>>,
32 query_indices: Vec<usize>,
33 queries: Vec<GroundAtom>,
34}
35
36impl ExactGpuState {
37 fn empty(queries: Vec<GroundAtom>) -> Self {
38 Self {
39 provider: None,
40 cache: None,
41 handle: None,
42 weights: None,
43 max_var: 0,
44 query_vars_device: None,
45 query_indices: Vec::new(),
46 queries,
47 }
48 }
49
50 pub(crate) fn provider(&self) -> Option<&Arc<CudaKernelProvider>> {
51 self.provider.as_ref()
52 }
53
54 pub(crate) fn lock_cache(&self) -> Option<std::sync::MutexGuard<'_, GpuCircuitCache>> {
55 self.cache.as_ref().map(|cache| {
56 cache
57 .lock()
58 .unwrap_or_else(|poisoned| poisoned.into_inner())
59 })
60 }
61
62 pub(crate) fn handle(&self) -> Option<&GpuCircuitCacheHandle> {
63 self.handle.as_ref()
64 }
65
66 pub(crate) fn weights(&self) -> Option<&GpuWeights> {
67 self.weights.as_ref()
68 }
69
70 pub(crate) fn max_var(&self) -> u32 {
71 self.max_var
72 }
73
74 pub(crate) fn query_vars_device(&self) -> Option<&TrackedCudaSlice<u32>> {
75 self.query_vars_device.as_ref()
76 }
77
78 pub(crate) fn query_indices(&self) -> &[usize] {
79 &self.query_indices
80 }
81
82 pub fn queries(&self) -> &[GroundAtom] {
83 &self.queries
84 }
85
86 pub(crate) fn allocate_query_restore(&self) -> Result<Option<TrackedCudaSlice<f64>>> {
87 let Some(provider) = self.provider.as_ref() else {
88 return Ok(None);
89 };
90 let Some(query_vars) = self.query_vars_device.as_ref() else {
91 return Ok(None);
92 };
93 let buf = provider.memory().alloc::<f64>(query_vars.len())?;
94 Ok(Some(buf))
95 }
96
97 pub(crate) fn apply_query_vars(
98 &self,
99 cache: &mut GpuCircuitCache,
100 saved: &mut TrackedCudaSlice<f64>,
101 ) -> Result<()> {
102 let Some(provider) = self.provider.as_ref() else {
103 return Ok(());
104 };
105 let Some(query_vars) = self.query_vars_device.as_ref() else {
106 return Ok(());
107 };
108 let (_, log_false) = cache.var_log_weights_mut();
109 apply_query_vars_device(provider, query_vars, self.max_var, log_false, saved)
110 }
111
112 pub(crate) fn restore_query_vars(
113 &self,
114 cache: &mut GpuCircuitCache,
115 saved: &TrackedCudaSlice<f64>,
116 ) -> Result<()> {
117 let Some(provider) = self.provider.as_ref() else {
118 return Ok(());
119 };
120 let Some(query_vars) = self.query_vars_device.as_ref() else {
121 return Ok(());
122 };
123 let (_, log_false) = cache.var_log_weights_mut();
124 restore_query_vars_device(provider, query_vars, self.max_var, log_false, saved)
125 }
126}
127
128pub(crate) fn compile_provenance_gpu_only(
129 provenance: &Provenance,
130 config: GpuConfig,
131) -> Result<ExactGpuState> {
132 if config.memory_bytes == 0 {
133 return Err(XlogError::Kernel(
134 "GPU memory budget must be non-zero".to_string(),
135 ));
136 }
137
138 let mut roots_set: HashSet<crate::pir::PirNodeId> = HashSet::new();
139 let mut evidence_formulas: Vec<(crate::pir::PirNodeId, bool, GroundAtom)> = Vec::new();
140 let mut evidence_atoms: HashMap<GroundAtom, bool> = HashMap::new();
141 for (atom, value) in &provenance.evidence {
142 if let Some(prev) = evidence_atoms.insert(atom.clone(), *value) {
143 if prev != *value {
144 return Err(XlogError::Execution(format!(
145 "Exact inference error: conflicting evidence for {}",
146 display_atom(atom)
147 )));
148 }
149 }
150
151 let formula = provenance.query_formula(&atom.predicate, &atom.args);
152 match formula {
153 Some(id) => {
154 roots_set.insert(id);
155 evidence_formulas.push((id, *value, atom.clone()));
156 }
157 None => {
158 if *value {
159 return Err(XlogError::Execution(format!(
160 "Exact inference error: evidence atom is never derivable: {}",
161 display_atom(atom)
162 )));
163 }
164 }
165 }
166 }
167
168 let mut queries: Vec<GroundAtom> = Vec::new();
169 let mut query_nodes: Vec<(usize, crate::pir::PirNodeId)> = Vec::new();
170 for atom in &provenance.queries {
171 let formula = provenance.query_formula(&atom.predicate, &atom.args);
172 if let Some(id) = formula {
173 roots_set.insert(id);
174 query_nodes.push((queries.len(), id));
175 }
176 queries.push(atom.clone());
177 }
178
179 for (idx, node) in provenance.pir.nodes().iter().enumerate() {
182 match node {
183 crate::pir::PirNode::Decision { .. }
184 | crate::pir::PirNode::Lit { .. }
185 | crate::pir::PirNode::NegLit { .. } => {
186 roots_set.insert(crate::pir::PirNodeId::from_u32(idx as u32));
187 }
188 _ => {}
189 }
190 }
191
192 let mut roots: Vec<crate::pir::PirNodeId> = roots_set.into_iter().collect();
193 roots.sort();
194
195 if roots.is_empty() {
196 return Ok(ExactGpuState::empty(queries));
197 }
198
199 let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
200 let memory = Arc::new(GpuMemoryManager::new(
201 device.clone(),
202 MemoryBudget::with_limit(config.memory_bytes),
203 ));
204 let provider = Arc::new(CudaKernelProvider::new(device, memory)?);
205
206 let canonical_cnf_hash = crate::cnf::canonical_pir_hash(&provenance.pir, &roots)?;
207 let gpu_pir = GpuPirGraph::from_host(&provenance.pir, &provider)?;
208 let gpu_roots = GpuPirRoots::from_host(&roots, &provider)?;
209 let encoding = encode_cnf_gpu(&gpu_pir, &gpu_roots, &provider)?;
210 if encoding.vars.max_var != encoding.cnf.var_cap {
211 return Err(XlogError::Compilation(format!(
212 "Exact inference error: CNF var_cap {} != vars.max_var {}",
213 encoding.cnf.var_cap, encoding.vars.max_var
214 )));
215 }
216
217 let (leaf_probs_host, choice_true_host, choice_false_host) = build_weight_sources(provenance)?;
218 let leaf_probs = upload_f64(&provider, &leaf_probs_host)?;
219 let choice_true = upload_f64(&provider, &choice_true_host)?;
220 let choice_false = upload_f64(&provider, &choice_false_host)?;
221
222 let evidence_by_var = if evidence_formulas.is_empty() {
223 let mut evidence = provider
224 .memory()
225 .alloc::<u8>((encoding.vars.max_var as usize) + 1)?;
226 provider
227 .device()
228 .inner()
229 .memset_zeros(&mut evidence)
230 .map_err(|e| XlogError::Kernel(format!("Failed to zero evidence buffer: {}", e)))?;
231 evidence
232 } else {
233 let mut nodes: Vec<u32> = Vec::with_capacity(evidence_formulas.len());
234 let mut vals: Vec<u8> = Vec::with_capacity(evidence_formulas.len());
235 for (node, value, _atom) in &evidence_formulas {
236 nodes.push(node.as_u32());
237 vals.push(if *value { 1u8 } else { 2u8 });
238 }
239 let evidence_nodes = upload_u32(&provider, &nodes)?;
240 let evidence_vals = upload_u8(&provider, &vals)?;
241 build_evidence_by_var_gpu(
242 &encoding.vars.node_var,
243 &evidence_nodes,
244 &evidence_vals,
245 encoding.vars.max_var,
246 &provider,
247 )?
248 };
249
250 let weights = build_weights_gpu(
251 &encoding.vars,
252 &leaf_probs,
253 &choice_true,
254 &choice_false,
255 &evidence_by_var,
256 &provider,
257 )?;
258
259 let random_var_count = leaf_probs_host
260 .len()
261 .checked_add(choice_true_host.len())
262 .ok_or_else(|| XlogError::Compilation("random var count overflow".to_string()))?;
263 let random_var_count = u32::try_from(random_var_count)
264 .map_err(|_| XlogError::Compilation("random var count exceeds u32".to_string()))?;
265 let num_leaf_probs = u32::try_from(leaf_probs_host.len())
266 .map_err(|_| XlogError::Compilation("leaf_probs count exceeds u32".to_string()))?;
267 let num_choice_probs = u32::try_from(choice_true_host.len())
268 .map_err(|_| XlogError::Compilation("choice_probs count exceeds u32".to_string()))?;
269 let (random_var_list, actual_random_var_count) = collect_random_vars_device(
270 &provider,
271 &encoding.vars,
272 num_leaf_probs,
273 num_choice_probs,
274 random_var_count,
275 )?;
276 let random_vars = DeviceRandomVarList::from_device(random_var_list, actual_random_var_count)?;
277 let compile_config = default_compile_config(&encoding.cnf, config.memory_bytes)?;
278 let cache_config = default_cache_config(&encoding.cnf, &compile_config)?;
279
280 let mut cache = GpuCircuitCache::new(&provider, cache_config)?;
281 let (handle, _compile_profile) = compile_gpu_d4_and_verify_cached(
282 &encoding.cnf,
283 &encoding.decision_var_limit,
284 &provider,
285 &compile_config,
286 &mut cache,
287 &random_vars,
288 Some(canonical_cnf_hash),
289 )?;
290 cache.store_weights(&handle, &weights.log_true, &weights.log_false)?;
291
292 let (query_indices, query_vars_device) = if query_nodes.is_empty() {
293 (Vec::new(), None)
294 } else {
295 let mut node_ids: Vec<u32> = Vec::with_capacity(query_nodes.len());
296 let mut indices: Vec<usize> = Vec::with_capacity(query_nodes.len());
297 for (idx, node) in &query_nodes {
298 indices.push(*idx);
299 node_ids.push(node.as_u32());
300 }
301 let node_ids_device = upload_u32(&provider, &node_ids)?;
302 let vars_device = map_nodes_to_vars_gpu(
303 &encoding.vars.node_var,
304 &node_ids_device,
305 encoding.vars.max_var,
306 &provider,
307 )?;
308 (indices, Some(vars_device))
309 };
310
311 Ok(ExactGpuState {
312 provider: Some(provider),
313 cache: Some(Mutex::new(cache)),
314 handle: Some(handle),
315 weights: Some(weights),
316 max_var: encoding.vars.max_var,
317 query_vars_device,
318 query_indices,
319 queries,
320 })
321}
322
323fn display_atom(atom: &GroundAtom) -> String {
324 if atom.args.is_empty() {
325 format!("{}()", atom.predicate)
326 } else {
327 format!("{}({} args)", atom.predicate, atom.args.len())
328 }
329}