Skip to main content

xlog_induce/
lib.rs

1//! xlog-induce — native bounded exact-induction engine.
2//!
3//! Scores all `(left, right)` candidate pairs for the four supported topologies (chain,
4//! star, fanout, fanin) in a single batched GPU pass and returns the top-K per topology
5//! with full candidate metadata.
6//!
7//! Behaviorally equivalent to the `backend="python"` reference implementation
8//! in `crates/pyxlog/python/pyxlog/ilp/exact_induce.py` on bounded requests;
9//! the parity contract is locked by `python/tests/test_ilp_exact_induce.py`.
10//!
11//! The native production path includes request validation, deterministic reduction,
12//! trivial-dead-end early returns, the batched scoring kernel, device-side top-K
13//! selection, and compact selected-row transfers.
14
15pub mod index;
16pub mod provenance;
17pub mod reduce;
18pub mod score;
19pub mod types;
20mod validate;
21
22pub use provenance::InductionProvenanceRegistry;
23pub use reduce::{reduce_per_topology, ScoredPair};
24pub use types::{
25    ExactInductionConfig, ExactInductionResult, InducedRuleProvenance, InducedRuleRegistry,
26    InductionAlternative, InductionSupportRow, RuleSourceKind, ScoredCandidate, Topology,
27};
28
29use xlog_core::{RelId, Result, ScalarType, XlogError};
30use xlog_cuda::{CudaBuffer, CudaKernelProvider};
31
32use validate::{classify_request, PreKernelOutcome, RequestMetadata};
33
34#[derive(Clone, Copy, Debug, Eq, PartialEq)]
35enum ExactPairType {
36    U64,
37    U32,
38    Symbol,
39}
40
41/// Inputs to one [`induce_exact`] call.
42///
43/// Each candidate is a `(RelId, &CudaBuffer)` pair: the `RelId` is a label that
44/// flows through to every [`ScoredCandidate`] produced from that buffer, and
45/// the `CudaBuffer` carries the relation's current binary-pair facts.
46///
47/// `positives` and `negatives` are themselves binary-pair buffers (arity 2,
48/// column type `U64`). Name-to-`RelId` resolution and relation-store lookup
49/// happen at the pyxlog boundary — the engine only sees indices + handles.
50pub struct InduceExactRequest<'a> {
51    pub head_rel_idx: RelId,
52    pub candidates: &'a [(RelId, &'a CudaBuffer)],
53    pub positives: &'a CudaBuffer,
54    pub negatives: Option<&'a CudaBuffer>,
55    pub config: ExactInductionConfig,
56}
57
58/// Run exact induction against one request.
59///
60/// Returns an [`ExactInductionResult`] matching the Python reference on
61/// bounded inputs.
62///
63/// The `provider` argument owns the kernel launcher — it's passed separately
64/// from the request so the engine can also materialize short-lived GPU
65/// buffers (an empty negatives buffer when `request.negatives` is `None`).
66pub fn induce_exact(
67    provider: &CudaKernelProvider,
68    request: &InduceExactRequest<'_>,
69) -> Result<ExactInductionResult> {
70    // Empty candidates is a trivial dead-end and needs no CUDA inspection —
71    // matches the Python reference's `if not body_indices: return ...` path.
72    if request.candidates.is_empty() {
73        return Ok(ExactInductionResult::default());
74    }
75
76    // Buffer-level validation (arity 2, accepted typed pair columns). Runs before
77    // metadata extraction so we fail loud on pyxlog-side assembly bugs.
78    let pair_type = validate_pair_buffer(request.positives, "positives")?;
79    if let Some(neg) = request.negatives {
80        require_pair_type(neg, "negatives", pair_type)?;
81    }
82    for (i, (_, buf)) in request.candidates.iter().enumerate() {
83        require_pair_type(buf, &format!("candidate[{}]", i), pair_type)?;
84    }
85
86    // Extract row counts from the cached host-side metadata. The DLPack ingest
87    // path (`CudaKernelProvider::from_dlpack_tensors_with_schema`) populates
88    // `cached_row_count`, so this is a pure struct read — no device-to-host
89    // transfer. That's how we keep the hot-loop device-to-host transfer budget
90    // flat across candidate counts.
91    let pos_count = cached_rows(request.positives, "positives")?;
92    let neg_count = request
93        .negatives
94        .map(|b| cached_rows(b, "negatives"))
95        .transpose()?
96        .unwrap_or(0);
97
98    let meta = RequestMetadata {
99        candidate_count: request.candidates.len() as u32,
100        positive_count: pos_count,
101        negative_count: neg_count,
102        k_per_topology: request.config.k_per_topology,
103    };
104
105    match classify_request(meta) {
106        PreKernelOutcome::TrivialEmpty(result) => Ok(result),
107        PreKernelOutcome::Proceed(m) => score_and_reduce(provider, request, m),
108    }
109}
110
111fn score_and_reduce(
112    provider: &CudaKernelProvider,
113    request: &InduceExactRequest<'_>,
114    meta: RequestMetadata,
115) -> Result<ExactInductionResult> {
116    // ── Normalize negatives: engine + launcher expect an always-present
117    //    buffer. When the caller passes `None`, construct an empty U64 pair
118    //    buffer (zero rows) using the positives' schema. This keeps the
119    //    launcher signature and kernel signature uniform.
120    let empty_neg_holder: Option<CudaBuffer> = if request.negatives.is_none() {
121        Some(provider.create_empty_buffer(request.positives.schema().clone())?)
122    } else {
123        None
124    };
125    let negatives: &CudaBuffer = match request.negatives {
126        Some(b) => b,
127        None => empty_neg_holder
128            .as_ref()
129            .expect("holder populated in the None branch above"),
130    };
131
132    // ── Drive the batched scoring kernel and device-side top-K selector.
133    let candidate_buffers: Vec<&CudaBuffer> = request.candidates.iter().map(|(_, b)| *b).collect();
134    let selected = provider.ilp_exact_score_topk(
135        &candidate_buffers,
136        request.positives,
137        negatives,
138        request.config.k_per_topology,
139    )?;
140    let mut candidates = Vec::with_capacity(selected.len());
141    for row in selected {
142        let topology = topology_from_kernel_idx(row.topology_idx)?;
143        let left_idx = row.left_idx as usize;
144        let right_idx = row.right_idx as usize;
145        let (left_rel_idx, _) = request.candidates.get(left_idx).ok_or_else(|| {
146            XlogError::Execution(format!(
147                "induce_exact: device selector returned left index {} for {} candidates",
148                left_idx,
149                request.candidates.len()
150            ))
151        })?;
152        let (right_rel_idx, _) = request.candidates.get(right_idx).ok_or_else(|| {
153            XlogError::Execution(format!(
154                "induce_exact: device selector returned right index {} for {} candidates",
155                right_idx,
156                request.candidates.len()
157            ))
158        })?;
159        candidates.push(ScoredCandidate {
160            topology,
161            head_rel_idx: request.head_rel_idx,
162            left_rel_idx: *left_rel_idx,
163            right_rel_idx: *right_rel_idx,
164            positives_covered: row.positives_covered,
165            negatives_covered: row.negatives_covered,
166            local_rank: row.local_rank,
167            next_positives_covered: row.next_positives_covered,
168            next_negatives_covered: row.next_negatives_covered,
169            tie_class_size: row.tie_class_size,
170        });
171    }
172    let total_scored = 4u32
173        .checked_mul(meta.candidate_count)
174        .and_then(|v| v.checked_mul(meta.candidate_count))
175        .ok_or_else(|| XlogError::Execution("induce_exact: total_scored overflow".into()))?;
176
177    Ok(ExactInductionResult {
178        candidates,
179        total_scored,
180        candidate_count: meta.candidate_count,
181        positive_count: meta.positive_count,
182        negative_count: meta.negative_count,
183    })
184}
185
186fn topology_from_kernel_idx(idx: u32) -> Result<Topology> {
187    match idx {
188        0 => Ok(Topology::Chain),
189        1 => Ok(Topology::Star),
190        2 => Ok(Topology::Fanout),
191        3 => Ok(Topology::Fanin),
192        _ => Err(XlogError::Execution(format!(
193            "induce_exact: device selector returned topology index {}",
194            idx
195        ))),
196    }
197}
198
199fn validate_pair_buffer(buf: &CudaBuffer, label: &str) -> Result<ExactPairType> {
200    if buf.arity() != 2 {
201        return Err(XlogError::Execution(format!(
202            "induce_exact: {} buffer has arity {}, expected 2",
203            label,
204            buf.arity(),
205        )));
206    }
207    let mut pair_type = None;
208    for col_idx in 0..2 {
209        let t = buf.schema().column_type(col_idx).ok_or_else(|| {
210            XlogError::Type(format!(
211                "induce_exact: {} buffer column {} has no schema type",
212                label, col_idx,
213            ))
214        })?;
215        let col_type = match t {
216            ScalarType::U64 => ExactPairType::U64,
217            ScalarType::U32 => ExactPairType::U32,
218            ScalarType::Symbol => ExactPairType::Symbol,
219            _ => {
220                return Err(XlogError::Type(format!(
221                    "induce_exact: {} buffer column {} has type {:?}, expected U64, U32, or Symbol",
222                    label, col_idx, t,
223                )));
224            }
225        };
226        if let Some(expected) = pair_type {
227            if expected != col_type {
228                return Err(XlogError::Type(format!(
229                    "induce_exact: {} buffer column {} type mismatch: {:?} vs {:?}",
230                    label, col_idx, expected, col_type,
231                )));
232            }
233        } else {
234            pair_type = Some(col_type);
235        }
236    }
237    Ok(pair_type.expect("arity 2 loop sets pair type"))
238}
239
240fn require_pair_type(buf: &CudaBuffer, label: &str, expected: ExactPairType) -> Result<()> {
241    let actual = validate_pair_buffer(buf, label)?;
242    if actual != expected {
243        return Err(XlogError::Type(format!(
244            "induce_exact: {} buffer type mismatch: expected {:?}, got {:?}",
245            label, expected, actual,
246        )));
247    }
248    Ok(())
249}
250
251fn cached_rows(buf: &CudaBuffer, label: &str) -> Result<u32> {
252    buf.cached_row_count().ok_or_else(|| {
253        XlogError::Execution(format!(
254            "induce_exact: {} buffer has no cached row count \
255             (DLPack ingest path should populate it; required to avoid hot-loop device-to-host transfer)",
256            label,
257        ))
258    })
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn topology_as_str_matches_python_contract() {
267        assert_eq!(Topology::Chain.as_str(), "chain");
268        assert_eq!(Topology::Star.as_str(), "star");
269        assert_eq!(Topology::Fanout.as_str(), "fanout");
270        assert_eq!(Topology::Fanin.as_str(), "fanin");
271    }
272
273    #[test]
274    fn topology_all_is_engine_order() {
275        assert_eq!(
276            Topology::ALL,
277            [
278                Topology::Chain,
279                Topology::Star,
280                Topology::Fanout,
281                Topology::Fanin
282            ],
283        );
284    }
285}