1use std::collections::{HashMap, HashSet};
4
5use xlog_core::{RelId, Result, XlogError};
6use xlog_ir::rir::{LookupPerm, ProjectExpr, VariableOrder};
7use xlog_ir::{ExecutionPlan, JoinType, RirNode};
8
9use super::RelationDelta;
10use super::{DeltaRecomputeStats, Executor};
11
12fn triangle_delta_var_order(leader_idx: u8) -> VariableOrder {
13 let lookup_perms = match leader_idx {
14 0 => vec![
15 LookupPerm {
16 input_idx: 1,
17 swap_cols: false,
18 },
19 LookupPerm {
20 input_idx: 2,
21 swap_cols: false,
22 },
23 ],
24 1 => vec![
25 LookupPerm {
26 input_idx: 2,
27 swap_cols: true,
28 },
29 LookupPerm {
30 input_idx: 0,
31 swap_cols: true,
32 },
33 ],
34 2 => vec![
35 LookupPerm {
36 input_idx: 1,
37 swap_cols: true,
38 },
39 LookupPerm {
40 input_idx: 0,
41 swap_cols: false,
42 },
43 ],
44 _ => unreachable!("triangle leader_idx out of range"),
45 };
46 let kernel_output_cols = match leader_idx {
47 0 => vec![
48 ProjectExpr::Column(0),
49 ProjectExpr::Column(1),
50 ProjectExpr::Column(2),
51 ],
52 1 => vec![
53 ProjectExpr::Column(2),
54 ProjectExpr::Column(0),
55 ProjectExpr::Column(1),
56 ],
57 2 => vec![
58 ProjectExpr::Column(0),
59 ProjectExpr::Column(2),
60 ProjectExpr::Column(1),
61 ],
62 _ => unreachable!("triangle leader_idx out of range"),
63 };
64 VariableOrder::legacy(leader_idx, lookup_perms, kernel_output_cols)
65}
66
67fn cycle4_delta_var_order(leader_idx: u8) -> VariableOrder {
68 let lookup_perms = (1..4)
69 .map(|offset| LookupPerm {
70 input_idx: ((leader_idx as usize + offset) % 4) as u8,
71 swap_cols: false,
72 })
73 .collect();
74 let kernel_output_cols = match leader_idx {
75 0 => vec![
76 ProjectExpr::Column(0),
77 ProjectExpr::Column(1),
78 ProjectExpr::Column(2),
79 ProjectExpr::Column(3),
80 ],
81 1 => vec![
82 ProjectExpr::Column(3),
83 ProjectExpr::Column(0),
84 ProjectExpr::Column(1),
85 ProjectExpr::Column(2),
86 ],
87 2 => vec![
88 ProjectExpr::Column(2),
89 ProjectExpr::Column(3),
90 ProjectExpr::Column(0),
91 ProjectExpr::Column(1),
92 ],
93 3 => vec![
94 ProjectExpr::Column(1),
95 ProjectExpr::Column(2),
96 ProjectExpr::Column(3),
97 ProjectExpr::Column(0),
98 ],
99 _ => unreachable!("4-cycle leader_idx out of range"),
100 };
101 VariableOrder::legacy(leader_idx, lookup_perms, kernel_output_cols)
102}
103
104fn delta_outermost_var_order(
105 input_count: usize,
106 replaced_input_idx: Option<usize>,
107 current: Option<&VariableOrder>,
108) -> Option<VariableOrder> {
109 let Some(idx) = replaced_input_idx else {
110 return current.cloned();
111 };
112 if current.and_then(|order| order.kclique.as_ref()).is_some() {
113 return current.cloned();
114 }
115 if idx == 0 {
116 return None;
117 }
118 match input_count {
119 3 if idx < 3 => Some(triangle_delta_var_order(idx as u8)),
120 4 if idx < 4 => Some(cycle4_delta_var_order(idx as u8)),
121 _ => current.cloned(),
122 }
123}
124
125impl Executor {
126 pub fn apply_deltas_and_recompute(
131 &mut self,
132 plan: &ExecutionPlan,
133 deltas: &HashMap<String, RelationDelta>,
134 ) -> Result<DeltaRecomputeStats> {
135 if deltas.is_empty() {
136 return Ok(DeltaRecomputeStats::default());
137 }
138
139 let has_deletes = deltas
140 .values()
141 .any(|d| d.delete.as_ref().map(|b| !b.is_empty()).unwrap_or(false));
142
143 for (name, delta) in deltas {
145 let existing = self.store.get(name);
146
147 let base_schema = existing
148 .map(|b| b.schema().clone())
149 .or_else(|| delta.insert.as_ref().map(|b| b.schema().clone()))
150 .or_else(|| delta.delete.as_ref().map(|b| b.schema().clone()))
151 .ok_or_else(|| {
152 XlogError::Execution(format!(
153 "Delta update for {} has no existing relation and no schema",
154 name
155 ))
156 })?;
157
158 let mut updated = if let Some(buf) = existing {
159 self.clone_buffer(buf)?
160 } else {
161 self.create_empty_buffer(base_schema)?
162 };
163
164 if let Some(delete_buf) = &delta.delete {
165 updated = self.provider.diff_gpu(&updated, delete_buf)?;
166 }
167 if let Some(insert_buf) = &delta.insert {
168 updated = self.provider.union_gpu(&updated, insert_buf)?;
169 }
170
171 self.store_put(name, updated);
172 }
173
174 let changed_preds: HashSet<&str> = deltas.keys().map(|s| s.as_str()).collect();
176
177 let mut pred_to_scc: HashMap<&str, u32> = HashMap::new();
178 for scc in &plan.sccs {
179 for pred in &scc.predicates {
180 pred_to_scc.insert(pred.as_str(), scc.id);
181 }
182 }
183
184 let mut dependents: HashMap<u32, Vec<u32>> = HashMap::new();
185 for (scc_id, rules) in plan.rules_by_scc.iter().enumerate() {
186 let scc_id = scc_id as u32;
187 for rule in rules {
188 let mut rels = Vec::new();
189 Self::collect_scan_rels(&rule.body, &mut rels);
190 for rel in rels {
191 let Some(name) = self.get_rel_name(rel) else {
192 continue;
193 };
194 let Some(&dep_scc) = pred_to_scc.get(name) else {
195 continue;
196 };
197 if dep_scc == scc_id {
198 continue;
199 }
200 dependents.entry(dep_scc).or_default().push(scc_id);
201 }
202 }
203 }
204
205 let mut affected: HashSet<u32> = HashSet::new();
206 let mut queue: Vec<u32> = Vec::new();
207 for pred in &changed_preds {
208 if let Some(&scc) = pred_to_scc.get(*pred) {
209 affected.insert(scc);
210 queue.push(scc);
211 }
212 }
213
214 while let Some(scc) = queue.pop() {
215 if let Some(deps) = dependents.get(&scc) {
216 for &next in deps {
217 if affected.insert(next) {
218 queue.push(next);
219 }
220 }
221 }
222 }
223
224 if affected.is_empty() {
225 return Ok(DeltaRecomputeStats {
226 changed_relations: deltas.len(),
227 has_deletes,
228 affected_sccs: 0,
229 recomputed_sccs: 0,
230 incremental_sccs: 0,
231 });
232 }
233
234 fn contains_non_monotonic_ops(node: &RirNode) -> bool {
235 match node {
236 RirNode::Unit | RirNode::Scan { .. } => false,
237 RirNode::Filter { input, .. }
238 | RirNode::Project { input, .. }
239 | RirNode::Distinct { input, .. } => contains_non_monotonic_ops(input),
240 RirNode::Union { inputs } => inputs.iter().any(contains_non_monotonic_ops),
241 RirNode::GroupBy { .. } | RirNode::Diff { .. } => true,
242 RirNode::Join {
243 left,
244 right,
245 join_type,
246 ..
247 } => {
248 matches!(join_type, JoinType::Anti | JoinType::LeftOuter)
249 || contains_non_monotonic_ops(left)
250 || contains_non_monotonic_ops(right)
251 }
252 RirNode::Fixpoint {
253 base, recursive, ..
254 } => contains_non_monotonic_ops(base) || contains_non_monotonic_ops(recursive),
255 RirNode::TensorMaskedJoin { .. } => false,
256 RirNode::MultiWayJoin { fallback, .. } | RirNode::ChainJoin { fallback, .. } => {
260 contains_non_monotonic_ops(fallback)
261 }
262 }
263 }
264
265 let mut recompute_sccs: HashSet<u32> = HashSet::new();
271 if has_deletes {
272 recompute_sccs = affected.clone();
273 } else {
274 for &scc_id in &affected {
275 if let Some(rules) = plan.rules_by_scc.get(scc_id as usize) {
276 if rules.iter().any(|r| contains_non_monotonic_ops(&r.body)) {
277 recompute_sccs.insert(scc_id);
278 }
279 }
280 }
281
282 let mut queue: Vec<u32> = recompute_sccs.iter().copied().collect();
285 while let Some(scc) = queue.pop() {
286 if let Some(deps) = dependents.get(&scc) {
287 for &next in deps {
288 if !affected.contains(&next) {
289 continue;
290 }
291 if recompute_sccs.insert(next) {
292 queue.push(next);
293 }
294 }
295 }
296 }
297 }
298
299 for scc_id in &recompute_sccs {
301 let Some(scc) = plan.sccs.iter().find(|s| s.id == *scc_id) else {
302 continue;
303 };
304
305 for pred in &scc.predicates {
306 if changed_preds.contains(pred.as_str()) {
307 continue;
308 }
309 let schema = self
310 .store
311 .get(pred)
312 .map(|b| b.schema().clone())
313 .or_else(|| {
314 plan.rules_by_scc
315 .get(*scc_id as usize)
316 .and_then(|rules| rules.iter().find(|r| r.head == pred.as_str()))
317 .and_then(|r| {
318 let schema = r.meta.schema.clone();
319 if schema.arity() > 0 {
320 Some(schema)
321 } else {
322 None
323 }
324 })
325 })
326 .ok_or_else(|| {
327 XlogError::Execution(format!(
328 "Missing schema for predicate {} during recompute",
329 pred
330 ))
331 })?;
332
333 let empty = self.create_empty_buffer(schema)?;
334 self.store_put(pred, empty);
335 }
336 }
337
338 for stratum in &plan.strata {
340 for &scc_id in &stratum.sccs {
341 if !affected.contains(&scc_id) {
342 continue;
343 }
344 let rules = plan.rules_by_scc.get(scc_id as usize).ok_or_else(|| {
345 XlogError::Execution(format!("Missing rules for SCC {}", scc_id))
346 })?;
347 let is_recursive = plan
348 .sccs
349 .iter()
350 .find(|s| s.id == scc_id)
351 .map(|s| s.is_recursive)
352 .unwrap_or(false);
353
354 if is_recursive {
355 self.execute_recursive_scc(rules)?;
356 } else {
357 self.execute_non_recursive_scc(rules)?;
358 }
359 }
360 }
361
362 Ok(DeltaRecomputeStats {
363 changed_relations: deltas.len(),
364 has_deletes,
365 affected_sccs: affected.len(),
366 recomputed_sccs: recompute_sccs.len(),
367 incremental_sccs: affected.len().saturating_sub(recompute_sccs.len()),
368 })
369 }
370
371 pub(crate) fn collect_scan_rels(node: &RirNode, out: &mut Vec<RelId>) {
372 match node {
373 RirNode::Unit => {}
374 RirNode::Scan { rel } => out.push(*rel),
375 RirNode::Filter { input, .. } | RirNode::Project { input, .. } => {
376 Self::collect_scan_rels(input, out);
377 }
378 RirNode::Join { left, right, .. }
379 | RirNode::ChainJoin { left, right, .. }
380 | RirNode::Diff { left, right } => {
381 Self::collect_scan_rels(left, out);
382 Self::collect_scan_rels(right, out);
383 }
384 RirNode::GroupBy { input, .. } | RirNode::Distinct { input, .. } => {
385 Self::collect_scan_rels(input, out);
386 }
387 RirNode::Union { inputs } => {
388 for input in inputs {
389 Self::collect_scan_rels(input, out);
390 }
391 }
392 RirNode::Fixpoint {
393 base, recursive, ..
394 } => {
395 Self::collect_scan_rels(base, out);
396 Self::collect_scan_rels(recursive, out);
397 }
398 RirNode::TensorMaskedJoin { rel_index, .. } => {
399 for (rel_id, _) in rel_index {
400 out.push(*rel_id);
401 }
402 }
403 RirNode::MultiWayJoin { inputs, .. } => {
406 for input in inputs {
407 Self::collect_scan_rels(input, out);
408 }
409 }
410 }
411 }
412
413 pub(crate) fn rewrite_scan_nth(
414 node: &RirNode,
415 target: RelId,
416 nth: usize,
417 replacement: RelId,
418 ) -> Option<RirNode> {
419 let mut remaining = nth;
420 let (rewritten, replaced) =
421 Self::rewrite_scan_nth_impl(node, target, &mut remaining, replacement);
422 replaced.then_some(rewritten)
423 }
424
425 fn rewrite_scan_nth_impl(
426 node: &RirNode,
427 target: RelId,
428 remaining: &mut usize,
429 replacement: RelId,
430 ) -> (RirNode, bool) {
431 match node {
432 RirNode::Unit => (RirNode::Unit, false),
433 RirNode::Scan { rel } => {
434 if *rel == target {
435 if *remaining == 0 {
436 *remaining = usize::MAX;
444 return (RirNode::Scan { rel: replacement }, true);
445 }
446 if *remaining != usize::MAX {
447 *remaining -= 1;
448 }
449 }
450 (node.clone(), false)
451 }
452
453 RirNode::Filter { input, predicate } => {
454 let (new_input, replaced) =
455 Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
456 (
457 RirNode::Filter {
458 input: Box::new(new_input),
459 predicate: predicate.clone(),
460 },
461 replaced,
462 )
463 }
464
465 RirNode::Project { input, columns } => {
466 let (new_input, replaced) =
467 Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
468 (
469 RirNode::Project {
470 input: Box::new(new_input),
471 columns: columns.clone(),
472 },
473 replaced,
474 )
475 }
476
477 RirNode::Join {
478 left,
479 right,
480 left_keys,
481 right_keys,
482 join_type,
483 } => {
484 let (new_left, replaced_left) =
485 Self::rewrite_scan_nth_impl(left, target, remaining, replacement);
486 if replaced_left {
487 return (
488 RirNode::Join {
489 left: Box::new(new_left),
490 right: right.clone(),
491 left_keys: left_keys.clone(),
492 right_keys: right_keys.clone(),
493 join_type: *join_type,
494 },
495 true,
496 );
497 }
498 let (new_right, replaced_right) =
499 Self::rewrite_scan_nth_impl(right, target, remaining, replacement);
500 (
501 RirNode::Join {
502 left: Box::new(new_left),
503 right: Box::new(new_right),
504 left_keys: left_keys.clone(),
505 right_keys: right_keys.clone(),
506 join_type: *join_type,
507 },
508 replaced_right,
509 )
510 }
511
512 RirNode::GroupBy {
513 input,
514 key_cols,
515 aggs,
516 } => {
517 let (new_input, replaced) =
518 Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
519 (
520 RirNode::GroupBy {
521 input: Box::new(new_input),
522 key_cols: key_cols.clone(),
523 aggs: aggs.clone(),
524 },
525 replaced,
526 )
527 }
528
529 RirNode::Union { inputs } => {
530 let mut replaced_any = false;
531 let mut new_inputs = Vec::with_capacity(inputs.len());
532 for input in inputs {
533 let (new_input, replaced) =
534 Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
535 replaced_any |= replaced;
536 new_inputs.push(new_input);
537 }
538 (RirNode::Union { inputs: new_inputs }, replaced_any)
539 }
540
541 RirNode::Distinct { input, key_cols } => {
542 let (new_input, replaced) =
543 Self::rewrite_scan_nth_impl(input, target, remaining, replacement);
544 (
545 RirNode::Distinct {
546 input: Box::new(new_input),
547 key_cols: key_cols.clone(),
548 },
549 replaced,
550 )
551 }
552
553 RirNode::Diff { left, right } => {
554 let (new_left, replaced_left) =
555 Self::rewrite_scan_nth_impl(left, target, remaining, replacement);
556 if replaced_left {
557 return (
558 RirNode::Diff {
559 left: Box::new(new_left),
560 right: right.clone(),
561 },
562 true,
563 );
564 }
565 let (new_right, replaced_right) =
566 Self::rewrite_scan_nth_impl(right, target, remaining, replacement);
567 (
568 RirNode::Diff {
569 left: Box::new(new_left),
570 right: Box::new(new_right),
571 },
572 replaced_right,
573 )
574 }
575
576 RirNode::Fixpoint {
577 scc_id,
578 base,
579 recursive,
580 delta_rel,
581 full_rel,
582 } => {
583 let (new_base, replaced_base) =
584 Self::rewrite_scan_nth_impl(base, target, remaining, replacement);
585 if replaced_base {
586 return (
587 RirNode::Fixpoint {
588 scc_id: *scc_id,
589 base: Box::new(new_base),
590 recursive: recursive.clone(),
591 delta_rel: *delta_rel,
592 full_rel: *full_rel,
593 },
594 true,
595 );
596 }
597 let (new_recursive, replaced_recursive) =
598 Self::rewrite_scan_nth_impl(recursive, target, remaining, replacement);
599 (
600 RirNode::Fixpoint {
601 scc_id: *scc_id,
602 base: Box::new(new_base),
603 recursive: Box::new(new_recursive),
604 delta_rel: *delta_rel,
605 full_rel: *full_rel,
606 },
607 replaced_recursive,
608 )
609 }
610 RirNode::TensorMaskedJoin { .. } => {
611 (node.clone(), false)
613 }
614 RirNode::ChainJoin {
615 left,
616 right,
617 left_key,
618 right_key,
619 output_columns,
620 fallback,
621 } => {
622 let starting_remaining = *remaining;
623 let mut inputs_remaining = starting_remaining;
624 let (new_left, replaced_left) =
625 Self::rewrite_scan_nth_impl(left, target, &mut inputs_remaining, replacement);
626 let (new_right, replaced_right) =
627 Self::rewrite_scan_nth_impl(right, target, &mut inputs_remaining, replacement);
628 let mut fallback_remaining = starting_remaining;
629 let (new_fallback, fallback_replaced) = Self::rewrite_scan_nth_impl(
630 fallback,
631 target,
632 &mut fallback_remaining,
633 replacement,
634 );
635 *remaining = inputs_remaining;
636 (
637 RirNode::ChainJoin {
638 left: Box::new(new_left),
639 right: Box::new(new_right),
640 left_key: *left_key,
641 right_key: *right_key,
642 output_columns: output_columns.clone(),
643 fallback: Box::new(new_fallback),
644 },
645 replaced_left || replaced_right || fallback_replaced,
646 )
647 }
648 RirNode::MultiWayJoin {
658 inputs,
659 slot_vars,
660 output_columns,
661 fallback,
662 plan,
663 var_order,
664 } => {
665 let starting_remaining = *remaining;
666 let mut inputs_remaining = starting_remaining;
667 let mut new_inputs = Vec::with_capacity(inputs.len());
668 let mut any_replaced = false;
669 let mut replaced_input_idx = None;
670 for (idx, inp) in inputs.iter().enumerate() {
671 let (new_inp, replaced) = Self::rewrite_scan_nth_impl(
672 inp,
673 target,
674 &mut inputs_remaining,
675 replacement,
676 );
677 any_replaced |= replaced;
678 if replaced {
679 replaced_input_idx = Some(idx);
680 }
681 new_inputs.push(new_inp);
682 }
683 let mut fallback_remaining = starting_remaining;
684 let (new_fallback, fallback_replaced) = Self::rewrite_scan_nth_impl(
685 fallback,
686 target,
687 &mut fallback_remaining,
688 replacement,
689 );
690 *remaining = inputs_remaining;
691 let input_count = new_inputs.len();
692 (
693 RirNode::MultiWayJoin {
694 inputs: new_inputs,
695 slot_vars: slot_vars.clone(),
696 output_columns: output_columns.clone(),
697 fallback: Box::new(new_fallback),
698 plan: plan.clone(),
699 var_order: delta_outermost_var_order(
700 input_count,
701 replaced_input_idx,
702 var_order.as_ref(),
703 ),
704 },
705 any_replaced || fallback_replaced,
706 )
707 }
708 }
709 }
710}
711
712#[cfg(test)]
713mod multiway_walker_tests {
714 use super::*;
721 use xlog_ir::rir::ProjectExpr;
722
723 fn triangle_multiway(a: RelId, b: RelId, c: RelId) -> RirNode {
724 let scan_a = RirNode::Scan { rel: a };
725 let scan_b = RirNode::Scan { rel: b };
726 let scan_c = RirNode::Scan { rel: c };
727 let inner = RirNode::Join {
728 left: Box::new(scan_a.clone()),
729 right: Box::new(scan_b.clone()),
730 left_keys: vec![1],
731 right_keys: vec![0],
732 join_type: JoinType::Inner,
733 };
734 let outer = RirNode::Join {
735 left: Box::new(inner),
736 right: Box::new(scan_c.clone()),
737 left_keys: vec![0, 3],
738 right_keys: vec![0, 1],
739 join_type: JoinType::Inner,
740 };
741 let fallback = RirNode::Project {
742 input: Box::new(outer),
743 columns: vec![
744 ProjectExpr::Column(0),
745 ProjectExpr::Column(1),
746 ProjectExpr::Column(3),
747 ],
748 };
749 RirNode::MultiWayJoin {
750 inputs: vec![scan_a, scan_b, scan_c],
751 slot_vars: vec![
752 vec![Some(0), Some(1)],
753 vec![Some(1), Some(2)],
754 vec![Some(0), Some(2)],
755 ],
756 output_columns: vec![
757 ProjectExpr::Column(0),
758 ProjectExpr::Column(1),
759 ProjectExpr::Column(3),
760 ],
761 fallback: Box::new(fallback),
762 plan: None,
763 var_order: None,
764 }
765 }
766
767 #[test]
768 fn collect_scan_rels_walks_multiway_inputs_only() {
769 let node = triangle_multiway(RelId(10), RelId(20), RelId(30));
770 let mut out = Vec::new();
771 Executor::collect_scan_rels(&node, &mut out);
772 assert_eq!(out.len(), 3, "expected 3 scan rels, got: {:?}", out);
775 assert!(out.contains(&RelId(10)));
776 assert!(out.contains(&RelId(20)));
777 assert!(out.contains(&RelId(30)));
778 }
779
780 #[test]
789 fn rewrite_scan_nth_rewrites_inputs_and_fallback() {
790 let node = triangle_multiway(RelId(10), RelId(20), RelId(30));
791
792 let rewritten =
794 Executor::rewrite_scan_nth(&node, RelId(10), 0, RelId(99)).expect("occ=0 must succeed");
795 match rewritten {
796 RirNode::MultiWayJoin {
797 inputs, fallback, ..
798 } => {
799 assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(99) }));
801 assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(20) }));
803 assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(30) }));
804 fn find_rel(n: &RirNode, target: RelId) -> bool {
807 match n {
808 RirNode::Scan { rel } => *rel == target,
809 RirNode::Project { input, .. } => find_rel(input, target),
810 RirNode::Join { left, right, .. } => {
811 find_rel(left, target) || find_rel(right, target)
812 }
813 _ => false,
814 }
815 }
816 assert!(
817 find_rel(&fallback, RelId(99)),
818 "fallback must contain RelId(99) at the 0-th occurrence position"
819 );
820 assert!(
821 !find_rel(&fallback, RelId(10)),
822 "fallback must NOT contain RelId(10) — the only occurrence was substituted"
823 );
824 }
825 _ => panic!("expected MultiWayJoin after rewrite"),
826 }
827
828 assert!(
830 Executor::rewrite_scan_nth(&node, RelId(10), 1, RelId(99)).is_none(),
831 "occ=1 must return None — RelId(10) has only 1 occurrence per view"
832 );
833 }
834
835 fn fourway_multiway(a: RelId, b: RelId, c: RelId, d: RelId) -> RirNode {
844 let inner1 = RirNode::Join {
847 left: Box::new(RirNode::Scan { rel: a }),
848 right: Box::new(RirNode::Scan { rel: b }),
849 left_keys: vec![1],
850 right_keys: vec![0],
851 join_type: JoinType::Inner,
852 };
853 let inner2 = RirNode::Join {
854 left: Box::new(inner1),
855 right: Box::new(RirNode::Scan { rel: c }),
856 left_keys: vec![3],
857 right_keys: vec![0],
858 join_type: JoinType::Inner,
859 };
860 let outer = RirNode::Join {
861 left: Box::new(inner2),
862 right: Box::new(RirNode::Scan { rel: d }),
863 left_keys: vec![0, 5],
864 right_keys: vec![0, 1],
865 join_type: JoinType::Inner,
866 };
867 let fallback = RirNode::Project {
868 input: Box::new(outer),
869 columns: vec![
870 xlog_ir::rir::ProjectExpr::Column(0),
871 xlog_ir::rir::ProjectExpr::Column(1),
872 xlog_ir::rir::ProjectExpr::Column(3),
873 xlog_ir::rir::ProjectExpr::Column(5),
874 ],
875 };
876 RirNode::MultiWayJoin {
877 inputs: vec![
878 RirNode::Scan { rel: a },
879 RirNode::Scan { rel: b },
880 RirNode::Scan { rel: c },
881 RirNode::Scan { rel: d },
882 ],
883 slot_vars: vec![
884 vec![Some(0), Some(1)],
885 vec![Some(1), Some(2)],
886 vec![Some(2), Some(3)],
887 vec![Some(0), Some(3)],
888 ],
889 output_columns: vec![
890 xlog_ir::rir::ProjectExpr::Column(0),
891 xlog_ir::rir::ProjectExpr::Column(1),
892 xlog_ir::rir::ProjectExpr::Column(2),
893 xlog_ir::rir::ProjectExpr::Column(3),
894 ],
895 fallback: Box::new(fallback),
896 plan: None,
897 var_order: None,
898 }
899 }
900
901 #[test]
902 fn collect_scan_rels_handles_4_inputs() {
903 let node = fourway_multiway(RelId(10), RelId(20), RelId(30), RelId(40));
904 let mut out = Vec::new();
905 Executor::collect_scan_rels(&node, &mut out);
906 assert_eq!(
907 out.len(),
908 4,
909 "expected 4 scan rels, got {} entries: {:?}",
910 out.len(),
911 out
912 );
913 for id in [10, 20, 30, 40] {
914 assert!(out.contains(&RelId(id)), "RelId({}) missing", id);
915 }
916 }
917
918 #[test]
927 fn rewrite_scan_nth_handles_4_inputs_and_fallback() {
928 let node = fourway_multiway(RelId(10), RelId(20), RelId(30), RelId(40));
929
930 let rewritten =
932 Executor::rewrite_scan_nth(&node, RelId(40), 0, RelId(99)).expect("occ=0 must succeed");
933 let RirNode::MultiWayJoin {
934 inputs, fallback, ..
935 } = rewritten
936 else {
937 panic!("expected MultiWayJoin");
938 };
939 assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(10) }));
941 assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(20) }));
942 assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(30) }));
943 assert!(matches!(inputs[3], RirNode::Scan { rel: RelId(99) }));
944 fn find_rel(n: &RirNode, target: RelId) -> bool {
947 match n {
948 RirNode::Scan { rel } => *rel == target,
949 RirNode::Project { input, .. } => find_rel(input, target),
950 RirNode::Join { left, right, .. } => {
951 find_rel(left, target) || find_rel(right, target)
952 }
953 _ => false,
954 }
955 }
956 assert!(
957 find_rel(&fallback, RelId(99)),
958 "fallback must contain RelId(99) at the 0-th occurrence position"
959 );
960 assert!(
961 !find_rel(&fallback, RelId(40)),
962 "fallback must NOT contain RelId(40) — the only occurrence was substituted"
963 );
964
965 assert!(
967 Executor::rewrite_scan_nth(&node, RelId(40), 1, RelId(99)).is_none(),
968 "occ=1 must return None — RelId(40) has only 1 occurrence per view"
969 );
970 }
971
972 #[test]
973 fn delta_outermost_leader_selection_rebinds_triangle_variant() {
974 let node = triangle_multiway(RelId(10), RelId(20), RelId(30));
975 let rewritten =
976 Executor::rewrite_scan_nth(&node, RelId(20), 0, RelId(99)).expect("rewrite must hit");
977 let RirNode::MultiWayJoin { var_order, .. } = rewritten else {
978 panic!("expected MultiWayJoin");
979 };
980 let var_order = var_order.expect("rewritten input 1 must become leader");
981 assert_eq!(var_order.leader_idx, 1);
982 assert_eq!(var_order.lookup_perms.len(), 2);
983 assert_eq!(var_order.lookup_perms[0].input_idx, 2);
984 assert_eq!(var_order.lookup_perms[1].input_idx, 0);
985 }
986
987 #[test]
988 fn delta_outermost_leader_selection_rebinds_4cycle_variant() {
989 let node = fourway_multiway(RelId(10), RelId(20), RelId(30), RelId(40));
990 let rewritten =
991 Executor::rewrite_scan_nth(&node, RelId(40), 0, RelId(99)).expect("rewrite must hit");
992 let RirNode::MultiWayJoin { var_order, .. } = rewritten else {
993 panic!("expected MultiWayJoin");
994 };
995 let var_order = var_order.expect("rewritten input 3 must become leader");
996 assert_eq!(var_order.leader_idx, 3);
997 assert_eq!(var_order.lookup_perms.len(), 3);
998 assert_eq!(var_order.lookup_perms[0].input_idx, 0);
999 assert_eq!(var_order.lookup_perms[1].input_idx, 1);
1000 assert_eq!(var_order.lookup_perms[2].input_idx, 2);
1001 }
1002}
1003
1004#[cfg(test)]
1005mod rewrite_scan_nth_occurrence_identity_tests {
1006 use super::*;
1032 use xlog_ir::rir::ProjectExpr;
1033 use xlog_ir::JoinType;
1034
1035 fn chain_join(left_rel: RelId, right_rel: RelId) -> RirNode {
1036 let left = RirNode::Scan { rel: left_rel };
1037 let right = RirNode::Scan { rel: right_rel };
1038 let fallback = RirNode::Project {
1039 input: Box::new(RirNode::Join {
1040 left: Box::new(left.clone()),
1041 right: Box::new(right.clone()),
1042 left_keys: vec![1],
1043 right_keys: vec![0],
1044 join_type: JoinType::Inner,
1045 }),
1046 columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
1047 };
1048 RirNode::ChainJoin {
1049 left: Box::new(left),
1050 right: Box::new(right),
1051 left_key: 1,
1052 right_key: 0,
1053 output_columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
1054 fallback: Box::new(fallback),
1055 }
1056 }
1057
1058 fn three_same_predicate_multiway(target_rel: RelId) -> RirNode {
1071 let inputs = vec![
1072 RirNode::Scan { rel: target_rel },
1073 RirNode::Scan { rel: target_rel },
1074 RirNode::Scan { rel: target_rel },
1075 ];
1076 let inner = RirNode::Join {
1077 left: Box::new(RirNode::Scan { rel: target_rel }),
1078 right: Box::new(RirNode::Scan { rel: target_rel }),
1079 left_keys: vec![1],
1080 right_keys: vec![0],
1081 join_type: JoinType::Inner,
1082 };
1083 let outer = RirNode::Join {
1084 left: Box::new(inner),
1085 right: Box::new(RirNode::Scan { rel: target_rel }),
1086 left_keys: vec![0, 3],
1087 right_keys: vec![0, 1],
1088 join_type: JoinType::Inner,
1089 };
1090 let fallback = RirNode::Project {
1091 input: Box::new(outer),
1092 columns: vec![
1093 ProjectExpr::Column(0),
1094 ProjectExpr::Column(1),
1095 ProjectExpr::Column(3),
1096 ],
1097 };
1098 RirNode::MultiWayJoin {
1099 inputs,
1100 slot_vars: vec![
1101 vec![Some(0), Some(1)],
1102 vec![Some(1), Some(2)],
1103 vec![Some(0), Some(2)],
1104 ],
1105 output_columns: vec![
1106 ProjectExpr::Column(0),
1107 ProjectExpr::Column(1),
1108 ProjectExpr::Column(2),
1109 ],
1110 fallback: Box::new(fallback),
1111 plan: None,
1112 var_order: None,
1113 }
1114 }
1115
1116 fn collect_scans_in_order(node: &RirNode, out: &mut Vec<RelId>) {
1121 match node {
1122 RirNode::Unit => {}
1123 RirNode::Scan { rel } => out.push(*rel),
1124 RirNode::Filter { input, .. }
1125 | RirNode::Project { input, .. }
1126 | RirNode::GroupBy { input, .. }
1127 | RirNode::Distinct { input, .. } => collect_scans_in_order(input, out),
1128 RirNode::Join { left, right, .. } | RirNode::Diff { left, right } => {
1129 collect_scans_in_order(left, out);
1130 collect_scans_in_order(right, out);
1131 }
1132 RirNode::ChainJoin {
1133 left,
1134 right,
1135 fallback,
1136 ..
1137 } => {
1138 collect_scans_in_order(left, out);
1139 collect_scans_in_order(right, out);
1140 collect_scans_in_order(fallback, out);
1141 }
1142 RirNode::Union { inputs } => {
1143 for n in inputs {
1144 collect_scans_in_order(n, out);
1145 }
1146 }
1147 RirNode::Fixpoint {
1148 base, recursive, ..
1149 } => {
1150 collect_scans_in_order(base, out);
1151 collect_scans_in_order(recursive, out);
1152 }
1153 RirNode::TensorMaskedJoin { rel_index, .. } => {
1154 for (rid, _) in rel_index {
1155 out.push(*rid);
1156 }
1157 }
1158 RirNode::MultiWayJoin {
1159 inputs, fallback, ..
1160 } => {
1161 for inp in inputs {
1162 collect_scans_in_order(inp, out);
1163 }
1164 collect_scans_in_order(fallback, out);
1165 }
1166 }
1167 }
1168
1169 #[test]
1170 fn chain_join_rewrite_scan_nth_updates_dispatch_shape_and_fallback() {
1171 let target = RelId(7);
1172 let delta = RelId(700);
1173 let body = chain_join(target, RelId(8));
1174 let rewritten =
1175 Executor::rewrite_scan_nth(&body, target, 0, delta).expect("target scan must rewrite");
1176
1177 let mut scans = Vec::new();
1178 collect_scans_in_order(&rewritten, &mut scans);
1179 assert_eq!(
1180 scans,
1181 vec![delta, RelId(8), delta, RelId(8)],
1182 "ChainJoin dispatch inputs and fallback must target the same occurrence"
1183 );
1184
1185 let RirNode::ChainJoin { left, fallback, .. } = rewritten else {
1186 panic!("expected rewritten ChainJoin");
1187 };
1188 assert!(matches!(left.as_ref(), RirNode::Scan { rel } if *rel == delta));
1189 let RirNode::Project { input, .. } = fallback.as_ref() else {
1190 panic!("expected ChainJoin fallback Project");
1191 };
1192 let RirNode::Join { left, .. } = input.as_ref() else {
1193 panic!("expected ChainJoin fallback Join");
1194 };
1195 assert!(matches!(left.as_ref(), RirNode::Scan { rel } if *rel == delta));
1196 }
1197
1198 #[test]
1210 fn rewrite_scan_nth_replaces_exact_kth_occurrence_in_inputs_and_fallback() {
1211 let target = RelId(7);
1212 let body = three_same_predicate_multiway(target);
1213
1214 let mut pre = Vec::new();
1218 collect_scans_in_order(&body, &mut pre);
1219 assert_eq!(
1220 pre,
1221 vec![target, target, target, target, target, target],
1222 "pre-rewrite: 6 target Scans in canonical walk order"
1223 );
1224
1225 for occ in 0..3 {
1229 let replacement = RelId(100 + occ as u32);
1234 let rewritten = Executor::rewrite_scan_nth(&body, target, occ, replacement)
1235 .unwrap_or_else(|| panic!("occ={} must succeed", occ));
1236
1237 let mut post = Vec::new();
1238 collect_scans_in_order(&rewritten, &mut post);
1239
1240 let mut expected = vec![target; 6];
1244 expected[occ] = replacement; expected[3 + occ] = replacement; assert_eq!(
1248 post, expected,
1249 "occ={}: post-rewrite Scan order must replace EXACTLY the k-th occurrence in inputs AND fallback; got {:?}, expected {:?}",
1250 occ, post, expected
1251 );
1252 }
1253 }
1254
1255 #[test]
1268 fn rewrite_scan_nth_input_fallback_symmetry_at_occ_0() {
1269 let target = RelId(7);
1270 let replacement = RelId(99);
1271 let body = three_same_predicate_multiway(target);
1272
1273 let rewritten =
1274 Executor::rewrite_scan_nth(&body, target, 0, replacement).expect("occ=0 must succeed");
1275
1276 match rewritten {
1277 RirNode::MultiWayJoin {
1278 inputs, fallback, ..
1279 } => {
1280 assert!(
1283 matches!(inputs[0], RirNode::Scan { rel } if rel == replacement),
1284 "input[0] must be replacement; got {:?}",
1285 inputs[0]
1286 );
1287 assert!(matches!(inputs[1], RirNode::Scan { rel } if rel == target));
1289 assert!(matches!(inputs[2], RirNode::Scan { rel } if rel == target));
1290 let mut fallback_scans = Vec::new();
1294 collect_scans_in_order(&fallback, &mut fallback_scans);
1295 assert_eq!(
1296 fallback_scans,
1297 vec![replacement, target, target],
1298 "fallback walk order: occ=0 must replace position 0 only"
1299 );
1300 }
1301 _ => panic!("expected MultiWayJoin after rewrite"),
1302 }
1303 }
1304}