1use std::collections::HashMap;
33use std::sync::Arc;
34use xlog_core::{RelId, Schema};
35use xlog_ir::{CompareOp, Expr, JoinType, RirNode};
36use xlog_stats::StatsManager;
37
38#[derive(Debug, Clone)]
43#[non_exhaustive]
44pub struct OptimizerConfig {
45 pub dp_threshold: usize,
51
52 pub index_heat_threshold: f32,
57
58 pub enable_pushdown: bool,
63
64 pub default_filter_selectivity: f64,
69
70 pub transfer_cost_multiplier: f64,
75
76 pub default_bytes_per_row: u64,
80}
81
82impl Default for OptimizerConfig {
83 fn default() -> Self {
84 Self {
85 dp_threshold: 10,
86 index_heat_threshold: 0.7,
87 enable_pushdown: true,
88 default_filter_selectivity: 0.1,
89 transfer_cost_multiplier: 100.0,
90 default_bytes_per_row: 32,
91 }
92 }
93}
94
95#[derive(Debug, Clone, Default, PartialEq)]
100pub struct PlanCost {
101 pub rows: u64,
103
104 pub cpu_cost: f64,
109
110 pub gpu_mem: u64,
115
116 pub transfers: u32,
121}
122
123impl PlanCost {
124 pub fn with_rows(rows: u64) -> Self {
126 Self {
127 rows,
128 ..Default::default()
129 }
130 }
131
132 pub fn total_cost(&self, transfer_weight: f64) -> f64 {
143 self.cpu_cost + (self.gpu_mem as f64 * 0.001) + (self.transfers as f64 * transfer_weight)
144 }
145
146 pub fn then(self, other: PlanCost) -> PlanCost {
150 PlanCost {
151 rows: other.rows,
152 cpu_cost: self.cpu_cost + other.cpu_cost,
153 gpu_mem: self.gpu_mem.max(other.gpu_mem), transfers: self.transfers + other.transfers,
155 }
156 }
157}
158
159pub struct Optimizer {
165 stats: Arc<StatsManager>,
166 config: OptimizerConfig,
167 schemas: HashMap<RelId, Schema>,
169}
170
171impl Optimizer {
172 pub fn new(stats: Arc<StatsManager>) -> Self {
178 Self {
179 stats,
180 config: OptimizerConfig::default(),
181 schemas: HashMap::new(),
182 }
183 }
184
185 pub fn with_config(stats: Arc<StatsManager>, config: OptimizerConfig) -> Self {
192 Self {
193 stats,
194 config,
195 schemas: HashMap::new(),
196 }
197 }
198
199 pub fn set_schemas(&mut self, schemas: HashMap<RelId, Schema>) {
204 self.schemas = schemas;
205 }
206
207 pub fn config(&self) -> &OptimizerConfig {
209 &self.config
210 }
211
212 pub fn stats(&self) -> &Arc<StatsManager> {
214 &self.stats
215 }
216
217 pub fn optimize(&self, node: RirNode) -> RirNode {
235 if self.config.enable_pushdown {
236 self.predicate_pushdown(node)
237 } else {
238 node
239 }
240 }
241
242 fn predicate_pushdown(&self, node: RirNode) -> RirNode {
260 match node {
261 RirNode::Unit => RirNode::Unit,
263 RirNode::Scan { rel } => RirNode::Scan { rel },
264
265 RirNode::Filter { input, predicate } => {
267 let optimized_input = self.predicate_pushdown(*input);
269
270 match optimized_input {
271 RirNode::Filter {
273 input: inner_input,
274 predicate: inner_pred,
275 } => {
276 let merged = Expr::And(vec![inner_pred, predicate]);
277 RirNode::Filter {
278 input: inner_input,
279 predicate: merged,
280 }
281 }
282
283 RirNode::Project {
285 input: proj_input,
286 columns,
287 } => {
288 if let Some(remapped) =
290 self.remap_predicate_through_project(&predicate, &columns)
291 {
292 RirNode::Project {
294 input: Box::new(RirNode::Filter {
295 input: proj_input,
296 predicate: remapped,
297 }),
298 columns,
299 }
300 } else {
301 RirNode::Filter {
303 input: Box::new(RirNode::Project {
304 input: proj_input,
305 columns,
306 }),
307 predicate,
308 }
309 }
310 }
311
312 RirNode::Join {
314 left,
315 right,
316 left_keys,
317 right_keys,
318 join_type,
319 } => {
320 let left_width = self.estimate_width(&left);
321 let (left_preds, right_preds, remaining) =
322 self.split_predicate_for_join(&predicate, left_width);
323
324 let new_left = if !left_preds.is_empty() {
326 Box::new(RirNode::Filter {
327 input: left,
328 predicate: Self::conjoin(left_preds),
329 })
330 } else {
331 left
332 };
333
334 let new_right = if !right_preds.is_empty() {
335 Box::new(RirNode::Filter {
336 input: right,
337 predicate: Self::conjoin(right_preds),
338 })
339 } else {
340 right
341 };
342
343 let join_node = RirNode::Join {
344 left: new_left,
345 right: new_right,
346 left_keys,
347 right_keys,
348 join_type,
349 };
350
351 if !remaining.is_empty() {
353 RirNode::Filter {
354 input: Box::new(join_node),
355 predicate: Self::conjoin(remaining),
356 }
357 } else {
358 join_node
359 }
360 }
361
362 other => RirNode::Filter {
364 input: Box::new(other),
365 predicate,
366 },
367 }
368 }
369
370 RirNode::Project { input, columns } => RirNode::Project {
372 input: Box::new(self.predicate_pushdown(*input)),
373 columns,
374 },
375
376 RirNode::Join {
378 left,
379 right,
380 left_keys,
381 right_keys,
382 join_type,
383 } => RirNode::Join {
384 left: Box::new(self.predicate_pushdown(*left)),
385 right: Box::new(self.predicate_pushdown(*right)),
386 left_keys,
387 right_keys,
388 join_type,
389 },
390
391 RirNode::GroupBy {
393 input,
394 key_cols,
395 aggs,
396 } => RirNode::GroupBy {
397 input: Box::new(self.predicate_pushdown(*input)),
398 key_cols,
399 aggs,
400 },
401
402 RirNode::Union { inputs } => RirNode::Union {
404 inputs: inputs
405 .into_iter()
406 .map(|i| self.predicate_pushdown(i))
407 .collect(),
408 },
409
410 RirNode::Distinct { input, key_cols } => RirNode::Distinct {
412 input: Box::new(self.predicate_pushdown(*input)),
413 key_cols,
414 },
415
416 RirNode::Diff { left, right } => RirNode::Diff {
418 left: Box::new(self.predicate_pushdown(*left)),
419 right: Box::new(self.predicate_pushdown(*right)),
420 },
421
422 RirNode::Fixpoint {
424 scc_id,
425 base,
426 recursive,
427 delta_rel,
428 full_rel,
429 } => RirNode::Fixpoint {
430 scc_id,
431 base: Box::new(self.predicate_pushdown(*base)),
432 recursive: Box::new(self.predicate_pushdown(*recursive)),
433 delta_rel,
434 full_rel,
435 },
436
437 RirNode::TensorMaskedJoin { .. } => node, RirNode::MultiWayJoin { .. } | RirNode::ChainJoin { .. } => node,
443 }
444 }
445
446 fn remap_predicate_through_project(
452 &self,
453 predicate: &Expr,
454 columns: &[xlog_ir::ProjectExpr],
455 ) -> Option<Expr> {
456 let mut output_to_input: std::collections::HashMap<usize, usize> =
459 std::collections::HashMap::new();
460
461 for (out_idx, proj_expr) in columns.iter().enumerate() {
462 if let xlog_ir::ProjectExpr::Column(in_idx) = proj_expr {
463 output_to_input.insert(out_idx, *in_idx);
464 }
465 }
466
467 self.remap_expr(predicate, &output_to_input)
468 }
469
470 fn remap_expr(
472 &self,
473 expr: &Expr,
474 mapping: &std::collections::HashMap<usize, usize>,
475 ) -> Option<Expr> {
476 match expr {
477 Expr::Column(idx) => mapping.get(idx).map(|&new_idx| Expr::Column(new_idx)),
478
479 Expr::Const(val) => Some(Expr::Const(val.clone())),
480
481 Expr::Compare { left, op, right } => {
482 let new_left = self.remap_expr(left, mapping)?;
483 let new_right = self.remap_expr(right, mapping)?;
484 Some(Expr::Compare {
485 left: Box::new(new_left),
486 op: *op,
487 right: Box::new(new_right),
488 })
489 }
490
491 Expr::And(exprs) => {
492 let remapped: Option<Vec<_>> =
493 exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
494 remapped.map(Expr::And)
495 }
496
497 Expr::Or(exprs) => {
498 let remapped: Option<Vec<_>> =
499 exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
500 remapped.map(Expr::Or)
501 }
502
503 Expr::Not(inner) => {
504 let remapped = self.remap_expr(inner, mapping)?;
505 Some(Expr::Not(Box::new(remapped)))
506 }
507
508 Expr::Add(l, r) => {
510 let new_l = self.remap_expr(l, mapping)?;
511 let new_r = self.remap_expr(r, mapping)?;
512 Some(Expr::Add(Box::new(new_l), Box::new(new_r)))
513 }
514 Expr::Sub(l, r) => {
515 let new_l = self.remap_expr(l, mapping)?;
516 let new_r = self.remap_expr(r, mapping)?;
517 Some(Expr::Sub(Box::new(new_l), Box::new(new_r)))
518 }
519 Expr::Mul(l, r) => {
520 let new_l = self.remap_expr(l, mapping)?;
521 let new_r = self.remap_expr(r, mapping)?;
522 Some(Expr::Mul(Box::new(new_l), Box::new(new_r)))
523 }
524 Expr::Div(l, r) => {
525 let new_l = self.remap_expr(l, mapping)?;
526 let new_r = self.remap_expr(r, mapping)?;
527 Some(Expr::Div(Box::new(new_l), Box::new(new_r)))
528 }
529 Expr::Mod(l, r) => {
530 let new_l = self.remap_expr(l, mapping)?;
531 let new_r = self.remap_expr(r, mapping)?;
532 Some(Expr::Mod(Box::new(new_l), Box::new(new_r)))
533 }
534
535 Expr::Abs(inner) => {
537 let remapped = self.remap_expr(inner, mapping)?;
538 Some(Expr::Abs(Box::new(remapped)))
539 }
540 Expr::Min(l, r) => {
541 let new_l = self.remap_expr(l, mapping)?;
542 let new_r = self.remap_expr(r, mapping)?;
543 Some(Expr::Min(Box::new(new_l), Box::new(new_r)))
544 }
545 Expr::Max(l, r) => {
546 let new_l = self.remap_expr(l, mapping)?;
547 let new_r = self.remap_expr(r, mapping)?;
548 Some(Expr::Max(Box::new(new_l), Box::new(new_r)))
549 }
550 Expr::Pow(l, r) => {
551 let new_l = self.remap_expr(l, mapping)?;
552 let new_r = self.remap_expr(r, mapping)?;
553 Some(Expr::Pow(Box::new(new_l), Box::new(new_r)))
554 }
555 Expr::Cast(inner, scalar_type) => {
556 let remapped = self.remap_expr(inner, mapping)?;
557 Some(Expr::Cast(Box::new(remapped), *scalar_type))
558 }
559 Expr::Conditional {
560 condition,
561 then_expr,
562 else_expr,
563 } => {
564 let new_condition = self.remap_expr(condition, mapping)?;
565 let new_then = self.remap_expr(then_expr, mapping)?;
566 let new_else = self.remap_expr(else_expr, mapping)?;
567 Some(Expr::Conditional {
568 condition: Box::new(new_condition),
569 then_expr: Box::new(new_then),
570 else_expr: Box::new(new_else),
571 })
572 }
573 }
574 }
575
576 fn estimate_width(&self, node: &RirNode) -> usize {
578 match node {
579 RirNode::Unit => 0,
580 RirNode::Scan { rel } => {
581 if let Some(schema) = self.schemas.get(rel) {
583 schema.arity()
584 } else if let Some(stats) = self.stats.get_relation_stats(*rel) {
585 stats.column_stats.len().max(1)
586 } else {
587 4 }
589 }
590 RirNode::Filter { input, .. } => self.estimate_width(input),
591 RirNode::Project { columns, .. } => columns.len(),
592 RirNode::Join { left, right, .. } => {
593 self.estimate_width(left) + self.estimate_width(right)
594 }
595 RirNode::ChainJoin { output_columns, .. } => output_columns.len(),
596 RirNode::GroupBy { key_cols, aggs, .. } => key_cols.len() + aggs.len(),
597 RirNode::Union { inputs } => {
598 inputs.first().map(|i| self.estimate_width(i)).unwrap_or(0)
599 }
600 RirNode::Distinct { input, .. } => self.estimate_width(input),
601 RirNode::Diff { left, .. } => self.estimate_width(left),
602 RirNode::Fixpoint { base, .. } => self.estimate_width(base),
603 RirNode::TensorMaskedJoin { head_rel_id, .. } => self
606 .schemas
607 .get(head_rel_id)
608 .map(|s| s.arity())
609 .unwrap_or(2),
610 RirNode::MultiWayJoin { output_columns, .. } => output_columns.len(),
613 }
614 }
615
616 fn split_predicate_for_join(
620 &self,
621 predicate: &Expr,
622 left_width: usize,
623 ) -> (Vec<Expr>, Vec<Expr>, Vec<Expr>) {
624 let mut left_preds = Vec::new();
625 let mut right_preds = Vec::new();
626 let mut remaining = Vec::new();
627
628 let conjuncts = Self::flatten_and(predicate);
630
631 for conj in conjuncts {
632 let cols = Self::collect_columns(&conj);
633 let max_col = cols.iter().copied().max().unwrap_or(0);
634 let min_col = cols.iter().copied().min().unwrap_or(0);
635
636 if cols.is_empty() {
637 left_preds.push(conj);
639 } else if max_col < left_width {
640 left_preds.push(conj);
642 } else if min_col >= left_width {
643 let remapped = Self::remap_columns(&conj, |c| c - left_width);
645 right_preds.push(remapped);
646 } else {
647 remaining.push(conj);
649 }
650 }
651
652 (left_preds, right_preds, remaining)
653 }
654
655 fn flatten_and(expr: &Expr) -> Vec<Expr> {
657 match expr {
658 Expr::And(exprs) => exprs.iter().flat_map(Self::flatten_and).collect(),
659 other => vec![other.clone()],
660 }
661 }
662
663 fn collect_columns(expr: &Expr) -> Vec<usize> {
665 match expr {
666 Expr::Column(idx) => vec![*idx],
667 Expr::Const(_) => vec![],
668 Expr::Compare { left, right, .. } => {
669 let mut cols = Self::collect_columns(left);
670 cols.extend(Self::collect_columns(right));
671 cols
672 }
673 Expr::And(exprs) | Expr::Or(exprs) => {
674 exprs.iter().flat_map(Self::collect_columns).collect()
675 }
676 Expr::Not(inner) | Expr::Abs(inner) | Expr::Cast(inner, _) => {
677 Self::collect_columns(inner)
678 }
679 Expr::Add(l, r)
680 | Expr::Sub(l, r)
681 | Expr::Mul(l, r)
682 | Expr::Div(l, r)
683 | Expr::Mod(l, r)
684 | Expr::Min(l, r)
685 | Expr::Max(l, r)
686 | Expr::Pow(l, r) => {
687 let mut cols = Self::collect_columns(l);
688 cols.extend(Self::collect_columns(r));
689 cols
690 }
691 Expr::Conditional {
692 condition,
693 then_expr,
694 else_expr,
695 } => {
696 let mut cols = Self::collect_columns(condition);
697 cols.extend(Self::collect_columns(then_expr));
698 cols.extend(Self::collect_columns(else_expr));
699 cols
700 }
701 }
702 }
703
704 fn remap_columns<F: Fn(usize) -> usize + Copy>(expr: &Expr, f: F) -> Expr {
706 match expr {
707 Expr::Column(idx) => Expr::Column(f(*idx)),
708 Expr::Const(v) => Expr::Const(v.clone()),
709 Expr::Compare { left, op, right } => Expr::Compare {
710 left: Box::new(Self::remap_columns(left, f)),
711 op: *op,
712 right: Box::new(Self::remap_columns(right, f)),
713 },
714 Expr::And(exprs) => {
715 Expr::And(exprs.iter().map(|e| Self::remap_columns(e, f)).collect())
716 }
717 Expr::Or(exprs) => Expr::Or(exprs.iter().map(|e| Self::remap_columns(e, f)).collect()),
718 Expr::Not(inner) => Expr::Not(Box::new(Self::remap_columns(inner, f))),
719 Expr::Add(l, r) => Expr::Add(
720 Box::new(Self::remap_columns(l, f)),
721 Box::new(Self::remap_columns(r, f)),
722 ),
723 Expr::Sub(l, r) => Expr::Sub(
724 Box::new(Self::remap_columns(l, f)),
725 Box::new(Self::remap_columns(r, f)),
726 ),
727 Expr::Mul(l, r) => Expr::Mul(
728 Box::new(Self::remap_columns(l, f)),
729 Box::new(Self::remap_columns(r, f)),
730 ),
731 Expr::Div(l, r) => Expr::Div(
732 Box::new(Self::remap_columns(l, f)),
733 Box::new(Self::remap_columns(r, f)),
734 ),
735 Expr::Mod(l, r) => Expr::Mod(
736 Box::new(Self::remap_columns(l, f)),
737 Box::new(Self::remap_columns(r, f)),
738 ),
739 Expr::Abs(inner) => Expr::Abs(Box::new(Self::remap_columns(inner, f))),
740 Expr::Min(l, r) => Expr::Min(
741 Box::new(Self::remap_columns(l, f)),
742 Box::new(Self::remap_columns(r, f)),
743 ),
744 Expr::Max(l, r) => Expr::Max(
745 Box::new(Self::remap_columns(l, f)),
746 Box::new(Self::remap_columns(r, f)),
747 ),
748 Expr::Pow(l, r) => Expr::Pow(
749 Box::new(Self::remap_columns(l, f)),
750 Box::new(Self::remap_columns(r, f)),
751 ),
752 Expr::Cast(inner, t) => Expr::Cast(Box::new(Self::remap_columns(inner, f)), *t),
753 Expr::Conditional {
754 condition,
755 then_expr,
756 else_expr,
757 } => Expr::Conditional {
758 condition: Box::new(Self::remap_columns(condition, f)),
759 then_expr: Box::new(Self::remap_columns(then_expr, f)),
760 else_expr: Box::new(Self::remap_columns(else_expr, f)),
761 },
762 }
763 }
764
765 fn conjoin(predicates: Vec<Expr>) -> Expr {
767 debug_assert!(!predicates.is_empty());
768 if predicates.len() == 1 {
769 predicates.into_iter().next().unwrap()
770 } else {
771 Expr::And(predicates)
772 }
773 }
774
775 pub fn estimate_cost(&self, node: &RirNode) -> PlanCost {
788 match node {
789 RirNode::Unit => PlanCost {
790 rows: 1,
791 cpu_cost: 0.0,
792 gpu_mem: 0,
793 transfers: 0,
794 },
795 RirNode::Scan { rel } => self.estimate_scan_cost(*rel),
796
797 RirNode::Filter { input, predicate } => {
798 let input_cost = self.estimate_cost(input);
799 self.estimate_filter_cost(input_cost, predicate, input)
800 }
801
802 RirNode::Project { input, columns } => {
803 let input_cost = self.estimate_cost(input);
804 self.estimate_project_cost(input_cost, columns)
805 }
806
807 RirNode::Join {
808 left,
809 right,
810 left_keys,
811 right_keys,
812 join_type,
813 } => {
814 let left_cost = self.estimate_cost(left);
815 let right_cost = self.estimate_cost(right);
816 self.estimate_join_cost(
817 left_cost, right_cost, left, right, left_keys, right_keys, *join_type,
818 )
819 }
820
821 RirNode::ChainJoin {
822 left,
823 right,
824 left_key,
825 right_key,
826 output_columns,
827 ..
828 } => {
829 let left_cost = self.estimate_cost(left);
830 let right_cost = self.estimate_cost(right);
831 let join_cost = self.estimate_join_cost(
832 left_cost,
833 right_cost,
834 left,
835 right,
836 &[*left_key],
837 &[*right_key],
838 JoinType::Inner,
839 );
840 self.estimate_project_cost(join_cost, output_columns)
841 }
842
843 RirNode::GroupBy {
844 input,
845 key_cols,
846 aggs,
847 } => {
848 let input_cost = self.estimate_cost(input);
849 self.estimate_groupby_cost(input_cost, key_cols, aggs)
850 }
851
852 RirNode::Union { inputs } => {
853 let costs: Vec<_> = inputs.iter().map(|i| self.estimate_cost(i)).collect();
854 self.estimate_union_cost(costs)
855 }
856
857 RirNode::Distinct { input, key_cols } => {
858 let input_cost = self.estimate_cost(input);
859 self.estimate_distinct_cost(input_cost, key_cols)
860 }
861
862 RirNode::Diff { left, right } => {
863 let left_cost = self.estimate_cost(left);
864 let right_cost = self.estimate_cost(right);
865 self.estimate_diff_cost(left_cost, right_cost)
866 }
867
868 RirNode::Fixpoint {
869 base, recursive, ..
870 } => {
871 let base_cost = self.estimate_cost(base);
872 let recursive_cost = self.estimate_cost(recursive);
873 self.estimate_fixpoint_cost(base_cost, recursive_cost)
874 }
875
876 RirNode::TensorMaskedJoin {
877 max_active_rules, ..
878 } => PlanCost {
879 rows: *max_active_rules as u64,
880 cpu_cost: *max_active_rules as f64 * 100.0,
881 gpu_mem: *max_active_rules as u64 * 1024,
882 transfers: 1,
883 },
884 RirNode::MultiWayJoin { inputs, .. } => {
889 let mut total = PlanCost::default();
890 for inp in inputs {
891 let c = self.estimate_cost(inp);
892 total.rows = total.rows.saturating_add(c.rows);
893 total.cpu_cost += c.cpu_cost;
894 total.gpu_mem = total.gpu_mem.saturating_add(c.gpu_mem);
895 total.transfers = total.transfers.saturating_add(c.transfers);
896 }
897 total
898 }
899 }
900 }
901
902 fn estimate_scan_cost(&self, rel: RelId) -> PlanCost {
904 if let Some(stats) = self.stats.get_relation_stats(rel) {
905 PlanCost {
906 rows: stats.cardinality,
907 cpu_cost: stats.cardinality as f64 * 0.01, gpu_mem: stats
909 .byte_size
910 .max(stats.cardinality * self.config.default_bytes_per_row),
911 transfers: 0, }
913 } else {
914 let default_rows = 1000;
916 PlanCost {
917 rows: default_rows,
918 cpu_cost: default_rows as f64 * 0.01,
919 gpu_mem: default_rows * self.config.default_bytes_per_row,
920 transfers: 0,
921 }
922 }
923 }
924
925 fn estimate_filter_cost(
927 &self,
928 input_cost: PlanCost,
929 predicate: &Expr,
930 input: &RirNode,
931 ) -> PlanCost {
932 let selectivity = self.estimate_predicate_selectivity(predicate, input);
933 let output_rows = ((input_cost.rows as f64 * selectivity) as u64).max(1);
934
935 PlanCost {
936 rows: output_rows,
937 cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.02, gpu_mem: input_cost.gpu_mem, transfers: input_cost.transfers,
940 }
941 }
942
943 fn estimate_project_cost(
945 &self,
946 input_cost: PlanCost,
947 columns: &[xlog_ir::ProjectExpr],
948 ) -> PlanCost {
949 let computed_count = columns
951 .iter()
952 .filter(|c| matches!(c, xlog_ir::ProjectExpr::Computed(_, _)))
953 .count();
954
955 let compute_cost = computed_count as f64 * input_cost.rows as f64 * 0.05;
957
958 let output_width_ratio = columns.len() as f64 / (columns.len() + 2) as f64; PlanCost {
962 rows: input_cost.rows,
963 cpu_cost: input_cost.cpu_cost + compute_cost,
964 gpu_mem: (input_cost.gpu_mem as f64 * output_width_ratio) as u64,
965 transfers: input_cost.transfers,
966 }
967 }
968
969 #[allow(clippy::too_many_arguments)]
971 fn estimate_join_cost(
972 &self,
973 left_cost: PlanCost,
974 right_cost: PlanCost,
975 left: &RirNode,
976 right: &RirNode,
977 left_keys: &[usize],
978 right_keys: &[usize],
979 join_type: JoinType,
980 ) -> PlanCost {
981 let output_rows = match join_type {
984 JoinType::Semi => {
985 ((left_cost.rows as f64 * 0.5) as u64).max(1)
987 }
988 JoinType::Anti => {
989 ((left_cost.rows as f64 * 0.5) as u64).max(1)
991 }
992 JoinType::Inner | JoinType::LeftOuter => {
993 let left_rels = left.referenced_relations();
995 let right_rels = right.referenced_relations();
996
997 if left_rels.len() == 1 && right_rels.len() == 1 {
998 let estimated = self.stats.estimate_join_cardinality(
1000 left_rels[0],
1001 right_rels[0],
1002 left_keys,
1003 right_keys,
1004 );
1005
1006 match join_type {
1007 JoinType::LeftOuter => estimated.max(left_cost.rows),
1008 _ => estimated,
1009 }
1010 } else {
1011 match join_type {
1013 JoinType::Inner => {
1014 ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64).max(1)
1016 }
1017 JoinType::LeftOuter => {
1018 left_cost.rows.max(
1020 ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64)
1021 .max(1),
1022 )
1023 }
1024 _ => unreachable!(),
1025 }
1026 }
1027 }
1028 };
1029
1030 let build_cost = right_cost.rows as f64 * 1.0; let probe_cost = left_cost.rows as f64 * 0.5; let cpu_cost = left_cost.cpu_cost + right_cost.cpu_cost + build_cost + probe_cost;
1034
1035 let hash_table_overhead = right_cost.gpu_mem / 2; let gpu_mem = left_cost.gpu_mem + right_cost.gpu_mem + hash_table_overhead;
1038
1039 PlanCost {
1040 rows: output_rows,
1041 cpu_cost,
1042 gpu_mem,
1043 transfers: left_cost.transfers + right_cost.transfers,
1044 }
1045 }
1046
1047 fn estimate_groupby_cost(
1049 &self,
1050 input_cost: PlanCost,
1051 key_cols: &[usize],
1052 _aggs: &[(usize, xlog_core::AggOp)],
1053 ) -> PlanCost {
1054 let estimated_groups = if key_cols.is_empty() {
1057 1 } else {
1059 ((input_cost.rows as f64).sqrt() as u64).max(1)
1061 };
1062
1063 PlanCost {
1064 rows: estimated_groups,
1065 cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.5, gpu_mem: input_cost.gpu_mem + estimated_groups * self.config.default_bytes_per_row,
1067 transfers: input_cost.transfers,
1068 }
1069 }
1070
1071 fn estimate_union_cost(&self, input_costs: Vec<PlanCost>) -> PlanCost {
1073 let total_rows: u64 = input_costs.iter().map(|c| c.rows).sum();
1074 let total_cpu: f64 = input_costs.iter().map(|c| c.cpu_cost).sum();
1075 let max_gpu: u64 = input_costs.iter().map(|c| c.gpu_mem).max().unwrap_or(0);
1076 let total_transfers: u32 = input_costs.iter().map(|c| c.transfers).sum();
1077
1078 PlanCost {
1079 rows: total_rows,
1080 cpu_cost: total_cpu + total_rows as f64 * 0.01, gpu_mem: max_gpu, transfers: total_transfers,
1083 }
1084 }
1085
1086 fn estimate_distinct_cost(&self, input_cost: PlanCost, _key_cols: &[usize]) -> PlanCost {
1088 let estimated_distinct = (input_cost.rows as f64 * 0.7) as u64;
1090
1091 PlanCost {
1092 rows: estimated_distinct.max(1),
1093 cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.3, gpu_mem: input_cost.gpu_mem + input_cost.rows * 8, transfers: input_cost.transfers,
1096 }
1097 }
1098
1099 fn estimate_diff_cost(&self, left_cost: PlanCost, right_cost: PlanCost) -> PlanCost {
1101 let estimated_remaining = (left_cost.rows as f64 * 0.5) as u64;
1103
1104 PlanCost {
1105 rows: estimated_remaining.max(1),
1106 cpu_cost: left_cost.cpu_cost + right_cost.cpu_cost + right_cost.rows as f64 * 0.5,
1107 gpu_mem: left_cost.gpu_mem + right_cost.gpu_mem,
1108 transfers: left_cost.transfers + right_cost.transfers,
1109 }
1110 }
1111
1112 fn estimate_fixpoint_cost(&self, base_cost: PlanCost, recursive_cost: PlanCost) -> PlanCost {
1114 let estimated_iterations = ((base_cost.rows as f64).log2().ceil() as u64).max(1);
1117
1118 PlanCost {
1119 rows: base_cost.rows * estimated_iterations, cpu_cost: base_cost.cpu_cost + recursive_cost.cpu_cost * estimated_iterations as f64,
1121 gpu_mem: (base_cost.gpu_mem + recursive_cost.gpu_mem) * 2, transfers: base_cost.transfers + recursive_cost.transfers * estimated_iterations as u32,
1123 }
1124 }
1125
1126 fn estimate_predicate_selectivity(&self, predicate: &Expr, input: &RirNode) -> f64 {
1128 match predicate {
1129 Expr::Compare { left, op, right } => {
1130 self.estimate_compare_selectivity(left, *op, right, input)
1131 }
1132 Expr::And(exprs) => {
1133 exprs
1135 .iter()
1136 .map(|e| self.estimate_predicate_selectivity(e, input))
1137 .product()
1138 }
1139 Expr::Or(exprs) => {
1140 exprs
1143 .iter()
1144 .map(|e| self.estimate_predicate_selectivity(e, input))
1145 .fold(0.0, f64::max)
1146 }
1147 Expr::Not(inner) => 1.0 - self.estimate_predicate_selectivity(inner, input),
1148 _ => self.config.default_filter_selectivity,
1149 }
1150 }
1151
1152 fn estimate_compare_selectivity(
1154 &self,
1155 left: &Expr,
1156 op: CompareOp,
1157 right: &Expr,
1158 input: &RirNode,
1159 ) -> f64 {
1160 if let (Expr::Column(col_idx), Expr::Const(_)) | (Expr::Const(_), Expr::Column(col_idx)) =
1162 (left, right)
1163 {
1164 if let Some(rel_id) = self.find_column_relation(input, *col_idx) {
1166 if let Some(stats) = self.stats.get_relation_stats(rel_id) {
1167 if let Some(col_stats) = stats.get_column(*col_idx) {
1168 return match op {
1169 CompareOp::Eq => col_stats.equality_selectivity(stats.cardinality),
1170 CompareOp::Ne => {
1171 1.0 - col_stats.equality_selectivity(stats.cardinality)
1172 }
1173 CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => {
1174 0.33
1176 }
1177 };
1178 }
1179 }
1180 }
1181 }
1182
1183 match op {
1185 CompareOp::Eq => 0.1, CompareOp::Ne => 0.9, CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => 0.33, }
1189 }
1190
1191 fn find_column_relation(&self, node: &RirNode, col_idx: usize) -> Option<RelId> {
1193 match node {
1194 RirNode::Scan { rel } => Some(*rel),
1195 RirNode::Filter { input, .. } => self.find_column_relation(input, col_idx),
1196 RirNode::Project { input, columns } => {
1197 if col_idx < columns.len() {
1199 if let xlog_ir::ProjectExpr::Column(src_idx) = &columns[col_idx] {
1200 return self.find_column_relation(input, *src_idx);
1201 }
1202 }
1203 None
1204 }
1205 RirNode::Join { left, right, .. } => {
1206 let left_width = self.estimate_width(left);
1207 if col_idx < left_width {
1208 self.find_column_relation(left, col_idx)
1209 } else {
1210 self.find_column_relation(right, col_idx - left_width)
1211 }
1212 }
1213 RirNode::MultiWayJoin { .. } => None,
1219 _ => None, }
1221 }
1222
1223 pub fn recommend_indexes(&self) -> Vec<RelId> {
1228 self.stats.hot_relations(self.config.index_heat_threshold)
1229 }
1230
1231 pub fn should_use_greedy(&self, node: &RirNode) -> bool {
1235 let rels = node.referenced_relations();
1236 let unique_rels: std::collections::HashSet<_> = rels.iter().collect();
1237 unique_rels.len() > self.config.dp_threshold
1238 }
1239}
1240
1241pub mod selectivity_pass {
1254 use std::collections::HashMap;
1300 use xlog_core::RelId;
1301 use xlog_ir::ExecutionPlan;
1302 use xlog_stats::StatsManager;
1303
1304 pub fn run(plan: &mut ExecutionPlan, stats: &StatsManager, rel_ids: &HashMap<String, RelId>) {
1314 let _ = rel_ids;
1320 for rules in plan.rules_by_scc.iter_mut() {
1321 for rule in rules.iter_mut() {
1322 if let Some(rewritten) = super::reorder::try_reorder_triangle(&rule.body, stats) {
1323 rule.body = rewritten;
1324 continue;
1325 }
1326 if let Some(rewritten) = super::reorder::try_reorder_4cycle(&rule.body, stats) {
1327 rule.body = rewritten;
1328 }
1329 }
1330 }
1331 }
1332}
1333
1334pub mod helper_split_pass {
1336 use std::collections::{HashMap, HashSet};
1337
1338 use xlog_core::{RelId, ScalarType, Schema};
1339 use xlog_ir::rir::{HelperSplitSpec, KCliqueVariableOrder};
1340 use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
1341 use xlog_stats::StatsManager;
1342
1343 const HEAVY_SKEW_RATIO: f64 = 10.0;
1344
1345 #[derive(Debug, Clone, PartialEq, Eq)]
1347 pub struct HelperRelationSpec {
1348 pub name: String,
1350 pub rel_id: RelId,
1352 pub schema: Schema,
1354 pub source_rels: [RelId; 2],
1356 }
1357
1358 struct JoinStep {
1359 left_keys: Vec<usize>,
1360 right_keys: Vec<usize>,
1361 }
1362
1363 struct LinearBody {
1364 leaves: Vec<RelId>,
1365 leaf_classes: Vec<Vec<u32>>,
1366 joins: Vec<JoinStep>,
1367 project: Vec<ProjectExpr>,
1368 final_classes: Vec<u32>,
1369 }
1370
1371 struct FlatJoin {
1372 leaves: Vec<RelId>,
1373 output_cols: Vec<usize>,
1374 equalities: Vec<(usize, usize)>,
1375 }
1376
1377 struct Candidate {
1378 pair_start: usize,
1379 helper_schema: Schema,
1380 helper_project: Vec<ProjectExpr>,
1381 helper_join_left_keys: Vec<usize>,
1382 helper_join_right_keys: Vec<usize>,
1383 exposed_classes: Vec<u32>,
1384 }
1385
1386 struct Rewrite {
1387 helper_body: RirNode,
1388 outer_body: RirNode,
1389 spec: HelperRelationSpec,
1390 }
1391
1392 #[derive(Clone, Copy)]
1393 struct KCliqueHelperEdge {
1394 slot: usize,
1395 rel: RelId,
1396 left: usize,
1397 right: usize,
1398 }
1399
1400 pub fn run<F>(
1402 plan: &mut ExecutionPlan,
1403 schemas: &HashMap<RelId, Schema>,
1404 stats: &StatsManager,
1405 mut allocate: F,
1406 ) -> Vec<HelperRelationSpec>
1407 where
1408 F: FnMut(Schema) -> (String, RelId),
1409 {
1410 let mut specs = Vec::new();
1411 for scc_idx in 0..plan.rules_by_scc.len() {
1412 let mut rule_idx = 0;
1413 while rule_idx < plan.rules_by_scc[scc_idx].len() {
1414 let rewrite = {
1415 let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1416 try_rewrite_rule(rule, schemas, stats, &mut allocate)
1417 };
1418 if let Some(rewrite) = rewrite {
1419 let helper_rule = CompiledRule {
1420 head: rewrite.spec.name.clone(),
1421 body: rewrite.helper_body,
1422 meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1423 };
1424 plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1425 rule_idx += 1;
1426 plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1427 add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1428 specs.push(rewrite.spec);
1429 }
1430 rule_idx += 1;
1431 }
1432 }
1433 specs
1434 }
1435
1436 pub fn run_kclique_specs<F>(
1442 plan: &mut ExecutionPlan,
1443 schemas: &HashMap<RelId, Schema>,
1444 mut allocate: F,
1445 ) -> Vec<HelperRelationSpec>
1446 where
1447 F: FnMut(Schema) -> (String, RelId),
1448 {
1449 let mut specs = Vec::new();
1450 for scc_idx in 0..plan.rules_by_scc.len() {
1451 let mut rule_idx = 0;
1452 while rule_idx < plan.rules_by_scc[scc_idx].len() {
1453 let rewrite = {
1454 let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1455 try_rewrite_kclique_rule(rule, schemas, &mut allocate)
1456 };
1457 if let Some(rewrite) = rewrite {
1458 let helper_rule = CompiledRule {
1459 head: rewrite.spec.name.clone(),
1460 body: rewrite.helper_body,
1461 meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1462 };
1463 plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1464 rule_idx += 1;
1465 plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1466 add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1467 specs.push(rewrite.spec);
1468 }
1469 rule_idx += 1;
1470 }
1471 }
1472 specs
1473 }
1474
1475 fn add_helper_to_scc(sccs: &mut [Scc], scc_idx: usize, helper: &str) {
1476 if let Some(scc) = sccs.get_mut(scc_idx) {
1477 if !scc.predicates.iter().any(|p| p == helper) {
1478 scc.predicates.push(helper.to_string());
1479 }
1480 }
1481 }
1482
1483 fn try_rewrite_rule<F>(
1484 rule: &CompiledRule,
1485 schemas: &HashMap<RelId, Schema>,
1486 stats: &StatsManager,
1487 allocate: &mut F,
1488 ) -> Option<Rewrite>
1489 where
1490 F: FnMut(Schema) -> (String, RelId),
1491 {
1492 let linear = linearize_project_body(&rule.body, schemas)?;
1493 let candidate = choose_candidate(&linear, schemas, stats)?;
1494 let (helper_name, helper_rel) = allocate(candidate.helper_schema.clone());
1495 let helper_body = build_helper_body(&linear, &candidate);
1496 let outer_body = build_outer_body(&linear, &candidate, helper_rel)?;
1497 Some(Rewrite {
1498 helper_body,
1499 outer_body,
1500 spec: HelperRelationSpec {
1501 name: helper_name,
1502 rel_id: helper_rel,
1503 schema: candidate.helper_schema,
1504 source_rels: [
1505 linear.leaves[candidate.pair_start],
1506 linear.leaves[candidate.pair_start + 1],
1507 ],
1508 },
1509 })
1510 }
1511
1512 fn try_rewrite_kclique_rule<F>(
1513 rule: &CompiledRule,
1514 schemas: &HashMap<RelId, Schema>,
1515 allocate: &mut F,
1516 ) -> Option<Rewrite>
1517 where
1518 F: FnMut(Schema) -> (String, RelId),
1519 {
1520 let mut outer_body = rule.body.clone();
1521 let RirNode::MultiWayJoin {
1522 inputs, var_order, ..
1523 } = &mut outer_body
1524 else {
1525 return None;
1526 };
1527 let kclique = var_order.as_ref()?.kclique.as_ref()?;
1528 let spec = kclique.helper_split_specs.first()?;
1529 let (hot_left, hot_right, target) = kclique_helper_edges(inputs, kclique, spec)?;
1530 let helper_schema = schemas.get(&target.rel)?.clone();
1531 let (helper_name, helper_rel) = allocate(helper_schema.clone());
1532 let helper_body = build_kclique_helper_body(spec, hot_left, hot_right, target)?;
1533 *inputs.get_mut(target.slot)? = RirNode::Scan { rel: helper_rel };
1534 Some(Rewrite {
1535 helper_body,
1536 outer_body,
1537 spec: HelperRelationSpec {
1538 name: helper_name,
1539 rel_id: helper_rel,
1540 schema: helper_schema,
1541 source_rels: [hot_left.rel, hot_right.rel],
1542 },
1543 })
1544 }
1545
1546 fn kclique_helper_edges(
1547 inputs: &[RirNode],
1548 kclique: &KCliqueVariableOrder,
1549 spec: &HelperSplitSpec,
1550 ) -> Option<(KCliqueHelperEdge, KCliqueHelperEdge, KCliqueHelperEdge)> {
1551 let k = usize::from(kclique.k);
1552 let hot = usize::from(spec.variable);
1553 let mut hot_edges = Vec::new();
1554 let mut target = None;
1555 for &slot in &spec.edge_slots {
1556 let slot = usize::from(slot);
1557 let (left, right) = kclique_edge_pair(slot, k)?;
1558 let RirNode::Scan { rel } = inputs.get(slot)? else {
1559 return None;
1560 };
1561 let edge = KCliqueHelperEdge {
1562 slot,
1563 rel: *rel,
1564 left,
1565 right,
1566 };
1567 if left == hot || right == hot {
1568 hot_edges.push(edge);
1569 } else {
1570 target = Some(edge);
1571 }
1572 }
1573 if hot_edges.len() != 2 {
1574 return None;
1575 }
1576 Some((hot_edges[0], hot_edges[1], target?))
1577 }
1578
1579 fn build_kclique_helper_body(
1580 spec: &HelperSplitSpec,
1581 hot_left: KCliqueHelperEdge,
1582 hot_right: KCliqueHelperEdge,
1583 target: KCliqueHelperEdge,
1584 ) -> Option<RirNode> {
1585 let hot = usize::from(spec.variable);
1586 let target_left = target.left;
1587 let target_right = target.right;
1588 let first_other = kclique_other_endpoint(hot_left, hot)?;
1589 let second_other = kclique_other_endpoint(hot_right, hot)?;
1590 if ![first_other, second_other].contains(&target_left)
1591 || ![first_other, second_other].contains(&target_right)
1592 {
1593 return None;
1594 }
1595
1596 let first_scan = RirNode::Scan { rel: hot_left.rel };
1597 let second_scan = RirNode::Scan { rel: hot_right.rel };
1598 let target_scan = RirNode::Scan { rel: target.rel };
1599 let first_hot_col = kclique_endpoint_col(hot_left, hot)?;
1600 let second_hot_col = kclique_endpoint_col(hot_right, hot)?;
1601 let first_other_col = kclique_endpoint_col(hot_left, first_other)?;
1602 let second_other_col = 2 + kclique_endpoint_col(hot_right, second_other)?;
1603
1604 let target_left_in_join = if first_other == target_left {
1605 first_other_col
1606 } else {
1607 second_other_col
1608 };
1609 let target_right_in_join = if first_other == target_right {
1610 first_other_col
1611 } else {
1612 second_other_col
1613 };
1614 let target_left_col = kclique_endpoint_col(target, target_left)?;
1615 let target_right_col = kclique_endpoint_col(target, target_right)?;
1616
1617 let hot_join = RirNode::Join {
1618 left: Box::new(first_scan),
1619 right: Box::new(second_scan),
1620 left_keys: vec![first_hot_col],
1621 right_keys: vec![second_hot_col],
1622 join_type: JoinType::Inner,
1623 };
1624 let helper_join = RirNode::Join {
1625 left: Box::new(hot_join),
1626 right: Box::new(target_scan),
1627 left_keys: vec![target_left_in_join, target_right_in_join],
1628 right_keys: vec![target_left_col, target_right_col],
1629 join_type: JoinType::Inner,
1630 };
1631 Some(RirNode::Project {
1632 input: Box::new(helper_join),
1633 columns: vec![ProjectExpr::Column(4), ProjectExpr::Column(5)],
1634 })
1635 }
1636
1637 fn kclique_edge_pair(edge_idx: usize, k: usize) -> Option<(usize, usize)> {
1638 let mut idx = 0usize;
1639 for left in 0..k {
1640 for right in (left + 1)..k {
1641 if idx == edge_idx {
1642 return Some((left, right));
1643 }
1644 idx += 1;
1645 }
1646 }
1647 None
1648 }
1649
1650 fn kclique_endpoint_col(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1651 if edge.left == variable {
1652 Some(0)
1653 } else if edge.right == variable {
1654 Some(1)
1655 } else {
1656 None
1657 }
1658 }
1659
1660 fn kclique_other_endpoint(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1661 if edge.left == variable {
1662 Some(edge.right)
1663 } else if edge.right == variable {
1664 Some(edge.left)
1665 } else {
1666 None
1667 }
1668 }
1669
1670 fn linearize_project_body(
1671 body: &RirNode,
1672 schemas: &HashMap<RelId, Schema>,
1673 ) -> Option<LinearBody> {
1674 let RirNode::Project { input, columns } = body else {
1675 return None;
1676 };
1677 let flat = collect_join_graph(input, schemas)?;
1678 if flat.leaves.len() < 6 {
1679 return None;
1680 }
1681 let mut offsets = Vec::with_capacity(flat.leaves.len());
1682 let mut total_cols = 0usize;
1683 for rel in &flat.leaves {
1684 offsets.push(total_cols);
1685 total_cols += schemas.get(rel)?.arity();
1686 }
1687 let mut uf = UnionFind::new(total_cols);
1688 for (left, right) in flat.equalities {
1689 if left >= total_cols || right >= total_cols {
1690 return None;
1691 }
1692 uf.union(left, right);
1693 }
1694 let mut leaf_classes: Vec<Vec<u32>> = Vec::with_capacity(flat.leaves.len());
1695 for (leaf_idx, rel) in flat.leaves.iter().enumerate() {
1696 let arity = schemas.get(rel)?.arity();
1697 let offset = offsets[leaf_idx];
1698 leaf_classes.push((0..arity).map(|col| uf.find(offset + col) as u32).collect());
1699 }
1700 let final_classes = flat
1701 .output_cols
1702 .iter()
1703 .map(|col| uf.find(*col) as u32)
1704 .collect();
1705 let joins = derive_left_deep_steps(&leaf_classes)?;
1706 Some(LinearBody {
1707 leaves: flat.leaves,
1708 leaf_classes,
1709 joins,
1710 project: columns.clone(),
1711 final_classes,
1712 })
1713 }
1714
1715 fn collect_join_graph(node: &RirNode, schemas: &HashMap<RelId, Schema>) -> Option<FlatJoin> {
1716 match node {
1717 RirNode::Scan { rel } => Some(FlatJoin {
1718 leaves: vec![*rel],
1719 output_cols: (0..schemas.get(rel)?.arity()).collect(),
1720 equalities: Vec::new(),
1721 }),
1722 RirNode::Join {
1723 left,
1724 right,
1725 left_keys,
1726 right_keys,
1727 join_type,
1728 } if *join_type == JoinType::Inner => {
1729 let left_flat = collect_join_graph(left, schemas)?;
1730 let right_flat = collect_join_graph(right, schemas)?;
1731 if left_keys.len() != right_keys.len() {
1732 return None;
1733 }
1734 let right_shift = total_width(&left_flat.leaves, schemas)?;
1735 let mut leaves = left_flat.leaves;
1736 leaves.extend(right_flat.leaves);
1737 let right_output_cols: Vec<usize> = right_flat
1738 .output_cols
1739 .iter()
1740 .map(|col| col + right_shift)
1741 .collect();
1742 let mut equalities = left_flat.equalities;
1743 equalities.extend(
1744 right_flat
1745 .equalities
1746 .iter()
1747 .map(|(left, right)| (left + right_shift, right + right_shift)),
1748 );
1749 for (&left_key, &right_key) in left_keys.iter().zip(right_keys.iter()) {
1750 equalities.push((
1751 *left_flat.output_cols.get(left_key)?,
1752 *right_output_cols.get(right_key)?,
1753 ));
1754 }
1755 let mut output_cols = left_flat.output_cols;
1756 output_cols.extend(right_output_cols);
1757 Some(FlatJoin {
1758 leaves,
1759 output_cols,
1760 equalities,
1761 })
1762 }
1763 _ => None,
1764 }
1765 }
1766
1767 fn total_width(leaves: &[RelId], schemas: &HashMap<RelId, Schema>) -> Option<usize> {
1768 leaves
1769 .iter()
1770 .map(|rel| schemas.get(rel).map(Schema::arity))
1771 .try_fold(0usize, |acc, width| width.map(|width| acc + width))
1772 }
1773
1774 fn derive_left_deep_steps(leaf_classes: &[Vec<u32>]) -> Option<Vec<JoinStep>> {
1775 let mut joins = Vec::with_capacity(leaf_classes.len().saturating_sub(1));
1776 let mut current = leaf_classes.first()?.clone();
1777 for classes in leaf_classes.iter().skip(1) {
1778 let mut left_keys = Vec::new();
1779 let mut right_keys = Vec::new();
1780 for (right_col, class) in classes.iter().enumerate() {
1781 if let Some(left_col) = current
1782 .iter()
1783 .position(|current_class| current_class == class)
1784 {
1785 left_keys.push(left_col);
1786 right_keys.push(right_col);
1787 }
1788 }
1789 if left_keys.is_empty() {
1790 return None;
1791 }
1792 joins.push(JoinStep {
1793 left_keys,
1794 right_keys,
1795 });
1796 current.extend(classes.iter().copied());
1797 }
1798 Some(joins)
1799 }
1800
1801 fn choose_candidate(
1802 linear: &LinearBody,
1803 schemas: &HashMap<RelId, Schema>,
1804 stats: &StatsManager,
1805 ) -> Option<Candidate> {
1806 for pair_start in 3..linear.leaves.len().saturating_sub(1) {
1807 let candidate = build_candidate(linear, schemas, pair_start)?;
1808 if skew_ratio_for_candidate(linear, stats, &candidate) >= HEAVY_SKEW_RATIO {
1809 return Some(candidate);
1810 }
1811 }
1812 None
1813 }
1814
1815 fn build_candidate(
1816 linear: &LinearBody,
1817 schemas: &HashMap<RelId, Schema>,
1818 pair_start: usize,
1819 ) -> Option<Candidate> {
1820 let left_rel = linear.leaves[pair_start];
1821 let right_rel = linear.leaves[pair_start + 1];
1822 let left_schema = schemas.get(&left_rel)?;
1823 let right_schema = schemas.get(&right_rel)?;
1824 let internal_step = linear.joins.get(pair_start)?;
1825 let mut helper_left_keys = Vec::new();
1826 let mut helper_right_keys = Vec::new();
1827 for (&left_key, &right_key) in internal_step
1828 .left_keys
1829 .iter()
1830 .zip(internal_step.right_keys.iter())
1831 {
1832 let class = class_at_state(linear, pair_start + 1, left_key)?;
1833 let left_col = linear.leaf_classes[pair_start]
1834 .iter()
1835 .position(|c| *c == class)?;
1836 helper_left_keys.push(left_col);
1837 helper_right_keys.push(right_key);
1838 }
1839 let internal: HashSet<u32> = helper_left_keys
1840 .iter()
1841 .map(|col| linear.leaf_classes[pair_start][*col])
1842 .collect();
1843 let outside = outside_classes(linear, pair_start);
1844 let output = projected_classes(linear)?;
1845 let mut exposed_classes = Vec::new();
1846 let mut helper_project = Vec::new();
1847 let mut helper_columns = Vec::new();
1848 for (col, class) in linear.leaf_classes[pair_start].iter().copied().enumerate() {
1849 if !internal.contains(&class)
1850 && (outside.contains(&class) || output.contains(&class))
1851 && !exposed_classes.contains(&class)
1852 {
1853 exposed_classes.push(class);
1854 helper_project.push(ProjectExpr::Column(col));
1855 let ty = left_schema.column_type(col).unwrap_or(ScalarType::U32);
1856 helper_columns.push((format!("c{}", helper_columns.len()), ty));
1857 }
1858 }
1859 let right_offset = left_schema.arity();
1860 for (col, class) in linear.leaf_classes[pair_start + 1]
1861 .iter()
1862 .copied()
1863 .enumerate()
1864 {
1865 if !internal.contains(&class)
1866 && (outside.contains(&class) || output.contains(&class))
1867 && !exposed_classes.contains(&class)
1868 {
1869 exposed_classes.push(class);
1870 helper_project.push(ProjectExpr::Column(right_offset + col));
1871 let ty = right_schema.column_type(col).unwrap_or(ScalarType::U32);
1872 helper_columns.push((format!("c{}", helper_columns.len()), ty));
1873 }
1874 }
1875 if exposed_classes.len() != 2 {
1876 return None;
1877 }
1878 Some(Candidate {
1879 pair_start,
1880 helper_schema: Schema::new(helper_columns),
1881 helper_project,
1882 helper_join_left_keys: helper_left_keys,
1883 helper_join_right_keys: helper_right_keys,
1884 exposed_classes,
1885 })
1886 }
1887
1888 fn class_at_state(linear: &LinearBody, leaf_count: usize, col: usize) -> Option<u32> {
1889 let mut idx = col;
1890 for leaf_idx in 0..leaf_count {
1891 let classes = &linear.leaf_classes[leaf_idx];
1892 if idx < classes.len() {
1893 return Some(classes[idx]);
1894 }
1895 idx -= classes.len();
1896 }
1897 None
1898 }
1899
1900 fn outside_classes(linear: &LinearBody, pair_start: usize) -> HashSet<u32> {
1901 linear
1902 .leaf_classes
1903 .iter()
1904 .enumerate()
1905 .filter(|(idx, _)| *idx != pair_start && *idx != pair_start + 1)
1906 .flat_map(|(_, classes)| classes.iter().copied())
1907 .collect()
1908 }
1909
1910 fn projected_classes(linear: &LinearBody) -> Option<HashSet<u32>> {
1911 let mut out = HashSet::new();
1912 for expr in &linear.project {
1913 let ProjectExpr::Column(col) = expr else {
1914 return None;
1915 };
1916 out.insert(*linear.final_classes.get(*col)?);
1917 }
1918 Some(out)
1919 }
1920
1921 fn skew_ratio_for_candidate(
1922 linear: &LinearBody,
1923 stats: &StatsManager,
1924 candidate: &Candidate,
1925 ) -> f64 {
1926 let rel = linear.leaves[candidate.pair_start];
1927 let Some(rel_stats) = stats.get_relation_stats(rel) else {
1928 return 0.0;
1929 };
1930 let mut ratio: f64 = 0.0;
1931 for (col, class) in linear.leaf_classes[candidate.pair_start]
1932 .iter()
1933 .copied()
1934 .enumerate()
1935 {
1936 if !candidate.exposed_classes.contains(&class) {
1937 continue;
1938 }
1939 let Some(col_stats) = rel_stats.get_column(col) else {
1940 continue;
1941 };
1942 if col_stats.distinct_estimate == 0 {
1943 continue;
1944 }
1945 ratio = ratio.max(rel_stats.cardinality as f64 / col_stats.distinct_estimate as f64);
1946 }
1947 ratio
1948 }
1949
1950 fn build_helper_body(linear: &LinearBody, candidate: &Candidate) -> RirNode {
1951 let left = RirNode::Scan {
1952 rel: linear.leaves[candidate.pair_start],
1953 };
1954 let right = RirNode::Scan {
1955 rel: linear.leaves[candidate.pair_start + 1],
1956 };
1957 RirNode::Project {
1958 input: Box::new(RirNode::Join {
1959 left: Box::new(left),
1960 right: Box::new(right),
1961 left_keys: candidate.helper_join_left_keys.clone(),
1962 right_keys: candidate.helper_join_right_keys.clone(),
1963 join_type: JoinType::Inner,
1964 }),
1965 columns: candidate.helper_project.clone(),
1966 }
1967 }
1968
1969 fn build_outer_body(
1970 linear: &LinearBody,
1971 candidate: &Candidate,
1972 helper_rel: RelId,
1973 ) -> Option<RirNode> {
1974 let mut node = RirNode::Scan {
1975 rel: linear.leaves[0],
1976 };
1977 let mut classes = linear.leaf_classes[0].clone();
1978 for leaf_idx in 1..candidate.pair_start {
1979 let step = &linear.joins[leaf_idx - 1];
1980 node = RirNode::Join {
1981 left: Box::new(node),
1982 right: Box::new(RirNode::Scan {
1983 rel: linear.leaves[leaf_idx],
1984 }),
1985 left_keys: step.left_keys.clone(),
1986 right_keys: step.right_keys.clone(),
1987 join_type: JoinType::Inner,
1988 };
1989 classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
1990 }
1991 let prefix_step = &linear.joins[candidate.pair_start - 1];
1992 let mut helper_right_keys = Vec::new();
1993 for &rk in &prefix_step.right_keys {
1994 let class = linear.leaf_classes[candidate.pair_start][rk];
1995 helper_right_keys.push(candidate.exposed_classes.iter().position(|c| *c == class)?);
1996 }
1997 node = RirNode::Join {
1998 left: Box::new(node),
1999 right: Box::new(RirNode::Scan { rel: helper_rel }),
2000 left_keys: prefix_step.left_keys.clone(),
2001 right_keys: helper_right_keys,
2002 join_type: JoinType::Inner,
2003 };
2004 classes.extend(candidate.exposed_classes.iter().copied());
2005 for leaf_idx in candidate.pair_start + 2..linear.leaves.len() {
2006 let step = &linear.joins[leaf_idx - 1];
2007 let mut left_keys = Vec::new();
2008 for &lk in &step.left_keys {
2009 let class = class_at_state(linear, leaf_idx, lk)?;
2010 left_keys.push(classes.iter().position(|c| *c == class)?);
2011 }
2012 node = RirNode::Join {
2013 left: Box::new(node),
2014 right: Box::new(RirNode::Scan {
2015 rel: linear.leaves[leaf_idx],
2016 }),
2017 left_keys,
2018 right_keys: step.right_keys.clone(),
2019 join_type: JoinType::Inner,
2020 };
2021 classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
2022 }
2023 let mut project = Vec::with_capacity(linear.project.len());
2024 for expr in &linear.project {
2025 let ProjectExpr::Column(col) = expr else {
2026 return None;
2027 };
2028 let class = *linear.final_classes.get(*col)?;
2029 let mapped = classes.iter().position(|c| *c == class)?;
2030 project.push(ProjectExpr::Column(mapped));
2031 }
2032 Some(RirNode::Project {
2033 input: Box::new(node),
2034 columns: project,
2035 })
2036 }
2037
2038 struct UnionFind {
2039 parent: Vec<usize>,
2040 }
2041
2042 impl UnionFind {
2043 fn new(len: usize) -> Self {
2044 Self {
2045 parent: (0..len).collect(),
2046 }
2047 }
2048
2049 fn find(&mut self, x: usize) -> usize {
2050 let p = self.parent[x];
2051 if p == x {
2052 x
2053 } else {
2054 let root = self.find(p);
2055 self.parent[x] = root;
2056 root
2057 }
2058 }
2059
2060 fn union(&mut self, a: usize, b: usize) {
2061 let ra = self.find(a);
2062 let rb = self.find(b);
2063 if ra != rb {
2064 self.parent[rb] = ra;
2065 }
2066 }
2067 }
2068}
2069
2070#[path = "optimizer/stream_schedule_pass.rs"]
2071pub mod stream_schedule_pass;
2072
2073#[cfg(test)]
2074mod helper_split_pass_tests {
2075 use std::collections::HashMap;
2076
2077 use super::helper_split_pass;
2078 use xlog_core::{RelId, ScalarType, Schema};
2079 use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
2080 use xlog_stats::{ColumnStats, StatsManager};
2081
2082 fn edge_schema() -> Schema {
2083 Schema::new(vec![
2084 ("c0".to_string(), ScalarType::U32),
2085 ("c1".to_string(), ScalarType::U32),
2086 ])
2087 }
2088
2089 fn helper_schema() -> Schema {
2090 Schema::new(vec![
2091 ("c0".to_string(), ScalarType::U32),
2092 ("c1".to_string(), ScalarType::U32),
2093 ])
2094 }
2095
2096 fn schemas() -> HashMap<RelId, Schema> {
2097 (0..6)
2098 .map(|idx| (RelId(idx), edge_schema()))
2099 .collect::<HashMap<_, _>>()
2100 }
2101
2102 fn left_deep_fixture_body() -> RirNode {
2103 let ab_bc = RirNode::Join {
2104 left: Box::new(RirNode::Scan { rel: RelId(0) }),
2105 right: Box::new(RirNode::Scan { rel: RelId(1) }),
2106 left_keys: vec![1],
2107 right_keys: vec![0],
2108 join_type: JoinType::Inner,
2109 };
2110 let with_cd = RirNode::Join {
2111 left: Box::new(ab_bc),
2112 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2113 left_keys: vec![3],
2114 right_keys: vec![0],
2115 join_type: JoinType::Inner,
2116 };
2117 let with_de = RirNode::Join {
2118 left: Box::new(with_cd),
2119 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2120 left_keys: vec![5],
2121 right_keys: vec![0],
2122 join_type: JoinType::Inner,
2123 };
2124 let with_ef = RirNode::Join {
2125 left: Box::new(with_de),
2126 right: Box::new(RirNode::Scan { rel: RelId(4) }),
2127 left_keys: vec![7],
2128 right_keys: vec![0],
2129 join_type: JoinType::Inner,
2130 };
2131 let with_af = RirNode::Join {
2132 left: Box::new(with_ef),
2133 right: Box::new(RirNode::Scan { rel: RelId(5) }),
2134 left_keys: vec![0, 9],
2135 right_keys: vec![0, 1],
2136 join_type: JoinType::Inner,
2137 };
2138 RirNode::Project {
2139 input: Box::new(with_af),
2140 columns: vec![
2141 ProjectExpr::Column(0),
2142 ProjectExpr::Column(1),
2143 ProjectExpr::Column(3),
2144 ProjectExpr::Column(5),
2145 ProjectExpr::Column(9),
2146 ],
2147 }
2148 }
2149
2150 fn plan() -> ExecutionPlan {
2151 ExecutionPlan {
2152 sccs: vec![Scc {
2153 id: 0,
2154 predicates: vec!["out".to_string()],
2155 is_recursive: false,
2156 }],
2157 strata: vec![],
2158 rules_by_scc: vec![vec![CompiledRule {
2159 head: "out".to_string(),
2160 body: left_deep_fixture_body(),
2161 meta: RirMeta::with_schema(Schema::new(vec![
2162 ("a".to_string(), ScalarType::U32),
2163 ("b".to_string(), ScalarType::U32),
2164 ("c".to_string(), ScalarType::U32),
2165 ("d".to_string(), ScalarType::U32),
2166 ("f".to_string(), ScalarType::U32),
2167 ])),
2168 }]],
2169 est_memory_peak: 0,
2170 rel_arities: std::collections::HashMap::new(),
2171 }
2172 }
2173
2174 fn stats_for_de(distinct_d: u64) -> StatsManager {
2175 let mut stats = StatsManager::new();
2176 for idx in 0..6 {
2177 stats.register_relation(RelId(idx));
2178 stats.update_cardinality(RelId(idx), 8192);
2179 }
2180 let mut d_col = ColumnStats::new(0, ScalarType::U32);
2181 d_col.update_distinct(distinct_d);
2182 stats.add_column_stats(RelId(3), d_col);
2183 stats
2184 }
2185
2186 fn contains_scan(node: &RirNode, rel: RelId) -> bool {
2187 match node {
2188 RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
2189 RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
2190 contains_scan(left, rel) || contains_scan(right, rel)
2191 }
2192 RirNode::Project { input, .. }
2193 | RirNode::Filter { input, .. }
2194 | RirNode::Distinct { input, .. }
2195 | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
2196 RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
2197 RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
2198 RirNode::Fixpoint {
2199 base, recursive, ..
2200 } => contains_scan(base, rel) || contains_scan(recursive, rel),
2201 RirNode::MultiWayJoin { inputs, .. } => {
2202 inputs.iter().any(|input| contains_scan(input, rel))
2203 }
2204 RirNode::TensorMaskedJoin { rel_index, .. } => {
2205 rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
2206 }
2207 RirNode::Unit => false,
2208 }
2209 }
2210
2211 #[test]
2212 fn helper_split_extracts_buried_pair() {
2213 let mut plan = plan();
2214 let schemas = schemas();
2215 let stats = stats_for_de(1);
2216 let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2217 ("__kclique_helper_6".to_string(), RelId(6))
2218 });
2219
2220 assert_eq!(specs.len(), 1);
2221 assert_eq!(specs[0].name, "__kclique_helper_6");
2222 assert_eq!(specs[0].rel_id, RelId(6));
2223 assert_eq!(specs[0].schema, helper_schema());
2224 assert_eq!(specs[0].source_rels, [RelId(3), RelId(4)]);
2225 assert_eq!(plan.rules_by_scc[0].len(), 2);
2226 assert_eq!(plan.rules_by_scc[0][0].head, "__kclique_helper_6");
2227 assert_eq!(plan.rules_by_scc[0][1].head, "out");
2228 assert!(contains_scan(&plan.rules_by_scc[0][1].body, RelId(6)));
2229 assert!(plan.sccs[0]
2230 .predicates
2231 .iter()
2232 .any(|predicate| predicate == "__kclique_helper_6"));
2233 }
2234
2235 #[test]
2236 fn helper_split_ignores_flat_distribution() {
2237 let mut plan = plan();
2238 let schemas = schemas();
2239 let stats = stats_for_de(8192);
2240 let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2241 ("__kclique_helper_6".to_string(), RelId(6))
2242 });
2243
2244 assert!(specs.is_empty());
2245 assert_eq!(plan.rules_by_scc[0].len(), 1);
2246 assert!(!contains_scan(&plan.rules_by_scc[0][0].body, RelId(6)));
2247 }
2248}
2249
2250mod reorder {
2253 use std::collections::HashMap;
2254 use xlog_core::RelId;
2255 use xlog_ir::rir::ProjectExpr;
2256 use xlog_ir::{JoinType, RirNode};
2257 use xlog_stats::StatsManager;
2258
2259 fn ac3(atom: u8, col: u8) -> u8 {
2260 atom * 2 + col
2261 }
2262 fn ac4(atom: u8, col: u8) -> u8 {
2263 atom * 2 + col
2264 }
2265 fn uf_find_n<const N: usize>(parent: &mut [u8; N], x: u8) -> u8 {
2266 let mut root = x;
2267 while parent[root as usize] != root {
2268 root = parent[root as usize];
2269 }
2270 let mut cur = x;
2271 while parent[cur as usize] != root {
2272 let next = parent[cur as usize];
2273 parent[cur as usize] = root;
2274 cur = next;
2275 }
2276 root
2277 }
2278 fn uf_union_n<const N: usize>(parent: &mut [u8; N], a: u8, b: u8) {
2279 let ra = uf_find_n(parent, a);
2280 let rb = uf_find_n(parent, b);
2281 if ra != rb {
2282 parent[rb as usize] = ra;
2283 }
2284 }
2285
2286 fn populated_card(stats: &StatsManager, rel: RelId) -> Option<u64> {
2287 stats
2288 .get_relation_stats(rel)
2289 .map(|s| s.cardinality)
2290 .filter(|c| *c > 0)
2291 }
2292
2293 struct TriangleSemantics {
2298 rel_xy: RelId,
2299 rel_yz: RelId,
2300 rel_xz: RelId,
2301 }
2302
2303 fn match_and_infer_triangle(body: &RirNode) -> Option<TriangleSemantics> {
2304 let RirNode::Project {
2305 input: outer_input,
2306 columns,
2307 } = body
2308 else {
2309 return None;
2310 };
2311 let RirNode::Join {
2312 left: l1,
2313 right: r1,
2314 left_keys: lk1,
2315 right_keys: rk1,
2316 join_type: jt1,
2317 } = outer_input.as_ref()
2318 else {
2319 return None;
2320 };
2321 if !matches!(jt1, JoinType::Inner) {
2322 return None;
2323 }
2324 let RirNode::Scan { rel: rel_third } = r1.as_ref() else {
2325 return None;
2326 };
2327 let RirNode::Join {
2328 left: l2,
2329 right: r2,
2330 left_keys: lk2,
2331 right_keys: rk2,
2332 join_type: jt2,
2333 } = l1.as_ref()
2334 else {
2335 return None;
2336 };
2337 if !matches!(jt2, JoinType::Inner) {
2338 return None;
2339 }
2340 let RirNode::Scan { rel: rel_inner_l } = l2.as_ref() else {
2341 return None;
2342 };
2343 let RirNode::Scan { rel: rel_inner_r } = r2.as_ref() else {
2344 return None;
2345 };
2346 if lk2.len() != 1 || rk2.len() != 1 || lk1.len() != 2 || rk1.len() != 2 {
2347 return None;
2348 }
2349 if columns.len() != 3 {
2350 return None;
2351 }
2352 if lk2[0] >= 2 || rk2[0] >= 2 {
2353 return None;
2354 }
2355 if lk1.iter().any(|k| *k >= 4) || rk1.iter().any(|k| *k >= 2) {
2356 return None;
2357 }
2358
2359 let mut parent = [0u8, 1, 2, 3, 4, 5];
2360 uf_union_n::<6>(&mut parent, ac3(0, lk2[0] as u8), ac3(1, rk2[0] as u8));
2361 for i in 0..2 {
2362 let inner_ac = match lk1[i] {
2363 0 => (0u8, 0u8),
2364 1 => (0, 1),
2365 2 => (1, 0),
2366 3 => (1, 1),
2367 _ => return None,
2368 };
2369 uf_union_n::<6>(
2370 &mut parent,
2371 ac3(inner_ac.0, inner_ac.1),
2372 ac3(2, rk1[i] as u8),
2373 );
2374 }
2375 let roots: [u8; 6] = std::array::from_fn(|i| uf_find_n::<6>(&mut parent, i as u8));
2376 let mut counts: HashMap<u8, u8> = HashMap::new();
2377 for r in &roots {
2378 *counts.entry(*r).or_insert(0) += 1;
2379 }
2380 if counts.len() != 3 || counts.values().any(|c| *c != 2) {
2381 return None;
2382 }
2383 let mut head_classes: [u8; 3] = [0; 3];
2384 for (i, pc) in columns.iter().enumerate() {
2385 let ProjectExpr::Column(k) = pc else {
2386 return None;
2387 };
2388 let outer_ac = match *k {
2389 0 => (0u8, 0u8),
2390 1 => (0, 1),
2391 2 => (1, 0),
2392 3 => (1, 1),
2393 4 => (2, 0),
2394 5 => (2, 1),
2395 _ => return None,
2396 };
2397 head_classes[i] = uf_find_n::<6>(&mut parent, ac3(outer_ac.0, outer_ac.1));
2398 }
2399 if head_classes[0] == head_classes[1]
2400 || head_classes[0] == head_classes[2]
2401 || head_classes[1] == head_classes[2]
2402 {
2403 return None;
2404 }
2405 let x_class = head_classes[0];
2406 let y_class = head_classes[1];
2407 let z_class = head_classes[2];
2408 let atom_classes = |a: u8| (roots[ac3(a, 0) as usize], roots[ac3(a, 1) as usize]);
2409 let atom_rels = [*rel_inner_l, *rel_inner_r, *rel_third];
2410 let mut rel_xy = None;
2411 let mut rel_yz = None;
2412 let mut rel_xz = None;
2413 for atom_idx in 0..3u8 {
2414 let (c0, c1) = atom_classes(atom_idx);
2415 let bx = c0 == x_class || c1 == x_class;
2416 let by = c0 == y_class || c1 == y_class;
2417 let bz = c0 == z_class || c1 == z_class;
2418 match (bx, by, bz) {
2419 (true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2420 (false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2421 (true, false, true) => rel_xz = Some(atom_rels[atom_idx as usize]),
2422 _ => return None,
2423 }
2424 }
2425 Some(TriangleSemantics {
2426 rel_xy: rel_xy?,
2427 rel_yz: rel_yz?,
2428 rel_xz: rel_xz?,
2429 })
2430 }
2431
2432 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2433 #[allow(clippy::enum_variant_names)]
2434 enum TriangleInnerPair {
2435 YShared,
2436 XShared,
2437 ZShared,
2438 }
2439
2440 fn build_triangle_body(s: &TriangleSemantics, inner_pair: TriangleInnerPair) -> RirNode {
2441 let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2442 match inner_pair {
2443 TriangleInnerPair::YShared => {
2444 let inner = RirNode::Join {
2445 left: Box::new(mk_scan(s.rel_xy)),
2446 right: Box::new(mk_scan(s.rel_yz)),
2447 left_keys: vec![1],
2448 right_keys: vec![0],
2449 join_type: JoinType::Inner,
2450 };
2451 let outer = RirNode::Join {
2452 left: Box::new(inner),
2453 right: Box::new(mk_scan(s.rel_xz)),
2454 left_keys: vec![0, 3],
2455 right_keys: vec![0, 1],
2456 join_type: JoinType::Inner,
2457 };
2458 RirNode::Project {
2459 input: Box::new(outer),
2460 columns: vec![
2461 ProjectExpr::Column(0),
2462 ProjectExpr::Column(1),
2463 ProjectExpr::Column(3),
2464 ],
2465 }
2466 }
2467 TriangleInnerPair::XShared => {
2468 let inner = RirNode::Join {
2469 left: Box::new(mk_scan(s.rel_xy)),
2470 right: Box::new(mk_scan(s.rel_xz)),
2471 left_keys: vec![0],
2472 right_keys: vec![0],
2473 join_type: JoinType::Inner,
2474 };
2475 let outer = RirNode::Join {
2476 left: Box::new(inner),
2477 right: Box::new(mk_scan(s.rel_yz)),
2478 left_keys: vec![1, 3],
2479 right_keys: vec![0, 1],
2480 join_type: JoinType::Inner,
2481 };
2482 RirNode::Project {
2483 input: Box::new(outer),
2484 columns: vec![
2485 ProjectExpr::Column(0),
2486 ProjectExpr::Column(1),
2487 ProjectExpr::Column(3),
2488 ],
2489 }
2490 }
2491 TriangleInnerPair::ZShared => {
2492 let inner = RirNode::Join {
2493 left: Box::new(mk_scan(s.rel_xz)),
2494 right: Box::new(mk_scan(s.rel_yz)),
2495 left_keys: vec![1],
2496 right_keys: vec![1],
2497 join_type: JoinType::Inner,
2498 };
2499 let outer = RirNode::Join {
2500 left: Box::new(inner),
2501 right: Box::new(mk_scan(s.rel_xy)),
2502 left_keys: vec![0, 2],
2503 right_keys: vec![0, 1],
2504 join_type: JoinType::Inner,
2505 };
2506 RirNode::Project {
2507 input: Box::new(outer),
2508 columns: vec![
2509 ProjectExpr::Column(0),
2510 ProjectExpr::Column(2),
2511 ProjectExpr::Column(3),
2512 ],
2513 }
2514 }
2515 }
2516 }
2517
2518 pub fn try_reorder_triangle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2519 let s = match_and_infer_triangle(body)?;
2520 let _ = (
2521 populated_card(stats, s.rel_xy)?,
2522 populated_card(stats, s.rel_yz)?,
2523 populated_card(stats, s.rel_xz)?,
2524 );
2525 let est_y = stats.estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0]);
2526 let est_x = stats.estimate_join_cardinality(s.rel_xy, s.rel_xz, &[0], &[0]);
2527 let est_z = stats.estimate_join_cardinality(s.rel_yz, s.rel_xz, &[1], &[1]);
2528 let mut best = (TriangleInnerPair::YShared, est_y);
2529 if est_x < best.1 {
2530 best = (TriangleInnerPair::XShared, est_x);
2531 }
2532 if est_z < best.1 {
2533 best = (TriangleInnerPair::ZShared, est_z);
2534 }
2535 let candidate = build_triangle_body(&s, best.0);
2536 if format!("{:?}", candidate) == format!("{:?}", body) {
2542 return None;
2543 }
2544 Some(candidate)
2545 }
2546
2547 struct Cycle4Semantics {
2552 rel_wx: RelId,
2553 rel_xy: RelId,
2554 rel_yz: RelId,
2555 rel_zw: RelId,
2556 }
2557
2558 fn match_and_infer_4cycle(body: &RirNode) -> Option<Cycle4Semantics> {
2559 let RirNode::Project {
2560 input: outer_input,
2561 columns,
2562 } = body
2563 else {
2564 return None;
2565 };
2566 let RirNode::Join {
2567 left: outer_l,
2568 right: outer_r,
2569 left_keys: olk,
2570 right_keys: ork,
2571 join_type: ojt,
2572 } = outer_input.as_ref()
2573 else {
2574 return None;
2575 };
2576 if !matches!(ojt, JoinType::Inner) {
2577 return None;
2578 }
2579 let RirNode::Join {
2580 left: ll,
2581 right: lr,
2582 left_keys: ilk_l,
2583 right_keys: irk_l,
2584 join_type: ijt_l,
2585 } = outer_l.as_ref()
2586 else {
2587 return None;
2588 };
2589 if !matches!(ijt_l, JoinType::Inner) {
2590 return None;
2591 }
2592 let RirNode::Scan { rel: rel_ll } = ll.as_ref() else {
2593 return None;
2594 };
2595 let RirNode::Scan { rel: rel_lr } = lr.as_ref() else {
2596 return None;
2597 };
2598 let RirNode::Join {
2599 left: rl,
2600 right: rr,
2601 left_keys: ilk_r,
2602 right_keys: irk_r,
2603 join_type: ijt_r,
2604 } = outer_r.as_ref()
2605 else {
2606 return None;
2607 };
2608 if !matches!(ijt_r, JoinType::Inner) {
2609 return None;
2610 }
2611 let RirNode::Scan { rel: rel_rl } = rl.as_ref() else {
2612 return None;
2613 };
2614 let RirNode::Scan { rel: rel_rr } = rr.as_ref() else {
2615 return None;
2616 };
2617 if ilk_l.len() != 1 || irk_l.len() != 1 || ilk_r.len() != 1 || irk_r.len() != 1 {
2618 return None;
2619 }
2620 if olk.len() != 2 || ork.len() != 2 || columns.len() != 4 {
2621 return None;
2622 }
2623 if ilk_l[0] >= 2 || irk_l[0] >= 2 || ilk_r[0] >= 2 || irk_r[0] >= 2 {
2624 return None;
2625 }
2626 if olk.iter().any(|k| *k >= 4) || ork.iter().any(|k| *k >= 4) {
2627 return None;
2628 }
2629
2630 let mut parent = [0u8, 1, 2, 3, 4, 5, 6, 7];
2631 uf_union_n::<8>(&mut parent, ac4(0, ilk_l[0] as u8), ac4(1, irk_l[0] as u8));
2632 uf_union_n::<8>(&mut parent, ac4(2, ilk_r[0] as u8), ac4(3, irk_r[0] as u8));
2633 for i in 0..2 {
2634 let l_ac = match olk[i] {
2635 0 => (0u8, 0u8),
2636 1 => (0, 1),
2637 2 => (1, 0),
2638 3 => (1, 1),
2639 _ => return None,
2640 };
2641 let r_ac = match ork[i] {
2642 0 => (2u8, 0u8),
2643 1 => (2, 1),
2644 2 => (3, 0),
2645 3 => (3, 1),
2646 _ => return None,
2647 };
2648 uf_union_n::<8>(&mut parent, ac4(l_ac.0, l_ac.1), ac4(r_ac.0, r_ac.1));
2649 }
2650 let roots: [u8; 8] = std::array::from_fn(|i| uf_find_n::<8>(&mut parent, i as u8));
2651 let mut counts: HashMap<u8, u8> = HashMap::new();
2652 for r in &roots {
2653 *counts.entry(*r).or_insert(0) += 1;
2654 }
2655 if counts.len() != 4 || counts.values().any(|c| *c != 2) {
2656 return None;
2657 }
2658
2659 let mut head_classes: [u8; 4] = [0; 4];
2660 for (i, pc) in columns.iter().enumerate() {
2661 let ProjectExpr::Column(k) = pc else {
2662 return None;
2663 };
2664 let ac = match *k {
2665 0 => (0u8, 0u8),
2666 1 => (0, 1),
2667 2 => (1, 0),
2668 3 => (1, 1),
2669 4 => (2, 0),
2670 5 => (2, 1),
2671 6 => (3, 0),
2672 7 => (3, 1),
2673 _ => return None,
2674 };
2675 head_classes[i] = uf_find_n::<8>(&mut parent, ac4(ac.0, ac.1));
2676 }
2677 for i in 0..4 {
2678 for j in (i + 1)..4 {
2679 if head_classes[i] == head_classes[j] {
2680 return None;
2681 }
2682 }
2683 }
2684 let w_class = head_classes[0];
2685 let x_class = head_classes[1];
2686 let y_class = head_classes[2];
2687 let z_class = head_classes[3];
2688 let atom_classes = |a: u8| (roots[ac4(a, 0) as usize], roots[ac4(a, 1) as usize]);
2689 let atom_rels = [*rel_ll, *rel_lr, *rel_rl, *rel_rr];
2690 let mut rel_wx = None;
2691 let mut rel_xy = None;
2692 let mut rel_yz = None;
2693 let mut rel_zw = None;
2694 for atom_idx in 0..4u8 {
2695 let (c0, c1) = atom_classes(atom_idx);
2696 let bw = c0 == w_class || c1 == w_class;
2697 let bx = c0 == x_class || c1 == x_class;
2698 let by = c0 == y_class || c1 == y_class;
2699 let bz = c0 == z_class || c1 == z_class;
2700 match (bw, bx, by, bz) {
2701 (true, true, false, false) => rel_wx = Some(atom_rels[atom_idx as usize]),
2702 (false, true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2703 (false, false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2704 (true, false, false, true) => rel_zw = Some(atom_rels[atom_idx as usize]),
2705 _ => return None,
2706 }
2707 }
2708 Some(Cycle4Semantics {
2709 rel_wx: rel_wx?,
2710 rel_xy: rel_xy?,
2711 rel_yz: rel_yz?,
2712 rel_zw: rel_zw?,
2713 })
2714 }
2715
2716 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2717 enum Cycle4Grouping {
2718 Default,
2719 Alt,
2720 }
2721
2722 fn build_4cycle_body(s: &Cycle4Semantics, g: Cycle4Grouping) -> RirNode {
2723 let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2724 match g {
2725 Cycle4Grouping::Default => {
2726 let il = RirNode::Join {
2727 left: Box::new(mk_scan(s.rel_wx)),
2728 right: Box::new(mk_scan(s.rel_xy)),
2729 left_keys: vec![1],
2730 right_keys: vec![0],
2731 join_type: JoinType::Inner,
2732 };
2733 let ir = RirNode::Join {
2734 left: Box::new(mk_scan(s.rel_yz)),
2735 right: Box::new(mk_scan(s.rel_zw)),
2736 left_keys: vec![1],
2737 right_keys: vec![0],
2738 join_type: JoinType::Inner,
2739 };
2740 let outer = RirNode::Join {
2741 left: Box::new(il),
2742 right: Box::new(ir),
2743 left_keys: vec![0, 3],
2744 right_keys: vec![3, 0],
2745 join_type: JoinType::Inner,
2746 };
2747 RirNode::Project {
2748 input: Box::new(outer),
2749 columns: vec![
2750 ProjectExpr::Column(0),
2751 ProjectExpr::Column(1),
2752 ProjectExpr::Column(3),
2753 ProjectExpr::Column(5),
2754 ],
2755 }
2756 }
2757 Cycle4Grouping::Alt => {
2758 let il = RirNode::Join {
2759 left: Box::new(mk_scan(s.rel_xy)),
2760 right: Box::new(mk_scan(s.rel_yz)),
2761 left_keys: vec![1],
2762 right_keys: vec![0],
2763 join_type: JoinType::Inner,
2764 };
2765 let ir = RirNode::Join {
2766 left: Box::new(mk_scan(s.rel_zw)),
2767 right: Box::new(mk_scan(s.rel_wx)),
2768 left_keys: vec![1],
2769 right_keys: vec![0],
2770 join_type: JoinType::Inner,
2771 };
2772 let outer = RirNode::Join {
2773 left: Box::new(il),
2774 right: Box::new(ir),
2775 left_keys: vec![0, 3],
2776 right_keys: vec![3, 0],
2777 join_type: JoinType::Inner,
2778 };
2779 RirNode::Project {
2780 input: Box::new(outer),
2781 columns: vec![
2782 ProjectExpr::Column(5),
2783 ProjectExpr::Column(0),
2784 ProjectExpr::Column(1),
2785 ProjectExpr::Column(3),
2786 ],
2787 }
2788 }
2789 }
2790 }
2791
2792 pub fn try_reorder_4cycle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2793 let s = match_and_infer_4cycle(body)?;
2794 let _ = (
2795 populated_card(stats, s.rel_wx)?,
2796 populated_card(stats, s.rel_xy)?,
2797 populated_card(stats, s.rel_yz)?,
2798 populated_card(stats, s.rel_zw)?,
2799 );
2800 let est_default = stats
2801 .estimate_join_cardinality(s.rel_wx, s.rel_xy, &[1], &[0])
2802 .saturating_add(stats.estimate_join_cardinality(s.rel_yz, s.rel_zw, &[1], &[0]));
2803 let est_alt = stats
2804 .estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0])
2805 .saturating_add(stats.estimate_join_cardinality(s.rel_zw, s.rel_wx, &[1], &[0]));
2806 let chosen = if est_alt < est_default {
2807 Cycle4Grouping::Alt
2808 } else {
2809 Cycle4Grouping::Default
2810 };
2811 let candidate = build_4cycle_body(&s, chosen);
2812 if format!("{:?}", candidate) == format!("{:?}", body) {
2813 return None;
2814 }
2815 Some(candidate)
2816 }
2817}
2818
2819#[cfg(test)]
2820mod selectivity_pass_tests {
2821 use super::selectivity_pass;
2822 use crate::Compiler;
2823 use xlog_stats::StatsManager;
2824
2825 fn body_snapshots(plan: &xlog_ir::ExecutionPlan) -> Vec<String> {
2826 plan.rules_by_scc
2827 .iter()
2828 .flatten()
2829 .map(|r| format!("{:?}", r.body))
2830 .collect()
2831 }
2832
2833 #[test]
2834 fn selectivity_pass_is_noop_for_triangle_plan() {
2835 let mut compiler = Compiler::new();
2836 let plan = compiler
2837 .compile("tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z).")
2838 .expect("compile");
2839 let before = body_snapshots(&plan);
2840 let stats = StatsManager::new();
2841 let mut plan2 = plan.clone();
2842 selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2843 let after = body_snapshots(&plan2);
2844 assert_eq!(
2845 before, after,
2846 "selectivity_pass must preserve every triangle rule body byte-for-byte"
2847 );
2848 }
2849
2850 #[test]
2851 fn selectivity_pass_is_noop_for_4cycle_plan() {
2852 let mut compiler = Compiler::new();
2853 let plan = compiler
2854 .compile("cycle4(W, X, Y, Z) :- e1(W, X), e2(X, Y), e3(Y, Z), e4(Z, W).")
2855 .expect("compile");
2856 let before = body_snapshots(&plan);
2857 let stats = StatsManager::new();
2858 let mut plan2 = plan.clone();
2859 selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2860 let after = body_snapshots(&plan2);
2861 assert_eq!(
2862 before, after,
2863 "selectivity_pass must preserve every 4-cycle rule body byte-for-byte"
2864 );
2865 }
2866
2867 #[test]
2868 fn selectivity_pass_is_noop_for_recursive_scc() {
2869 let mut compiler = Compiler::new();
2870 let plan = compiler
2871 .compile(
2872 "edge(1, 2). edge(2, 3). \
2873 reach(X, Y) :- edge(X, Y). \
2874 reach(X, Z) :- reach(X, Y), edge(Y, Z).",
2875 )
2876 .expect("compile");
2877 let before = body_snapshots(&plan);
2878 let stats = StatsManager::new();
2879 let mut plan2 = plan.clone();
2880 selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2881 let after = body_snapshots(&plan2);
2882 assert_eq!(
2883 before, after,
2884 "selectivity_pass must preserve recursive SCC bodies byte-for-byte"
2885 );
2886 }
2887
2888 use xlog_core::RelId;
2893 use xlog_ir::plan::{CompiledRule, PlanBuilder, Scc};
2894 use xlog_ir::rir::ProjectExpr;
2895 use xlog_ir::{ExecutionPlan, JoinType, RirNode};
2896
2897 fn synth_triangle_plan() -> ExecutionPlan {
2906 let inner = RirNode::Join {
2907 left: Box::new(RirNode::Scan { rel: RelId(1) }),
2908 right: Box::new(RirNode::Scan { rel: RelId(2) }),
2909 left_keys: vec![1],
2910 right_keys: vec![0],
2911 join_type: JoinType::Inner,
2912 };
2913 let outer = RirNode::Join {
2914 left: Box::new(inner),
2915 right: Box::new(RirNode::Scan { rel: RelId(3) }),
2916 left_keys: vec![0, 3],
2917 right_keys: vec![0, 1],
2918 join_type: JoinType::Inner,
2919 };
2920 let body = RirNode::Project {
2921 input: Box::new(outer),
2922 columns: vec![
2923 ProjectExpr::Column(0),
2924 ProjectExpr::Column(1),
2925 ProjectExpr::Column(3),
2926 ],
2927 };
2928 let mut builder = PlanBuilder::new();
2929 builder.add_scc(Scc {
2930 id: 0,
2931 predicates: vec!["tri".to_string()],
2932 is_recursive: false,
2933 });
2934 builder.add_rule(
2935 0,
2936 CompiledRule {
2937 head: "tri".to_string(),
2938 body,
2939 meta: Default::default(),
2940 },
2941 );
2942 builder.build()
2943 }
2944
2945 fn seed_triangle_stats(c1: u64, c2: u64, c3: u64) -> StatsManager {
2949 let mut stats = StatsManager::new();
2950 for (rid, card) in [(RelId(1), c1), (RelId(2), c2), (RelId(3), c3)] {
2951 stats.register_relation(rid);
2952 stats.update_cardinality(rid, card);
2953 }
2954 stats
2955 }
2956
2957 fn inspect_triangle_inner_pair(plan: &xlog_ir::ExecutionPlan) -> Option<(RelId, RelId)> {
2968 let body = &plan.rules_by_scc.iter().flatten().next()?.body;
2969 let body = match body {
2970 xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
2971 other => other,
2972 };
2973 let xlog_ir::RirNode::Project { input, .. } = body else {
2974 return None;
2975 };
2976 let xlog_ir::RirNode::Join { left, .. } = input.as_ref() else {
2977 return None;
2978 };
2979 let xlog_ir::RirNode::Join {
2980 left: l2,
2981 right: r2,
2982 ..
2983 } = left.as_ref()
2984 else {
2985 return None;
2986 };
2987 let xlog_ir::RirNode::Scan { rel: rel_l } = l2.as_ref() else {
2988 return None;
2989 };
2990 let xlog_ir::RirNode::Scan { rel: rel_r } = r2.as_ref() else {
2991 return None;
2992 };
2993 Some((*rel_l, *rel_r))
2994 }
2995
2996 #[test]
3003 fn selectivity_pass_picks_y_shared_inner_when_e1_e2_smallest() {
3004 let mut plan = synth_triangle_plan();
3005 let stats = seed_triangle_stats(10, 10, 100_000);
3007 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3008 let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3009 assert!(
3011 pair == (RelId(1), RelId(2)) || pair == (RelId(2), RelId(1)),
3012 "expected (RelId(1), RelId(2)) for Y-shared; got {:?}",
3013 pair
3014 );
3015 }
3016
3017 #[test]
3020 fn selectivity_pass_picks_x_shared_inner_when_e1_e3_smallest() {
3021 let mut plan = synth_triangle_plan();
3022 let stats = seed_triangle_stats(10, 100_000, 10);
3024 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3025 let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3026 assert!(
3028 pair == (RelId(1), RelId(3)) || pair == (RelId(3), RelId(1)),
3029 "expected (RelId(1), RelId(3)) for X-shared; got {:?}",
3030 pair
3031 );
3032 }
3033
3034 #[test]
3037 fn selectivity_pass_picks_z_shared_inner_when_e2_e3_smallest() {
3038 let mut plan = synth_triangle_plan();
3039 let stats = seed_triangle_stats(100_000, 10, 10);
3041 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3042 let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3043 assert!(
3045 pair == (RelId(2), RelId(3)) || pair == (RelId(3), RelId(2)),
3046 "expected (RelId(2), RelId(3)) for Z-shared; got {:?}",
3047 pair
3048 );
3049 }
3050
3051 #[test]
3056 fn selectivity_pass_two_snapshots_produce_different_inner_pairs() {
3057 let mut plan_a = synth_triangle_plan();
3058 let stats_a = seed_triangle_stats(10, 10, 100_000); selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3060 let pair_a = inspect_triangle_inner_pair(&plan_a).expect("snapshot A pair");
3061
3062 let mut plan_b = synth_triangle_plan();
3063 let stats_b = seed_triangle_stats(100_000, 10, 10); selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3065 let pair_b = inspect_triangle_inner_pair(&plan_b).expect("snapshot B pair");
3066
3067 let normalize = |(a, b): (RelId, RelId)| -> (RelId, RelId) {
3068 if a.0 <= b.0 {
3069 (a, b)
3070 } else {
3071 (b, a)
3072 }
3073 };
3074 assert_ne!(
3075 normalize(pair_a),
3076 normalize(pair_b),
3077 "two different stats snapshots must produce different inner pairs; \
3078 got A = {:?}, B = {:?}",
3079 pair_a,
3080 pair_b
3081 );
3082 }
3083
3084 #[test]
3092 fn selectivity_pass_with_only_relation_cards_may_pick_arbitrary_pair() {
3093 let mut plan = synth_triangle_plan();
3094 let stats = seed_triangle_stats(100, 100, 100);
3096 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3097 let _ = inspect_triangle_inner_pair(&plan);
3100 }
3101
3102 fn synth_4cycle_plan() -> ExecutionPlan {
3113 let inner_left = RirNode::Join {
3114 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3115 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3116 left_keys: vec![1],
3117 right_keys: vec![0],
3118 join_type: JoinType::Inner,
3119 };
3120 let inner_right = RirNode::Join {
3121 left: Box::new(RirNode::Scan { rel: RelId(3) }),
3122 right: Box::new(RirNode::Scan { rel: RelId(4) }),
3123 left_keys: vec![1],
3124 right_keys: vec![0],
3125 join_type: JoinType::Inner,
3126 };
3127 let outer = RirNode::Join {
3128 left: Box::new(inner_left),
3129 right: Box::new(inner_right),
3130 left_keys: vec![0, 3],
3131 right_keys: vec![3, 0],
3132 join_type: JoinType::Inner,
3133 };
3134 let body = RirNode::Project {
3135 input: Box::new(outer),
3136 columns: vec![
3137 ProjectExpr::Column(0),
3138 ProjectExpr::Column(1),
3139 ProjectExpr::Column(3),
3140 ProjectExpr::Column(5),
3141 ],
3142 };
3143 let mut builder = PlanBuilder::new();
3144 builder.add_scc(Scc {
3145 id: 0,
3146 predicates: vec!["cyc".to_string()],
3147 is_recursive: false,
3148 });
3149 builder.add_rule(
3150 0,
3151 CompiledRule {
3152 head: "cyc".to_string(),
3153 body,
3154 meta: Default::default(),
3155 },
3156 );
3157 builder.build()
3158 }
3159
3160 fn seed_4cycle_stats(c1: u64, c2: u64, c3: u64, c4: u64) -> StatsManager {
3161 let mut stats = StatsManager::new();
3162 for (rid, card) in [
3163 (RelId(1), c1),
3164 (RelId(2), c2),
3165 (RelId(3), c3),
3166 (RelId(4), c4),
3167 ] {
3168 stats.register_relation(rid);
3169 stats.update_cardinality(rid, card);
3170 }
3171 stats
3172 }
3173
3174 fn inspect_4cycle_grouping(
3178 plan: &xlog_ir::ExecutionPlan,
3179 ) -> Option<(RelId, RelId, RelId, RelId)> {
3180 let body = &plan.rules_by_scc.iter().flatten().next()?.body;
3181 let body = match body {
3182 xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
3183 other => other,
3184 };
3185 let xlog_ir::RirNode::Project { input, .. } = body else {
3186 return None;
3187 };
3188 let xlog_ir::RirNode::Join { left, right, .. } = input.as_ref() else {
3189 return None;
3190 };
3191 let xlog_ir::RirNode::Join {
3192 left: ll,
3193 right: lr,
3194 ..
3195 } = left.as_ref()
3196 else {
3197 return None;
3198 };
3199 let xlog_ir::RirNode::Join {
3200 left: rl,
3201 right: rr,
3202 ..
3203 } = right.as_ref()
3204 else {
3205 return None;
3206 };
3207 let xlog_ir::RirNode::Scan { rel: r_ll } = ll.as_ref() else {
3208 return None;
3209 };
3210 let xlog_ir::RirNode::Scan { rel: r_lr } = lr.as_ref() else {
3211 return None;
3212 };
3213 let xlog_ir::RirNode::Scan { rel: r_rl } = rl.as_ref() else {
3214 return None;
3215 };
3216 let xlog_ir::RirNode::Scan { rel: r_rr } = rr.as_ref() else {
3217 return None;
3218 };
3219 Some((*r_ll, *r_lr, *r_rl, *r_rr))
3220 }
3221
3222 #[test]
3244 fn selectivity_pass_4cycle_picks_default_grouping_when_corners_smallest() {
3245 let mut plan = synth_4cycle_plan();
3246 let stats = seed_4cycle_stats(10, 10_000, 10_000, 10);
3247 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3248 let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3249 assert_eq!(
3251 (ll, lr, rl, rr),
3252 (RelId(1), RelId(2), RelId(3), RelId(4)),
3253 "expected Default grouping"
3254 );
3255 }
3256
3257 #[test]
3270 fn selectivity_pass_4cycle_picks_alt_grouping_when_diagonals_smallest() {
3271 let mut plan = synth_4cycle_plan();
3272 let stats = seed_4cycle_stats(10_000, 10_000, 10, 10);
3273 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3274 let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3275 assert_eq!(
3277 (ll, lr, rl, rr),
3278 (RelId(2), RelId(3), RelId(4), RelId(1)),
3279 "expected Alt grouping"
3280 );
3281 }
3282
3283 #[test]
3287 fn selectivity_pass_4cycle_two_snapshots_produce_different_groupings() {
3288 let mut plan_a = synth_4cycle_plan();
3289 let stats_a = seed_4cycle_stats(10, 10_000, 10_000, 10); selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3291 let g_a = inspect_4cycle_grouping(&plan_a).expect("grouping a");
3292
3293 let mut plan_b = synth_4cycle_plan();
3294 let stats_b = seed_4cycle_stats(10_000, 10_000, 10, 10); selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3296 let g_b = inspect_4cycle_grouping(&plan_b).expect("grouping b");
3297
3298 assert_ne!(
3299 g_a, g_b,
3300 "two different stats snapshots must produce different 4-cycle groupings; \
3301 got A = {:?}, B = {:?}",
3302 g_a, g_b
3303 );
3304 }
3305
3306 #[test]
3309 fn selectivity_pass_4cycle_skips_when_card_missing() {
3310 let mut plan = synth_4cycle_plan();
3311 let mut stats = StatsManager::new();
3313 for rid in [RelId(1), RelId(2), RelId(3)] {
3314 stats.register_relation(rid);
3315 stats.update_cardinality(rid, 100);
3316 }
3317 let before = format!("{:?}", plan.rules_by_scc[0][0].body);
3318 selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3319 let after = format!("{:?}", plan.rules_by_scc[0][0].body);
3320 assert_eq!(
3321 before, after,
3322 "missing-stats safety floor must leave body unchanged"
3323 );
3324 }
3325}
3326
3327#[cfg(test)]
3328mod tests {
3329 use super::*;
3330 use xlog_core::ScalarType;
3331 use xlog_ir::{ConstValue, ProjectExpr};
3332 use xlog_stats::ColumnStats;
3333
3334 fn make_stats_manager() -> Arc<StatsManager> {
3335 let mut mgr = StatsManager::new();
3336
3337 mgr.register_relation(RelId(1));
3339 mgr.update_cardinality(RelId(1), 10_000);
3340 mgr.update_byte_size(RelId(1), 320_000); mgr.register_relation(RelId(2));
3343 mgr.update_cardinality(RelId(2), 5_000);
3344 mgr.update_byte_size(RelId(2), 160_000);
3345
3346 mgr.register_relation(RelId(3));
3347 mgr.update_cardinality(RelId(3), 1_000);
3348 mgr.update_byte_size(RelId(3), 32_000);
3349
3350 let mut col0 = ColumnStats::new(0, ScalarType::I64);
3352 col0.update_distinct(1000);
3353 col0.update_range(0, 10000);
3354 mgr.add_column_stats(RelId(1), col0);
3355
3356 let mut col1 = ColumnStats::new(1, ScalarType::I64);
3357 col1.update_distinct(100);
3358 mgr.add_column_stats(RelId(1), col1);
3359
3360 Arc::new(mgr)
3361 }
3362
3363 #[test]
3364 fn test_optimizer_new() {
3365 let stats = make_stats_manager();
3366 let optimizer = Optimizer::new(stats);
3367
3368 assert_eq!(optimizer.config().dp_threshold, 10);
3369 assert!(optimizer.config().enable_pushdown);
3370 }
3371
3372 #[test]
3373 fn test_optimizer_with_config() {
3374 let stats = make_stats_manager();
3375 let config = OptimizerConfig {
3376 dp_threshold: 5,
3377 enable_pushdown: false,
3378 ..Default::default()
3379 };
3380 let optimizer = Optimizer::with_config(stats, config);
3381
3382 assert_eq!(optimizer.config().dp_threshold, 5);
3383 assert!(!optimizer.config().enable_pushdown);
3384 }
3385
3386 #[test]
3387 fn test_estimate_scan_cost() {
3388 let stats = make_stats_manager();
3389 let optimizer = Optimizer::new(stats);
3390
3391 let scan = RirNode::Scan { rel: RelId(1) };
3392 let cost = optimizer.estimate_cost(&scan);
3393
3394 assert_eq!(cost.rows, 10_000);
3395 assert!(cost.gpu_mem > 0);
3396 assert_eq!(cost.transfers, 0); }
3398
3399 #[test]
3400 fn test_estimate_scan_cost_unknown_relation() {
3401 let stats = Arc::new(StatsManager::new());
3402 let optimizer = Optimizer::new(stats);
3403
3404 let scan = RirNode::Scan { rel: RelId(999) };
3405 let cost = optimizer.estimate_cost(&scan);
3406
3407 assert_eq!(cost.rows, 1000);
3409 }
3410
3411 #[test]
3412 fn test_estimate_filter_cost() {
3413 let stats = make_stats_manager();
3414 let optimizer = Optimizer::new(stats);
3415
3416 let filter = RirNode::Filter {
3417 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3418 predicate: Expr::Compare {
3419 left: Box::new(Expr::Column(0)),
3420 op: CompareOp::Eq,
3421 right: Box::new(Expr::Const(ConstValue::I64(42))),
3422 },
3423 };
3424
3425 let cost = optimizer.estimate_cost(&filter);
3426
3427 assert!(cost.rows < 10_000);
3429 assert!(cost.rows >= 1);
3430 }
3431
3432 #[test]
3433 fn test_estimate_join_cost() {
3434 let stats = make_stats_manager();
3435 let optimizer = Optimizer::new(stats);
3436
3437 let join = RirNode::Join {
3438 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3439 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3440 left_keys: vec![0],
3441 right_keys: vec![0],
3442 join_type: JoinType::Inner,
3443 };
3444
3445 let cost = optimizer.estimate_cost(&join);
3446
3447 assert!(cost.rows > 0);
3449 assert!(cost.cpu_cost > 0.0);
3450 assert!(cost.gpu_mem > 0);
3451 }
3452
3453 #[test]
3454 fn test_estimate_join_cost_with_selectivity() {
3455 let mut mgr = StatsManager::new();
3456 mgr.register_relation(RelId(1));
3457 mgr.register_relation(RelId(2));
3458 mgr.update_cardinality(RelId(1), 1000);
3459 mgr.update_cardinality(RelId(2), 500);
3460
3461 mgr.record_join_result(RelId(1), RelId(2), vec![0], vec![0], 500_000, 2500);
3463
3464 let optimizer = Optimizer::new(Arc::new(mgr));
3465
3466 let join = RirNode::Join {
3467 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3468 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3469 left_keys: vec![0],
3470 right_keys: vec![0],
3471 join_type: JoinType::Inner,
3472 };
3473
3474 let cost = optimizer.estimate_cost(&join);
3475
3476 assert!(cost.rows > 0);
3478 }
3479
3480 #[test]
3481 fn test_predicate_pushdown_simple_scan() {
3482 let stats = make_stats_manager();
3483 let optimizer = Optimizer::new(stats);
3484
3485 let scan = RirNode::Scan { rel: RelId(1) };
3486 let optimized = optimizer.optimize(scan);
3487
3488 assert!(matches!(optimized, RirNode::Scan { rel: RelId(1) }));
3490 }
3491
3492 #[test]
3493 fn test_predicate_pushdown_filter_on_scan() {
3494 let stats = make_stats_manager();
3495 let optimizer = Optimizer::new(stats);
3496
3497 let filter = RirNode::Filter {
3498 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3499 predicate: Expr::Compare {
3500 left: Box::new(Expr::Column(0)),
3501 op: CompareOp::Eq,
3502 right: Box::new(Expr::Const(ConstValue::I64(42))),
3503 },
3504 };
3505
3506 let optimized = optimizer.optimize(filter);
3507
3508 assert!(matches!(optimized, RirNode::Filter { .. }));
3510 }
3511
3512 #[test]
3513 fn test_predicate_pushdown_merges_filters() {
3514 let stats = make_stats_manager();
3515 let optimizer = Optimizer::new(stats);
3516
3517 let nested_filter = RirNode::Filter {
3518 input: Box::new(RirNode::Filter {
3519 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3520 predicate: Expr::Compare {
3521 left: Box::new(Expr::Column(0)),
3522 op: CompareOp::Gt,
3523 right: Box::new(Expr::Const(ConstValue::I64(0))),
3524 },
3525 }),
3526 predicate: Expr::Compare {
3527 left: Box::new(Expr::Column(0)),
3528 op: CompareOp::Lt,
3529 right: Box::new(Expr::Const(ConstValue::I64(100))),
3530 },
3531 };
3532
3533 let optimized = optimizer.optimize(nested_filter);
3534
3535 if let RirNode::Filter { predicate, .. } = optimized {
3537 assert!(matches!(predicate, Expr::And(_)));
3538 } else {
3539 panic!("Expected Filter node");
3540 }
3541 }
3542
3543 #[test]
3544 fn test_predicate_pushdown_through_project() {
3545 let stats = make_stats_manager();
3546 let optimizer = Optimizer::new(stats);
3547
3548 let plan = RirNode::Filter {
3550 input: Box::new(RirNode::Project {
3551 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3552 columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(1)],
3553 }),
3554 predicate: Expr::Compare {
3555 left: Box::new(Expr::Column(0)),
3556 op: CompareOp::Eq,
3557 right: Box::new(Expr::Const(ConstValue::I64(42))),
3558 },
3559 };
3560
3561 let optimized = optimizer.optimize(plan);
3562
3563 assert!(matches!(optimized, RirNode::Project { .. }));
3565 if let RirNode::Project { input, .. } = optimized {
3566 assert!(matches!(*input, RirNode::Filter { .. }));
3567 }
3568 }
3569
3570 #[test]
3571 fn test_predicate_pushdown_into_join() {
3572 let stats = make_stats_manager();
3573 let optimizer = Optimizer::new(stats);
3574
3575 let plan = RirNode::Filter {
3577 input: Box::new(RirNode::Join {
3578 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3579 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3580 left_keys: vec![0],
3581 right_keys: vec![0],
3582 join_type: JoinType::Inner,
3583 }),
3584 predicate: Expr::Compare {
3585 left: Box::new(Expr::Column(0)), op: CompareOp::Eq,
3587 right: Box::new(Expr::Const(ConstValue::I64(42))),
3588 },
3589 };
3590
3591 let optimized = optimizer.optimize(plan);
3592
3593 if let RirNode::Join { left, .. } = optimized {
3595 assert!(matches!(*left, RirNode::Filter { .. }));
3596 } else {
3597 panic!("Expected Join node");
3598 }
3599 }
3600
3601 #[test]
3602 fn test_plan_cost_total() {
3603 let cost = PlanCost {
3604 rows: 1000,
3605 cpu_cost: 100.0,
3606 gpu_mem: 1_000_000,
3607 transfers: 2,
3608 };
3609
3610 let total = cost.total_cost(100.0);
3611
3612 assert!((total - 1300.0).abs() < 0.001);
3615 }
3616
3617 #[test]
3618 fn test_plan_cost_then() {
3619 let cost1 = PlanCost {
3620 rows: 1000,
3621 cpu_cost: 50.0,
3622 gpu_mem: 500,
3623 transfers: 1,
3624 };
3625
3626 let cost2 = PlanCost {
3627 rows: 500,
3628 cpu_cost: 25.0,
3629 gpu_mem: 800,
3630 transfers: 1,
3631 };
3632
3633 let combined = cost1.then(cost2);
3634
3635 assert_eq!(combined.rows, 500); assert_eq!(combined.cpu_cost, 75.0);
3637 assert_eq!(combined.gpu_mem, 800); assert_eq!(combined.transfers, 2);
3639 }
3640
3641 #[test]
3642 fn test_optimizer_config_default() {
3643 let config = OptimizerConfig::default();
3644
3645 assert_eq!(config.dp_threshold, 10);
3646 assert!((config.index_heat_threshold - 0.7).abs() < 0.001);
3647 assert!(config.enable_pushdown);
3648 assert!((config.default_filter_selectivity - 0.1).abs() < 0.001);
3649 }
3650
3651 #[test]
3652 fn test_should_use_greedy() {
3653 let stats = make_stats_manager();
3654 let config = OptimizerConfig {
3655 dp_threshold: 2,
3656 ..Default::default()
3657 };
3658 let optimizer = Optimizer::with_config(stats, config);
3659
3660 let single = RirNode::Scan { rel: RelId(1) };
3662 assert!(!optimizer.should_use_greedy(&single));
3663
3664 let multi = RirNode::Join {
3666 left: Box::new(RirNode::Join {
3667 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3668 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3669 left_keys: vec![0],
3670 right_keys: vec![0],
3671 join_type: JoinType::Inner,
3672 }),
3673 right: Box::new(RirNode::Scan { rel: RelId(3) }),
3674 left_keys: vec![0],
3675 right_keys: vec![0],
3676 join_type: JoinType::Inner,
3677 };
3678 assert!(optimizer.should_use_greedy(&multi));
3679 }
3680
3681 #[test]
3682 fn test_recommend_indexes() {
3683 let mut mgr = StatsManager::new();
3684 mgr.register_relation(RelId(1));
3685 mgr.register_relation(RelId(2));
3686
3687 for _ in 0..50 {
3689 mgr.record_access(RelId(1));
3690 }
3691
3692 let optimizer = Optimizer::new(Arc::new(mgr));
3693 let recommendations = optimizer.recommend_indexes();
3694
3695 assert!(recommendations.contains(&RelId(1)));
3696 assert!(!recommendations.contains(&RelId(2)));
3697 }
3698
3699 #[test]
3700 fn test_estimate_groupby_cost() {
3701 let stats = make_stats_manager();
3702 let optimizer = Optimizer::new(stats);
3703
3704 let groupby = RirNode::GroupBy {
3705 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3706 key_cols: vec![0],
3707 aggs: vec![(1, xlog_core::AggOp::Sum)],
3708 };
3709
3710 let cost = optimizer.estimate_cost(&groupby);
3711
3712 assert!(cost.rows < 10_000);
3714 assert!(cost.rows >= 1);
3715 }
3716
3717 #[test]
3718 fn test_estimate_union_cost() {
3719 let stats = make_stats_manager();
3720 let optimizer = Optimizer::new(stats);
3721
3722 let union = RirNode::Union {
3723 inputs: vec![
3724 RirNode::Scan { rel: RelId(1) },
3725 RirNode::Scan { rel: RelId(2) },
3726 ],
3727 };
3728
3729 let cost = optimizer.estimate_cost(&union);
3730
3731 assert_eq!(cost.rows, 15_000); }
3734
3735 #[test]
3736 fn test_estimate_distinct_cost() {
3737 let stats = make_stats_manager();
3738 let optimizer = Optimizer::new(stats);
3739
3740 let distinct = RirNode::Distinct {
3741 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3742 key_cols: vec![0],
3743 };
3744
3745 let cost = optimizer.estimate_cost(&distinct);
3746
3747 assert!(cost.rows <= 10_000);
3749 assert!(cost.rows >= 1);
3750 }
3751
3752 #[test]
3753 fn test_estimate_diff_cost() {
3754 let stats = make_stats_manager();
3755 let optimizer = Optimizer::new(stats);
3756
3757 let diff = RirNode::Diff {
3758 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3759 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3760 };
3761
3762 let cost = optimizer.estimate_cost(&diff);
3763
3764 assert!(cost.rows <= 10_000);
3766 assert!(cost.rows >= 1);
3767 }
3768
3769 #[test]
3770 fn test_estimate_fixpoint_cost() {
3771 let stats = make_stats_manager();
3772 let optimizer = Optimizer::new(stats);
3773
3774 let fixpoint = RirNode::Fixpoint {
3775 scc_id: 0,
3776 base: Box::new(RirNode::Scan { rel: RelId(1) }),
3777 recursive: Box::new(RirNode::Scan { rel: RelId(1) }),
3778 delta_rel: RelId(10),
3779 full_rel: RelId(11),
3780 };
3781
3782 let cost = optimizer.estimate_cost(&fixpoint);
3783
3784 assert!(cost.rows >= 10_000);
3786 }
3787
3788 #[test]
3789 fn test_predicate_selectivity_equality() {
3790 let stats = make_stats_manager();
3791 let optimizer = Optimizer::new(stats);
3792
3793 let scan = RirNode::Scan { rel: RelId(1) };
3794
3795 let eq_pred = Expr::Compare {
3797 left: Box::new(Expr::Column(0)),
3798 op: CompareOp::Eq,
3799 right: Box::new(Expr::Const(ConstValue::I64(42))),
3800 };
3801
3802 let selectivity = optimizer.estimate_predicate_selectivity(&eq_pred, &scan);
3803
3804 assert!(selectivity < 0.01);
3806 assert!(selectivity > 0.0);
3807 }
3808
3809 #[test]
3810 fn test_predicate_selectivity_and() {
3811 let stats = make_stats_manager();
3812 let optimizer = Optimizer::new(stats);
3813
3814 let scan = RirNode::Scan { rel: RelId(1) };
3815
3816 let and_pred = Expr::And(vec![
3818 Expr::Compare {
3819 left: Box::new(Expr::Column(0)),
3820 op: CompareOp::Gt,
3821 right: Box::new(Expr::Const(ConstValue::I64(0))),
3822 },
3823 Expr::Compare {
3824 left: Box::new(Expr::Column(0)),
3825 op: CompareOp::Lt,
3826 right: Box::new(Expr::Const(ConstValue::I64(100))),
3827 },
3828 ]);
3829
3830 let selectivity = optimizer.estimate_predicate_selectivity(&and_pred, &scan);
3831
3832 assert!(selectivity < 0.5);
3834 assert!(selectivity > 0.0);
3835 }
3836
3837 #[test]
3838 fn test_predicate_selectivity_not() {
3839 let stats = make_stats_manager();
3840 let optimizer = Optimizer::new(stats);
3841
3842 let scan = RirNode::Scan { rel: RelId(1) };
3843
3844 let not_pred = Expr::Not(Box::new(Expr::Compare {
3846 left: Box::new(Expr::Column(0)),
3847 op: CompareOp::Eq,
3848 right: Box::new(Expr::Const(ConstValue::I64(42))),
3849 }));
3850
3851 let selectivity = optimizer.estimate_predicate_selectivity(¬_pred, &scan);
3852
3853 assert!(selectivity > 0.9);
3855 }
3856
3857 #[test]
3858 fn test_join_type_semi() {
3859 let stats = make_stats_manager();
3860 let optimizer = Optimizer::new(stats);
3861
3862 let semi_join = RirNode::Join {
3863 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3864 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3865 left_keys: vec![0],
3866 right_keys: vec![0],
3867 join_type: JoinType::Semi,
3868 };
3869
3870 let cost = optimizer.estimate_cost(&semi_join);
3871
3872 assert!(cost.rows <= 10_000);
3874 }
3875
3876 #[test]
3877 fn test_join_type_anti() {
3878 let stats = make_stats_manager();
3879 let optimizer = Optimizer::new(stats);
3880
3881 let anti_join = RirNode::Join {
3882 left: Box::new(RirNode::Scan { rel: RelId(1) }),
3883 right: Box::new(RirNode::Scan { rel: RelId(2) }),
3884 left_keys: vec![0],
3885 right_keys: vec![0],
3886 join_type: JoinType::Anti,
3887 };
3888
3889 let cost = optimizer.estimate_cost(&anti_join);
3890
3891 assert!(cost.rows <= 10_000);
3893 }
3894
3895 #[test]
3896 fn test_pushdown_disabled() {
3897 let stats = make_stats_manager();
3898 let config = OptimizerConfig {
3899 enable_pushdown: false,
3900 ..Default::default()
3901 };
3902 let optimizer = Optimizer::with_config(stats, config);
3903
3904 let plan = RirNode::Filter {
3906 input: Box::new(RirNode::Filter {
3907 input: Box::new(RirNode::Scan { rel: RelId(1) }),
3908 predicate: Expr::Compare {
3909 left: Box::new(Expr::Column(0)),
3910 op: CompareOp::Gt,
3911 right: Box::new(Expr::Const(ConstValue::I64(0))),
3912 },
3913 }),
3914 predicate: Expr::Compare {
3915 left: Box::new(Expr::Column(0)),
3916 op: CompareOp::Lt,
3917 right: Box::new(Expr::Const(ConstValue::I64(100))),
3918 },
3919 };
3920
3921 let optimized = optimizer.optimize(plan.clone());
3922
3923 if let RirNode::Filter { input, .. } = optimized {
3926 assert!(matches!(*input, RirNode::Filter { .. }));
3927 } else {
3928 panic!("Expected Filter node");
3929 }
3930 }
3931
3932 #[test]
3933 fn test_collect_columns() {
3934 let expr = Expr::And(vec![
3935 Expr::Compare {
3936 left: Box::new(Expr::Column(0)),
3937 op: CompareOp::Eq,
3938 right: Box::new(Expr::Column(2)),
3939 },
3940 Expr::Compare {
3941 left: Box::new(Expr::Column(1)),
3942 op: CompareOp::Gt,
3943 right: Box::new(Expr::Const(ConstValue::I64(0))),
3944 },
3945 ]);
3946
3947 let cols = Optimizer::collect_columns(&expr);
3948
3949 assert!(cols.contains(&0));
3950 assert!(cols.contains(&1));
3951 assert!(cols.contains(&2));
3952 }
3953
3954 #[test]
3955 fn test_flatten_and() {
3956 let nested = Expr::And(vec![
3957 Expr::And(vec![
3958 Expr::Compare {
3959 left: Box::new(Expr::Column(0)),
3960 op: CompareOp::Eq,
3961 right: Box::new(Expr::Const(ConstValue::I64(1))),
3962 },
3963 Expr::Compare {
3964 left: Box::new(Expr::Column(1)),
3965 op: CompareOp::Eq,
3966 right: Box::new(Expr::Const(ConstValue::I64(2))),
3967 },
3968 ]),
3969 Expr::Compare {
3970 left: Box::new(Expr::Column(2)),
3971 op: CompareOp::Eq,
3972 right: Box::new(Expr::Const(ConstValue::I64(3))),
3973 },
3974 ]);
3975
3976 let flattened = Optimizer::flatten_and(&nested);
3977
3978 assert_eq!(flattened.len(), 3);
3979 }
3980
3981 #[test]
3982 fn test_conjoin_single() {
3983 let single = vec![Expr::Compare {
3984 left: Box::new(Expr::Column(0)),
3985 op: CompareOp::Eq,
3986 right: Box::new(Expr::Const(ConstValue::I64(42))),
3987 }];
3988
3989 let result = Optimizer::conjoin(single);
3990
3991 assert!(matches!(result, Expr::Compare { .. }));
3992 }
3993
3994 #[test]
3995 fn test_conjoin_multiple() {
3996 let multiple = vec![
3997 Expr::Compare {
3998 left: Box::new(Expr::Column(0)),
3999 op: CompareOp::Eq,
4000 right: Box::new(Expr::Const(ConstValue::I64(1))),
4001 },
4002 Expr::Compare {
4003 left: Box::new(Expr::Column(1)),
4004 op: CompareOp::Eq,
4005 right: Box::new(Expr::Const(ConstValue::I64(2))),
4006 },
4007 ];
4008
4009 let result = Optimizer::conjoin(multiple);
4010
4011 assert!(matches!(result, Expr::And(_)));
4012 }
4013
4014 #[test]
4015 fn test_predicate_pushdown_with_schemas() {
4016 let stats = make_stats_manager();
4019 let mut optimizer = Optimizer::new(stats);
4020
4021 let left_schema = Schema::new(vec![
4023 ("c0".to_string(), xlog_core::ScalarType::Symbol),
4024 ("c1".to_string(), xlog_core::ScalarType::Symbol),
4025 ("c2".to_string(), xlog_core::ScalarType::Symbol),
4026 ]);
4027 let right_schema = Schema::new(vec![
4028 ("c0".to_string(), xlog_core::ScalarType::Symbol),
4029 ("c1".to_string(), xlog_core::ScalarType::Symbol),
4030 ("c2".to_string(), xlog_core::ScalarType::U32),
4031 ]);
4032
4033 let mut schemas = HashMap::new();
4034 schemas.insert(RelId(1), left_schema);
4035 schemas.insert(RelId(2), right_schema);
4036 optimizer.set_schemas(schemas);
4037
4038 let plan = RirNode::Filter {
4040 input: Box::new(RirNode::Join {
4041 left: Box::new(RirNode::Scan { rel: RelId(1) }),
4042 right: Box::new(RirNode::Scan { rel: RelId(2) }),
4043 left_keys: vec![0],
4044 right_keys: vec![0],
4045 join_type: JoinType::Inner,
4046 }),
4047 predicate: Expr::Compare {
4048 left: Box::new(Expr::Column(5)), op: CompareOp::Ge,
4050 right: Box::new(Expr::Const(ConstValue::U32(4))),
4051 },
4052 };
4053
4054 let optimized = optimizer.optimize(plan);
4055
4056 if let RirNode::Join { right, .. } = optimized {
4058 if let RirNode::Filter { predicate, .. } = *right {
4059 if let Expr::Compare { left, .. } = predicate {
4060 if let Expr::Column(idx) = *left {
4061 assert_eq!(
4062 idx, 2,
4063 "Column should be remapped to 2 (5 - left_width(3) = 2)"
4064 );
4065 } else {
4066 panic!("Expected Column expression");
4067 }
4068 } else {
4069 panic!("Expected Compare predicate");
4070 }
4071 } else {
4072 panic!("Expected Filter on right side of join");
4073 }
4074 } else {
4075 panic!("Expected Join node");
4076 }
4077 }
4078
4079 fn build_canonical_triangle_multiway() -> RirNode {
4095 let scan_xy = RirNode::Scan { rel: RelId(1) };
4096 let scan_yz = RirNode::Scan { rel: RelId(2) };
4097 let scan_xz = RirNode::Scan { rel: RelId(3) };
4098 let inner_join = RirNode::Join {
4099 left: Box::new(scan_xy.clone()),
4100 right: Box::new(scan_yz.clone()),
4101 left_keys: vec![1],
4102 right_keys: vec![0],
4103 join_type: JoinType::Inner,
4104 };
4105 let outer_join = RirNode::Join {
4106 left: Box::new(inner_join),
4107 right: Box::new(scan_xz.clone()),
4108 left_keys: vec![0, 3],
4109 right_keys: vec![0, 1],
4110 join_type: JoinType::Inner,
4111 };
4112 let fallback = RirNode::Project {
4113 input: Box::new(outer_join),
4114 columns: vec![
4115 ProjectExpr::Column(0),
4116 ProjectExpr::Column(1),
4117 ProjectExpr::Column(3),
4118 ],
4119 };
4120 RirNode::MultiWayJoin {
4121 inputs: vec![scan_xy, scan_yz, scan_xz],
4122 slot_vars: vec![
4123 vec![Some(0), Some(1)],
4124 vec![Some(1), Some(2)],
4125 vec![Some(0), Some(2)],
4126 ],
4127 output_columns: vec![
4128 ProjectExpr::Column(0),
4129 ProjectExpr::Column(1),
4130 ProjectExpr::Column(3),
4131 ],
4132 fallback: Box::new(fallback),
4133 plan: None,
4134 var_order: None,
4135 }
4136 }
4137
4138 fn build_4input_multiway() -> RirNode {
4147 let scans = [RelId(1), RelId(2), RelId(3), RelId(1)]
4148 .map(|rel| RirNode::Scan { rel })
4149 .to_vec();
4150 let slot_vars = vec![
4152 vec![Some(0u32), Some(1)],
4153 vec![Some(1u32), Some(2)],
4154 vec![Some(2u32), Some(3)],
4155 vec![Some(0u32), Some(3)],
4156 ];
4157 let output_columns = vec![
4160 ProjectExpr::Column(0),
4161 ProjectExpr::Column(1),
4162 ProjectExpr::Column(2),
4163 ProjectExpr::Column(3),
4164 ];
4165 let fallback = RirNode::Unit;
4168 RirNode::MultiWayJoin {
4169 inputs: scans,
4170 slot_vars,
4171 output_columns,
4172 fallback: Box::new(fallback),
4173 plan: None,
4174 var_order: None,
4175 }
4176 }
4177
4178 #[test]
4179 fn optimize_returns_multiway_unchanged() {
4180 let optimizer = Optimizer::new(make_stats_manager());
4181 for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4182 let optimized = optimizer.optimize(node.clone());
4183 match (&node, &optimized) {
4184 (
4185 RirNode::MultiWayJoin {
4186 inputs: a_in,
4187 output_columns: a_out,
4188 ..
4189 },
4190 RirNode::MultiWayJoin {
4191 inputs: b_in,
4192 output_columns: b_out,
4193 ..
4194 },
4195 ) => {
4196 assert_eq!(a_in.len(), b_in.len());
4197 assert_eq!(a_out.len(), b_out.len());
4198 }
4199 _ => panic!("optimize() must return a MultiWayJoin"),
4200 }
4201 }
4202 }
4203
4204 #[test]
4205 fn estimate_width_uses_output_columns_arity() {
4206 let optimizer = Optimizer::new(make_stats_manager());
4207 assert_eq!(
4209 optimizer.estimate_width(&build_canonical_triangle_multiway()),
4210 3
4211 );
4212 assert_eq!(optimizer.estimate_width(&build_4input_multiway()), 4);
4216 }
4217
4218 #[test]
4219 fn estimate_cost_sums_input_costs() {
4220 let optimizer = Optimizer::new(make_stats_manager());
4221
4222 let cost_tri = optimizer.estimate_cost(&build_canonical_triangle_multiway());
4225 assert!(
4226 cost_tri.rows >= 16_000,
4227 "expected cost.rows >= 16000, got {}",
4228 cost_tri.rows
4229 );
4230
4231 let cost_4 = optimizer.estimate_cost(&build_4input_multiway());
4236 assert!(
4237 cost_4.rows >= 26_000,
4238 "expected 4-input cost.rows >= 26000, got {}",
4239 cost_4.rows
4240 );
4241 assert!(
4242 cost_4.rows > cost_tri.rows,
4243 "4-input cost ({}) must exceed triangle cost ({})",
4244 cost_4.rows,
4245 cost_tri.rows
4246 );
4247 }
4248
4249 #[test]
4250 fn find_column_relation_returns_none_for_multiway() {
4251 let optimizer = Optimizer::new(make_stats_manager());
4252 for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4257 for col in 0..node.referenced_relations().len() {
4258 assert!(
4259 optimizer.find_column_relation(&node, col).is_none(),
4260 "find_column_relation must return None for any \
4261 MultiWayJoin column (col={})",
4262 col,
4263 );
4264 }
4265 }
4266 }
4267}