Skip to main content

xlog_logic/
optimizer.rs

1//! Query optimizer for join ordering and predicate pushdown.
2//!
3//! This module provides cost-based query optimization for XLOG's relational IR.
4//! It uses GPU-resident statistics from [`xlog_stats::StatsManager`] to make
5//! informed decisions about:
6//!
7//! - **Predicate pushdown**: Moving filter predicates closer to base scans to
8//!   reduce intermediate result sizes early in the pipeline.
9//! - **Cost estimation**: Computing expected row counts, CPU costs, GPU memory
10//!   usage, and data transfer counts for plan nodes.
11//! - **Join ordering**: (Future) Reordering joins based on selectivity estimates
12//!   to minimize intermediate result sizes.
13//!
14//! # Usage
15//!
16//! ```ignore
17//! use std::sync::Arc;
18//! use xlog_logic::optimizer::{Optimizer, OptimizerConfig, PlanCost};
19//! use xlog_stats::StatsManager;
20//!
21//! let stats = Arc::new(StatsManager::new());
22//! let optimizer = Optimizer::new(stats);
23//!
24//! // Optimize a query plan
25//! let optimized_plan = optimizer.optimize(original_plan);
26//!
27//! // Get cost estimates
28//! let cost = optimizer.estimate_cost(&optimized_plan);
29//! println!("Estimated rows: {}, GPU memory: {} bytes", cost.rows, cost.gpu_mem);
30//! ```
31
32use std::collections::HashMap;
33use std::sync::Arc;
34use xlog_core::{RelId, Schema};
35use xlog_ir::{CompareOp, Expr, JoinType, RirNode};
36use xlog_stats::StatsManager;
37
38/// Configuration for query optimization.
39///
40/// Controls optimizer behavior including thresholds for algorithm selection
41/// and feature toggles.
42#[derive(Debug, Clone)]
43#[non_exhaustive]
44pub struct OptimizerConfig {
45    /// Maximum number of relations for exhaustive dynamic programming.
46    ///
47    /// When a query involves more relations than this threshold, the optimizer
48    /// switches to a greedy algorithm for join ordering to avoid exponential
49    /// time complexity. Default: 10 relations.
50    pub dp_threshold: usize,
51
52    /// Heat threshold for recommending index creation.
53    ///
54    /// Relations with access heat above this threshold are candidates for
55    /// index building to accelerate future queries. Default: 0.7.
56    pub index_heat_threshold: f32,
57
58    /// Enable predicate pushdown optimization.
59    ///
60    /// When enabled, filter predicates are pushed down through projections
61    /// and joins to be applied as early as possible. Default: true.
62    pub enable_pushdown: bool,
63
64    /// Default selectivity for filters when no statistics are available.
65    ///
66    /// Used as a fallback when column statistics cannot provide a better
67    /// estimate. Default: 0.1 (10% selectivity).
68    pub default_filter_selectivity: f64,
69
70    /// Cost multiplier for GPU-to-host data transfers.
71    ///
72    /// Transfers are expensive operations; this multiplier reflects the
73    /// relative cost compared to local GPU operations. Default: 100.0.
74    pub transfer_cost_multiplier: f64,
75
76    /// Bytes per row used for GPU memory estimation when schema is unknown.
77    ///
78    /// Default: 32 bytes (assumes 4 columns at 8 bytes each on average).
79    pub default_bytes_per_row: u64,
80}
81
82impl Default for OptimizerConfig {
83    fn default() -> Self {
84        Self {
85            dp_threshold: 10,
86            index_heat_threshold: 0.7,
87            enable_pushdown: true,
88            default_filter_selectivity: 0.1,
89            transfer_cost_multiplier: 100.0,
90            default_bytes_per_row: 32,
91        }
92    }
93}
94
95/// Cost estimate for a query plan node.
96///
97/// Captures the multi-dimensional cost of executing a plan node, enabling
98/// the optimizer to make informed decisions based on available resources.
99#[derive(Debug, Clone, Default, PartialEq)]
100pub struct PlanCost {
101    /// Estimated number of output rows.
102    pub rows: u64,
103
104    /// Estimated CPU cost (arbitrary units, relative comparisons only).
105    ///
106    /// This represents processing overhead that cannot be parallelized on
107    /// the GPU, such as coordination, scheduling, and result materialization.
108    pub cpu_cost: f64,
109
110    /// Estimated GPU memory usage in bytes.
111    ///
112    /// Includes both input buffers and intermediate storage required for
113    /// the operation.
114    pub gpu_mem: u64,
115
116    /// Number of GPU-to-host or host-to-GPU data transfers.
117    ///
118    /// Transfers are typically the most expensive operations in GPU computing
119    /// and should be minimized.
120    pub transfers: u32,
121}
122
123impl PlanCost {
124    /// Creates a new cost estimate with specified row count.
125    pub fn with_rows(rows: u64) -> Self {
126        Self {
127            rows,
128            ..Default::default()
129        }
130    }
131
132    /// Computes a scalar cost value for comparison purposes.
133    ///
134    /// The formula weights different cost components:
135    /// - CPU cost is taken directly
136    /// - GPU memory is scaled by 0.001 (1GB = 1M cost units)
137    /// - Transfers are heavily weighted due to their high latency
138    ///
139    /// # Arguments
140    ///
141    /// * `transfer_weight` - Weight multiplier for transfer costs
142    pub fn total_cost(&self, transfer_weight: f64) -> f64 {
143        self.cpu_cost + (self.gpu_mem as f64 * 0.001) + (self.transfers as f64 * transfer_weight)
144    }
145
146    /// Combines two costs representing sequential operations.
147    ///
148    /// Row count comes from the second (later) operation; other costs are summed.
149    pub fn then(self, other: PlanCost) -> PlanCost {
150        PlanCost {
151            rows: other.rows,
152            cpu_cost: self.cpu_cost + other.cpu_cost,
153            gpu_mem: self.gpu_mem.max(other.gpu_mem), // Peak memory usage
154            transfers: self.transfers + other.transfers,
155        }
156    }
157}
158
159/// Query optimizer using statistics for cost-based decisions.
160///
161/// The optimizer transforms query plans to improve execution efficiency
162/// by applying rewrites like predicate pushdown and using statistics to
163/// estimate costs for different plan alternatives.
164pub struct Optimizer {
165    stats: Arc<StatsManager>,
166    config: OptimizerConfig,
167    /// Schemas for relations, keyed by RelId
168    schemas: HashMap<RelId, Schema>,
169}
170
171impl Optimizer {
172    /// Creates a new optimizer with default configuration.
173    ///
174    /// # Arguments
175    ///
176    /// * `stats` - Shared statistics manager for cardinality and selectivity estimates
177    pub fn new(stats: Arc<StatsManager>) -> Self {
178        Self {
179            stats,
180            config: OptimizerConfig::default(),
181            schemas: HashMap::new(),
182        }
183    }
184
185    /// Creates a new optimizer with custom configuration.
186    ///
187    /// # Arguments
188    ///
189    /// * `stats` - Shared statistics manager
190    /// * `config` - Custom optimizer configuration
191    pub fn with_config(stats: Arc<StatsManager>, config: OptimizerConfig) -> Self {
192        Self {
193            stats,
194            config,
195            schemas: HashMap::new(),
196        }
197    }
198
199    /// Sets the schemas for relations.
200    ///
201    /// This information is used by the optimizer to accurately determine
202    /// column widths during predicate pushdown.
203    pub fn set_schemas(&mut self, schemas: HashMap<RelId, Schema>) {
204        self.schemas = schemas;
205    }
206
207    /// Returns a reference to the current configuration.
208    pub fn config(&self) -> &OptimizerConfig {
209        &self.config
210    }
211
212    /// Returns a reference to the statistics manager.
213    pub fn stats(&self) -> &Arc<StatsManager> {
214        &self.stats
215    }
216
217    /// Optimizes an execution plan by applying transformation rules.
218    ///
219    /// Currently applies:
220    /// - Predicate pushdown (if enabled)
221    ///
222    /// Future optimizations may include:
223    /// - Join reordering based on cardinality estimates
224    /// - Projection pushdown
225    /// - Common subexpression elimination
226    ///
227    /// # Arguments
228    ///
229    /// * `node` - The plan to optimize
230    ///
231    /// # Returns
232    ///
233    /// An optimized plan that is semantically equivalent to the input
234    pub fn optimize(&self, node: RirNode) -> RirNode {
235        if self.config.enable_pushdown {
236            self.predicate_pushdown(node)
237        } else {
238            node
239        }
240    }
241
242    /// Pushes filter predicates closer to scan nodes.
243    ///
244    /// This transformation reduces intermediate result sizes by applying
245    /// filters as early as possible in the query pipeline. The rules are:
246    ///
247    /// - Filters can be pushed through projections (with column remapping)
248    /// - Filters can be pushed into one or both sides of a join if the
249    ///   predicate references only columns from that side
250    /// - Filters on join keys can inform join selectivity estimates
251    ///
252    /// # Arguments
253    ///
254    /// * `node` - The plan node to transform
255    ///
256    /// # Returns
257    ///
258    /// The transformed plan with predicates pushed down where beneficial
259    fn predicate_pushdown(&self, node: RirNode) -> RirNode {
260        match node {
261            // Base case: scan nodes cannot be transformed further
262            RirNode::Unit => RirNode::Unit,
263            RirNode::Scan { rel } => RirNode::Scan { rel },
264
265            // Filter on top of another node: try to push down
266            RirNode::Filter { input, predicate } => {
267                // First, recursively optimize the input
268                let optimized_input = self.predicate_pushdown(*input);
269
270                match optimized_input {
271                    // Filter on Filter: merge predicates
272                    RirNode::Filter {
273                        input: inner_input,
274                        predicate: inner_pred,
275                    } => {
276                        let merged = Expr::And(vec![inner_pred, predicate]);
277                        RirNode::Filter {
278                            input: inner_input,
279                            predicate: merged,
280                        }
281                    }
282
283                    // Filter on Project: push through if possible
284                    RirNode::Project {
285                        input: proj_input,
286                        columns,
287                    } => {
288                        // Check if predicate only references pass-through columns
289                        if let Some(remapped) =
290                            self.remap_predicate_through_project(&predicate, &columns)
291                        {
292                            // Push the remapped predicate below the projection
293                            RirNode::Project {
294                                input: Box::new(RirNode::Filter {
295                                    input: proj_input,
296                                    predicate: remapped,
297                                }),
298                                columns,
299                            }
300                        } else {
301                            // Cannot push: keep filter above
302                            RirNode::Filter {
303                                input: Box::new(RirNode::Project {
304                                    input: proj_input,
305                                    columns,
306                                }),
307                                predicate,
308                            }
309                        }
310                    }
311
312                    // Filter on Join: try to push to appropriate side
313                    RirNode::Join {
314                        left,
315                        right,
316                        left_keys,
317                        right_keys,
318                        join_type,
319                    } => {
320                        let left_width = self.estimate_width(&left);
321                        let (left_preds, right_preds, remaining) =
322                            self.split_predicate_for_join(&predicate, left_width);
323
324                        // Apply pushed predicates to each side
325                        let new_left = if !left_preds.is_empty() {
326                            Box::new(RirNode::Filter {
327                                input: left,
328                                predicate: Self::conjoin(left_preds),
329                            })
330                        } else {
331                            left
332                        };
333
334                        let new_right = if !right_preds.is_empty() {
335                            Box::new(RirNode::Filter {
336                                input: right,
337                                predicate: Self::conjoin(right_preds),
338                            })
339                        } else {
340                            right
341                        };
342
343                        let join_node = RirNode::Join {
344                            left: new_left,
345                            right: new_right,
346                            left_keys,
347                            right_keys,
348                            join_type,
349                        };
350
351                        // Apply remaining predicates that couldn't be pushed
352                        if !remaining.is_empty() {
353                            RirNode::Filter {
354                                input: Box::new(join_node),
355                                predicate: Self::conjoin(remaining),
356                            }
357                        } else {
358                            join_node
359                        }
360                    }
361
362                    // Default: cannot push further
363                    other => RirNode::Filter {
364                        input: Box::new(other),
365                        predicate,
366                    },
367                }
368            }
369
370            // Project: recursively optimize input
371            RirNode::Project { input, columns } => RirNode::Project {
372                input: Box::new(self.predicate_pushdown(*input)),
373                columns,
374            },
375
376            // Join: recursively optimize both sides
377            RirNode::Join {
378                left,
379                right,
380                left_keys,
381                right_keys,
382                join_type,
383            } => RirNode::Join {
384                left: Box::new(self.predicate_pushdown(*left)),
385                right: Box::new(self.predicate_pushdown(*right)),
386                left_keys,
387                right_keys,
388                join_type,
389            },
390
391            // GroupBy: recursively optimize input
392            RirNode::GroupBy {
393                input,
394                key_cols,
395                aggs,
396            } => RirNode::GroupBy {
397                input: Box::new(self.predicate_pushdown(*input)),
398                key_cols,
399                aggs,
400            },
401
402            // Union: recursively optimize all inputs
403            RirNode::Union { inputs } => RirNode::Union {
404                inputs: inputs
405                    .into_iter()
406                    .map(|i| self.predicate_pushdown(i))
407                    .collect(),
408            },
409
410            // Distinct: recursively optimize input
411            RirNode::Distinct { input, key_cols } => RirNode::Distinct {
412                input: Box::new(self.predicate_pushdown(*input)),
413                key_cols,
414            },
415
416            // Diff: recursively optimize both sides
417            RirNode::Diff { left, right } => RirNode::Diff {
418                left: Box::new(self.predicate_pushdown(*left)),
419                right: Box::new(self.predicate_pushdown(*right)),
420            },
421
422            // Fixpoint: recursively optimize base and recursive parts
423            RirNode::Fixpoint {
424                scc_id,
425                base,
426                recursive,
427                delta_rel,
428                full_rel,
429            } => RirNode::Fixpoint {
430                scc_id,
431                base: Box::new(self.predicate_pushdown(*base)),
432                recursive: Box::new(self.predicate_pushdown(*recursive)),
433                delta_rel,
434                full_rel,
435            },
436
437            RirNode::TensorMaskedJoin { .. } => node, // Leaf-like: no pushdown
438
439            // Promoted physical-shape nodes are produced after the
440            // optimizer runs. Required for compile safety and as a
441            // no-op fallback if the call order ever changes.
442            RirNode::MultiWayJoin { .. } | RirNode::ChainJoin { .. } => node,
443        }
444    }
445
446    /// Attempts to remap a predicate through a projection.
447    ///
448    /// Returns `Some(remapped_predicate)` if all column references in the
449    /// predicate can be traced back through pass-through columns.
450    /// Returns `None` if the predicate references computed columns.
451    fn remap_predicate_through_project(
452        &self,
453        predicate: &Expr,
454        columns: &[xlog_ir::ProjectExpr],
455    ) -> Option<Expr> {
456        // Build a mapping from output column index to input column index
457        // Only for pass-through columns
458        let mut output_to_input: std::collections::HashMap<usize, usize> =
459            std::collections::HashMap::new();
460
461        for (out_idx, proj_expr) in columns.iter().enumerate() {
462            if let xlog_ir::ProjectExpr::Column(in_idx) = proj_expr {
463                output_to_input.insert(out_idx, *in_idx);
464            }
465        }
466
467        self.remap_expr(predicate, &output_to_input)
468    }
469
470    /// Recursively remaps column references in an expression.
471    fn remap_expr(
472        &self,
473        expr: &Expr,
474        mapping: &std::collections::HashMap<usize, usize>,
475    ) -> Option<Expr> {
476        match expr {
477            Expr::Column(idx) => mapping.get(idx).map(|&new_idx| Expr::Column(new_idx)),
478
479            Expr::Const(val) => Some(Expr::Const(val.clone())),
480
481            Expr::Compare { left, op, right } => {
482                let new_left = self.remap_expr(left, mapping)?;
483                let new_right = self.remap_expr(right, mapping)?;
484                Some(Expr::Compare {
485                    left: Box::new(new_left),
486                    op: *op,
487                    right: Box::new(new_right),
488                })
489            }
490
491            Expr::And(exprs) => {
492                let remapped: Option<Vec<_>> =
493                    exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
494                remapped.map(Expr::And)
495            }
496
497            Expr::Or(exprs) => {
498                let remapped: Option<Vec<_>> =
499                    exprs.iter().map(|e| self.remap_expr(e, mapping)).collect();
500                remapped.map(Expr::Or)
501            }
502
503            Expr::Not(inner) => {
504                let remapped = self.remap_expr(inner, mapping)?;
505                Some(Expr::Not(Box::new(remapped)))
506            }
507
508            // Arithmetic operations
509            Expr::Add(l, r) => {
510                let new_l = self.remap_expr(l, mapping)?;
511                let new_r = self.remap_expr(r, mapping)?;
512                Some(Expr::Add(Box::new(new_l), Box::new(new_r)))
513            }
514            Expr::Sub(l, r) => {
515                let new_l = self.remap_expr(l, mapping)?;
516                let new_r = self.remap_expr(r, mapping)?;
517                Some(Expr::Sub(Box::new(new_l), Box::new(new_r)))
518            }
519            Expr::Mul(l, r) => {
520                let new_l = self.remap_expr(l, mapping)?;
521                let new_r = self.remap_expr(r, mapping)?;
522                Some(Expr::Mul(Box::new(new_l), Box::new(new_r)))
523            }
524            Expr::Div(l, r) => {
525                let new_l = self.remap_expr(l, mapping)?;
526                let new_r = self.remap_expr(r, mapping)?;
527                Some(Expr::Div(Box::new(new_l), Box::new(new_r)))
528            }
529            Expr::Mod(l, r) => {
530                let new_l = self.remap_expr(l, mapping)?;
531                let new_r = self.remap_expr(r, mapping)?;
532                Some(Expr::Mod(Box::new(new_l), Box::new(new_r)))
533            }
534
535            // Built-in functions
536            Expr::Abs(inner) => {
537                let remapped = self.remap_expr(inner, mapping)?;
538                Some(Expr::Abs(Box::new(remapped)))
539            }
540            Expr::Min(l, r) => {
541                let new_l = self.remap_expr(l, mapping)?;
542                let new_r = self.remap_expr(r, mapping)?;
543                Some(Expr::Min(Box::new(new_l), Box::new(new_r)))
544            }
545            Expr::Max(l, r) => {
546                let new_l = self.remap_expr(l, mapping)?;
547                let new_r = self.remap_expr(r, mapping)?;
548                Some(Expr::Max(Box::new(new_l), Box::new(new_r)))
549            }
550            Expr::Pow(l, r) => {
551                let new_l = self.remap_expr(l, mapping)?;
552                let new_r = self.remap_expr(r, mapping)?;
553                Some(Expr::Pow(Box::new(new_l), Box::new(new_r)))
554            }
555            Expr::Cast(inner, scalar_type) => {
556                let remapped = self.remap_expr(inner, mapping)?;
557                Some(Expr::Cast(Box::new(remapped), *scalar_type))
558            }
559            Expr::Conditional {
560                condition,
561                then_expr,
562                else_expr,
563            } => {
564                let new_condition = self.remap_expr(condition, mapping)?;
565                let new_then = self.remap_expr(then_expr, mapping)?;
566                let new_else = self.remap_expr(else_expr, mapping)?;
567                Some(Expr::Conditional {
568                    condition: Box::new(new_condition),
569                    then_expr: Box::new(new_then),
570                    else_expr: Box::new(new_else),
571                })
572            }
573        }
574    }
575
576    /// Estimates the output width (number of columns) of a plan node.
577    fn estimate_width(&self, node: &RirNode) -> usize {
578        match node {
579            RirNode::Unit => 0,
580            RirNode::Scan { rel } => {
581                // Use schema if available, otherwise stats, otherwise default
582                if let Some(schema) = self.schemas.get(rel) {
583                    schema.arity()
584                } else if let Some(stats) = self.stats.get_relation_stats(*rel) {
585                    stats.column_stats.len().max(1)
586                } else {
587                    4 // Default assumption
588                }
589            }
590            RirNode::Filter { input, .. } => self.estimate_width(input),
591            RirNode::Project { columns, .. } => columns.len(),
592            RirNode::Join { left, right, .. } => {
593                self.estimate_width(left) + self.estimate_width(right)
594            }
595            RirNode::ChainJoin { output_columns, .. } => output_columns.len(),
596            RirNode::GroupBy { key_cols, aggs, .. } => key_cols.len() + aggs.len(),
597            RirNode::Union { inputs } => {
598                inputs.first().map(|i| self.estimate_width(i)).unwrap_or(0)
599            }
600            RirNode::Distinct { input, .. } => self.estimate_width(input),
601            RirNode::Diff { left, .. } => self.estimate_width(left),
602            RirNode::Fixpoint { base, .. } => self.estimate_width(base),
603            // TensorMaskedJoin schemas are keyed by RelId.
604            // Use head_rel_id (not head_rel_name) for lookup.
605            RirNode::TensorMaskedJoin { head_rel_id, .. } => self
606                .schemas
607                .get(head_rel_id)
608                .map(|s| s.arity())
609                .unwrap_or(2),
610            // MultiWayJoin is produced after promotion; width equals the
611            // head projection arity, mirroring the Project arm.
612            RirNode::MultiWayJoin { output_columns, .. } => output_columns.len(),
613        }
614    }
615
616    /// Splits a predicate into parts pushable to left, right, or neither side of a join.
617    ///
618    /// Returns (left_predicates, right_predicates, remaining_predicates).
619    fn split_predicate_for_join(
620        &self,
621        predicate: &Expr,
622        left_width: usize,
623    ) -> (Vec<Expr>, Vec<Expr>, Vec<Expr>) {
624        let mut left_preds = Vec::new();
625        let mut right_preds = Vec::new();
626        let mut remaining = Vec::new();
627
628        // Flatten AND expressions
629        let conjuncts = Self::flatten_and(predicate);
630
631        for conj in conjuncts {
632            let cols = Self::collect_columns(&conj);
633            let max_col = cols.iter().copied().max().unwrap_or(0);
634            let min_col = cols.iter().copied().min().unwrap_or(0);
635
636            if cols.is_empty() {
637                // No columns referenced, can push to either side
638                left_preds.push(conj);
639            } else if max_col < left_width {
640                // All columns from left side
641                left_preds.push(conj);
642            } else if min_col >= left_width {
643                // All columns from right side - need to remap
644                let remapped = Self::remap_columns(&conj, |c| c - left_width);
645                right_preds.push(remapped);
646            } else {
647                // References both sides, cannot push
648                remaining.push(conj);
649            }
650        }
651
652        (left_preds, right_preds, remaining)
653    }
654
655    /// Flattens nested AND expressions into a list of conjuncts.
656    fn flatten_and(expr: &Expr) -> Vec<Expr> {
657        match expr {
658            Expr::And(exprs) => exprs.iter().flat_map(Self::flatten_and).collect(),
659            other => vec![other.clone()],
660        }
661    }
662
663    /// Collects all column indices referenced in an expression.
664    fn collect_columns(expr: &Expr) -> Vec<usize> {
665        match expr {
666            Expr::Column(idx) => vec![*idx],
667            Expr::Const(_) => vec![],
668            Expr::Compare { left, right, .. } => {
669                let mut cols = Self::collect_columns(left);
670                cols.extend(Self::collect_columns(right));
671                cols
672            }
673            Expr::And(exprs) | Expr::Or(exprs) => {
674                exprs.iter().flat_map(Self::collect_columns).collect()
675            }
676            Expr::Not(inner) | Expr::Abs(inner) | Expr::Cast(inner, _) => {
677                Self::collect_columns(inner)
678            }
679            Expr::Add(l, r)
680            | Expr::Sub(l, r)
681            | Expr::Mul(l, r)
682            | Expr::Div(l, r)
683            | Expr::Mod(l, r)
684            | Expr::Min(l, r)
685            | Expr::Max(l, r)
686            | Expr::Pow(l, r) => {
687                let mut cols = Self::collect_columns(l);
688                cols.extend(Self::collect_columns(r));
689                cols
690            }
691            Expr::Conditional {
692                condition,
693                then_expr,
694                else_expr,
695            } => {
696                let mut cols = Self::collect_columns(condition);
697                cols.extend(Self::collect_columns(then_expr));
698                cols.extend(Self::collect_columns(else_expr));
699                cols
700            }
701        }
702    }
703
704    /// Remaps column references in an expression using a transformation function.
705    fn remap_columns<F: Fn(usize) -> usize + Copy>(expr: &Expr, f: F) -> Expr {
706        match expr {
707            Expr::Column(idx) => Expr::Column(f(*idx)),
708            Expr::Const(v) => Expr::Const(v.clone()),
709            Expr::Compare { left, op, right } => Expr::Compare {
710                left: Box::new(Self::remap_columns(left, f)),
711                op: *op,
712                right: Box::new(Self::remap_columns(right, f)),
713            },
714            Expr::And(exprs) => {
715                Expr::And(exprs.iter().map(|e| Self::remap_columns(e, f)).collect())
716            }
717            Expr::Or(exprs) => Expr::Or(exprs.iter().map(|e| Self::remap_columns(e, f)).collect()),
718            Expr::Not(inner) => Expr::Not(Box::new(Self::remap_columns(inner, f))),
719            Expr::Add(l, r) => Expr::Add(
720                Box::new(Self::remap_columns(l, f)),
721                Box::new(Self::remap_columns(r, f)),
722            ),
723            Expr::Sub(l, r) => Expr::Sub(
724                Box::new(Self::remap_columns(l, f)),
725                Box::new(Self::remap_columns(r, f)),
726            ),
727            Expr::Mul(l, r) => Expr::Mul(
728                Box::new(Self::remap_columns(l, f)),
729                Box::new(Self::remap_columns(r, f)),
730            ),
731            Expr::Div(l, r) => Expr::Div(
732                Box::new(Self::remap_columns(l, f)),
733                Box::new(Self::remap_columns(r, f)),
734            ),
735            Expr::Mod(l, r) => Expr::Mod(
736                Box::new(Self::remap_columns(l, f)),
737                Box::new(Self::remap_columns(r, f)),
738            ),
739            Expr::Abs(inner) => Expr::Abs(Box::new(Self::remap_columns(inner, f))),
740            Expr::Min(l, r) => Expr::Min(
741                Box::new(Self::remap_columns(l, f)),
742                Box::new(Self::remap_columns(r, f)),
743            ),
744            Expr::Max(l, r) => Expr::Max(
745                Box::new(Self::remap_columns(l, f)),
746                Box::new(Self::remap_columns(r, f)),
747            ),
748            Expr::Pow(l, r) => Expr::Pow(
749                Box::new(Self::remap_columns(l, f)),
750                Box::new(Self::remap_columns(r, f)),
751            ),
752            Expr::Cast(inner, t) => Expr::Cast(Box::new(Self::remap_columns(inner, f)), *t),
753            Expr::Conditional {
754                condition,
755                then_expr,
756                else_expr,
757            } => Expr::Conditional {
758                condition: Box::new(Self::remap_columns(condition, f)),
759                then_expr: Box::new(Self::remap_columns(then_expr, f)),
760                else_expr: Box::new(Self::remap_columns(else_expr, f)),
761            },
762        }
763    }
764
765    /// Combines a list of predicates into a single AND expression.
766    fn conjoin(predicates: Vec<Expr>) -> Expr {
767        debug_assert!(!predicates.is_empty());
768        if predicates.len() == 1 {
769            predicates.into_iter().next().unwrap()
770        } else {
771            Expr::And(predicates)
772        }
773    }
774
775    /// Estimates the cost of executing a plan node.
776    ///
777    /// Recursively computes cost estimates for the entire plan tree,
778    /// using statistics when available and falling back to heuristics.
779    ///
780    /// # Arguments
781    ///
782    /// * `node` - The plan node to estimate
783    ///
784    /// # Returns
785    ///
786    /// A [`PlanCost`] with estimated rows, CPU cost, GPU memory, and transfers
787    pub fn estimate_cost(&self, node: &RirNode) -> PlanCost {
788        match node {
789            RirNode::Unit => PlanCost {
790                rows: 1,
791                cpu_cost: 0.0,
792                gpu_mem: 0,
793                transfers: 0,
794            },
795            RirNode::Scan { rel } => self.estimate_scan_cost(*rel),
796
797            RirNode::Filter { input, predicate } => {
798                let input_cost = self.estimate_cost(input);
799                self.estimate_filter_cost(input_cost, predicate, input)
800            }
801
802            RirNode::Project { input, columns } => {
803                let input_cost = self.estimate_cost(input);
804                self.estimate_project_cost(input_cost, columns)
805            }
806
807            RirNode::Join {
808                left,
809                right,
810                left_keys,
811                right_keys,
812                join_type,
813            } => {
814                let left_cost = self.estimate_cost(left);
815                let right_cost = self.estimate_cost(right);
816                self.estimate_join_cost(
817                    left_cost, right_cost, left, right, left_keys, right_keys, *join_type,
818                )
819            }
820
821            RirNode::ChainJoin {
822                left,
823                right,
824                left_key,
825                right_key,
826                output_columns,
827                ..
828            } => {
829                let left_cost = self.estimate_cost(left);
830                let right_cost = self.estimate_cost(right);
831                let join_cost = self.estimate_join_cost(
832                    left_cost,
833                    right_cost,
834                    left,
835                    right,
836                    &[*left_key],
837                    &[*right_key],
838                    JoinType::Inner,
839                );
840                self.estimate_project_cost(join_cost, output_columns)
841            }
842
843            RirNode::GroupBy {
844                input,
845                key_cols,
846                aggs,
847            } => {
848                let input_cost = self.estimate_cost(input);
849                self.estimate_groupby_cost(input_cost, key_cols, aggs)
850            }
851
852            RirNode::Union { inputs } => {
853                let costs: Vec<_> = inputs.iter().map(|i| self.estimate_cost(i)).collect();
854                self.estimate_union_cost(costs)
855            }
856
857            RirNode::Distinct { input, key_cols } => {
858                let input_cost = self.estimate_cost(input);
859                self.estimate_distinct_cost(input_cost, key_cols)
860            }
861
862            RirNode::Diff { left, right } => {
863                let left_cost = self.estimate_cost(left);
864                let right_cost = self.estimate_cost(right);
865                self.estimate_diff_cost(left_cost, right_cost)
866            }
867
868            RirNode::Fixpoint {
869                base, recursive, ..
870            } => {
871                let base_cost = self.estimate_cost(base);
872                let recursive_cost = self.estimate_cost(recursive);
873                self.estimate_fixpoint_cost(base_cost, recursive_cost)
874            }
875
876            RirNode::TensorMaskedJoin {
877                max_active_rules, ..
878            } => PlanCost {
879                rows: *max_active_rules as u64,
880                cpu_cost: *max_active_rules as f64 * 100.0,
881                gpu_mem: *max_active_rules as u64 * 1024,
882                transfers: 1,
883            },
884            // MultiWayJoin heuristic cost is the sum of input scan costs.
885            // Post-promoter dispatch decides whether to run the WCOJ kernel
886            // or fall back; full multiway cost-model integration is separate
887            // planner work.
888            RirNode::MultiWayJoin { inputs, .. } => {
889                let mut total = PlanCost::default();
890                for inp in inputs {
891                    let c = self.estimate_cost(inp);
892                    total.rows = total.rows.saturating_add(c.rows);
893                    total.cpu_cost += c.cpu_cost;
894                    total.gpu_mem = total.gpu_mem.saturating_add(c.gpu_mem);
895                    total.transfers = total.transfers.saturating_add(c.transfers);
896                }
897                total
898            }
899        }
900    }
901
902    /// Estimates cost for a base relation scan.
903    fn estimate_scan_cost(&self, rel: RelId) -> PlanCost {
904        if let Some(stats) = self.stats.get_relation_stats(rel) {
905            PlanCost {
906                rows: stats.cardinality,
907                cpu_cost: stats.cardinality as f64 * 0.01, // Minimal per-row CPU cost
908                gpu_mem: stats
909                    .byte_size
910                    .max(stats.cardinality * self.config.default_bytes_per_row),
911                transfers: 0, // Data already on GPU
912            }
913        } else {
914            // Default estimates for unknown relations
915            let default_rows = 1000;
916            PlanCost {
917                rows: default_rows,
918                cpu_cost: default_rows as f64 * 0.01,
919                gpu_mem: default_rows * self.config.default_bytes_per_row,
920                transfers: 0,
921            }
922        }
923    }
924
925    /// Estimates cost for a filter operation.
926    fn estimate_filter_cost(
927        &self,
928        input_cost: PlanCost,
929        predicate: &Expr,
930        input: &RirNode,
931    ) -> PlanCost {
932        let selectivity = self.estimate_predicate_selectivity(predicate, input);
933        let output_rows = ((input_cost.rows as f64 * selectivity) as u64).max(1);
934
935        PlanCost {
936            rows: output_rows,
937            cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.02, // Predicate eval cost
938            gpu_mem: input_cost.gpu_mem, // Filter reuses input memory
939            transfers: input_cost.transfers,
940        }
941    }
942
943    /// Estimates cost for a projection operation.
944    fn estimate_project_cost(
945        &self,
946        input_cost: PlanCost,
947        columns: &[xlog_ir::ProjectExpr],
948    ) -> PlanCost {
949        // Count computed vs pass-through columns
950        let computed_count = columns
951            .iter()
952            .filter(|c| matches!(c, xlog_ir::ProjectExpr::Computed(_, _)))
953            .count();
954
955        // Computed columns add CPU cost
956        let compute_cost = computed_count as f64 * input_cost.rows as f64 * 0.05;
957
958        // Output size may be smaller if fewer columns
959        let output_width_ratio = columns.len() as f64 / (columns.len() + 2) as f64; // Rough estimate
960
961        PlanCost {
962            rows: input_cost.rows,
963            cpu_cost: input_cost.cpu_cost + compute_cost,
964            gpu_mem: (input_cost.gpu_mem as f64 * output_width_ratio) as u64,
965            transfers: input_cost.transfers,
966        }
967    }
968
969    /// Estimates cost for a join operation.
970    #[allow(clippy::too_many_arguments)]
971    fn estimate_join_cost(
972        &self,
973        left_cost: PlanCost,
974        right_cost: PlanCost,
975        left: &RirNode,
976        right: &RirNode,
977        left_keys: &[usize],
978        right_keys: &[usize],
979        join_type: JoinType,
980    ) -> PlanCost {
981        // Semi and Anti joins always produce at most left_cost.rows
982        // Handle these specially before checking stats
983        let output_rows = match join_type {
984            JoinType::Semi => {
985                // At most left side rows, estimate 50% match
986                ((left_cost.rows as f64 * 0.5) as u64).max(1)
987            }
988            JoinType::Anti => {
989                // At most left side rows, estimate 50% don't match
990                ((left_cost.rows as f64 * 0.5) as u64).max(1)
991            }
992            JoinType::Inner | JoinType::LeftOuter => {
993                // Get relation IDs for selectivity lookup
994                let left_rels = left.referenced_relations();
995                let right_rels = right.referenced_relations();
996
997                if left_rels.len() == 1 && right_rels.len() == 1 {
998                    // Simple join between two base relations
999                    let estimated = self.stats.estimate_join_cardinality(
1000                        left_rels[0],
1001                        right_rels[0],
1002                        left_keys,
1003                        right_keys,
1004                    );
1005
1006                    match join_type {
1007                        JoinType::LeftOuter => estimated.max(left_cost.rows),
1008                        _ => estimated,
1009                    }
1010                } else {
1011                    // Multi-way or complex join: use heuristic
1012                    match join_type {
1013                        JoinType::Inner => {
1014                            // Assume 10% selectivity for inner joins
1015                            ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64).max(1)
1016                        }
1017                        JoinType::LeftOuter => {
1018                            // At least left side rows
1019                            left_cost.rows.max(
1020                                ((left_cost.rows as f64 * right_cost.rows as f64 * 0.1) as u64)
1021                                    .max(1),
1022                            )
1023                        }
1024                        _ => unreachable!(),
1025                    }
1026                }
1027            }
1028        };
1029
1030        // Join CPU cost: hash build + probe
1031        let build_cost = right_cost.rows as f64 * 1.0; // Build hash table
1032        let probe_cost = left_cost.rows as f64 * 0.5; // Probe operations
1033        let cpu_cost = left_cost.cpu_cost + right_cost.cpu_cost + build_cost + probe_cost;
1034
1035        // GPU memory: both inputs plus hash table overhead
1036        let hash_table_overhead = right_cost.gpu_mem / 2; // Approximate hash table size
1037        let gpu_mem = left_cost.gpu_mem + right_cost.gpu_mem + hash_table_overhead;
1038
1039        PlanCost {
1040            rows: output_rows,
1041            cpu_cost,
1042            gpu_mem,
1043            transfers: left_cost.transfers + right_cost.transfers,
1044        }
1045    }
1046
1047    /// Estimates cost for a group-by with aggregation.
1048    fn estimate_groupby_cost(
1049        &self,
1050        input_cost: PlanCost,
1051        key_cols: &[usize],
1052        _aggs: &[(usize, xlog_core::AggOp)],
1053    ) -> PlanCost {
1054        // Estimate distinct groups based on key columns
1055        // Heuristic: sqrt(input_rows) for unknown cardinality
1056        let estimated_groups = if key_cols.is_empty() {
1057            1 // No grouping = single result
1058        } else {
1059            // Rough estimate: assume good reduction
1060            ((input_cost.rows as f64).sqrt() as u64).max(1)
1061        };
1062
1063        PlanCost {
1064            rows: estimated_groups,
1065            cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.5, // Aggregation cost
1066            gpu_mem: input_cost.gpu_mem + estimated_groups * self.config.default_bytes_per_row,
1067            transfers: input_cost.transfers,
1068        }
1069    }
1070
1071    /// Estimates cost for a union operation.
1072    fn estimate_union_cost(&self, input_costs: Vec<PlanCost>) -> PlanCost {
1073        let total_rows: u64 = input_costs.iter().map(|c| c.rows).sum();
1074        let total_cpu: f64 = input_costs.iter().map(|c| c.cpu_cost).sum();
1075        let max_gpu: u64 = input_costs.iter().map(|c| c.gpu_mem).max().unwrap_or(0);
1076        let total_transfers: u32 = input_costs.iter().map(|c| c.transfers).sum();
1077
1078        PlanCost {
1079            rows: total_rows,
1080            cpu_cost: total_cpu + total_rows as f64 * 0.01, // Concatenation cost
1081            gpu_mem: max_gpu,                               // Can process sequentially
1082            transfers: total_transfers,
1083        }
1084    }
1085
1086    /// Estimates cost for a distinct operation.
1087    fn estimate_distinct_cost(&self, input_cost: PlanCost, _key_cols: &[usize]) -> PlanCost {
1088        // Heuristic: distinct reduces rows by some factor
1089        let estimated_distinct = (input_cost.rows as f64 * 0.7) as u64;
1090
1091        PlanCost {
1092            rows: estimated_distinct.max(1),
1093            cpu_cost: input_cost.cpu_cost + input_cost.rows as f64 * 0.3, // Hash-based dedup
1094            gpu_mem: input_cost.gpu_mem + input_cost.rows * 8,            // Hash set overhead
1095            transfers: input_cost.transfers,
1096        }
1097    }
1098
1099    /// Estimates cost for a set difference operation.
1100    fn estimate_diff_cost(&self, left_cost: PlanCost, right_cost: PlanCost) -> PlanCost {
1101        // Diff removes matching rows from left
1102        let estimated_remaining = (left_cost.rows as f64 * 0.5) as u64;
1103
1104        PlanCost {
1105            rows: estimated_remaining.max(1),
1106            cpu_cost: left_cost.cpu_cost + right_cost.cpu_cost + right_cost.rows as f64 * 0.5,
1107            gpu_mem: left_cost.gpu_mem + right_cost.gpu_mem,
1108            transfers: left_cost.transfers + right_cost.transfers,
1109        }
1110    }
1111
1112    /// Estimates cost for a fixpoint (recursive) operation.
1113    fn estimate_fixpoint_cost(&self, base_cost: PlanCost, recursive_cost: PlanCost) -> PlanCost {
1114        // Fixpoint cost depends on number of iterations
1115        // Heuristic: assume log2(base_rows) iterations
1116        let estimated_iterations = ((base_cost.rows as f64).log2().ceil() as u64).max(1);
1117
1118        PlanCost {
1119            rows: base_cost.rows * estimated_iterations, // Output accumulates
1120            cpu_cost: base_cost.cpu_cost + recursive_cost.cpu_cost * estimated_iterations as f64,
1121            gpu_mem: (base_cost.gpu_mem + recursive_cost.gpu_mem) * 2, // Need delta and full
1122            transfers: base_cost.transfers + recursive_cost.transfers * estimated_iterations as u32,
1123        }
1124    }
1125
1126    /// Estimates selectivity of a predicate expression.
1127    fn estimate_predicate_selectivity(&self, predicate: &Expr, input: &RirNode) -> f64 {
1128        match predicate {
1129            Expr::Compare { left, op, right } => {
1130                self.estimate_compare_selectivity(left, *op, right, input)
1131            }
1132            Expr::And(exprs) => {
1133                // Multiply selectivities (independence assumption)
1134                exprs
1135                    .iter()
1136                    .map(|e| self.estimate_predicate_selectivity(e, input))
1137                    .product()
1138            }
1139            Expr::Or(exprs) => {
1140                // P(A or B) = P(A) + P(B) - P(A)P(B) for independent events
1141                // Simplified: max of selectivities as lower bound
1142                exprs
1143                    .iter()
1144                    .map(|e| self.estimate_predicate_selectivity(e, input))
1145                    .fold(0.0, f64::max)
1146            }
1147            Expr::Not(inner) => 1.0 - self.estimate_predicate_selectivity(inner, input),
1148            _ => self.config.default_filter_selectivity,
1149        }
1150    }
1151
1152    /// Estimates selectivity for a comparison predicate.
1153    fn estimate_compare_selectivity(
1154        &self,
1155        left: &Expr,
1156        op: CompareOp,
1157        right: &Expr,
1158        input: &RirNode,
1159    ) -> f64 {
1160        // Try to get column statistics if comparing column to constant
1161        if let (Expr::Column(col_idx), Expr::Const(_)) | (Expr::Const(_), Expr::Column(col_idx)) =
1162            (left, right)
1163        {
1164            // Find the base relation for this column
1165            if let Some(rel_id) = self.find_column_relation(input, *col_idx) {
1166                if let Some(stats) = self.stats.get_relation_stats(rel_id) {
1167                    if let Some(col_stats) = stats.get_column(*col_idx) {
1168                        return match op {
1169                            CompareOp::Eq => col_stats.equality_selectivity(stats.cardinality),
1170                            CompareOp::Ne => {
1171                                1.0 - col_stats.equality_selectivity(stats.cardinality)
1172                            }
1173                            CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => {
1174                                // Range predicates: estimate ~33% selectivity
1175                                0.33
1176                            }
1177                        };
1178                    }
1179                }
1180            }
1181        }
1182
1183        // Default selectivities by operator
1184        match op {
1185            CompareOp::Eq => 0.1, // 10% for equality
1186            CompareOp::Ne => 0.9, // 90% for inequality
1187            CompareOp::Lt | CompareOp::Le | CompareOp::Gt | CompareOp::Ge => 0.33, // 33% for ranges
1188        }
1189    }
1190
1191    /// Finds the base relation that provides a given column.
1192    fn find_column_relation(&self, node: &RirNode, col_idx: usize) -> Option<RelId> {
1193        match node {
1194            RirNode::Scan { rel } => Some(*rel),
1195            RirNode::Filter { input, .. } => self.find_column_relation(input, col_idx),
1196            RirNode::Project { input, columns } => {
1197                // Trace column through projection
1198                if col_idx < columns.len() {
1199                    if let xlog_ir::ProjectExpr::Column(src_idx) = &columns[col_idx] {
1200                        return self.find_column_relation(input, *src_idx);
1201                    }
1202                }
1203                None
1204            }
1205            RirNode::Join { left, right, .. } => {
1206                let left_width = self.estimate_width(left);
1207                if col_idx < left_width {
1208                    self.find_column_relation(left, col_idx)
1209                } else {
1210                    self.find_column_relation(right, col_idx - left_width)
1211                }
1212            }
1213            // MultiWayJoin has no stable column-to-input mapping here.
1214            // The promoter runs after the optimizer, so this arm is
1215            // unreachable in production. A half-mapped implementation that
1216            // walked `inputs` via `slot_vars` would be more dangerous than
1217            // returning None for this optimizer fallback.
1218            RirNode::MultiWayJoin { .. } => None,
1219            _ => None, // Complex cases: give up
1220        }
1221    }
1222
1223    /// Returns relations that should have indexes built based on access heat.
1224    ///
1225    /// This is useful for adaptive query processing where frequently accessed
1226    /// relations benefit from index structures.
1227    pub fn recommend_indexes(&self) -> Vec<RelId> {
1228        self.stats.hot_relations(self.config.index_heat_threshold)
1229    }
1230
1231    /// Returns true if the query involves more relations than the DP threshold.
1232    ///
1233    /// Used to decide between exhaustive and greedy join ordering algorithms.
1234    pub fn should_use_greedy(&self, node: &RirNode) -> bool {
1235        let rels = node.referenced_relations();
1236        let unique_rels: std::collections::HashSet<_> = rels.iter().collect();
1237        unique_rels.len() > self.config.dp_threshold
1238    }
1239}
1240
1241// Selectivity-aware optimizer pass.
1242//
1243// No-op by default for unrecognized shapes. This pass owns the
1244// selectivity-driven join reordering hook; broader planner work may add
1245// reordering logic that consults `stats` for more shapes.
1246//
1247// Walks `plan.rules_by_scc[*].body` and rewrites nodes in place. The default
1248// no-op preserves every existing plan tree byte-for-byte. Tests assert
1249// structural equality for triangle, 4-cycle, and recursive-SCC plans.
1250//
1251// Compile-pipeline ordering: runs between `Optimizer::optimize` and
1252// `xlog_logic::promote::promote_multiway`.
1253pub mod selectivity_pass {
1254    //! Selectivity-driven join reordering for canonical lowered triangle and
1255    //! 4-cycle bodies.
1256    //!
1257    //! ## Behavior
1258    //!
1259    //! For each rule body that matches the canonical lowered
1260    //! triangle or 4-cycle shape, the pass enumerates the valid
1261    //! candidate inner pairings (3 for triangle, 2 for 4-cycle),
1262    //! computes each candidate's
1263    //! `StatsManager::estimate_join_cardinality` with
1264    //! **pair-derived join keys from the shared-variable
1265    //! mapping**, and rewrites the body so the smallest-cost
1266    //! choice is materialized first. Tie → keep the optimizer's
1267    //! existing order (deterministic no-op).
1268    //!
1269    //! ## Safety floor
1270    //!
1271    //! If any input atom for a recognized body has no
1272    //! `StatsManager` entry OR `cardinality == 0`, the body is
1273    //! left unchanged. Recursive deltas / freshly-uploaded
1274    //! relations / unseeded predicates therefore stay on the
1275    //! optimizer's default order until stats are populated.
1276    //!
1277    //! ## Default-fallback edge case
1278    //!
1279    //! `StatsManager::estimate_join_cardinality` returns `u64`
1280    //! with no provenance — the caller cannot tell whether the
1281    //! estimate came from the cached `JoinSelectivity` table,
1282    //! the column-distinct heuristic, or the 10% default
1283    //! fallback. When all input atoms have populated
1284    //! cardinalities but no column statistics, the per-pair
1285    //! estimates may all collapse to the same fallback ratio,
1286    //! making the chosen pairing uninformative. **This is an
1287    //! accepted trade-off**: row-set parity holds regardless of
1288    //! selectivity quality (the rewrite preserves semantics);
1289    //! the integration checks gate on row-set + WCOJ-dispatch
1290    //! correctness, not on optimal pair choice.
1291    //!
1292    //! ## Promoter coordination
1293    //!
1294    //! The triangle and 4-cycle promoters accept the canonical *semantic*
1295    //! shape with any valid key combination; they emit `MultiWayJoin.inputs`
1296    //! and `slot_vars` in canonical semantic order regardless of the body's
1297    //! positional layout. Reordered bodies therefore still promote and still
1298    //! dispatch the WCOJ kernel correctly.
1299    use std::collections::HashMap;
1300    use xlog_core::RelId;
1301    use xlog_ir::ExecutionPlan;
1302    use xlog_stats::StatsManager;
1303
1304    /// Selectivity-driven join reordering for canonical triangle and 4-cycle
1305    /// bodies. See module-level doc.
1306    ///
1307    /// `rel_ids` is the predicate-name → RelId map used to
1308    /// resolve body Scans against `StatsManager` lookups.
1309    /// Production callers pass `Compiler::lowerer().rel_ids()`.
1310    /// Test callers can pass an empty map; with no
1311    /// `StatsManager` entries either, the safety floor leaves
1312    /// every body unchanged (legacy no-op behavior preserved).
1313    pub fn run(plan: &mut ExecutionPlan, stats: &StatsManager, rel_ids: &HashMap<String, RelId>) {
1314        // `rel_ids` is reserved for future shape-extension
1315        // work; the current rewriters operate on RelIds
1316        // directly from the body's Scans, so the map isn't
1317        // consulted here. Production callers still pass it
1318        // so the API surface is forward-compatible.
1319        let _ = rel_ids;
1320        for rules in plan.rules_by_scc.iter_mut() {
1321            for rule in rules.iter_mut() {
1322                if let Some(rewritten) = super::reorder::try_reorder_triangle(&rule.body, stats) {
1323                    rule.body = rewritten;
1324                    continue;
1325                }
1326                if let Some(rewritten) = super::reorder::try_reorder_4cycle(&rule.body, stats) {
1327                    rule.body = rewritten;
1328                }
1329            }
1330        }
1331    }
1332}
1333
1334/// Ahead-of-time helper-relation splitting for deep joins with buried skew.
1335pub mod helper_split_pass {
1336    use std::collections::{HashMap, HashSet};
1337
1338    use xlog_core::{RelId, ScalarType, Schema};
1339    use xlog_ir::rir::{HelperSplitSpec, KCliqueVariableOrder};
1340    use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
1341    use xlog_stats::StatsManager;
1342
1343    const HEAVY_SKEW_RATIO: f64 = 10.0;
1344
1345    /// Description of a helper relation introduced by the pass.
1346    #[derive(Debug, Clone, PartialEq, Eq)]
1347    pub struct HelperRelationSpec {
1348        /// Predicate name allocated for the helper relation.
1349        pub name: String,
1350        /// Relation identifier allocated for the helper relation.
1351        pub rel_id: RelId,
1352        /// Output schema of the helper relation.
1353        pub schema: Schema,
1354        /// Pair of source relations extracted into the helper body.
1355        pub source_rels: [RelId; 2],
1356    }
1357
1358    struct JoinStep {
1359        left_keys: Vec<usize>,
1360        right_keys: Vec<usize>,
1361    }
1362
1363    struct LinearBody {
1364        leaves: Vec<RelId>,
1365        leaf_classes: Vec<Vec<u32>>,
1366        joins: Vec<JoinStep>,
1367        project: Vec<ProjectExpr>,
1368        final_classes: Vec<u32>,
1369    }
1370
1371    struct FlatJoin {
1372        leaves: Vec<RelId>,
1373        output_cols: Vec<usize>,
1374        equalities: Vec<(usize, usize)>,
1375    }
1376
1377    struct Candidate {
1378        pair_start: usize,
1379        helper_schema: Schema,
1380        helper_project: Vec<ProjectExpr>,
1381        helper_join_left_keys: Vec<usize>,
1382        helper_join_right_keys: Vec<usize>,
1383        exposed_classes: Vec<u32>,
1384    }
1385
1386    struct Rewrite {
1387        helper_body: RirNode,
1388        outer_body: RirNode,
1389        spec: HelperRelationSpec,
1390    }
1391
1392    #[derive(Clone, Copy)]
1393    struct KCliqueHelperEdge {
1394        slot: usize,
1395        rel: RelId,
1396        left: usize,
1397        right: usize,
1398    }
1399
1400    /// Rewrite eligible rules in-place and return the helper relations introduced.
1401    pub fn run<F>(
1402        plan: &mut ExecutionPlan,
1403        schemas: &HashMap<RelId, Schema>,
1404        stats: &StatsManager,
1405        mut allocate: F,
1406    ) -> Vec<HelperRelationSpec>
1407    where
1408        F: FnMut(Schema) -> (String, RelId),
1409    {
1410        let mut specs = Vec::new();
1411        for scc_idx in 0..plan.rules_by_scc.len() {
1412            let mut rule_idx = 0;
1413            while rule_idx < plan.rules_by_scc[scc_idx].len() {
1414                let rewrite = {
1415                    let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1416                    try_rewrite_rule(rule, schemas, stats, &mut allocate)
1417                };
1418                if let Some(rewrite) = rewrite {
1419                    let helper_rule = CompiledRule {
1420                        head: rewrite.spec.name.clone(),
1421                        body: rewrite.helper_body,
1422                        meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1423                    };
1424                    plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1425                    rule_idx += 1;
1426                    plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1427                    add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1428                    specs.push(rewrite.spec);
1429                }
1430                rule_idx += 1;
1431            }
1432        }
1433        specs
1434    }
1435
1436    /// K-clique helper split entry for K-clique plans that already carry
1437    /// planner-produced `HelperSplitSpec`s. The pass reuses the compiler-owned
1438    /// helper-relation lifecycle: emit a helper rule before the consumer rule,
1439    /// allocate a compiler-owned helper relation, and rewrite the consumer to
1440    /// scan that helper.
1441    pub fn run_kclique_specs<F>(
1442        plan: &mut ExecutionPlan,
1443        schemas: &HashMap<RelId, Schema>,
1444        mut allocate: F,
1445    ) -> Vec<HelperRelationSpec>
1446    where
1447        F: FnMut(Schema) -> (String, RelId),
1448    {
1449        let mut specs = Vec::new();
1450        for scc_idx in 0..plan.rules_by_scc.len() {
1451            let mut rule_idx = 0;
1452            while rule_idx < plan.rules_by_scc[scc_idx].len() {
1453                let rewrite = {
1454                    let rule = &plan.rules_by_scc[scc_idx][rule_idx];
1455                    try_rewrite_kclique_rule(rule, schemas, &mut allocate)
1456                };
1457                if let Some(rewrite) = rewrite {
1458                    let helper_rule = CompiledRule {
1459                        head: rewrite.spec.name.clone(),
1460                        body: rewrite.helper_body,
1461                        meta: RirMeta::with_schema(rewrite.spec.schema.clone()),
1462                    };
1463                    plan.rules_by_scc[scc_idx].insert(rule_idx, helper_rule);
1464                    rule_idx += 1;
1465                    plan.rules_by_scc[scc_idx][rule_idx].body = rewrite.outer_body;
1466                    add_helper_to_scc(&mut plan.sccs, scc_idx, &rewrite.spec.name);
1467                    specs.push(rewrite.spec);
1468                }
1469                rule_idx += 1;
1470            }
1471        }
1472        specs
1473    }
1474
1475    fn add_helper_to_scc(sccs: &mut [Scc], scc_idx: usize, helper: &str) {
1476        if let Some(scc) = sccs.get_mut(scc_idx) {
1477            if !scc.predicates.iter().any(|p| p == helper) {
1478                scc.predicates.push(helper.to_string());
1479            }
1480        }
1481    }
1482
1483    fn try_rewrite_rule<F>(
1484        rule: &CompiledRule,
1485        schemas: &HashMap<RelId, Schema>,
1486        stats: &StatsManager,
1487        allocate: &mut F,
1488    ) -> Option<Rewrite>
1489    where
1490        F: FnMut(Schema) -> (String, RelId),
1491    {
1492        let linear = linearize_project_body(&rule.body, schemas)?;
1493        let candidate = choose_candidate(&linear, schemas, stats)?;
1494        let (helper_name, helper_rel) = allocate(candidate.helper_schema.clone());
1495        let helper_body = build_helper_body(&linear, &candidate);
1496        let outer_body = build_outer_body(&linear, &candidate, helper_rel)?;
1497        Some(Rewrite {
1498            helper_body,
1499            outer_body,
1500            spec: HelperRelationSpec {
1501                name: helper_name,
1502                rel_id: helper_rel,
1503                schema: candidate.helper_schema,
1504                source_rels: [
1505                    linear.leaves[candidate.pair_start],
1506                    linear.leaves[candidate.pair_start + 1],
1507                ],
1508            },
1509        })
1510    }
1511
1512    fn try_rewrite_kclique_rule<F>(
1513        rule: &CompiledRule,
1514        schemas: &HashMap<RelId, Schema>,
1515        allocate: &mut F,
1516    ) -> Option<Rewrite>
1517    where
1518        F: FnMut(Schema) -> (String, RelId),
1519    {
1520        let mut outer_body = rule.body.clone();
1521        let RirNode::MultiWayJoin {
1522            inputs, var_order, ..
1523        } = &mut outer_body
1524        else {
1525            return None;
1526        };
1527        let kclique = var_order.as_ref()?.kclique.as_ref()?;
1528        let spec = kclique.helper_split_specs.first()?;
1529        let (hot_left, hot_right, target) = kclique_helper_edges(inputs, kclique, spec)?;
1530        let helper_schema = schemas.get(&target.rel)?.clone();
1531        let (helper_name, helper_rel) = allocate(helper_schema.clone());
1532        let helper_body = build_kclique_helper_body(spec, hot_left, hot_right, target)?;
1533        *inputs.get_mut(target.slot)? = RirNode::Scan { rel: helper_rel };
1534        Some(Rewrite {
1535            helper_body,
1536            outer_body,
1537            spec: HelperRelationSpec {
1538                name: helper_name,
1539                rel_id: helper_rel,
1540                schema: helper_schema,
1541                source_rels: [hot_left.rel, hot_right.rel],
1542            },
1543        })
1544    }
1545
1546    fn kclique_helper_edges(
1547        inputs: &[RirNode],
1548        kclique: &KCliqueVariableOrder,
1549        spec: &HelperSplitSpec,
1550    ) -> Option<(KCliqueHelperEdge, KCliqueHelperEdge, KCliqueHelperEdge)> {
1551        let k = usize::from(kclique.k);
1552        let hot = usize::from(spec.variable);
1553        let mut hot_edges = Vec::new();
1554        let mut target = None;
1555        for &slot in &spec.edge_slots {
1556            let slot = usize::from(slot);
1557            let (left, right) = kclique_edge_pair(slot, k)?;
1558            let RirNode::Scan { rel } = inputs.get(slot)? else {
1559                return None;
1560            };
1561            let edge = KCliqueHelperEdge {
1562                slot,
1563                rel: *rel,
1564                left,
1565                right,
1566            };
1567            if left == hot || right == hot {
1568                hot_edges.push(edge);
1569            } else {
1570                target = Some(edge);
1571            }
1572        }
1573        if hot_edges.len() != 2 {
1574            return None;
1575        }
1576        Some((hot_edges[0], hot_edges[1], target?))
1577    }
1578
1579    fn build_kclique_helper_body(
1580        spec: &HelperSplitSpec,
1581        hot_left: KCliqueHelperEdge,
1582        hot_right: KCliqueHelperEdge,
1583        target: KCliqueHelperEdge,
1584    ) -> Option<RirNode> {
1585        let hot = usize::from(spec.variable);
1586        let target_left = target.left;
1587        let target_right = target.right;
1588        let first_other = kclique_other_endpoint(hot_left, hot)?;
1589        let second_other = kclique_other_endpoint(hot_right, hot)?;
1590        if ![first_other, second_other].contains(&target_left)
1591            || ![first_other, second_other].contains(&target_right)
1592        {
1593            return None;
1594        }
1595
1596        let first_scan = RirNode::Scan { rel: hot_left.rel };
1597        let second_scan = RirNode::Scan { rel: hot_right.rel };
1598        let target_scan = RirNode::Scan { rel: target.rel };
1599        let first_hot_col = kclique_endpoint_col(hot_left, hot)?;
1600        let second_hot_col = kclique_endpoint_col(hot_right, hot)?;
1601        let first_other_col = kclique_endpoint_col(hot_left, first_other)?;
1602        let second_other_col = 2 + kclique_endpoint_col(hot_right, second_other)?;
1603
1604        let target_left_in_join = if first_other == target_left {
1605            first_other_col
1606        } else {
1607            second_other_col
1608        };
1609        let target_right_in_join = if first_other == target_right {
1610            first_other_col
1611        } else {
1612            second_other_col
1613        };
1614        let target_left_col = kclique_endpoint_col(target, target_left)?;
1615        let target_right_col = kclique_endpoint_col(target, target_right)?;
1616
1617        let hot_join = RirNode::Join {
1618            left: Box::new(first_scan),
1619            right: Box::new(second_scan),
1620            left_keys: vec![first_hot_col],
1621            right_keys: vec![second_hot_col],
1622            join_type: JoinType::Inner,
1623        };
1624        let helper_join = RirNode::Join {
1625            left: Box::new(hot_join),
1626            right: Box::new(target_scan),
1627            left_keys: vec![target_left_in_join, target_right_in_join],
1628            right_keys: vec![target_left_col, target_right_col],
1629            join_type: JoinType::Inner,
1630        };
1631        Some(RirNode::Project {
1632            input: Box::new(helper_join),
1633            columns: vec![ProjectExpr::Column(4), ProjectExpr::Column(5)],
1634        })
1635    }
1636
1637    fn kclique_edge_pair(edge_idx: usize, k: usize) -> Option<(usize, usize)> {
1638        let mut idx = 0usize;
1639        for left in 0..k {
1640            for right in (left + 1)..k {
1641                if idx == edge_idx {
1642                    return Some((left, right));
1643                }
1644                idx += 1;
1645            }
1646        }
1647        None
1648    }
1649
1650    fn kclique_endpoint_col(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1651        if edge.left == variable {
1652            Some(0)
1653        } else if edge.right == variable {
1654            Some(1)
1655        } else {
1656            None
1657        }
1658    }
1659
1660    fn kclique_other_endpoint(edge: KCliqueHelperEdge, variable: usize) -> Option<usize> {
1661        if edge.left == variable {
1662            Some(edge.right)
1663        } else if edge.right == variable {
1664            Some(edge.left)
1665        } else {
1666            None
1667        }
1668    }
1669
1670    fn linearize_project_body(
1671        body: &RirNode,
1672        schemas: &HashMap<RelId, Schema>,
1673    ) -> Option<LinearBody> {
1674        let RirNode::Project { input, columns } = body else {
1675            return None;
1676        };
1677        let flat = collect_join_graph(input, schemas)?;
1678        if flat.leaves.len() < 6 {
1679            return None;
1680        }
1681        let mut offsets = Vec::with_capacity(flat.leaves.len());
1682        let mut total_cols = 0usize;
1683        for rel in &flat.leaves {
1684            offsets.push(total_cols);
1685            total_cols += schemas.get(rel)?.arity();
1686        }
1687        let mut uf = UnionFind::new(total_cols);
1688        for (left, right) in flat.equalities {
1689            if left >= total_cols || right >= total_cols {
1690                return None;
1691            }
1692            uf.union(left, right);
1693        }
1694        let mut leaf_classes: Vec<Vec<u32>> = Vec::with_capacity(flat.leaves.len());
1695        for (leaf_idx, rel) in flat.leaves.iter().enumerate() {
1696            let arity = schemas.get(rel)?.arity();
1697            let offset = offsets[leaf_idx];
1698            leaf_classes.push((0..arity).map(|col| uf.find(offset + col) as u32).collect());
1699        }
1700        let final_classes = flat
1701            .output_cols
1702            .iter()
1703            .map(|col| uf.find(*col) as u32)
1704            .collect();
1705        let joins = derive_left_deep_steps(&leaf_classes)?;
1706        Some(LinearBody {
1707            leaves: flat.leaves,
1708            leaf_classes,
1709            joins,
1710            project: columns.clone(),
1711            final_classes,
1712        })
1713    }
1714
1715    fn collect_join_graph(node: &RirNode, schemas: &HashMap<RelId, Schema>) -> Option<FlatJoin> {
1716        match node {
1717            RirNode::Scan { rel } => Some(FlatJoin {
1718                leaves: vec![*rel],
1719                output_cols: (0..schemas.get(rel)?.arity()).collect(),
1720                equalities: Vec::new(),
1721            }),
1722            RirNode::Join {
1723                left,
1724                right,
1725                left_keys,
1726                right_keys,
1727                join_type,
1728            } if *join_type == JoinType::Inner => {
1729                let left_flat = collect_join_graph(left, schemas)?;
1730                let right_flat = collect_join_graph(right, schemas)?;
1731                if left_keys.len() != right_keys.len() {
1732                    return None;
1733                }
1734                let right_shift = total_width(&left_flat.leaves, schemas)?;
1735                let mut leaves = left_flat.leaves;
1736                leaves.extend(right_flat.leaves);
1737                let right_output_cols: Vec<usize> = right_flat
1738                    .output_cols
1739                    .iter()
1740                    .map(|col| col + right_shift)
1741                    .collect();
1742                let mut equalities = left_flat.equalities;
1743                equalities.extend(
1744                    right_flat
1745                        .equalities
1746                        .iter()
1747                        .map(|(left, right)| (left + right_shift, right + right_shift)),
1748                );
1749                for (&left_key, &right_key) in left_keys.iter().zip(right_keys.iter()) {
1750                    equalities.push((
1751                        *left_flat.output_cols.get(left_key)?,
1752                        *right_output_cols.get(right_key)?,
1753                    ));
1754                }
1755                let mut output_cols = left_flat.output_cols;
1756                output_cols.extend(right_output_cols);
1757                Some(FlatJoin {
1758                    leaves,
1759                    output_cols,
1760                    equalities,
1761                })
1762            }
1763            _ => None,
1764        }
1765    }
1766
1767    fn total_width(leaves: &[RelId], schemas: &HashMap<RelId, Schema>) -> Option<usize> {
1768        leaves
1769            .iter()
1770            .map(|rel| schemas.get(rel).map(Schema::arity))
1771            .try_fold(0usize, |acc, width| width.map(|width| acc + width))
1772    }
1773
1774    fn derive_left_deep_steps(leaf_classes: &[Vec<u32>]) -> Option<Vec<JoinStep>> {
1775        let mut joins = Vec::with_capacity(leaf_classes.len().saturating_sub(1));
1776        let mut current = leaf_classes.first()?.clone();
1777        for classes in leaf_classes.iter().skip(1) {
1778            let mut left_keys = Vec::new();
1779            let mut right_keys = Vec::new();
1780            for (right_col, class) in classes.iter().enumerate() {
1781                if let Some(left_col) = current
1782                    .iter()
1783                    .position(|current_class| current_class == class)
1784                {
1785                    left_keys.push(left_col);
1786                    right_keys.push(right_col);
1787                }
1788            }
1789            if left_keys.is_empty() {
1790                return None;
1791            }
1792            joins.push(JoinStep {
1793                left_keys,
1794                right_keys,
1795            });
1796            current.extend(classes.iter().copied());
1797        }
1798        Some(joins)
1799    }
1800
1801    fn choose_candidate(
1802        linear: &LinearBody,
1803        schemas: &HashMap<RelId, Schema>,
1804        stats: &StatsManager,
1805    ) -> Option<Candidate> {
1806        for pair_start in 3..linear.leaves.len().saturating_sub(1) {
1807            let candidate = build_candidate(linear, schemas, pair_start)?;
1808            if skew_ratio_for_candidate(linear, stats, &candidate) >= HEAVY_SKEW_RATIO {
1809                return Some(candidate);
1810            }
1811        }
1812        None
1813    }
1814
1815    fn build_candidate(
1816        linear: &LinearBody,
1817        schemas: &HashMap<RelId, Schema>,
1818        pair_start: usize,
1819    ) -> Option<Candidate> {
1820        let left_rel = linear.leaves[pair_start];
1821        let right_rel = linear.leaves[pair_start + 1];
1822        let left_schema = schemas.get(&left_rel)?;
1823        let right_schema = schemas.get(&right_rel)?;
1824        let internal_step = linear.joins.get(pair_start)?;
1825        let mut helper_left_keys = Vec::new();
1826        let mut helper_right_keys = Vec::new();
1827        for (&left_key, &right_key) in internal_step
1828            .left_keys
1829            .iter()
1830            .zip(internal_step.right_keys.iter())
1831        {
1832            let class = class_at_state(linear, pair_start + 1, left_key)?;
1833            let left_col = linear.leaf_classes[pair_start]
1834                .iter()
1835                .position(|c| *c == class)?;
1836            helper_left_keys.push(left_col);
1837            helper_right_keys.push(right_key);
1838        }
1839        let internal: HashSet<u32> = helper_left_keys
1840            .iter()
1841            .map(|col| linear.leaf_classes[pair_start][*col])
1842            .collect();
1843        let outside = outside_classes(linear, pair_start);
1844        let output = projected_classes(linear)?;
1845        let mut exposed_classes = Vec::new();
1846        let mut helper_project = Vec::new();
1847        let mut helper_columns = Vec::new();
1848        for (col, class) in linear.leaf_classes[pair_start].iter().copied().enumerate() {
1849            if !internal.contains(&class)
1850                && (outside.contains(&class) || output.contains(&class))
1851                && !exposed_classes.contains(&class)
1852            {
1853                exposed_classes.push(class);
1854                helper_project.push(ProjectExpr::Column(col));
1855                let ty = left_schema.column_type(col).unwrap_or(ScalarType::U32);
1856                helper_columns.push((format!("c{}", helper_columns.len()), ty));
1857            }
1858        }
1859        let right_offset = left_schema.arity();
1860        for (col, class) in linear.leaf_classes[pair_start + 1]
1861            .iter()
1862            .copied()
1863            .enumerate()
1864        {
1865            if !internal.contains(&class)
1866                && (outside.contains(&class) || output.contains(&class))
1867                && !exposed_classes.contains(&class)
1868            {
1869                exposed_classes.push(class);
1870                helper_project.push(ProjectExpr::Column(right_offset + col));
1871                let ty = right_schema.column_type(col).unwrap_or(ScalarType::U32);
1872                helper_columns.push((format!("c{}", helper_columns.len()), ty));
1873            }
1874        }
1875        if exposed_classes.len() != 2 {
1876            return None;
1877        }
1878        Some(Candidate {
1879            pair_start,
1880            helper_schema: Schema::new(helper_columns),
1881            helper_project,
1882            helper_join_left_keys: helper_left_keys,
1883            helper_join_right_keys: helper_right_keys,
1884            exposed_classes,
1885        })
1886    }
1887
1888    fn class_at_state(linear: &LinearBody, leaf_count: usize, col: usize) -> Option<u32> {
1889        let mut idx = col;
1890        for leaf_idx in 0..leaf_count {
1891            let classes = &linear.leaf_classes[leaf_idx];
1892            if idx < classes.len() {
1893                return Some(classes[idx]);
1894            }
1895            idx -= classes.len();
1896        }
1897        None
1898    }
1899
1900    fn outside_classes(linear: &LinearBody, pair_start: usize) -> HashSet<u32> {
1901        linear
1902            .leaf_classes
1903            .iter()
1904            .enumerate()
1905            .filter(|(idx, _)| *idx != pair_start && *idx != pair_start + 1)
1906            .flat_map(|(_, classes)| classes.iter().copied())
1907            .collect()
1908    }
1909
1910    fn projected_classes(linear: &LinearBody) -> Option<HashSet<u32>> {
1911        let mut out = HashSet::new();
1912        for expr in &linear.project {
1913            let ProjectExpr::Column(col) = expr else {
1914                return None;
1915            };
1916            out.insert(*linear.final_classes.get(*col)?);
1917        }
1918        Some(out)
1919    }
1920
1921    fn skew_ratio_for_candidate(
1922        linear: &LinearBody,
1923        stats: &StatsManager,
1924        candidate: &Candidate,
1925    ) -> f64 {
1926        let rel = linear.leaves[candidate.pair_start];
1927        let Some(rel_stats) = stats.get_relation_stats(rel) else {
1928            return 0.0;
1929        };
1930        let mut ratio: f64 = 0.0;
1931        for (col, class) in linear.leaf_classes[candidate.pair_start]
1932            .iter()
1933            .copied()
1934            .enumerate()
1935        {
1936            if !candidate.exposed_classes.contains(&class) {
1937                continue;
1938            }
1939            let Some(col_stats) = rel_stats.get_column(col) else {
1940                continue;
1941            };
1942            if col_stats.distinct_estimate == 0 {
1943                continue;
1944            }
1945            ratio = ratio.max(rel_stats.cardinality as f64 / col_stats.distinct_estimate as f64);
1946        }
1947        ratio
1948    }
1949
1950    fn build_helper_body(linear: &LinearBody, candidate: &Candidate) -> RirNode {
1951        let left = RirNode::Scan {
1952            rel: linear.leaves[candidate.pair_start],
1953        };
1954        let right = RirNode::Scan {
1955            rel: linear.leaves[candidate.pair_start + 1],
1956        };
1957        RirNode::Project {
1958            input: Box::new(RirNode::Join {
1959                left: Box::new(left),
1960                right: Box::new(right),
1961                left_keys: candidate.helper_join_left_keys.clone(),
1962                right_keys: candidate.helper_join_right_keys.clone(),
1963                join_type: JoinType::Inner,
1964            }),
1965            columns: candidate.helper_project.clone(),
1966        }
1967    }
1968
1969    fn build_outer_body(
1970        linear: &LinearBody,
1971        candidate: &Candidate,
1972        helper_rel: RelId,
1973    ) -> Option<RirNode> {
1974        let mut node = RirNode::Scan {
1975            rel: linear.leaves[0],
1976        };
1977        let mut classes = linear.leaf_classes[0].clone();
1978        for leaf_idx in 1..candidate.pair_start {
1979            let step = &linear.joins[leaf_idx - 1];
1980            node = RirNode::Join {
1981                left: Box::new(node),
1982                right: Box::new(RirNode::Scan {
1983                    rel: linear.leaves[leaf_idx],
1984                }),
1985                left_keys: step.left_keys.clone(),
1986                right_keys: step.right_keys.clone(),
1987                join_type: JoinType::Inner,
1988            };
1989            classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
1990        }
1991        let prefix_step = &linear.joins[candidate.pair_start - 1];
1992        let mut helper_right_keys = Vec::new();
1993        for &rk in &prefix_step.right_keys {
1994            let class = linear.leaf_classes[candidate.pair_start][rk];
1995            helper_right_keys.push(candidate.exposed_classes.iter().position(|c| *c == class)?);
1996        }
1997        node = RirNode::Join {
1998            left: Box::new(node),
1999            right: Box::new(RirNode::Scan { rel: helper_rel }),
2000            left_keys: prefix_step.left_keys.clone(),
2001            right_keys: helper_right_keys,
2002            join_type: JoinType::Inner,
2003        };
2004        classes.extend(candidate.exposed_classes.iter().copied());
2005        for leaf_idx in candidate.pair_start + 2..linear.leaves.len() {
2006            let step = &linear.joins[leaf_idx - 1];
2007            let mut left_keys = Vec::new();
2008            for &lk in &step.left_keys {
2009                let class = class_at_state(linear, leaf_idx, lk)?;
2010                left_keys.push(classes.iter().position(|c| *c == class)?);
2011            }
2012            node = RirNode::Join {
2013                left: Box::new(node),
2014                right: Box::new(RirNode::Scan {
2015                    rel: linear.leaves[leaf_idx],
2016                }),
2017                left_keys,
2018                right_keys: step.right_keys.clone(),
2019                join_type: JoinType::Inner,
2020            };
2021            classes.extend(linear.leaf_classes[leaf_idx].iter().copied());
2022        }
2023        let mut project = Vec::with_capacity(linear.project.len());
2024        for expr in &linear.project {
2025            let ProjectExpr::Column(col) = expr else {
2026                return None;
2027            };
2028            let class = *linear.final_classes.get(*col)?;
2029            let mapped = classes.iter().position(|c| *c == class)?;
2030            project.push(ProjectExpr::Column(mapped));
2031        }
2032        Some(RirNode::Project {
2033            input: Box::new(node),
2034            columns: project,
2035        })
2036    }
2037
2038    struct UnionFind {
2039        parent: Vec<usize>,
2040    }
2041
2042    impl UnionFind {
2043        fn new(len: usize) -> Self {
2044            Self {
2045                parent: (0..len).collect(),
2046            }
2047        }
2048
2049        fn find(&mut self, x: usize) -> usize {
2050            let p = self.parent[x];
2051            if p == x {
2052                x
2053            } else {
2054                let root = self.find(p);
2055                self.parent[x] = root;
2056                root
2057            }
2058        }
2059
2060        fn union(&mut self, a: usize, b: usize) {
2061            let ra = self.find(a);
2062            let rb = self.find(b);
2063            if ra != rb {
2064                self.parent[rb] = ra;
2065            }
2066        }
2067    }
2068}
2069
2070#[path = "optimizer/stream_schedule_pass.rs"]
2071pub mod stream_schedule_pass;
2072
2073#[cfg(test)]
2074mod helper_split_pass_tests {
2075    use std::collections::HashMap;
2076
2077    use super::helper_split_pass;
2078    use xlog_core::{RelId, ScalarType, Schema};
2079    use xlog_ir::{CompiledRule, ExecutionPlan, JoinType, ProjectExpr, RirMeta, RirNode, Scc};
2080    use xlog_stats::{ColumnStats, StatsManager};
2081
2082    fn edge_schema() -> Schema {
2083        Schema::new(vec![
2084            ("c0".to_string(), ScalarType::U32),
2085            ("c1".to_string(), ScalarType::U32),
2086        ])
2087    }
2088
2089    fn helper_schema() -> Schema {
2090        Schema::new(vec![
2091            ("c0".to_string(), ScalarType::U32),
2092            ("c1".to_string(), ScalarType::U32),
2093        ])
2094    }
2095
2096    fn schemas() -> HashMap<RelId, Schema> {
2097        (0..6)
2098            .map(|idx| (RelId(idx), edge_schema()))
2099            .collect::<HashMap<_, _>>()
2100    }
2101
2102    fn left_deep_fixture_body() -> RirNode {
2103        let ab_bc = RirNode::Join {
2104            left: Box::new(RirNode::Scan { rel: RelId(0) }),
2105            right: Box::new(RirNode::Scan { rel: RelId(1) }),
2106            left_keys: vec![1],
2107            right_keys: vec![0],
2108            join_type: JoinType::Inner,
2109        };
2110        let with_cd = RirNode::Join {
2111            left: Box::new(ab_bc),
2112            right: Box::new(RirNode::Scan { rel: RelId(2) }),
2113            left_keys: vec![3],
2114            right_keys: vec![0],
2115            join_type: JoinType::Inner,
2116        };
2117        let with_de = RirNode::Join {
2118            left: Box::new(with_cd),
2119            right: Box::new(RirNode::Scan { rel: RelId(3) }),
2120            left_keys: vec![5],
2121            right_keys: vec![0],
2122            join_type: JoinType::Inner,
2123        };
2124        let with_ef = RirNode::Join {
2125            left: Box::new(with_de),
2126            right: Box::new(RirNode::Scan { rel: RelId(4) }),
2127            left_keys: vec![7],
2128            right_keys: vec![0],
2129            join_type: JoinType::Inner,
2130        };
2131        let with_af = RirNode::Join {
2132            left: Box::new(with_ef),
2133            right: Box::new(RirNode::Scan { rel: RelId(5) }),
2134            left_keys: vec![0, 9],
2135            right_keys: vec![0, 1],
2136            join_type: JoinType::Inner,
2137        };
2138        RirNode::Project {
2139            input: Box::new(with_af),
2140            columns: vec![
2141                ProjectExpr::Column(0),
2142                ProjectExpr::Column(1),
2143                ProjectExpr::Column(3),
2144                ProjectExpr::Column(5),
2145                ProjectExpr::Column(9),
2146            ],
2147        }
2148    }
2149
2150    fn plan() -> ExecutionPlan {
2151        ExecutionPlan {
2152            sccs: vec![Scc {
2153                id: 0,
2154                predicates: vec!["out".to_string()],
2155                is_recursive: false,
2156            }],
2157            strata: vec![],
2158            rules_by_scc: vec![vec![CompiledRule {
2159                head: "out".to_string(),
2160                body: left_deep_fixture_body(),
2161                meta: RirMeta::with_schema(Schema::new(vec![
2162                    ("a".to_string(), ScalarType::U32),
2163                    ("b".to_string(), ScalarType::U32),
2164                    ("c".to_string(), ScalarType::U32),
2165                    ("d".to_string(), ScalarType::U32),
2166                    ("f".to_string(), ScalarType::U32),
2167                ])),
2168            }]],
2169            est_memory_peak: 0,
2170            rel_arities: std::collections::HashMap::new(),
2171        }
2172    }
2173
2174    fn stats_for_de(distinct_d: u64) -> StatsManager {
2175        let mut stats = StatsManager::new();
2176        for idx in 0..6 {
2177            stats.register_relation(RelId(idx));
2178            stats.update_cardinality(RelId(idx), 8192);
2179        }
2180        let mut d_col = ColumnStats::new(0, ScalarType::U32);
2181        d_col.update_distinct(distinct_d);
2182        stats.add_column_stats(RelId(3), d_col);
2183        stats
2184    }
2185
2186    fn contains_scan(node: &RirNode, rel: RelId) -> bool {
2187        match node {
2188            RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
2189            RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
2190                contains_scan(left, rel) || contains_scan(right, rel)
2191            }
2192            RirNode::Project { input, .. }
2193            | RirNode::Filter { input, .. }
2194            | RirNode::Distinct { input, .. }
2195            | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
2196            RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
2197            RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
2198            RirNode::Fixpoint {
2199                base, recursive, ..
2200            } => contains_scan(base, rel) || contains_scan(recursive, rel),
2201            RirNode::MultiWayJoin { inputs, .. } => {
2202                inputs.iter().any(|input| contains_scan(input, rel))
2203            }
2204            RirNode::TensorMaskedJoin { rel_index, .. } => {
2205                rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
2206            }
2207            RirNode::Unit => false,
2208        }
2209    }
2210
2211    #[test]
2212    fn helper_split_extracts_buried_pair() {
2213        let mut plan = plan();
2214        let schemas = schemas();
2215        let stats = stats_for_de(1);
2216        let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2217            ("__kclique_helper_6".to_string(), RelId(6))
2218        });
2219
2220        assert_eq!(specs.len(), 1);
2221        assert_eq!(specs[0].name, "__kclique_helper_6");
2222        assert_eq!(specs[0].rel_id, RelId(6));
2223        assert_eq!(specs[0].schema, helper_schema());
2224        assert_eq!(specs[0].source_rels, [RelId(3), RelId(4)]);
2225        assert_eq!(plan.rules_by_scc[0].len(), 2);
2226        assert_eq!(plan.rules_by_scc[0][0].head, "__kclique_helper_6");
2227        assert_eq!(plan.rules_by_scc[0][1].head, "out");
2228        assert!(contains_scan(&plan.rules_by_scc[0][1].body, RelId(6)));
2229        assert!(plan.sccs[0]
2230            .predicates
2231            .iter()
2232            .any(|predicate| predicate == "__kclique_helper_6"));
2233    }
2234
2235    #[test]
2236    fn helper_split_ignores_flat_distribution() {
2237        let mut plan = plan();
2238        let schemas = schemas();
2239        let stats = stats_for_de(8192);
2240        let specs = helper_split_pass::run(&mut plan, &schemas, &stats, |_| {
2241            ("__kclique_helper_6".to_string(), RelId(6))
2242        });
2243
2244        assert!(specs.is_empty());
2245        assert_eq!(plan.rules_by_scc[0].len(), 1);
2246        assert!(!contains_scan(&plan.rules_by_scc[0][0].body, RelId(6)));
2247    }
2248}
2249
2250/// Selectivity-driven body rewriters for triangle and 4-cycle canonical lowered
2251/// shapes. `pub(super)` so `selectivity_pass::run` can dispatch into them.
2252mod reorder {
2253    use std::collections::HashMap;
2254    use xlog_core::RelId;
2255    use xlog_ir::rir::ProjectExpr;
2256    use xlog_ir::{JoinType, RirNode};
2257    use xlog_stats::StatsManager;
2258
2259    fn ac3(atom: u8, col: u8) -> u8 {
2260        atom * 2 + col
2261    }
2262    fn ac4(atom: u8, col: u8) -> u8 {
2263        atom * 2 + col
2264    }
2265    fn uf_find_n<const N: usize>(parent: &mut [u8; N], x: u8) -> u8 {
2266        let mut root = x;
2267        while parent[root as usize] != root {
2268            root = parent[root as usize];
2269        }
2270        let mut cur = x;
2271        while parent[cur as usize] != root {
2272            let next = parent[cur as usize];
2273            parent[cur as usize] = root;
2274            cur = next;
2275        }
2276        root
2277    }
2278    fn uf_union_n<const N: usize>(parent: &mut [u8; N], a: u8, b: u8) {
2279        let ra = uf_find_n(parent, a);
2280        let rb = uf_find_n(parent, b);
2281        if ra != rb {
2282            parent[rb as usize] = ra;
2283        }
2284    }
2285
2286    fn populated_card(stats: &StatsManager, rel: RelId) -> Option<u64> {
2287        stats
2288            .get_relation_stats(rel)
2289            .map(|s| s.cardinality)
2290            .filter(|c| *c > 0)
2291    }
2292
2293    // ---------------------------------------------------------
2294    // Triangle rewriter
2295    // ---------------------------------------------------------
2296
2297    struct TriangleSemantics {
2298        rel_xy: RelId,
2299        rel_yz: RelId,
2300        rel_xz: RelId,
2301    }
2302
2303    fn match_and_infer_triangle(body: &RirNode) -> Option<TriangleSemantics> {
2304        let RirNode::Project {
2305            input: outer_input,
2306            columns,
2307        } = body
2308        else {
2309            return None;
2310        };
2311        let RirNode::Join {
2312            left: l1,
2313            right: r1,
2314            left_keys: lk1,
2315            right_keys: rk1,
2316            join_type: jt1,
2317        } = outer_input.as_ref()
2318        else {
2319            return None;
2320        };
2321        if !matches!(jt1, JoinType::Inner) {
2322            return None;
2323        }
2324        let RirNode::Scan { rel: rel_third } = r1.as_ref() else {
2325            return None;
2326        };
2327        let RirNode::Join {
2328            left: l2,
2329            right: r2,
2330            left_keys: lk2,
2331            right_keys: rk2,
2332            join_type: jt2,
2333        } = l1.as_ref()
2334        else {
2335            return None;
2336        };
2337        if !matches!(jt2, JoinType::Inner) {
2338            return None;
2339        }
2340        let RirNode::Scan { rel: rel_inner_l } = l2.as_ref() else {
2341            return None;
2342        };
2343        let RirNode::Scan { rel: rel_inner_r } = r2.as_ref() else {
2344            return None;
2345        };
2346        if lk2.len() != 1 || rk2.len() != 1 || lk1.len() != 2 || rk1.len() != 2 {
2347            return None;
2348        }
2349        if columns.len() != 3 {
2350            return None;
2351        }
2352        if lk2[0] >= 2 || rk2[0] >= 2 {
2353            return None;
2354        }
2355        if lk1.iter().any(|k| *k >= 4) || rk1.iter().any(|k| *k >= 2) {
2356            return None;
2357        }
2358
2359        let mut parent = [0u8, 1, 2, 3, 4, 5];
2360        uf_union_n::<6>(&mut parent, ac3(0, lk2[0] as u8), ac3(1, rk2[0] as u8));
2361        for i in 0..2 {
2362            let inner_ac = match lk1[i] {
2363                0 => (0u8, 0u8),
2364                1 => (0, 1),
2365                2 => (1, 0),
2366                3 => (1, 1),
2367                _ => return None,
2368            };
2369            uf_union_n::<6>(
2370                &mut parent,
2371                ac3(inner_ac.0, inner_ac.1),
2372                ac3(2, rk1[i] as u8),
2373            );
2374        }
2375        let roots: [u8; 6] = std::array::from_fn(|i| uf_find_n::<6>(&mut parent, i as u8));
2376        let mut counts: HashMap<u8, u8> = HashMap::new();
2377        for r in &roots {
2378            *counts.entry(*r).or_insert(0) += 1;
2379        }
2380        if counts.len() != 3 || counts.values().any(|c| *c != 2) {
2381            return None;
2382        }
2383        let mut head_classes: [u8; 3] = [0; 3];
2384        for (i, pc) in columns.iter().enumerate() {
2385            let ProjectExpr::Column(k) = pc else {
2386                return None;
2387            };
2388            let outer_ac = match *k {
2389                0 => (0u8, 0u8),
2390                1 => (0, 1),
2391                2 => (1, 0),
2392                3 => (1, 1),
2393                4 => (2, 0),
2394                5 => (2, 1),
2395                _ => return None,
2396            };
2397            head_classes[i] = uf_find_n::<6>(&mut parent, ac3(outer_ac.0, outer_ac.1));
2398        }
2399        if head_classes[0] == head_classes[1]
2400            || head_classes[0] == head_classes[2]
2401            || head_classes[1] == head_classes[2]
2402        {
2403            return None;
2404        }
2405        let x_class = head_classes[0];
2406        let y_class = head_classes[1];
2407        let z_class = head_classes[2];
2408        let atom_classes = |a: u8| (roots[ac3(a, 0) as usize], roots[ac3(a, 1) as usize]);
2409        let atom_rels = [*rel_inner_l, *rel_inner_r, *rel_third];
2410        let mut rel_xy = None;
2411        let mut rel_yz = None;
2412        let mut rel_xz = None;
2413        for atom_idx in 0..3u8 {
2414            let (c0, c1) = atom_classes(atom_idx);
2415            let bx = c0 == x_class || c1 == x_class;
2416            let by = c0 == y_class || c1 == y_class;
2417            let bz = c0 == z_class || c1 == z_class;
2418            match (bx, by, bz) {
2419                (true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2420                (false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2421                (true, false, true) => rel_xz = Some(atom_rels[atom_idx as usize]),
2422                _ => return None,
2423            }
2424        }
2425        Some(TriangleSemantics {
2426            rel_xy: rel_xy?,
2427            rel_yz: rel_yz?,
2428            rel_xz: rel_xz?,
2429        })
2430    }
2431
2432    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2433    #[allow(clippy::enum_variant_names)]
2434    enum TriangleInnerPair {
2435        YShared,
2436        XShared,
2437        ZShared,
2438    }
2439
2440    fn build_triangle_body(s: &TriangleSemantics, inner_pair: TriangleInnerPair) -> RirNode {
2441        let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2442        match inner_pair {
2443            TriangleInnerPair::YShared => {
2444                let inner = RirNode::Join {
2445                    left: Box::new(mk_scan(s.rel_xy)),
2446                    right: Box::new(mk_scan(s.rel_yz)),
2447                    left_keys: vec![1],
2448                    right_keys: vec![0],
2449                    join_type: JoinType::Inner,
2450                };
2451                let outer = RirNode::Join {
2452                    left: Box::new(inner),
2453                    right: Box::new(mk_scan(s.rel_xz)),
2454                    left_keys: vec![0, 3],
2455                    right_keys: vec![0, 1],
2456                    join_type: JoinType::Inner,
2457                };
2458                RirNode::Project {
2459                    input: Box::new(outer),
2460                    columns: vec![
2461                        ProjectExpr::Column(0),
2462                        ProjectExpr::Column(1),
2463                        ProjectExpr::Column(3),
2464                    ],
2465                }
2466            }
2467            TriangleInnerPair::XShared => {
2468                let inner = RirNode::Join {
2469                    left: Box::new(mk_scan(s.rel_xy)),
2470                    right: Box::new(mk_scan(s.rel_xz)),
2471                    left_keys: vec![0],
2472                    right_keys: vec![0],
2473                    join_type: JoinType::Inner,
2474                };
2475                let outer = RirNode::Join {
2476                    left: Box::new(inner),
2477                    right: Box::new(mk_scan(s.rel_yz)),
2478                    left_keys: vec![1, 3],
2479                    right_keys: vec![0, 1],
2480                    join_type: JoinType::Inner,
2481                };
2482                RirNode::Project {
2483                    input: Box::new(outer),
2484                    columns: vec![
2485                        ProjectExpr::Column(0),
2486                        ProjectExpr::Column(1),
2487                        ProjectExpr::Column(3),
2488                    ],
2489                }
2490            }
2491            TriangleInnerPair::ZShared => {
2492                let inner = RirNode::Join {
2493                    left: Box::new(mk_scan(s.rel_xz)),
2494                    right: Box::new(mk_scan(s.rel_yz)),
2495                    left_keys: vec![1],
2496                    right_keys: vec![1],
2497                    join_type: JoinType::Inner,
2498                };
2499                let outer = RirNode::Join {
2500                    left: Box::new(inner),
2501                    right: Box::new(mk_scan(s.rel_xy)),
2502                    left_keys: vec![0, 2],
2503                    right_keys: vec![0, 1],
2504                    join_type: JoinType::Inner,
2505                };
2506                RirNode::Project {
2507                    input: Box::new(outer),
2508                    columns: vec![
2509                        ProjectExpr::Column(0),
2510                        ProjectExpr::Column(2),
2511                        ProjectExpr::Column(3),
2512                    ],
2513                }
2514            }
2515        }
2516    }
2517
2518    pub fn try_reorder_triangle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2519        let s = match_and_infer_triangle(body)?;
2520        let _ = (
2521            populated_card(stats, s.rel_xy)?,
2522            populated_card(stats, s.rel_yz)?,
2523            populated_card(stats, s.rel_xz)?,
2524        );
2525        let est_y = stats.estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0]);
2526        let est_x = stats.estimate_join_cardinality(s.rel_xy, s.rel_xz, &[0], &[0]);
2527        let est_z = stats.estimate_join_cardinality(s.rel_yz, s.rel_xz, &[1], &[1]);
2528        let mut best = (TriangleInnerPair::YShared, est_y);
2529        if est_x < best.1 {
2530            best = (TriangleInnerPair::XShared, est_x);
2531        }
2532        if est_z < best.1 {
2533            best = (TriangleInnerPair::ZShared, est_z);
2534        }
2535        let candidate = build_triangle_body(&s, best.0);
2536        // Skip when the candidate is structurally identical to
2537        // the input (no-op rewrite). RirNode doesn't impl
2538        // PartialEq, so compare via Debug — bodies are small
2539        // (≤ 6 Scans + 2 Joins + 1 Project) so the cost is
2540        // negligible relative to the optimizer's broader work.
2541        if format!("{:?}", candidate) == format!("{:?}", body) {
2542            return None;
2543        }
2544        Some(candidate)
2545    }
2546
2547    // ---------------------------------------------------------
2548    // 4-cycle rewriter
2549    // ---------------------------------------------------------
2550
2551    struct Cycle4Semantics {
2552        rel_wx: RelId,
2553        rel_xy: RelId,
2554        rel_yz: RelId,
2555        rel_zw: RelId,
2556    }
2557
2558    fn match_and_infer_4cycle(body: &RirNode) -> Option<Cycle4Semantics> {
2559        let RirNode::Project {
2560            input: outer_input,
2561            columns,
2562        } = body
2563        else {
2564            return None;
2565        };
2566        let RirNode::Join {
2567            left: outer_l,
2568            right: outer_r,
2569            left_keys: olk,
2570            right_keys: ork,
2571            join_type: ojt,
2572        } = outer_input.as_ref()
2573        else {
2574            return None;
2575        };
2576        if !matches!(ojt, JoinType::Inner) {
2577            return None;
2578        }
2579        let RirNode::Join {
2580            left: ll,
2581            right: lr,
2582            left_keys: ilk_l,
2583            right_keys: irk_l,
2584            join_type: ijt_l,
2585        } = outer_l.as_ref()
2586        else {
2587            return None;
2588        };
2589        if !matches!(ijt_l, JoinType::Inner) {
2590            return None;
2591        }
2592        let RirNode::Scan { rel: rel_ll } = ll.as_ref() else {
2593            return None;
2594        };
2595        let RirNode::Scan { rel: rel_lr } = lr.as_ref() else {
2596            return None;
2597        };
2598        let RirNode::Join {
2599            left: rl,
2600            right: rr,
2601            left_keys: ilk_r,
2602            right_keys: irk_r,
2603            join_type: ijt_r,
2604        } = outer_r.as_ref()
2605        else {
2606            return None;
2607        };
2608        if !matches!(ijt_r, JoinType::Inner) {
2609            return None;
2610        }
2611        let RirNode::Scan { rel: rel_rl } = rl.as_ref() else {
2612            return None;
2613        };
2614        let RirNode::Scan { rel: rel_rr } = rr.as_ref() else {
2615            return None;
2616        };
2617        if ilk_l.len() != 1 || irk_l.len() != 1 || ilk_r.len() != 1 || irk_r.len() != 1 {
2618            return None;
2619        }
2620        if olk.len() != 2 || ork.len() != 2 || columns.len() != 4 {
2621            return None;
2622        }
2623        if ilk_l[0] >= 2 || irk_l[0] >= 2 || ilk_r[0] >= 2 || irk_r[0] >= 2 {
2624            return None;
2625        }
2626        if olk.iter().any(|k| *k >= 4) || ork.iter().any(|k| *k >= 4) {
2627            return None;
2628        }
2629
2630        let mut parent = [0u8, 1, 2, 3, 4, 5, 6, 7];
2631        uf_union_n::<8>(&mut parent, ac4(0, ilk_l[0] as u8), ac4(1, irk_l[0] as u8));
2632        uf_union_n::<8>(&mut parent, ac4(2, ilk_r[0] as u8), ac4(3, irk_r[0] as u8));
2633        for i in 0..2 {
2634            let l_ac = match olk[i] {
2635                0 => (0u8, 0u8),
2636                1 => (0, 1),
2637                2 => (1, 0),
2638                3 => (1, 1),
2639                _ => return None,
2640            };
2641            let r_ac = match ork[i] {
2642                0 => (2u8, 0u8),
2643                1 => (2, 1),
2644                2 => (3, 0),
2645                3 => (3, 1),
2646                _ => return None,
2647            };
2648            uf_union_n::<8>(&mut parent, ac4(l_ac.0, l_ac.1), ac4(r_ac.0, r_ac.1));
2649        }
2650        let roots: [u8; 8] = std::array::from_fn(|i| uf_find_n::<8>(&mut parent, i as u8));
2651        let mut counts: HashMap<u8, u8> = HashMap::new();
2652        for r in &roots {
2653            *counts.entry(*r).or_insert(0) += 1;
2654        }
2655        if counts.len() != 4 || counts.values().any(|c| *c != 2) {
2656            return None;
2657        }
2658
2659        let mut head_classes: [u8; 4] = [0; 4];
2660        for (i, pc) in columns.iter().enumerate() {
2661            let ProjectExpr::Column(k) = pc else {
2662                return None;
2663            };
2664            let ac = match *k {
2665                0 => (0u8, 0u8),
2666                1 => (0, 1),
2667                2 => (1, 0),
2668                3 => (1, 1),
2669                4 => (2, 0),
2670                5 => (2, 1),
2671                6 => (3, 0),
2672                7 => (3, 1),
2673                _ => return None,
2674            };
2675            head_classes[i] = uf_find_n::<8>(&mut parent, ac4(ac.0, ac.1));
2676        }
2677        for i in 0..4 {
2678            for j in (i + 1)..4 {
2679                if head_classes[i] == head_classes[j] {
2680                    return None;
2681                }
2682            }
2683        }
2684        let w_class = head_classes[0];
2685        let x_class = head_classes[1];
2686        let y_class = head_classes[2];
2687        let z_class = head_classes[3];
2688        let atom_classes = |a: u8| (roots[ac4(a, 0) as usize], roots[ac4(a, 1) as usize]);
2689        let atom_rels = [*rel_ll, *rel_lr, *rel_rl, *rel_rr];
2690        let mut rel_wx = None;
2691        let mut rel_xy = None;
2692        let mut rel_yz = None;
2693        let mut rel_zw = None;
2694        for atom_idx in 0..4u8 {
2695            let (c0, c1) = atom_classes(atom_idx);
2696            let bw = c0 == w_class || c1 == w_class;
2697            let bx = c0 == x_class || c1 == x_class;
2698            let by = c0 == y_class || c1 == y_class;
2699            let bz = c0 == z_class || c1 == z_class;
2700            match (bw, bx, by, bz) {
2701                (true, true, false, false) => rel_wx = Some(atom_rels[atom_idx as usize]),
2702                (false, true, true, false) => rel_xy = Some(atom_rels[atom_idx as usize]),
2703                (false, false, true, true) => rel_yz = Some(atom_rels[atom_idx as usize]),
2704                (true, false, false, true) => rel_zw = Some(atom_rels[atom_idx as usize]),
2705                _ => return None,
2706            }
2707        }
2708        Some(Cycle4Semantics {
2709            rel_wx: rel_wx?,
2710            rel_xy: rel_xy?,
2711            rel_yz: rel_yz?,
2712            rel_zw: rel_zw?,
2713        })
2714    }
2715
2716    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
2717    enum Cycle4Grouping {
2718        Default,
2719        Alt,
2720    }
2721
2722    fn build_4cycle_body(s: &Cycle4Semantics, g: Cycle4Grouping) -> RirNode {
2723        let mk_scan = |r: RelId| RirNode::Scan { rel: r };
2724        match g {
2725            Cycle4Grouping::Default => {
2726                let il = RirNode::Join {
2727                    left: Box::new(mk_scan(s.rel_wx)),
2728                    right: Box::new(mk_scan(s.rel_xy)),
2729                    left_keys: vec![1],
2730                    right_keys: vec![0],
2731                    join_type: JoinType::Inner,
2732                };
2733                let ir = RirNode::Join {
2734                    left: Box::new(mk_scan(s.rel_yz)),
2735                    right: Box::new(mk_scan(s.rel_zw)),
2736                    left_keys: vec![1],
2737                    right_keys: vec![0],
2738                    join_type: JoinType::Inner,
2739                };
2740                let outer = RirNode::Join {
2741                    left: Box::new(il),
2742                    right: Box::new(ir),
2743                    left_keys: vec![0, 3],
2744                    right_keys: vec![3, 0],
2745                    join_type: JoinType::Inner,
2746                };
2747                RirNode::Project {
2748                    input: Box::new(outer),
2749                    columns: vec![
2750                        ProjectExpr::Column(0),
2751                        ProjectExpr::Column(1),
2752                        ProjectExpr::Column(3),
2753                        ProjectExpr::Column(5),
2754                    ],
2755                }
2756            }
2757            Cycle4Grouping::Alt => {
2758                let il = RirNode::Join {
2759                    left: Box::new(mk_scan(s.rel_xy)),
2760                    right: Box::new(mk_scan(s.rel_yz)),
2761                    left_keys: vec![1],
2762                    right_keys: vec![0],
2763                    join_type: JoinType::Inner,
2764                };
2765                let ir = RirNode::Join {
2766                    left: Box::new(mk_scan(s.rel_zw)),
2767                    right: Box::new(mk_scan(s.rel_wx)),
2768                    left_keys: vec![1],
2769                    right_keys: vec![0],
2770                    join_type: JoinType::Inner,
2771                };
2772                let outer = RirNode::Join {
2773                    left: Box::new(il),
2774                    right: Box::new(ir),
2775                    left_keys: vec![0, 3],
2776                    right_keys: vec![3, 0],
2777                    join_type: JoinType::Inner,
2778                };
2779                RirNode::Project {
2780                    input: Box::new(outer),
2781                    columns: vec![
2782                        ProjectExpr::Column(5),
2783                        ProjectExpr::Column(0),
2784                        ProjectExpr::Column(1),
2785                        ProjectExpr::Column(3),
2786                    ],
2787                }
2788            }
2789        }
2790    }
2791
2792    pub fn try_reorder_4cycle(body: &RirNode, stats: &StatsManager) -> Option<RirNode> {
2793        let s = match_and_infer_4cycle(body)?;
2794        let _ = (
2795            populated_card(stats, s.rel_wx)?,
2796            populated_card(stats, s.rel_xy)?,
2797            populated_card(stats, s.rel_yz)?,
2798            populated_card(stats, s.rel_zw)?,
2799        );
2800        let est_default = stats
2801            .estimate_join_cardinality(s.rel_wx, s.rel_xy, &[1], &[0])
2802            .saturating_add(stats.estimate_join_cardinality(s.rel_yz, s.rel_zw, &[1], &[0]));
2803        let est_alt = stats
2804            .estimate_join_cardinality(s.rel_xy, s.rel_yz, &[1], &[0])
2805            .saturating_add(stats.estimate_join_cardinality(s.rel_zw, s.rel_wx, &[1], &[0]));
2806        let chosen = if est_alt < est_default {
2807            Cycle4Grouping::Alt
2808        } else {
2809            Cycle4Grouping::Default
2810        };
2811        let candidate = build_4cycle_body(&s, chosen);
2812        if format!("{:?}", candidate) == format!("{:?}", body) {
2813            return None;
2814        }
2815        Some(candidate)
2816    }
2817}
2818
2819#[cfg(test)]
2820mod selectivity_pass_tests {
2821    use super::selectivity_pass;
2822    use crate::Compiler;
2823    use xlog_stats::StatsManager;
2824
2825    fn body_snapshots(plan: &xlog_ir::ExecutionPlan) -> Vec<String> {
2826        plan.rules_by_scc
2827            .iter()
2828            .flatten()
2829            .map(|r| format!("{:?}", r.body))
2830            .collect()
2831    }
2832
2833    #[test]
2834    fn selectivity_pass_is_noop_for_triangle_plan() {
2835        let mut compiler = Compiler::new();
2836        let plan = compiler
2837            .compile("tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z).")
2838            .expect("compile");
2839        let before = body_snapshots(&plan);
2840        let stats = StatsManager::new();
2841        let mut plan2 = plan.clone();
2842        selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2843        let after = body_snapshots(&plan2);
2844        assert_eq!(
2845            before, after,
2846            "selectivity_pass must preserve every triangle rule body byte-for-byte"
2847        );
2848    }
2849
2850    #[test]
2851    fn selectivity_pass_is_noop_for_4cycle_plan() {
2852        let mut compiler = Compiler::new();
2853        let plan = compiler
2854            .compile("cycle4(W, X, Y, Z) :- e1(W, X), e2(X, Y), e3(Y, Z), e4(Z, W).")
2855            .expect("compile");
2856        let before = body_snapshots(&plan);
2857        let stats = StatsManager::new();
2858        let mut plan2 = plan.clone();
2859        selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2860        let after = body_snapshots(&plan2);
2861        assert_eq!(
2862            before, after,
2863            "selectivity_pass must preserve every 4-cycle rule body byte-for-byte"
2864        );
2865    }
2866
2867    #[test]
2868    fn selectivity_pass_is_noop_for_recursive_scc() {
2869        let mut compiler = Compiler::new();
2870        let plan = compiler
2871            .compile(
2872                "edge(1, 2). edge(2, 3). \
2873                 reach(X, Y) :- edge(X, Y). \
2874                 reach(X, Z) :- reach(X, Y), edge(Y, Z).",
2875            )
2876            .expect("compile");
2877        let before = body_snapshots(&plan);
2878        let stats = StatsManager::new();
2879        let mut plan2 = plan.clone();
2880        selectivity_pass::run(&mut plan2, &stats, &std::collections::HashMap::new());
2881        let after = body_snapshots(&plan2);
2882        assert_eq!(
2883            before, after,
2884            "selectivity_pass must preserve recursive SCC bodies byte-for-byte"
2885        );
2886    }
2887
2888    // ---------------------------------------------------------
2889    // Selectivity-driven reordering tests
2890    // ---------------------------------------------------------
2891
2892    use xlog_core::RelId;
2893    use xlog_ir::plan::{CompiledRule, PlanBuilder, Scc};
2894    use xlog_ir::rir::ProjectExpr;
2895    use xlog_ir::{ExecutionPlan, JoinType, RirNode};
2896
2897    /// Build a hand-crafted canonical lowered triangle plan
2898    /// with three Scans at RelId(1), RelId(2), RelId(3) for
2899    /// (e_xy, e_yz, e_xz). Bypasses the optimizer entirely so
2900    /// the reordering check is a clean stats-→-pair-choice
2901    /// observation, not a confounded test of optimizer plus the rewriter.
2902    ///
2903    /// Default canonical shape (Y-shared inner): inner keys
2904    /// `[1]/[0]`, outer keys `[0,3]/[0,1]`, project `[0,1,3]`.
2905    fn synth_triangle_plan() -> ExecutionPlan {
2906        let inner = RirNode::Join {
2907            left: Box::new(RirNode::Scan { rel: RelId(1) }),
2908            right: Box::new(RirNode::Scan { rel: RelId(2) }),
2909            left_keys: vec![1],
2910            right_keys: vec![0],
2911            join_type: JoinType::Inner,
2912        };
2913        let outer = RirNode::Join {
2914            left: Box::new(inner),
2915            right: Box::new(RirNode::Scan { rel: RelId(3) }),
2916            left_keys: vec![0, 3],
2917            right_keys: vec![0, 1],
2918            join_type: JoinType::Inner,
2919        };
2920        let body = RirNode::Project {
2921            input: Box::new(outer),
2922            columns: vec![
2923                ProjectExpr::Column(0),
2924                ProjectExpr::Column(1),
2925                ProjectExpr::Column(3),
2926            ],
2927        };
2928        let mut builder = PlanBuilder::new();
2929        builder.add_scc(Scc {
2930            id: 0,
2931            predicates: vec!["tri".to_string()],
2932            is_recursive: false,
2933        });
2934        builder.add_rule(
2935            0,
2936            CompiledRule {
2937                head: "tri".to_string(),
2938                body,
2939                meta: Default::default(),
2940            },
2941        );
2942        builder.build()
2943    }
2944
2945    /// Seed a `StatsManager` with three triangle-edge
2946    /// cardinalities at the conventional RelIds (1, 2, 3) used
2947    /// by `synth_triangle_plan`.
2948    fn seed_triangle_stats(c1: u64, c2: u64, c3: u64) -> StatsManager {
2949        let mut stats = StatsManager::new();
2950        for (rid, card) in [(RelId(1), c1), (RelId(2), c2), (RelId(3), c3)] {
2951            stats.register_relation(rid);
2952            stats.update_cardinality(rid, card);
2953        }
2954        stats
2955    }
2956
2957    /// Inspect the (left RelId, right RelId) of the inner Join
2958    /// in a canonical lowered triangle body. Used by selectivity reordering
2959    /// checks.
2960    ///
2961    /// After `compile()` the body is a `MultiWayJoin` whose
2962    /// `fallback` field holds the post-selectivity-pass
2963    /// pre-promotion shape — that's where the inner-pair
2964    /// signature lives. The helper unwraps `MultiWayJoin →
2965    /// fallback` if needed before drilling into the binary
2966    /// Join structure.
2967    fn inspect_triangle_inner_pair(plan: &xlog_ir::ExecutionPlan) -> Option<(RelId, RelId)> {
2968        let body = &plan.rules_by_scc.iter().flatten().next()?.body;
2969        let body = match body {
2970            xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
2971            other => other,
2972        };
2973        let xlog_ir::RirNode::Project { input, .. } = body else {
2974            return None;
2975        };
2976        let xlog_ir::RirNode::Join { left, .. } = input.as_ref() else {
2977            return None;
2978        };
2979        let xlog_ir::RirNode::Join {
2980            left: l2,
2981            right: r2,
2982            ..
2983        } = left.as_ref()
2984        else {
2985            return None;
2986        };
2987        let xlog_ir::RirNode::Scan { rel: rel_l } = l2.as_ref() else {
2988            return None;
2989        };
2990        let xlog_ir::RirNode::Scan { rel: rel_r } = r2.as_ref() else {
2991            return None;
2992        };
2993        Some((*rel_l, *rel_r))
2994    }
2995
2996    /// Snapshot 1: cards favor `(e1, e2)` Y-shared inner.
2997    /// Triangle rule: `tri(X, Y, Z) :- e1(X, Y), e2(Y, Z), e3(X, Z)`.
2998    /// To make Y-shared smallest, give e1 + e2 small cards and
2999    /// e3 a large card so all pair products are dominated by
3000    /// pairs containing e3 — except the pair (e1, e2) which
3001    /// is the smallest product.
3002    #[test]
3003    fn selectivity_pass_picks_y_shared_inner_when_e1_e2_smallest() {
3004        let mut plan = synth_triangle_plan();
3005        // e1=10, e2=10, e3=100_000 → Y-shared (e1⋈e2) smallest.
3006        let stats = seed_triangle_stats(10, 10, 100_000);
3007        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3008        let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3009        // Y-shared inner = (e_xy, e_yz) = (RelId(1), RelId(2)).
3010        assert!(
3011            pair == (RelId(1), RelId(2)) || pair == (RelId(2), RelId(1)),
3012            "expected (RelId(1), RelId(2)) for Y-shared; got {:?}",
3013            pair
3014        );
3015    }
3016
3017    /// Snapshot 2: cards favor `(e1, e3)` X-shared inner.
3018    /// e1 + e3 small, e2 large.
3019    #[test]
3020    fn selectivity_pass_picks_x_shared_inner_when_e1_e3_smallest() {
3021        let mut plan = synth_triangle_plan();
3022        // e1=10, e2=100_000, e3=10 → X-shared (e1⋈e3) smallest.
3023        let stats = seed_triangle_stats(10, 100_000, 10);
3024        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3025        let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3026        // X-shared inner = (e_xy, e_xz) = (RelId(1), RelId(3)).
3027        assert!(
3028            pair == (RelId(1), RelId(3)) || pair == (RelId(3), RelId(1)),
3029            "expected (RelId(1), RelId(3)) for X-shared; got {:?}",
3030            pair
3031        );
3032    }
3033
3034    /// Snapshot 3: cards favor `(e2, e3)` Z-shared inner.
3035    /// e2 + e3 small, e1 large.
3036    #[test]
3037    fn selectivity_pass_picks_z_shared_inner_when_e2_e3_smallest() {
3038        let mut plan = synth_triangle_plan();
3039        // e1=100_000, e2=10, e3=10 → Z-shared (e2⋈e3) smallest.
3040        let stats = seed_triangle_stats(100_000, 10, 10);
3041        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3042        let pair = inspect_triangle_inner_pair(&plan).expect("inner pair");
3043        // Z-shared inner = (e_yz, e_xz) = (RelId(2), RelId(3)).
3044        assert!(
3045            pair == (RelId(2), RelId(3)) || pair == (RelId(3), RelId(2)),
3046            "expected (RelId(2), RelId(3)) for Z-shared; got {:?}",
3047            pair
3048        );
3049    }
3050
3051    /// Two snapshots produce different inner pairs. Pins
3052    /// "stats drive the order, not deterministic
3053    /// canonicalization." Deterministic canonicalization that
3054    /// ignores stats CANNOT pass this gate.
3055    #[test]
3056    fn selectivity_pass_two_snapshots_produce_different_inner_pairs() {
3057        let mut plan_a = synth_triangle_plan();
3058        let stats_a = seed_triangle_stats(10, 10, 100_000); // Y-shared
3059        selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3060        let pair_a = inspect_triangle_inner_pair(&plan_a).expect("snapshot A pair");
3061
3062        let mut plan_b = synth_triangle_plan();
3063        let stats_b = seed_triangle_stats(100_000, 10, 10); // Z-shared
3064        selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3065        let pair_b = inspect_triangle_inner_pair(&plan_b).expect("snapshot B pair");
3066
3067        let normalize = |(a, b): (RelId, RelId)| -> (RelId, RelId) {
3068            if a.0 <= b.0 {
3069                (a, b)
3070            } else {
3071                (b, a)
3072            }
3073        };
3074        assert_ne!(
3075            normalize(pair_a),
3076            normalize(pair_b),
3077            "two different stats snapshots must produce different inner pairs; \
3078             got A = {:?}, B = {:?}",
3079            pair_a,
3080            pair_b
3081        );
3082    }
3083
3084    /// Fallback edge case: relation cards present but no
3085    /// column statistics. The 10% default fallback inside
3086    /// `estimate_join_cardinality` means all three pair
3087    /// estimates collapse to roughly the same ratio. The pass
3088    /// either picks SOME pair or leaves the body unchanged;
3089    /// the test is tolerant by design and documents the
3090    /// uninformative-fallback case explicitly.
3091    #[test]
3092    fn selectivity_pass_with_only_relation_cards_may_pick_arbitrary_pair() {
3093        let mut plan = synth_triangle_plan();
3094        // All three cards equal — no column stats to break ties.
3095        let stats = seed_triangle_stats(100, 100, 100);
3096        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3097        // Either a triangle inner pair is identifiable (any of
3098        // the three) or the body stays unchanged. Both are OK.
3099        let _ = inspect_triangle_inner_pair(&plan);
3100    }
3101
3102    // ---------------------------------------------------------
3103    // 4-cycle compile-time reordering tests
3104    // ---------------------------------------------------------
3105
3106    /// Build a hand-crafted canonical lowered 4-cycle plan
3107    /// with four Scans at RelId(1), RelId(2), RelId(3), RelId(4)
3108    /// for (e_wx, e_xy, e_yz, e_zw). Bypasses the optimizer.
3109    /// Default canonical bushy shape: inner-left
3110    /// `(e_wx ⋈ e_xy)` on X, inner-right `(e_yz ⋈ e_zw)` on Z,
3111    /// outer keys `[0, 3] / [3, 0]`, project `[0, 1, 3, 5]`.
3112    fn synth_4cycle_plan() -> ExecutionPlan {
3113        let inner_left = RirNode::Join {
3114            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3115            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3116            left_keys: vec![1],
3117            right_keys: vec![0],
3118            join_type: JoinType::Inner,
3119        };
3120        let inner_right = RirNode::Join {
3121            left: Box::new(RirNode::Scan { rel: RelId(3) }),
3122            right: Box::new(RirNode::Scan { rel: RelId(4) }),
3123            left_keys: vec![1],
3124            right_keys: vec![0],
3125            join_type: JoinType::Inner,
3126        };
3127        let outer = RirNode::Join {
3128            left: Box::new(inner_left),
3129            right: Box::new(inner_right),
3130            left_keys: vec![0, 3],
3131            right_keys: vec![3, 0],
3132            join_type: JoinType::Inner,
3133        };
3134        let body = RirNode::Project {
3135            input: Box::new(outer),
3136            columns: vec![
3137                ProjectExpr::Column(0),
3138                ProjectExpr::Column(1),
3139                ProjectExpr::Column(3),
3140                ProjectExpr::Column(5),
3141            ],
3142        };
3143        let mut builder = PlanBuilder::new();
3144        builder.add_scc(Scc {
3145            id: 0,
3146            predicates: vec!["cyc".to_string()],
3147            is_recursive: false,
3148        });
3149        builder.add_rule(
3150            0,
3151            CompiledRule {
3152                head: "cyc".to_string(),
3153                body,
3154                meta: Default::default(),
3155            },
3156        );
3157        builder.build()
3158    }
3159
3160    fn seed_4cycle_stats(c1: u64, c2: u64, c3: u64, c4: u64) -> StatsManager {
3161        let mut stats = StatsManager::new();
3162        for (rid, card) in [
3163            (RelId(1), c1),
3164            (RelId(2), c2),
3165            (RelId(3), c3),
3166            (RelId(4), c4),
3167        ] {
3168            stats.register_relation(rid);
3169            stats.update_cardinality(rid, card);
3170        }
3171        stats
3172    }
3173
3174    /// Recover the 4-cycle inner-grouping signature: `(left_left,
3175    /// left_right, right_left, right_right)` Scan RelIds. Used
3176    /// to identify which grouping the rewriter chose.
3177    fn inspect_4cycle_grouping(
3178        plan: &xlog_ir::ExecutionPlan,
3179    ) -> Option<(RelId, RelId, RelId, RelId)> {
3180        let body = &plan.rules_by_scc.iter().flatten().next()?.body;
3181        let body = match body {
3182            xlog_ir::RirNode::MultiWayJoin { fallback, .. } => fallback.as_ref(),
3183            other => other,
3184        };
3185        let xlog_ir::RirNode::Project { input, .. } = body else {
3186            return None;
3187        };
3188        let xlog_ir::RirNode::Join { left, right, .. } = input.as_ref() else {
3189            return None;
3190        };
3191        let xlog_ir::RirNode::Join {
3192            left: ll,
3193            right: lr,
3194            ..
3195        } = left.as_ref()
3196        else {
3197            return None;
3198        };
3199        let xlog_ir::RirNode::Join {
3200            left: rl,
3201            right: rr,
3202            ..
3203        } = right.as_ref()
3204        else {
3205            return None;
3206        };
3207        let xlog_ir::RirNode::Scan { rel: r_ll } = ll.as_ref() else {
3208            return None;
3209        };
3210        let xlog_ir::RirNode::Scan { rel: r_lr } = lr.as_ref() else {
3211            return None;
3212        };
3213        let xlog_ir::RirNode::Scan { rel: r_rl } = rl.as_ref() else {
3214            return None;
3215        };
3216        let xlog_ir::RirNode::Scan { rel: r_rr } = rr.as_ref() else {
3217            return None;
3218        };
3219        Some((*r_ll, *r_lr, *r_rl, *r_rr))
3220    }
3221
3222    /// 4-cycle: cards favor Default grouping
3223    /// `(e_wx⋈e_xy on X) + (e_yz⋈e_zw on Z)`. Default cost is
3224    /// `est(WX⋈XY)+est(YZ⋈ZW) = 0.1*c1*c2 + 0.1*c3*c4`.
3225    /// Alt cost is `0.1*c2*c3 + 0.1*c4*c1`. Default smaller
3226    /// when `c1*c2 + c3*c4 < c2*c3 + c4*c1`. With
3227    /// (c1=10, c2=10, c3=100_000, c4=100_000):
3228    ///   default = 100 + 10^10 ≈ 10^10.
3229    ///   alt = 10^6 + 10^6 ≈ 2*10^6.
3230    /// → alt is smaller, so this fixture actually favors Alt.
3231    /// Use (c1=10, c2=10, c3=10, c4=10_000_000) instead:
3232    ///   default = 100 + 10^8 = 10^8.
3233    ///   alt = 100 + 10^8 = 10^8 (same).
3234    /// Need uneven c4 vs others: (c1=10, c2=10, c3=10_000_000, c4=10):
3235    ///   default = 100 + 10^8 = 10^8.
3236    ///   alt = 10^8 + 100 = 10^8 (same).
3237    /// Default favored when c1*c2 << c2*c3 AND c3*c4 << c4*c1.
3238    /// I.e., c1 small and c4 small relative to c2 and c3.
3239    /// (c1=10, c2=10_000, c3=10_000, c4=10):
3240    ///   default = 0.1*100_000 + 0.1*100_000 = 20_000.
3241    ///   alt = 0.1*100_000_000 + 0.1*100 = 10_000_010.
3242    /// → Default smaller. ✓
3243    #[test]
3244    fn selectivity_pass_4cycle_picks_default_grouping_when_corners_smallest() {
3245        let mut plan = synth_4cycle_plan();
3246        let stats = seed_4cycle_stats(10, 10_000, 10_000, 10);
3247        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3248        let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3249        // Default: (e_wx, e_xy, e_yz, e_zw) = (RelId(1..4)).
3250        assert_eq!(
3251            (ll, lr, rl, rr),
3252            (RelId(1), RelId(2), RelId(3), RelId(4)),
3253            "expected Default grouping"
3254        );
3255    }
3256
3257    /// 4-cycle: cards favor Alt grouping
3258    /// `(e_xy⋈e_yz on Y) + (e_zw⋈e_wx on W)`. Alt smaller when
3259    /// `c2*c3 + c4*c1 < c1*c2 + c3*c4`. Use
3260    /// (c1=10_000, c2=10, c3=10, c4=10_000):
3261    ///   default = 0.1*100_000 + 0.1*100_000 = 20_000.
3262    ///   alt = 0.1*100 + 0.1*10^8 = 10_000_010.
3263    /// → Default still wins. Need c1*c2 LARGE and c3*c4 LARGE
3264    /// while c2*c3 SMALL and c4*c1 SMALL. Try
3265    /// (c1=10_000, c2=10_000, c3=10, c4=10):
3266    ///   default = 0.1*10^8 + 0.1*100 = 10_000_010.
3267    ///   alt = 0.1*100_000 + 0.1*100_000 = 20_000.
3268    /// → Alt smaller. ✓
3269    #[test]
3270    fn selectivity_pass_4cycle_picks_alt_grouping_when_diagonals_smallest() {
3271        let mut plan = synth_4cycle_plan();
3272        let stats = seed_4cycle_stats(10_000, 10_000, 10, 10);
3273        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3274        let (ll, lr, rl, rr) = inspect_4cycle_grouping(&plan).expect("grouping");
3275        // Alt: (e_xy, e_yz, e_zw, e_wx) = (RelId(2), RelId(3), RelId(4), RelId(1)).
3276        assert_eq!(
3277            (ll, lr, rl, rr),
3278            (RelId(2), RelId(3), RelId(4), RelId(1)),
3279            "expected Alt grouping"
3280        );
3281    }
3282
3283    /// Same plan, two stats snapshots → two different
3284    /// 4-cycle groupings. Pins "stats drive the choice" for
3285    /// 4-cycle.
3286    #[test]
3287    fn selectivity_pass_4cycle_two_snapshots_produce_different_groupings() {
3288        let mut plan_a = synth_4cycle_plan();
3289        let stats_a = seed_4cycle_stats(10, 10_000, 10_000, 10); // Default.
3290        selectivity_pass::run(&mut plan_a, &stats_a, &std::collections::HashMap::new());
3291        let g_a = inspect_4cycle_grouping(&plan_a).expect("grouping a");
3292
3293        let mut plan_b = synth_4cycle_plan();
3294        let stats_b = seed_4cycle_stats(10_000, 10_000, 10, 10); // Alt.
3295        selectivity_pass::run(&mut plan_b, &stats_b, &std::collections::HashMap::new());
3296        let g_b = inspect_4cycle_grouping(&plan_b).expect("grouping b");
3297
3298        assert_ne!(
3299            g_a, g_b,
3300            "two different stats snapshots must produce different 4-cycle groupings; \
3301             got A = {:?}, B = {:?}",
3302            g_a, g_b
3303        );
3304    }
3305
3306    /// 4-cycle missing-stats safety floor: any unseeded
3307    /// relation → body unchanged.
3308    #[test]
3309    fn selectivity_pass_4cycle_skips_when_card_missing() {
3310        let mut plan = synth_4cycle_plan();
3311        // Only seed 3 of 4.
3312        let mut stats = StatsManager::new();
3313        for rid in [RelId(1), RelId(2), RelId(3)] {
3314            stats.register_relation(rid);
3315            stats.update_cardinality(rid, 100);
3316        }
3317        let before = format!("{:?}", plan.rules_by_scc[0][0].body);
3318        selectivity_pass::run(&mut plan, &stats, &std::collections::HashMap::new());
3319        let after = format!("{:?}", plan.rules_by_scc[0][0].body);
3320        assert_eq!(
3321            before, after,
3322            "missing-stats safety floor must leave body unchanged"
3323        );
3324    }
3325}
3326
3327#[cfg(test)]
3328mod tests {
3329    use super::*;
3330    use xlog_core::ScalarType;
3331    use xlog_ir::{ConstValue, ProjectExpr};
3332    use xlog_stats::ColumnStats;
3333
3334    fn make_stats_manager() -> Arc<StatsManager> {
3335        let mut mgr = StatsManager::new();
3336
3337        // Register test relations with realistic statistics
3338        mgr.register_relation(RelId(1));
3339        mgr.update_cardinality(RelId(1), 10_000);
3340        mgr.update_byte_size(RelId(1), 320_000); // ~32 bytes per row
3341
3342        mgr.register_relation(RelId(2));
3343        mgr.update_cardinality(RelId(2), 5_000);
3344        mgr.update_byte_size(RelId(2), 160_000);
3345
3346        mgr.register_relation(RelId(3));
3347        mgr.update_cardinality(RelId(3), 1_000);
3348        mgr.update_byte_size(RelId(3), 32_000);
3349
3350        // Add column statistics for relation 1
3351        let mut col0 = ColumnStats::new(0, ScalarType::I64);
3352        col0.update_distinct(1000);
3353        col0.update_range(0, 10000);
3354        mgr.add_column_stats(RelId(1), col0);
3355
3356        let mut col1 = ColumnStats::new(1, ScalarType::I64);
3357        col1.update_distinct(100);
3358        mgr.add_column_stats(RelId(1), col1);
3359
3360        Arc::new(mgr)
3361    }
3362
3363    #[test]
3364    fn test_optimizer_new() {
3365        let stats = make_stats_manager();
3366        let optimizer = Optimizer::new(stats);
3367
3368        assert_eq!(optimizer.config().dp_threshold, 10);
3369        assert!(optimizer.config().enable_pushdown);
3370    }
3371
3372    #[test]
3373    fn test_optimizer_with_config() {
3374        let stats = make_stats_manager();
3375        let config = OptimizerConfig {
3376            dp_threshold: 5,
3377            enable_pushdown: false,
3378            ..Default::default()
3379        };
3380        let optimizer = Optimizer::with_config(stats, config);
3381
3382        assert_eq!(optimizer.config().dp_threshold, 5);
3383        assert!(!optimizer.config().enable_pushdown);
3384    }
3385
3386    #[test]
3387    fn test_estimate_scan_cost() {
3388        let stats = make_stats_manager();
3389        let optimizer = Optimizer::new(stats);
3390
3391        let scan = RirNode::Scan { rel: RelId(1) };
3392        let cost = optimizer.estimate_cost(&scan);
3393
3394        assert_eq!(cost.rows, 10_000);
3395        assert!(cost.gpu_mem > 0);
3396        assert_eq!(cost.transfers, 0); // Data on GPU
3397    }
3398
3399    #[test]
3400    fn test_estimate_scan_cost_unknown_relation() {
3401        let stats = Arc::new(StatsManager::new());
3402        let optimizer = Optimizer::new(stats);
3403
3404        let scan = RirNode::Scan { rel: RelId(999) };
3405        let cost = optimizer.estimate_cost(&scan);
3406
3407        // Should use defaults
3408        assert_eq!(cost.rows, 1000);
3409    }
3410
3411    #[test]
3412    fn test_estimate_filter_cost() {
3413        let stats = make_stats_manager();
3414        let optimizer = Optimizer::new(stats);
3415
3416        let filter = RirNode::Filter {
3417            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3418            predicate: Expr::Compare {
3419                left: Box::new(Expr::Column(0)),
3420                op: CompareOp::Eq,
3421                right: Box::new(Expr::Const(ConstValue::I64(42))),
3422            },
3423        };
3424
3425        let cost = optimizer.estimate_cost(&filter);
3426
3427        // Filter should reduce row count
3428        assert!(cost.rows < 10_000);
3429        assert!(cost.rows >= 1);
3430    }
3431
3432    #[test]
3433    fn test_estimate_join_cost() {
3434        let stats = make_stats_manager();
3435        let optimizer = Optimizer::new(stats);
3436
3437        let join = RirNode::Join {
3438            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3439            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3440            left_keys: vec![0],
3441            right_keys: vec![0],
3442            join_type: JoinType::Inner,
3443        };
3444
3445        let cost = optimizer.estimate_cost(&join);
3446
3447        // Should have positive estimates
3448        assert!(cost.rows > 0);
3449        assert!(cost.cpu_cost > 0.0);
3450        assert!(cost.gpu_mem > 0);
3451    }
3452
3453    #[test]
3454    fn test_estimate_join_cost_with_selectivity() {
3455        let mut mgr = StatsManager::new();
3456        mgr.register_relation(RelId(1));
3457        mgr.register_relation(RelId(2));
3458        mgr.update_cardinality(RelId(1), 1000);
3459        mgr.update_cardinality(RelId(2), 500);
3460
3461        // Record a join result to cache selectivity
3462        mgr.record_join_result(RelId(1), RelId(2), vec![0], vec![0], 500_000, 2500);
3463
3464        let optimizer = Optimizer::new(Arc::new(mgr));
3465
3466        let join = RirNode::Join {
3467            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3468            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3469            left_keys: vec![0],
3470            right_keys: vec![0],
3471            join_type: JoinType::Inner,
3472        };
3473
3474        let cost = optimizer.estimate_cost(&join);
3475
3476        // Should use cached selectivity for estimate
3477        assert!(cost.rows > 0);
3478    }
3479
3480    #[test]
3481    fn test_predicate_pushdown_simple_scan() {
3482        let stats = make_stats_manager();
3483        let optimizer = Optimizer::new(stats);
3484
3485        let scan = RirNode::Scan { rel: RelId(1) };
3486        let optimized = optimizer.optimize(scan);
3487
3488        // Scan should pass through unchanged
3489        assert!(matches!(optimized, RirNode::Scan { rel: RelId(1) }));
3490    }
3491
3492    #[test]
3493    fn test_predicate_pushdown_filter_on_scan() {
3494        let stats = make_stats_manager();
3495        let optimizer = Optimizer::new(stats);
3496
3497        let filter = RirNode::Filter {
3498            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3499            predicate: Expr::Compare {
3500                left: Box::new(Expr::Column(0)),
3501                op: CompareOp::Eq,
3502                right: Box::new(Expr::Const(ConstValue::I64(42))),
3503            },
3504        };
3505
3506        let optimized = optimizer.optimize(filter);
3507
3508        // Filter on scan should stay in place
3509        assert!(matches!(optimized, RirNode::Filter { .. }));
3510    }
3511
3512    #[test]
3513    fn test_predicate_pushdown_merges_filters() {
3514        let stats = make_stats_manager();
3515        let optimizer = Optimizer::new(stats);
3516
3517        let nested_filter = RirNode::Filter {
3518            input: Box::new(RirNode::Filter {
3519                input: Box::new(RirNode::Scan { rel: RelId(1) }),
3520                predicate: Expr::Compare {
3521                    left: Box::new(Expr::Column(0)),
3522                    op: CompareOp::Gt,
3523                    right: Box::new(Expr::Const(ConstValue::I64(0))),
3524                },
3525            }),
3526            predicate: Expr::Compare {
3527                left: Box::new(Expr::Column(0)),
3528                op: CompareOp::Lt,
3529                right: Box::new(Expr::Const(ConstValue::I64(100))),
3530            },
3531        };
3532
3533        let optimized = optimizer.optimize(nested_filter);
3534
3535        // Filters should be merged into AND
3536        if let RirNode::Filter { predicate, .. } = optimized {
3537            assert!(matches!(predicate, Expr::And(_)));
3538        } else {
3539            panic!("Expected Filter node");
3540        }
3541    }
3542
3543    #[test]
3544    fn test_predicate_pushdown_through_project() {
3545        let stats = make_stats_manager();
3546        let optimizer = Optimizer::new(stats);
3547
3548        // Filter on projected column that's a pass-through
3549        let plan = RirNode::Filter {
3550            input: Box::new(RirNode::Project {
3551                input: Box::new(RirNode::Scan { rel: RelId(1) }),
3552                columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(1)],
3553            }),
3554            predicate: Expr::Compare {
3555                left: Box::new(Expr::Column(0)),
3556                op: CompareOp::Eq,
3557                right: Box::new(Expr::Const(ConstValue::I64(42))),
3558            },
3559        };
3560
3561        let optimized = optimizer.optimize(plan);
3562
3563        // Filter should be pushed below project
3564        assert!(matches!(optimized, RirNode::Project { .. }));
3565        if let RirNode::Project { input, .. } = optimized {
3566            assert!(matches!(*input, RirNode::Filter { .. }));
3567        }
3568    }
3569
3570    #[test]
3571    fn test_predicate_pushdown_into_join() {
3572        let stats = make_stats_manager();
3573        let optimizer = Optimizer::new(stats);
3574
3575        // Filter on left side column only
3576        let plan = RirNode::Filter {
3577            input: Box::new(RirNode::Join {
3578                left: Box::new(RirNode::Scan { rel: RelId(1) }),
3579                right: Box::new(RirNode::Scan { rel: RelId(2) }),
3580                left_keys: vec![0],
3581                right_keys: vec![0],
3582                join_type: JoinType::Inner,
3583            }),
3584            predicate: Expr::Compare {
3585                left: Box::new(Expr::Column(0)), // Left side column
3586                op: CompareOp::Eq,
3587                right: Box::new(Expr::Const(ConstValue::I64(42))),
3588            },
3589        };
3590
3591        let optimized = optimizer.optimize(plan);
3592
3593        // Filter should be pushed into left side of join
3594        if let RirNode::Join { left, .. } = optimized {
3595            assert!(matches!(*left, RirNode::Filter { .. }));
3596        } else {
3597            panic!("Expected Join node");
3598        }
3599    }
3600
3601    #[test]
3602    fn test_plan_cost_total() {
3603        let cost = PlanCost {
3604            rows: 1000,
3605            cpu_cost: 100.0,
3606            gpu_mem: 1_000_000,
3607            transfers: 2,
3608        };
3609
3610        let total = cost.total_cost(100.0);
3611
3612        // cpu_cost + gpu_mem*0.001 + transfers*100
3613        // 100.0 + 1000.0 + 200.0 = 1300.0
3614        assert!((total - 1300.0).abs() < 0.001);
3615    }
3616
3617    #[test]
3618    fn test_plan_cost_then() {
3619        let cost1 = PlanCost {
3620            rows: 1000,
3621            cpu_cost: 50.0,
3622            gpu_mem: 500,
3623            transfers: 1,
3624        };
3625
3626        let cost2 = PlanCost {
3627            rows: 500,
3628            cpu_cost: 25.0,
3629            gpu_mem: 800,
3630            transfers: 1,
3631        };
3632
3633        let combined = cost1.then(cost2);
3634
3635        assert_eq!(combined.rows, 500); // Takes output rows from second
3636        assert_eq!(combined.cpu_cost, 75.0);
3637        assert_eq!(combined.gpu_mem, 800); // Peak memory
3638        assert_eq!(combined.transfers, 2);
3639    }
3640
3641    #[test]
3642    fn test_optimizer_config_default() {
3643        let config = OptimizerConfig::default();
3644
3645        assert_eq!(config.dp_threshold, 10);
3646        assert!((config.index_heat_threshold - 0.7).abs() < 0.001);
3647        assert!(config.enable_pushdown);
3648        assert!((config.default_filter_selectivity - 0.1).abs() < 0.001);
3649    }
3650
3651    #[test]
3652    fn test_should_use_greedy() {
3653        let stats = make_stats_manager();
3654        let config = OptimizerConfig {
3655            dp_threshold: 2,
3656            ..Default::default()
3657        };
3658        let optimizer = Optimizer::with_config(stats, config);
3659
3660        // Single relation: should NOT use greedy
3661        let single = RirNode::Scan { rel: RelId(1) };
3662        assert!(!optimizer.should_use_greedy(&single));
3663
3664        // Three relations: should use greedy (threshold is 2)
3665        let multi = RirNode::Join {
3666            left: Box::new(RirNode::Join {
3667                left: Box::new(RirNode::Scan { rel: RelId(1) }),
3668                right: Box::new(RirNode::Scan { rel: RelId(2) }),
3669                left_keys: vec![0],
3670                right_keys: vec![0],
3671                join_type: JoinType::Inner,
3672            }),
3673            right: Box::new(RirNode::Scan { rel: RelId(3) }),
3674            left_keys: vec![0],
3675            right_keys: vec![0],
3676            join_type: JoinType::Inner,
3677        };
3678        assert!(optimizer.should_use_greedy(&multi));
3679    }
3680
3681    #[test]
3682    fn test_recommend_indexes() {
3683        let mut mgr = StatsManager::new();
3684        mgr.register_relation(RelId(1));
3685        mgr.register_relation(RelId(2));
3686
3687        // Heat up relation 1 extensively
3688        for _ in 0..50 {
3689            mgr.record_access(RelId(1));
3690        }
3691
3692        let optimizer = Optimizer::new(Arc::new(mgr));
3693        let recommendations = optimizer.recommend_indexes();
3694
3695        assert!(recommendations.contains(&RelId(1)));
3696        assert!(!recommendations.contains(&RelId(2)));
3697    }
3698
3699    #[test]
3700    fn test_estimate_groupby_cost() {
3701        let stats = make_stats_manager();
3702        let optimizer = Optimizer::new(stats);
3703
3704        let groupby = RirNode::GroupBy {
3705            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3706            key_cols: vec![0],
3707            aggs: vec![(1, xlog_core::AggOp::Sum)],
3708        };
3709
3710        let cost = optimizer.estimate_cost(&groupby);
3711
3712        // GroupBy should reduce row count
3713        assert!(cost.rows < 10_000);
3714        assert!(cost.rows >= 1);
3715    }
3716
3717    #[test]
3718    fn test_estimate_union_cost() {
3719        let stats = make_stats_manager();
3720        let optimizer = Optimizer::new(stats);
3721
3722        let union = RirNode::Union {
3723            inputs: vec![
3724                RirNode::Scan { rel: RelId(1) },
3725                RirNode::Scan { rel: RelId(2) },
3726            ],
3727        };
3728
3729        let cost = optimizer.estimate_cost(&union);
3730
3731        // Union sums row counts
3732        assert_eq!(cost.rows, 15_000); // 10000 + 5000
3733    }
3734
3735    #[test]
3736    fn test_estimate_distinct_cost() {
3737        let stats = make_stats_manager();
3738        let optimizer = Optimizer::new(stats);
3739
3740        let distinct = RirNode::Distinct {
3741            input: Box::new(RirNode::Scan { rel: RelId(1) }),
3742            key_cols: vec![0],
3743        };
3744
3745        let cost = optimizer.estimate_cost(&distinct);
3746
3747        // Distinct reduces rows
3748        assert!(cost.rows <= 10_000);
3749        assert!(cost.rows >= 1);
3750    }
3751
3752    #[test]
3753    fn test_estimate_diff_cost() {
3754        let stats = make_stats_manager();
3755        let optimizer = Optimizer::new(stats);
3756
3757        let diff = RirNode::Diff {
3758            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3759            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3760        };
3761
3762        let cost = optimizer.estimate_cost(&diff);
3763
3764        // Diff reduces left side
3765        assert!(cost.rows <= 10_000);
3766        assert!(cost.rows >= 1);
3767    }
3768
3769    #[test]
3770    fn test_estimate_fixpoint_cost() {
3771        let stats = make_stats_manager();
3772        let optimizer = Optimizer::new(stats);
3773
3774        let fixpoint = RirNode::Fixpoint {
3775            scc_id: 0,
3776            base: Box::new(RirNode::Scan { rel: RelId(1) }),
3777            recursive: Box::new(RirNode::Scan { rel: RelId(1) }),
3778            delta_rel: RelId(10),
3779            full_rel: RelId(11),
3780        };
3781
3782        let cost = optimizer.estimate_cost(&fixpoint);
3783
3784        // Fixpoint accumulates rows across iterations
3785        assert!(cost.rows >= 10_000);
3786    }
3787
3788    #[test]
3789    fn test_predicate_selectivity_equality() {
3790        let stats = make_stats_manager();
3791        let optimizer = Optimizer::new(stats);
3792
3793        let scan = RirNode::Scan { rel: RelId(1) };
3794
3795        // Equality predicate
3796        let eq_pred = Expr::Compare {
3797            left: Box::new(Expr::Column(0)),
3798            op: CompareOp::Eq,
3799            right: Box::new(Expr::Const(ConstValue::I64(42))),
3800        };
3801
3802        let selectivity = optimizer.estimate_predicate_selectivity(&eq_pred, &scan);
3803
3804        // With 1000 distinct values, selectivity should be ~0.001
3805        assert!(selectivity < 0.01);
3806        assert!(selectivity > 0.0);
3807    }
3808
3809    #[test]
3810    fn test_predicate_selectivity_and() {
3811        let stats = make_stats_manager();
3812        let optimizer = Optimizer::new(stats);
3813
3814        let scan = RirNode::Scan { rel: RelId(1) };
3815
3816        // AND of two predicates
3817        let and_pred = Expr::And(vec![
3818            Expr::Compare {
3819                left: Box::new(Expr::Column(0)),
3820                op: CompareOp::Gt,
3821                right: Box::new(Expr::Const(ConstValue::I64(0))),
3822            },
3823            Expr::Compare {
3824                left: Box::new(Expr::Column(0)),
3825                op: CompareOp::Lt,
3826                right: Box::new(Expr::Const(ConstValue::I64(100))),
3827            },
3828        ]);
3829
3830        let selectivity = optimizer.estimate_predicate_selectivity(&and_pred, &scan);
3831
3832        // Product of individual selectivities (0.33 * 0.33 ≈ 0.11)
3833        assert!(selectivity < 0.5);
3834        assert!(selectivity > 0.0);
3835    }
3836
3837    #[test]
3838    fn test_predicate_selectivity_not() {
3839        let stats = make_stats_manager();
3840        let optimizer = Optimizer::new(stats);
3841
3842        let scan = RirNode::Scan { rel: RelId(1) };
3843
3844        // NOT of equality
3845        let not_pred = Expr::Not(Box::new(Expr::Compare {
3846            left: Box::new(Expr::Column(0)),
3847            op: CompareOp::Eq,
3848            right: Box::new(Expr::Const(ConstValue::I64(42))),
3849        }));
3850
3851        let selectivity = optimizer.estimate_predicate_selectivity(&not_pred, &scan);
3852
3853        // NOT(equality) should have high selectivity
3854        assert!(selectivity > 0.9);
3855    }
3856
3857    #[test]
3858    fn test_join_type_semi() {
3859        let stats = make_stats_manager();
3860        let optimizer = Optimizer::new(stats);
3861
3862        let semi_join = RirNode::Join {
3863            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3864            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3865            left_keys: vec![0],
3866            right_keys: vec![0],
3867            join_type: JoinType::Semi,
3868        };
3869
3870        let cost = optimizer.estimate_cost(&semi_join);
3871
3872        // Semi join outputs at most left side rows
3873        assert!(cost.rows <= 10_000);
3874    }
3875
3876    #[test]
3877    fn test_join_type_anti() {
3878        let stats = make_stats_manager();
3879        let optimizer = Optimizer::new(stats);
3880
3881        let anti_join = RirNode::Join {
3882            left: Box::new(RirNode::Scan { rel: RelId(1) }),
3883            right: Box::new(RirNode::Scan { rel: RelId(2) }),
3884            left_keys: vec![0],
3885            right_keys: vec![0],
3886            join_type: JoinType::Anti,
3887        };
3888
3889        let cost = optimizer.estimate_cost(&anti_join);
3890
3891        // Anti join outputs at most left side rows
3892        assert!(cost.rows <= 10_000);
3893    }
3894
3895    #[test]
3896    fn test_pushdown_disabled() {
3897        let stats = make_stats_manager();
3898        let config = OptimizerConfig {
3899            enable_pushdown: false,
3900            ..Default::default()
3901        };
3902        let optimizer = Optimizer::with_config(stats, config);
3903
3904        // Filter that could be pushed
3905        let plan = RirNode::Filter {
3906            input: Box::new(RirNode::Filter {
3907                input: Box::new(RirNode::Scan { rel: RelId(1) }),
3908                predicate: Expr::Compare {
3909                    left: Box::new(Expr::Column(0)),
3910                    op: CompareOp::Gt,
3911                    right: Box::new(Expr::Const(ConstValue::I64(0))),
3912                },
3913            }),
3914            predicate: Expr::Compare {
3915                left: Box::new(Expr::Column(0)),
3916                op: CompareOp::Lt,
3917                right: Box::new(Expr::Const(ConstValue::I64(100))),
3918            },
3919        };
3920
3921        let optimized = optimizer.optimize(plan.clone());
3922
3923        // With pushdown disabled, structure should remain the same
3924        // (outer filter, inner filter, scan)
3925        if let RirNode::Filter { input, .. } = optimized {
3926            assert!(matches!(*input, RirNode::Filter { .. }));
3927        } else {
3928            panic!("Expected Filter node");
3929        }
3930    }
3931
3932    #[test]
3933    fn test_collect_columns() {
3934        let expr = Expr::And(vec![
3935            Expr::Compare {
3936                left: Box::new(Expr::Column(0)),
3937                op: CompareOp::Eq,
3938                right: Box::new(Expr::Column(2)),
3939            },
3940            Expr::Compare {
3941                left: Box::new(Expr::Column(1)),
3942                op: CompareOp::Gt,
3943                right: Box::new(Expr::Const(ConstValue::I64(0))),
3944            },
3945        ]);
3946
3947        let cols = Optimizer::collect_columns(&expr);
3948
3949        assert!(cols.contains(&0));
3950        assert!(cols.contains(&1));
3951        assert!(cols.contains(&2));
3952    }
3953
3954    #[test]
3955    fn test_flatten_and() {
3956        let nested = Expr::And(vec![
3957            Expr::And(vec![
3958                Expr::Compare {
3959                    left: Box::new(Expr::Column(0)),
3960                    op: CompareOp::Eq,
3961                    right: Box::new(Expr::Const(ConstValue::I64(1))),
3962                },
3963                Expr::Compare {
3964                    left: Box::new(Expr::Column(1)),
3965                    op: CompareOp::Eq,
3966                    right: Box::new(Expr::Const(ConstValue::I64(2))),
3967                },
3968            ]),
3969            Expr::Compare {
3970                left: Box::new(Expr::Column(2)),
3971                op: CompareOp::Eq,
3972                right: Box::new(Expr::Const(ConstValue::I64(3))),
3973            },
3974        ]);
3975
3976        let flattened = Optimizer::flatten_and(&nested);
3977
3978        assert_eq!(flattened.len(), 3);
3979    }
3980
3981    #[test]
3982    fn test_conjoin_single() {
3983        let single = vec![Expr::Compare {
3984            left: Box::new(Expr::Column(0)),
3985            op: CompareOp::Eq,
3986            right: Box::new(Expr::Const(ConstValue::I64(42))),
3987        }];
3988
3989        let result = Optimizer::conjoin(single);
3990
3991        assert!(matches!(result, Expr::Compare { .. }));
3992    }
3993
3994    #[test]
3995    fn test_conjoin_multiple() {
3996        let multiple = vec![
3997            Expr::Compare {
3998                left: Box::new(Expr::Column(0)),
3999                op: CompareOp::Eq,
4000                right: Box::new(Expr::Const(ConstValue::I64(1))),
4001            },
4002            Expr::Compare {
4003                left: Box::new(Expr::Column(1)),
4004                op: CompareOp::Eq,
4005                right: Box::new(Expr::Const(ConstValue::I64(2))),
4006            },
4007        ];
4008
4009        let result = Optimizer::conjoin(multiple);
4010
4011        assert!(matches!(result, Expr::And(_)));
4012    }
4013
4014    #[test]
4015    fn test_predicate_pushdown_with_schemas() {
4016        // Regression test: ensure predicate pushdown uses schemas for accurate width estimation.
4017        // Without schemas, the optimizer could incorrectly remap column indices.
4018        let stats = make_stats_manager();
4019        let mut optimizer = Optimizer::new(stats);
4020
4021        // Set up schemas: left has 3 columns, right has 3 columns
4022        let left_schema = Schema::new(vec![
4023            ("c0".to_string(), xlog_core::ScalarType::Symbol),
4024            ("c1".to_string(), xlog_core::ScalarType::Symbol),
4025            ("c2".to_string(), xlog_core::ScalarType::Symbol),
4026        ]);
4027        let right_schema = Schema::new(vec![
4028            ("c0".to_string(), xlog_core::ScalarType::Symbol),
4029            ("c1".to_string(), xlog_core::ScalarType::Symbol),
4030            ("c2".to_string(), xlog_core::ScalarType::U32),
4031        ]);
4032
4033        let mut schemas = HashMap::new();
4034        schemas.insert(RelId(1), left_schema);
4035        schemas.insert(RelId(2), right_schema);
4036        optimizer.set_schemas(schemas);
4037
4038        // Filter on Column(5) which is in the right side (left_width=3, so column 5-3=2 in right)
4039        let plan = RirNode::Filter {
4040            input: Box::new(RirNode::Join {
4041                left: Box::new(RirNode::Scan { rel: RelId(1) }),
4042                right: Box::new(RirNode::Scan { rel: RelId(2) }),
4043                left_keys: vec![0],
4044                right_keys: vec![0],
4045                join_type: JoinType::Inner,
4046            }),
4047            predicate: Expr::Compare {
4048                left: Box::new(Expr::Column(5)), // Right side column (index 5 = 3 + 2)
4049                op: CompareOp::Ge,
4050                right: Box::new(Expr::Const(ConstValue::U32(4))),
4051            },
4052        };
4053
4054        let optimized = optimizer.optimize(plan);
4055
4056        // Filter should be pushed into right side of join with Column(2) (remapped from 5-3=2)
4057        if let RirNode::Join { right, .. } = optimized {
4058            if let RirNode::Filter { predicate, .. } = *right {
4059                if let Expr::Compare { left, .. } = predicate {
4060                    if let Expr::Column(idx) = *left {
4061                        assert_eq!(
4062                            idx, 2,
4063                            "Column should be remapped to 2 (5 - left_width(3) = 2)"
4064                        );
4065                    } else {
4066                        panic!("Expected Column expression");
4067                    }
4068                } else {
4069                    panic!("Expected Compare predicate");
4070                }
4071            } else {
4072                panic!("Expected Filter on right side of join");
4073            }
4074        } else {
4075            panic!("Expected Join node");
4076        }
4077    }
4078
4079    /// Optimizer fallback arms for `MultiWayJoin`.
4080    ///
4081    /// The promoter runs after `Optimizer::optimize` in `Compiler`, so
4082    /// these arms are unreachable in production. They exist for compile
4083    /// safety and to pin the documented semantics: `optimize` returns
4084    /// the node unchanged, `estimate_width` reports the head arity from
4085    /// `output_columns`, `estimate_cost` is the sum of input costs, and
4086    /// `find_column_relation` returns `None` under the optimizer fallback.
4087    ///
4088    /// Shape-agnostic coverage extends each test below to also exercise a
4089    /// synthesized four-input `MultiWayJoin` via [`build_4input_multiway`].
4090    /// This pins shape-agnosticism: the arms must NOT hard-code
4091    /// `inputs.len() == 3` or `output_columns.len() == 3`. The four-way
4092    /// promoter path produces real four-input bodies; these tests are the
4093    /// load-bearing guard against silent regression.
4094    fn build_canonical_triangle_multiway() -> RirNode {
4095        let scan_xy = RirNode::Scan { rel: RelId(1) };
4096        let scan_yz = RirNode::Scan { rel: RelId(2) };
4097        let scan_xz = RirNode::Scan { rel: RelId(3) };
4098        let inner_join = RirNode::Join {
4099            left: Box::new(scan_xy.clone()),
4100            right: Box::new(scan_yz.clone()),
4101            left_keys: vec![1],
4102            right_keys: vec![0],
4103            join_type: JoinType::Inner,
4104        };
4105        let outer_join = RirNode::Join {
4106            left: Box::new(inner_join),
4107            right: Box::new(scan_xz.clone()),
4108            left_keys: vec![0, 3],
4109            right_keys: vec![0, 1],
4110            join_type: JoinType::Inner,
4111        };
4112        let fallback = RirNode::Project {
4113            input: Box::new(outer_join),
4114            columns: vec![
4115                ProjectExpr::Column(0),
4116                ProjectExpr::Column(1),
4117                ProjectExpr::Column(3),
4118            ],
4119        };
4120        RirNode::MultiWayJoin {
4121            inputs: vec![scan_xy, scan_yz, scan_xz],
4122            slot_vars: vec![
4123                vec![Some(0), Some(1)],
4124                vec![Some(1), Some(2)],
4125                vec![Some(0), Some(2)],
4126            ],
4127            output_columns: vec![
4128                ProjectExpr::Column(0),
4129                ProjectExpr::Column(1),
4130                ProjectExpr::Column(3),
4131            ],
4132            fallback: Box::new(fallback),
4133            plan: None,
4134            var_order: None,
4135        }
4136    }
4137
4138    /// Synthesized four-input `MultiWayJoin` for shape-agnosticism testing.
4139    /// The original promoter shape is triangle-only, so this shape never
4140    /// reaches `Optimizer` through the production pipeline; the tests below
4141    /// exercise the optimizer arms directly.
4142    ///
4143    /// Inputs reuse `RelId(1, 2, 3, 1)` — RelId(1) repeats — so the
4144    /// stats manager registered in `make_stats_manager` covers all
4145    /// four scans. Cost floor is `2*10_000 + 5_000 + 1_000 = 26_000`.
4146    fn build_4input_multiway() -> RirNode {
4147        let scans = [RelId(1), RelId(2), RelId(3), RelId(1)]
4148            .map(|rel| RirNode::Scan { rel })
4149            .to_vec();
4150        // 4-cycle slot_vars [[A,B],[B,C],[C,D],[A,D]].
4151        let slot_vars = vec![
4152            vec![Some(0u32), Some(1)],
4153            vec![Some(1u32), Some(2)],
4154            vec![Some(2u32), Some(3)],
4155            vec![Some(0u32), Some(3)],
4156        ];
4157        // 4-arity head projection (no real semantic meaning — the
4158        // synthesized fallback is a stub).
4159        let output_columns = vec![
4160            ProjectExpr::Column(0),
4161            ProjectExpr::Column(1),
4162            ProjectExpr::Column(2),
4163            ProjectExpr::Column(3),
4164        ];
4165        // Stub fallback: the optimizer arms do not execute fallback,
4166        // so any RirNode is fine. Use Unit to keep the fixture small.
4167        let fallback = RirNode::Unit;
4168        RirNode::MultiWayJoin {
4169            inputs: scans,
4170            slot_vars,
4171            output_columns,
4172            fallback: Box::new(fallback),
4173            plan: None,
4174            var_order: None,
4175        }
4176    }
4177
4178    #[test]
4179    fn optimize_returns_multiway_unchanged() {
4180        let optimizer = Optimizer::new(make_stats_manager());
4181        for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4182            let optimized = optimizer.optimize(node.clone());
4183            match (&node, &optimized) {
4184                (
4185                    RirNode::MultiWayJoin {
4186                        inputs: a_in,
4187                        output_columns: a_out,
4188                        ..
4189                    },
4190                    RirNode::MultiWayJoin {
4191                        inputs: b_in,
4192                        output_columns: b_out,
4193                        ..
4194                    },
4195                ) => {
4196                    assert_eq!(a_in.len(), b_in.len());
4197                    assert_eq!(a_out.len(), b_out.len());
4198                }
4199                _ => panic!("optimize() must return a MultiWayJoin"),
4200            }
4201        }
4202    }
4203
4204    #[test]
4205    fn estimate_width_uses_output_columns_arity() {
4206        let optimizer = Optimizer::new(make_stats_manager());
4207        // Canonical triangle: 3 head columns.
4208        assert_eq!(
4209            optimizer.estimate_width(&build_canonical_triangle_multiway()),
4210            3
4211        );
4212        // 4-input synthesized: 4 head columns. Locks shape-
4213        // agnosticism — the arm must use output_columns.len(),
4214        // not a hard-coded 3.
4215        assert_eq!(optimizer.estimate_width(&build_4input_multiway()), 4);
4216    }
4217
4218    #[test]
4219    fn estimate_cost_sums_input_costs() {
4220        let optimizer = Optimizer::new(make_stats_manager());
4221
4222        // Canonical triangle: rels 1, 2, 3 with cardinalities
4223        // 10_000 + 5_000 + 1_000 = 16_000.
4224        let cost_tri = optimizer.estimate_cost(&build_canonical_triangle_multiway());
4225        assert!(
4226            cost_tri.rows >= 16_000,
4227            "expected cost.rows >= 16000, got {}",
4228            cost_tri.rows
4229        );
4230
4231        // 4-input synthesized: rels 1, 2, 3, 1 → 2*10_000 + 5_000 +
4232        // 1_000 = 26_000. The arm sums all four inputs; cost grows.
4233        // Locks shape-agnosticism — the arm must walk every entry
4234        // in `inputs`, not a hard-coded 3.
4235        let cost_4 = optimizer.estimate_cost(&build_4input_multiway());
4236        assert!(
4237            cost_4.rows >= 26_000,
4238            "expected 4-input cost.rows >= 26000, got {}",
4239            cost_4.rows
4240        );
4241        assert!(
4242            cost_4.rows > cost_tri.rows,
4243            "4-input cost ({}) must exceed triangle cost ({})",
4244            cost_4.rows,
4245            cost_tri.rows
4246        );
4247    }
4248
4249    #[test]
4250    fn find_column_relation_returns_none_for_multiway() {
4251        let optimizer = Optimizer::new(make_stats_manager());
4252        // Optimizer fallback guardrail: no column-to-input mapping is exposed
4253        // here. Half-mapped is more dangerous than None. The arm must return
4254        // None regardless of arity; the synthesized four-input shape catches a
4255        // future "let's just return inputs[col_idx % len]" patch.
4256        for node in [build_canonical_triangle_multiway(), build_4input_multiway()] {
4257            for col in 0..node.referenced_relations().len() {
4258                assert!(
4259                    optimizer.find_column_relation(&node, col).is_none(),
4260                    "find_column_relation must return None for any \
4261                     MultiWayJoin column (col={})",
4262                    col,
4263                );
4264            }
4265        }
4266    }
4267}