1use pyo3::exceptions::PyValueError;
21use pyo3::prelude::*;
22use pyo3::types::{PyDict, PyList};
23
24use xlog_core::RelId;
25use xlog_cuda::CudaBuffer;
26use xlog_induce::{
27 induce_exact as induce_exact_engine, ExactInductionConfig, ExactInductionResult,
28 InduceExactRequest,
29};
30
31use super::{dlpack_from_py, types, CompiledIlpProgram};
32
33#[pymethods]
34impl CompiledIlpProgram {
35 #[pyo3(signature = (
59 head_relation,
60 candidate_relations,
61 positive_arg0,
62 positive_arg1,
63 negative_arg0 = None,
64 negative_arg1 = None,
65 k_per_topology = 2,
66 deterministic = true,
67 ))]
68 #[allow(clippy::too_many_arguments)]
69 pub fn induce_exact_native<'py>(
70 &mut self,
71 py: Python<'py>,
72 head_relation: String,
73 candidate_relations: Vec<String>,
74 positive_arg0: &Bound<'py, PyAny>,
75 positive_arg1: &Bound<'py, PyAny>,
76 negative_arg0: Option<&Bound<'py, PyAny>>,
77 negative_arg1: Option<&Bound<'py, PyAny>>,
78 k_per_topology: u32,
79 deterministic: bool,
80 ) -> PyResult<PyObject> {
81 let head_rid = match self.rel_index.iter().find(|(_, n)| n == &head_relation) {
85 Some((rid, _)) => *rid,
86 None => return empty_result_dict(py),
87 };
88
89 let filtered_candidates: Vec<(RelId, String)> = candidate_relations
93 .into_iter()
94 .filter_map(|name| {
95 self.rel_index
96 .iter()
97 .find(|(_, n)| n == &name)
98 .map(|(rid, _)| (*rid, name))
99 })
100 .collect();
101
102 let head_schema = self.schemas.get(&head_relation).cloned().ok_or_else(|| {
106 PyValueError::new_err(format!(
107 "induce_exact_native: head_relation {:?} has no schema",
108 head_relation,
109 ))
110 })?;
111
112 let pos_dmt0 = dlpack_from_py(positive_arg0)?;
113 let pos_dmt1 = dlpack_from_py(positive_arg1)?;
114 let pos_buf = self
115 .provider
116 .from_dlpack_tensors_with_schema(head_schema.clone(), vec![pos_dmt0, pos_dmt1])
117 .map_err(types::xlog_err)?;
118
119 let neg_buf = match (negative_arg0, negative_arg1) {
120 (Some(a0), Some(a1)) => {
121 let dmt0 = dlpack_from_py(a0)?;
122 let dmt1 = dlpack_from_py(a1)?;
123 Some(
124 self.provider
125 .from_dlpack_tensors_with_schema(head_schema.clone(), vec![dmt0, dmt1])
126 .map_err(types::xlog_err)?,
127 )
128 }
129 (None, None) => None,
130 _ => {
131 return Err(PyValueError::new_err(
132 "induce_exact_native: negative_arg0 and negative_arg1 must both be set or both None",
133 ));
134 }
135 };
136
137 let store = self.executor.store();
143 let candidates_with_bufs: Vec<(RelId, &CudaBuffer)> = filtered_candidates
144 .iter()
145 .filter_map(|(rid, name)| store.get(name).map(|buf| (*rid, buf)))
146 .collect();
147
148 let request = InduceExactRequest {
150 head_rel_idx: head_rid,
151 candidates: &candidates_with_bufs,
152 positives: &pos_buf,
153 negatives: neg_buf.as_ref(),
154 config: ExactInductionConfig {
155 k_per_topology,
156 deterministic,
157 },
158 };
159 let result = induce_exact_engine(&self.provider, &request).map_err(types::xlog_err)?;
160
161 result_to_py_dict(py, &result, &head_relation, &self.rel_index)
164 }
165}
166
167fn empty_result_dict(py: Python<'_>) -> PyResult<PyObject> {
168 let d = PyDict::new(py);
169 let empty_candidates = PyList::empty(py);
170 d.set_item("candidates", empty_candidates)?;
171 d.set_item("total_scored", 0u32)?;
172 d.set_item("candidate_count", 0u32)?;
173 d.set_item("positive_count", 0u32)?;
174 d.set_item("negative_count", 0u32)?;
175 Ok(d.into())
176}
177
178fn result_to_py_dict(
179 py: Python<'_>,
180 result: &ExactInductionResult,
181 head_relation: &str,
182 rel_index: &[(RelId, String)],
183) -> PyResult<PyObject> {
184 let name_of = |rid: RelId| -> PyResult<&str> {
185 rel_index
186 .iter()
187 .find(|(r, _)| *r == rid)
188 .map(|(_, n)| n.as_str())
189 .ok_or_else(|| {
190 PyValueError::new_err(format!(
191 "induce_exact_native: result references unknown RelId {:?}",
192 rid,
193 ))
194 })
195 };
196
197 let candidates_list = PyList::empty(py);
198 for c in &result.candidates {
199 let entry = PyDict::new(py);
200 entry.set_item("topology", c.topology.as_str())?;
201 entry.set_item("head_relation", head_relation)?;
202 entry.set_item("left_relation", name_of(c.left_rel_idx)?)?;
203 entry.set_item("right_relation", name_of(c.right_rel_idx)?)?;
204 entry.set_item("positives_covered", c.positives_covered)?;
205 entry.set_item("negatives_covered", c.negatives_covered)?;
206 entry.set_item("local_rank", c.local_rank)?;
207 entry.set_item("next_positives_covered", c.next_positives_covered)?;
208 entry.set_item("next_negatives_covered", c.next_negatives_covered)?;
209 entry.set_item("tie_class_size", c.tie_class_size)?;
210 candidates_list.append(entry)?;
211 }
212
213 let d = PyDict::new(py);
214 d.set_item("candidates", candidates_list)?;
215 d.set_item("total_scored", result.total_scored)?;
216 d.set_item("candidate_count", result.candidate_count)?;
217 d.set_item("positive_count", result.positive_count)?;
218 d.set_item("negative_count", result.negative_count)?;
219 Ok(d.into())
220}