xlog_runtime/executor/recursive.rs
1//! Recursive SCC execution using semi-naive fixpoint iteration.
2
3use std::collections::{BTreeSet, HashMap, HashSet};
4use std::time::Instant;
5
6use xlog_core::{RelId, Result, Schema, XlogError};
7use xlog_cuda::CudaBuffer;
8use xlog_ir::{ExecutionPlan, RirNode, Stratum};
9
10use super::delta::DeltaRelationTracker;
11use super::Executor;
12
13impl Executor {
14 /// Maximum iterations for fixpoint computation to prevent infinite loops
15 const MAX_FIXPOINT_ITERATIONS: usize = 1000;
16
17 /// For a `MultiWayJoin` or `ChainJoin` body, try the specialized WCOJ
18 /// dispatchers first; on decline, fall back to the embedded fallback
19 /// subtree via `execute_node`. For any other RIR variant, defer to
20 /// `execute_node` directly.
21 ///
22 /// Used at two sites in the recursive engine: the seeding pass, where
23 /// stable rules and recursive rules get their initial dispatch on the full
24 /// body, and the per-variant loop, where each recursive scan with a
25 /// non-empty delta is rewritten to its delta `RelId` for one dispatch.
26 /// Multi-recursive bodies, including distinct recursive predicates and
27 /// same-predicate self-recursive bodies, reach a `MultiWayJoin` here after
28 /// the promoter admits bodies with more than one recursive scan; the
29 /// per-variant rewrite loop builds one variant per recursive occurrence
30 /// with a non-empty delta and dispatches each via this helper.
31 ///
32 /// Counter semantics: `wcoj_*_dispatch_count` increments once per
33 /// successful WCOJ kernel result: once per recursive rule, iteration, and
34 /// variant. Non-recursive dispatch sites increment once per rule per call.
35 fn execute_wcoj_or_fallback_node(&mut self, node: &RirNode) -> Result<CudaBuffer> {
36 if let RirNode::ChainJoin { .. } = node {
37 if let Some(buf) = self.try_dispatch_chain_on_body(node)? {
38 return Ok(buf);
39 }
40 return self.execute_node(node);
41 }
42 if let RirNode::MultiWayJoin { .. } = node {
43 // Triangle, 4-cycle, then K-clique. A body cannot
44 // match more than one paper-derived shape (different
45 // atom counts). The dispatcher's own gate handles
46 // env-var / config / adaptive decisions; this site is
47 // purely structural.
48 if let Some(buf) = self.try_dispatch_wcoj_triangle_on_body(node)? {
49 return Ok(buf);
50 }
51 if let Some(buf) = self.try_dispatch_wcoj_4cycle_on_body(node)? {
52 return Ok(buf);
53 }
54 // Recursive clique bodies use the same launch-local metadata
55 // builders as non-recursive K-clique dispatch, so rewritten
56 // semi-naive variants are eligible here too.
57 if let Some(buf) = self.try_dispatch_wcoj_clique5_on_body(node)? {
58 return Ok(buf);
59 }
60 if let Some(buf) = self.try_dispatch_wcoj_clique6_on_body(node)? {
61 return Ok(buf);
62 }
63 if let Some(buf) = self.try_dispatch_wcoj_clique7_on_body(node)? {
64 return Ok(buf);
65 }
66 if let Some(buf) = self.try_dispatch_wcoj_clique8_on_body(node)? {
67 return Ok(buf);
68 }
69 // Generalized Free Join dispatch for every multiway shape the
70 // dedicated dispatchers declined. The hook re-checks dedicated
71 // shapes structurally, so it only fires on general bodies.
72 if let Some(buf) = self.try_dispatch_free_join(node)? {
73 return Ok(buf);
74 }
75 }
76 self.execute_node(node)
77 }
78
79 fn refresh_kclique_edge_metadata_after_merge(
80 &mut self,
81 rules: &[xlog_ir::CompiledRule],
82 pred: &str,
83 ) {
84 let start = Instant::now();
85 let affected_rules = rules
86 .iter()
87 .filter(|rule| self.kclique_body_mentions_pred(&rule.body, pred))
88 .count() as u64;
89 self.record_kclique_histogram_refresh_time(start, affected_rules);
90 }
91
92 fn record_kclique_histogram_refresh_time(&mut self, start: Instant, affected_rules: u64) {
93 if affected_rules == 0 {
94 return;
95 }
96 self.kclique_histogram_refresh_count = self
97 .kclique_histogram_refresh_count
98 .saturating_add(affected_rules);
99 self.kclique_histogram_refresh_nanos = self
100 .kclique_histogram_refresh_nanos
101 .saturating_add(start.elapsed().as_nanos());
102 }
103
104 fn kclique_body_mentions_pred(&self, node: &RirNode, pred: &str) -> bool {
105 let RirNode::MultiWayJoin {
106 inputs, var_order, ..
107 } = node
108 else {
109 return false;
110 };
111 let Some(order) = var_order.as_ref().and_then(|order| order.kclique.as_ref()) else {
112 return false;
113 };
114 if !matches!(order.k, 5..=8) {
115 return false;
116 }
117 inputs.iter().any(|input| {
118 let RirNode::Scan { rel } = input else {
119 return false;
120 };
121 self.rel_names.get(rel).is_some_and(|name| name == pred)
122 })
123 }
124
125 /// Stub: always returns an error directing callers to use `execute_plan` instead.
126 pub fn execute_stratum(&mut self, _stratum: &Stratum) -> Result<()> {
127 Err(XlogError::Execution(
128 "execute_stratum cannot be called directly; use execute_plan instead which provides \
129 the required rules_by_scc context"
130 .to_string(),
131 ))
132 }
133
134 /// Execute all rules in a non-recursive strongly connected component once.
135 pub fn execute_non_recursive_scc(&mut self, rules: &[xlog_ir::CompiledRule]) -> Result<()> {
136 for rule in rules {
137 let result = self.execute_node(&rule.body)?;
138
139 if let Some(existing) = self.store.get(&rule.head) {
140 if result.is_empty() {
141 continue;
142 }
143 let merged = self.provider.union_gpu(existing, &result)?;
144 self.store_put(&rule.head, merged);
145 } else {
146 let key_cols: Vec<usize> = (0..result.arity()).collect();
147 let deduped = if result.is_empty() {
148 result
149 } else {
150 self.provider.dedup(&result, &key_cols)?
151 };
152 self.store_put(&rule.head, deduped);
153 }
154 }
155 Ok(())
156 }
157
158 /// Execute a stratum (internal implementation)
159 ///
160 /// Processes all SCCs in the stratum by executing their rules.
161 /// For recursive SCCs, uses semi-naive fixpoint iteration.
162 pub(super) fn execute_stratum_impl(
163 &mut self,
164 stratum: &Stratum,
165 plan: &ExecutionPlan,
166 ) -> Result<()> {
167 // Process each SCC in the stratum
168 for &scc_id in &stratum.sccs {
169 // Get rules for this SCC
170 if let Some(rules) = plan.rules_by_scc.get(scc_id as usize) {
171 // Get SCC metadata
172 let scc = plan.sccs.get(scc_id as usize);
173 let is_recursive = scc.map(|s| s.is_recursive).unwrap_or(false);
174
175 if is_recursive {
176 // Recursive SCC: use semi-naive fixpoint iteration. The
177 // recursive engine invokes WCOJ dispatch via
178 // `execute_wcoj_or_fallback_node` on both the seeding
179 // pass and per-variant evaluation when the promoted body
180 // shape is eligible.
181 self.execute_recursive_scc(rules)?;
182 } else {
183 // Non-recursive SCC: execute rules once, union results for same predicate.
184 for rule in rules {
185 // Route two-atom ChainJoin bodies before the
186 // triangle/4-cycle/KC attempts. The dispatcher
187 // silently declines on non-chain bodies or when
188 // the env gate disables the route.
189 if let Some(chain_result) = self.try_dispatch_chain_on_body(&rule.body)? {
190 if let Some(existing) = self.store.get(&rule.head) {
191 let merged = self.provider.union_gpu(existing, &chain_result)?;
192 self.store_put(&rule.head, merged);
193 } else {
194 let key_cols: Vec<usize> = (0..chain_result.arity()).collect();
195 let deduped = if chain_result.is_empty() {
196 chain_result
197 } else {
198 let dedup_input_rows = chain_result.num_rows();
199 let start = self.profiler.start_op();
200 let deduped = self.provider.dedup(&chain_result, &key_cols)?;
201 if let Some(start) = start {
202 let mem = self.provider.memory().allocated_bytes();
203 self.profiler.record_op(
204 "dedup",
205 dedup_input_rows,
206 deduped.num_rows(),
207 start,
208 mem,
209 );
210 self.profiler.record_peak_memory(mem);
211 }
212 deduped
213 };
214 self.store_put(&rule.head, deduped);
215 }
216 continue;
217 }
218
219 // WCOJ triangle dispatch, gated by runtime configuration.
220 // Try to short-circuit the rule via the GPU
221 // 3-way kernel. On Some(_), install the
222 // result and skip the binary-join path for
223 // this rule. On None (gate off, shape
224 // mismatch, missing input, kernel error),
225 // fall through silently. See
226 // `wcoj_dispatch::try_dispatch_wcoj_triangle`
227 // for the full match contract.
228 if let Some(wcoj_result) = self.try_dispatch_wcoj_triangle(rule)? {
229 // Mirrors the binary-join arm below:
230 // union with existing result if predicate
231 // already has data; otherwise install
232 // directly. WCOJ output is already
233 // sorted+deduped, so the dedup pass on
234 // the else branch is unnecessary here.
235 if let Some(existing) = self.store.get(&rule.head) {
236 let merged = self.provider.union_gpu(existing, &wcoj_result)?;
237 self.store_put(&rule.head, merged);
238 } else {
239 self.store_put(&rule.head, wcoj_result);
240 }
241 continue;
242 }
243
244 // WCOJ 4-cycle dispatch.
245 // Same pattern as triangle. Order is a doc
246 // anchor — a body cannot match both shapes
247 // (different atom counts), so triangle's
248 // earlier attempt always returns None on a
249 // 4-cycle body and vice versa.
250 if let Some(wcoj_result) = self.try_dispatch_wcoj_4cycle(rule)? {
251 if let Some(existing) = self.store.get(&rule.head) {
252 let merged = self.provider.union_gpu(existing, &wcoj_result)?;
253 self.store_put(&rule.head, merged);
254 } else {
255 self.store_put(&rule.head, wcoj_result);
256 }
257 continue;
258 }
259
260 // K-clique dispatch for k=5..k=8.
261 // Same shape-gated default-dispatch
262 // pattern as triangle / 4-cycle; silent
263 // fallback to MultiWayJoin.fallback on
264 // dispatcher decline or kernel error.
265 if let Some(wcoj_result) = self.try_dispatch_wcoj_clique5(rule)? {
266 if let Some(existing) = self.store.get(&rule.head) {
267 let merged = self.provider.union_gpu(existing, &wcoj_result)?;
268 self.store_put(&rule.head, merged);
269 } else {
270 self.store_put(&rule.head, wcoj_result);
271 }
272 continue;
273 }
274 if let Some(wcoj_result) = self.try_dispatch_wcoj_clique6(rule)? {
275 if let Some(existing) = self.store.get(&rule.head) {
276 let merged = self.provider.union_gpu(existing, &wcoj_result)?;
277 self.store_put(&rule.head, merged);
278 } else {
279 self.store_put(&rule.head, wcoj_result);
280 }
281 continue;
282 }
283 if let Some(wcoj_result) = self.try_dispatch_wcoj_clique7(rule)? {
284 if let Some(existing) = self.store.get(&rule.head) {
285 let merged = self.provider.union_gpu(existing, &wcoj_result)?;
286 self.store_put(&rule.head, merged);
287 } else {
288 self.store_put(&rule.head, wcoj_result);
289 }
290 continue;
291 }
292 if let Some(wcoj_result) = self.try_dispatch_wcoj_clique8(rule)? {
293 if let Some(existing) = self.store.get(&rule.head) {
294 let merged = self.provider.union_gpu(existing, &wcoj_result)?;
295 self.store_put(&rule.head, merged);
296 } else {
297 self.store_put(&rule.head, wcoj_result);
298 }
299 continue;
300 }
301
302 // Generalized Free Join dispatch for every multiway
303 // shape the dedicated dispatchers above declined. The
304 // dispatcher re-checks those shapes structurally, so
305 // it only fires on general bodies. Unlike the
306 // dedicated kernels, the frontier engine emits one row
307 // per derivation path, so the install mirrors the
308 // binary-join arm below: `union_gpu` dedups, and
309 // fresh installs dedup explicitly.
310 if let Some(fj_result) = self.try_dispatch_free_join(&rule.body)? {
311 if let Some(existing) = self.store.get(&rule.head) {
312 let merged = self.provider.union_gpu(existing, &fj_result)?;
313 self.store_put(&rule.head, merged);
314 } else {
315 let key_cols: Vec<usize> = (0..fj_result.arity()).collect();
316 let deduped = if fj_result.is_empty() {
317 fj_result
318 } else {
319 self.provider.dedup(&fj_result, &key_cols)?
320 };
321 self.store_put(&rule.head, deduped);
322 }
323 continue;
324 }
325
326 // When WCOJ dispatch declines on a `MultiWayJoin`
327 // body (gate off, kernel error, adaptive score below
328 // threshold, ...), execute the embedded `fallback`,
329 // the post-optimizer binary-join tree the promoter
330 // captured. `execute_node`'s `MultiWayJoin` arm is the
331 // defensive safety net; explicit destructuring here
332 // keeps the intent visible at the dispatch site.
333 let body_to_execute = match &rule.body {
334 xlog_ir::RirNode::MultiWayJoin { fallback, .. }
335 | xlog_ir::RirNode::ChainJoin { fallback, .. } => fallback.as_ref(),
336 other => other,
337 };
338 let result = self.execute_node(body_to_execute)?;
339
340 // Union with existing result if predicate already has data
341 if let Some(existing) = self.store.get(&rule.head) {
342 let union_input_rows = existing.num_rows() + result.num_rows();
343 let start = self.profiler.start_op();
344 let merged = self.provider.union_gpu(existing, &result)?;
345 if let Some(start) = start {
346 let mem = self.provider.memory().allocated_bytes();
347 self.profiler.record_op(
348 "union",
349 union_input_rows,
350 merged.num_rows(),
351 start,
352 mem,
353 );
354 self.profiler.record_peak_memory(mem);
355 }
356 self.store_put(&rule.head, merged);
357 } else {
358 let key_cols: Vec<usize> = (0..result.arity()).collect();
359 let deduped = if result.is_empty() {
360 result
361 } else {
362 let dedup_input_rows = result.num_rows();
363 let start = self.profiler.start_op();
364 let deduped = self.provider.dedup(&result, &key_cols)?;
365 if let Some(start) = start {
366 let mem = self.provider.memory().allocated_bytes();
367 self.profiler.record_op(
368 "dedup",
369 dedup_input_rows,
370 deduped.num_rows(),
371 start,
372 mem,
373 );
374 self.profiler.record_peak_memory(mem);
375 }
376 deduped
377 };
378 self.store_put(&rule.head, deduped);
379 }
380 }
381 }
382 }
383 }
384
385 Ok(())
386 }
387
388 /// Execute a recursive SCC using semi-naive fixpoint iteration
389 ///
390 /// The algorithm:
391 /// 1. Execute all rules once to get initial result
392 /// 2. Track which relations changed (delta)
393 /// 3. Re-execute rules, using delta from previous iteration
394 /// 4. Repeat until no changes (fixpoint reached)
395 pub fn execute_recursive_scc(&mut self, rules: &[xlog_ir::CompiledRule]) -> Result<()> {
396 // Reset the per-iteration stats trace at SCC entry so tests see a
397 // fresh trace per invocation. Gated on the `recursive-stats-trace`
398 // feature; default OFF.
399 #[cfg(feature = "recursive-stats-trace")]
400 {
401 self.last_recursive_stats_trace.entries.clear();
402 }
403 // Identify SCC predicates from rule heads (these are the recursive IDBs).
404 let mut recursive_pred_names: BTreeSet<String> = BTreeSet::new();
405 let mut schema_by_pred: HashMap<String, Schema> = HashMap::new();
406 for rule in rules {
407 recursive_pred_names.insert(rule.head.clone());
408 if rule.meta.schema.arity() > 0 {
409 schema_by_pred
410 .entry(rule.head.clone())
411 .or_insert_with(|| rule.meta.schema.clone());
412 }
413 }
414 let recursive_pred_lookup: HashSet<String> = recursive_pred_names.iter().cloned().collect();
415 let recursive_preds: Vec<String> = recursive_pred_names.into_iter().collect();
416
417 // Ensure all recursive predicates exist in the store so scans never fail
418 // due to evaluation order (mutual recursion can reference an as-yet-empty relation).
419 for pred in &recursive_preds {
420 if !self.store.contains(pred) {
421 let schema = schema_by_pred
422 .get(pred)
423 .cloned()
424 .or_else(|| self.store.get(pred).map(|b| b.schema().clone()))
425 .ok_or_else(|| {
426 XlogError::Execution(format!(
427 "Missing schema for recursive predicate {}",
428 pred
429 ))
430 })?;
431 let empty = self.create_empty_buffer(schema)?;
432 self.store_put(pred, empty);
433 }
434 }
435
436 // Create per-predicate delta relations (distinct RelIds) so semi-naive evaluation
437 // can target a single recursive Scan occurrence without overriding *all* scans of
438 // that predicate in a rule (required for self-joins like p(X,Y), p(Y,Z)).
439 let mut next_rel_id = self
440 .rel_names
441 .keys()
442 .map(|r| r.0)
443 .max()
444 .unwrap_or(0)
445 .saturating_add(1);
446
447 let mut delta_tracker = DeltaRelationTracker::new();
448 for pred in &recursive_preds {
449 let rel_id = RelId(next_rel_id);
450 next_rel_id = next_rel_id.saturating_add(1);
451 let name = format!("__delta_{}_{}", pred, rel_id.0);
452 self.register_relation(rel_id, &name);
453 delta_tracker.insert(pred.clone(), rel_id, name);
454 }
455
456 // Execute all rules once against the current store to seed initial results.
457 // Accumulate per-head before mutating the store to avoid order dependence.
458 //
459 // Route through `execute_wcoj_or_fallback_node` so promoted
460 // MultiWayJoin bodies for stable and linear-recursive triangles or
461 // 4-cycles get a chance at WCOJ dispatch on the seeding pass. Stable
462 // rules with zero recursive scans only run here, so without this hook
463 // they would never see a kernel.
464 let mut derived_initial: HashMap<String, CudaBuffer> = HashMap::new();
465 for rule in rules {
466 let result = self.execute_wcoj_or_fallback_node(&rule.body)?;
467 if let Some(acc) = derived_initial.get_mut(&rule.head) {
468 let union_input = acc.num_rows() + result.num_rows();
469 let start = self.profiler.start_op();
470 let merged = self.provider.union_gpu(acc, &result)?;
471 if let Some(start) = start {
472 let mem = self.provider.memory().allocated_bytes();
473 self.profiler
474 .record_op("union", union_input, merged.num_rows(), start, mem);
475 self.profiler.record_peak_memory(mem);
476 }
477 *acc = merged;
478 } else {
479 derived_initial.insert(rule.head.clone(), result);
480 }
481 }
482
483 // Initialize delta from the newly-derived tuples only.
484 //
485 // This supports incremental maintenance: if the SCC is executed again after EDB inserts,
486 // the delta relations start with only the *new* tuples, not a full rescan of the current
487 // fixed point.
488 for pred in &recursive_preds {
489 let full_old = self
490 .store
491 .remove(pred)
492 .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", pred)))?;
493
494 let derived = match derived_initial.remove(pred) {
495 Some(buf) => buf,
496 None => self.create_empty_buffer(full_old.schema().clone())?,
497 };
498
499 let union_input = full_old.num_rows() + derived.num_rows();
500 let start = self.profiler.start_op();
501 let merged = self.provider.union_gpu(&full_old, &derived)?;
502 if let Some(start) = start {
503 let mem = self.provider.memory().allocated_bytes();
504 self.profiler
505 .record_op("union", union_input, merged.num_rows(), start, mem);
506 self.profiler.record_peak_memory(mem);
507 }
508
509 let full_new = merged;
510
511 let delta_name = delta_tracker.delta_name(pred)?;
512
513 let full_old_rows = self.buffer_row_count(&full_old)?;
514 let full_new_rows = self.buffer_row_count(&full_new)?;
515 let delta_initial = if full_new_rows == 0 {
516 self.create_empty_buffer(full_new.schema().clone())?
517 } else if full_old_rows == 0 {
518 self.clone_buffer(&full_new)?
519 } else {
520 let diff_input = full_new.num_rows() + full_old.num_rows();
521 let start = self.profiler.start_op();
522 let diffed = self.provider.diff_gpu(&full_new, &full_old)?;
523 if let Some(start) = start {
524 let mem = self.provider.memory().allocated_bytes();
525 self.profiler
526 .record_op("diff", diff_input, diffed.num_rows(), start, mem);
527 self.profiler.record_peak_memory(mem);
528 }
529 diffed
530 };
531
532 // Seed-iteration cardinality refresh. Capture the actual
533 // `delta_initial` row count before the `store_put` move; after the
534 // move, the buffer is gone.
535 let delta_initial_rows = self.buffer_row_count(&delta_initial)? as u64;
536 let seed_full_rows = full_new_rows as u64;
537 // Pre-resolve rel_id lookups before the &mut self stats
538 // borrow below.
539 let full_rel_opt = self.name_to_rel_id(pred);
540 let delta_rel = delta_tracker.delta_rel_id(pred)?;
541
542 self.store_put(pred, full_new);
543 self.store_put(delta_name, delta_initial);
544
545 // Stats updates fire whether or not WCOJ ran on the seed
546 // pass. update_cardinality is a no-op for unregistered
547 // rel_ids (defensive: tests that don't register an IDB
548 // head get a no-op for the full_rel write).
549 if let Some(full_rel) = full_rel_opt {
550 self.stats.update_cardinality(full_rel, seed_full_rows);
551 }
552 self.stats.update_cardinality(delta_rel, delta_initial_rows);
553
554 // Seed stats trace entry, gated on `recursive-stats-trace`.
555 #[cfg(feature = "recursive-stats-trace")]
556 self.last_recursive_stats_trace
557 .entries
558 .push(super::RecursiveStatsTraceEntry {
559 iteration: 0,
560 pred: pred.clone(),
561 full_rel: full_rel_opt.unwrap_or(RelId(u32::MAX)),
562 delta_rel,
563 full_rows: seed_full_rows,
564 delta_rows: delta_initial_rows,
565 phase: super::RecursiveStatsPhase::Seed,
566 binary_est_for_variant: None,
567 });
568 }
569
570 // Iterate until no new tuples are produced.
571 let mut reached_fixpoint = false;
572 let max_iterations = self.config.max_iterations as usize;
573 let mut iteration_count = 0usize;
574 // D3 — per-fixpoint dispatch context for the factorized delta
575 // (domain bounds + normalized EDB statics are cached across
576 // iterations).
577 let mut fd_ctx = super::wcoj_dispatch::FactorizedDeltaCtx::default();
578 for _iteration in 0..max_iterations {
579 iteration_count += 1;
580 // Compute delta_new_raw per head by evaluating each rule once per recursive Scan occurrence.
581 let mut delta_new_raw_by_head: HashMap<String, CudaBuffer> = HashMap::new();
582 // D3 — factorized novel sets per head: already diffed
583 // against the stable relation and full-row deduped at
584 // dispatch time. Kept separate from the raw accumulator so
585 // all-factorized heads can skip the legacy diff entirely.
586 let mut delta_novel_by_head: HashMap<String, CudaBuffer> = HashMap::new();
587
588 for rule in rules {
589 let mut scans = Vec::new();
590 Self::collect_scan_rels(&rule.body, &mut scans);
591
592 // Build a list of (rel_id, occurrence_idx, pred_name) for recursive scans.
593 let mut seen: HashMap<RelId, usize> = HashMap::new();
594 let mut variants: Vec<(RelId, usize, String)> = Vec::new();
595 for rel_id in scans {
596 let pred_name = match self.get_rel_name(rel_id) {
597 Some(n) => n.to_string(),
598 None => continue,
599 };
600 if !recursive_pred_lookup.contains(&pred_name) {
601 continue;
602 }
603
604 // Skip variants where the delta for this predicate is empty.
605 let delta_name = match delta_tracker.get(&pred_name) {
606 Some((_rel_id, name)) => name.as_str(),
607 None => continue,
608 };
609 let delta_is_empty = match self.store.get(delta_name) {
610 Some(delta) => self.buffer_row_count(delta)? == 0,
611 None => true,
612 };
613 if delta_is_empty {
614 continue;
615 }
616
617 let occ = seen.entry(rel_id).or_insert(0);
618 variants.push((rel_id, *occ, pred_name));
619 *occ += 1;
620 }
621
622 if variants.is_empty() {
623 // Base rule: it can only contribute on the first seeding pass.
624 continue;
625 }
626
627 let mut rule_delta_raw: Option<CudaBuffer> = None;
628 let mut rule_delta_novel: Option<CudaBuffer> = None;
629 for (rel_id, occ, pred_name) in variants {
630 let delta_rel_id = delta_tracker.delta_rel_id(&pred_name)?;
631
632 let variant_node =
633 Self::rewrite_scan_nth(&rule.body, rel_id, occ, delta_rel_id).ok_or_else(
634 || {
635 XlogError::Execution(format!(
636 "Failed to rewrite rule body for predicate {}",
637 pred_name
638 ))
639 },
640 )?;
641
642 // Try the factorized delta pipeline first: a qualifying
643 // ChainJoin variant returns the novel set directly
644 // (already diffed against the head's stable relation and
645 // deduped). Declines are silent and fall through to the
646 // legacy path.
647 if let Some(novel) = self.try_dispatch_factorized_delta(
648 &variant_node,
649 delta_rel_id,
650 &rule.head,
651 &recursive_pred_lookup,
652 &mut fd_ctx,
653 )? {
654 rule_delta_novel = Some(match rule_delta_novel {
655 Some(acc) => self.provider.union_gpu(&acc, &novel)?,
656 None => novel,
657 });
658 continue;
659 }
660
661 // Try WCOJ on the rewritten variant body before falling
662 // back to the binary-join walker.
663 // For a linear-recursive triangle/4-cycle, the
664 // variant has one Scan's RelId swapped to its
665 // delta — the kernel reads from the delta store
666 // entry transparently, no special-case dispatch
667 // logic needed.
668 let out = self.execute_wcoj_or_fallback_node(&variant_node)?;
669 rule_delta_raw = Some(if let Some(acc) = rule_delta_raw {
670 let union_input = acc.num_rows() + out.num_rows();
671 let start = self.profiler.start_op();
672 let merged = self.provider.union_gpu(&acc, &out)?;
673 if let Some(start) = start {
674 let mem = self.provider.memory().allocated_bytes();
675 self.profiler.record_op(
676 "union",
677 union_input,
678 merged.num_rows(),
679 start,
680 mem,
681 );
682 self.profiler.record_peak_memory(mem);
683 }
684 merged
685 } else {
686 out
687 });
688 }
689
690 // D3 — a rule with BOTH factorized and legacy variant
691 // outputs folds its novel set into the raw accumulator
692 // (the legacy diff is a no-op on novel rows, so this is
693 // sound); an all-factorized rule keeps its novel set on
694 // the diff-free track.
695 if rule_delta_raw.is_some() {
696 if let Some(novel) = rule_delta_novel.take() {
697 let raw = rule_delta_raw.as_ref().expect("checked above");
698 rule_delta_raw = Some(self.provider.union_gpu(raw, &novel)?);
699 }
700 }
701 if let Some(rule_out) = rule_delta_raw {
702 if let Some(acc) = delta_new_raw_by_head.get_mut(&rule.head) {
703 let union_input = acc.num_rows() + rule_out.num_rows();
704 let start = self.profiler.start_op();
705 let merged = self.provider.union_gpu(acc, &rule_out)?;
706 if let Some(start) = start {
707 let mem = self.provider.memory().allocated_bytes();
708 self.profiler.record_op(
709 "union",
710 union_input,
711 merged.num_rows(),
712 start,
713 mem,
714 );
715 self.profiler.record_peak_memory(mem);
716 }
717 *acc = merged;
718 } else {
719 delta_new_raw_by_head.insert(rule.head.clone(), rule_out);
720 }
721 }
722 if let Some(rule_novel) = rule_delta_novel {
723 if let Some(acc) = delta_novel_by_head.get_mut(&rule.head) {
724 *acc = self.provider.union_gpu(acc, &rule_novel)?;
725 } else {
726 delta_novel_by_head.insert(rule.head.clone(), rule_novel);
727 }
728 }
729 }
730
731 // Finalize delta_new per head: delta_new = dedup(delta_raw - full).
732 delta_tracker.begin_iteration();
733
734 for pred in &recursive_preds {
735 let full = self
736 .store
737 .get(pred)
738 .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", pred)))?;
739 // Capture the current full row count for the trace's
740 // `full_rows` field before this iteration's delta relation is
741 // replaced. Gated on `recursive-stats-trace` so production
742 // builds do not compute it.
743 #[cfg(feature = "recursive-stats-trace")]
744 let pre_phase4_full_rows = self.buffer_row_count(full)? as u64;
745
746 let delta_raw = delta_new_raw_by_head.remove(pred);
747 let delta_novel = delta_novel_by_head.remove(pred);
748 // D3 — when a head received both raw and factorized
749 // contributions (different rules), fold the novel set
750 // into the raw side before the legacy diff (sound: the
751 // diff is a no-op on novel rows). An all-factorized
752 // head skips the diff entirely — its novel set is
753 // already diffed and deduped by construction.
754 let (delta_raw, delta_novel) = match (delta_raw, delta_novel) {
755 (Some(raw), Some(novel)) => {
756 (Some(self.provider.union_gpu(&raw, &novel)?), None)
757 }
758 other => other,
759 };
760 let delta_new = if let Some(novel) = delta_novel {
761 novel
762 } else if let Some(delta_raw) = delta_raw {
763 if self.buffer_row_count(&delta_raw)? == 0 {
764 self.create_empty_buffer(full.schema().clone())?
765 } else {
766 let diff_input = delta_raw.num_rows() + full.num_rows();
767 let start = self.profiler.start_op();
768 let diffed = self.provider.diff_gpu(&delta_raw, full)?;
769 if let Some(start) = start {
770 let mem = self.provider.memory().allocated_bytes();
771 self.profiler.record_op(
772 "diff",
773 diff_input,
774 diffed.num_rows(),
775 start,
776 mem,
777 );
778 self.profiler.record_peak_memory(mem);
779 }
780 diffed
781 }
782 } else {
783 self.create_empty_buffer(full.schema().clone())?
784 };
785
786 let delta_name = delta_tracker.delta_name(pred)?.to_string();
787 let delta_new_rows = self.buffer_row_count(&delta_new)? as u64;
788 if delta_new_rows != 0 {
789 delta_tracker.mark_changed();
790 }
791 // Pre-resolve rel_id lookups before the &mut self
792 // store_put + stats update below. `full_rel_opt` is
793 // only used by the trace under the
794 // `recursive-stats-trace` feature.
795 #[cfg(feature = "recursive-stats-trace")]
796 let full_rel_opt = self.name_to_rel_id(pred);
797 let delta_rel = delta_tracker.delta_rel_id(pred)?;
798 self.store_put(&delta_name, delta_new);
799
800 // Refresh the delta relation cardinality after computing this
801 // iteration's delta. The full relation cardinality is not
802 // updated here because the full relation has not changed yet
803 // this iteration; the merge step owns that.
804 self.stats.update_cardinality(delta_rel, delta_new_rows);
805
806 // Delta stats trace entry, gated on `recursive-stats-trace`.
807 // binary_est_for_variant captures the cost model's
808 // first-binary-hop estimate for the linear-recursive fixtures
809 // (`pred == "e1"` rewrites
810 // Scan(e1) → Scan(delta_e1); first hop is
811 // `delta_e1.col1 ⋈ e2.col0`). Populated inline because
812 // delta_rel is unregistered at fixpoint exit, so the
813 // test cannot recompute after `execute_plan` returns.
814 #[cfg(feature = "recursive-stats-trace")]
815 let binary_est_for_variant: Option<u64> = if pred == "e1" {
816 self.name_to_rel_id("e2").map(|e2_rel| {
817 self.stats
818 .estimate_join_cardinality(delta_rel, e2_rel, &[1], &[0])
819 })
820 } else {
821 None
822 };
823 #[cfg(feature = "recursive-stats-trace")]
824 self.last_recursive_stats_trace
825 .entries
826 .push(super::RecursiveStatsTraceEntry {
827 iteration: iteration_count,
828 pred: pred.clone(),
829 full_rel: full_rel_opt.unwrap_or(RelId(u32::MAX)),
830 delta_rel,
831 full_rows: pre_phase4_full_rows,
832 delta_rows: delta_new_rows,
833 phase: super::RecursiveStatsPhase::Phase2Delta,
834 binary_est_for_variant,
835 });
836 }
837
838 // Fixpoint reached if no deltas produced.
839 if delta_tracker.is_converged() {
840 reached_fixpoint = true;
841 self.profiler.record_iterations(iteration_count);
842 break;
843 }
844
845 // Merge deltas into full relations.
846 for pred in &recursive_preds {
847 let full_old = self
848 .store
849 .remove(pred)
850 .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", pred)))?;
851 let dn = delta_tracker.delta_name(pred)?.to_string();
852 let delta = self
853 .store_remove(&dn)
854 .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", dn)))?;
855
856 if self.buffer_row_count(&delta)? == 0 {
857 // Zero-delta short-circuit: full and delta are unchanged
858 // this iteration. The delta relation record with zero
859 // rows stands, and the full relation record from the prior
860 // merge stands. No additional update is needed.
861 self.store_put(pred, full_old);
862 self.store_put(&dn, delta);
863 continue;
864 }
865
866 let union_input = full_old.num_rows() + delta.num_rows();
867 let start = self.profiler.start_op();
868 let merged = self.provider.union_gpu(&full_old, &delta)?;
869 if let Some(start) = start {
870 let mem = self.provider.memory().allocated_bytes();
871 self.profiler
872 .record_op("union", union_input, merged.num_rows(), start, mem);
873 self.profiler.record_peak_memory(mem);
874 }
875
876 let full_new = merged;
877 // Capture `full_new`'s row count before the `store_put` move
878 // and pre-resolve `full_rel_opt` before the mutable stats
879 // borrow. The delta row count and delta relation id are only
880 // used by the trace under the `recursive-stats-trace` feature.
881 let full_new_rows_phase4 = self.buffer_row_count(&full_new)? as u64;
882 #[cfg(feature = "recursive-stats-trace")]
883 let delta_rows_phase4 = self.buffer_row_count(&delta)? as u64;
884 let full_rel_opt = self.name_to_rel_id(pred);
885 #[cfg(feature = "recursive-stats-trace")]
886 let delta_rel = delta_tracker.delta_rel_id(pred)?;
887 self.store_put(pred, full_new);
888 self.store_put(&dn, delta);
889
890 // Record the full relation's new cardinality. The delta
891 // relation was already recorded for this iteration.
892 if let Some(full_rel) = full_rel_opt {
893 self.stats
894 .update_cardinality(full_rel, full_new_rows_phase4);
895 }
896 self.refresh_kclique_edge_metadata_after_merge(rules, pred);
897
898 // Full-relation stats trace entry, gated on `recursive-stats-trace`.
899 #[cfg(feature = "recursive-stats-trace")]
900 self.last_recursive_stats_trace
901 .entries
902 .push(super::RecursiveStatsTraceEntry {
903 iteration: iteration_count,
904 pred: pred.clone(),
905 full_rel: full_rel_opt.unwrap_or(RelId(u32::MAX)),
906 delta_rel,
907 full_rows: full_new_rows_phase4,
908 delta_rows: delta_rows_phase4,
909 phase: super::RecursiveStatsPhase::Phase4Full,
910 binary_est_for_variant: None,
911 });
912 }
913 }
914
915 // Cleanup: remove delta relations from store and relation mapping.
916 for (_pred, (rel_id, delta_name)) in delta_tracker.into_inner() {
917 self.store_remove(&delta_name);
918 self.rel_names.remove(&rel_id);
919 self.name_to_rel.remove(&delta_name);
920 let _ = self.stats.unregister_relation(rel_id);
921 }
922
923 if !reached_fixpoint {
924 // Record iterations even on failure for debugging
925 self.profiler.record_iterations(iteration_count);
926 return Err(XlogError::Execution(format!(
927 "Recursive SCC iteration limit ({}) exceeded",
928 self.config.max_iterations
929 )));
930 }
931
932 Ok(())
933 }
934
935 /// Execute a Fixpoint node using semi-naive evaluation
936 ///
937 /// The semi-naive algorithm avoids redundant computation in recursive queries:
938 ///
939 /// 1. **Initialize:**
940 /// - Compute base case: `R = base_result`
941 /// - Set delta to base: `delta = R`
942 /// - Store both `R` and `delta` in RelationStore
943 ///
944 /// 2. **Iterate until fixpoint:**
945 /// - Compute new tuples: `delta_new = recursive_result` using current `delta`
946 /// - Remove already-known tuples: `delta_new = delta_new - R`
947 /// - If `delta_new` is empty, we have reached fixpoint
948 /// - Otherwise: `R = R union delta_new`, `delta = delta_new`
949 ///
950 /// 3. **Return:** Final `R`
951 ///
952 /// # Arguments
953 /// * `scc_id` - SCC identifier for logging/debugging
954 /// * `base` - Base case RIR tree (non-recursive facts/rules)
955 /// * `recursive` - Recursive RIR tree (references delta relation)
956 /// * `delta_rel` - RelId for delta relation
957 /// * `full_rel` - RelId for full relation
958 ///
959 /// # Returns
960 /// A CudaBuffer containing the final fixpoint result
961 ///
962 /// # Errors
963 /// Returns an error if iteration limit is exceeded
964 pub(super) fn execute_fixpoint(
965 &mut self,
966 scc_id: u32,
967 base: &RirNode,
968 recursive: &RirNode,
969 delta_rel: RelId,
970 full_rel: RelId,
971 ) -> Result<CudaBuffer> {
972 // Compute base case R = eval(base)
973 let r_initial = self.execute_node(base)?;
974
975 // Handle empty base case using device-resident row count
976 if self.buffer_row_count(&r_initial)? == 0 {
977 return Ok(r_initial);
978 }
979
980 // Initialize delta = R (clone the base result)
981 let delta_initial = self.clone_buffer(&r_initial)?;
982
983 // Get relation names for delta and full relations
984 let delta_name = self.get_or_create_rel_name(delta_rel, &format!("__delta_{}", scc_id));
985 let full_name = self.get_or_create_rel_name(full_rel, &format!("__full_{}", scc_id));
986
987 // Store initial R and delta in relation store
988 self.store_put(&full_name, r_initial);
989 self.store_put(&delta_name, delta_initial);
990
991 // Iterate until fixpoint
992 for _iteration in 0..Self::MAX_FIXPOINT_ITERATIONS {
993 // Evaluate recursive step using current delta
994 // The recursive RIR tree should reference delta_rel internally
995 let delta_new_raw = self.execute_node(recursive)?;
996
997 // Get current R for set difference
998 let current_r = self.store.get(&full_name).ok_or_else(|| {
999 XlogError::Execution(format!(
1000 "Full relation {} not found during fixpoint iteration",
1001 full_name
1002 ))
1003 })?;
1004
1005 // Compute delta_new = delta_new_raw - R (remove already-known tuples)
1006 let delta_new = self.provider.diff_gpu(&delta_new_raw, current_r)?;
1007
1008 // Check for fixpoint: if delta_new is empty, we are done
1009 if self.buffer_row_count(&delta_new)? == 0 {
1010 // Fixpoint reached - return final R
1011 let final_r = self.store_remove(&full_name).ok_or_else(|| {
1012 XlogError::Execution("Full relation lost during fixpoint".to_string())
1013 })?;
1014
1015 // Clean up delta relation
1016 self.store_remove(&delta_name);
1017
1018 return Ok(final_r);
1019 }
1020
1021 // Not at fixpoint yet: R = R union delta_new
1022 let new_r = self.provider.union_gpu(current_r, &delta_new)?;
1023
1024 // Update relations for next iteration
1025 // delta = delta_new (the newly discovered tuples)
1026 self.store_put(&delta_name, delta_new);
1027 self.store_put(&full_name, new_r);
1028 }
1029
1030 // Iteration limit exceeded
1031 Err(XlogError::Execution(format!(
1032 "Fixpoint iteration limit ({}) exceeded for SCC {}",
1033 Self::MAX_FIXPOINT_ITERATIONS,
1034 scc_id
1035 )))
1036 }
1037}