xlog_integration/wcoj_dispatch.rs
1//! v0.6.2 minimal env-gated GPU 3-way WCOJ triangle dispatch.
2//!
3//! Single public entry:
4//! [`try_wcoj_triangle_u32_dispatch`] (env-driven) and
5//! [`try_wcoj_triangle_u32_dispatch_with_gate`] (boolean-driven,
6//! for tests). The slice is intentionally narrow:
7//!
8//! * **Env flag only.** `XLOG_USE_WCOJ_TRIANGLE_U32=1` (or
9//! `true`/`TRUE`) opts in. Anything else (unset, `0`,
10//! `false`, etc.) means the helper returns `Ok(None)`
11//! unconditionally — the caller takes the existing
12//! binary-join path.
13//! * **Recognizes exactly one shape.** A rule of the form
14//! `tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z)` over
15//! 2-column WCOJ-eligible relations (U32, Symbol, or U64
16//! keys — see `WcojKeyWidth`): three positive 2-arity body
17//! atoms covering the head's three distinct variables in
18//! head-position order. No negation, no comparison filters,
19//! no recursion (head predicate not in body), no
20//! reversed-axis atoms (e.g. `e1(Y, X)`), no constants in
21//! atom args. The planner must also return
22//! [`xlog_logic::hypergraph::RulePlan::MultiwayCandidate`].
23//! * **Width uniformity.** All three slots must share a key
24//! width. A mixed-width triangle (e.g. e1 U32, e2 U64) is
25//! rejected at this dispatch level — the binary-join chain
26//! handles it.
27//! * **Silent fallback.** Any mismatch — gate off, shape
28//! mismatch, planner verdict not multiway, missing input
29//! buffer, unsupported scalar type, mixed-width slots —
30//! returns `Ok(None)` without an error or log line. The
31//! caller is expected to silently route to the existing
32//! binary-join path. This keeps the env flag truly opt-in
33//! and prevents the helper from accidentally diverting
34//! work it can't handle.
35//! * **Strict GPU pipeline on dispatch.** When all checks pass,
36//! the helper builds three sorted+deduped layouts and runs
37//! the matching WCOJ triangle kernel on the configured
38//! `launch_stream` — `wcoj_layout_u32_recorded` /
39//! `wcoj_triangle_u32_recorded` for 4-byte keys, the
40//! `_u64_recorded` siblings for 8-byte keys. All
41//! [`xlog_cuda::launch::LaunchRecorder`] discipline carries
42//! through unchanged.
43//!
44//! What this slice deliberately does NOT do:
45//!
46//! * No automatic detection at the executor level — callers
47//! pass the rule + input buffers explicitly. Executor
48//! wiring lives in `xlog-runtime`.
49//! * No recursion / SCC mixed execution.
50//! * No cost model.
51//! * No mixed-width admission (U32+U64 triangle stays on the
52//! binary-join path).
53//! * No histogram-guided block dispatch.
54
55use std::collections::{BTreeMap, BTreeSet};
56
57use xlog_core::{Result, ScalarType};
58use xlog_cuda::device_runtime::StreamId;
59use xlog_cuda::memory::CudaBuffer;
60use xlog_cuda::CudaKernelProvider;
61use xlog_logic::ast::{BodyLiteral, Rule, Term};
62use xlog_logic::hypergraph::{plan_rule, RefRelation, RefRelationStore, RefValue, RulePlan};
63
64/// Env variable controlling the dispatch gate. Treated as ON
65/// when set to `"1"` or case-insensitive `"true"`; anything else
66/// (unset, `"0"`, `"false"`, empty string, …) means OFF.
67pub const ENV_USE_WCOJ_TRIANGLE_U32: &str = "XLOG_USE_WCOJ_TRIANGLE_U32";
68
69/// Env-driven entry. Reads `XLOG_USE_WCOJ_TRIANGLE_U32` and
70/// delegates to [`try_wcoj_triangle_u32_dispatch_with_gate`].
71///
72/// Tests should prefer [`try_wcoj_triangle_u32_dispatch_with_gate`]
73/// to avoid races on the process-global env var; the env-driven
74/// form exists for production callers that want to opt in via
75/// configuration alone.
76pub fn try_wcoj_triangle_u32_dispatch(
77 rule: &Rule,
78 inputs: &BTreeMap<String, CudaBuffer>,
79 provider: &CudaKernelProvider,
80 launch_stream: StreamId,
81) -> Result<Option<CudaBuffer>> {
82 let gate = std::env::var(ENV_USE_WCOJ_TRIANGLE_U32)
83 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
84 .unwrap_or(false);
85 try_wcoj_triangle_u32_dispatch_with_gate(gate, rule, inputs, provider, launch_stream)
86}
87
88/// Test-friendly form that takes the gate as an explicit boolean.
89/// Production callers use [`try_wcoj_triangle_u32_dispatch`] which
90/// reads the env var.
91pub fn try_wcoj_triangle_u32_dispatch_with_gate(
92 gate_enabled: bool,
93 rule: &Rule,
94 inputs: &BTreeMap<String, CudaBuffer>,
95 provider: &CudaKernelProvider,
96 launch_stream: StreamId,
97) -> Result<Option<CudaBuffer>> {
98 if !gate_enabled {
99 return Ok(None);
100 }
101 // Match the triangle shape — silently bail on any mismatch.
102 let Some(matched) = match_triangle_shape(rule, inputs) else {
103 return Ok(None);
104 };
105
106 // Run the planner on a synthetic store carrying the inputs'
107 // schemas. plan_rule needs schemas (not rows) for typed
108 // gating; rows are not consulted at the gate phase.
109 let mut plan_store: RefRelationStore = BTreeMap::new();
110 for atom_match in &matched.atoms {
111 // Skip duplicates: the same predicate may legitimately
112 // appear in multiple slots (e.g. all three atoms over a
113 // single edge relation), and a BTreeMap insert collapses
114 // them naturally. Schemas are guaranteed identical
115 // because they all came from the same `inputs[name]`.
116 plan_store.insert(
117 atom_match.predicate.clone(),
118 schema_only_relation(&inputs[&atom_match.predicate]),
119 );
120 }
121 let plan = plan_rule(rule, &plan_store).ok();
122 match plan {
123 Some(RulePlan::MultiwayCandidate { .. }) => {}
124 _ => return Ok(None),
125 }
126
127 // Construct sorted+deduped layouts for each slot and run
128 // the WCOJ triangle kernel. Branch by width: 4-byte
129 // (U32 / Symbol) and 8-byte (U64) inputs go to parallel
130 // provider entries. Mixed-width triangles never reach this
131 // point — `match_triangle_shape` requires all three slots
132 // to share a width.
133 let result = match matched.width {
134 WcojKeyWidth::FourByte => {
135 let buf_xy = provider.wcoj_layout_u32_recorded(matched.e_xy, launch_stream)?;
136 let buf_yz = provider.wcoj_layout_u32_recorded(matched.e_yz, launch_stream)?;
137 let buf_xz = provider.wcoj_layout_u32_recorded(matched.e_xz, launch_stream)?;
138 provider.wcoj_triangle_u32_recorded(&buf_xy, &buf_yz, &buf_xz, launch_stream)?
139 }
140 WcojKeyWidth::EightByte => {
141 let buf_xy = provider.wcoj_layout_u64_recorded(matched.e_xy, launch_stream)?;
142 let buf_yz = provider.wcoj_layout_u64_recorded(matched.e_yz, launch_stream)?;
143 let buf_xz = provider.wcoj_layout_u64_recorded(matched.e_xz, launch_stream)?;
144 provider.wcoj_triangle_u64_recorded(&buf_xy, &buf_yz, &buf_xz, launch_stream)?
145 }
146 };
147 Ok(Some(result))
148}
149
150/// Matched atom slots after pattern recognition. Each `&CudaBuffer`
151/// borrows from the caller's `inputs` map. `width` is the key
152/// width shared by all three slots (mixed-width triangles never
153/// reach this struct — see `match_triangle_shape`).
154struct MatchedAtoms<'a> {
155 atoms: Vec<MatchedAtom>,
156 e_xy: &'a CudaBuffer,
157 e_yz: &'a CudaBuffer,
158 e_xz: &'a CudaBuffer,
159 width: WcojKeyWidth,
160}
161
162struct MatchedAtom {
163 predicate: String,
164}
165
166/// Physical key width for a WCOJ-eligible binary relation.
167/// `FourByte` covers `U32` and `Symbol` (bit-identical layout);
168/// `EightByte` covers `U64`. Other scalar types are not
169/// WCOJ-eligible at this dispatch level (the planner upstream
170/// is the source of truth for cross-relation type compatibility;
171/// `analyze_typed` rejects them before we get here).
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173enum WcojKeyWidth {
174 FourByte,
175 EightByte,
176}
177
178/// Attempt to match the rule + inputs against the v1 triangle
179/// shape. Returns `Some(...)` if every check passes; `None`
180/// otherwise. All checks are pure validation — no buffer
181/// modification, no kernel launches, no errors.
182fn match_triangle_shape<'a>(
183 rule: &Rule,
184 inputs: &'a BTreeMap<String, CudaBuffer>,
185) -> Option<MatchedAtoms<'a>> {
186 // Head must have exactly 3 distinct variable terms.
187 let head_vars = head_vars_or_none(rule)?;
188
189 // Body must have exactly 3 BodyLiteral::Positive atoms; no
190 // negation, no comparison filters, no is-expr.
191 if rule.body.len() != 3 {
192 return None;
193 }
194 let mut positive_atoms = Vec::with_capacity(3);
195 for literal in &rule.body {
196 match literal {
197 BodyLiteral::Positive(atom) => positive_atoms.push(atom),
198 _ => return None,
199 }
200 }
201
202 // Head predicate must NOT appear in the body — rules out
203 // direct recursion. Indirect recursion through SCC structure
204 // is also unsupported in this slice; since we only see one
205 // rule at a time, "head predicate equals any body
206 // predicate" is the conservative check.
207 let head_predicate = &rule.head.predicate;
208 for atom in &positive_atoms {
209 if &atom.predicate == head_predicate {
210 return None;
211 }
212 }
213
214 // Each atom must be 2-arity, with both args being head
215 // variables (not constants, anonymous, or non-head vars).
216 // Collect (head_position_0, head_position_1) per atom; the
217 // pair must be in ascending order — i.e. atom args are in
218 // head-position order, not reversed.
219 //
220 // After this collection, the three atoms together must
221 // cover EXACTLY the three pairs {0,1}, {1,2}, {0,2} (the
222 // triangle's three edges over head positions).
223 let mut atom_specs: Vec<(usize, usize, &xlog_logic::ast::Atom)> = Vec::with_capacity(3);
224 for atom in &positive_atoms {
225 if atom.terms.len() != 2 {
226 return None;
227 }
228 let pos0 = head_pos_of_var(&head_vars, &atom.terms[0])?;
229 let pos1 = head_pos_of_var(&head_vars, &atom.terms[1])?;
230 if pos0 >= pos1 {
231 // Same variable twice (`e(X, X)`) or reversed
232 // axis (`e(Y, X)`). Both are out-of-scope for v1
233 // dispatch.
234 return None;
235 }
236 atom_specs.push((pos0, pos1, atom));
237 }
238 let pair_set: BTreeSet<(usize, usize)> = atom_specs.iter().map(|(a, b, _)| (*a, *b)).collect();
239 let mut expected: BTreeSet<(usize, usize)> = BTreeSet::new();
240 expected.insert((0, 1));
241 expected.insert((1, 2));
242 expected.insert((0, 2));
243 if pair_set != expected {
244 return None;
245 }
246
247 // Look up each input buffer by predicate name. Validate
248 // every buffer is a 2-column WCOJ-eligible relation
249 // (4-byte U32/Symbol or 8-byte U64) AND that all three
250 // slots share the same width — mixed-width triangles fall
251 // back here so the binary-join chain handles them.
252 let mut e_xy: Option<&CudaBuffer> = None;
253 let mut e_yz: Option<&CudaBuffer> = None;
254 let mut e_xz: Option<&CudaBuffer> = None;
255 let mut atoms_out: Vec<MatchedAtom> = Vec::with_capacity(3);
256 let mut shared_width: Option<WcojKeyWidth> = None;
257 for (pos0, pos1, atom) in &atom_specs {
258 let buf = inputs.get(&atom.predicate)?;
259 let w = classify_two_col_wcoj_width(buf)?;
260 match shared_width {
261 None => shared_width = Some(w),
262 Some(prev) if prev == w => {}
263 Some(_) => return None,
264 }
265 atoms_out.push(MatchedAtom {
266 predicate: atom.predicate.clone(),
267 });
268 match (*pos0, *pos1) {
269 (0, 1) => e_xy = Some(buf),
270 (1, 2) => e_yz = Some(buf),
271 (0, 2) => e_xz = Some(buf),
272 _ => unreachable!("pair_set check above guarantees these three"),
273 }
274 }
275 Some(MatchedAtoms {
276 atoms: atoms_out,
277 e_xy: e_xy?,
278 e_yz: e_yz?,
279 e_xz: e_xz?,
280 width: shared_width?,
281 })
282}
283
284/// Return the rule head's variable names if every head term is a
285/// distinct `Term::Variable`; `None` otherwise.
286fn head_vars_or_none(rule: &Rule) -> Option<Vec<String>> {
287 if rule.head.terms.len() != 3 {
288 return None;
289 }
290 let mut out = Vec::with_capacity(3);
291 for term in &rule.head.terms {
292 match term {
293 Term::Variable(name) => out.push(name.clone()),
294 _ => return None,
295 }
296 }
297 // Distinctness.
298 let unique: BTreeSet<&String> = out.iter().collect();
299 if unique.len() != 3 {
300 return None;
301 }
302 Some(out)
303}
304
305/// Map an atom-arg term to its position in the head variable list
306/// (0, 1, or 2). Returns `None` for anything that isn't a
307/// head-list variable.
308fn head_pos_of_var(head_vars: &[String], term: &Term) -> Option<usize> {
309 match term {
310 Term::Variable(name) => head_vars.iter().position(|v| v == name),
311 _ => None,
312 }
313}
314
315/// Classify a binary [`CudaBuffer`]'s key width for WCOJ
316/// dispatch. Returns
317///
318/// * `Some(WcojKeyWidth::FourByte)` when both columns are
319/// `U32` or `Symbol` (bit-identical 4-byte physical layout —
320/// the WCOJ kernel reads them identically),
321/// * `Some(WcojKeyWidth::EightByte)` when both columns are
322/// `U64`,
323/// * `None` for any other arity / type combination, including
324/// mixed widths within a single buffer (e.g. one column U32
325/// and the other U64).
326///
327/// Cross-relation type-compatibility (so a Symbol column never
328/// joins with a U32 column with the same bit pattern, and a
329/// U32 column never joins with a U64 column) is enforced by
330/// the planner upstream via `xlog_logic::hypergraph::analyze_typed`;
331/// the dispatch helper additionally requires width-uniformity
332/// across the three slots in `match_triangle_shape`.
333fn classify_two_col_wcoj_width(buf: &CudaBuffer) -> Option<WcojKeyWidth> {
334 if buf.arity() != 2 {
335 return None;
336 }
337 let c0 = buf.schema.column_type(0)?;
338 let c1 = buf.schema.column_type(1)?;
339 let width0 = scalar_wcoj_width(c0)?;
340 let width1 = scalar_wcoj_width(c1)?;
341 if width0 != width1 {
342 return None;
343 }
344 Some(width0)
345}
346
347/// Map a single [`ScalarType`] to its WCOJ key width, or `None`
348/// if the type is not supported by any current WCOJ entry.
349fn scalar_wcoj_width(ty: ScalarType) -> Option<WcojKeyWidth> {
350 match ty {
351 ScalarType::U32 | ScalarType::Symbol => Some(WcojKeyWidth::FourByte),
352 ScalarType::U64 => Some(WcojKeyWidth::EightByte),
353 _ => None,
354 }
355}
356
357/// Build a synthetic schema-only [`RefRelation`] for the planner's
358/// view. Rows are intentionally empty: `plan_rule` only consults
359/// schemas for typed gating, never rows.
360fn schema_only_relation(buf: &CudaBuffer) -> RefRelation {
361 let arity = buf.arity();
362 let schema: Vec<ScalarType> = (0..arity)
363 .map(|i| buf.schema.column_type(i).unwrap_or(ScalarType::U32))
364 .collect();
365 RefRelation {
366 schema,
367 rows: Vec::<Vec<RefValue>>::new(),
368 }
369}