Skip to main content

xlog_runtime/executor/
rewrite.rs

1//! Tree rewriting and incremental delta recomputation.
2
3use std::collections::{HashMap, HashSet};
4
5use xlog_core::{RelId, Result, XlogError};
6use xlog_ir::rir::{LookupPerm, ProjectExpr, VariableOrder};
7use xlog_ir::{ExecutionPlan, JoinType, RirNode};
8
9use super::RelationDelta;
10use super::{DeltaRecomputeStats, Executor};
11
12fn triangle_delta_var_order(leader_idx: u8) -> VariableOrder {
13    let lookup_perms = match leader_idx {
14        0 => vec![
15            LookupPerm {
16                input_idx: 1,
17                swap_cols: false,
18            },
19            LookupPerm {
20                input_idx: 2,
21                swap_cols: false,
22            },
23        ],
24        1 => vec![
25            LookupPerm {
26                input_idx: 2,
27                swap_cols: true,
28            },
29            LookupPerm {
30                input_idx: 0,
31                swap_cols: true,
32            },
33        ],
34        2 => vec![
35            LookupPerm {
36                input_idx: 1,
37                swap_cols: true,
38            },
39            LookupPerm {
40                input_idx: 0,
41                swap_cols: false,
42            },
43        ],
44        _ => unreachable!("triangle leader_idx out of range"),
45    };
46    let kernel_output_cols = match leader_idx {
47        0 => vec![
48            ProjectExpr::Column(0),
49            ProjectExpr::Column(1),
50            ProjectExpr::Column(2),
51        ],
52        1 => vec![
53            ProjectExpr::Column(2),
54            ProjectExpr::Column(0),
55            ProjectExpr::Column(1),
56        ],
57        2 => vec![
58            ProjectExpr::Column(0),
59            ProjectExpr::Column(2),
60            ProjectExpr::Column(1),
61        ],
62        _ => unreachable!("triangle leader_idx out of range"),
63    };
64    VariableOrder::legacy(leader_idx, lookup_perms, kernel_output_cols)
65}
66
67fn cycle4_delta_var_order(leader_idx: u8) -> VariableOrder {
68    let lookup_perms = (1..4)
69        .map(|offset| LookupPerm {
70            input_idx: ((leader_idx as usize + offset) % 4) as u8,
71            swap_cols: false,
72        })
73        .collect();
74    let kernel_output_cols = match leader_idx {
75        0 => vec![
76            ProjectExpr::Column(0),
77            ProjectExpr::Column(1),
78            ProjectExpr::Column(2),
79            ProjectExpr::Column(3),
80        ],
81        1 => vec![
82            ProjectExpr::Column(3),
83            ProjectExpr::Column(0),
84            ProjectExpr::Column(1),
85            ProjectExpr::Column(2),
86        ],
87        2 => vec![
88            ProjectExpr::Column(2),
89            ProjectExpr::Column(3),
90            ProjectExpr::Column(0),
91            ProjectExpr::Column(1),
92        ],
93        3 => vec![
94            ProjectExpr::Column(1),
95            ProjectExpr::Column(2),
96            ProjectExpr::Column(3),
97            ProjectExpr::Column(0),
98        ],
99        _ => unreachable!("4-cycle leader_idx out of range"),
100    };
101    VariableOrder::legacy(leader_idx, lookup_perms, kernel_output_cols)
102}
103
104fn delta_outermost_var_order(
105    input_count: usize,
106    replaced_input_idx: Option<usize>,
107    current: Option<&VariableOrder>,
108) -> Option<VariableOrder> {
109    let Some(idx) = replaced_input_idx else {
110        return current.cloned();
111    };
112    if current.and_then(|order| order.kclique.as_ref()).is_some() {
113        return current.cloned();
114    }
115    if idx == 0 {
116        return None;
117    }
118    match input_count {
119        3 if idx < 3 => Some(triangle_delta_var_order(idx as u8)),
120        4 if idx < 4 => Some(cycle4_delta_var_order(idx as u8)),
121        _ => current.cloned(),
122    }
123}
124
125impl Executor {
126    /// Apply base-relation deltas and recompute affected SCCs (no recompilation).
127    ///
128    /// This provides correctness for both insertions and deletions by recomputing any SCCs that
129    /// depend (directly or transitively) on the changed relations.
130    pub fn apply_deltas_and_recompute(
131        &mut self,
132        plan: &ExecutionPlan,
133        deltas: &HashMap<String, RelationDelta>,
134    ) -> Result<DeltaRecomputeStats> {
135        if deltas.is_empty() {
136            return Ok(DeltaRecomputeStats::default());
137        }
138
139        let has_deletes = deltas
140            .values()
141            .any(|d| d.delete.as_ref().map(|b| !b.is_empty()).unwrap_or(false));
142
143        // 1) Apply EDB updates.
144        for (name, delta) in deltas {
145            let existing = self.store.get(name);
146
147            let base_schema = existing
148                .map(|b| b.schema().clone())
149                .or_else(|| delta.insert.as_ref().map(|b| b.schema().clone()))
150                .or_else(|| delta.delete.as_ref().map(|b| b.schema().clone()))
151                .ok_or_else(|| {
152                    XlogError::Execution(format!(
153                        "Delta update for {} has no existing relation and no schema",
154                        name
155                    ))
156                })?;
157
158            let mut updated = if let Some(buf) = existing {
159                self.clone_buffer(buf)?
160            } else {
161                self.create_empty_buffer(base_schema)?
162            };
163
164            if let Some(delete_buf) = &delta.delete {
165                updated = self.provider.diff_gpu(&updated, delete_buf)?;
166            }
167            if let Some(insert_buf) = &delta.insert {
168                updated = self.provider.union_gpu(&updated, insert_buf)?;
169            }
170
171            self.store_put(name, updated);
172        }
173
174        // 2) Compute affected SCC closure.
175        let changed_preds: HashSet<&str> = deltas.keys().map(|s| s.as_str()).collect();
176
177        let mut pred_to_scc: HashMap<&str, u32> = HashMap::new();
178        for scc in &plan.sccs {
179            for pred in &scc.predicates {
180                pred_to_scc.insert(pred.as_str(), scc.id);
181            }
182        }
183
184        let mut dependents: HashMap<u32, Vec<u32>> = HashMap::new();
185        for (scc_id, rules) in plan.rules_by_scc.iter().enumerate() {
186            let scc_id = scc_id as u32;
187            for rule in rules {
188                let mut rels = Vec::new();
189                Self::collect_scan_rels(&rule.body, &mut rels);
190                for rel in rels {
191                    let Some(name) = self.get_rel_name(rel) else {
192                        continue;
193                    };
194                    let Some(&dep_scc) = pred_to_scc.get(name) else {
195                        continue;
196                    };
197                    if dep_scc == scc_id {
198                        continue;
199                    }
200                    dependents.entry(dep_scc).or_default().push(scc_id);
201                }
202            }
203        }
204
205        let mut affected: HashSet<u32> = HashSet::new();
206        let mut queue: Vec<u32> = Vec::new();
207        for pred in &changed_preds {
208            if let Some(&scc) = pred_to_scc.get(*pred) {
209                affected.insert(scc);
210                queue.push(scc);
211            }
212        }
213
214        while let Some(scc) = queue.pop() {
215            if let Some(deps) = dependents.get(&scc) {
216                for &next in deps {
217                    if affected.insert(next) {
218                        queue.push(next);
219                    }
220                }
221            }
222        }
223
224        if affected.is_empty() {
225            return Ok(DeltaRecomputeStats {
226                changed_relations: deltas.len(),
227                has_deletes,
228                affected_sccs: 0,
229                recomputed_sccs: 0,
230                incremental_sccs: 0,
231            });
232        }
233
234        fn contains_non_monotonic_ops(node: &RirNode) -> bool {
235            match node {
236                RirNode::Unit | RirNode::Scan { .. } => false,
237                RirNode::Filter { input, .. }
238                | RirNode::Project { input, .. }
239                | RirNode::Distinct { input, .. } => contains_non_monotonic_ops(input),
240                RirNode::Union { inputs } => inputs.iter().any(contains_non_monotonic_ops),
241                RirNode::GroupBy { .. } | RirNode::Diff { .. } => true,
242                RirNode::Join {
243                    left,
244                    right,
245                    join_type,
246                    ..
247                } => {
248                    matches!(join_type, JoinType::Anti | JoinType::LeftOuter)
249                        || contains_non_monotonic_ops(left)
250                        || contains_non_monotonic_ops(right)
251                }
252                RirNode::Fixpoint {
253                    base, recursive, ..
254                } => contains_non_monotonic_ops(base) || contains_non_monotonic_ops(recursive),
255                RirNode::TensorMaskedJoin { .. } => false,
256                // Walk the fallback. The promoter only wraps already-monotonic
257                // multiway subtrees, but the fallback is the load-bearing
258                // source of truth.
259                RirNode::MultiWayJoin { fallback, .. } | RirNode::ChainJoin { fallback, .. } => {
260                    contains_non_monotonic_ops(fallback)
261                }
262            }
263        }
264
265        // 3) Decide which SCCs must be recomputed (cleared first).
266        //
267        // If there are deletes, we always recompute for correctness.
268        // If there are only inserts, we can incrementally update SCCs that are monotone w.r.t.
269        // insertion (no anti-joins, diffs, or aggregates) and do a targeted recompute for the rest.
270        let mut recompute_sccs: HashSet<u32> = HashSet::new();
271        if has_deletes {
272            recompute_sccs = affected.clone();
273        } else {
274            for &scc_id in &affected {
275                if let Some(rules) = plan.rules_by_scc.get(scc_id as usize) {
276                    if rules.iter().any(|r| contains_non_monotonic_ops(&r.body)) {
277                        recompute_sccs.insert(scc_id);
278                    }
279                }
280            }
281
282            // If any SCC is recomputed due to non-monotonic ops, all dependents must also be
283            // recomputed because their prior outputs may now be invalid.
284            let mut queue: Vec<u32> = recompute_sccs.iter().copied().collect();
285            while let Some(scc) = queue.pop() {
286                if let Some(deps) = dependents.get(&scc) {
287                    for &next in deps {
288                        if !affected.contains(&next) {
289                            continue;
290                        }
291                        if recompute_sccs.insert(next) {
292                            queue.push(next);
293                        }
294                    }
295                }
296            }
297        }
298
299        // 4) Clear IDB relations for SCCs we are recomputing (but never clear directly-updated bases).
300        for scc_id in &recompute_sccs {
301            let Some(scc) = plan.sccs.iter().find(|s| s.id == *scc_id) else {
302                continue;
303            };
304
305            for pred in &scc.predicates {
306                if changed_preds.contains(pred.as_str()) {
307                    continue;
308                }
309                let schema = self
310                    .store
311                    .get(pred)
312                    .map(|b| b.schema().clone())
313                    .or_else(|| {
314                        plan.rules_by_scc
315                            .get(*scc_id as usize)
316                            .and_then(|rules| rules.iter().find(|r| r.head == pred.as_str()))
317                            .and_then(|r| {
318                                let schema = r.meta.schema.clone();
319                                if schema.arity() > 0 {
320                                    Some(schema)
321                                } else {
322                                    None
323                                }
324                            })
325                    })
326                    .ok_or_else(|| {
327                        XlogError::Execution(format!(
328                            "Missing schema for predicate {} during recompute",
329                            pred
330                        ))
331                    })?;
332
333                let empty = self.create_empty_buffer(schema)?;
334                self.store_put(pred, empty);
335            }
336        }
337
338        // 5) Re-execute affected SCCs in plan order (incremental for insert-only monotone SCCs).
339        for stratum in &plan.strata {
340            for &scc_id in &stratum.sccs {
341                if !affected.contains(&scc_id) {
342                    continue;
343                }
344                let rules = plan.rules_by_scc.get(scc_id as usize).ok_or_else(|| {
345                    XlogError::Execution(format!("Missing rules for SCC {}", scc_id))
346                })?;
347                let is_recursive = plan
348                    .sccs
349                    .iter()
350                    .find(|s| s.id == scc_id)
351                    .map(|s| s.is_recursive)
352                    .unwrap_or(false);
353
354                if is_recursive {
355                    self.execute_recursive_scc(rules)?;
356                } else {
357                    self.execute_non_recursive_scc(rules)?;
358                }
359            }
360        }
361
362        Ok(DeltaRecomputeStats {
363            changed_relations: deltas.len(),
364            has_deletes,
365            affected_sccs: affected.len(),
366            recomputed_sccs: recompute_sccs.len(),
367            incremental_sccs: affected.len().saturating_sub(recompute_sccs.len()),
368        })
369    }
370
371    pub(crate) fn collect_scan_rels(node: &RirNode, out: &mut Vec<RelId>) {
372        match node {
373            RirNode::Unit => {}
374            RirNode::Scan { rel } => out.push(*rel),
375            RirNode::Filter { input, .. } | RirNode::Project { input, .. } => {
376                Self::collect_scan_rels(input, out);
377            }
378            RirNode::Join { left, right, .. }
379            | RirNode::ChainJoin { left, right, .. }
380            | RirNode::Diff { left, right } => {
381                Self::collect_scan_rels(left, out);
382                Self::collect_scan_rels(right, out);
383            }
384            RirNode::GroupBy { input, .. } | RirNode::Distinct { input, .. } => {
385                Self::collect_scan_rels(input, out);
386            }
387            RirNode::Union { inputs } => {
388                for input in inputs {
389                    Self::collect_scan_rels(input, out);
390                }
391            }
392            RirNode::Fixpoint {
393                base, recursive, ..
394            } => {
395                Self::collect_scan_rels(base, out);
396                Self::collect_scan_rels(recursive, out);
397            }
398            RirNode::TensorMaskedJoin { rel_index, .. } => {
399                for (rel_id, _) in rel_index {
400                    out.push(*rel_id);
401                }
402            }
403            // Collect from `inputs` only; the fallback subtree references the
404            // same set by promoter invariant.
405            RirNode::MultiWayJoin { inputs, .. } => {
406                for input in inputs {
407                    Self::collect_scan_rels(input, out);
408                }
409            }
410        }
411    }
412
413    pub(crate) fn rewrite_scan_nth(
414        node: &RirNode,
415        target: RelId,
416        nth: usize,
417        replacement: RelId,
418    ) -> Option<RirNode> {
419        let mut remaining = nth;
420        let (rewritten, replaced) =
421            Self::rewrite_scan_nth_impl(node, target, &mut remaining, replacement);
422        replaced.then_some(rewritten)
423    }
424
425    fn rewrite_scan_nth_impl(
426        node: &RirNode,
427        target: RelId,
428        remaining: &mut usize,
429        replacement: RelId,
430    ) -> (RirNode, bool) {
431        match node {
432            RirNode::Unit => (RirNode::Unit, false),
433            RirNode::Scan { rel } => {
434                if *rel == target {
435                    if *remaining == 0 {
436                        // Replace exactly one occurrence per `rewrite_scan_nth`
437                        // call, then mark this walk "done" via the
438                        // `usize::MAX` sentinel so subsequent matches in the
439                        // same walk do NOT replace again. Without this, a body
440                        // with 2+ same-predicate recursive Scans would have
441                        // ALL occurrences after `nth` overwritten when the
442                        // caller intended only the `nth`-th to be substituted.
443                        *remaining = usize::MAX;
444                        return (RirNode::Scan { rel: replacement }, true);
445                    }
446                    if *remaining != usize::MAX {
447                        *remaining -= 1;
448                    }
449                }
450                (node.clone(), false)
451            }
452
453            RirNode::Filter { input, predicate } => {
454                let (new_input, replaced) =
455                    Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
456                (
457                    RirNode::Filter {
458                        input: Box::new(new_input),
459                        predicate: predicate.clone(),
460                    },
461                    replaced,
462                )
463            }
464
465            RirNode::Project { input, columns } => {
466                let (new_input, replaced) =
467                    Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
468                (
469                    RirNode::Project {
470                        input: Box::new(new_input),
471                        columns: columns.clone(),
472                    },
473                    replaced,
474                )
475            }
476
477            RirNode::Join {
478                left,
479                right,
480                left_keys,
481                right_keys,
482                join_type,
483            } => {
484                let (new_left, replaced_left) =
485                    Self::rewrite_scan_nth_impl(left, target, remaining, replacement);
486                if replaced_left {
487                    return (
488                        RirNode::Join {
489                            left: Box::new(new_left),
490                            right: right.clone(),
491                            left_keys: left_keys.clone(),
492                            right_keys: right_keys.clone(),
493                            join_type: *join_type,
494                        },
495                        true,
496                    );
497                }
498                let (new_right, replaced_right) =
499                    Self::rewrite_scan_nth_impl(right, target, remaining, replacement);
500                (
501                    RirNode::Join {
502                        left: Box::new(new_left),
503                        right: Box::new(new_right),
504                        left_keys: left_keys.clone(),
505                        right_keys: right_keys.clone(),
506                        join_type: *join_type,
507                    },
508                    replaced_right,
509                )
510            }
511
512            RirNode::GroupBy {
513                input,
514                key_cols,
515                aggs,
516            } => {
517                let (new_input, replaced) =
518                    Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
519                (
520                    RirNode::GroupBy {
521                        input: Box::new(new_input),
522                        key_cols: key_cols.clone(),
523                        aggs: aggs.clone(),
524                    },
525                    replaced,
526                )
527            }
528
529            RirNode::Union { inputs } => {
530                let mut replaced_any = false;
531                let mut new_inputs = Vec::with_capacity(inputs.len());
532                for input in inputs {
533                    let (new_input, replaced) =
534                        Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
535                    replaced_any |= replaced;
536                    new_inputs.push(new_input);
537                }
538                (RirNode::Union { inputs: new_inputs }, replaced_any)
539            }
540
541            RirNode::Distinct { input, key_cols } => {
542                let (new_input, replaced) =
543                    Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
544                (
545                    RirNode::Distinct {
546                        input: Box::new(new_input),
547                        key_cols: key_cols.clone(),
548                    },
549                    replaced,
550                )
551            }
552
553            RirNode::Diff { left, right } => {
554                let (new_left, replaced_left) =
555                    Self::rewrite_scan_nth_impl(left, target, remaining, replacement);
556                if replaced_left {
557                    return (
558                        RirNode::Diff {
559                            left: Box::new(new_left),
560                            right: right.clone(),
561                        },
562                        true,
563                    );
564                }
565                let (new_right, replaced_right) =
566                    Self::rewrite_scan_nth_impl(right, target, remaining, replacement);
567                (
568                    RirNode::Diff {
569                        left: Box::new(new_left),
570                        right: Box::new(new_right),
571                    },
572                    replaced_right,
573                )
574            }
575
576            RirNode::Fixpoint {
577                scc_id,
578                base,
579                recursive,
580                delta_rel,
581                full_rel,
582            } => {
583                let (new_base, replaced_base) =
584                    Self::rewrite_scan_nth_impl(base, target, remaining, replacement);
585                if replaced_base {
586                    return (
587                        RirNode::Fixpoint {
588                            scc_id: *scc_id,
589                            base: Box::new(new_base),
590                            recursive: recursive.clone(),
591                            delta_rel: *delta_rel,
592                            full_rel: *full_rel,
593                        },
594                        true,
595                    );
596                }
597                let (new_recursive, replaced_recursive) =
598                    Self::rewrite_scan_nth_impl(recursive, target, remaining, replacement);
599                (
600                    RirNode::Fixpoint {
601                        scc_id: *scc_id,
602                        base: Box::new(new_base),
603                        recursive: Box::new(new_recursive),
604                        delta_rel: *delta_rel,
605                        full_rel: *full_rel,
606                    },
607                    replaced_recursive,
608                )
609            }
610            RirNode::TensorMaskedJoin { .. } => {
611                // TensorMaskedJoin is a leaf node — no child scans to rewrite.
612                (node.clone(), false)
613            }
614            RirNode::ChainJoin {
615                left,
616                right,
617                left_key,
618                right_key,
619                output_columns,
620                fallback,
621            } => {
622                let starting_remaining = *remaining;
623                let mut inputs_remaining = starting_remaining;
624                let (new_left, replaced_left) =
625                    Self::rewrite_scan_nth_impl(left, target, &mut inputs_remaining, replacement);
626                let (new_right, replaced_right) =
627                    Self::rewrite_scan_nth_impl(right, target, &mut inputs_remaining, replacement);
628                let mut fallback_remaining = starting_remaining;
629                let (new_fallback, fallback_replaced) = Self::rewrite_scan_nth_impl(
630                    fallback,
631                    target,
632                    &mut fallback_remaining,
633                    replacement,
634                );
635                *remaining = inputs_remaining;
636                (
637                    RirNode::ChainJoin {
638                        left: Box::new(new_left),
639                        right: Box::new(new_right),
640                        left_key: *left_key,
641                        right_key: *right_key,
642                        output_columns: output_columns.clone(),
643                        fallback: Box::new(new_fallback),
644                    },
645                    replaced_left || replaced_right || fallback_replaced,
646                )
647            }
648            // Rewrite `inputs` and `fallback` with SEPARATE `remaining`
649            // counter copies. Both views are the same logical body, so each
650            // must independently target the N-th occurrence. Sharing one
651            // counter across the two walks contaminated the fallback's count
652            // by the inputs' consumed matches, which produced
653            // wrong-occurrence substitutions on self-recursive bodies. The
654            // outer caller's `remaining` is updated to whatever the inputs
655            // walk consumed, so siblings of this MultiWayJoin (rare; typically
656            // wrapped in Project) see consistent counting.
657            RirNode::MultiWayJoin {
658                inputs,
659                slot_vars,
660                output_columns,
661                fallback,
662                plan,
663                var_order,
664            } => {
665                let starting_remaining = *remaining;
666                let mut inputs_remaining = starting_remaining;
667                let mut new_inputs = Vec::with_capacity(inputs.len());
668                let mut any_replaced = false;
669                let mut replaced_input_idx = None;
670                for (idx, inp) in inputs.iter().enumerate() {
671                    let (new_inp, replaced) = Self::rewrite_scan_nth_impl(
672                        inp,
673                        target,
674                        &mut inputs_remaining,
675                        replacement,
676                    );
677                    any_replaced |= replaced;
678                    if replaced {
679                        replaced_input_idx = Some(idx);
680                    }
681                    new_inputs.push(new_inp);
682                }
683                let mut fallback_remaining = starting_remaining;
684                let (new_fallback, fallback_replaced) = Self::rewrite_scan_nth_impl(
685                    fallback,
686                    target,
687                    &mut fallback_remaining,
688                    replacement,
689                );
690                *remaining = inputs_remaining;
691                let input_count = new_inputs.len();
692                (
693                    RirNode::MultiWayJoin {
694                        inputs: new_inputs,
695                        slot_vars: slot_vars.clone(),
696                        output_columns: output_columns.clone(),
697                        fallback: Box::new(new_fallback),
698                        plan: plan.clone(),
699                        var_order: delta_outermost_var_order(
700                            input_count,
701                            replaced_input_idx,
702                            var_order.as_ref(),
703                        ),
704                    },
705                    any_replaced || fallback_replaced,
706                )
707            }
708        }
709    }
710}
711
712#[cfg(test)]
713mod multiway_walker_tests {
714    //! Walker arm coverage for `MultiWayJoin` in the rewrite module.
715    //! `contains_non_monotonic_ops` is a nested `fn` inside an `Executor`
716    //! method and is not directly callable; its arm is exercised through
717    //! integration tests. The two `pub(crate)` walkers below are testable in
718    //! isolation.
719
720    use super::*;
721    use xlog_ir::rir::ProjectExpr;
722
723    fn triangle_multiway(a: RelId, b: RelId, c: RelId) -> RirNode {
724        let scan_a = RirNode::Scan { rel: a };
725        let scan_b = RirNode::Scan { rel: b };
726        let scan_c = RirNode::Scan { rel: c };
727        let inner = RirNode::Join {
728            left: Box::new(scan_a.clone()),
729            right: Box::new(scan_b.clone()),
730            left_keys: vec![1],
731            right_keys: vec![0],
732            join_type: JoinType::Inner,
733        };
734        let outer = RirNode::Join {
735            left: Box::new(inner),
736            right: Box::new(scan_c.clone()),
737            left_keys: vec![0, 3],
738            right_keys: vec![0, 1],
739            join_type: JoinType::Inner,
740        };
741        let fallback = RirNode::Project {
742            input: Box::new(outer),
743            columns: vec![
744                ProjectExpr::Column(0),
745                ProjectExpr::Column(1),
746                ProjectExpr::Column(3),
747            ],
748        };
749        RirNode::MultiWayJoin {
750            inputs: vec![scan_a, scan_b, scan_c],
751            slot_vars: vec![
752                vec![Some(0), Some(1)],
753                vec![Some(1), Some(2)],
754                vec![Some(0), Some(2)],
755            ],
756            output_columns: vec![
757                ProjectExpr::Column(0),
758                ProjectExpr::Column(1),
759                ProjectExpr::Column(3),
760            ],
761            fallback: Box::new(fallback),
762            plan: None,
763            var_order: None,
764        }
765    }
766
767    #[test]
768    fn collect_scan_rels_walks_multiway_inputs_only() {
769        let node = triangle_multiway(RelId(10), RelId(20), RelId(30));
770        let mut out = Vec::new();
771        Executor::collect_scan_rels(&node, &mut out);
772        // One entry per input slot; fallback is NOT walked (would
773        // double-count to 6 entries if it were).
774        assert_eq!(out.len(), 3, "expected 3 scan rels, got: {:?}", out);
775        assert!(out.contains(&RelId(10)));
776        assert!(out.contains(&RelId(20)));
777        assert!(out.contains(&RelId(30)));
778    }
779
780    /// Input/fallback symmetric rewrite semantic: `RelId(10)` appears once in
781    /// `inputs[0]` AND once inside `fallback` (the outer join's leftmost leaf).
782    /// Both copies are the 0-th occurrence in their respective walks. Because
783    /// inputs and fallback are two views of the same logical body, `occ=0`
784    /// substitutes BOTH copies.
785    ///
786    /// `occ=1` returns `None` because `RelId(10)` has only ONE
787    /// occurrence per view; there is no 2nd occurrence to substitute.
788    #[test]
789    fn rewrite_scan_nth_rewrites_inputs_and_fallback() {
790        let node = triangle_multiway(RelId(10), RelId(20), RelId(30));
791
792        // occ=0 substitutes input[0] AND fallback's leftmost leaf.
793        let rewritten =
794            Executor::rewrite_scan_nth(&node, RelId(10), 0, RelId(99)).expect("occ=0 must succeed");
795        match rewritten {
796            RirNode::MultiWayJoin {
797                inputs, fallback, ..
798            } => {
799                // Input[0] is the replacement.
800                assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(99) }));
801                // Inputs[1] and [2] are unchanged.
802                assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(20) }));
803                assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(30) }));
804                // Fallback's RelId(10) leaf is now RelId(99); no
805                // RelId(10) remains in the fallback.
806                fn find_rel(n: &RirNode, target: RelId) -> bool {
807                    match n {
808                        RirNode::Scan { rel } => *rel == target,
809                        RirNode::Project { input, .. } => find_rel(input, target),
810                        RirNode::Join { left, right, .. } => {
811                            find_rel(left, target) || find_rel(right, target)
812                        }
813                        _ => false,
814                    }
815                }
816                assert!(
817                    find_rel(&fallback, RelId(99)),
818                    "fallback must contain RelId(99) at the 0-th occurrence position"
819                );
820                assert!(
821                    !find_rel(&fallback, RelId(10)),
822                    "fallback must NOT contain RelId(10) — the only occurrence was substituted"
823                );
824            }
825            _ => panic!("expected MultiWayJoin after rewrite"),
826        }
827
828        // occ=1 returns None: RelId(10) appears only once per view.
829        assert!(
830            Executor::rewrite_scan_nth(&node, RelId(10), 1, RelId(99)).is_none(),
831            "occ=1 must return None — RelId(10) has only 1 occurrence per view"
832        );
833    }
834
835    /// MultiWayJoin shape-agnosticism guard.
836    ///
837    /// The earliest promoter path was triangle-only; later paths add 4-input
838    /// shapes. The walker arms in `collect_scan_rels` and
839    /// `rewrite_scan_nth_impl` must NOT hard-code `inputs.len() == 3`.
840    /// Synthesize a 4-input `MultiWayJoin` directly and exercise the walker.
841    /// This test does NOT execute the IR through the runtime; it only pins the
842    /// walker's contract.
843    fn fourway_multiway(a: RelId, b: RelId, c: RelId, d: RelId) -> RirNode {
844        // Synthetic 4-cycle slot_vars [[A,B],[B,C],[C,D],[A,D]] with
845        // a stub fallback whose Scan leaves repeat each rel once.
846        let inner1 = RirNode::Join {
847            left: Box::new(RirNode::Scan { rel: a }),
848            right: Box::new(RirNode::Scan { rel: b }),
849            left_keys: vec![1],
850            right_keys: vec![0],
851            join_type: JoinType::Inner,
852        };
853        let inner2 = RirNode::Join {
854            left: Box::new(inner1),
855            right: Box::new(RirNode::Scan { rel: c }),
856            left_keys: vec![3],
857            right_keys: vec![0],
858            join_type: JoinType::Inner,
859        };
860        let outer = RirNode::Join {
861            left: Box::new(inner2),
862            right: Box::new(RirNode::Scan { rel: d }),
863            left_keys: vec![0, 5],
864            right_keys: vec![0, 1],
865            join_type: JoinType::Inner,
866        };
867        let fallback = RirNode::Project {
868            input: Box::new(outer),
869            columns: vec![
870                xlog_ir::rir::ProjectExpr::Column(0),
871                xlog_ir::rir::ProjectExpr::Column(1),
872                xlog_ir::rir::ProjectExpr::Column(3),
873                xlog_ir::rir::ProjectExpr::Column(5),
874            ],
875        };
876        RirNode::MultiWayJoin {
877            inputs: vec![
878                RirNode::Scan { rel: a },
879                RirNode::Scan { rel: b },
880                RirNode::Scan { rel: c },
881                RirNode::Scan { rel: d },
882            ],
883            slot_vars: vec![
884                vec![Some(0), Some(1)],
885                vec![Some(1), Some(2)],
886                vec![Some(2), Some(3)],
887                vec![Some(0), Some(3)],
888            ],
889            output_columns: vec![
890                xlog_ir::rir::ProjectExpr::Column(0),
891                xlog_ir::rir::ProjectExpr::Column(1),
892                xlog_ir::rir::ProjectExpr::Column(2),
893                xlog_ir::rir::ProjectExpr::Column(3),
894            ],
895            fallback: Box::new(fallback),
896            plan: None,
897            var_order: None,
898        }
899    }
900
901    #[test]
902    fn collect_scan_rels_handles_4_inputs() {
903        let node = fourway_multiway(RelId(10), RelId(20), RelId(30), RelId(40));
904        let mut out = Vec::new();
905        Executor::collect_scan_rels(&node, &mut out);
906        assert_eq!(
907            out.len(),
908            4,
909            "expected 4 scan rels, got {} entries: {:?}",
910            out.len(),
911            out
912        );
913        for id in [10, 20, 30, 40] {
914            assert!(out.contains(&RelId(id)), "RelId({}) missing", id);
915        }
916    }
917
918    /// Input/fallback symmetric rewrite semantic for the 4-input shape:
919    /// `RelId(40)` appears once in `inputs[3]` AND once inside `fallback` (the
920    /// outer join's right scan). Both copies are the 0-th occurrence in their
921    /// respective walks; the input/fallback symmetry contract makes `occ=0`
922    /// substitute BOTH copies.
923    ///
924    /// `occ=1` returns `None` because `RelId(40)` has only ONE
925    /// occurrence per view.
926    #[test]
927    fn rewrite_scan_nth_handles_4_inputs_and_fallback() {
928        let node = fourway_multiway(RelId(10), RelId(20), RelId(30), RelId(40));
929
930        // occ=0 substitutes input[3] AND fallback's RelId(40) leaf.
931        let rewritten =
932            Executor::rewrite_scan_nth(&node, RelId(40), 0, RelId(99)).expect("occ=0 must succeed");
933        let RirNode::MultiWayJoin {
934            inputs, fallback, ..
935        } = rewritten
936        else {
937            panic!("expected MultiWayJoin");
938        };
939        // Input[3] is the replacement; inputs [0..2] unchanged.
940        assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(10) }));
941        assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(20) }));
942        assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(30) }));
943        assert!(matches!(inputs[3], RirNode::Scan { rel: RelId(99) }));
944        // Fallback's RelId(40) leaf is now RelId(99); no RelId(40)
945        // remains in the fallback.
946        fn find_rel(n: &RirNode, target: RelId) -> bool {
947            match n {
948                RirNode::Scan { rel } => *rel == target,
949                RirNode::Project { input, .. } => find_rel(input, target),
950                RirNode::Join { left, right, .. } => {
951                    find_rel(left, target) || find_rel(right, target)
952                }
953                _ => false,
954            }
955        }
956        assert!(
957            find_rel(&fallback, RelId(99)),
958            "fallback must contain RelId(99) at the 0-th occurrence position"
959        );
960        assert!(
961            !find_rel(&fallback, RelId(40)),
962            "fallback must NOT contain RelId(40) — the only occurrence was substituted"
963        );
964
965        // occ=1 returns None: RelId(40) appears only once per view.
966        assert!(
967            Executor::rewrite_scan_nth(&node, RelId(40), 1, RelId(99)).is_none(),
968            "occ=1 must return None — RelId(40) has only 1 occurrence per view"
969        );
970    }
971
972    #[test]
973    fn delta_outermost_leader_selection_rebinds_triangle_variant() {
974        let node = triangle_multiway(RelId(10), RelId(20), RelId(30));
975        let rewritten =
976            Executor::rewrite_scan_nth(&node, RelId(20), 0, RelId(99)).expect("rewrite must hit");
977        let RirNode::MultiWayJoin { var_order, .. } = rewritten else {
978            panic!("expected MultiWayJoin");
979        };
980        let var_order = var_order.expect("rewritten input 1 must become leader");
981        assert_eq!(var_order.leader_idx, 1);
982        assert_eq!(var_order.lookup_perms.len(), 2);
983        assert_eq!(var_order.lookup_perms[0].input_idx, 2);
984        assert_eq!(var_order.lookup_perms[1].input_idx, 0);
985    }
986
987    #[test]
988    fn delta_outermost_leader_selection_rebinds_4cycle_variant() {
989        let node = fourway_multiway(RelId(10), RelId(20), RelId(30), RelId(40));
990        let rewritten =
991            Executor::rewrite_scan_nth(&node, RelId(40), 0, RelId(99)).expect("rewrite must hit");
992        let RirNode::MultiWayJoin { var_order, .. } = rewritten else {
993            panic!("expected MultiWayJoin");
994        };
995        let var_order = var_order.expect("rewritten input 3 must become leader");
996        assert_eq!(var_order.leader_idx, 3);
997        assert_eq!(var_order.lookup_perms.len(), 3);
998        assert_eq!(var_order.lookup_perms[0].input_idx, 0);
999        assert_eq!(var_order.lookup_perms[1].input_idx, 1);
1000        assert_eq!(var_order.lookup_perms[2].input_idx, 2);
1001    }
1002}
1003
1004#[cfg(test)]
1005mod rewrite_scan_nth_occurrence_identity_tests {
1006    //! `rewrite_scan_nth` occurrence-identity preservation. The scan-case
1007    //! sentinel after replacement and the `MultiWayJoin` arm's separate
1008    //! `remaining` counters for inputs vs fallback ensure:
1009    //!
1010    //! 1. For a body with N same-predicate occurrences, calling
1011    //!    `rewrite_scan_nth(body, target, occ=k, replacement)` substitutes
1012    //!    EXACTLY ONE occurrence (the k-th) — not 0, not >1.
1013    //!
1014    //! 2. For a `MultiWayJoin` whose `inputs` and `fallback` both contain
1015    //!    the target, occ=k substitutes the k-th occurrence INDEPENDENTLY
1016    //!    in inputs AND in fallback (both views share the same logical
1017    //!    body; both must reflect the same logical rewrite).
1018    //!
1019    //! Fixed behavior bugs:
1020    //! - Scan case early-returned on match without decrementing
1021    //!   `remaining`, so subsequent matches in the same walk would also
1022    //!   replace at remaining==0.
1023    //! - MultiWayJoin arm shared `&mut remaining` across the inputs walk
1024    //!   and the subsequent fallback walk; the fallback walk's counter
1025    //!   was contaminated by inputs' consumption.
1026    //!
1027    //! Both bugs were latent on distinct-recursive-predicate fixtures and
1028    //! manifest on same-predicate self-recursive bodies with multiple target
1029    //! occurrences.
1030
1031    use super::*;
1032    use xlog_ir::rir::ProjectExpr;
1033    use xlog_ir::JoinType;
1034
1035    fn chain_join(left_rel: RelId, right_rel: RelId) -> RirNode {
1036        let left = RirNode::Scan { rel: left_rel };
1037        let right = RirNode::Scan { rel: right_rel };
1038        let fallback = RirNode::Project {
1039            input: Box::new(RirNode::Join {
1040                left: Box::new(left.clone()),
1041                right: Box::new(right.clone()),
1042                left_keys: vec![1],
1043                right_keys: vec![0],
1044                join_type: JoinType::Inner,
1045            }),
1046            columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
1047        };
1048        RirNode::ChainJoin {
1049            left: Box::new(left),
1050            right: Box::new(right),
1051            left_key: 1,
1052            right_key: 0,
1053            output_columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
1054            fallback: Box::new(fallback),
1055        }
1056    }
1057
1058    /// Build a synthetic `MultiWayJoin` whose `inputs` are 3 same-
1059    /// predicate Scans (`Scan { rel: target_rel }` × 3) plus a fallback
1060    /// that mirrors the same 3 Scans inside a left-deep Join chain.
1061    /// This is the structural shape of a self-recursive triangle body
1062    /// like `tri(X,Y,Z) :- p(X,Y), p(Y,Z), p(X,Z)` after promotion: 3
1063    /// inputs slots all targeting `p`, fallback containing 3 `p` Scans
1064    /// in the binary-join expansion. The fallback's left-deep Join
1065    /// produces canonical depth-first walk order
1066    /// `[innermost-left, inner-right, outer-right]` — matching the
1067    /// inputs' left-to-right order, so the k-th input occurrence and
1068    /// the k-th fallback occurrence correspond to the same logical
1069    /// body slot.
1070    fn three_same_predicate_multiway(target_rel: RelId) -> RirNode {
1071        let inputs = vec![
1072            RirNode::Scan { rel: target_rel },
1073            RirNode::Scan { rel: target_rel },
1074            RirNode::Scan { rel: target_rel },
1075        ];
1076        let inner = RirNode::Join {
1077            left: Box::new(RirNode::Scan { rel: target_rel }),
1078            right: Box::new(RirNode::Scan { rel: target_rel }),
1079            left_keys: vec![1],
1080            right_keys: vec![0],
1081            join_type: JoinType::Inner,
1082        };
1083        let outer = RirNode::Join {
1084            left: Box::new(inner),
1085            right: Box::new(RirNode::Scan { rel: target_rel }),
1086            left_keys: vec![0, 3],
1087            right_keys: vec![0, 1],
1088            join_type: JoinType::Inner,
1089        };
1090        let fallback = RirNode::Project {
1091            input: Box::new(outer),
1092            columns: vec![
1093                ProjectExpr::Column(0),
1094                ProjectExpr::Column(1),
1095                ProjectExpr::Column(3),
1096            ],
1097        };
1098        RirNode::MultiWayJoin {
1099            inputs,
1100            slot_vars: vec![
1101                vec![Some(0), Some(1)],
1102                vec![Some(1), Some(2)],
1103                vec![Some(0), Some(2)],
1104            ],
1105            output_columns: vec![
1106                ProjectExpr::Column(0),
1107                ProjectExpr::Column(1),
1108                ProjectExpr::Column(2),
1109            ],
1110            fallback: Box::new(fallback),
1111            plan: None,
1112            var_order: None,
1113        }
1114    }
1115
1116    /// Walk a body depth-first / left-first and collect every Scan
1117    /// RelId in encounter order. For a `MultiWayJoin`, walks `inputs`
1118    /// in order then `fallback`. Used by the regression tests to
1119    /// assert EXACT post-rewrite positional identity.
1120    fn collect_scans_in_order(node: &RirNode, out: &mut Vec<RelId>) {
1121        match node {
1122            RirNode::Unit => {}
1123            RirNode::Scan { rel } => out.push(*rel),
1124            RirNode::Filter { input, .. }
1125            | RirNode::Project { input, .. }
1126            | RirNode::GroupBy { input, .. }
1127            | RirNode::Distinct { input, .. } => collect_scans_in_order(input, out),
1128            RirNode::Join { left, right, .. } | RirNode::Diff { left, right } => {
1129                collect_scans_in_order(left, out);
1130                collect_scans_in_order(right, out);
1131            }
1132            RirNode::ChainJoin {
1133                left,
1134                right,
1135                fallback,
1136                ..
1137            } => {
1138                collect_scans_in_order(left, out);
1139                collect_scans_in_order(right, out);
1140                collect_scans_in_order(fallback, out);
1141            }
1142            RirNode::Union { inputs } => {
1143                for n in inputs {
1144                    collect_scans_in_order(n, out);
1145                }
1146            }
1147            RirNode::Fixpoint {
1148                base, recursive, ..
1149            } => {
1150                collect_scans_in_order(base, out);
1151                collect_scans_in_order(recursive, out);
1152            }
1153            RirNode::TensorMaskedJoin { rel_index, .. } => {
1154                for (rid, _) in rel_index {
1155                    out.push(*rid);
1156                }
1157            }
1158            RirNode::MultiWayJoin {
1159                inputs, fallback, ..
1160            } => {
1161                for inp in inputs {
1162                    collect_scans_in_order(inp, out);
1163                }
1164                collect_scans_in_order(fallback, out);
1165            }
1166        }
1167    }
1168
1169    #[test]
1170    fn chain_join_rewrite_scan_nth_updates_dispatch_shape_and_fallback() {
1171        let target = RelId(7);
1172        let delta = RelId(700);
1173        let body = chain_join(target, RelId(8));
1174        let rewritten =
1175            Executor::rewrite_scan_nth(&body, target, 0, delta).expect("target scan must rewrite");
1176
1177        let mut scans = Vec::new();
1178        collect_scans_in_order(&rewritten, &mut scans);
1179        assert_eq!(
1180            scans,
1181            vec![delta, RelId(8), delta, RelId(8)],
1182            "ChainJoin dispatch inputs and fallback must target the same occurrence"
1183        );
1184
1185        let RirNode::ChainJoin { left, fallback, .. } = rewritten else {
1186            panic!("expected rewritten ChainJoin");
1187        };
1188        assert!(matches!(left.as_ref(), RirNode::Scan { rel } if *rel == delta));
1189        let RirNode::Project { input, .. } = fallback.as_ref() else {
1190            panic!("expected ChainJoin fallback Project");
1191        };
1192        let RirNode::Join { left, .. } = input.as_ref() else {
1193            panic!("expected ChainJoin fallback Join");
1194        };
1195        assert!(matches!(left.as_ref(), RirNode::Scan { rel } if *rel == delta));
1196    }
1197
1198    /// Regression pin: with 3 occurrences of `target` in `inputs`
1199    /// and 3 in `fallback`, the sentinel/separate-counter fix BOTH
1200    /// applied makes occ=k substitute the k-th occurrence in
1201    /// **inputs walk order** AND the k-th occurrence in **fallback
1202    /// walk order** — and ONLY those two positions; all other
1203    /// occurrences remain unchanged.
1204    ///
1205    /// This test asserts EXACT positional identity (not just total
1206    /// replacement count). A broken implementation that always
1207    /// rewrites occurrence 0 for every occ would pass a count-only
1208    /// check but fails this positional assertion.
1209    #[test]
1210    fn rewrite_scan_nth_replaces_exact_kth_occurrence_in_inputs_and_fallback() {
1211        let target = RelId(7);
1212        let body = three_same_predicate_multiway(target);
1213
1214        // Pre-rewrite: 6 Scans of target in canonical walk order.
1215        // [input[0], input[1], input[2], fallback's innermost-left,
1216        //  fallback's inner-right, fallback's outer-right].
1217        let mut pre = Vec::new();
1218        collect_scans_in_order(&body, &mut pre);
1219        assert_eq!(
1220            pre,
1221            vec![target, target, target, target, target, target],
1222            "pre-rewrite: 6 target Scans in canonical walk order"
1223        );
1224
1225        // For each occ in {0, 1, 2}, the k-th occurrence in the
1226        // INPUTS walk AND the k-th occurrence in the FALLBACK walk
1227        // are replaced — and nothing else.
1228        for occ in 0..3 {
1229            // Use a distinct RelId per occ so a buggy implementation
1230            // that always rewrites occurrence 0 for every occ would
1231            // produce a different post-rewrite Scan order than
1232            // expected.
1233            let replacement = RelId(100 + occ as u32);
1234            let rewritten = Executor::rewrite_scan_nth(&body, target, occ, replacement)
1235                .unwrap_or_else(|| panic!("occ={} must succeed", occ));
1236
1237            let mut post = Vec::new();
1238            collect_scans_in_order(&rewritten, &mut post);
1239
1240            // Build expected sequence: positions 0..3 = inputs walk,
1241            // positions 3..6 = fallback walk. Position `occ` in each
1242            // half becomes `replacement`; all others remain `target`.
1243            let mut expected = vec![target; 6];
1244            expected[occ] = replacement; // k-th input occurrence
1245            expected[3 + occ] = replacement; // k-th fallback occurrence
1246
1247            assert_eq!(
1248                post, expected,
1249                "occ={}: post-rewrite Scan order must replace EXACTLY the k-th occurrence in inputs AND fallback; got {:?}, expected {:?}",
1250                occ, post, expected
1251            );
1252        }
1253    }
1254
1255    /// Regression pin: occ=0 of a target appearing in input[0] AND
1256    /// in fallback's leftmost leaf substitutes BOTH copies (input/
1257    /// fallback symmetry). Locks the "logical body shared between inputs and
1258    /// fallback" semantic.
1259    ///
1260    /// This test asserts the EXACT post-rewrite shape (input[0]
1261    /// becomes replacement; the rest of the inputs+fallback structure
1262    /// is identical to pre-rewrite except fallback's leftmost-Scan
1263    /// becomes replacement). Complementary to
1264    /// `rewrite_scan_nth_replaces_exact_kth_occurrence_in_inputs_and_fallback`
1265    /// above which exercises occ ∈ {0, 1, 2}; this test is the
1266    /// focused occ=0 cert.
1267    #[test]
1268    fn rewrite_scan_nth_input_fallback_symmetry_at_occ_0() {
1269        let target = RelId(7);
1270        let replacement = RelId(99);
1271        let body = three_same_predicate_multiway(target);
1272
1273        let rewritten =
1274            Executor::rewrite_scan_nth(&body, target, 0, replacement).expect("occ=0 must succeed");
1275
1276        match rewritten {
1277            RirNode::MultiWayJoin {
1278                inputs, fallback, ..
1279            } => {
1280                // input[0] must be the replacement (the 0th occurrence
1281                // in inputs).
1282                assert!(
1283                    matches!(inputs[0], RirNode::Scan { rel } if rel == replacement),
1284                    "input[0] must be replacement; got {:?}",
1285                    inputs[0]
1286                );
1287                // input[1] and input[2] must remain the original target.
1288                assert!(matches!(inputs[1], RirNode::Scan { rel } if rel == target));
1289                assert!(matches!(inputs[2], RirNode::Scan { rel } if rel == target));
1290                // Fallback walk order: [innermost-left, inner-right,
1291                // outer-right]. occ=0 must replace the innermost-left
1292                // (position 0 in fallback walk) ONLY.
1293                let mut fallback_scans = Vec::new();
1294                collect_scans_in_order(&fallback, &mut fallback_scans);
1295                assert_eq!(
1296                    fallback_scans,
1297                    vec![replacement, target, target],
1298                    "fallback walk order: occ=0 must replace position 0 only"
1299                );
1300            }
1301            _ => panic!("expected MultiWayJoin after rewrite"),
1302        }
1303    }
1304}