Skip to main content

pyxlog/
logic.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::Instant;
4
5use pyo3::exceptions::{PyRuntimeError, PyValueError};
6use pyo3::prelude::*;
7use pyo3::types::{PyDict, PyList, PySequence};
8
9use xlog_cuda::DlpackManagedTensor;
10use xlog_gpu::logic as gpu_logic;
11use xlog_logic::ast::ProbEngine;
12use xlog_neural::{NetworkRegistry, TensorSourceRegistry};
13use xlog_prob::exact::{ExactDdnnfProgram, GpuConfig};
14use xlog_prob::mc::McProgram;
15use xlog_runtime::RelationDelta;
16
17use std::collections::HashMap as StdHashMap;
18
19use super::neural_registry::NeuralPredicateRegistry;
20use super::{
21    dlpack_capsule_from_tensor, dlpack_from_py, enforce_call_memory_limit,
22    parse_prob_engine_override, provider_from_config, provider_memory_stats, types,
23    CompiledLogicProgram, CompiledProbProgram, CompiledProgram, LogicDeltaStats, LogicEvalResult,
24    LogicProgram, LogicQueryResult, LogicRelationSession, Program, RelationChangeCallback,
25};
26
27#[pymethods]
28impl Program {
29    #[staticmethod]
30    #[pyo3(signature = (source, device=0, memory_mb=32768, prob_engine=None))]
31    pub fn compile(
32        source: &str,
33        device: usize,
34        memory_mb: u64,
35        prob_engine: Option<String>,
36    ) -> PyResult<CompiledProgram> {
37        if memory_mb == 0 {
38            return Err(PyValueError::new_err("memory_mb must be > 0"));
39        }
40
41        let mut config = GpuConfig::default();
42        config.device_ordinal = device;
43        config.memory_bytes = memory_mb * 1024 * 1024;
44
45        // Parse the AST to get prob_engine and neural predicates
46        let ast = xlog_logic::parse_program(source).map_err(types::xlog_err)?;
47
48        // Extract declared neural network names
49        let declared_networks: HashSet<String> = ast
50            .neural_predicates
51            .iter()
52            .map(|np| np.network.clone())
53            .collect();
54        // Build by-network form index: network name -> is_embedding
55        let mut declared_network_forms: HashMap<String, bool> = HashMap::new();
56        for np in &ast.neural_predicates {
57            let is_embedding = np.labels.is_none();
58            match declared_network_forms.get(&np.network) {
59                Some(&existing_form) if existing_form != is_embedding => {
60                    return Err(PyValueError::new_err(format!(
61                        "network '{}' is declared as both classification and embedding; \
62                         each network name must have a single form",
63                        np.network
64                    )));
65                }
66                _ => {
67                    declared_network_forms.insert(np.network.clone(), is_embedding);
68                }
69            }
70        }
71
72        let neural_registry = NeuralPredicateRegistry::from_ast(&ast).map_err(types::val_err)?;
73
74        let engine = match prob_engine {
75            Some(s) => parse_prob_engine_override(&s)?,
76            None => ast.prob_engine(),
77        };
78
79        let program = match engine {
80            ProbEngine::ExactDdnnf => CompiledProbProgram::Exact(
81                ExactDdnnfProgram::compile_source_with_gpu(source, config)
82                    .map_err(types::xlog_err)?,
83            ),
84            ProbEngine::Mc => CompiledProbProgram::Mc(
85                McProgram::compile_source_with_gpu(source, config).map_err(types::xlog_err)?,
86            ),
87        };
88        let provider = provider_from_config(config).map_err(types::xlog_err)?;
89
90        Ok(CompiledProgram {
91            program,
92            output_provider: Arc::new(provider),
93            network_registry: NetworkRegistry::new(),
94            neural_registry,
95            declared_networks,
96            declared_network_forms,
97            tensor_sources: TensorSourceRegistry::new(),
98            domain_source: None,
99            _source: source.to_string(),
100            ast,
101            _gpu_config: config,
102            _prob_engine: engine,
103            query_signature_cache: StdHashMap::new(),
104            circuit_cache: StdHashMap::new(),
105            circuit_cache_hits: 0,
106            circuit_cache_misses: 0,
107            template_compile_count: 0,
108            batch_queries: true,
109            last_compile_profile: None,
110        })
111    }
112}
113
114#[pymethods]
115impl LogicProgram {
116    #[staticmethod]
117    #[pyo3(signature = (source, device=0, memory_mb=32768))]
118    pub fn compile(source: &str, device: usize, memory_mb: u64) -> PyResult<CompiledLogicProgram> {
119        if memory_mb == 0 {
120            return Err(PyValueError::new_err("memory_mb must be > 0"));
121        }
122
123        let mut config = GpuConfig::default();
124        config.device_ordinal = device;
125        config.memory_bytes = memory_mb * 1024 * 1024;
126
127        let program = gpu_logic::LogicProgram::compile(source).map_err(types::xlog_err)?;
128        let provider = provider_from_config(config).map_err(types::xlog_err)?;
129
130        Ok(CompiledLogicProgram {
131            program,
132            provider: Arc::new(provider),
133        })
134    }
135}
136
137#[pymethods]
138impl CompiledLogicProgram {
139    #[pyo3(signature = (dlpack_inputs=None, memory_mb=None))]
140    pub fn evaluate(
141        &self,
142        py: Python<'_>,
143        dlpack_inputs: Option<&Bound<'_, PyDict>>,
144        memory_mb: Option<u64>,
145    ) -> PyResult<LogicEvalResult> {
146        enforce_call_memory_limit(&self.provider, memory_mb)?;
147        let mut inputs: HashMap<String, xlog_cuda::CudaBuffer> = HashMap::new();
148
149        if let Some(dict) = dlpack_inputs {
150            for (k, v) in dict.iter() {
151                let name: String = k.extract()?;
152                let schema = self.program.schema(&name).ok_or_else(|| {
153                    PyValueError::new_err(format!(
154                        "Unknown input relation {} (not present in compiled schemas)",
155                        name
156                    ))
157                })?;
158
159                let tensors = collect_dlpack_columns(
160                    &v,
161                    &format!(
162                        "Input relation {} must be a sequence of DLPack columns",
163                        name
164                    ),
165                )?;
166
167                let buffer = self
168                    .provider
169                    .from_dlpack_tensors_with_schema(schema.clone(), tensors)
170                    .map_err(types::xlog_err)?;
171
172                inputs.insert(name, buffer);
173            }
174        }
175
176        let result = self
177            .program
178            .evaluate(self.provider.clone(), inputs)
179            .map_err(types::xlog_err)?;
180        pack_logic_result_with_provider(py, &self.provider, result)
181    }
182
183    pub fn session(&self) -> PyResult<LogicRelationSession> {
184        let relation_store = self
185            .program
186            .create_relation_store(self.provider.clone())
187            .map_err(types::xlog_err)?;
188        Ok(LogicRelationSession {
189            program: self.program.clone(),
190            provider: self.provider.clone(),
191            relation_store,
192            evaluation_store: None,
193            session_runtime: None,
194            last_delta_stats: None,
195            relation_callbacks: Vec::new(),
196            next_relation_callback_id: 1,
197            relation_generations: HashMap::new(),
198        })
199    }
200
201    /// Return memory diagnostics including allocated_bytes and memory_limit_bytes.
202    pub fn memory_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
203        provider_memory_stats(py, &self.provider)
204    }
205
206    pub fn rule_provenance(&self, py: Python<'_>) -> PyResult<PyObject> {
207        pack_rule_provenance(py, &self.program.rule_provenance())
208    }
209
210    pub fn proof_traces(&self, py: Python<'_>) -> PyResult<PyObject> {
211        pack_proof_traces(py, &self.program.proof_traces())
212    }
213}
214
215impl CompiledLogicProgram {}
216
217#[pymethods]
218impl LogicRelationSession {
219    pub fn put_relation(
220        &mut self,
221        name: String,
222        dlpack_columns: &Bound<'_, PyAny>,
223    ) -> PyResult<()> {
224        if name.starts_with("__") {
225            return Err(PyValueError::new_err(format!(
226                "Relation {} is internal and cannot be stored in a persistent session",
227                name
228            )));
229        }
230        let schema = self.program.schema(&name).ok_or_else(|| {
231            PyValueError::new_err(format!(
232                "Unknown relation {} (not present in compiled schemas)",
233                name
234            ))
235        })?;
236        let tensors = collect_dlpack_columns(
237            dlpack_columns,
238            &format!("Relation {} must be a sequence of DLPack columns", name),
239        )?;
240        let buffer = self
241            .provider
242            .from_dlpack_tensors_with_schema(schema.clone(), tensors)
243            .map_err(types::xlog_err)?;
244        self.relation_store.put(&name, buffer);
245        self.evaluation_store = None;
246        self.session_runtime = None;
247        self.last_delta_stats = None;
248        Ok(())
249    }
250
251    #[pyo3(signature = (memory_mb=None))]
252    pub fn evaluate(
253        &mut self,
254        py: Python<'_>,
255        memory_mb: Option<u64>,
256    ) -> PyResult<LogicEvalResult> {
257        enforce_call_memory_limit(&self.provider, memory_mb)?;
258        let result = if let Some(store) = &self.evaluation_store {
259            self.program
260                .evaluate_cached_relation_store(self.provider.clone(), store)
261                .map_err(types::xlog_err)?
262        } else {
263            if self.session_runtime.is_none() {
264                self.session_runtime = Some(
265                    self.program
266                        .create_session_runtime(self.provider.clone(), &self.relation_store, false)
267                        .map_err(types::xlog_err)?,
268                );
269            }
270            let runtime = self.session_runtime.as_mut().ok_or_else(|| {
271                PyRuntimeError::new_err("session runtime unavailable during evaluation")
272            })?;
273            let (result, store) = self
274                .program
275                .evaluate_with_session_runtime(self.provider.clone(), runtime)
276                .map_err(types::xlog_err)?;
277            self.evaluation_store = Some(store);
278            result
279        };
280        pack_logic_result_with_provider(py, &self.provider, result)
281    }
282
283    pub fn insert_relation(
284        &mut self,
285        py: Python<'_>,
286        name: String,
287        dlpack_columns: &Bound<'_, PyAny>,
288    ) -> PyResult<PyObject> {
289        let insert = self.relation_delta_buffer(&name, dlpack_columns)?;
290        self.apply_single_relation_delta(py, name, Some(insert), None)
291    }
292
293    pub fn delete_relation(
294        &mut self,
295        py: Python<'_>,
296        name: String,
297        dlpack_columns: &Bound<'_, PyAny>,
298    ) -> PyResult<PyObject> {
299        let delete = self.relation_delta_buffer(&name, dlpack_columns)?;
300        self.apply_single_relation_delta(py, name, None, Some(delete))
301    }
302
303    #[pyo3(signature = (name, insert_columns=None, delete_columns=None))]
304    pub fn apply_relation_delta(
305        &mut self,
306        py: Python<'_>,
307        name: String,
308        insert_columns: Option<&Bound<'_, PyAny>>,
309        delete_columns: Option<&Bound<'_, PyAny>>,
310    ) -> PyResult<PyObject> {
311        if insert_columns.is_none() && delete_columns.is_none() {
312            return Err(PyValueError::new_err(
313                "apply_relation_delta requires insert_columns, delete_columns, or both",
314            ));
315        }
316        let insert = insert_columns
317            .map(|columns| self.relation_delta_buffer(&name, columns))
318            .transpose()?;
319        let delete = delete_columns
320            .map(|columns| self.relation_delta_buffer(&name, columns))
321            .transpose()?;
322        self.apply_single_relation_delta(py, name, insert, delete)
323    }
324
325    pub fn apply_relation_delta_batch(
326        &mut self,
327        py: Python<'_>,
328        updates: &Bound<'_, PyAny>,
329    ) -> PyResult<PyObject> {
330        let (batch, relation_names) =
331            self.parse_relation_delta_batch("apply_relation_delta_batch", updates)?;
332
333        let report = self
334            .program
335            .apply_relation_delta_batch_with_session_runtime(
336                self.provider.clone(),
337                &mut self.relation_store,
338                &mut self.evaluation_store,
339                &mut self.session_runtime,
340                batch,
341            )
342            .map_err(types::xlog_err)?;
343        let stats = logic_delta_stats_from_report(report);
344        self.last_delta_stats = Some(stats.clone());
345        self.fire_relation_callbacks(py, &relation_names, &stats)?;
346        pack_delta_stats(py, &stats)
347    }
348
349    #[pyo3(signature = (updates, check_equivalence=false))]
350    pub fn apply_relation_delta_debug(
351        &mut self,
352        py: Python<'_>,
353        updates: &Bound<'_, PyAny>,
354        check_equivalence: bool,
355    ) -> PyResult<PyObject> {
356        let (batch, relation_names) =
357            self.parse_relation_delta_batch("apply_relation_delta_debug", updates)?;
358        let delta_start = Instant::now();
359        let report = self
360            .program
361            .apply_relation_delta_batch_with_session_runtime(
362                self.provider.clone(),
363                &mut self.relation_store,
364                &mut self.evaluation_store,
365                &mut self.session_runtime,
366                batch,
367            )
368            .map_err(types::xlog_err)?;
369        let delta_micros = delta_start.elapsed().as_micros().max(1) as u64;
370        let mut stats = logic_delta_stats_from_report(report);
371        if check_equivalence {
372            let full_start = Instant::now();
373            let (_, full_store) = self
374                .program
375                .evaluate_with_relation_store_and_cache(
376                    self.provider.clone(),
377                    &self.relation_store,
378                    false,
379                )
380                .map_err(types::xlog_err)?;
381            let full_micros = full_start.elapsed().as_micros() as u64;
382            let cached_store = self.evaluation_store.as_ref().ok_or_else(|| {
383                PyRuntimeError::new_err("delta debug missing cached store after delta application")
384            })?;
385            stats.equivalent_to_full_recompute = Some(
386                self.program
387                    .relation_stores_query_equivalent(
388                        self.provider.as_ref(),
389                        &full_store,
390                        cached_store,
391                    )
392                    .map_err(types::xlog_err)?,
393            );
394            let speedup = full_micros as f64 / delta_micros as f64;
395            stats.planner_telemetry.measured_delta_speedup = Some(speedup);
396            if speedup >= 1.0 {
397                stats
398                    .planner_telemetry
399                    .planner_advice
400                    .push(format!("delta path is faster by {speedup:.2}x"));
401            } else {
402                stats.planner_telemetry.planner_advice.push(format!(
403                    "full recompute may be faster; delta measured {speedup:.2}x"
404                ));
405            }
406        }
407        self.last_delta_stats = Some(stats.clone());
408        self.fire_relation_callbacks(py, &relation_names, &stats)?;
409        pack_delta_stats(py, &stats)
410    }
411
412    pub fn delta_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
413        match &self.last_delta_stats {
414            Some(stats) => pack_delta_stats(py, stats),
415            None => {
416                let dict = PyDict::new(py);
417                dict.set_item("status", "unavailable")?;
418                dict.set_item("reason", "no relation delta has been applied")?;
419                Ok(dict.into())
420            }
421        }
422    }
423
424    pub fn rule_provenance(&self, py: Python<'_>) -> PyResult<PyObject> {
425        pack_rule_provenance(py, &self.program.rule_provenance())
426    }
427
428    pub fn proof_traces(&self, py: Python<'_>) -> PyResult<PyObject> {
429        pack_proof_traces(py, &self.program.proof_traces())
430    }
431
432    pub fn register_relation_callback(
433        &mut self,
434        py: Python<'_>,
435        callback: PyObject,
436    ) -> PyResult<u64> {
437        if !callback.bind(py).is_callable() {
438            return Err(PyValueError::new_err(
439                "register_relation_callback expects a callable",
440            ));
441        }
442        let id = self.next_relation_callback_id;
443        self.next_relation_callback_id = self.next_relation_callback_id.saturating_add(1);
444        self.relation_callbacks
445            .push(RelationChangeCallback { id, callback });
446        Ok(id)
447    }
448
449    pub fn unregister_relation_callback(&mut self, callback_id: u64) -> bool {
450        let before = self.relation_callbacks.len();
451        self.relation_callbacks
452            .retain(|registered| registered.id != callback_id);
453        before != self.relation_callbacks.len()
454    }
455
456    pub fn cuda_graph_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
457        let dict = PyDict::new(py);
458        dict.set_item(
459            "csm_cuda_graph_captures",
460            self.provider.csm_cuda_graph_captures(),
461        )?;
462        dict.set_item(
463            "csm_cuda_graph_launches",
464            self.provider.csm_cuda_graph_launches(),
465        )?;
466        dict.set_item(
467            "csm_cuda_graph_fallbacks",
468            self.provider.csm_cuda_graph_fallbacks(),
469        )?;
470        dict.set_item(
471            "csm_cuda_graph_cache_hits",
472            self.provider.csm_cuda_graph_cache_hits(),
473        )?;
474        Ok(dict.into())
475    }
476
477    pub fn host_transfer_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
478        let stats = self.provider.host_transfer_stats();
479        let dict = PyDict::new(py);
480        dict.set_item("dtoh_bytes", stats.dtoh_bytes)?;
481        dict.set_item("htod_bytes", stats.htod_bytes)?;
482        dict.set_item("dtoh_calls", stats.dtoh_calls)?;
483        dict.set_item("htod_calls", stats.htod_calls)?;
484        Ok(dict.into())
485    }
486
487    pub fn join_index_cache_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
488        let dict = PyDict::new(py);
489        let stats = self
490            .session_runtime
491            .as_ref()
492            .map(|runtime| runtime.join_index_cache_stats())
493            .unwrap_or_default();
494        dict.set_item("lookups", stats.lookups)?;
495        dict.set_item("hits", stats.hits)?;
496        dict.set_item("misses", stats.misses)?;
497        dict.set_item("builds", stats.builds)?;
498        dict.set_item("evictions", stats.evictions)?;
499        dict.set_item("invalidations", stats.invalidations)?;
500        dict.set_item("stale_rejections", stats.stale_rejections)?;
501        dict.set_item("background_build_requests", stats.background_build_requests)?;
502        dict.set_item(
503            "background_builds_completed",
504            stats.background_builds_completed,
505        )?;
506        dict.set_item(
507            "background_builds_deferred",
508            stats.background_builds_deferred,
509        )?;
510        dict.set_item("entries", stats.entries)?;
511        dict.set_item("total_bytes", stats.total_bytes)?;
512        Ok(dict.into())
513    }
514
515    /// Multiway/Free-Join dispatch telemetry for the retained session
516    /// executor. Counters accumulate across evaluates within this session;
517    /// all zeros before the first evaluate.
518    pub fn wcoj_dispatch_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
519        let dict = PyDict::new(py);
520        let stats = self
521            .session_runtime
522            .as_ref()
523            .map(|runtime| runtime.wcoj_dispatch_stats())
524            .unwrap_or_default();
525        dict.set_item("free_join_dispatch_count", stats.free_join_dispatch_count)?;
526        dict.set_item(
527            "factorized_delta_dispatch_count",
528            stats.factorized_delta_dispatch_count,
529        )?;
530        dict.set_item(
531            "wcoj_groupby_fusion_dispatch_count",
532            stats.wcoj_groupby_fusion_dispatch_count,
533        )?;
534        dict.set_item("wcoj_error_decline_count", stats.wcoj_error_decline_count)?;
535        Ok(dict.into())
536    }
537
538    pub fn reset_host_transfer_stats(&self) {
539        self.provider.reset_host_transfer_stats()
540    }
541
542    /// Return memory diagnostics including allocated_bytes and memory_limit_bytes.
543    pub fn memory_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
544        provider_memory_stats(py, &self.provider)
545    }
546
547    pub fn export_relation(&mut self, py: Python<'_>, name: &str) -> PyResult<Vec<PyObject>> {
548        let existing = self.relation_store.get(name).ok_or_else(|| {
549            PyValueError::new_err(format!(
550                "Relation '{}' not found in persistent session",
551                name
552            ))
553        })?;
554        let replacement = self
555            .provider
556            .clone_buffer(existing)
557            .map_err(types::xlog_err)?;
558        let buffer = self.relation_store.remove(name).ok_or_else(|| {
559            PyRuntimeError::new_err(format!("Relation '{}' disappeared during export", name))
560        })?;
561        self.relation_store.put(name, replacement);
562        export_buffer_columns(py, &self.provider, buffer)
563    }
564
565    pub fn remove_relation(&mut self, name: &str) -> bool {
566        let removed = self.relation_store.remove(name).is_some();
567        if removed {
568            self.evaluation_store = None;
569            self.session_runtime = None;
570            self.last_delta_stats = None;
571        }
572        removed
573    }
574
575    pub fn clear_relations(&mut self) {
576        self.relation_store.clear();
577        self.evaluation_store = None;
578        self.session_runtime = None;
579        self.last_delta_stats = None;
580    }
581}
582
583impl LogicRelationSession {
584    fn parse_relation_delta_batch(
585        &self,
586        method_name: &str,
587        updates: &Bound<'_, PyAny>,
588    ) -> PyResult<(Vec<(String, RelationDelta)>, Vec<String>)> {
589        let seq = updates.downcast::<PySequence>().map_err(|_| {
590            PyValueError::new_err(format!(
591                "{method_name} expects a sequence of update dictionaries"
592            ))
593        })?;
594        let mut batch: Vec<(String, RelationDelta)> = Vec::with_capacity(seq.len()? as usize);
595        let mut relation_names: Vec<String> = Vec::new();
596        for item in seq.try_iter()? {
597            let item = item?;
598            let dict = item.downcast::<PyDict>().map_err(|_| {
599                PyValueError::new_err(format!("{method_name} updates must be dictionaries"))
600            })?;
601            let name_obj = dict.get_item("name")?.ok_or_else(|| {
602                PyValueError::new_err(format!("{method_name} update missing 'name'"))
603            })?;
604            let name: String = name_obj.extract()?;
605            let insert = optional_delta_columns(dict, "insert_columns")
606                .map(|columns| self.relation_delta_buffer(&name, &columns))
607                .transpose()?;
608            let delete = optional_delta_columns(dict, "delete_columns")
609                .map(|columns| self.relation_delta_buffer(&name, &columns))
610                .transpose()?;
611            if insert.is_none() && delete.is_none() {
612                return Err(PyValueError::new_err(format!(
613                    "{method_name} updates require insert_columns, delete_columns, or both"
614                )));
615            }
616            relation_names.push(name.clone());
617            batch.push((name, RelationDelta::new(insert, delete)));
618        }
619        Ok((batch, relation_names))
620    }
621
622    fn relation_delta_buffer(
623        &self,
624        name: &str,
625        dlpack_columns: &Bound<'_, PyAny>,
626    ) -> PyResult<xlog_cuda::CudaBuffer> {
627        if name.starts_with("__") {
628            return Err(PyValueError::new_err(format!(
629                "Relation {} is internal and cannot be updated in a persistent session",
630                name
631            )));
632        }
633        let schema = self.program.schema(name).ok_or_else(|| {
634            PyValueError::new_err(format!(
635                "Unknown relation {} (not present in compiled schemas)",
636                name
637            ))
638        })?;
639        let tensors = collect_dlpack_columns(
640            dlpack_columns,
641            &format!(
642                "Relation {} delta must be a sequence of DLPack columns",
643                name
644            ),
645        )?;
646        self.provider
647            .from_dlpack_tensors_with_schema(schema.clone(), tensors)
648            .map_err(types::xlog_err)
649    }
650
651    fn apply_single_relation_delta(
652        &mut self,
653        py: Python<'_>,
654        name: String,
655        insert: Option<xlog_cuda::CudaBuffer>,
656        delete: Option<xlog_cuda::CudaBuffer>,
657    ) -> PyResult<PyObject> {
658        let relation_names = vec![name.clone()];
659        let mut deltas = HashMap::new();
660        deltas.insert(name, RelationDelta::new(insert, delete));
661        let report = self
662            .program
663            .apply_relation_deltas_with_session_runtime(
664                self.provider.clone(),
665                &mut self.relation_store,
666                &mut self.evaluation_store,
667                &mut self.session_runtime,
668                deltas,
669            )
670            .map_err(types::xlog_err)?;
671        let stats = LogicDeltaStats {
672            input_delta_count: report.input_delta_count,
673            changed_relations: report.changed_relations,
674            changed_relation_names: report.changed_relation_names,
675            insert_rows: report.insert_rows,
676            delete_rows: report.delete_rows,
677            has_deletes: report.has_deletes,
678            affected_sccs: report.affected_sccs,
679            recomputed_sccs: report.recomputed_sccs,
680            incremental_sccs: report.incremental_sccs,
681            coalesced_insert_rows: report.coalesced_insert_rows,
682            coalesced_delete_rows: report.coalesced_delete_rows,
683            canceled_rows: report.canceled_rows,
684            equivalent_to_full_recompute: None,
685            planner_telemetry: report.planner_telemetry,
686            debug_trace: report.debug_trace,
687        };
688        self.last_delta_stats = Some(stats.clone());
689        self.fire_relation_callbacks(py, &relation_names, &stats)?;
690        pack_delta_stats(py, &stats)
691    }
692
693    fn fire_relation_callbacks(
694        &mut self,
695        py: Python<'_>,
696        relation_names: &[String],
697        stats: &LogicDeltaStats,
698    ) -> PyResult<()> {
699        if self.relation_callbacks.is_empty() || stats.changed_relations == 0 {
700            return Ok(());
701        }
702
703        let mut seen = HashSet::new();
704        let mut events: Vec<(String, u64)> = Vec::new();
705        for relation in relation_names {
706            if seen.insert(relation.clone()) {
707                let generation = self
708                    .relation_generations
709                    .entry(relation.clone())
710                    .and_modify(|current| *current = current.saturating_add(1))
711                    .or_insert(1);
712                events.push((relation.clone(), *generation));
713            }
714        }
715
716        for (relation, generation) in events {
717            let payload = relation_callback_payload(py, &relation, generation, stats)?;
718            for registered in &self.relation_callbacks {
719                registered.callback.call1(py, (payload.clone_ref(py),))?;
720            }
721        }
722
723        Ok(())
724    }
725}
726
727fn relation_callback_payload(
728    py: Python<'_>,
729    relation: &str,
730    generation: u64,
731    stats: &LogicDeltaStats,
732) -> PyResult<PyObject> {
733    let dict = PyDict::new(py);
734    dict.set_item("relation", relation)?;
735    dict.set_item("generation", generation)?;
736    dict.set_item("input_delta_count", stats.input_delta_count)?;
737    dict.set_item(
738        "changed_relation_names",
739        stats.changed_relation_names.clone(),
740    )?;
741    dict.set_item("insert_rows", stats.insert_rows)?;
742    dict.set_item("delete_rows", stats.delete_rows)?;
743    dict.set_item("has_deletes", stats.has_deletes)?;
744    dict.set_item("coalesced_insert_rows", stats.coalesced_insert_rows)?;
745    dict.set_item("coalesced_delete_rows", stats.coalesced_delete_rows)?;
746    dict.set_item("canceled_rows", stats.canceled_rows)?;
747    dict.set_item("affected_sccs", stats.affected_sccs)?;
748    dict.set_item("recomputed_sccs", stats.recomputed_sccs)?;
749    dict.set_item("incremental_sccs", stats.incremental_sccs)?;
750    dict.set_item("debug_trace", stats.debug_trace.clone())?;
751    dict.set_item("telemetry", pack_delta_stats(py, stats)?)?;
752    Ok(dict.into())
753}
754
755fn optional_delta_columns<'py>(dict: &Bound<'py, PyDict>, key: &str) -> Option<Bound<'py, PyAny>> {
756    match dict.get_item(key) {
757        Ok(Some(value)) if !value.is_none() => Some(value),
758        _ => None,
759    }
760}
761
762fn logic_delta_stats_from_report(report: gpu_logic::LogicDeltaReport) -> LogicDeltaStats {
763    LogicDeltaStats {
764        input_delta_count: report.input_delta_count,
765        changed_relations: report.changed_relations,
766        changed_relation_names: report.changed_relation_names,
767        insert_rows: report.insert_rows,
768        delete_rows: report.delete_rows,
769        has_deletes: report.has_deletes,
770        affected_sccs: report.affected_sccs,
771        recomputed_sccs: report.recomputed_sccs,
772        incremental_sccs: report.incremental_sccs,
773        coalesced_insert_rows: report.coalesced_insert_rows,
774        coalesced_delete_rows: report.coalesced_delete_rows,
775        canceled_rows: report.canceled_rows,
776        equivalent_to_full_recompute: None,
777        planner_telemetry: report.planner_telemetry,
778        debug_trace: report.debug_trace,
779    }
780}
781
782fn pack_delta_stats(py: Python<'_>, stats: &LogicDeltaStats) -> PyResult<PyObject> {
783    let dict = PyDict::new(py);
784    dict.set_item("status", "ok")?;
785    dict.set_item("input_delta_count", stats.input_delta_count)?;
786    dict.set_item("changed_relations", stats.changed_relations)?;
787    dict.set_item(
788        "changed_relation_names",
789        stats.changed_relation_names.clone(),
790    )?;
791    dict.set_item("insert_rows", stats.insert_rows)?;
792    dict.set_item("delete_rows", stats.delete_rows)?;
793    dict.set_item("has_deletes", stats.has_deletes)?;
794    dict.set_item("affected_sccs", stats.affected_sccs)?;
795    dict.set_item("recomputed_sccs", stats.recomputed_sccs)?;
796    dict.set_item("incremental_sccs", stats.incremental_sccs)?;
797    dict.set_item("coalesced_insert_rows", stats.coalesced_insert_rows)?;
798    dict.set_item("coalesced_delete_rows", stats.coalesced_delete_rows)?;
799    dict.set_item("canceled_rows", stats.canceled_rows)?;
800    dict.set_item(
801        "equivalent_to_full_recompute",
802        stats.equivalent_to_full_recompute,
803    )?;
804    dict.set_item(
805        "planner_telemetry",
806        pack_delta_planner_telemetry(py, &stats.planner_telemetry)?,
807    )?;
808    dict.set_item("debug_trace", stats.debug_trace.clone())?;
809    Ok(dict.into())
810}
811
812fn pack_delta_planner_telemetry(
813    py: Python<'_>,
814    telemetry: &gpu_logic::DeltaPlannerTelemetry,
815) -> PyResult<PyObject> {
816    let dict = PyDict::new(py);
817    dict.set_item("cache_reused", telemetry.cache_reused)?;
818    dict.set_item("fallback_decision", telemetry.fallback_decision.clone())?;
819    dict.set_item("affected_sccs", telemetry.affected_sccs)?;
820    dict.set_item("recomputed_sccs", telemetry.recomputed_sccs)?;
821    dict.set_item("incremental_sccs", telemetry.incremental_sccs)?;
822    dict.set_item("estimated_delta_speedup", telemetry.estimated_delta_speedup)?;
823    dict.set_item("measured_delta_speedup", telemetry.measured_delta_speedup)?;
824    dict.set_item("planner_advice", telemetry.planner_advice.clone())?;
825    Ok(dict.into())
826}
827
828fn pack_rule_provenance(
829    py: Python<'_>,
830    entries: &[xlog_logic::RuleProvenance],
831) -> PyResult<PyObject> {
832    let list = PyList::empty(py);
833    for entry in entries {
834        let dict = PyDict::new(py);
835        dict.set_item("rule_id", &entry.rule_id)?;
836        dict.set_item("head", &entry.head)?;
837        dict.set_item("source_kind", entry.source_kind.as_str())?;
838        dict.set_item("source_span", entry.source_span.clone())?;
839        dict.set_item("generation_trace_hash", entry.generation_trace_hash.clone())?;
840        dict.set_item("support_relation_ids", entry.support_relation_ids.clone())?;
841        dict.set_item(
842            "counterexample_relation_ids",
843            entry.counterexample_relation_ids.clone(),
844        )?;
845        list.append(dict)?;
846    }
847    Ok(list.into())
848}
849
850fn pack_proof_traces(
851    py: Python<'_>,
852    entries: &[xlog_logic::QueryProofTrace],
853) -> PyResult<PyObject> {
854    let list = PyList::empty(py);
855    for entry in entries {
856        let dict = PyDict::new(py);
857        dict.set_item("query_id", &entry.query_id)?;
858        dict.set_item("query", &entry.query)?;
859        dict.set_item("answer_relation", &entry.answer_relation)?;
860        dict.set_item("rule_ids", entry.rule_ids.clone())?;
861        dict.set_item("source_facts", entry.source_facts.clone())?;
862        dict.set_item("rejected_alternatives", entry.rejected_alternatives.clone())?;
863        list.append(dict)?;
864    }
865    Ok(list.into())
866}
867
868fn collect_dlpack_columns(
869    obj: &Bound<'_, PyAny>,
870    type_error_message: &str,
871) -> PyResult<Vec<DlpackManagedTensor>> {
872    let seq = obj
873        .downcast::<PySequence>()
874        .map_err(|_| PyValueError::new_err(type_error_message.to_string()))?;
875
876    let mut tensors: Vec<DlpackManagedTensor> = Vec::with_capacity(seq.len()?);
877    for item in seq.try_iter()? {
878        let item = item?;
879        tensors.push(dlpack_from_py(&item)?);
880    }
881    Ok(tensors)
882}
883
884fn export_buffer_columns(
885    py: Python<'_>,
886    provider: &Arc<xlog_cuda::CudaKernelProvider>,
887    buffer: xlog_cuda::CudaBuffer,
888) -> PyResult<Vec<PyObject>> {
889    let arity = buffer.arity();
890    let table = provider.to_dlpack_table(buffer);
891    let mut tensors: Vec<PyObject> = Vec::with_capacity(arity);
892    for col_idx in 0..arity {
893        let tensor = table.column(col_idx).map_err(types::xlog_err)?;
894        tensors.push(dlpack_capsule_from_tensor(py, tensor)?);
895    }
896    Ok(tensors)
897}
898
899fn pack_logic_result_with_provider(
900    py: Python<'_>,
901    provider: &Arc<xlog_cuda::CudaKernelProvider>,
902    result: gpu_logic::LogicEvalResult,
903) -> PyResult<LogicEvalResult> {
904    let mut queries: Vec<Py<LogicQueryResult>> = Vec::with_capacity(result.queries.len());
905
906    for q in result.queries {
907        let num_rows = provider
908            .validated_logical_row_count(&q.buffer)
909            .map_err(types::xlog_err)?;
910        let is_true = q.columns.is_empty() && num_rows > 0;
911        let tensors = if q.columns.is_empty() {
912            Vec::new()
913        } else {
914            export_buffer_columns(py, provider, q.buffer)?
915        };
916
917        queries.push(Py::new(
918            py,
919            LogicQueryResult {
920                relation_name: q.relation_name,
921                columns: q.columns,
922                sort_labels: q.sort_labels,
923                tensors,
924                num_rows,
925                is_true,
926            },
927        )?);
928    }
929
930    Ok(LogicEvalResult { queries })
931}