1use std::collections::HashMap;
80use xlog_core::RelId;
81use xlog_ir::rir::{
82 ColumnSwap, CostPredictionRecord as RirCostPredictionRecord, KCliqueVariableOrder,
83 MultiwayPlan, PlannedHashReason, ProjectExpr, SortedLayoutSpec, StreamGroupId, VariableOrder,
84 K_CLIQUE_MAX_EDGES, K_CLIQUE_MAX_K,
85};
86use xlog_ir::{ExecutionPlan, JoinType, RirNode};
87use xlog_stats::{StatsManager, StatsSnapshot};
88
89use crate::compiler_config::CompilerConfig;
90use crate::hypergraph::var_order::{
91 plan_kclique_var_order, FullVariableOrder, KCliqueEdge, KCliqueShape,
92};
93use crate::hypergraph::VertexId;
94use crate::wcoj_var_ordering::{wcoj_cost_gate_predicts_wcoj, WcojVariableOrderingModel};
95
96pub fn promote_multiway(
116 plan: &mut ExecutionPlan,
117 _rel_ids: &HashMap<String, RelId>,
118 stats: &StatsManager,
119 config: &CompilerConfig,
120) {
121 let rel_arities = plan.rel_arities.clone();
126 for (scc_id, rules) in plan.rules_by_scc.iter_mut().enumerate() {
127 if plan.sccs.get(scc_id).is_none() {
128 continue;
129 }
130 for rule in rules.iter_mut() {
131 if try_promote_triangle_inside_aggregate(&mut rule.body, stats, config) {
158 continue;
159 }
160 if try_promote_4cycle_inside_aggregate(&mut rule.body, stats, config) {
165 continue;
166 }
167 if try_promote_clique_inside_aggregate(&mut rule.body, stats) {
173 continue;
174 }
175 if try_promote_general_multiway_inside_aggregate(&mut rule.body, &rel_arities) {
183 continue;
184 }
185 if let Some(promoted) = try_promote_chain(&rule.body) {
186 rule.body = promoted;
187 continue;
188 }
189 let normalized_tri = normalize_triangle_to_left_deep(&rule.body);
199 let body_for_tri = normalized_tri.as_ref().unwrap_or(&rule.body);
200 if let Some(promoted) = try_promote_triangle(body_for_tri, stats, config) {
201 rule.body = promoted;
202 continue;
203 }
204 let normalized_4c = normalize_4cycle_to_bushy(&rule.body);
205 let body_for_4c = normalized_4c.as_ref().unwrap_or(&rule.body);
206 if let Some(promoted) = try_promote_4cycle(body_for_4c, stats, config) {
207 rule.body = promoted;
208 continue;
209 }
210 if let Some(promoted) = try_promote_clique_k(&rule.body, 5, stats)
218 .or_else(|| try_promote_clique_k(&rule.body, 6, stats))
219 .or_else(|| try_promote_clique_k(&rule.body, 7, stats))
220 .or_else(|| try_promote_clique_k(&rule.body, 8, stats))
221 {
222 rule.body = promoted;
223 continue;
224 }
225 if let Some(promoted) = try_promote_general_multiway(&rule.body, &rel_arities) {
234 rule.body = promoted;
235 continue;
236 }
237 }
238 }
239}
240
241fn ac_idx(atom_idx: u8, col_idx: u8) -> u8 {
245 debug_assert!(atom_idx < 3);
246 debug_assert!(col_idx < 2);
247 atom_idx * 2 + col_idx
248}
249
250fn inner_output_ac(k: usize) -> Option<(u8, u8)> {
254 match k {
255 0 => Some((0, 0)),
256 1 => Some((0, 1)),
257 2 => Some((1, 0)),
258 3 => Some((1, 1)),
259 _ => None,
260 }
261}
262
263fn outer_output_ac(k: usize) -> Option<(u8, u8)> {
267 match k {
268 0..=3 => inner_output_ac(k),
269 4 => Some((2, 0)),
270 5 => Some((2, 1)),
271 _ => None,
272 }
273}
274
275fn uf_find(parent: &mut [u8; 6], x: u8) -> u8 {
278 let mut root = x;
279 while parent[root as usize] != root {
280 root = parent[root as usize];
281 }
282 let mut cur = x;
283 while parent[cur as usize] != root {
284 let next = parent[cur as usize];
285 parent[cur as usize] = root;
286 cur = next;
287 }
288 root
289}
290
291fn uf_union(parent: &mut [u8; 6], a: u8, b: u8) {
292 let ra = uf_find(parent, a);
293 let rb = uf_find(parent, b);
294 if ra != rb {
295 parent[rb as usize] = ra;
296 }
297}
298
299#[allow(clippy::too_many_arguments)]
317fn infer_triangle_semantics(
318 inner_left_rel: RelId,
319 inner_right_rel: RelId,
320 outer_third_rel: RelId,
321 lk2: &[usize],
322 rk2: &[usize],
323 lk1: &[usize],
324 rk1: &[usize],
325 project_cols: &[ProjectExpr],
326) -> Option<(RelId, RelId, RelId)> {
327 if lk2.len() != 1 || rk2.len() != 1 {
328 return None;
329 }
330 if lk1.len() != 2 || rk1.len() != 2 {
331 return None;
332 }
333 if project_cols.len() != 3 {
334 return None;
335 }
336 if lk2[0] >= 2 || rk2[0] >= 2 {
337 return None;
338 }
339 if lk1.iter().any(|k| *k >= 4) || rk1.iter().any(|k| *k >= 2) {
340 return None;
341 }
342
343 let mut parent = [0u8, 1, 2, 3, 4, 5];
346
347 uf_union(
349 &mut parent,
350 ac_idx(0, lk2[0] as u8),
351 ac_idx(1, rk2[0] as u8),
352 );
353
354 for i in 0..2 {
356 let (inner_atom, inner_col) = inner_output_ac(lk1[i])?;
357 uf_union(
358 &mut parent,
359 ac_idx(inner_atom, inner_col),
360 ac_idx(2, rk1[i] as u8),
361 );
362 }
363
364 let roots: [u8; 6] = std::array::from_fn(|i| uf_find(&mut parent, i as u8));
367 let mut counts: HashMap<u8, u8> = HashMap::new();
368 for r in &roots {
369 *counts.entry(*r).or_insert(0) += 1;
370 }
371 if counts.len() != 3 || counts.values().any(|c| *c != 2) {
372 return None;
373 }
374
375 let mut head_classes: [u8; 3] = [0; 3];
378 for (i, pc) in project_cols.iter().enumerate() {
379 let ProjectExpr::Column(k) = pc else {
380 return None;
381 };
382 let (atom, col) = outer_output_ac(*k)?;
383 head_classes[i] = uf_find(&mut parent, ac_idx(atom, col));
384 }
385 if head_classes[0] == head_classes[1]
387 || head_classes[0] == head_classes[2]
388 || head_classes[1] == head_classes[2]
389 {
390 return None;
391 }
392 let x_class = head_classes[0];
393 let y_class = head_classes[1];
394 let z_class = head_classes[2];
395
396 let atom_classes = |atom_idx: u8| -> (u8, u8) {
399 (
400 roots[ac_idx(atom_idx, 0) as usize],
401 roots[ac_idx(atom_idx, 1) as usize],
402 )
403 };
404
405 let atom_rels = [inner_left_rel, inner_right_rel, outer_third_rel];
406 let mut rel_xy: Option<RelId> = None;
407 let mut rel_yz: Option<RelId> = None;
408 let mut rel_xz: Option<RelId> = None;
409
410 for atom_idx in 0..3u8 {
411 let (c0, c1) = atom_classes(atom_idx);
412 let binds_x = c0 == x_class || c1 == x_class;
413 let binds_y = c0 == y_class || c1 == y_class;
414 let binds_z = c0 == z_class || c1 == z_class;
415 match (binds_x, binds_y, binds_z) {
416 (true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
417 (false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
418 (true, false, true) => rel_xz = Some(atom_rels[atom_idx as usize]),
419 _ => return None,
420 }
421 }
422
423 Some((rel_xy?, rel_yz?, rel_xz?))
424}
425
426fn normalize_triangle_to_left_deep(node: &RirNode) -> Option<RirNode> {
446 let RirNode::Project {
447 input: outer_input,
448 columns,
449 } = node
450 else {
451 return None;
452 };
453 let RirNode::Join {
454 left: outer_l,
455 right: outer_r,
456 left_keys: outer_lk,
457 right_keys: outer_rk,
458 join_type: outer_jt,
459 } = outer_input.as_ref()
460 else {
461 return None;
462 };
463 if !matches!(outer_jt, JoinType::Inner) {
464 return None;
465 }
466 let RirNode::Scan { rel: _ } = outer_l.as_ref() else {
468 return None;
469 };
470 let RirNode::Join { .. } = outer_r.as_ref() else {
471 return None;
472 };
473 let RirNode::Join {
476 left: inner_l,
477 right: inner_r,
478 ..
479 } = outer_r.as_ref()
480 else {
481 return None;
482 };
483 if !matches!(inner_l.as_ref(), RirNode::Scan { .. })
484 || !matches!(inner_r.as_ref(), RirNode::Scan { .. })
485 {
486 return None;
487 }
488 let new_outer = RirNode::Join {
493 left: outer_r.clone(),
494 right: outer_l.clone(),
495 left_keys: outer_rk.clone(),
496 right_keys: outer_lk.clone(),
497 join_type: JoinType::Inner,
498 };
499 let new_columns: Vec<ProjectExpr> = columns
501 .iter()
502 .map(|expr| match expr {
503 ProjectExpr::Column(k) => ProjectExpr::Column((*k + 4) % 6),
504 other => other.clone(),
505 })
506 .collect();
507 Some(RirNode::Project {
508 input: Box::new(new_outer),
509 columns: new_columns,
510 })
511}
512
513fn normalize_4cycle_to_bushy(node: &RirNode) -> Option<RirNode> {
530 let RirNode::Project {
531 input: outer_input,
532 columns,
533 } = node
534 else {
535 return None;
536 };
537 let RirNode::Join {
538 left: outer_l,
539 right: outer_r,
540 left_keys: outer_lk,
541 right_keys: outer_rk,
542 join_type: outer_jt,
543 } = outer_input.as_ref()
544 else {
545 return None;
546 };
547 if !matches!(outer_jt, JoinType::Inner) {
548 return None;
549 }
550 let RirNode::Scan { rel: r0 } = outer_l.as_ref() else {
553 return None;
554 };
555 let RirNode::Join {
556 left: middle_l,
557 right: middle_r,
558 left_keys: middle_lk,
559 right_keys: middle_rk,
560 join_type: middle_jt,
561 } = outer_r.as_ref()
562 else {
563 return None;
564 };
565 if !matches!(middle_jt, JoinType::Inner) {
566 return None;
567 }
568 let RirNode::Scan { rel: r1 } = middle_l.as_ref() else {
569 return None;
570 };
571 let RirNode::Join {
572 left: deep_l,
573 right: deep_r,
574 left_keys: deep_lk,
575 right_keys: deep_rk,
576 join_type: deep_jt,
577 } = middle_r.as_ref()
578 else {
579 return None;
580 };
581 if !matches!(deep_jt, JoinType::Inner) {
582 return None;
583 }
584 let RirNode::Scan { rel: r2 } = deep_l.as_ref() else {
585 return None;
586 };
587 let RirNode::Scan { rel: r3 } = deep_r.as_ref() else {
588 return None;
589 };
590 if outer_lk.as_slice() != [0, 1] || outer_rk.as_slice() != [5, 0] {
592 return None;
593 }
594 if middle_lk.as_slice() != [1] || middle_rk.as_slice() != [0] {
595 return None;
596 }
597 if deep_lk.as_slice() != [1] || deep_rk.as_slice() != [0] {
598 return None;
599 }
600 let inner_left = RirNode::Join {
608 left: Box::new(RirNode::Scan { rel: *r0 }),
609 right: Box::new(RirNode::Scan { rel: *r1 }),
610 left_keys: vec![1],
611 right_keys: vec![0],
612 join_type: JoinType::Inner,
613 };
614 let inner_right = RirNode::Join {
615 left: Box::new(RirNode::Scan { rel: *r2 }),
616 right: Box::new(RirNode::Scan { rel: *r3 }),
617 left_keys: vec![1],
618 right_keys: vec![0],
619 join_type: JoinType::Inner,
620 };
621 let new_outer = RirNode::Join {
622 left: Box::new(inner_left),
623 right: Box::new(inner_right),
624 left_keys: vec![3, 0],
625 right_keys: vec![0, 3],
626 join_type: JoinType::Inner,
627 };
628 Some(RirNode::Project {
629 input: Box::new(new_outer),
630 columns: columns.clone(),
631 })
632}
633
634fn try_promote_triangle_inside_aggregate(
646 body: &mut RirNode,
647 stats: &StatsManager,
648 config: &CompilerConfig,
649) -> bool {
650 let RirNode::Project { input: gb, .. } = body else {
651 return false;
652 };
653 let RirNode::GroupBy {
654 input: group_input, ..
655 } = gb.as_mut()
656 else {
657 return false;
658 };
659 let RirNode::Project {
660 input: inner,
661 columns: group_cols,
662 } = group_input.as_mut()
663 else {
664 return false;
665 };
666 let canonical = RirNode::Project {
669 input: inner.clone(),
670 columns: vec![
671 ProjectExpr::Column(0),
672 ProjectExpr::Column(1),
673 ProjectExpr::Column(3),
674 ],
675 };
676 let normalized = normalize_triangle_to_left_deep(&canonical);
677 let candidate = normalized.as_ref().unwrap_or(&canonical);
678 let Some(promoted) = try_promote_triangle(candidate, stats, config) else {
679 return false;
680 };
681 let RirNode::MultiWayJoin { output_columns, .. } = &promoted else {
682 return false;
683 };
684 let mut remapped: Vec<ProjectExpr> = Vec::with_capacity(group_cols.len());
685 for col in group_cols.iter() {
686 let ProjectExpr::Column(c) = col else {
687 return false;
688 };
689 let Some(pos) = output_columns
690 .iter()
691 .position(|oc| matches!(oc, ProjectExpr::Column(x) if x == c))
692 else {
693 return false;
694 };
695 remapped.push(ProjectExpr::Column(pos));
696 }
697 *group_cols = remapped;
698 **inner = promoted;
699 true
700}
701
702fn try_promote_4cycle_inside_aggregate(
715 body: &mut RirNode,
716 stats: &StatsManager,
717 config: &CompilerConfig,
718) -> bool {
719 let RirNode::Project { input: gb, .. } = body else {
720 return false;
721 };
722 let RirNode::GroupBy {
723 input: group_input, ..
724 } = gb.as_mut()
725 else {
726 return false;
727 };
728 let RirNode::Project {
729 input: inner,
730 columns: group_cols,
731 } = group_input.as_mut()
732 else {
733 return false;
734 };
735 let canonical = RirNode::Project {
738 input: inner.clone(),
739 columns: vec![
740 ProjectExpr::Column(0),
741 ProjectExpr::Column(1),
742 ProjectExpr::Column(3),
743 ProjectExpr::Column(5),
744 ],
745 };
746 let normalized = normalize_4cycle_to_bushy(&canonical);
747 let candidate = normalized.as_ref().unwrap_or(&canonical);
748 let Some(promoted) = try_promote_4cycle(candidate, stats, config) else {
749 return false;
750 };
751 let RirNode::MultiWayJoin { output_columns, .. } = &promoted else {
752 return false;
753 };
754 let mut remapped: Vec<ProjectExpr> = Vec::with_capacity(group_cols.len());
755 for col in group_cols.iter() {
756 let ProjectExpr::Column(c) = col else {
757 return false;
758 };
759 let Some(pos) = output_columns
760 .iter()
761 .position(|oc| matches!(oc, ProjectExpr::Column(x) if x == c))
762 else {
763 return false;
764 };
765 remapped.push(ProjectExpr::Column(pos));
766 }
767 *group_cols = remapped;
768 **inner = promoted;
769 true
770}
771
772fn try_promote_clique_inside_aggregate(body: &mut RirNode, stats: &StatsManager) -> bool {
791 let RirNode::Project { input: gb, .. } = body else {
792 return false;
793 };
794 let RirNode::GroupBy {
795 input: group_input, ..
796 } = gb.as_mut()
797 else {
798 return false;
799 };
800 let RirNode::Project {
801 input: inner,
802 columns: group_cols,
803 } = group_input.as_mut()
804 else {
805 return false;
806 };
807 let mut scans: Vec<RelId> = Vec::new();
810 let mut key_pairs: Vec<(usize, usize)> = Vec::new();
811 if walk_clique_node(inner, &mut scans, &mut key_pairs).is_none() {
812 return false;
813 }
814 let k = match scans.len() {
815 10 => 5,
816 15 => 6,
817 _ => return false,
818 };
819 let n_slots = 2 * scans.len();
822 let mut parent: Vec<usize> = (0..n_slots).collect();
823 for (a, b) in &key_pairs {
824 if *a >= n_slots || *b >= n_slots {
825 return false;
826 }
827 uf_union_clique(&mut parent, *a, *b);
828 }
829 let mut class_roots: Vec<usize> = Vec::new();
830 let mut first_slot_of_class: HashMap<usize, usize> = HashMap::new();
831 for slot in 0..n_slots {
832 let root = uf_find_clique(&mut parent, slot);
833 if !first_slot_of_class.contains_key(&root) {
834 first_slot_of_class.insert(root, slot);
835 class_roots.push(root);
836 }
837 }
838 if class_roots.len() != k {
839 return false;
840 }
841 let class_idx: HashMap<usize, usize> = class_roots
852 .iter()
853 .enumerate()
854 .map(|(idx, root)| (*root, idx))
855 .collect();
856 let mut in_degree = vec![0usize; k];
857 for atom in 0..scans.len() {
858 let cls_a = uf_find_clique(&mut parent, 2 * atom);
859 let cls_b = uf_find_clique(&mut parent, 2 * atom + 1);
860 if cls_a == cls_b {
861 return false;
862 }
863 in_degree[class_idx[&cls_b]] += 1;
864 }
865 let mut class_pos_by_head = vec![usize::MAX; k];
866 for (pos, deg) in in_degree.iter().enumerate() {
867 if *deg >= k || class_pos_by_head[*deg] != usize::MAX {
868 return false;
869 }
870 class_pos_by_head[*deg] = pos;
871 }
872 let representative_slots: Vec<usize> = class_pos_by_head
873 .iter()
874 .map(|pos| first_slot_of_class[&class_roots[*pos]])
875 .collect();
876 let class_to_head: HashMap<usize, usize> = class_pos_by_head
877 .iter()
878 .enumerate()
879 .map(|(head_idx, pos)| (class_roots[*pos], head_idx))
880 .collect();
881 let canonical = RirNode::Project {
882 input: inner.clone(),
883 columns: representative_slots
884 .iter()
885 .map(|slot| ProjectExpr::Column(*slot))
886 .collect(),
887 };
888 let Some(promoted) = try_promote_clique_k(&canonical, k, stats) else {
889 return false;
890 };
891 let mut remapped: Vec<ProjectExpr> = Vec::with_capacity(group_cols.len());
894 for col in group_cols.iter() {
895 let ProjectExpr::Column(c) = col else {
896 return false;
897 };
898 if *c >= n_slots {
899 return false;
900 }
901 let root = uf_find_clique(&mut parent, *c);
902 let Some(head_idx) = class_to_head.get(&root) else {
903 return false;
904 };
905 remapped.push(ProjectExpr::Column(*head_idx));
906 }
907 *group_cols = remapped;
908 **inner = promoted;
909 true
910}
911
912fn try_promote_triangle(
915 node: &RirNode,
916 stats: &StatsManager,
917 config: &CompilerConfig,
918) -> Option<RirNode> {
919 let RirNode::Project {
920 input: outer_input,
921 columns,
922 } = node
923 else {
924 return None;
925 };
926 let RirNode::Join {
927 left: l1,
928 right: r1,
929 left_keys: lk1,
930 right_keys: rk1,
931 join_type: jt1,
932 } = outer_input.as_ref()
933 else {
934 return None;
935 };
936 if !matches!(jt1, JoinType::Inner) {
937 return None;
938 }
939 let RirNode::Scan { rel: rel_third } = r1.as_ref() else {
940 return None;
941 };
942 let RirNode::Join {
943 left: l2,
944 right: r2,
945 left_keys: lk2,
946 right_keys: rk2,
947 join_type: jt2,
948 } = l1.as_ref()
949 else {
950 return None;
951 };
952 if !matches!(jt2, JoinType::Inner) {
953 return None;
954 }
955 let RirNode::Scan { rel: rel_inner_l } = l2.as_ref() else {
956 return None;
957 };
958 let RirNode::Scan { rel: rel_inner_r } = r2.as_ref() else {
959 return None;
960 };
961
962 let (rel_xy, rel_yz, rel_xz) = infer_triangle_semantics(
964 *rel_inner_l,
965 *rel_inner_r,
966 *rel_third,
967 lk2,
968 rk2,
969 lk1,
970 rk1,
971 columns,
972 )?;
973
974 let inputs = vec![
975 RirNode::Scan { rel: rel_xy },
976 RirNode::Scan { rel: rel_yz },
977 RirNode::Scan { rel: rel_xz },
978 ];
979 let slot_vars = vec![
983 vec![Some(0u32), Some(1)],
984 vec![Some(1u32), Some(2)],
985 vec![Some(0u32), Some(2)],
986 ];
987 let output_columns = columns.clone();
988 let fallback = Box::new(node.clone());
989 use crate::compiler_config::WcojVarOrderingKind;
994 use crate::wcoj_var_ordering::{
995 build_triangle_var_order, HeatAwareLeaderModel, LeaderCardinalityModel,
996 };
997 let leader_idx = match config.wcoj_variable_ordering {
998 WcojVarOrderingKind::Disabled => None,
999 WcojVarOrderingKind::LeaderCardinality => {
1000 LeaderCardinalityModel.pick_triangle_leader([rel_xy, rel_yz, rel_xz], stats, config)
1001 }
1002 WcojVarOrderingKind::HeatAware => {
1003 HeatAwareLeaderModel.pick_triangle_leader([rel_xy, rel_yz, rel_xz], stats, config)
1004 }
1005 };
1006 let var_order = leader_idx.map(build_triangle_var_order);
1007 Some(RirNode::MultiWayJoin {
1008 inputs,
1009 slot_vars,
1010 output_columns,
1011 fallback,
1012 plan: None,
1013 var_order,
1014 })
1015}
1016
1017fn try_promote_chain(node: &RirNode) -> Option<RirNode> {
1021 let RirNode::Project { input, columns } = node else {
1022 return None;
1023 };
1024 let RirNode::Join {
1025 left,
1026 right,
1027 left_keys,
1028 right_keys,
1029 join_type,
1030 } = input.as_ref()
1031 else {
1032 return None;
1033 };
1034 if !matches!(join_type, JoinType::Inner) {
1035 return None;
1036 }
1037 if left_keys.len() != 1 || right_keys.len() != 1 {
1038 return None;
1039 }
1040 let left_key = left_keys[0];
1041 let right_key = right_keys[0];
1042 if left_key >= 2 || right_key >= 2 {
1043 return None;
1044 }
1045 let RirNode::Scan { rel: rel_left } = left.as_ref() else {
1046 return None;
1047 };
1048 let RirNode::Scan { rel: rel_right } = right.as_ref() else {
1049 return None;
1050 };
1051
1052 Some(RirNode::ChainJoin {
1053 left: Box::new(RirNode::Scan { rel: *rel_left }),
1054 right: Box::new(RirNode::Scan { rel: *rel_right }),
1055 left_key,
1056 right_key,
1057 output_columns: columns.clone(),
1058 fallback: Box::new(node.clone()),
1059 })
1060}
1061
1062fn ac_idx_4(atom_idx: u8, col_idx: u8) -> u8 {
1067 debug_assert!(atom_idx < 4);
1068 debug_assert!(col_idx < 2);
1069 atom_idx * 2 + col_idx
1070}
1071
1072fn outer_left_inner_output_ac(k: usize) -> Option<(u8, u8)> {
1073 match k {
1074 0 => Some((0, 0)),
1075 1 => Some((0, 1)),
1076 2 => Some((1, 0)),
1077 3 => Some((1, 1)),
1078 _ => None,
1079 }
1080}
1081
1082fn outer_right_inner_output_ac(k: usize) -> Option<(u8, u8)> {
1083 match k {
1084 0 => Some((2, 0)),
1085 1 => Some((2, 1)),
1086 2 => Some((3, 0)),
1087 3 => Some((3, 1)),
1088 _ => None,
1089 }
1090}
1091
1092fn outer_4cycle_output_ac(k: usize) -> Option<(u8, u8)> {
1093 match k {
1094 0..=3 => outer_left_inner_output_ac(k),
1095 4..=7 => outer_right_inner_output_ac(k - 4),
1096 _ => None,
1097 }
1098}
1099
1100fn uf_find_8(parent: &mut [u8; 8], x: u8) -> u8 {
1101 let mut root = x;
1102 while parent[root as usize] != root {
1103 root = parent[root as usize];
1104 }
1105 let mut cur = x;
1106 while parent[cur as usize] != root {
1107 let next = parent[cur as usize];
1108 parent[cur as usize] = root;
1109 cur = next;
1110 }
1111 root
1112}
1113
1114fn uf_union_8(parent: &mut [u8; 8], a: u8, b: u8) {
1115 let ra = uf_find_8(parent, a);
1116 let rb = uf_find_8(parent, b);
1117 if ra != rb {
1118 parent[rb as usize] = ra;
1119 }
1120}
1121
1122#[allow(clippy::too_many_arguments)]
1129fn infer_4cycle_semantics(
1130 rel_ll: RelId,
1131 rel_lr: RelId,
1132 rel_rl: RelId,
1133 rel_rr: RelId,
1134 ilk_l: &[usize],
1135 irk_l: &[usize],
1136 ilk_r: &[usize],
1137 irk_r: &[usize],
1138 olk: &[usize],
1139 ork: &[usize],
1140 project_cols: &[ProjectExpr],
1141) -> Option<(RelId, RelId, RelId, RelId)> {
1142 if ilk_l.len() != 1 || irk_l.len() != 1 {
1143 return None;
1144 }
1145 if ilk_r.len() != 1 || irk_r.len() != 1 {
1146 return None;
1147 }
1148 if olk.len() != 2 || ork.len() != 2 {
1149 return None;
1150 }
1151 if project_cols.len() != 4 {
1152 return None;
1153 }
1154 if ilk_l[0] >= 2 || irk_l[0] >= 2 || ilk_r[0] >= 2 || irk_r[0] >= 2 {
1155 return None;
1156 }
1157 if olk.iter().any(|k| *k >= 4) || ork.iter().any(|k| *k >= 4) {
1158 return None;
1159 }
1160
1161 let mut parent = [0u8, 1, 2, 3, 4, 5, 6, 7];
1162
1163 uf_union_8(
1165 &mut parent,
1166 ac_idx_4(0, ilk_l[0] as u8),
1167 ac_idx_4(1, irk_l[0] as u8),
1168 );
1169
1170 uf_union_8(
1172 &mut parent,
1173 ac_idx_4(2, ilk_r[0] as u8),
1174 ac_idx_4(3, irk_r[0] as u8),
1175 );
1176
1177 for i in 0..2 {
1181 let (la, lc) = outer_left_inner_output_ac(olk[i])?;
1182 let (ra, rc) = outer_right_inner_output_ac(ork[i])?;
1183 uf_union_8(&mut parent, ac_idx_4(la, lc), ac_idx_4(ra, rc));
1184 }
1185
1186 let roots: [u8; 8] = std::array::from_fn(|i| uf_find_8(&mut parent, i as u8));
1187 let mut counts: HashMap<u8, u8> = HashMap::new();
1188 for r in &roots {
1189 *counts.entry(*r).or_insert(0) += 1;
1190 }
1191 if counts.len() != 4 || counts.values().any(|c| *c != 2) {
1192 return None;
1193 }
1194
1195 let mut head_classes: [u8; 4] = [0; 4];
1197 for (i, pc) in project_cols.iter().enumerate() {
1198 let ProjectExpr::Column(k) = pc else {
1199 return None;
1200 };
1201 let (atom, col) = outer_4cycle_output_ac(*k)?;
1202 head_classes[i] = uf_find_8(&mut parent, ac_idx_4(atom, col));
1203 }
1204 for i in 0..4 {
1206 for j in (i + 1)..4 {
1207 if head_classes[i] == head_classes[j] {
1208 return None;
1209 }
1210 }
1211 }
1212 let w_class = head_classes[0];
1213 let x_class = head_classes[1];
1214 let y_class = head_classes[2];
1215 let z_class = head_classes[3];
1216
1217 let atom_classes = |atom_idx: u8| -> (u8, u8) {
1218 (
1219 roots[ac_idx_4(atom_idx, 0) as usize],
1220 roots[ac_idx_4(atom_idx, 1) as usize],
1221 )
1222 };
1223
1224 let atom_rels = [rel_ll, rel_lr, rel_rl, rel_rr];
1225 let mut rel_wx: Option<RelId> = None;
1226 let mut rel_xy: Option<RelId> = None;
1227 let mut rel_yz: Option<RelId> = None;
1228 let mut rel_zw: Option<RelId> = None;
1229
1230 for atom_idx in 0..4u8 {
1231 let (c0, c1) = atom_classes(atom_idx);
1232 let binds_w = c0 == w_class || c1 == w_class;
1233 let binds_x = c0 == x_class || c1 == x_class;
1234 let binds_y = c0 == y_class || c1 == y_class;
1235 let binds_z = c0 == z_class || c1 == z_class;
1236 match (binds_w, binds_x, binds_y, binds_z) {
1237 (true, true, false, false) => rel_wx = Some(atom_rels[atom_idx as usize]),
1238 (false, true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
1239 (false, false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
1240 (true, false, false, true) => rel_zw = Some(atom_rels[atom_idx as usize]),
1241 _ => return None,
1242 }
1243 }
1244
1245 Some((rel_wx?, rel_xy?, rel_yz?, rel_zw?))
1246}
1247
1248fn try_promote_4cycle(
1268 node: &RirNode,
1269 stats: &StatsManager,
1270 config: &CompilerConfig,
1271) -> Option<RirNode> {
1272 let RirNode::Project {
1273 input: outer_input,
1274 columns,
1275 } = node
1276 else {
1277 return None;
1278 };
1279 let RirNode::Join {
1280 left: outer_l,
1281 right: outer_r,
1282 left_keys: olk,
1283 right_keys: ork,
1284 join_type: ojt,
1285 } = outer_input.as_ref()
1286 else {
1287 return None;
1288 };
1289 if !matches!(ojt, JoinType::Inner) {
1290 return None;
1291 }
1292 let RirNode::Join {
1293 left: ll,
1294 right: lr,
1295 left_keys: ilk_l,
1296 right_keys: irk_l,
1297 join_type: ijt_l,
1298 } = outer_l.as_ref()
1299 else {
1300 return None;
1301 };
1302 if !matches!(ijt_l, JoinType::Inner) {
1303 return None;
1304 }
1305 let RirNode::Scan { rel: rel_ll } = ll.as_ref() else {
1306 return None;
1307 };
1308 let RirNode::Scan { rel: rel_lr } = lr.as_ref() else {
1309 return None;
1310 };
1311 let RirNode::Join {
1312 left: rl,
1313 right: rr,
1314 left_keys: ilk_r,
1315 right_keys: irk_r,
1316 join_type: ijt_r,
1317 } = outer_r.as_ref()
1318 else {
1319 return None;
1320 };
1321 if !matches!(ijt_r, JoinType::Inner) {
1322 return None;
1323 }
1324 let RirNode::Scan { rel: rel_rl } = rl.as_ref() else {
1325 return None;
1326 };
1327 let RirNode::Scan { rel: rel_rr } = rr.as_ref() else {
1328 return None;
1329 };
1330
1331 let (rel_wx, rel_xy, rel_yz, rel_zw) = infer_4cycle_semantics(
1334 *rel_ll, *rel_lr, *rel_rl, *rel_rr, ilk_l, irk_l, ilk_r, irk_r, olk, ork, columns,
1335 )?;
1336
1337 let inputs = vec![
1338 RirNode::Scan { rel: rel_wx },
1339 RirNode::Scan { rel: rel_xy },
1340 RirNode::Scan { rel: rel_yz },
1341 RirNode::Scan { rel: rel_zw },
1342 ];
1343 let slot_vars = vec![
1348 vec![Some(0u32), Some(1)],
1349 vec![Some(1u32), Some(2)],
1350 vec![Some(2u32), Some(3)],
1351 vec![Some(3u32), Some(0)],
1352 ];
1353 let output_columns = columns.clone();
1354 let fallback = Box::new(node.clone());
1355 use crate::compiler_config::WcojVarOrderingKind;
1359 use crate::wcoj_var_ordering::{
1360 build_cycle4_var_order, HeatAwareLeaderModel, LeaderCardinalityModel,
1361 };
1362 let leader_idx_4 = match config.wcoj_variable_ordering {
1363 WcojVarOrderingKind::Disabled => None,
1364 WcojVarOrderingKind::LeaderCardinality => LeaderCardinalityModel.pick_4cycle_leader(
1365 [rel_wx, rel_xy, rel_yz, rel_zw],
1366 stats,
1367 config,
1368 ),
1369 WcojVarOrderingKind::HeatAware => {
1370 HeatAwareLeaderModel.pick_4cycle_leader([rel_wx, rel_xy, rel_yz, rel_zw], stats, config)
1371 }
1372 };
1373 let var_order = leader_idx_4.map(build_cycle4_var_order);
1374 Some(RirNode::MultiWayJoin {
1375 inputs,
1376 slot_vars,
1377 output_columns,
1378 fallback,
1379 plan: None,
1380 var_order,
1381 })
1382}
1383
1384fn clique_edge_idx(i: usize, j: usize, k: usize) -> usize {
1403 debug_assert!(i < j && j < k);
1404 i * (2 * k - i - 1) / 2 + (j - i - 1)
1405}
1406
1407fn uf_find_clique(parent: &mut [usize], mut x: usize) -> usize {
1409 while parent[x] != x {
1410 parent[x] = parent[parent[x]];
1411 x = parent[x];
1412 }
1413 x
1414}
1415
1416fn uf_union_clique(parent: &mut [usize], a: usize, b: usize) {
1417 let ra = uf_find_clique(parent, a);
1418 let rb = uf_find_clique(parent, b);
1419 if ra != rb {
1420 parent[rb] = ra;
1421 }
1422}
1423
1424#[allow(clippy::type_complexity)]
1440fn flatten_clique_body(body: &RirNode) -> Option<(Vec<RelId>, Vec<(usize, usize)>, Vec<usize>)> {
1441 let RirNode::Project { input, columns } = body else {
1442 return None;
1443 };
1444 let mut scans: Vec<RelId> = Vec::new();
1445 let mut key_pairs: Vec<(usize, usize)> = Vec::new();
1446 let _width = walk_clique_node(input, &mut scans, &mut key_pairs)?;
1447 let mut project_globals: Vec<usize> = Vec::with_capacity(columns.len());
1448 for c in columns {
1449 let xlog_ir::rir::ProjectExpr::Column(k) = c else {
1450 return None;
1451 };
1452 project_globals.push(*k);
1453 }
1454 Some((scans, key_pairs, project_globals))
1455}
1456
1457fn walk_clique_node(
1460 node: &RirNode,
1461 scans: &mut Vec<RelId>,
1462 key_pairs: &mut Vec<(usize, usize)>,
1463) -> Option<usize> {
1464 match node {
1465 RirNode::Scan { rel } => {
1466 scans.push(*rel);
1467 Some(2)
1468 }
1469 RirNode::Join {
1470 left,
1471 right,
1472 left_keys,
1473 right_keys,
1474 join_type,
1475 } => {
1476 if !matches!(join_type, JoinType::Inner) {
1477 return None;
1478 }
1479 let left_offset = scans.len() * 2;
1480 let left_width = walk_clique_node(left, scans, key_pairs)?;
1481 let right_offset = left_offset + left_width;
1482 let right_width = walk_clique_node(right, scans, key_pairs)?;
1483 if left_keys.len() != right_keys.len() {
1484 return None;
1485 }
1486 for (lk, rk) in left_keys.iter().zip(right_keys.iter()) {
1487 if *lk >= left_width || *rk >= right_width {
1488 return None;
1489 }
1490 key_pairs.push((left_offset + *lk, right_offset + *rk));
1491 }
1492 Some(left_width + right_width)
1493 }
1494 _ => None,
1498 }
1499}
1500
1501fn walk_general_node(
1510 node: &RirNode,
1511 arities: &HashMap<RelId, usize>,
1512 scans: &mut Vec<RelId>,
1513 widths: &mut Vec<usize>,
1514 key_pairs: &mut Vec<(usize, usize)>,
1515) -> Option<usize> {
1516 match node {
1517 RirNode::Scan { rel } => {
1518 let width = *arities.get(rel)?;
1519 if width == 0 {
1520 return None;
1521 }
1522 scans.push(*rel);
1523 widths.push(width);
1524 Some(width)
1525 }
1526 RirNode::Join {
1527 left,
1528 right,
1529 left_keys,
1530 right_keys,
1531 join_type,
1532 } => {
1533 if !matches!(join_type, JoinType::Inner) {
1534 return None;
1535 }
1536 let left_offset: usize = widths.iter().sum();
1539 let left_width = walk_general_node(left, arities, scans, widths, key_pairs)?;
1540 let right_offset = left_offset + left_width;
1541 let right_width = walk_general_node(right, arities, scans, widths, key_pairs)?;
1542 if left_keys.len() != right_keys.len() || left_keys.is_empty() {
1543 return None;
1544 }
1545 for (lk, rk) in left_keys.iter().zip(right_keys.iter()) {
1546 if *lk >= left_width || *rk >= right_width {
1547 return None;
1548 }
1549 key_pairs.push((left_offset + *lk, right_offset + *rk));
1550 }
1551 Some(left_width + right_width)
1552 }
1553 _ => None,
1554 }
1555}
1556
1557fn try_promote_general_multiway(
1571 body: &RirNode,
1572 arities: &HashMap<RelId, usize>,
1573) -> Option<RirNode> {
1574 let RirNode::Project { input, columns } = body else {
1575 return None;
1576 };
1577 let mut scans: Vec<RelId> = Vec::new();
1578 let mut widths: Vec<usize> = Vec::new();
1579 let mut key_pairs: Vec<(usize, usize)> = Vec::new();
1580 let total = walk_general_node(input, arities, &mut scans, &mut widths, &mut key_pairs)?;
1581 if scans.len() < 3 {
1582 return None;
1583 }
1584
1585 let mut parent: Vec<usize> = (0..total).collect();
1590 for (a, b) in &key_pairs {
1591 if *a >= total || *b >= total {
1592 return None;
1593 }
1594 uf_union_clique(&mut parent, *a, *b);
1595 }
1596 let mut class_of_root: HashMap<usize, u32> = HashMap::new();
1597 let mut slot_class: Vec<u32> = Vec::with_capacity(total);
1598 for slot in 0..total {
1599 let root = uf_find_clique(&mut parent, slot);
1600 let next = class_of_root.len() as u32;
1601 let cls = *class_of_root.entry(root).or_insert(next);
1602 slot_class.push(cls);
1603 }
1604
1605 for c in columns {
1608 let ProjectExpr::Column(k) = c else {
1609 return None;
1610 };
1611 if *k >= total {
1612 return None;
1613 }
1614 }
1615
1616 let inputs: Vec<RirNode> = scans
1617 .iter()
1618 .map(|rel| RirNode::Scan { rel: *rel })
1619 .collect();
1620 let mut slot_vars: Vec<Vec<Option<u32>>> = Vec::with_capacity(scans.len());
1621 let mut offset = 0usize;
1622 for &w in &widths {
1623 slot_vars.push((offset..offset + w).map(|s| Some(slot_class[s])).collect());
1624 offset += w;
1625 }
1626
1627 Some(RirNode::MultiWayJoin {
1628 inputs,
1629 slot_vars,
1630 output_columns: columns.clone(),
1631 fallback: Box::new(body.clone()),
1632 plan: Some(MultiwayPlan::FreeJoin),
1638 var_order: None,
1639 })
1640}
1641
1642fn try_promote_general_multiway_inside_aggregate(
1657 body: &mut RirNode,
1658 arities: &HashMap<RelId, usize>,
1659) -> bool {
1660 let RirNode::Project { input: gb, .. } = body else {
1661 return false;
1662 };
1663 let RirNode::GroupBy {
1664 input: group_input, ..
1665 } = gb.as_mut()
1666 else {
1667 return false;
1668 };
1669 let RirNode::Project {
1670 input: inner,
1671 columns: group_cols,
1672 } = group_input.as_mut()
1673 else {
1674 return false;
1675 };
1676 let candidate = RirNode::Project {
1677 input: inner.clone(),
1678 columns: group_cols.clone(),
1679 };
1680 let Some(promoted) = try_promote_general_multiway(&candidate, arities) else {
1681 return false;
1682 };
1683 *group_cols = (0..group_cols.len()).map(ProjectExpr::Column).collect();
1686 **inner = promoted;
1687 true
1688}
1689
1690fn try_promote_clique_k(body: &RirNode, k: usize, stats: &StatsManager) -> Option<RirNode> {
1698 if !(5..=8).contains(&k) {
1699 return None;
1700 }
1701 let expected_edges = k * (k - 1) / 2;
1702
1703 let (scans, key_pairs, project_globals) = flatten_clique_body(body)?;
1705
1706 if scans.len() != expected_edges {
1708 return None;
1709 }
1710
1711 if project_globals.len() != k {
1713 return None;
1714 }
1715
1716 let n_slots = 2 * expected_edges;
1719 let mut parent: Vec<usize> = (0..n_slots).collect();
1720 for (a, b) in &key_pairs {
1721 if *a >= n_slots || *b >= n_slots {
1722 return None;
1723 }
1724 uf_union_clique(&mut parent, *a, *b);
1725 }
1726
1727 let mut head_class: Vec<usize> = Vec::with_capacity(k);
1731 for col in &project_globals {
1732 if *col >= n_slots {
1733 return None;
1734 }
1735 head_class.push(uf_find_clique(&mut parent, *col));
1736 }
1737 let mut sorted_head_classes = head_class.clone();
1739 sorted_head_classes.sort();
1740 sorted_head_classes.dedup();
1741 if sorted_head_classes.len() != k {
1742 return None;
1743 }
1744
1745 let mut all_class_count: HashMap<usize, usize> = HashMap::new();
1749 for slot in 0..n_slots {
1750 let root = uf_find_clique(&mut parent, slot);
1751 *all_class_count.entry(root).or_insert(0) += 1;
1752 }
1753 if all_class_count.len() != k {
1754 return None;
1755 }
1756 for &count in all_class_count.values() {
1759 if count != k - 1 {
1760 return None;
1761 }
1762 }
1763
1764 let mut class_to_head_idx: HashMap<usize, usize> = HashMap::new();
1767 for (head_idx, cls) in head_class.iter().enumerate() {
1768 class_to_head_idx.insert(*cls, head_idx);
1769 }
1770
1771 let mut atom_pairs: Vec<(usize, usize)> = Vec::with_capacity(expected_edges);
1777 let mut canonical_to_scan_idx: HashMap<(usize, usize), usize> = HashMap::new();
1778 for (atom_i, _rel) in scans.iter().enumerate() {
1779 let slot_a = 2 * atom_i;
1780 let slot_b = 2 * atom_i + 1;
1781 let cls_a = uf_find_clique(&mut parent, slot_a);
1782 let cls_b = uf_find_clique(&mut parent, slot_b);
1783 if cls_a == cls_b {
1784 return None;
1786 }
1787 let head_a = class_to_head_idx.get(&cls_a)?;
1788 let head_b = class_to_head_idx.get(&cls_b)?;
1789 if *head_a > *head_b {
1794 return None;
1795 }
1796 let (lo, hi) = (*head_a, *head_b);
1797 atom_pairs.push((lo, hi));
1798 if canonical_to_scan_idx.insert((lo, hi), atom_i).is_some() {
1800 return None;
1801 }
1802 }
1803
1804 if canonical_to_scan_idx.len() != expected_edges {
1807 return None;
1808 }
1809 for i in 0..k {
1810 for j in (i + 1)..k {
1811 if !canonical_to_scan_idx.contains_key(&(i, j)) {
1812 return None;
1813 }
1814 }
1815 }
1816
1817 let mut reordered_scans: Vec<RelId> = Vec::with_capacity(expected_edges);
1819 for i in 0..k {
1820 for j in (i + 1)..k {
1821 let scan_idx = canonical_to_scan_idx[&(i, j)];
1822 reordered_scans.push(scans[scan_idx]);
1823 }
1824 }
1825
1826 let inputs: Vec<RirNode> = reordered_scans
1831 .iter()
1832 .map(|rel| RirNode::Scan { rel: *rel })
1833 .collect();
1834 let mut slot_vars: Vec<Vec<Option<u32>>> = Vec::with_capacity(expected_edges);
1835 for i in 0..k {
1836 for j in (i + 1)..k {
1837 let _ = clique_edge_idx(i, j, k); slot_vars.push(vec![Some(i as u32), Some(j as u32)]);
1839 }
1840 }
1841 let RirNode::Project { columns, .. } = body else {
1845 return None;
1846 };
1847 let output_columns = columns.clone();
1848 let fallback = Box::new(body.clone());
1849 let shape = build_kclique_shape(k, &reordered_scans)?;
1850 let planner_stats = kclique_planner_stats(stats);
1851
1852 let (plan, var_order) = match plan_kclique_var_order(&shape, &planner_stats) {
1857 Some(full_order) => {
1858 let evidence = rir_cost_prediction(&full_order);
1859 if wcoj_cost_gate_predicts_wcoj(evidence.wcoj_cost, evidence.hash_cost) {
1860 let kclique_order = kclique_variable_order_from_plan(&shape, &full_order)?;
1861 (
1862 MultiwayPlan::WcojWithPlan(kclique_order.clone()),
1863 Some(VariableOrder::kclique(kclique_order)),
1864 )
1865 } else {
1866 (
1867 MultiwayPlan::PlannedHashRoute {
1868 reason: PlannedHashReason::PlannerPredictsHashWins,
1869 planner_evidence: evidence,
1870 },
1871 None,
1872 )
1873 }
1874 }
1875 None => (
1876 MultiwayPlan::PlannedHashRoute {
1877 reason: PlannedHashReason::IncompleteStatsSafeDefault,
1878 planner_evidence: RirCostPredictionRecord::empty(),
1879 },
1880 None,
1881 ),
1882 };
1883
1884 Some(RirNode::MultiWayJoin {
1885 inputs,
1886 slot_vars,
1887 output_columns,
1888 fallback,
1889 plan: Some(plan),
1890 var_order,
1891 })
1892}
1893
1894fn build_kclique_shape(k: usize, rels: &[RelId]) -> Option<KCliqueShape> {
1895 let mut edges = Vec::with_capacity(rels.len());
1896 let mut idx = 0usize;
1897 for i in 0..k {
1898 for j in (i + 1)..k {
1899 let rel_id = *rels.get(idx)?;
1900 edges.push(KCliqueEdge {
1901 rel_id,
1902 left: VertexId(i),
1903 right: VertexId(j),
1904 left_col: 0,
1905 right_col: 1,
1906 });
1907 idx += 1;
1908 }
1909 }
1910 KCliqueShape::from_edges(k as u8, edges)
1911}
1912
1913fn kclique_planner_stats(stats: &StatsManager) -> StatsSnapshot {
1914 stats.snapshot()
1915}
1916
1917fn rir_cost_prediction(plan: &FullVariableOrder) -> RirCostPredictionRecord {
1918 RirCostPredictionRecord {
1919 wcoj_cost: plan.cost_prediction.wcoj_cost,
1920 hash_cost: plan.cost_prediction.hash_cost,
1921 }
1922}
1923
1924fn kclique_variable_order_from_plan(
1925 shape: &KCliqueShape,
1926 plan: &FullVariableOrder,
1927) -> Option<KCliqueVariableOrder> {
1928 let k = shape.variable_count();
1929 let expected_edges = usize::from(k) * usize::from(k - 1) / 2;
1930 if plan.variable_order.len() != usize::from(k) || plan.edge_permutation.len() != expected_edges
1931 {
1932 return None;
1933 }
1934
1935 let mut variable_positions = [u8::MAX; K_CLIQUE_MAX_K];
1936 for (position, variable) in plan.variable_order.iter().enumerate() {
1937 if variable.0 >= usize::from(k) {
1938 return None;
1939 }
1940 variable_positions[variable.0] = position as u8;
1941 }
1942
1943 let mut edge_permutation = [u8::MAX; K_CLIQUE_MAX_EDGES];
1944 let mut column_swaps = Vec::new();
1945 let mut leader_slot = None;
1946 for (slot, edge_idx) in plan.edge_permutation.iter().copied().enumerate() {
1947 let edge = shape.edges().get(edge_idx)?;
1948 let left_pos = variable_positions[edge.left.0];
1949 let right_pos = variable_positions[edge.right.0];
1950 if left_pos == u8::MAX || right_pos == u8::MAX {
1951 return None;
1952 }
1953 edge_permutation[slot] = edge_idx as u8;
1954 if left_pos > right_pos {
1955 column_swaps.push(ColumnSwap {
1956 edge_slot: slot as u8,
1957 swap_cols: true,
1958 });
1959 }
1960 if [left_pos, right_pos].into_iter().min() == Some(0)
1961 && [left_pos, right_pos].into_iter().max() == Some(1)
1962 {
1963 leader_slot = Some(slot as u8);
1964 }
1965 }
1966
1967 let sorted_edge_slots = vec![leader_slot.unwrap_or(0)];
1968 let sorted_layout_requirements = SortedLayoutSpec {
1969 edge_slots: sorted_edge_slots,
1970 key_columns: vec![vec![0, 1]],
1971 };
1972
1973 let helper_split_specs = plan.helper_split_specs.clone();
1975 Some(KCliqueVariableOrder::new(
1976 k,
1977 variable_positions,
1978 edge_permutation,
1979 column_swaps,
1980 sorted_layout_requirements,
1981 helper_split_specs,
1982 StreamGroupId(0),
1983 ))
1984}
1985
1986#[cfg(test)]
1987mod tests {
1988 use super::*;
1989 use xlog_core::RelId;
1990 use xlog_ir::{CompiledRule, ExecutionPlan, PlanBuilder, Scc};
1991
1992 fn canonical_triangle_tree() -> RirNode {
1993 let inner = RirNode::Join {
1994 left: Box::new(RirNode::Scan { rel: RelId(1) }),
1995 right: Box::new(RirNode::Scan { rel: RelId(2) }),
1996 left_keys: vec![1],
1997 right_keys: vec![0],
1998 join_type: JoinType::Inner,
1999 };
2000 let outer = RirNode::Join {
2001 left: Box::new(inner),
2002 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2003 left_keys: vec![0, 3],
2004 right_keys: vec![0, 1],
2005 join_type: JoinType::Inner,
2006 };
2007 RirNode::Project {
2008 input: Box::new(outer),
2009 columns: vec![
2010 ProjectExpr::Column(0),
2011 ProjectExpr::Column(1),
2012 ProjectExpr::Column(3),
2013 ],
2014 }
2015 }
2016
2017 fn plan_with_body(body: RirNode) -> ExecutionPlan {
2018 let mut builder = PlanBuilder::new();
2019 builder.add_scc(Scc {
2020 id: 0,
2021 predicates: vec!["t".to_string()],
2022 is_recursive: false,
2023 });
2024 builder.add_rule(
2025 0,
2026 CompiledRule {
2027 head: "t".to_string(),
2028 body,
2029 meta: Default::default(),
2030 },
2031 );
2032 builder.build()
2033 }
2034
2035 fn canonical_chain_tree() -> RirNode {
2036 let join = RirNode::Join {
2037 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2038 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2039 left_keys: vec![1],
2040 right_keys: vec![0],
2041 join_type: JoinType::Inner,
2042 };
2043 RirNode::Project {
2044 input: Box::new(join),
2045 columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
2046 }
2047 }
2048
2049 #[test]
2050 fn promotes_canonical_chain() {
2051 let mut plan = plan_with_body(canonical_chain_tree());
2052 promote_multiway(
2053 &mut plan,
2054 &HashMap::new(),
2055 &StatsManager::new(),
2056 &CompilerConfig::default(),
2057 );
2058 let body = &plan.rules_by_scc[0][0].body;
2059 match body {
2060 RirNode::ChainJoin {
2061 left,
2062 right,
2063 left_key,
2064 right_key,
2065 output_columns,
2066 fallback,
2067 } => {
2068 assert!(matches!(left.as_ref(), RirNode::Scan { rel: RelId(1) }));
2069 assert!(matches!(right.as_ref(), RirNode::Scan { rel: RelId(2) }));
2070 assert_eq!(*left_key, 1);
2071 assert_eq!(*right_key, 0);
2072 assert_eq!(
2073 output_columns,
2074 &vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
2075 );
2076 assert!(matches!(fallback.as_ref(), RirNode::Project { .. }));
2077 }
2078 other => panic!("expected ChainJoin, got {:?}", other),
2079 }
2080 }
2081
2082 #[test]
2083 fn chain_promotion_rejects_non_inner_join() {
2084 let mut body = canonical_chain_tree();
2085 if let RirNode::Project { input, .. } = &mut body {
2086 if let RirNode::Join { join_type, .. } = input.as_mut() {
2087 *join_type = JoinType::LeftOuter;
2088 }
2089 }
2090 assert!(try_promote_chain(&body).is_none());
2091 }
2092
2093 #[test]
2094 fn chain_promotion_rejects_multi_key_join() {
2095 let mut body = canonical_chain_tree();
2096 if let RirNode::Project { input, .. } = &mut body {
2097 if let RirNode::Join {
2098 left_keys,
2099 right_keys,
2100 ..
2101 } = input.as_mut()
2102 {
2103 *left_keys = vec![0, 1];
2104 *right_keys = vec![0, 1];
2105 }
2106 }
2107 assert!(try_promote_chain(&body).is_none());
2108 }
2109
2110 #[test]
2111 fn promotes_canonical_triangle() {
2112 let mut plan = plan_with_body(canonical_triangle_tree());
2113 promote_multiway(
2114 &mut plan,
2115 &HashMap::new(),
2116 &StatsManager::new(),
2117 &CompilerConfig::default(),
2118 );
2119 let body = &plan.rules_by_scc[0][0].body;
2120 match body {
2121 RirNode::MultiWayJoin {
2122 inputs,
2123 slot_vars,
2124 output_columns,
2125 fallback,
2126 var_order: _,
2127 ..
2128 } => {
2129 assert_eq!(inputs.len(), 3);
2130 assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(1) }));
2131 assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(2) }));
2132 assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(3) }));
2133 assert_eq!(
2134 slot_vars,
2135 &vec![
2136 vec![Some(0u32), Some(1)],
2137 vec![Some(1u32), Some(2)],
2138 vec![Some(0u32), Some(2)],
2139 ]
2140 );
2141 assert_eq!(
2142 output_columns,
2143 &vec![
2144 ProjectExpr::Column(0),
2145 ProjectExpr::Column(1),
2146 ProjectExpr::Column(3),
2147 ]
2148 );
2149 assert!(matches!(fallback.as_ref(), RirNode::Project { .. }));
2151 }
2152 other => panic!("expected MultiWayJoin, got {:?}", other),
2153 }
2154 }
2155
2156 #[test]
2157 fn fallback_is_structurally_equal_to_input() {
2158 let pre = canonical_triangle_tree();
2159 let mut plan = plan_with_body(pre.clone());
2160 promote_multiway(
2161 &mut plan,
2162 &HashMap::new(),
2163 &StatsManager::new(),
2164 &CompilerConfig::default(),
2165 );
2166 let body = &plan.rules_by_scc[0][0].body;
2167 let RirNode::MultiWayJoin { fallback, .. } = body else {
2168 panic!("expected MultiWayJoin");
2169 };
2170 assert_eq!(format!("{:?}", fallback.as_ref()), format!("{:?}", pre));
2174 }
2175
2176 #[test]
2177 fn idempotent_under_repeat_calls() {
2178 let mut plan = plan_with_body(canonical_triangle_tree());
2179 promote_multiway(
2180 &mut plan,
2181 &HashMap::new(),
2182 &StatsManager::new(),
2183 &CompilerConfig::default(),
2184 );
2185 let first = format!("{:?}", &plan.rules_by_scc[0][0].body);
2186 promote_multiway(
2187 &mut plan,
2188 &HashMap::new(),
2189 &StatsManager::new(),
2190 &CompilerConfig::default(),
2191 );
2192 let second = format!("{:?}", &plan.rules_by_scc[0][0].body);
2193 assert_eq!(first, second);
2194 }
2195
2196 #[test]
2202 fn promotes_triangle_with_x_shared_inner_pair() {
2203 let inner = RirNode::Join {
2204 left: Box::new(RirNode::Scan { rel: RelId(1) }), right: Box::new(RirNode::Scan { rel: RelId(2) }), left_keys: vec![0],
2207 right_keys: vec![0],
2208 join_type: JoinType::Inner,
2209 };
2210 let outer = RirNode::Join {
2211 left: Box::new(inner),
2212 right: Box::new(RirNode::Scan { rel: RelId(3) }), left_keys: vec![1, 3],
2214 right_keys: vec![0, 1],
2215 join_type: JoinType::Inner,
2216 };
2217 let body = RirNode::Project {
2218 input: Box::new(outer),
2219 columns: vec![
2220 ProjectExpr::Column(0),
2221 ProjectExpr::Column(1),
2222 ProjectExpr::Column(3),
2223 ],
2224 };
2225 let mut plan = plan_with_body(body);
2226 promote_multiway(
2227 &mut plan,
2228 &HashMap::new(),
2229 &StatsManager::new(),
2230 &CompilerConfig::default(),
2231 );
2232 let RirNode::MultiWayJoin {
2233 inputs, slot_vars, ..
2234 } = &plan.rules_by_scc[0][0].body
2235 else {
2236 panic!("expected MultiWayJoin after promotion");
2237 };
2238 let scan_rels: Vec<RelId> = inputs
2240 .iter()
2241 .map(|n| match n {
2242 RirNode::Scan { rel } => *rel,
2243 _ => panic!("expected Scan in MultiWayJoin inputs"),
2244 })
2245 .collect();
2246 assert_eq!(scan_rels, vec![RelId(1), RelId(3), RelId(2)]);
2247 assert_eq!(
2249 slot_vars,
2250 &vec![
2251 vec![Some(0u32), Some(1)],
2252 vec![Some(1u32), Some(2)],
2253 vec![Some(0u32), Some(2)],
2254 ]
2255 );
2256 }
2257
2258 #[test]
2264 fn promotes_triangle_with_z_shared_inner_pair() {
2265 let inner = RirNode::Join {
2266 left: Box::new(RirNode::Scan { rel: RelId(1) }), right: Box::new(RirNode::Scan { rel: RelId(2) }), left_keys: vec![1],
2269 right_keys: vec![1],
2270 join_type: JoinType::Inner,
2271 };
2272 let outer = RirNode::Join {
2273 left: Box::new(inner),
2274 right: Box::new(RirNode::Scan { rel: RelId(3) }), left_keys: vec![0, 2],
2276 right_keys: vec![0, 1],
2277 join_type: JoinType::Inner,
2278 };
2279 let body = RirNode::Project {
2280 input: Box::new(outer),
2281 columns: vec![
2284 ProjectExpr::Column(0),
2285 ProjectExpr::Column(2),
2286 ProjectExpr::Column(3),
2287 ],
2288 };
2289 let mut plan = plan_with_body(body);
2290 promote_multiway(
2291 &mut plan,
2292 &HashMap::new(),
2293 &StatsManager::new(),
2294 &CompilerConfig::default(),
2295 );
2296 let RirNode::MultiWayJoin {
2297 inputs, slot_vars, ..
2298 } = &plan.rules_by_scc[0][0].body
2299 else {
2300 panic!("expected MultiWayJoin after promotion");
2301 };
2302 let scan_rels: Vec<RelId> = inputs
2303 .iter()
2304 .map(|n| match n {
2305 RirNode::Scan { rel } => *rel,
2306 _ => panic!("expected Scan in MultiWayJoin inputs"),
2307 })
2308 .collect();
2309 assert_eq!(scan_rels, vec![RelId(3), RelId(2), RelId(1)]);
2310 assert_eq!(
2311 slot_vars,
2312 &vec![
2313 vec![Some(0u32), Some(1)],
2314 vec![Some(1u32), Some(2)],
2315 vec![Some(0u32), Some(2)],
2316 ]
2317 );
2318 }
2319
2320 #[test]
2321 fn promotes_triangle_with_rotated_projection_columns() {
2322 let inner = RirNode::Join {
2334 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2335 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2336 left_keys: vec![1],
2337 right_keys: vec![0],
2338 join_type: JoinType::Inner,
2339 };
2340 let outer = RirNode::Join {
2341 left: Box::new(inner),
2342 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2343 left_keys: vec![0, 3],
2344 right_keys: vec![0, 1],
2345 join_type: JoinType::Inner,
2346 };
2347 let body = RirNode::Project {
2348 input: Box::new(outer),
2349 columns: vec![
2352 ProjectExpr::Column(1),
2353 ProjectExpr::Column(0),
2354 ProjectExpr::Column(3),
2355 ],
2356 };
2357 let mut plan = plan_with_body(body);
2358 promote_multiway(
2359 &mut plan,
2360 &HashMap::new(),
2361 &StatsManager::new(),
2362 &CompilerConfig::default(),
2363 );
2364 let RirNode::MultiWayJoin {
2366 slot_vars,
2367 output_columns,
2368 ..
2369 } = &plan.rules_by_scc[0][0].body
2370 else {
2371 panic!("expected MultiWayJoin after promotion");
2372 };
2373 assert_eq!(
2375 slot_vars,
2376 &vec![
2377 vec![Some(0u32), Some(1)],
2378 vec![Some(1u32), Some(2)],
2379 vec![Some(0u32), Some(2)],
2380 ]
2381 );
2382 assert_eq!(
2384 output_columns,
2385 &vec![
2386 ProjectExpr::Column(1),
2387 ProjectExpr::Column(0),
2388 ProjectExpr::Column(3),
2389 ]
2390 );
2391 }
2392
2393 #[test]
2394 fn rejects_non_inner_join() {
2395 let inner = RirNode::Join {
2396 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2397 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2398 left_keys: vec![1],
2399 right_keys: vec![0],
2400 join_type: JoinType::LeftOuter,
2401 };
2402 let outer = RirNode::Join {
2403 left: Box::new(inner),
2404 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2405 left_keys: vec![0, 3],
2406 right_keys: vec![0, 1],
2407 join_type: JoinType::Inner,
2408 };
2409 let body = RirNode::Project {
2410 input: Box::new(outer),
2411 columns: vec![
2412 ProjectExpr::Column(0),
2413 ProjectExpr::Column(1),
2414 ProjectExpr::Column(3),
2415 ],
2416 };
2417 let mut plan = plan_with_body(body);
2418 promote_multiway(
2419 &mut plan,
2420 &HashMap::new(),
2421 &StatsManager::new(),
2422 &CompilerConfig::default(),
2423 );
2424 assert!(matches!(
2425 &plan.rules_by_scc[0][0].body,
2426 RirNode::Project { .. }
2427 ));
2428 }
2429
2430 #[test]
2431 fn rejects_filter_above_outer_join() {
2432 let inner = RirNode::Join {
2435 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2436 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2437 left_keys: vec![1],
2438 right_keys: vec![0],
2439 join_type: JoinType::Inner,
2440 };
2441 let outer = RirNode::Join {
2442 left: Box::new(inner),
2443 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2444 left_keys: vec![0, 3],
2445 right_keys: vec![0, 1],
2446 join_type: JoinType::Inner,
2447 };
2448 let filtered = RirNode::Filter {
2449 input: Box::new(outer),
2450 predicate: xlog_ir::Expr::Column(0),
2451 };
2452 let body = RirNode::Project {
2453 input: Box::new(filtered),
2454 columns: vec![
2455 ProjectExpr::Column(0),
2456 ProjectExpr::Column(1),
2457 ProjectExpr::Column(3),
2458 ],
2459 };
2460 let mut plan = plan_with_body(body);
2461 promote_multiway(
2462 &mut plan,
2463 &HashMap::new(),
2464 &StatsManager::new(),
2465 &CompilerConfig::default(),
2466 );
2467 assert!(matches!(
2468 &plan.rules_by_scc[0][0].body,
2469 RirNode::Project { .. }
2470 ));
2471 }
2472
2473 #[test]
2474 fn meta_preserved_byte_for_byte() {
2475 use xlog_core::Schema;
2476 use xlog_ir::metadata::RirMeta;
2477
2478 let schema = Schema::new(vec![
2479 ("x".to_string(), xlog_core::ScalarType::U32),
2480 ("y".to_string(), xlog_core::ScalarType::U32),
2481 ("z".to_string(), xlog_core::ScalarType::U32),
2482 ]);
2483 let meta_pre = RirMeta::with_schema(schema).with_rows(100, 250);
2484
2485 let mut builder = PlanBuilder::new();
2486 builder.add_scc(Scc {
2487 id: 0,
2488 predicates: vec!["t".to_string()],
2489 is_recursive: false,
2490 });
2491 builder.add_rule(
2492 0,
2493 CompiledRule {
2494 head: "t".to_string(),
2495 body: canonical_triangle_tree(),
2496 meta: meta_pre.clone(),
2497 },
2498 );
2499 let mut plan = builder.build();
2500 promote_multiway(
2501 &mut plan,
2502 &HashMap::new(),
2503 &StatsManager::new(),
2504 &CompilerConfig::default(),
2505 );
2506 assert_eq!(
2507 format!("{:?}", &plan.rules_by_scc[0][0].meta),
2508 format!("{:?}", meta_pre),
2509 );
2510 }
2511
2512 #[test]
2522 fn promotes_stable_triangle_in_recursive_scc() {
2523 let mut builder = PlanBuilder::new();
2524 builder.add_scc(Scc {
2525 id: 0,
2526 predicates: vec!["tri".to_string()],
2527 is_recursive: true,
2528 });
2529 builder.add_rule(
2530 0,
2531 CompiledRule {
2532 head: "tri".to_string(),
2533 body: canonical_triangle_tree(),
2534 meta: Default::default(),
2535 },
2536 );
2537 let mut plan = builder.build();
2538 promote_multiway(
2541 &mut plan,
2542 &HashMap::new(),
2543 &StatsManager::new(),
2544 &CompilerConfig::default(),
2545 );
2546 assert!(matches!(
2547 &plan.rules_by_scc[0][0].body,
2548 RirNode::MultiWayJoin { .. }
2549 ));
2550 }
2551
2552 #[test]
2557 fn promotes_linear_recursive_triangle() {
2558 let mut builder = PlanBuilder::new();
2559 builder.add_scc(Scc {
2560 id: 0,
2561 predicates: vec!["tri".to_string()],
2562 is_recursive: true,
2563 });
2564 builder.add_rule(
2565 0,
2566 CompiledRule {
2567 head: "tri".to_string(),
2568 body: canonical_triangle_tree(),
2569 meta: Default::default(),
2570 },
2571 );
2572 let mut plan = builder.build();
2573 let mut rel_ids = HashMap::new();
2574 rel_ids.insert("tri".to_string(), RelId(2)); promote_multiway(
2576 &mut plan,
2577 &rel_ids,
2578 &StatsManager::new(),
2579 &CompilerConfig::default(),
2580 );
2581 assert!(matches!(
2582 &plan.rules_by_scc[0][0].body,
2583 RirNode::MultiWayJoin { .. }
2584 ));
2585 }
2586
2587 #[test]
2593 fn promotes_multirec_triangle_in_recursive_scc() {
2594 let mut builder = PlanBuilder::new();
2595 builder.add_scc(Scc {
2596 id: 0,
2597 predicates: vec!["tri_a".to_string(), "tri_b".to_string()],
2598 is_recursive: true,
2599 });
2600 builder.add_rule(
2601 0,
2602 CompiledRule {
2603 head: "tri_a".to_string(),
2604 body: canonical_triangle_tree(),
2605 meta: Default::default(),
2606 },
2607 );
2608 let mut plan = builder.build();
2609 let mut rel_ids = HashMap::new();
2610 rel_ids.insert("tri_a".to_string(), RelId(1));
2611 rel_ids.insert("tri_b".to_string(), RelId(2));
2612 promote_multiway(
2614 &mut plan,
2615 &rel_ids,
2616 &StatsManager::new(),
2617 &CompilerConfig::default(),
2618 );
2619 assert!(matches!(
2620 &plan.rules_by_scc[0][0].body,
2621 RirNode::MultiWayJoin { .. }
2622 ));
2623 }
2624
2625 #[test]
2629 fn promotes_linear_rec_and_non_rec_sccs_in_mixed_plan() {
2630 let mut builder = PlanBuilder::new();
2631 builder.add_scc(Scc {
2632 id: 0,
2633 predicates: vec!["rec".to_string()],
2634 is_recursive: true,
2635 });
2636 builder.add_rule(
2637 0,
2638 CompiledRule {
2639 head: "rec".to_string(),
2640 body: canonical_triangle_tree(),
2641 meta: Default::default(),
2642 },
2643 );
2644 builder.add_scc(Scc {
2645 id: 1,
2646 predicates: vec!["nonrec".to_string()],
2647 is_recursive: false,
2648 });
2649 builder.add_rule(
2650 1,
2651 CompiledRule {
2652 head: "nonrec".to_string(),
2653 body: canonical_triangle_tree(),
2654 meta: Default::default(),
2655 },
2656 );
2657 let mut plan = builder.build();
2658 let mut rel_ids = HashMap::new();
2659 rel_ids.insert("rec".to_string(), RelId(1)); promote_multiway(
2663 &mut plan,
2664 &rel_ids,
2665 &StatsManager::new(),
2666 &CompilerConfig::default(),
2667 );
2668 assert!(matches!(
2669 &plan.rules_by_scc[0][0].body,
2670 RirNode::MultiWayJoin { .. }
2671 ));
2672 assert!(matches!(
2673 &plan.rules_by_scc[1][0].body,
2674 RirNode::MultiWayJoin { .. }
2675 ));
2676 }
2677
2678 fn canonical_4cycle_tree() -> RirNode {
2688 let inner_l = RirNode::Join {
2689 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2690 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2691 left_keys: vec![1],
2692 right_keys: vec![0],
2693 join_type: JoinType::Inner,
2694 };
2695 let inner_r = RirNode::Join {
2696 left: Box::new(RirNode::Scan { rel: RelId(3) }),
2697 right: Box::new(RirNode::Scan { rel: RelId(4) }),
2698 left_keys: vec![1],
2699 right_keys: vec![0],
2700 join_type: JoinType::Inner,
2701 };
2702 let outer = RirNode::Join {
2703 left: Box::new(inner_l),
2704 right: Box::new(inner_r),
2705 left_keys: vec![0, 3],
2706 right_keys: vec![3, 0],
2707 join_type: JoinType::Inner,
2708 };
2709 RirNode::Project {
2710 input: Box::new(outer),
2711 columns: vec![
2712 ProjectExpr::Column(0),
2713 ProjectExpr::Column(1),
2714 ProjectExpr::Column(3),
2715 ProjectExpr::Column(5),
2716 ],
2717 }
2718 }
2719
2720 fn plan_with_4cycle_body(body: RirNode) -> ExecutionPlan {
2721 let mut builder = PlanBuilder::new();
2722 builder.add_scc(Scc {
2723 id: 0,
2724 predicates: vec!["cycle4".to_string()],
2725 is_recursive: false,
2726 });
2727 builder.add_rule(
2728 0,
2729 CompiledRule {
2730 head: "cycle4".to_string(),
2731 body,
2732 meta: Default::default(),
2733 },
2734 );
2735 builder.build()
2736 }
2737
2738 #[test]
2739 fn promotes_canonical_4cycle() {
2740 let mut plan = plan_with_4cycle_body(canonical_4cycle_tree());
2741 promote_multiway(
2742 &mut plan,
2743 &HashMap::new(),
2744 &StatsManager::new(),
2745 &CompilerConfig::default(),
2746 );
2747 let body = &plan.rules_by_scc[0][0].body;
2748 match body {
2749 RirNode::MultiWayJoin {
2750 inputs,
2751 slot_vars,
2752 output_columns,
2753 fallback,
2754 var_order: _,
2755 ..
2756 } => {
2757 assert_eq!(inputs.len(), 4);
2758 assert!(matches!(inputs[0], RirNode::Scan { rel: RelId(1) }));
2759 assert!(matches!(inputs[1], RirNode::Scan { rel: RelId(2) }));
2760 assert!(matches!(inputs[2], RirNode::Scan { rel: RelId(3) }));
2761 assert!(matches!(inputs[3], RirNode::Scan { rel: RelId(4) }));
2762 assert_eq!(
2763 slot_vars,
2764 &vec![
2765 vec![Some(0u32), Some(1)],
2766 vec![Some(1u32), Some(2)],
2767 vec![Some(2u32), Some(3)],
2768 vec![Some(3u32), Some(0)],
2769 ]
2770 );
2771 assert_eq!(
2772 output_columns,
2773 &vec![
2774 ProjectExpr::Column(0),
2775 ProjectExpr::Column(1),
2776 ProjectExpr::Column(3),
2777 ProjectExpr::Column(5),
2778 ]
2779 );
2780 assert!(matches!(fallback.as_ref(), RirNode::Project { .. }));
2781 }
2782 other => panic!("expected MultiWayJoin, got {:?}", other),
2783 }
2784 }
2785
2786 #[test]
2787 fn fallback_4cycle_is_structurally_equal_to_input() {
2788 let pre = canonical_4cycle_tree();
2789 let mut plan = plan_with_4cycle_body(pre.clone());
2790 promote_multiway(
2791 &mut plan,
2792 &HashMap::new(),
2793 &StatsManager::new(),
2794 &CompilerConfig::default(),
2795 );
2796 let body = &plan.rules_by_scc[0][0].body;
2797 let RirNode::MultiWayJoin { fallback, .. } = body else {
2798 panic!("expected MultiWayJoin");
2799 };
2800 assert_eq!(format!("{:?}", fallback.as_ref()), format!("{:?}", pre));
2801 }
2802
2803 #[test]
2804 fn idempotent_4cycle_under_repeat_calls() {
2805 let mut plan = plan_with_4cycle_body(canonical_4cycle_tree());
2806 promote_multiway(
2807 &mut plan,
2808 &HashMap::new(),
2809 &StatsManager::new(),
2810 &CompilerConfig::default(),
2811 );
2812 let first = format!("{:?}", &plan.rules_by_scc[0][0].body);
2813 promote_multiway(
2814 &mut plan,
2815 &HashMap::new(),
2816 &StatsManager::new(),
2817 &CompilerConfig::default(),
2818 );
2819 let second = format!("{:?}", &plan.rules_by_scc[0][0].body);
2820 assert_eq!(first, second);
2821 }
2822
2823 #[test]
2824 fn rejects_4cycle_with_left_deep_shape() {
2825 let inner = RirNode::Join {
2832 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2833 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2834 left_keys: vec![1],
2835 right_keys: vec![0],
2836 join_type: JoinType::Inner,
2837 };
2838 let outer = RirNode::Join {
2839 left: Box::new(inner),
2840 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2841 left_keys: vec![0, 3],
2842 right_keys: vec![0, 1],
2843 join_type: JoinType::Inner,
2844 };
2845 let body = RirNode::Project {
2846 input: Box::new(outer),
2847 columns: vec![
2848 ProjectExpr::Column(0),
2849 ProjectExpr::Column(1),
2850 ProjectExpr::Column(3),
2851 ProjectExpr::Column(5),
2852 ],
2853 };
2854 assert!(
2859 try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
2860 );
2861 }
2862
2863 #[test]
2871 fn promotes_4cycle_with_alternative_inner_grouping() {
2872 let inner_left = RirNode::Join {
2875 left: Box::new(RirNode::Scan { rel: RelId(2) }),
2876 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2877 left_keys: vec![1],
2878 right_keys: vec![0],
2879 join_type: JoinType::Inner,
2880 };
2881 let inner_right = RirNode::Join {
2882 left: Box::new(RirNode::Scan { rel: RelId(4) }),
2883 right: Box::new(RirNode::Scan { rel: RelId(1) }),
2884 left_keys: vec![1],
2885 right_keys: vec![0],
2886 join_type: JoinType::Inner,
2887 };
2888 let outer = RirNode::Join {
2889 left: Box::new(inner_left),
2890 right: Box::new(inner_right),
2891 left_keys: vec![0, 3],
2894 right_keys: vec![3, 0],
2895 join_type: JoinType::Inner,
2896 };
2897 let body = RirNode::Project {
2900 input: Box::new(outer),
2901 columns: vec![
2902 ProjectExpr::Column(5),
2903 ProjectExpr::Column(0),
2904 ProjectExpr::Column(1),
2905 ProjectExpr::Column(3),
2906 ],
2907 };
2908 let mut plan = plan_with_body(body);
2909 promote_multiway(
2910 &mut plan,
2911 &HashMap::new(),
2912 &StatsManager::new(),
2913 &CompilerConfig::default(),
2914 );
2915 let RirNode::MultiWayJoin {
2916 inputs, slot_vars, ..
2917 } = &plan.rules_by_scc[0][0].body
2918 else {
2919 panic!("expected MultiWayJoin after promotion");
2920 };
2921 let scan_rels: Vec<RelId> = inputs
2922 .iter()
2923 .map(|n| match n {
2924 RirNode::Scan { rel } => *rel,
2925 _ => panic!("expected Scan in MultiWayJoin inputs"),
2926 })
2927 .collect();
2928 assert_eq!(
2931 scan_rels,
2932 vec![RelId(1), RelId(2), RelId(3), RelId(4)],
2933 "inputs must be in semantic order regardless of positional layout"
2934 );
2935 assert_eq!(
2937 slot_vars,
2938 &vec![
2939 vec![Some(0u32), Some(1)],
2940 vec![Some(1u32), Some(2)],
2941 vec![Some(2u32), Some(3)],
2942 vec![Some(3u32), Some(0)],
2943 ]
2944 );
2945 }
2946
2947 #[test]
2948 fn rejects_4cycle_with_rotated_columns() {
2949 let mut body = canonical_4cycle_tree();
2950 if let RirNode::Project { columns, .. } = &mut body {
2951 columns.swap(0, 1);
2953 }
2954 assert!(
2955 try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
2956 );
2957 }
2958
2959 #[test]
2960 fn rejects_4cycle_with_non_inner_outer_join() {
2961 let mut body = canonical_4cycle_tree();
2962 if let RirNode::Project { input, .. } = &mut body {
2963 if let RirNode::Join { join_type, .. } = input.as_mut() {
2964 *join_type = JoinType::LeftOuter;
2965 }
2966 }
2967 assert!(
2968 try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
2969 );
2970 }
2971
2972 #[test]
2973 fn rejects_4cycle_with_wrong_outer_keys() {
2974 let mut body = canonical_4cycle_tree();
2975 if let RirNode::Project { input, .. } = &mut body {
2976 if let RirNode::Join { left_keys, .. } = input.as_mut() {
2977 *left_keys = vec![0, 4]; }
2979 }
2980 assert!(
2981 try_promote_4cycle(&body, &StatsManager::new(), &CompilerConfig::default()).is_none()
2982 );
2983 }
2984
2985 #[test]
2988 fn promotes_stable_4cycle_in_recursive_scc() {
2989 let mut builder = PlanBuilder::new();
2990 builder.add_scc(Scc {
2991 id: 0,
2992 predicates: vec!["rec_cycle".to_string()],
2993 is_recursive: true,
2994 });
2995 builder.add_rule(
2996 0,
2997 CompiledRule {
2998 head: "rec_cycle".to_string(),
2999 body: canonical_4cycle_tree(),
3000 meta: Default::default(),
3001 },
3002 );
3003 let mut plan = builder.build();
3004 promote_multiway(
3006 &mut plan,
3007 &HashMap::new(),
3008 &StatsManager::new(),
3009 &CompilerConfig::default(),
3010 );
3011 assert!(matches!(
3012 &plan.rules_by_scc[0][0].body,
3013 RirNode::MultiWayJoin { .. }
3014 ));
3015 }
3016
3017 #[test]
3020 fn promotes_linear_recursive_4cycle() {
3021 let mut builder = PlanBuilder::new();
3022 builder.add_scc(Scc {
3023 id: 0,
3024 predicates: vec!["rec_cycle".to_string()],
3025 is_recursive: true,
3026 });
3027 builder.add_rule(
3028 0,
3029 CompiledRule {
3030 head: "rec_cycle".to_string(),
3031 body: canonical_4cycle_tree(),
3032 meta: Default::default(),
3033 },
3034 );
3035 let mut plan = builder.build();
3036 let mut rel_ids = HashMap::new();
3037 rel_ids.insert("rec_cycle".to_string(), RelId(2));
3040 promote_multiway(
3041 &mut plan,
3042 &rel_ids,
3043 &StatsManager::new(),
3044 &CompilerConfig::default(),
3045 );
3046 assert!(matches!(
3047 &plan.rules_by_scc[0][0].body,
3048 RirNode::MultiWayJoin { .. }
3049 ));
3050 }
3051
3052 #[test]
3055 fn promotes_multirec_4cycle_in_recursive_scc() {
3056 let mut builder = PlanBuilder::new();
3057 builder.add_scc(Scc {
3058 id: 0,
3059 predicates: vec!["rc_a".to_string(), "rc_b".to_string()],
3060 is_recursive: true,
3061 });
3062 builder.add_rule(
3063 0,
3064 CompiledRule {
3065 head: "rc_a".to_string(),
3066 body: canonical_4cycle_tree(),
3067 meta: Default::default(),
3068 },
3069 );
3070 let mut plan = builder.build();
3071 let mut rel_ids = HashMap::new();
3072 rel_ids.insert("rc_a".to_string(), RelId(1));
3073 rel_ids.insert("rc_b".to_string(), RelId(2));
3074 promote_multiway(
3075 &mut plan,
3076 &rel_ids,
3077 &StatsManager::new(),
3078 &CompilerConfig::default(),
3079 );
3080 assert!(matches!(
3081 &plan.rules_by_scc[0][0].body,
3082 RirNode::MultiWayJoin { .. }
3083 ));
3084 }
3085
3086 #[test]
3087 fn triangle_does_not_match_4cycle_promoter() {
3088 let triangle = canonical_triangle_tree();
3092 assert!(
3093 try_promote_4cycle(&triangle, &StatsManager::new(), &CompilerConfig::default())
3094 .is_none()
3095 );
3096 let four_cycle = canonical_4cycle_tree();
3097 assert!(try_promote_triangle(
3098 &four_cycle,
3099 &StatsManager::new(),
3100 &CompilerConfig::default()
3101 )
3102 .is_none());
3103 }
3104}