Skip to main content

xlog_integration/
wcoj_dispatch.rs

1//! v0.6.2 minimal env-gated GPU 3-way WCOJ triangle dispatch.
2//!
3//! Single public entry:
4//! [`try_wcoj_triangle_u32_dispatch`] (env-driven) and
5//! [`try_wcoj_triangle_u32_dispatch_with_gate`] (boolean-driven,
6//! for tests). The slice is intentionally narrow:
7//!
8//!   * **Env flag only.** `XLOG_USE_WCOJ_TRIANGLE_U32=1` (or
9//!     `true`/`TRUE`) opts in. Anything else (unset, `0`,
10//!     `false`, etc.) means the helper returns `Ok(None)`
11//!     unconditionally — the caller takes the existing
12//!     binary-join path.
13//!   * **Recognizes exactly one shape.** A rule of the form
14//!     `tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z)` over
15//!     2-column WCOJ-eligible relations (U32, Symbol, or U64
16//!     keys — see `WcojKeyWidth`): three positive 2-arity body
17//!     atoms covering the head's three distinct variables in
18//!     head-position order. No negation, no comparison filters,
19//!     no recursion (head predicate not in body), no
20//!     reversed-axis atoms (e.g. `e1(Y, X)`), no constants in
21//!     atom args. The planner must also return
22//!     [`xlog_logic::hypergraph::RulePlan::MultiwayCandidate`].
23//!   * **Width uniformity.** All three slots must share a key
24//!     width. A mixed-width triangle (e.g. e1 U32, e2 U64) is
25//!     rejected at this dispatch level — the binary-join chain
26//!     handles it.
27//!   * **Silent fallback.** Any mismatch — gate off, shape
28//!     mismatch, planner verdict not multiway, missing input
29//!     buffer, unsupported scalar type, mixed-width slots —
30//!     returns `Ok(None)` without an error or log line. The
31//!     caller is expected to silently route to the existing
32//!     binary-join path. This keeps the env flag truly opt-in
33//!     and prevents the helper from accidentally diverting
34//!     work it can't handle.
35//!   * **Strict GPU pipeline on dispatch.** When all checks pass,
36//!     the helper builds three sorted+deduped layouts and runs
37//!     the matching WCOJ triangle kernel on the configured
38//!     `launch_stream` — `wcoj_layout_u32_recorded` /
39//!     `wcoj_triangle_u32_recorded` for 4-byte keys, the
40//!     `_u64_recorded` siblings for 8-byte keys. All
41//!     [`xlog_cuda::launch::LaunchRecorder`] discipline carries
42//!     through unchanged.
43//!
44//! What this slice deliberately does NOT do:
45//!
46//!   * No automatic detection at the executor level — callers
47//!     pass the rule + input buffers explicitly. Executor
48//!     wiring lives in `xlog-runtime`.
49//!   * No recursion / SCC mixed execution.
50//!   * No cost model.
51//!   * No mixed-width admission (U32+U64 triangle stays on the
52//!     binary-join path).
53//!   * No histogram-guided block dispatch.
54
55use std::collections::{BTreeMap, BTreeSet};
56
57use xlog_core::{Result, ScalarType};
58use xlog_cuda::device_runtime::StreamId;
59use xlog_cuda::memory::CudaBuffer;
60use xlog_cuda::CudaKernelProvider;
61use xlog_logic::ast::{BodyLiteral, Rule, Term};
62use xlog_logic::hypergraph::{plan_rule, RefRelation, RefRelationStore, RefValue, RulePlan};
63
64/// Env variable controlling the dispatch gate. Treated as ON
65/// when set to `"1"` or case-insensitive `"true"`; anything else
66/// (unset, `"0"`, `"false"`, empty string, …) means OFF.
67pub const ENV_USE_WCOJ_TRIANGLE_U32: &str = "XLOG_USE_WCOJ_TRIANGLE_U32";
68
69/// Env-driven entry. Reads `XLOG_USE_WCOJ_TRIANGLE_U32` and
70/// delegates to [`try_wcoj_triangle_u32_dispatch_with_gate`].
71///
72/// Tests should prefer [`try_wcoj_triangle_u32_dispatch_with_gate`]
73/// to avoid races on the process-global env var; the env-driven
74/// form exists for production callers that want to opt in via
75/// configuration alone.
76pub fn try_wcoj_triangle_u32_dispatch(
77    rule: &Rule,
78    inputs: &BTreeMap<String, CudaBuffer>,
79    provider: &CudaKernelProvider,
80    launch_stream: StreamId,
81) -> Result<Option<CudaBuffer>> {
82    let gate = std::env::var(ENV_USE_WCOJ_TRIANGLE_U32)
83        .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
84        .unwrap_or(false);
85    try_wcoj_triangle_u32_dispatch_with_gate(gate, rule, inputs, provider, launch_stream)
86}
87
88/// Test-friendly form that takes the gate as an explicit boolean.
89/// Production callers use [`try_wcoj_triangle_u32_dispatch`] which
90/// reads the env var.
91pub fn try_wcoj_triangle_u32_dispatch_with_gate(
92    gate_enabled: bool,
93    rule: &Rule,
94    inputs: &BTreeMap<String, CudaBuffer>,
95    provider: &CudaKernelProvider,
96    launch_stream: StreamId,
97) -> Result<Option<CudaBuffer>> {
98    if !gate_enabled {
99        return Ok(None);
100    }
101    // Match the triangle shape — silently bail on any mismatch.
102    let Some(matched) = match_triangle_shape(rule, inputs) else {
103        return Ok(None);
104    };
105
106    // Run the planner on a synthetic store carrying the inputs'
107    // schemas. plan_rule needs schemas (not rows) for typed
108    // gating; rows are not consulted at the gate phase.
109    let mut plan_store: RefRelationStore = BTreeMap::new();
110    for atom_match in &matched.atoms {
111        // Skip duplicates: the same predicate may legitimately
112        // appear in multiple slots (e.g. all three atoms over a
113        // single edge relation), and a BTreeMap insert collapses
114        // them naturally. Schemas are guaranteed identical
115        // because they all came from the same `inputs[name]`.
116        plan_store.insert(
117            atom_match.predicate.clone(),
118            schema_only_relation(&inputs[&atom_match.predicate]),
119        );
120    }
121    let plan = plan_rule(rule, &plan_store).ok();
122    match plan {
123        Some(RulePlan::MultiwayCandidate { .. }) => {}
124        _ => return Ok(None),
125    }
126
127    // Construct sorted+deduped layouts for each slot and run
128    // the WCOJ triangle kernel. Branch by width: 4-byte
129    // (U32 / Symbol) and 8-byte (U64) inputs go to parallel
130    // provider entries. Mixed-width triangles never reach this
131    // point — `match_triangle_shape` requires all three slots
132    // to share a width.
133    let result = match matched.width {
134        WcojKeyWidth::FourByte => {
135            let buf_xy = provider.wcoj_layout_u32_recorded(matched.e_xy, launch_stream)?;
136            let buf_yz = provider.wcoj_layout_u32_recorded(matched.e_yz, launch_stream)?;
137            let buf_xz = provider.wcoj_layout_u32_recorded(matched.e_xz, launch_stream)?;
138            provider.wcoj_triangle_u32_recorded(&buf_xy, &buf_yz, &buf_xz, launch_stream)?
139        }
140        WcojKeyWidth::EightByte => {
141            let buf_xy = provider.wcoj_layout_u64_recorded(matched.e_xy, launch_stream)?;
142            let buf_yz = provider.wcoj_layout_u64_recorded(matched.e_yz, launch_stream)?;
143            let buf_xz = provider.wcoj_layout_u64_recorded(matched.e_xz, launch_stream)?;
144            provider.wcoj_triangle_u64_recorded(&buf_xy, &buf_yz, &buf_xz, launch_stream)?
145        }
146    };
147    Ok(Some(result))
148}
149
150/// Matched atom slots after pattern recognition. Each `&CudaBuffer`
151/// borrows from the caller's `inputs` map. `width` is the key
152/// width shared by all three slots (mixed-width triangles never
153/// reach this struct — see `match_triangle_shape`).
154struct MatchedAtoms<'a> {
155    atoms: Vec<MatchedAtom>,
156    e_xy: &'a CudaBuffer,
157    e_yz: &'a CudaBuffer,
158    e_xz: &'a CudaBuffer,
159    width: WcojKeyWidth,
160}
161
162struct MatchedAtom {
163    predicate: String,
164}
165
166/// Physical key width for a WCOJ-eligible binary relation.
167/// `FourByte` covers `U32` and `Symbol` (bit-identical layout);
168/// `EightByte` covers `U64`. Other scalar types are not
169/// WCOJ-eligible at this dispatch level (the planner upstream
170/// is the source of truth for cross-relation type compatibility;
171/// `analyze_typed` rejects them before we get here).
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173enum WcojKeyWidth {
174    FourByte,
175    EightByte,
176}
177
178/// Attempt to match the rule + inputs against the v1 triangle
179/// shape. Returns `Some(...)` if every check passes; `None`
180/// otherwise. All checks are pure validation — no buffer
181/// modification, no kernel launches, no errors.
182fn match_triangle_shape<'a>(
183    rule: &Rule,
184    inputs: &'a BTreeMap<String, CudaBuffer>,
185) -> Option<MatchedAtoms<'a>> {
186    // Head must have exactly 3 distinct variable terms.
187    let head_vars = head_vars_or_none(rule)?;
188
189    // Body must have exactly 3 BodyLiteral::Positive atoms; no
190    // negation, no comparison filters, no is-expr.
191    if rule.body.len() != 3 {
192        return None;
193    }
194    let mut positive_atoms = Vec::with_capacity(3);
195    for literal in &rule.body {
196        match literal {
197            BodyLiteral::Positive(atom) => positive_atoms.push(atom),
198            _ => return None,
199        }
200    }
201
202    // Head predicate must NOT appear in the body — rules out
203    // direct recursion. Indirect recursion through SCC structure
204    // is also unsupported in this slice; since we only see one
205    // rule at a time, "head predicate equals any body
206    // predicate" is the conservative check.
207    let head_predicate = &rule.head.predicate;
208    for atom in &positive_atoms {
209        if &atom.predicate == head_predicate {
210            return None;
211        }
212    }
213
214    // Each atom must be 2-arity, with both args being head
215    // variables (not constants, anonymous, or non-head vars).
216    // Collect (head_position_0, head_position_1) per atom; the
217    // pair must be in ascending order — i.e. atom args are in
218    // head-position order, not reversed.
219    //
220    // After this collection, the three atoms together must
221    // cover EXACTLY the three pairs {0,1}, {1,2}, {0,2} (the
222    // triangle's three edges over head positions).
223    let mut atom_specs: Vec<(usize, usize, &xlog_logic::ast::Atom)> = Vec::with_capacity(3);
224    for atom in &positive_atoms {
225        if atom.terms.len() != 2 {
226            return None;
227        }
228        let pos0 = head_pos_of_var(&head_vars, &atom.terms[0])?;
229        let pos1 = head_pos_of_var(&head_vars, &atom.terms[1])?;
230        if pos0 >= pos1 {
231            // Same variable twice (`e(X, X)`) or reversed
232            // axis (`e(Y, X)`). Both are out-of-scope for v1
233            // dispatch.
234            return None;
235        }
236        atom_specs.push((pos0, pos1, atom));
237    }
238    let pair_set: BTreeSet<(usize, usize)> = atom_specs.iter().map(|(a, b, _)| (*a, *b)).collect();
239    let mut expected: BTreeSet<(usize, usize)> = BTreeSet::new();
240    expected.insert((0, 1));
241    expected.insert((1, 2));
242    expected.insert((0, 2));
243    if pair_set != expected {
244        return None;
245    }
246
247    // Look up each input buffer by predicate name. Validate
248    // every buffer is a 2-column WCOJ-eligible relation
249    // (4-byte U32/Symbol or 8-byte U64) AND that all three
250    // slots share the same width — mixed-width triangles fall
251    // back here so the binary-join chain handles them.
252    let mut e_xy: Option<&CudaBuffer> = None;
253    let mut e_yz: Option<&CudaBuffer> = None;
254    let mut e_xz: Option<&CudaBuffer> = None;
255    let mut atoms_out: Vec<MatchedAtom> = Vec::with_capacity(3);
256    let mut shared_width: Option<WcojKeyWidth> = None;
257    for (pos0, pos1, atom) in &atom_specs {
258        let buf = inputs.get(&atom.predicate)?;
259        let w = classify_two_col_wcoj_width(buf)?;
260        match shared_width {
261            None => shared_width = Some(w),
262            Some(prev) if prev == w => {}
263            Some(_) => return None,
264        }
265        atoms_out.push(MatchedAtom {
266            predicate: atom.predicate.clone(),
267        });
268        match (*pos0, *pos1) {
269            (0, 1) => e_xy = Some(buf),
270            (1, 2) => e_yz = Some(buf),
271            (0, 2) => e_xz = Some(buf),
272            _ => unreachable!("pair_set check above guarantees these three"),
273        }
274    }
275    Some(MatchedAtoms {
276        atoms: atoms_out,
277        e_xy: e_xy?,
278        e_yz: e_yz?,
279        e_xz: e_xz?,
280        width: shared_width?,
281    })
282}
283
284/// Return the rule head's variable names if every head term is a
285/// distinct `Term::Variable`; `None` otherwise.
286fn head_vars_or_none(rule: &Rule) -> Option<Vec<String>> {
287    if rule.head.terms.len() != 3 {
288        return None;
289    }
290    let mut out = Vec::with_capacity(3);
291    for term in &rule.head.terms {
292        match term {
293            Term::Variable(name) => out.push(name.clone()),
294            _ => return None,
295        }
296    }
297    // Distinctness.
298    let unique: BTreeSet<&String> = out.iter().collect();
299    if unique.len() != 3 {
300        return None;
301    }
302    Some(out)
303}
304
305/// Map an atom-arg term to its position in the head variable list
306/// (0, 1, or 2). Returns `None` for anything that isn't a
307/// head-list variable.
308fn head_pos_of_var(head_vars: &[String], term: &Term) -> Option<usize> {
309    match term {
310        Term::Variable(name) => head_vars.iter().position(|v| v == name),
311        _ => None,
312    }
313}
314
315/// Classify a binary [`CudaBuffer`]'s key width for WCOJ
316/// dispatch. Returns
317///
318/// * `Some(WcojKeyWidth::FourByte)` when both columns are
319///   `U32` or `Symbol` (bit-identical 4-byte physical layout —
320///   the WCOJ kernel reads them identically),
321/// * `Some(WcojKeyWidth::EightByte)` when both columns are
322///   `U64`,
323/// * `None` for any other arity / type combination, including
324///   mixed widths within a single buffer (e.g. one column U32
325///   and the other U64).
326///
327/// Cross-relation type-compatibility (so a Symbol column never
328/// joins with a U32 column with the same bit pattern, and a
329/// U32 column never joins with a U64 column) is enforced by
330/// the planner upstream via `xlog_logic::hypergraph::analyze_typed`;
331/// the dispatch helper additionally requires width-uniformity
332/// across the three slots in `match_triangle_shape`.
333fn classify_two_col_wcoj_width(buf: &CudaBuffer) -> Option<WcojKeyWidth> {
334    if buf.arity() != 2 {
335        return None;
336    }
337    let c0 = buf.schema.column_type(0)?;
338    let c1 = buf.schema.column_type(1)?;
339    let width0 = scalar_wcoj_width(c0)?;
340    let width1 = scalar_wcoj_width(c1)?;
341    if width0 != width1 {
342        return None;
343    }
344    Some(width0)
345}
346
347/// Map a single [`ScalarType`] to its WCOJ key width, or `None`
348/// if the type is not supported by any current WCOJ entry.
349fn scalar_wcoj_width(ty: ScalarType) -> Option<WcojKeyWidth> {
350    match ty {
351        ScalarType::U32 | ScalarType::Symbol => Some(WcojKeyWidth::FourByte),
352        ScalarType::U64 => Some(WcojKeyWidth::EightByte),
353        _ => None,
354    }
355}
356
357/// Build a synthetic schema-only [`RefRelation`] for the planner's
358/// view. Rows are intentionally empty: `plan_rule` only consults
359/// schemas for typed gating, never rows.
360fn schema_only_relation(buf: &CudaBuffer) -> RefRelation {
361    let arity = buf.arity();
362    let schema: Vec<ScalarType> = (0..arity)
363        .map(|i| buf.schema.column_type(i).unwrap_or(ScalarType::U32))
364        .collect();
365    RefRelation {
366        schema,
367        rows: Vec::<Vec<RefValue>>::new(),
368    }
369}