Skip to main content

pyxlog/
ilp_exact.rs

1//! pyxlog bridge for the native bounded exact-induction engine.
2//!
3//! Deliberately separate from `crates/pyxlog/src/ilp.rs` (already ~2.7k LOC)
4//! so the native exact-induction binding seam has its own testable surface.
5//! Converts the Python-visible `induce_exact_native(...)` call into an
6//! [`xlog_induce::InduceExactRequest`] and maps the result back to a
7//! `dict`-shaped `PyObject` that the Python wrapper
8//! (`crates/pyxlog/python/pyxlog/ilp/exact_induce.py`) packages into
9//! `ExactInductionResult` / `ScoredCandidate` dataclass instances.
10//!
11//! Implementation history:
12//!   * The initial bridge raised `PyNotImplementedError`.
13//!   * The first native path added name-to-`RelId` resolution, DLPack-backed
14//!     positive/negative buffer construction, the engine call, and a
15//!     dict-shaped return.
16//!   * The batched scoring kernel is wired through
17//!     `CudaKernelProvider::ilp_exact_score`; native parity is locked by
18//!     `python/tests/test_ilp_exact_induce.py`.
19
20use pyo3::exceptions::PyValueError;
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23
24use xlog_core::RelId;
25use xlog_cuda::CudaBuffer;
26use xlog_induce::{
27    induce_exact as induce_exact_engine, ExactInductionConfig, ExactInductionResult,
28    InduceExactRequest,
29};
30
31use super::{dlpack_from_py, types, CompiledIlpProgram};
32
33#[pymethods]
34impl CompiledIlpProgram {
35    /// Native backend entrypoint for `pyxlog.ilp.induce_exact(backend="native")`.
36    ///
37    /// Returns a `dict` (not a dataclass) with the following shape — the
38    /// Python wrapper repackages this into the existing `ExactInductionResult`
39    /// / `ScoredCandidate` dataclass instances:
40    ///
41    /// ```text
42    /// {
43    ///   "candidates": [
44    ///     {"topology": str, "head_relation": str,
45    ///      "left_relation": str, "right_relation": str,
46    ///      "positives_covered": int, "negatives_covered": int,
47    ///      "local_rank": int,
48    ///      "next_positives_covered": int, "next_negatives_covered": int,
49    ///      "tie_class_size": int},
50    ///     ...
51    ///   ],
52    ///   "total_scored": int,
53    ///   "candidate_count": int,
54    ///   "positive_count": int,
55    ///   "negative_count": int,
56    /// }
57    /// ```
58    #[pyo3(signature = (
59        head_relation,
60        candidate_relations,
61        positive_arg0,
62        positive_arg1,
63        negative_arg0 = None,
64        negative_arg1 = None,
65        k_per_topology = 2,
66        deterministic = true,
67    ))]
68    #[allow(clippy::too_many_arguments)]
69    pub fn induce_exact_native<'py>(
70        &mut self,
71        py: Python<'py>,
72        head_relation: String,
73        candidate_relations: Vec<String>,
74        positive_arg0: &Bound<'py, PyAny>,
75        positive_arg1: &Bound<'py, PyAny>,
76        negative_arg0: Option<&Bound<'py, PyAny>>,
77        negative_arg1: Option<&Bound<'py, PyAny>>,
78        k_per_topology: u32,
79        deterministic: bool,
80    ) -> PyResult<PyObject> {
81        // ── 1. Resolve head_relation → RelId ───────────────────────────────
82        // Python reference silently returns an empty result when the head
83        // isn't in the ILP schema — mirror that behavior here.
84        let head_rid = match self.rel_index.iter().find(|(_, n)| n == &head_relation) {
85            Some((rid, _)) => *rid,
86            None => return empty_result_dict(py),
87        };
88
89        // ── 2. Filter candidate_relations against rel_index ────────────────
90        // Drops names that aren't declared in the ILP schema, matching
91        // Python's `if cname in name_to_idx: body_indices.append(...)`.
92        let filtered_candidates: Vec<(RelId, String)> = candidate_relations
93            .into_iter()
94            .filter_map(|name| {
95                self.rel_index
96                    .iter()
97                    .find(|(_, n)| n == &name)
98                    .map(|(rid, _)| (*rid, name))
99            })
100            .collect();
101
102        // ── 3. Build positive / negative CudaBuffers via DLPack ────────────
103        // Use head_relation's declared schema — this also validates that the
104        // query tensors are column-type-compatible with the head relation.
105        let head_schema = self.schemas.get(&head_relation).cloned().ok_or_else(|| {
106            PyValueError::new_err(format!(
107                "induce_exact_native: head_relation {:?} has no schema",
108                head_relation,
109            ))
110        })?;
111
112        let pos_dmt0 = dlpack_from_py(positive_arg0)?;
113        let pos_dmt1 = dlpack_from_py(positive_arg1)?;
114        let pos_buf = self
115            .provider
116            .from_dlpack_tensors_with_schema(head_schema.clone(), vec![pos_dmt0, pos_dmt1])
117            .map_err(types::xlog_err)?;
118
119        let neg_buf = match (negative_arg0, negative_arg1) {
120            (Some(a0), Some(a1)) => {
121                let dmt0 = dlpack_from_py(a0)?;
122                let dmt1 = dlpack_from_py(a1)?;
123                Some(
124                    self.provider
125                        .from_dlpack_tensors_with_schema(head_schema.clone(), vec![dmt0, dmt1])
126                        .map_err(types::xlog_err)?,
127                )
128            }
129            (None, None) => None,
130            _ => {
131                return Err(PyValueError::new_err(
132                    "induce_exact_native: negative_arg0 and negative_arg1 must both be set or both None",
133                ));
134            }
135        };
136
137        // ── 4. Look up candidate buffers in the relation store ─────────────
138        // NOTE: Declared-but-unloaded candidates are silently dropped here.
139        // The Python reference would score them as zero-coverage (empty fact
140        // set). The test suite always loads facts for every candidate, so
141        // this divergence is not observable in the native parity test suite.
142        let store = self.executor.store();
143        let candidates_with_bufs: Vec<(RelId, &CudaBuffer)> = filtered_candidates
144            .iter()
145            .filter_map(|(rid, name)| store.get(name).map(|buf| (*rid, buf)))
146            .collect();
147
148        // ── 5. Call the engine ─────────────────────────────────────────────
149        let request = InduceExactRequest {
150            head_rel_idx: head_rid,
151            candidates: &candidates_with_bufs,
152            positives: &pos_buf,
153            negatives: neg_buf.as_ref(),
154            config: ExactInductionConfig {
155                k_per_topology,
156                deterministic,
157            },
158        };
159        let result = induce_exact_engine(&self.provider, &request).map_err(types::xlog_err)?;
160
161        // ── 6. Marshal the result into a dict — RelId → relation name uses
162        //       the same `rel_index` we resolved names against. ─────────────
163        result_to_py_dict(py, &result, &head_relation, &self.rel_index)
164    }
165}
166
167fn empty_result_dict(py: Python<'_>) -> PyResult<PyObject> {
168    let d = PyDict::new(py);
169    let empty_candidates = PyList::empty(py);
170    d.set_item("candidates", empty_candidates)?;
171    d.set_item("total_scored", 0u32)?;
172    d.set_item("candidate_count", 0u32)?;
173    d.set_item("positive_count", 0u32)?;
174    d.set_item("negative_count", 0u32)?;
175    Ok(d.into())
176}
177
178fn result_to_py_dict(
179    py: Python<'_>,
180    result: &ExactInductionResult,
181    head_relation: &str,
182    rel_index: &[(RelId, String)],
183) -> PyResult<PyObject> {
184    let name_of = |rid: RelId| -> PyResult<&str> {
185        rel_index
186            .iter()
187            .find(|(r, _)| *r == rid)
188            .map(|(_, n)| n.as_str())
189            .ok_or_else(|| {
190                PyValueError::new_err(format!(
191                    "induce_exact_native: result references unknown RelId {:?}",
192                    rid,
193                ))
194            })
195    };
196
197    let candidates_list = PyList::empty(py);
198    for c in &result.candidates {
199        let entry = PyDict::new(py);
200        entry.set_item("topology", c.topology.as_str())?;
201        entry.set_item("head_relation", head_relation)?;
202        entry.set_item("left_relation", name_of(c.left_rel_idx)?)?;
203        entry.set_item("right_relation", name_of(c.right_rel_idx)?)?;
204        entry.set_item("positives_covered", c.positives_covered)?;
205        entry.set_item("negatives_covered", c.negatives_covered)?;
206        entry.set_item("local_rank", c.local_rank)?;
207        entry.set_item("next_positives_covered", c.next_positives_covered)?;
208        entry.set_item("next_negatives_covered", c.next_negatives_covered)?;
209        entry.set_item("tie_class_size", c.tie_class_size)?;
210        candidates_list.append(entry)?;
211    }
212
213    let d = PyDict::new(py);
214    d.set_item("candidates", candidates_list)?;
215    d.set_item("total_scored", result.total_scored)?;
216    d.set_item("candidate_count", result.candidate_count)?;
217    d.set_item("positive_count", result.positive_count)?;
218    d.set_item("negative_count", result.negative_count)?;
219    Ok(d.into())
220}