1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Instant;
4
5use pyo3::exceptions::{PyRuntimeError, PyValueError};
6use pyo3::prelude::*;
7use pyo3::types::{PyDict, PyList, PySequence};
8
9use xlog_cuda::DlpackManagedTensor;
10use xlog_gpu::logic as gpu_logic;
11use xlog_logic::ast::ProbEngine;
12use xlog_neural::{NetworkRegistry, TensorSourceRegistry};
13use xlog_prob::exact::{ExactDdnnfProgram, GpuConfig};
14use xlog_prob::mc::McProgram;
15use xlog_runtime::RelationDelta;
16
17use std::collections::HashMap as StdHashMap;
18
19use super::neural_registry::NeuralPredicateRegistry;
20use super::{
21 dlpack_capsule_from_tensor, dlpack_from_py, enforce_call_memory_limit,
22 parse_prob_engine_override, provider_from_config, provider_memory_stats, types,
23 CompiledLogicProgram, CompiledProbProgram, CompiledProgram, LogicDeltaStats, LogicEvalResult,
24 LogicProgram, LogicQueryResult, LogicRelationSession, Program, RelationChangeCallback,
25};
26
27#[pymethods]
28impl Program {
29 #[staticmethod]
30 #[pyo3(signature = (source, device=0, memory_mb=32768, prob_engine=None))]
31 pub fn compile(
32 source: &str,
33 device: usize,
34 memory_mb: u64,
35 prob_engine: Option<String>,
36 ) -> PyResult<CompiledProgram> {
37 if memory_mb == 0 {
38 return Err(PyValueError::new_err("memory_mb must be > 0"));
39 }
40
41 let mut config = GpuConfig::default();
42 config.device_ordinal = device;
43 config.memory_bytes = memory_mb * 1024 * 1024;
44
45 let ast = xlog_logic::parse_program(source).map_err(types::xlog_err)?;
47
48 let declared_networks: HashSet<String> = ast
50 .neural_predicates
51 .iter()
52 .map(|np| np.network.clone())
53 .collect();
54 let mut declared_network_forms: HashMap<String, bool> = HashMap::new();
56 for np in &ast.neural_predicates {
57 let is_embedding = np.labels.is_none();
58 match declared_network_forms.get(&np.network) {
59 Some(&existing_form) if existing_form != is_embedding => {
60 return Err(PyValueError::new_err(format!(
61 "network '{}' is declared as both classification and embedding; \
62 each network name must have a single form",
63 np.network
64 )));
65 }
66 _ => {
67 declared_network_forms.insert(np.network.clone(), is_embedding);
68 }
69 }
70 }
71
72 let neural_registry = NeuralPredicateRegistry::from_ast(&ast).map_err(types::val_err)?;
73
74 let engine = match prob_engine {
75 Some(s) => parse_prob_engine_override(&s)?,
76 None => ast.prob_engine(),
77 };
78
79 let program = match engine {
80 ProbEngine::ExactDdnnf => CompiledProbProgram::Exact(
81 ExactDdnnfProgram::compile_source_with_gpu(source, config)
82 .map_err(types::xlog_err)?,
83 ),
84 ProbEngine::Mc => CompiledProbProgram::Mc(
85 McProgram::compile_source_with_gpu(source, config).map_err(types::xlog_err)?,
86 ),
87 };
88 let provider = provider_from_config(config).map_err(types::xlog_err)?;
89
90 Ok(CompiledProgram {
91 program,
92 output_provider: Arc::new(provider),
93 network_registry: NetworkRegistry::new(),
94 neural_registry,
95 declared_networks,
96 declared_network_forms,
97 tensor_sources: TensorSourceRegistry::new(),
98 domain_source: None,
99 _source: source.to_string(),
100 ast,
101 _gpu_config: config,
102 _prob_engine: engine,
103 query_signature_cache: StdHashMap::new(),
104 circuit_cache: StdHashMap::new(),
105 circuit_cache_hits: 0,
106 circuit_cache_misses: 0,
107 template_compile_count: 0,
108 batch_queries: true,
109 last_compile_profile: None,
110 })
111 }
112}
113
114#[pymethods]
115impl LogicProgram {
116 #[staticmethod]
117 #[pyo3(signature = (source, device=0, memory_mb=32768))]
118 pub fn compile(source: &str, device: usize, memory_mb: u64) -> PyResult<CompiledLogicProgram> {
119 if memory_mb == 0 {
120 return Err(PyValueError::new_err("memory_mb must be > 0"));
121 }
122
123 let mut config = GpuConfig::default();
124 config.device_ordinal = device;
125 config.memory_bytes = memory_mb * 1024 * 1024;
126
127 let program = gpu_logic::LogicProgram::compile(source).map_err(types::xlog_err)?;
128 let provider = provider_from_config(config).map_err(types::xlog_err)?;
129
130 Ok(CompiledLogicProgram {
131 program,
132 provider: Arc::new(provider),
133 })
134 }
135}
136
137#[pymethods]
138impl CompiledLogicProgram {
139 #[pyo3(signature = (dlpack_inputs=None, memory_mb=None))]
140 pub fn evaluate(
141 &self,
142 py: Python<'_>,
143 dlpack_inputs: Option<&Bound<'_, PyDict>>,
144 memory_mb: Option<u64>,
145 ) -> PyResult<LogicEvalResult> {
146 enforce_call_memory_limit(&self.provider, memory_mb)?;
147 let mut inputs: HashMap<String, xlog_cuda::CudaBuffer> = HashMap::new();
148
149 if let Some(dict) = dlpack_inputs {
150 for (k, v) in dict.iter() {
151 let name: String = k.extract()?;
152 let schema = self.program.schema(&name).ok_or_else(|| {
153 PyValueError::new_err(format!(
154 "Unknown input relation {} (not present in compiled schemas)",
155 name
156 ))
157 })?;
158
159 let tensors = collect_dlpack_columns(
160 &v,
161 &format!(
162 "Input relation {} must be a sequence of DLPack columns",
163 name
164 ),
165 )?;
166
167 let buffer = self
168 .provider
169 .from_dlpack_tensors_with_schema(schema.clone(), tensors)
170 .map_err(types::xlog_err)?;
171
172 inputs.insert(name, buffer);
173 }
174 }
175
176 let result = self
177 .program
178 .evaluate(self.provider.clone(), inputs)
179 .map_err(types::xlog_err)?;
180 pack_logic_result_with_provider(py, &self.provider, result)
181 }
182
183 pub fn session(&self) -> PyResult<LogicRelationSession> {
184 let relation_store = self
185 .program
186 .create_relation_store(self.provider.clone())
187 .map_err(types::xlog_err)?;
188 Ok(LogicRelationSession {
189 program: self.program.clone(),
190 provider: self.provider.clone(),
191 relation_store,
192 evaluation_store: None,
193 session_runtime: None,
194 last_delta_stats: None,
195 relation_callbacks: Vec::new(),
196 next_relation_callback_id: 1,
197 relation_generations: HashMap::new(),
198 })
199 }
200
201 pub fn memory_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
203 provider_memory_stats(py, &self.provider)
204 }
205
206 pub fn rule_provenance(&self, py: Python<'_>) -> PyResult<PyObject> {
207 pack_rule_provenance(py, &self.program.rule_provenance())
208 }
209
210 pub fn proof_traces(&self, py: Python<'_>) -> PyResult<PyObject> {
211 pack_proof_traces(py, &self.program.proof_traces())
212 }
213}
214
215impl CompiledLogicProgram {}
216
217#[pymethods]
218impl LogicRelationSession {
219 pub fn put_relation(
220 &mut self,
221 name: String,
222 dlpack_columns: &Bound<'_, PyAny>,
223 ) -> PyResult<()> {
224 if name.starts_with("__") {
225 return Err(PyValueError::new_err(format!(
226 "Relation {} is internal and cannot be stored in a persistent session",
227 name
228 )));
229 }
230 let schema = self.program.schema(&name).ok_or_else(|| {
231 PyValueError::new_err(format!(
232 "Unknown relation {} (not present in compiled schemas)",
233 name
234 ))
235 })?;
236 let tensors = collect_dlpack_columns(
237 dlpack_columns,
238 &format!("Relation {} must be a sequence of DLPack columns", name),
239 )?;
240 let buffer = self
241 .provider
242 .from_dlpack_tensors_with_schema(schema.clone(), tensors)
243 .map_err(types::xlog_err)?;
244 self.relation_store.put(&name, buffer);
245 self.evaluation_store = None;
246 self.session_runtime = None;
247 self.last_delta_stats = None;
248 Ok(())
249 }
250
251 #[pyo3(signature = (memory_mb=None))]
252 pub fn evaluate(
253 &mut self,
254 py: Python<'_>,
255 memory_mb: Option<u64>,
256 ) -> PyResult<LogicEvalResult> {
257 enforce_call_memory_limit(&self.provider, memory_mb)?;
258 let result = if let Some(store) = &self.evaluation_store {
259 self.program
260 .evaluate_cached_relation_store(self.provider.clone(), store)
261 .map_err(types::xlog_err)?
262 } else {
263 if self.session_runtime.is_none() {
264 self.session_runtime = Some(
265 self.program
266 .create_session_runtime(self.provider.clone(), &self.relation_store, false)
267 .map_err(types::xlog_err)?,
268 );
269 }
270 let runtime = self.session_runtime.as_mut().ok_or_else(|| {
271 PyRuntimeError::new_err("session runtime unavailable during evaluation")
272 })?;
273 let (result, store) = self
274 .program
275 .evaluate_with_session_runtime(self.provider.clone(), runtime)
276 .map_err(types::xlog_err)?;
277 self.evaluation_store = Some(store);
278 result
279 };
280 pack_logic_result_with_provider(py, &self.provider, result)
281 }
282
283 pub fn insert_relation(
284 &mut self,
285 py: Python<'_>,
286 name: String,
287 dlpack_columns: &Bound<'_, PyAny>,
288 ) -> PyResult<PyObject> {
289 let insert = self.relation_delta_buffer(&name, dlpack_columns)?;
290 self.apply_single_relation_delta(py, name, Some(insert), None)
291 }
292
293 pub fn delete_relation(
294 &mut self,
295 py: Python<'_>,
296 name: String,
297 dlpack_columns: &Bound<'_, PyAny>,
298 ) -> PyResult<PyObject> {
299 let delete = self.relation_delta_buffer(&name, dlpack_columns)?;
300 self.apply_single_relation_delta(py, name, None, Some(delete))
301 }
302
303 #[pyo3(signature = (name, insert_columns=None, delete_columns=None))]
304 pub fn apply_relation_delta(
305 &mut self,
306 py: Python<'_>,
307 name: String,
308 insert_columns: Option<&Bound<'_, PyAny>>,
309 delete_columns: Option<&Bound<'_, PyAny>>,
310 ) -> PyResult<PyObject> {
311 if insert_columns.is_none() && delete_columns.is_none() {
312 return Err(PyValueError::new_err(
313 "apply_relation_delta requires insert_columns, delete_columns, or both",
314 ));
315 }
316 let insert = insert_columns
317 .map(|columns| self.relation_delta_buffer(&name, columns))
318 .transpose()?;
319 let delete = delete_columns
320 .map(|columns| self.relation_delta_buffer(&name, columns))
321 .transpose()?;
322 self.apply_single_relation_delta(py, name, insert, delete)
323 }
324
325 pub fn apply_relation_delta_batch(
326 &mut self,
327 py: Python<'_>,
328 updates: &Bound<'_, PyAny>,
329 ) -> PyResult<PyObject> {
330 let (batch, relation_names) =
331 self.parse_relation_delta_batch("apply_relation_delta_batch", updates)?;
332
333 let report = self
334 .program
335 .apply_relation_delta_batch_with_session_runtime(
336 self.provider.clone(),
337 &mut self.relation_store,
338 &mut self.evaluation_store,
339 &mut self.session_runtime,
340 batch,
341 )
342 .map_err(types::xlog_err)?;
343 let stats = logic_delta_stats_from_report(report);
344 self.last_delta_stats = Some(stats.clone());
345 self.fire_relation_callbacks(py, &relation_names, &stats)?;
346 pack_delta_stats(py, &stats)
347 }
348
349 #[pyo3(signature = (updates, check_equivalence=false))]
350 pub fn apply_relation_delta_debug(
351 &mut self,
352 py: Python<'_>,
353 updates: &Bound<'_, PyAny>,
354 check_equivalence: bool,
355 ) -> PyResult<PyObject> {
356 let (batch, relation_names) =
357 self.parse_relation_delta_batch("apply_relation_delta_debug", updates)?;
358 let delta_start = Instant::now();
359 let report = self
360 .program
361 .apply_relation_delta_batch_with_session_runtime(
362 self.provider.clone(),
363 &mut self.relation_store,
364 &mut self.evaluation_store,
365 &mut self.session_runtime,
366 batch,
367 )
368 .map_err(types::xlog_err)?;
369 let delta_micros = delta_start.elapsed().as_micros().max(1) as u64;
370 let mut stats = logic_delta_stats_from_report(report);
371 if check_equivalence {
372 let full_start = Instant::now();
373 let (_, full_store) = self
374 .program
375 .evaluate_with_relation_store_and_cache(
376 self.provider.clone(),
377 &self.relation_store,
378 false,
379 )
380 .map_err(types::xlog_err)?;
381 let full_micros = full_start.elapsed().as_micros() as u64;
382 let cached_store = self.evaluation_store.as_ref().ok_or_else(|| {
383 PyRuntimeError::new_err("delta debug missing cached store after delta application")
384 })?;
385 stats.equivalent_to_full_recompute = Some(
386 self.program
387 .relation_stores_query_equivalent(
388 self.provider.as_ref(),
389 &full_store,
390 cached_store,
391 )
392 .map_err(types::xlog_err)?,
393 );
394 let speedup = full_micros as f64 / delta_micros as f64;
395 stats.planner_telemetry.measured_delta_speedup = Some(speedup);
396 if speedup >= 1.0 {
397 stats
398 .planner_telemetry
399 .planner_advice
400 .push(format!("delta path is faster by {speedup:.2}x"));
401 } else {
402 stats.planner_telemetry.planner_advice.push(format!(
403 "full recompute may be faster; delta measured {speedup:.2}x"
404 ));
405 }
406 }
407 self.last_delta_stats = Some(stats.clone());
408 self.fire_relation_callbacks(py, &relation_names, &stats)?;
409 pack_delta_stats(py, &stats)
410 }
411
412 pub fn delta_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
413 match &self.last_delta_stats {
414 Some(stats) => pack_delta_stats(py, stats),
415 None => {
416 let dict = PyDict::new(py);
417 dict.set_item("status", "unavailable")?;
418 dict.set_item("reason", "no relation delta has been applied")?;
419 Ok(dict.into())
420 }
421 }
422 }
423
424 pub fn rule_provenance(&self, py: Python<'_>) -> PyResult<PyObject> {
425 pack_rule_provenance(py, &self.program.rule_provenance())
426 }
427
428 pub fn proof_traces(&self, py: Python<'_>) -> PyResult<PyObject> {
429 pack_proof_traces(py, &self.program.proof_traces())
430 }
431
432 pub fn register_relation_callback(
433 &mut self,
434 py: Python<'_>,
435 callback: PyObject,
436 ) -> PyResult<u64> {
437 if !callback.bind(py).is_callable() {
438 return Err(PyValueError::new_err(
439 "register_relation_callback expects a callable",
440 ));
441 }
442 let id = self.next_relation_callback_id;
443 self.next_relation_callback_id = self.next_relation_callback_id.saturating_add(1);
444 self.relation_callbacks
445 .push(RelationChangeCallback { id, callback });
446 Ok(id)
447 }
448
449 pub fn unregister_relation_callback(&mut self, callback_id: u64) -> bool {
450 let before = self.relation_callbacks.len();
451 self.relation_callbacks
452 .retain(|registered| registered.id != callback_id);
453 before != self.relation_callbacks.len()
454 }
455
456 pub fn cuda_graph_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
457 let dict = PyDict::new(py);
458 dict.set_item(
459 "csm_cuda_graph_captures",
460 self.provider.csm_cuda_graph_captures(),
461 )?;
462 dict.set_item(
463 "csm_cuda_graph_launches",
464 self.provider.csm_cuda_graph_launches(),
465 )?;
466 dict.set_item(
467 "csm_cuda_graph_fallbacks",
468 self.provider.csm_cuda_graph_fallbacks(),
469 )?;
470 dict.set_item(
471 "csm_cuda_graph_cache_hits",
472 self.provider.csm_cuda_graph_cache_hits(),
473 )?;
474 Ok(dict.into())
475 }
476
477 pub fn host_transfer_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
478 let stats = self.provider.host_transfer_stats();
479 let dict = PyDict::new(py);
480 dict.set_item("dtoh_bytes", stats.dtoh_bytes)?;
481 dict.set_item("htod_bytes", stats.htod_bytes)?;
482 dict.set_item("dtoh_calls", stats.dtoh_calls)?;
483 dict.set_item("htod_calls", stats.htod_calls)?;
484 Ok(dict.into())
485 }
486
487 pub fn join_index_cache_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
488 let dict = PyDict::new(py);
489 let stats = self
490 .session_runtime
491 .as_ref()
492 .map(|runtime| runtime.join_index_cache_stats())
493 .unwrap_or_default();
494 dict.set_item("lookups", stats.lookups)?;
495 dict.set_item("hits", stats.hits)?;
496 dict.set_item("misses", stats.misses)?;
497 dict.set_item("builds", stats.builds)?;
498 dict.set_item("evictions", stats.evictions)?;
499 dict.set_item("invalidations", stats.invalidations)?;
500 dict.set_item("stale_rejections", stats.stale_rejections)?;
501 dict.set_item("background_build_requests", stats.background_build_requests)?;
502 dict.set_item(
503 "background_builds_completed",
504 stats.background_builds_completed,
505 )?;
506 dict.set_item(
507 "background_builds_deferred",
508 stats.background_builds_deferred,
509 )?;
510 dict.set_item("entries", stats.entries)?;
511 dict.set_item("total_bytes", stats.total_bytes)?;
512 Ok(dict.into())
513 }
514
515 pub fn wcoj_dispatch_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
519 let dict = PyDict::new(py);
520 let stats = self
521 .session_runtime
522 .as_ref()
523 .map(|runtime| runtime.wcoj_dispatch_stats())
524 .unwrap_or_default();
525 dict.set_item("free_join_dispatch_count", stats.free_join_dispatch_count)?;
526 dict.set_item(
527 "factorized_delta_dispatch_count",
528 stats.factorized_delta_dispatch_count,
529 )?;
530 dict.set_item(
531 "wcoj_groupby_fusion_dispatch_count",
532 stats.wcoj_groupby_fusion_dispatch_count,
533 )?;
534 dict.set_item("wcoj_error_decline_count", stats.wcoj_error_decline_count)?;
535 Ok(dict.into())
536 }
537
538 pub fn reset_host_transfer_stats(&self) {
539 self.provider.reset_host_transfer_stats()
540 }
541
542 pub fn memory_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
544 provider_memory_stats(py, &self.provider)
545 }
546
547 pub fn export_relation(&mut self, py: Python<'_>, name: &str) -> PyResult<Vec<PyObject>> {
548 let existing = self.relation_store.get(name).ok_or_else(|| {
549 PyValueError::new_err(format!(
550 "Relation '{}' not found in persistent session",
551 name
552 ))
553 })?;
554 let replacement = self
555 .provider
556 .clone_buffer(existing)
557 .map_err(types::xlog_err)?;
558 let buffer = self.relation_store.remove(name).ok_or_else(|| {
559 PyRuntimeError::new_err(format!("Relation '{}' disappeared during export", name))
560 })?;
561 self.relation_store.put(name, replacement);
562 export_buffer_columns(py, &self.provider, buffer)
563 }
564
565 pub fn remove_relation(&mut self, name: &str) -> bool {
566 let removed = self.relation_store.remove(name).is_some();
567 if removed {
568 self.evaluation_store = None;
569 self.session_runtime = None;
570 self.last_delta_stats = None;
571 }
572 removed
573 }
574
575 pub fn clear_relations(&mut self) {
576 self.relation_store.clear();
577 self.evaluation_store = None;
578 self.session_runtime = None;
579 self.last_delta_stats = None;
580 }
581}
582
583impl LogicRelationSession {
584 fn parse_relation_delta_batch(
585 &self,
586 method_name: &str,
587 updates: &Bound<'_, PyAny>,
588 ) -> PyResult<(Vec<(String, RelationDelta)>, Vec<String>)> {
589 let seq = updates.downcast::<PySequence>().map_err(|_| {
590 PyValueError::new_err(format!(
591 "{method_name} expects a sequence of update dictionaries"
592 ))
593 })?;
594 let mut batch: Vec<(String, RelationDelta)> = Vec::with_capacity(seq.len()? as usize);
595 let mut relation_names: Vec<String> = Vec::new();
596 for item in seq.try_iter()? {
597 let item = item?;
598 let dict = item.downcast::<PyDict>().map_err(|_| {
599 PyValueError::new_err(format!("{method_name} updates must be dictionaries"))
600 })?;
601 let name_obj = dict.get_item("name")?.ok_or_else(|| {
602 PyValueError::new_err(format!("{method_name} update missing 'name'"))
603 })?;
604 let name: String = name_obj.extract()?;
605 let insert = optional_delta_columns(dict, "insert_columns")
606 .map(|columns| self.relation_delta_buffer(&name, &columns))
607 .transpose()?;
608 let delete = optional_delta_columns(dict, "delete_columns")
609 .map(|columns| self.relation_delta_buffer(&name, &columns))
610 .transpose()?;
611 if insert.is_none() && delete.is_none() {
612 return Err(PyValueError::new_err(format!(
613 "{method_name} updates require insert_columns, delete_columns, or both"
614 )));
615 }
616 relation_names.push(name.clone());
617 batch.push((name, RelationDelta::new(insert, delete)));
618 }
619 Ok((batch, relation_names))
620 }
621
622 fn relation_delta_buffer(
623 &self,
624 name: &str,
625 dlpack_columns: &Bound<'_, PyAny>,
626 ) -> PyResult<xlog_cuda::CudaBuffer> {
627 if name.starts_with("__") {
628 return Err(PyValueError::new_err(format!(
629 "Relation {} is internal and cannot be updated in a persistent session",
630 name
631 )));
632 }
633 let schema = self.program.schema(name).ok_or_else(|| {
634 PyValueError::new_err(format!(
635 "Unknown relation {} (not present in compiled schemas)",
636 name
637 ))
638 })?;
639 let tensors = collect_dlpack_columns(
640 dlpack_columns,
641 &format!(
642 "Relation {} delta must be a sequence of DLPack columns",
643 name
644 ),
645 )?;
646 self.provider
647 .from_dlpack_tensors_with_schema(schema.clone(), tensors)
648 .map_err(types::xlog_err)
649 }
650
651 fn apply_single_relation_delta(
652 &mut self,
653 py: Python<'_>,
654 name: String,
655 insert: Option<xlog_cuda::CudaBuffer>,
656 delete: Option<xlog_cuda::CudaBuffer>,
657 ) -> PyResult<PyObject> {
658 let relation_names = vec![name.clone()];
659 let mut deltas = HashMap::new();
660 deltas.insert(name, RelationDelta::new(insert, delete));
661 let report = self
662 .program
663 .apply_relation_deltas_with_session_runtime(
664 self.provider.clone(),
665 &mut self.relation_store,
666 &mut self.evaluation_store,
667 &mut self.session_runtime,
668 deltas,
669 )
670 .map_err(types::xlog_err)?;
671 let stats = LogicDeltaStats {
672 input_delta_count: report.input_delta_count,
673 changed_relations: report.changed_relations,
674 changed_relation_names: report.changed_relation_names,
675 insert_rows: report.insert_rows,
676 delete_rows: report.delete_rows,
677 has_deletes: report.has_deletes,
678 affected_sccs: report.affected_sccs,
679 recomputed_sccs: report.recomputed_sccs,
680 incremental_sccs: report.incremental_sccs,
681 coalesced_insert_rows: report.coalesced_insert_rows,
682 coalesced_delete_rows: report.coalesced_delete_rows,
683 canceled_rows: report.canceled_rows,
684 equivalent_to_full_recompute: None,
685 planner_telemetry: report.planner_telemetry,
686 debug_trace: report.debug_trace,
687 };
688 self.last_delta_stats = Some(stats.clone());
689 self.fire_relation_callbacks(py, &relation_names, &stats)?;
690 pack_delta_stats(py, &stats)
691 }
692
693 fn fire_relation_callbacks(
694 &mut self,
695 py: Python<'_>,
696 relation_names: &[String],
697 stats: &LogicDeltaStats,
698 ) -> PyResult<()> {
699 if self.relation_callbacks.is_empty() || stats.changed_relations == 0 {
700 return Ok(());
701 }
702
703 let mut seen = HashSet::new();
704 let mut events: Vec<(String, u64)> = Vec::new();
705 for relation in relation_names {
706 if seen.insert(relation.clone()) {
707 let generation = self
708 .relation_generations
709 .entry(relation.clone())
710 .and_modify(|current| *current = current.saturating_add(1))
711 .or_insert(1);
712 events.push((relation.clone(), *generation));
713 }
714 }
715
716 for (relation, generation) in events {
717 let payload = relation_callback_payload(py, &relation, generation, stats)?;
718 for registered in &self.relation_callbacks {
719 registered.callback.call1(py, (payload.clone_ref(py),))?;
720 }
721 }
722
723 Ok(())
724 }
725}
726
727fn relation_callback_payload(
728 py: Python<'_>,
729 relation: &str,
730 generation: u64,
731 stats: &LogicDeltaStats,
732) -> PyResult<PyObject> {
733 let dict = PyDict::new(py);
734 dict.set_item("relation", relation)?;
735 dict.set_item("generation", generation)?;
736 dict.set_item("input_delta_count", stats.input_delta_count)?;
737 dict.set_item(
738 "changed_relation_names",
739 stats.changed_relation_names.clone(),
740 )?;
741 dict.set_item("insert_rows", stats.insert_rows)?;
742 dict.set_item("delete_rows", stats.delete_rows)?;
743 dict.set_item("has_deletes", stats.has_deletes)?;
744 dict.set_item("coalesced_insert_rows", stats.coalesced_insert_rows)?;
745 dict.set_item("coalesced_delete_rows", stats.coalesced_delete_rows)?;
746 dict.set_item("canceled_rows", stats.canceled_rows)?;
747 dict.set_item("affected_sccs", stats.affected_sccs)?;
748 dict.set_item("recomputed_sccs", stats.recomputed_sccs)?;
749 dict.set_item("incremental_sccs", stats.incremental_sccs)?;
750 dict.set_item("debug_trace", stats.debug_trace.clone())?;
751 dict.set_item("telemetry", pack_delta_stats(py, stats)?)?;
752 Ok(dict.into())
753}
754
755fn optional_delta_columns<'py>(dict: &Bound<'py, PyDict>, key: &str) -> Option<Bound<'py, PyAny>> {
756 match dict.get_item(key) {
757 Ok(Some(value)) if !value.is_none() => Some(value),
758 _ => None,
759 }
760}
761
762fn logic_delta_stats_from_report(report: gpu_logic::LogicDeltaReport) -> LogicDeltaStats {
763 LogicDeltaStats {
764 input_delta_count: report.input_delta_count,
765 changed_relations: report.changed_relations,
766 changed_relation_names: report.changed_relation_names,
767 insert_rows: report.insert_rows,
768 delete_rows: report.delete_rows,
769 has_deletes: report.has_deletes,
770 affected_sccs: report.affected_sccs,
771 recomputed_sccs: report.recomputed_sccs,
772 incremental_sccs: report.incremental_sccs,
773 coalesced_insert_rows: report.coalesced_insert_rows,
774 coalesced_delete_rows: report.coalesced_delete_rows,
775 canceled_rows: report.canceled_rows,
776 equivalent_to_full_recompute: None,
777 planner_telemetry: report.planner_telemetry,
778 debug_trace: report.debug_trace,
779 }
780}
781
782fn pack_delta_stats(py: Python<'_>, stats: &LogicDeltaStats) -> PyResult<PyObject> {
783 let dict = PyDict::new(py);
784 dict.set_item("status", "ok")?;
785 dict.set_item("input_delta_count", stats.input_delta_count)?;
786 dict.set_item("changed_relations", stats.changed_relations)?;
787 dict.set_item(
788 "changed_relation_names",
789 stats.changed_relation_names.clone(),
790 )?;
791 dict.set_item("insert_rows", stats.insert_rows)?;
792 dict.set_item("delete_rows", stats.delete_rows)?;
793 dict.set_item("has_deletes", stats.has_deletes)?;
794 dict.set_item("affected_sccs", stats.affected_sccs)?;
795 dict.set_item("recomputed_sccs", stats.recomputed_sccs)?;
796 dict.set_item("incremental_sccs", stats.incremental_sccs)?;
797 dict.set_item("coalesced_insert_rows", stats.coalesced_insert_rows)?;
798 dict.set_item("coalesced_delete_rows", stats.coalesced_delete_rows)?;
799 dict.set_item("canceled_rows", stats.canceled_rows)?;
800 dict.set_item(
801 "equivalent_to_full_recompute",
802 stats.equivalent_to_full_recompute,
803 )?;
804 dict.set_item(
805 "planner_telemetry",
806 pack_delta_planner_telemetry(py, &stats.planner_telemetry)?,
807 )?;
808 dict.set_item("debug_trace", stats.debug_trace.clone())?;
809 Ok(dict.into())
810}
811
812fn pack_delta_planner_telemetry(
813 py: Python<'_>,
814 telemetry: &gpu_logic::DeltaPlannerTelemetry,
815) -> PyResult<PyObject> {
816 let dict = PyDict::new(py);
817 dict.set_item("cache_reused", telemetry.cache_reused)?;
818 dict.set_item("fallback_decision", telemetry.fallback_decision.clone())?;
819 dict.set_item("affected_sccs", telemetry.affected_sccs)?;
820 dict.set_item("recomputed_sccs", telemetry.recomputed_sccs)?;
821 dict.set_item("incremental_sccs", telemetry.incremental_sccs)?;
822 dict.set_item("estimated_delta_speedup", telemetry.estimated_delta_speedup)?;
823 dict.set_item("measured_delta_speedup", telemetry.measured_delta_speedup)?;
824 dict.set_item("planner_advice", telemetry.planner_advice.clone())?;
825 Ok(dict.into())
826}
827
828fn pack_rule_provenance(
829 py: Python<'_>,
830 entries: &[xlog_logic::RuleProvenance],
831) -> PyResult<PyObject> {
832 let list = PyList::empty(py);
833 for entry in entries {
834 let dict = PyDict::new(py);
835 dict.set_item("rule_id", &entry.rule_id)?;
836 dict.set_item("head", &entry.head)?;
837 dict.set_item("source_kind", entry.source_kind.as_str())?;
838 dict.set_item("source_span", entry.source_span.clone())?;
839 dict.set_item("generation_trace_hash", entry.generation_trace_hash.clone())?;
840 dict.set_item("support_relation_ids", entry.support_relation_ids.clone())?;
841 dict.set_item(
842 "counterexample_relation_ids",
843 entry.counterexample_relation_ids.clone(),
844 )?;
845 list.append(dict)?;
846 }
847 Ok(list.into())
848}
849
850fn pack_proof_traces(
851 py: Python<'_>,
852 entries: &[xlog_logic::QueryProofTrace],
853) -> PyResult<PyObject> {
854 let list = PyList::empty(py);
855 for entry in entries {
856 let dict = PyDict::new(py);
857 dict.set_item("query_id", &entry.query_id)?;
858 dict.set_item("query", &entry.query)?;
859 dict.set_item("answer_relation", &entry.answer_relation)?;
860 dict.set_item("rule_ids", entry.rule_ids.clone())?;
861 dict.set_item("source_facts", entry.source_facts.clone())?;
862 dict.set_item("rejected_alternatives", entry.rejected_alternatives.clone())?;
863 list.append(dict)?;
864 }
865 Ok(list.into())
866}
867
868fn collect_dlpack_columns(
869 obj: &Bound<'_, PyAny>,
870 type_error_message: &str,
871) -> PyResult<Vec<DlpackManagedTensor>> {
872 let seq = obj
873 .downcast::<PySequence>()
874 .map_err(|_| PyValueError::new_err(type_error_message.to_string()))?;
875
876 let mut tensors: Vec<DlpackManagedTensor> = Vec::with_capacity(seq.len()?);
877 for item in seq.try_iter()? {
878 let item = item?;
879 tensors.push(dlpack_from_py(&item)?);
880 }
881 Ok(tensors)
882}
883
884fn export_buffer_columns(
885 py: Python<'_>,
886 provider: &Arc<xlog_cuda::CudaKernelProvider>,
887 buffer: xlog_cuda::CudaBuffer,
888) -> PyResult<Vec<PyObject>> {
889 let arity = buffer.arity();
890 let table = provider.to_dlpack_table(buffer);
891 let mut tensors: Vec<PyObject> = Vec::with_capacity(arity);
892 for col_idx in 0..arity {
893 let tensor = table.column(col_idx).map_err(types::xlog_err)?;
894 tensors.push(dlpack_capsule_from_tensor(py, tensor)?);
895 }
896 Ok(tensors)
897}
898
899fn pack_logic_result_with_provider(
900 py: Python<'_>,
901 provider: &Arc<xlog_cuda::CudaKernelProvider>,
902 result: gpu_logic::LogicEvalResult,
903) -> PyResult<LogicEvalResult> {
904 let mut queries: Vec<Py<LogicQueryResult>> = Vec::with_capacity(result.queries.len());
905
906 for q in result.queries {
907 let num_rows = provider
908 .validated_logical_row_count(&q.buffer)
909 .map_err(types::xlog_err)?;
910 let is_true = q.columns.is_empty() && num_rows > 0;
911 let tensors = if q.columns.is_empty() {
912 Vec::new()
913 } else {
914 export_buffer_columns(py, provider, q.buffer)?
915 };
916
917 queries.push(Py::new(
918 py,
919 LogicQueryResult {
920 relation_name: q.relation_name,
921 columns: q.columns,
922 sort_labels: q.sort_labels,
923 tensors,
924 num_rows,
925 is_true,
926 },
927 )?);
928 }
929
930 Ok(LogicEvalResult { queries })
931}