Skip to main content

xlog_runtime/executor/
recursive.rs

1//! Recursive SCC execution using semi-naive fixpoint iteration.
2
3use std::collections::{BTreeSet, HashMap, HashSet};
4use std::time::Instant;
5
6use xlog_core::{RelId, Result, Schema, XlogError};
7use xlog_cuda::CudaBuffer;
8use xlog_ir::{ExecutionPlan, RirNode, Stratum};
9
10use super::delta::DeltaRelationTracker;
11use super::Executor;
12
13impl Executor {
14    /// Maximum iterations for fixpoint computation to prevent infinite loops
15    const MAX_FIXPOINT_ITERATIONS: usize = 1000;
16
17    /// For a `MultiWayJoin` or `ChainJoin` body, try the specialized WCOJ
18    /// dispatchers first; on decline, fall back to the embedded fallback
19    /// subtree via `execute_node`. For any other RIR variant, defer to
20    /// `execute_node` directly.
21    ///
22    /// Used at two sites in the recursive engine: the seeding pass, where
23    /// stable rules and recursive rules get their initial dispatch on the full
24    /// body, and the per-variant loop, where each recursive scan with a
25    /// non-empty delta is rewritten to its delta `RelId` for one dispatch.
26    /// Multi-recursive bodies, including distinct recursive predicates and
27    /// same-predicate self-recursive bodies, reach a `MultiWayJoin` here after
28    /// the promoter admits bodies with more than one recursive scan; the
29    /// per-variant rewrite loop builds one variant per recursive occurrence
30    /// with a non-empty delta and dispatches each via this helper.
31    ///
32    /// Counter semantics: `wcoj_*_dispatch_count` increments once per
33    /// successful WCOJ kernel result: once per recursive rule, iteration, and
34    /// variant. Non-recursive dispatch sites increment once per rule per call.
35    fn execute_wcoj_or_fallback_node(&mut self, node: &RirNode) -> Result<CudaBuffer> {
36        if let RirNode::ChainJoin { .. } = node {
37            if let Some(buf) = self.try_dispatch_chain_on_body(node)? {
38                return Ok(buf);
39            }
40            return self.execute_node(node);
41        }
42        if let RirNode::MultiWayJoin { .. } = node {
43            // Triangle, 4-cycle, then K-clique. A body cannot
44            // match more than one paper-derived shape (different
45            // atom counts). The dispatcher's own gate handles
46            // env-var / config / adaptive decisions; this site is
47            // purely structural.
48            if let Some(buf) = self.try_dispatch_wcoj_triangle_on_body(node)? {
49                return Ok(buf);
50            }
51            if let Some(buf) = self.try_dispatch_wcoj_4cycle_on_body(node)? {
52                return Ok(buf);
53            }
54            // Recursive clique bodies use the same launch-local metadata
55            // builders as non-recursive K-clique dispatch, so rewritten
56            // semi-naive variants are eligible here too.
57            if let Some(buf) = self.try_dispatch_wcoj_clique5_on_body(node)? {
58                return Ok(buf);
59            }
60            if let Some(buf) = self.try_dispatch_wcoj_clique6_on_body(node)? {
61                return Ok(buf);
62            }
63            if let Some(buf) = self.try_dispatch_wcoj_clique7_on_body(node)? {
64                return Ok(buf);
65            }
66            if let Some(buf) = self.try_dispatch_wcoj_clique8_on_body(node)? {
67                return Ok(buf);
68            }
69            // Generalized Free Join dispatch for every multiway shape the
70            // dedicated dispatchers declined. The hook re-checks dedicated
71            // shapes structurally, so it only fires on general bodies.
72            if let Some(buf) = self.try_dispatch_free_join(node)? {
73                return Ok(buf);
74            }
75        }
76        self.execute_node(node)
77    }
78
79    fn refresh_kclique_edge_metadata_after_merge(
80        &mut self,
81        rules: &[xlog_ir::CompiledRule],
82        pred: &str,
83    ) {
84        let start = Instant::now();
85        let affected_rules = rules
86            .iter()
87            .filter(|rule| self.kclique_body_mentions_pred(&rule.body, pred))
88            .count() as u64;
89        self.record_kclique_histogram_refresh_time(start, affected_rules);
90    }
91
92    fn record_kclique_histogram_refresh_time(&mut self, start: Instant, affected_rules: u64) {
93        if affected_rules == 0 {
94            return;
95        }
96        self.kclique_histogram_refresh_count = self
97            .kclique_histogram_refresh_count
98            .saturating_add(affected_rules);
99        self.kclique_histogram_refresh_nanos = self
100            .kclique_histogram_refresh_nanos
101            .saturating_add(start.elapsed().as_nanos());
102    }
103
104    fn kclique_body_mentions_pred(&self, node: &RirNode, pred: &str) -> bool {
105        let RirNode::MultiWayJoin {
106            inputs, var_order, ..
107        } = node
108        else {
109            return false;
110        };
111        let Some(order) = var_order.as_ref().and_then(|order| order.kclique.as_ref()) else {
112            return false;
113        };
114        if !matches!(order.k, 5..=8) {
115            return false;
116        }
117        inputs.iter().any(|input| {
118            let RirNode::Scan { rel } = input else {
119                return false;
120            };
121            self.rel_names.get(rel).is_some_and(|name| name == pred)
122        })
123    }
124
125    /// Stub: always returns an error directing callers to use `execute_plan` instead.
126    pub fn execute_stratum(&mut self, _stratum: &Stratum) -> Result<()> {
127        Err(XlogError::Execution(
128            "execute_stratum cannot be called directly; use execute_plan instead which provides \
129             the required rules_by_scc context"
130                .to_string(),
131        ))
132    }
133
134    /// Execute all rules in a non-recursive strongly connected component once.
135    pub fn execute_non_recursive_scc(&mut self, rules: &[xlog_ir::CompiledRule]) -> Result<()> {
136        for rule in rules {
137            let result = self.execute_node(&rule.body)?;
138
139            if let Some(existing) = self.store.get(&rule.head) {
140                if result.is_empty() {
141                    continue;
142                }
143                let merged = self.provider.union_gpu(existing, &result)?;
144                self.store_put(&rule.head, merged);
145            } else {
146                let key_cols: Vec<usize> = (0..result.arity()).collect();
147                let deduped = if result.is_empty() {
148                    result
149                } else {
150                    self.provider.dedup(&result, &key_cols)?
151                };
152                self.store_put(&rule.head, deduped);
153            }
154        }
155        Ok(())
156    }
157
158    /// Execute a stratum (internal implementation)
159    ///
160    /// Processes all SCCs in the stratum by executing their rules.
161    /// For recursive SCCs, uses semi-naive fixpoint iteration.
162    pub(super) fn execute_stratum_impl(
163        &mut self,
164        stratum: &Stratum,
165        plan: &ExecutionPlan,
166    ) -> Result<()> {
167        // Process each SCC in the stratum
168        for &scc_id in &stratum.sccs {
169            // Get rules for this SCC
170            if let Some(rules) = plan.rules_by_scc.get(scc_id as usize) {
171                // Get SCC metadata
172                let scc = plan.sccs.get(scc_id as usize);
173                let is_recursive = scc.map(|s| s.is_recursive).unwrap_or(false);
174
175                if is_recursive {
176                    // Recursive SCC: use semi-naive fixpoint iteration. The
177                    // recursive engine invokes WCOJ dispatch via
178                    // `execute_wcoj_or_fallback_node` on both the seeding
179                    // pass and per-variant evaluation when the promoted body
180                    // shape is eligible.
181                    self.execute_recursive_scc(rules)?;
182                } else {
183                    // Non-recursive SCC: execute rules once, union results for same predicate.
184                    for rule in rules {
185                        // Route two-atom ChainJoin bodies before the
186                        // triangle/4-cycle/KC attempts. The dispatcher
187                        // silently declines on non-chain bodies or when
188                        // the env gate disables the route.
189                        if let Some(chain_result) = self.try_dispatch_chain_on_body(&rule.body)? {
190                            if let Some(existing) = self.store.get(&rule.head) {
191                                let merged = self.provider.union_gpu(existing, &chain_result)?;
192                                self.store_put(&rule.head, merged);
193                            } else {
194                                let key_cols: Vec<usize> = (0..chain_result.arity()).collect();
195                                let deduped = if chain_result.is_empty() {
196                                    chain_result
197                                } else {
198                                    let dedup_input_rows = chain_result.num_rows();
199                                    let start = self.profiler.start_op();
200                                    let deduped = self.provider.dedup(&chain_result, &key_cols)?;
201                                    if let Some(start) = start {
202                                        let mem = self.provider.memory().allocated_bytes();
203                                        self.profiler.record_op(
204                                            "dedup",
205                                            dedup_input_rows,
206                                            deduped.num_rows(),
207                                            start,
208                                            mem,
209                                        );
210                                        self.profiler.record_peak_memory(mem);
211                                    }
212                                    deduped
213                                };
214                                self.store_put(&rule.head, deduped);
215                            }
216                            continue;
217                        }
218
219                        // WCOJ triangle dispatch, gated by runtime configuration.
220                        // Try to short-circuit the rule via the GPU
221                        // 3-way kernel. On Some(_), install the
222                        // result and skip the binary-join path for
223                        // this rule. On None (gate off, shape
224                        // mismatch, missing input, kernel error),
225                        // fall through silently. See
226                        // `wcoj_dispatch::try_dispatch_wcoj_triangle`
227                        // for the full match contract.
228                        if let Some(wcoj_result) = self.try_dispatch_wcoj_triangle(rule)? {
229                            // Mirrors the binary-join arm below:
230                            // union with existing result if predicate
231                            // already has data; otherwise install
232                            // directly. WCOJ output is already
233                            // sorted+deduped, so the dedup pass on
234                            // the else branch is unnecessary here.
235                            if let Some(existing) = self.store.get(&rule.head) {
236                                let merged = self.provider.union_gpu(existing, &wcoj_result)?;
237                                self.store_put(&rule.head, merged);
238                            } else {
239                                self.store_put(&rule.head, wcoj_result);
240                            }
241                            continue;
242                        }
243
244                        // WCOJ 4-cycle dispatch.
245                        // Same pattern as triangle. Order is a doc
246                        // anchor — a body cannot match both shapes
247                        // (different atom counts), so triangle's
248                        // earlier attempt always returns None on a
249                        // 4-cycle body and vice versa.
250                        if let Some(wcoj_result) = self.try_dispatch_wcoj_4cycle(rule)? {
251                            if let Some(existing) = self.store.get(&rule.head) {
252                                let merged = self.provider.union_gpu(existing, &wcoj_result)?;
253                                self.store_put(&rule.head, merged);
254                            } else {
255                                self.store_put(&rule.head, wcoj_result);
256                            }
257                            continue;
258                        }
259
260                        // K-clique dispatch for k=5..k=8.
261                        // Same shape-gated default-dispatch
262                        // pattern as triangle / 4-cycle; silent
263                        // fallback to MultiWayJoin.fallback on
264                        // dispatcher decline or kernel error.
265                        if let Some(wcoj_result) = self.try_dispatch_wcoj_clique5(rule)? {
266                            if let Some(existing) = self.store.get(&rule.head) {
267                                let merged = self.provider.union_gpu(existing, &wcoj_result)?;
268                                self.store_put(&rule.head, merged);
269                            } else {
270                                self.store_put(&rule.head, wcoj_result);
271                            }
272                            continue;
273                        }
274                        if let Some(wcoj_result) = self.try_dispatch_wcoj_clique6(rule)? {
275                            if let Some(existing) = self.store.get(&rule.head) {
276                                let merged = self.provider.union_gpu(existing, &wcoj_result)?;
277                                self.store_put(&rule.head, merged);
278                            } else {
279                                self.store_put(&rule.head, wcoj_result);
280                            }
281                            continue;
282                        }
283                        if let Some(wcoj_result) = self.try_dispatch_wcoj_clique7(rule)? {
284                            if let Some(existing) = self.store.get(&rule.head) {
285                                let merged = self.provider.union_gpu(existing, &wcoj_result)?;
286                                self.store_put(&rule.head, merged);
287                            } else {
288                                self.store_put(&rule.head, wcoj_result);
289                            }
290                            continue;
291                        }
292                        if let Some(wcoj_result) = self.try_dispatch_wcoj_clique8(rule)? {
293                            if let Some(existing) = self.store.get(&rule.head) {
294                                let merged = self.provider.union_gpu(existing, &wcoj_result)?;
295                                self.store_put(&rule.head, merged);
296                            } else {
297                                self.store_put(&rule.head, wcoj_result);
298                            }
299                            continue;
300                        }
301
302                        // Generalized Free Join dispatch for every multiway
303                        // shape the dedicated dispatchers above declined. The
304                        // dispatcher re-checks those shapes structurally, so
305                        // it only fires on general bodies. Unlike the
306                        // dedicated kernels, the frontier engine emits one row
307                        // per derivation path, so the install mirrors the
308                        // binary-join arm below: `union_gpu` dedups, and
309                        // fresh installs dedup explicitly.
310                        if let Some(fj_result) = self.try_dispatch_free_join(&rule.body)? {
311                            if let Some(existing) = self.store.get(&rule.head) {
312                                let merged = self.provider.union_gpu(existing, &fj_result)?;
313                                self.store_put(&rule.head, merged);
314                            } else {
315                                let key_cols: Vec<usize> = (0..fj_result.arity()).collect();
316                                let deduped = if fj_result.is_empty() {
317                                    fj_result
318                                } else {
319                                    self.provider.dedup(&fj_result, &key_cols)?
320                                };
321                                self.store_put(&rule.head, deduped);
322                            }
323                            continue;
324                        }
325
326                        // When WCOJ dispatch declines on a `MultiWayJoin`
327                        // body (gate off, kernel error, adaptive score below
328                        // threshold, ...), execute the embedded `fallback`,
329                        // the post-optimizer binary-join tree the promoter
330                        // captured. `execute_node`'s `MultiWayJoin` arm is the
331                        // defensive safety net; explicit destructuring here
332                        // keeps the intent visible at the dispatch site.
333                        let body_to_execute = match &rule.body {
334                            xlog_ir::RirNode::MultiWayJoin { fallback, .. }
335                            | xlog_ir::RirNode::ChainJoin { fallback, .. } => fallback.as_ref(),
336                            other => other,
337                        };
338                        let result = self.execute_node(body_to_execute)?;
339
340                        // Union with existing result if predicate already has data
341                        if let Some(existing) = self.store.get(&rule.head) {
342                            let union_input_rows = existing.num_rows() + result.num_rows();
343                            let start = self.profiler.start_op();
344                            let merged = self.provider.union_gpu(existing, &result)?;
345                            if let Some(start) = start {
346                                let mem = self.provider.memory().allocated_bytes();
347                                self.profiler.record_op(
348                                    "union",
349                                    union_input_rows,
350                                    merged.num_rows(),
351                                    start,
352                                    mem,
353                                );
354                                self.profiler.record_peak_memory(mem);
355                            }
356                            self.store_put(&rule.head, merged);
357                        } else {
358                            let key_cols: Vec<usize> = (0..result.arity()).collect();
359                            let deduped = if result.is_empty() {
360                                result
361                            } else {
362                                let dedup_input_rows = result.num_rows();
363                                let start = self.profiler.start_op();
364                                let deduped = self.provider.dedup(&result, &key_cols)?;
365                                if let Some(start) = start {
366                                    let mem = self.provider.memory().allocated_bytes();
367                                    self.profiler.record_op(
368                                        "dedup",
369                                        dedup_input_rows,
370                                        deduped.num_rows(),
371                                        start,
372                                        mem,
373                                    );
374                                    self.profiler.record_peak_memory(mem);
375                                }
376                                deduped
377                            };
378                            self.store_put(&rule.head, deduped);
379                        }
380                    }
381                }
382            }
383        }
384
385        Ok(())
386    }
387
388    /// Execute a recursive SCC using semi-naive fixpoint iteration
389    ///
390    /// The algorithm:
391    /// 1. Execute all rules once to get initial result
392    /// 2. Track which relations changed (delta)
393    /// 3. Re-execute rules, using delta from previous iteration
394    /// 4. Repeat until no changes (fixpoint reached)
395    pub fn execute_recursive_scc(&mut self, rules: &[xlog_ir::CompiledRule]) -> Result<()> {
396        // Reset the per-iteration stats trace at SCC entry so tests see a
397        // fresh trace per invocation. Gated on the `recursive-stats-trace`
398        // feature; default OFF.
399        #[cfg(feature = "recursive-stats-trace")]
400        {
401            self.last_recursive_stats_trace.entries.clear();
402        }
403        // Identify SCC predicates from rule heads (these are the recursive IDBs).
404        let mut recursive_pred_names: BTreeSet<String> = BTreeSet::new();
405        let mut schema_by_pred: HashMap<String, Schema> = HashMap::new();
406        for rule in rules {
407            recursive_pred_names.insert(rule.head.clone());
408            if rule.meta.schema.arity() > 0 {
409                schema_by_pred
410                    .entry(rule.head.clone())
411                    .or_insert_with(|| rule.meta.schema.clone());
412            }
413        }
414        let recursive_pred_lookup: HashSet<String> = recursive_pred_names.iter().cloned().collect();
415        let recursive_preds: Vec<String> = recursive_pred_names.into_iter().collect();
416
417        // Ensure all recursive predicates exist in the store so scans never fail
418        // due to evaluation order (mutual recursion can reference an as-yet-empty relation).
419        for pred in &recursive_preds {
420            if !self.store.contains(pred) {
421                let schema = schema_by_pred
422                    .get(pred)
423                    .cloned()
424                    .or_else(|| self.store.get(pred).map(|b| b.schema().clone()))
425                    .ok_or_else(|| {
426                        XlogError::Execution(format!(
427                            "Missing schema for recursive predicate {}",
428                            pred
429                        ))
430                    })?;
431                let empty = self.create_empty_buffer(schema)?;
432                self.store_put(pred, empty);
433            }
434        }
435
436        // Create per-predicate delta relations (distinct RelIds) so semi-naive evaluation
437        // can target a single recursive Scan occurrence without overriding *all* scans of
438        // that predicate in a rule (required for self-joins like p(X,Y), p(Y,Z)).
439        let mut next_rel_id = self
440            .rel_names
441            .keys()
442            .map(|r| r.0)
443            .max()
444            .unwrap_or(0)
445            .saturating_add(1);
446
447        let mut delta_tracker = DeltaRelationTracker::new();
448        for pred in &recursive_preds {
449            let rel_id = RelId(next_rel_id);
450            next_rel_id = next_rel_id.saturating_add(1);
451            let name = format!("__delta_{}_{}", pred, rel_id.0);
452            self.register_relation(rel_id, &name);
453            delta_tracker.insert(pred.clone(), rel_id, name);
454        }
455
456        // Execute all rules once against the current store to seed initial results.
457        // Accumulate per-head before mutating the store to avoid order dependence.
458        //
459        // Route through `execute_wcoj_or_fallback_node` so promoted
460        // MultiWayJoin bodies for stable and linear-recursive triangles or
461        // 4-cycles get a chance at WCOJ dispatch on the seeding pass. Stable
462        // rules with zero recursive scans only run here, so without this hook
463        // they would never see a kernel.
464        let mut derived_initial: HashMap<String, CudaBuffer> = HashMap::new();
465        for rule in rules {
466            let result = self.execute_wcoj_or_fallback_node(&rule.body)?;
467            if let Some(acc) = derived_initial.get_mut(&rule.head) {
468                let union_input = acc.num_rows() + result.num_rows();
469                let start = self.profiler.start_op();
470                let merged = self.provider.union_gpu(acc, &result)?;
471                if let Some(start) = start {
472                    let mem = self.provider.memory().allocated_bytes();
473                    self.profiler
474                        .record_op("union", union_input, merged.num_rows(), start, mem);
475                    self.profiler.record_peak_memory(mem);
476                }
477                *acc = merged;
478            } else {
479                derived_initial.insert(rule.head.clone(), result);
480            }
481        }
482
483        // Initialize delta from the newly-derived tuples only.
484        //
485        // This supports incremental maintenance: if the SCC is executed again after EDB inserts,
486        // the delta relations start with only the *new* tuples, not a full rescan of the current
487        // fixed point.
488        for pred in &recursive_preds {
489            let full_old = self
490                .store
491                .remove(pred)
492                .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", pred)))?;
493
494            let derived = match derived_initial.remove(pred) {
495                Some(buf) => buf,
496                None => self.create_empty_buffer(full_old.schema().clone())?,
497            };
498
499            let union_input = full_old.num_rows() + derived.num_rows();
500            let start = self.profiler.start_op();
501            let merged = self.provider.union_gpu(&full_old, &derived)?;
502            if let Some(start) = start {
503                let mem = self.provider.memory().allocated_bytes();
504                self.profiler
505                    .record_op("union", union_input, merged.num_rows(), start, mem);
506                self.profiler.record_peak_memory(mem);
507            }
508
509            let full_new = merged;
510
511            let delta_name = delta_tracker.delta_name(pred)?;
512
513            let full_old_rows = self.buffer_row_count(&full_old)?;
514            let full_new_rows = self.buffer_row_count(&full_new)?;
515            let delta_initial = if full_new_rows == 0 {
516                self.create_empty_buffer(full_new.schema().clone())?
517            } else if full_old_rows == 0 {
518                self.clone_buffer(&full_new)?
519            } else {
520                let diff_input = full_new.num_rows() + full_old.num_rows();
521                let start = self.profiler.start_op();
522                let diffed = self.provider.diff_gpu(&full_new, &full_old)?;
523                if let Some(start) = start {
524                    let mem = self.provider.memory().allocated_bytes();
525                    self.profiler
526                        .record_op("diff", diff_input, diffed.num_rows(), start, mem);
527                    self.profiler.record_peak_memory(mem);
528                }
529                diffed
530            };
531
532            // Seed-iteration cardinality refresh. Capture the actual
533            // `delta_initial` row count before the `store_put` move; after the
534            // move, the buffer is gone.
535            let delta_initial_rows = self.buffer_row_count(&delta_initial)? as u64;
536            let seed_full_rows = full_new_rows as u64;
537            // Pre-resolve rel_id lookups before the &mut self stats
538            // borrow below.
539            let full_rel_opt = self.name_to_rel_id(pred);
540            let delta_rel = delta_tracker.delta_rel_id(pred)?;
541
542            self.store_put(pred, full_new);
543            self.store_put(delta_name, delta_initial);
544
545            // Stats updates fire whether or not WCOJ ran on the seed
546            // pass. update_cardinality is a no-op for unregistered
547            // rel_ids (defensive: tests that don't register an IDB
548            // head get a no-op for the full_rel write).
549            if let Some(full_rel) = full_rel_opt {
550                self.stats.update_cardinality(full_rel, seed_full_rows);
551            }
552            self.stats.update_cardinality(delta_rel, delta_initial_rows);
553
554            // Seed stats trace entry, gated on `recursive-stats-trace`.
555            #[cfg(feature = "recursive-stats-trace")]
556            self.last_recursive_stats_trace
557                .entries
558                .push(super::RecursiveStatsTraceEntry {
559                    iteration: 0,
560                    pred: pred.clone(),
561                    full_rel: full_rel_opt.unwrap_or(RelId(u32::MAX)),
562                    delta_rel,
563                    full_rows: seed_full_rows,
564                    delta_rows: delta_initial_rows,
565                    phase: super::RecursiveStatsPhase::Seed,
566                    binary_est_for_variant: None,
567                });
568        }
569
570        // Iterate until no new tuples are produced.
571        let mut reached_fixpoint = false;
572        let max_iterations = self.config.max_iterations as usize;
573        let mut iteration_count = 0usize;
574        // D3 — per-fixpoint dispatch context for the factorized delta
575        // (domain bounds + normalized EDB statics are cached across
576        // iterations).
577        let mut fd_ctx = super::wcoj_dispatch::FactorizedDeltaCtx::default();
578        for _iteration in 0..max_iterations {
579            iteration_count += 1;
580            // Compute delta_new_raw per head by evaluating each rule once per recursive Scan occurrence.
581            let mut delta_new_raw_by_head: HashMap<String, CudaBuffer> = HashMap::new();
582            // D3 — factorized novel sets per head: already diffed
583            // against the stable relation and full-row deduped at
584            // dispatch time. Kept separate from the raw accumulator so
585            // all-factorized heads can skip the legacy diff entirely.
586            let mut delta_novel_by_head: HashMap<String, CudaBuffer> = HashMap::new();
587
588            for rule in rules {
589                let mut scans = Vec::new();
590                Self::collect_scan_rels(&rule.body, &mut scans);
591
592                // Build a list of (rel_id, occurrence_idx, pred_name) for recursive scans.
593                let mut seen: HashMap<RelId, usize> = HashMap::new();
594                let mut variants: Vec<(RelId, usize, String)> = Vec::new();
595                for rel_id in scans {
596                    let pred_name = match self.get_rel_name(rel_id) {
597                        Some(n) => n.to_string(),
598                        None => continue,
599                    };
600                    if !recursive_pred_lookup.contains(&pred_name) {
601                        continue;
602                    }
603
604                    // Skip variants where the delta for this predicate is empty.
605                    let delta_name = match delta_tracker.get(&pred_name) {
606                        Some((_rel_id, name)) => name.as_str(),
607                        None => continue,
608                    };
609                    let delta_is_empty = match self.store.get(delta_name) {
610                        Some(delta) => self.buffer_row_count(delta)? == 0,
611                        None => true,
612                    };
613                    if delta_is_empty {
614                        continue;
615                    }
616
617                    let occ = seen.entry(rel_id).or_insert(0);
618                    variants.push((rel_id, *occ, pred_name));
619                    *occ += 1;
620                }
621
622                if variants.is_empty() {
623                    // Base rule: it can only contribute on the first seeding pass.
624                    continue;
625                }
626
627                let mut rule_delta_raw: Option<CudaBuffer> = None;
628                let mut rule_delta_novel: Option<CudaBuffer> = None;
629                for (rel_id, occ, pred_name) in variants {
630                    let delta_rel_id = delta_tracker.delta_rel_id(&pred_name)?;
631
632                    let variant_node =
633                        Self::rewrite_scan_nth(&rule.body, rel_id, occ, delta_rel_id).ok_or_else(
634                            || {
635                                XlogError::Execution(format!(
636                                    "Failed to rewrite rule body for predicate {}",
637                                    pred_name
638                                ))
639                            },
640                        )?;
641
642                    // Try the factorized delta pipeline first: a qualifying
643                    // ChainJoin variant returns the novel set directly
644                    // (already diffed against the head's stable relation and
645                    // deduped). Declines are silent and fall through to the
646                    // legacy path.
647                    if let Some(novel) = self.try_dispatch_factorized_delta(
648                        &variant_node,
649                        delta_rel_id,
650                        &rule.head,
651                        &recursive_pred_lookup,
652                        &mut fd_ctx,
653                    )? {
654                        rule_delta_novel = Some(match rule_delta_novel {
655                            Some(acc) => self.provider.union_gpu(&acc, &novel)?,
656                            None => novel,
657                        });
658                        continue;
659                    }
660
661                    // Try WCOJ on the rewritten variant body before falling
662                    // back to the binary-join walker.
663                    // For a linear-recursive triangle/4-cycle, the
664                    // variant has one Scan's RelId swapped to its
665                    // delta — the kernel reads from the delta store
666                    // entry transparently, no special-case dispatch
667                    // logic needed.
668                    let out = self.execute_wcoj_or_fallback_node(&variant_node)?;
669                    rule_delta_raw = Some(if let Some(acc) = rule_delta_raw {
670                        let union_input = acc.num_rows() + out.num_rows();
671                        let start = self.profiler.start_op();
672                        let merged = self.provider.union_gpu(&acc, &out)?;
673                        if let Some(start) = start {
674                            let mem = self.provider.memory().allocated_bytes();
675                            self.profiler.record_op(
676                                "union",
677                                union_input,
678                                merged.num_rows(),
679                                start,
680                                mem,
681                            );
682                            self.profiler.record_peak_memory(mem);
683                        }
684                        merged
685                    } else {
686                        out
687                    });
688                }
689
690                // D3 — a rule with BOTH factorized and legacy variant
691                // outputs folds its novel set into the raw accumulator
692                // (the legacy diff is a no-op on novel rows, so this is
693                // sound); an all-factorized rule keeps its novel set on
694                // the diff-free track.
695                if rule_delta_raw.is_some() {
696                    if let Some(novel) = rule_delta_novel.take() {
697                        let raw = rule_delta_raw.as_ref().expect("checked above");
698                        rule_delta_raw = Some(self.provider.union_gpu(raw, &novel)?);
699                    }
700                }
701                if let Some(rule_out) = rule_delta_raw {
702                    if let Some(acc) = delta_new_raw_by_head.get_mut(&rule.head) {
703                        let union_input = acc.num_rows() + rule_out.num_rows();
704                        let start = self.profiler.start_op();
705                        let merged = self.provider.union_gpu(acc, &rule_out)?;
706                        if let Some(start) = start {
707                            let mem = self.provider.memory().allocated_bytes();
708                            self.profiler.record_op(
709                                "union",
710                                union_input,
711                                merged.num_rows(),
712                                start,
713                                mem,
714                            );
715                            self.profiler.record_peak_memory(mem);
716                        }
717                        *acc = merged;
718                    } else {
719                        delta_new_raw_by_head.insert(rule.head.clone(), rule_out);
720                    }
721                }
722                if let Some(rule_novel) = rule_delta_novel {
723                    if let Some(acc) = delta_novel_by_head.get_mut(&rule.head) {
724                        *acc = self.provider.union_gpu(acc, &rule_novel)?;
725                    } else {
726                        delta_novel_by_head.insert(rule.head.clone(), rule_novel);
727                    }
728                }
729            }
730
731            // Finalize delta_new per head: delta_new = dedup(delta_raw - full).
732            delta_tracker.begin_iteration();
733
734            for pred in &recursive_preds {
735                let full = self
736                    .store
737                    .get(pred)
738                    .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", pred)))?;
739                // Capture the current full row count for the trace's
740                // `full_rows` field before this iteration's delta relation is
741                // replaced. Gated on `recursive-stats-trace` so production
742                // builds do not compute it.
743                #[cfg(feature = "recursive-stats-trace")]
744                let pre_phase4_full_rows = self.buffer_row_count(full)? as u64;
745
746                let delta_raw = delta_new_raw_by_head.remove(pred);
747                let delta_novel = delta_novel_by_head.remove(pred);
748                // D3 — when a head received both raw and factorized
749                // contributions (different rules), fold the novel set
750                // into the raw side before the legacy diff (sound: the
751                // diff is a no-op on novel rows). An all-factorized
752                // head skips the diff entirely — its novel set is
753                // already diffed and deduped by construction.
754                let (delta_raw, delta_novel) = match (delta_raw, delta_novel) {
755                    (Some(raw), Some(novel)) => {
756                        (Some(self.provider.union_gpu(&raw, &novel)?), None)
757                    }
758                    other => other,
759                };
760                let delta_new = if let Some(novel) = delta_novel {
761                    novel
762                } else if let Some(delta_raw) = delta_raw {
763                    if self.buffer_row_count(&delta_raw)? == 0 {
764                        self.create_empty_buffer(full.schema().clone())?
765                    } else {
766                        let diff_input = delta_raw.num_rows() + full.num_rows();
767                        let start = self.profiler.start_op();
768                        let diffed = self.provider.diff_gpu(&delta_raw, full)?;
769                        if let Some(start) = start {
770                            let mem = self.provider.memory().allocated_bytes();
771                            self.profiler.record_op(
772                                "diff",
773                                diff_input,
774                                diffed.num_rows(),
775                                start,
776                                mem,
777                            );
778                            self.profiler.record_peak_memory(mem);
779                        }
780                        diffed
781                    }
782                } else {
783                    self.create_empty_buffer(full.schema().clone())?
784                };
785
786                let delta_name = delta_tracker.delta_name(pred)?.to_string();
787                let delta_new_rows = self.buffer_row_count(&delta_new)? as u64;
788                if delta_new_rows != 0 {
789                    delta_tracker.mark_changed();
790                }
791                // Pre-resolve rel_id lookups before the &mut self
792                // store_put + stats update below. `full_rel_opt` is
793                // only used by the trace under the
794                // `recursive-stats-trace` feature.
795                #[cfg(feature = "recursive-stats-trace")]
796                let full_rel_opt = self.name_to_rel_id(pred);
797                let delta_rel = delta_tracker.delta_rel_id(pred)?;
798                self.store_put(&delta_name, delta_new);
799
800                // Refresh the delta relation cardinality after computing this
801                // iteration's delta. The full relation cardinality is not
802                // updated here because the full relation has not changed yet
803                // this iteration; the merge step owns that.
804                self.stats.update_cardinality(delta_rel, delta_new_rows);
805
806                // Delta stats trace entry, gated on `recursive-stats-trace`.
807                // binary_est_for_variant captures the cost model's
808                // first-binary-hop estimate for the linear-recursive fixtures
809                // (`pred == "e1"` rewrites
810                // Scan(e1) → Scan(delta_e1); first hop is
811                // `delta_e1.col1 ⋈ e2.col0`). Populated inline because
812                // delta_rel is unregistered at fixpoint exit, so the
813                // test cannot recompute after `execute_plan` returns.
814                #[cfg(feature = "recursive-stats-trace")]
815                let binary_est_for_variant: Option<u64> = if pred == "e1" {
816                    self.name_to_rel_id("e2").map(|e2_rel| {
817                        self.stats
818                            .estimate_join_cardinality(delta_rel, e2_rel, &[1], &[0])
819                    })
820                } else {
821                    None
822                };
823                #[cfg(feature = "recursive-stats-trace")]
824                self.last_recursive_stats_trace
825                    .entries
826                    .push(super::RecursiveStatsTraceEntry {
827                        iteration: iteration_count,
828                        pred: pred.clone(),
829                        full_rel: full_rel_opt.unwrap_or(RelId(u32::MAX)),
830                        delta_rel,
831                        full_rows: pre_phase4_full_rows,
832                        delta_rows: delta_new_rows,
833                        phase: super::RecursiveStatsPhase::Phase2Delta,
834                        binary_est_for_variant,
835                    });
836            }
837
838            // Fixpoint reached if no deltas produced.
839            if delta_tracker.is_converged() {
840                reached_fixpoint = true;
841                self.profiler.record_iterations(iteration_count);
842                break;
843            }
844
845            // Merge deltas into full relations.
846            for pred in &recursive_preds {
847                let full_old = self
848                    .store
849                    .remove(pred)
850                    .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", pred)))?;
851                let dn = delta_tracker.delta_name(pred)?.to_string();
852                let delta = self
853                    .store_remove(&dn)
854                    .ok_or_else(|| XlogError::Execution(format!("Missing relation: {}", dn)))?;
855
856                if self.buffer_row_count(&delta)? == 0 {
857                    // Zero-delta short-circuit: full and delta are unchanged
858                    // this iteration. The delta relation record with zero
859                    // rows stands, and the full relation record from the prior
860                    // merge stands. No additional update is needed.
861                    self.store_put(pred, full_old);
862                    self.store_put(&dn, delta);
863                    continue;
864                }
865
866                let union_input = full_old.num_rows() + delta.num_rows();
867                let start = self.profiler.start_op();
868                let merged = self.provider.union_gpu(&full_old, &delta)?;
869                if let Some(start) = start {
870                    let mem = self.provider.memory().allocated_bytes();
871                    self.profiler
872                        .record_op("union", union_input, merged.num_rows(), start, mem);
873                    self.profiler.record_peak_memory(mem);
874                }
875
876                let full_new = merged;
877                // Capture `full_new`'s row count before the `store_put` move
878                // and pre-resolve `full_rel_opt` before the mutable stats
879                // borrow. The delta row count and delta relation id are only
880                // used by the trace under the `recursive-stats-trace` feature.
881                let full_new_rows_phase4 = self.buffer_row_count(&full_new)? as u64;
882                #[cfg(feature = "recursive-stats-trace")]
883                let delta_rows_phase4 = self.buffer_row_count(&delta)? as u64;
884                let full_rel_opt = self.name_to_rel_id(pred);
885                #[cfg(feature = "recursive-stats-trace")]
886                let delta_rel = delta_tracker.delta_rel_id(pred)?;
887                self.store_put(pred, full_new);
888                self.store_put(&dn, delta);
889
890                // Record the full relation's new cardinality. The delta
891                // relation was already recorded for this iteration.
892                if let Some(full_rel) = full_rel_opt {
893                    self.stats
894                        .update_cardinality(full_rel, full_new_rows_phase4);
895                }
896                self.refresh_kclique_edge_metadata_after_merge(rules, pred);
897
898                // Full-relation stats trace entry, gated on `recursive-stats-trace`.
899                #[cfg(feature = "recursive-stats-trace")]
900                self.last_recursive_stats_trace
901                    .entries
902                    .push(super::RecursiveStatsTraceEntry {
903                        iteration: iteration_count,
904                        pred: pred.clone(),
905                        full_rel: full_rel_opt.unwrap_or(RelId(u32::MAX)),
906                        delta_rel,
907                        full_rows: full_new_rows_phase4,
908                        delta_rows: delta_rows_phase4,
909                        phase: super::RecursiveStatsPhase::Phase4Full,
910                        binary_est_for_variant: None,
911                    });
912            }
913        }
914
915        // Cleanup: remove delta relations from store and relation mapping.
916        for (_pred, (rel_id, delta_name)) in delta_tracker.into_inner() {
917            self.store_remove(&delta_name);
918            self.rel_names.remove(&rel_id);
919            self.name_to_rel.remove(&delta_name);
920            let _ = self.stats.unregister_relation(rel_id);
921        }
922
923        if !reached_fixpoint {
924            // Record iterations even on failure for debugging
925            self.profiler.record_iterations(iteration_count);
926            return Err(XlogError::Execution(format!(
927                "Recursive SCC iteration limit ({}) exceeded",
928                self.config.max_iterations
929            )));
930        }
931
932        Ok(())
933    }
934
935    /// Execute a Fixpoint node using semi-naive evaluation
936    ///
937    /// The semi-naive algorithm avoids redundant computation in recursive queries:
938    ///
939    /// 1. **Initialize:**
940    ///    - Compute base case: `R = base_result`
941    ///    - Set delta to base: `delta = R`
942    ///    - Store both `R` and `delta` in RelationStore
943    ///
944    /// 2. **Iterate until fixpoint:**
945    ///    - Compute new tuples: `delta_new = recursive_result` using current `delta`
946    ///    - Remove already-known tuples: `delta_new = delta_new - R`
947    ///    - If `delta_new` is empty, we have reached fixpoint
948    ///    - Otherwise: `R = R union delta_new`, `delta = delta_new`
949    ///
950    /// 3. **Return:** Final `R`
951    ///
952    /// # Arguments
953    /// * `scc_id` - SCC identifier for logging/debugging
954    /// * `base` - Base case RIR tree (non-recursive facts/rules)
955    /// * `recursive` - Recursive RIR tree (references delta relation)
956    /// * `delta_rel` - RelId for delta relation
957    /// * `full_rel` - RelId for full relation
958    ///
959    /// # Returns
960    /// A CudaBuffer containing the final fixpoint result
961    ///
962    /// # Errors
963    /// Returns an error if iteration limit is exceeded
964    pub(super) fn execute_fixpoint(
965        &mut self,
966        scc_id: u32,
967        base: &RirNode,
968        recursive: &RirNode,
969        delta_rel: RelId,
970        full_rel: RelId,
971    ) -> Result<CudaBuffer> {
972        // Compute base case R = eval(base)
973        let r_initial = self.execute_node(base)?;
974
975        // Handle empty base case using device-resident row count
976        if self.buffer_row_count(&r_initial)? == 0 {
977            return Ok(r_initial);
978        }
979
980        // Initialize delta = R (clone the base result)
981        let delta_initial = self.clone_buffer(&r_initial)?;
982
983        // Get relation names for delta and full relations
984        let delta_name = self.get_or_create_rel_name(delta_rel, &format!("__delta_{}", scc_id));
985        let full_name = self.get_or_create_rel_name(full_rel, &format!("__full_{}", scc_id));
986
987        // Store initial R and delta in relation store
988        self.store_put(&full_name, r_initial);
989        self.store_put(&delta_name, delta_initial);
990
991        // Iterate until fixpoint
992        for _iteration in 0..Self::MAX_FIXPOINT_ITERATIONS {
993            // Evaluate recursive step using current delta
994            // The recursive RIR tree should reference delta_rel internally
995            let delta_new_raw = self.execute_node(recursive)?;
996
997            // Get current R for set difference
998            let current_r = self.store.get(&full_name).ok_or_else(|| {
999                XlogError::Execution(format!(
1000                    "Full relation {} not found during fixpoint iteration",
1001                    full_name
1002                ))
1003            })?;
1004
1005            // Compute delta_new = delta_new_raw - R (remove already-known tuples)
1006            let delta_new = self.provider.diff_gpu(&delta_new_raw, current_r)?;
1007
1008            // Check for fixpoint: if delta_new is empty, we are done
1009            if self.buffer_row_count(&delta_new)? == 0 {
1010                // Fixpoint reached - return final R
1011                let final_r = self.store_remove(&full_name).ok_or_else(|| {
1012                    XlogError::Execution("Full relation lost during fixpoint".to_string())
1013                })?;
1014
1015                // Clean up delta relation
1016                self.store_remove(&delta_name);
1017
1018                return Ok(final_r);
1019            }
1020
1021            // Not at fixpoint yet: R = R union delta_new
1022            let new_r = self.provider.union_gpu(current_r, &delta_new)?;
1023
1024            // Update relations for next iteration
1025            // delta = delta_new (the newly discovered tuples)
1026            self.store_put(&delta_name, delta_new);
1027            self.store_put(&full_name, new_r);
1028        }
1029
1030        // Iteration limit exceeded
1031        Err(XlogError::Execution(format!(
1032            "Fixpoint iteration limit ({}) exceeded for SCC {}",
1033            Self::MAX_FIXPOINT_ITERATIONS,
1034            scc_id
1035        )))
1036    }
1037}