Skip to main content

xlog_logic/
lower.rs

1//! Lowering from AST to IR
2//!
3//! This module transforms Datalog programs (AST) into the Relational IR (RIR)
4//! representation for execution. The lowering process:
5//!
6//! 1. Infers schemas from facts and predicate declarations
7//! 2. Tracks variable positions across atoms for join key computation
8//! 3. Builds left-deep join trees for multi-atom rule bodies
9//! 4. Handles negation via set difference (Diff) nodes
10//! 5. Wraps recursive predicates in Fixpoint nodes
11//! 6. Projects to match head variables
12
13use std::collections::{HashMap, HashSet};
14
15use xlog_core::{symbol, AggOp as CoreAggOp, RelId, Result, ScalarType, Schema, XlogError};
16use xlog_ir::{
17    CompareOp, CompiledRule, ConstValue, ExecutionPlan, Expr, JoinType, PlanBuilder, ProjectExpr,
18    RirMeta, RirNode, Scc, Stratum as IrStratum,
19};
20
21use crate::ast::{
22    AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Comparison, IsExpr, LearnableRule, PredColumn,
23    Program, Rule, Term, TypeRef,
24};
25use crate::stratify::{build_dependency_graph, find_sccs_for_lowering, DepType};
26
27struct JoinPlan<'a> {
28    node: RirNode,
29    leaf_order: Vec<&'a Atom>,
30    leaf_order_idx: Vec<usize>,
31    var_pos: HashMap<String, usize>,
32    width: usize,
33    est_rows: f64,
34    total_cost: f64,
35}
36
37fn pred_columns_for_decl(pred_decl: &crate::ast::PredDecl) -> Vec<PredColumn> {
38    if pred_decl.columns.is_empty() {
39        pred_decl
40            .types
41            .iter()
42            .cloned()
43            .map(|typ| PredColumn { name: None, typ })
44            .collect()
45    } else {
46        pred_decl.columns.clone()
47    }
48}
49
50fn resolve_pred_column_type(
51    predicate: &str,
52    index: usize,
53    typ: &TypeRef,
54    domains: &HashMap<String, ScalarType>,
55) -> Result<ScalarType> {
56    match typ {
57        TypeRef::Scalar(ty) => Ok(*ty),
58        TypeRef::Domain(name) => domains.get(name).copied().ok_or_else(|| {
59            XlogError::Compilation(format!(
60                "v0.8.5 unknown domain alias '{}' in predicate '{}' column {}",
61                name, predicate, index
62            ))
63        }),
64        TypeRef::List(_) | TypeRef::Term | TypeRef::Compound | TypeRef::PredRef => {
65            Ok(ScalarType::U64)
66        }
67    }
68}
69
70fn validate_lowerable_terms(program: &Program) -> Result<()> {
71    for rule in &program.rules {
72        validate_atom_terms(&rule.head, "rule head")?;
73        for lit in &rule.body {
74            match lit {
75                BodyLiteral::Positive(atom) => validate_atom_terms(atom, "positive body atom")?,
76                BodyLiteral::Negated(atom) => validate_atom_terms(atom, "negated body atom")?,
77                BodyLiteral::Epistemic(_) => {}
78                BodyLiteral::Comparison(cmp) => {
79                    validate_term_lowerable(&cmp.left, "comparison left operand")?;
80                    validate_term_lowerable(&cmp.right, "comparison right operand")?;
81                }
82                BodyLiteral::IsExpr(_) => {}
83                BodyLiteral::Univ(_) => {
84                    return Err(XlogError::Compilation(
85                        "v0.8.5 meta error: univ literal was not normalized before lowering"
86                            .to_string(),
87                    ));
88                }
89            }
90        }
91    }
92    for constraint in &program.constraints {
93        for lit in &constraint.body {
94            match lit {
95                BodyLiteral::Positive(atom) => validate_atom_terms(atom, "constraint body atom")?,
96                BodyLiteral::Negated(atom) => {
97                    validate_atom_terms(atom, "constraint negated body atom")?
98                }
99                BodyLiteral::Epistemic(_) => {}
100                BodyLiteral::Comparison(cmp) => {
101                    validate_term_lowerable(&cmp.left, "constraint comparison left operand")?;
102                    validate_term_lowerable(&cmp.right, "constraint comparison right operand")?;
103                }
104                BodyLiteral::IsExpr(_) => {}
105                BodyLiteral::Univ(_) => {
106                    return Err(XlogError::Compilation(
107                        "v0.8.5 meta error: univ literal was not normalized before lowering"
108                            .to_string(),
109                    ));
110                }
111            }
112        }
113    }
114    for query in &program.queries {
115        validate_atom_terms(&query.atom, "query atom")?;
116    }
117    for pf in &program.prob_facts {
118        validate_atom_terms(&pf.atom, "probabilistic fact")?;
119    }
120    for ad in &program.annotated_disjunctions {
121        for choice in &ad.choices {
122            validate_atom_terms(&choice.atom, "annotated disjunction choice")?;
123        }
124    }
125    for evidence in &program.evidence {
126        validate_atom_terms(&evidence.atom, "evidence atom")?;
127    }
128    for query in &program.prob_queries {
129        validate_atom_terms(&query.atom, "probabilistic query")?;
130    }
131    for neural in &program.neural_predicates {
132        validate_atom_terms(&neural.predicate, "neural predicate")?;
133    }
134    for learnable in &program.learnable_rules {
135        validate_atom_terms(&learnable.head, "learnable rule head")?;
136        for lit in &learnable.body {
137            if let BodyLiteral::Positive(atom) = lit {
138                validate_atom_terms(atom, "learnable rule body")?;
139            }
140        }
141    }
142    Ok(())
143}
144
145fn validate_atom_terms(atom: &Atom, context: &str) -> Result<()> {
146    for term in &atom.terms {
147        validate_term_lowerable(term, context)?;
148    }
149    Ok(())
150}
151
152fn validate_term_lowerable(term: &Term, context: &str) -> Result<()> {
153    match term {
154        Term::List(_) => Err(term_not_lowerable_error(context, "list")),
155        Term::Cons { .. } => Err(term_not_lowerable_error(context, "cons")),
156        Term::Compound { .. } => Err(term_not_lowerable_error(context, "compound")),
157        Term::PredRef(_) => Err(term_not_lowerable_error(context, "predref")),
158        Term::Variable(_)
159        | Term::Anonymous
160        | Term::Integer(_)
161        | Term::Float(_)
162        | Term::String(_)
163        | Term::Symbol(_)
164        | Term::Aggregate(_) => Ok(()),
165    }
166}
167
168fn term_not_lowerable_error(context: &str, kind: &str) -> XlogError {
169    XlogError::Compilation(format!(
170        "term form '{}' in {} is parsed but not lowerable by this execution path",
171        kind, context
172    ))
173}
174
175fn term_kind_for_lowering_error(term: &Term) -> &'static str {
176    match term {
177        Term::List(_) => "list",
178        Term::Cons { .. } => "cons",
179        Term::Compound { .. } => "compound",
180        Term::PredRef(_) => "predref",
181        Term::Variable(_)
182        | Term::Anonymous
183        | Term::Integer(_)
184        | Term::Float(_)
185        | Term::String(_)
186        | Term::Symbol(_)
187        | Term::Aggregate(_) => "term",
188    }
189}
190
191/// Lowerer transforms AST programs into RIR execution plans.
192pub struct Lowerer {
193    /// Inferred or declared schemas for each predicate
194    schemas: HashMap<String, Schema>,
195    /// Stratification result (predicates grouped by strata)
196    strata: Vec<Vec<String>>,
197    /// Estimated cardinality per predicate (for join ordering)
198    est_cardinality: HashMap<String, u64>,
199    /// Optional cardinality hints per predicate (e.g., from runtime statistics).
200    cardinality_hints: HashMap<String, u64>,
201    /// Next available relation ID
202    next_rel_id: u32,
203    /// Mapping from predicate names to relation IDs
204    rel_ids: HashMap<String, RelId>,
205    /// SCCs for the program (from stratification)
206    sccs: Vec<Scc>,
207    /// Maximum active rules for TensorMaskedJoin (default 32)
208    max_active_rules: usize,
209}
210
211impl Default for Lowerer {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl Lowerer {
218    /// Create a new lowerer instance
219    pub fn new() -> Self {
220        Self {
221            schemas: HashMap::new(),
222            strata: Vec::new(),
223            est_cardinality: HashMap::new(),
224            cardinality_hints: HashMap::new(),
225            next_rel_id: 0,
226            rel_ids: HashMap::new(),
227            sccs: Vec::new(),
228            max_active_rules: 32,
229        }
230    }
231
232    /// Set the maximum active rules for TensorMaskedJoin.
233    pub fn set_max_active_rules(&mut self, max: usize) {
234        self.max_active_rules = max;
235    }
236
237    /// Set the stratification result for ordering
238    pub(crate) fn set_strata(&mut self, strata: Vec<Vec<String>>) {
239        self.strata = strata;
240    }
241
242    /// Set cardinality hints (typically sourced from runtime statistics snapshots).
243    ///
244    /// These hints are used by lowering-time join ordering when available.
245    pub(crate) fn set_cardinality_hints(&mut self, hints: HashMap<String, u64>) {
246        self.cardinality_hints = hints;
247    }
248
249    /// Get the mapping from predicate names to relation IDs
250    pub fn rel_ids(&self) -> &HashMap<String, RelId> {
251        &self.rel_ids
252    }
253
254    /// Get the inferred schemas for predicates
255    pub fn schemas(&self) -> &HashMap<String, Schema> {
256        &self.schemas
257    }
258
259    pub(crate) fn create_helper_relation(&mut self, schema: Schema) -> (String, RelId) {
260        let name = format!("__kclique_helper_{}", self.next_rel_id);
261        let rel_id = self.get_or_create_rel_id(&name);
262        self.schemas.insert(name.clone(), schema);
263        (name, rel_id)
264    }
265
266    /// Get or allocate a relation ID for a predicate
267    fn get_or_create_rel_id(&mut self, name: &str) -> RelId {
268        if let Some(&id) = self.rel_ids.get(name) {
269            id
270        } else {
271            let id = RelId(self.next_rel_id);
272            self.next_rel_id += 1;
273            self.rel_ids.insert(name.to_string(), id);
274            id
275        }
276    }
277
278    /// Infer schemas from facts and predicate declarations
279    fn infer_schemas(&mut self, program: &Program) -> Result<()> {
280        let domains: HashMap<String, ScalarType> = program
281            .domains
282            .iter()
283            .map(|domain| (domain.name.clone(), domain.typ))
284            .collect();
285
286        // First, use explicit predicate declarations
287        for pred_decl in &program.predicates {
288            let declared_columns = pred_columns_for_decl(pred_decl);
289            let columns: Vec<(String, ScalarType)> = declared_columns
290                .iter()
291                .enumerate()
292                .map(|(i, col)| {
293                    let name = col.name.clone().unwrap_or_else(|| format!("c{}", i));
294                    resolve_pred_column_type(&pred_decl.name, i, &col.typ, &domains)
295                        .map(|ty| (name, ty))
296                })
297                .collect::<Result<Vec<_>>>()?;
298            self.schemas
299                .insert(pred_decl.name.clone(), Schema::new(columns));
300        }
301
302        // Then, infer from facts (if no declaration exists)
303        for rule in program.facts() {
304            let pred = &rule.head.predicate;
305            if !self.schemas.contains_key(pred) {
306                let columns: Vec<(String, ScalarType)> = rule
307                    .head
308                    .terms
309                    .iter()
310                    .enumerate()
311                    .map(|(i, term)| {
312                        let ty = infer_term_type(term);
313                        (format!("c{}", i), ty)
314                    })
315                    .collect();
316                self.schemas.insert(pred.clone(), Schema::new(columns));
317            }
318        }
319
320        // Finally, infer from rule heads if we still don't have a schema
321        for rule in &program.rules {
322            let pred = &rule.head.predicate;
323            if !self.schemas.contains_key(pred) {
324                // Use default U64 type for variables
325                let columns: Vec<(String, ScalarType)> = rule
326                    .head
327                    .terms
328                    .iter()
329                    .enumerate()
330                    .map(|(i, term)| {
331                        let ty = match term {
332                            Term::Variable(name) => self
333                                .infer_head_term_type_from_body(rule, name)
334                                .unwrap_or_else(|| infer_term_type(term)),
335                            _ => infer_term_type(term),
336                        };
337                        (format!("c{}", i), ty)
338                    })
339                    .collect();
340                let schema = Schema::new(columns)
341                    .with_sort_labels(sort_labels_from_terms(&rule.head.terms))
342                    .expect("rule head sort labels match inferred schema arity");
343                self.schemas.insert(pred.clone(), schema);
344            }
345        }
346
347        // Also infer from rule bodies for EDB predicates that only appear in bodies
348        for rule in &program.rules {
349            for lit in &rule.body {
350                let atom = match lit {
351                    BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
352                    BodyLiteral::Epistemic(_)
353                    | BodyLiteral::Comparison(_)
354                    | BodyLiteral::IsExpr(_)
355                    | BodyLiteral::Univ(_) => continue,
356                };
357                let pred = &atom.predicate;
358                if self.schemas.contains_key(pred) {
359                    continue;
360                }
361                let columns: Vec<(String, ScalarType)> = atom
362                    .terms
363                    .iter()
364                    .enumerate()
365                    .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
366                    .collect();
367                let schema = Schema::new(columns)
368                    .with_sort_labels(sort_labels_from_terms(&atom.terms))
369                    .expect("body sort labels match inferred schema arity");
370                self.schemas.insert(pred.clone(), schema);
371            }
372        }
373
374        // Ensure schemas exist for probabilistic facts and annotated disjunctions
375        for pf in &program.prob_facts {
376            let pred = &pf.atom.predicate;
377            if self.schemas.contains_key(pred) {
378                continue;
379            }
380            let columns: Vec<(String, ScalarType)> = pf
381                .atom
382                .terms
383                .iter()
384                .enumerate()
385                .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
386                .collect();
387            self.schemas.insert(pred.clone(), Schema::new(columns));
388        }
389
390        for ad in &program.annotated_disjunctions {
391            for choice in &ad.choices {
392                let pred = &choice.atom.predicate;
393                if self.schemas.contains_key(pred) {
394                    continue;
395                }
396                let columns: Vec<(String, ScalarType)> = choice
397                    .atom
398                    .terms
399                    .iter()
400                    .enumerate()
401                    .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
402                    .collect();
403                self.schemas.insert(pred.clone(), Schema::new(columns));
404            }
405        }
406
407        Ok(())
408    }
409
410    fn infer_head_term_type_from_body(&self, rule: &Rule, var_name: &str) -> Option<ScalarType> {
411        for lit in &rule.body {
412            let atom = match lit {
413                BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
414                BodyLiteral::Epistemic(_)
415                | BodyLiteral::Comparison(_)
416                | BodyLiteral::IsExpr(_)
417                | BodyLiteral::Univ(_) => continue,
418            };
419            let schema = self.schemas.get(&atom.predicate)?;
420            for (idx, term) in atom.terms.iter().enumerate() {
421                if let Term::Variable(name) = term {
422                    if name == var_name {
423                        if let Some(ty) = schema.column_type(idx) {
424                            return Some(ty);
425                        }
426                    }
427                }
428            }
429        }
430        None
431    }
432
433    fn infer_cardinalities(&mut self, program: &Program) {
434        self.est_cardinality.clear();
435
436        let mut fact_counts: HashMap<String, u64> = HashMap::new();
437        for fact in program.facts() {
438            *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
439        }
440
441        for pred in self.schemas.keys() {
442            let est = self
443                .cardinality_hints
444                .get(pred)
445                .copied()
446                .or_else(|| fact_counts.get(pred).copied())
447                .unwrap_or(1000)
448                .max(1);
449            self.est_cardinality.insert(pred.clone(), est);
450        }
451    }
452
453    /// Build SCCs from the dependency graph
454    fn build_sccs(&mut self, program: &Program) {
455        let graph = build_dependency_graph(program);
456        let scc_groups = find_sccs_for_lowering(&graph);
457
458        self.sccs.clear();
459        for (id, predicates) in scc_groups.iter().enumerate() {
460            // An SCC is recursive if it has more than one predicate
461            // or if a single predicate depends on itself positively
462            let is_recursive = if predicates.len() > 1 {
463                true
464            } else {
465                let pred = &predicates[0];
466                graph
467                    .outgoing(pred)
468                    .iter()
469                    .any(|e| e.to == *pred && e.dep_type == DepType::Positive)
470            };
471
472            self.sccs.push(Scc {
473                id: id as u32,
474                predicates: predicates.clone(),
475                is_recursive,
476            });
477        }
478    }
479
480    /// Lower an entire program to an execution plan
481    pub fn lower_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
482        validate_lowerable_terms(program)?;
483        // Infer schemas
484        self.infer_schemas(program)?;
485        self.infer_cardinalities(program);
486
487        // Pre-allocate RelIds for declared predicates so schema-only programs
488        // can populate relation stores before any facts or executable rules
489        // mention those relations. This keeps ILP candidate generation and
490        // runtime relation upload aligned with declared schemas.
491        for pred_decl in &program.predicates {
492            self.get_or_create_rel_id(&pred_decl.name);
493        }
494
495        // Build SCCs
496        self.build_sccs(program);
497
498        // Build execution plan
499        let mut builder = PlanBuilder::new();
500
501        // Add SCCs to the builder
502        for scc in &self.sccs {
503            builder.add_scc(scc.clone());
504        }
505
506        // Build strata from our strata field
507        for (id, preds) in self.strata.iter().enumerate() {
508            // Find which SCCs belong to this stratum
509            let scc_ids: Vec<u32> = self
510                .sccs
511                .iter()
512                .filter(|scc| scc.predicates.iter().any(|p| preds.contains(p)))
513                .map(|scc| scc.id)
514                .collect();
515
516            if !scc_ids.is_empty() {
517                builder.add_stratum(IrStratum {
518                    id: id as u32,
519                    sccs: scc_ids,
520                });
521            }
522        }
523
524        // Lower each rule
525        let mut rules_by_pred: HashMap<String, Vec<&Rule>> = HashMap::new();
526        for rule in program.proper_rules() {
527            rules_by_pred
528                .entry(rule.head.predicate.clone())
529                .or_default()
530                .push(rule);
531        }
532
533        // Add facts as scan-only rules
534        for fact in program.facts() {
535            let pred = &fact.head.predicate;
536            let scc_id = self.find_scc_for_predicate(pred);
537            let rel_id = self.get_or_create_rel_id(pred);
538
539            let body = RirNode::Scan { rel: rel_id };
540            let meta = self.create_meta_for_predicate(pred);
541
542            builder.add_rule(
543                scc_id,
544                CompiledRule {
545                    head: pred.clone(),
546                    body,
547                    meta,
548                },
549            );
550        }
551
552        // Lower proper rules
553        for (pred, rules) in &rules_by_pred {
554            let scc_id = self.find_scc_for_predicate(pred);
555
556            for rule in rules {
557                let body = self.lower_rule(rule)?;
558                let meta = self.create_meta_for_predicate(pred);
559
560                builder.add_rule(
561                    scc_id,
562                    CompiledRule {
563                        head: pred.clone(),
564                        body,
565                        meta,
566                    },
567                );
568            }
569        }
570
571        // Lower learnable rules into tensor-masked joins.
572        // Pre-allocate RelIds for ALL learnable predicates (heads + bodies)
573        // so every lower_learnable_rule snapshot is complete.
574        for learnable in &program.learnable_rules {
575            self.get_or_create_rel_id(&learnable.head.predicate);
576            for lit in &learnable.body {
577                if let BodyLiteral::Positive(atom) = lit {
578                    self.get_or_create_rel_id(&atom.predicate);
579                }
580            }
581        }
582        for learnable in &program.learnable_rules {
583            let head_pred = &learnable.head.predicate;
584            let scc_id = self.find_scc_for_predicate(head_pred);
585            let body = self.lower_learnable_rule(learnable)?;
586            let meta = self.create_meta_for_predicate(head_pred);
587            builder.add_rule(
588                scc_id,
589                CompiledRule {
590                    head: head_pred.clone(),
591                    body,
592                    meta,
593                },
594            );
595        }
596
597        let mut plan = builder.build();
598        // Record relation arities for downstream generic multiway shape
599        // promoters that size Scan leaves from these values.
600        // One pre-pass over the AST covers every predicate the lowerer
601        // assigned a RelId: rule heads, positive/negated body atoms,
602        // and facts.
603        for rule in program.proper_rules() {
604            if let Some(&id) = self.rel_ids.get(&rule.head.predicate) {
605                plan.rel_arities.insert(id, rule.head.terms.len());
606            }
607            for lit in &rule.body {
608                let atom = match lit {
609                    BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a,
610                    _ => continue,
611                };
612                if let Some(&id) = self.rel_ids.get(&atom.predicate) {
613                    plan.rel_arities.insert(id, atom.terms.len());
614                }
615            }
616        }
617        for fact in program.facts() {
618            if let Some(&id) = self.rel_ids.get(&fact.head.predicate) {
619                plan.rel_arities.insert(id, fact.head.terms.len());
620            }
621        }
622        Ok(plan)
623    }
624
625    /// Find the SCC ID for a predicate
626    fn find_scc_for_predicate(&self, pred: &str) -> u32 {
627        self.sccs
628            .iter()
629            .find(|scc| scc.predicates.contains(&pred.to_string()))
630            .map(|scc| scc.id)
631            .unwrap_or(0)
632    }
633
634    /// Create metadata for a predicate
635    fn create_meta_for_predicate(&self, pred: &str) -> RirMeta {
636        let schema = self
637            .schemas
638            .get(pred)
639            .cloned()
640            .unwrap_or_else(|| Schema::new(vec![]));
641        RirMeta::with_schema(schema)
642    }
643
644    /// Lower a learnable rule template into a TensorMaskedJoin node.
645    /// Validates that the body has exactly two positive atoms.
646    /// Sorts rel_index by RelId for deterministic tensor dimension mapping.
647    /// Uses get_or_create_rel_id for heads so head-only predicates are handled.
648    fn lower_learnable_rule(&mut self, rule: &LearnableRule) -> Result<RirNode> {
649        // Validate body shape before indexing fixed body positions.
650        if rule.body.len() != 2 {
651            return Err(XlogError::Compilation(format!(
652                "learnable rule '{}' requires exactly 2 body literals, got {}",
653                rule.mask_name,
654                rule.body.len()
655            )));
656        }
657        for (idx, lit) in rule.body.iter().enumerate() {
658            match lit {
659                BodyLiteral::Positive(_) => {}
660                _ => {
661                    return Err(XlogError::Compilation(format!(
662                        "learnable rule '{}' body[{}]: only positive atoms allowed",
663                        rule.mask_name, idx
664                    )));
665                }
666            }
667        }
668
669        // Sort by RelId for deterministic tensor dimension mapping.
670        let mut rel_index: Vec<(RelId, String)> = self
671            .rel_ids()
672            .iter()
673            .map(|(name, id)| (*id, name.clone()))
674            .collect();
675        rel_index.sort_by_key(|(id, _)| id.0);
676        let schema_size = rel_index.len();
677
678        let (left_keys, right_keys) =
679            self.extract_template_join_keys(&rule.body[0], &rule.body[1])?;
680
681        let head_rel_name = rule.head.predicate.clone();
682        // Allocate lazily because head-only predicates may not have a RelId yet.
683        let head_rel_id = self.get_or_create_rel_id(&head_rel_name);
684
685        // Compute head projection: map head variables to join result columns.
686        // Join result layout: [left_col_0..left_col_n, right_col_0..right_col_m].
687        let left_atom = rule.body[0].atom().unwrap();
688        let right_atom = rule.body[1].atom().unwrap();
689        let left_arity = left_atom.terms.len();
690
691        // Build variable -> first-occurrence column mapping over joined result
692        let mut var_to_col: HashMap<String, usize> = HashMap::new();
693        for (i, term) in left_atom.terms.iter().enumerate() {
694            if let Some(name) = term.variable_name() {
695                var_to_col.entry(name.to_string()).or_insert(i);
696            }
697        }
698        for (i, term) in right_atom.terms.iter().enumerate() {
699            if let Some(name) = term.variable_name() {
700                var_to_col.entry(name.to_string()).or_insert(left_arity + i);
701            }
702        }
703
704        let mut head_projection: Vec<usize> = Vec::new();
705        for term in &rule.head.terms {
706            if let Some(name) = term.variable_name() {
707                let col = var_to_col.get(name).ok_or_else(|| {
708                    XlogError::Compilation(format!(
709                        "Learnable rule head variable '{}' not found in body atoms \
710                         ({}, {}). All head variables must appear in the body.",
711                        name, left_atom.predicate, right_atom.predicate,
712                    ))
713                })?;
714                head_projection.push(*col);
715            } else {
716                return Err(XlogError::Compilation(format!(
717                    "Learnable rule head must contain only variables, \
718                     found constant {:?} in head of '{}'",
719                    term, head_rel_name,
720                )));
721            }
722        }
723
724        // Infer schema for head predicate from the learnable rule if not already set.
725        // The head's column types come from the projected join columns.
726        if !self.schemas.contains_key(&head_rel_name) {
727            let columns: Vec<(String, ScalarType)> = head_projection
728                .iter()
729                .enumerate()
730                .map(|(i, &col)| {
731                    // Determine the type from left or right atom's schema
732                    let ty = if col < left_arity {
733                        self.schemas
734                            .get(&left_atom.predicate)
735                            .and_then(|s| s.column_type(col))
736                            .unwrap_or(ScalarType::U32)
737                    } else {
738                        self.schemas
739                            .get(&right_atom.predicate)
740                            .and_then(|s| s.column_type(col - left_arity))
741                            .unwrap_or(ScalarType::U32)
742                    };
743                    (format!("c{}", i), ty)
744                })
745                .collect();
746            self.schemas
747                .insert(head_rel_name.clone(), Schema::new(columns));
748        }
749
750        Ok(RirNode::TensorMaskedJoin {
751            mask_name: rule.mask_name.clone(),
752            schema_size,
753            left_keys,
754            right_keys,
755            rel_index,
756            head_rel_name,
757            head_rel_id,
758            max_active_rules: self.max_active_rules,
759            head_projection,
760        })
761    }
762
763    /// Extract join keys from two body literals' shared variables.
764    /// For `b1(X, Z), b2(Z, Y)`, the shared variable Z gives left_keys=[1], right_keys=[0].
765    fn extract_template_join_keys(
766        &self,
767        left: &BodyLiteral,
768        right: &BodyLiteral,
769    ) -> Result<(Vec<usize>, Vec<usize>)> {
770        let left_atom = left
771            .atom()
772            .ok_or_else(|| XlogError::Compilation("Learnable body[0] is not an atom".into()))?;
773        let right_atom = right
774            .atom()
775            .ok_or_else(|| XlogError::Compilation("Learnable body[1] is not an atom".into()))?;
776
777        let mut left_keys = Vec::new();
778        let mut right_keys = Vec::new();
779
780        for (li, lt) in left_atom.terms.iter().enumerate() {
781            if let Some(lname) = lt.variable_name() {
782                for (ri, rt) in right_atom.terms.iter().enumerate() {
783                    if let Some(rname) = rt.variable_name() {
784                        if lname == rname {
785                            left_keys.push(li);
786                            right_keys.push(ri);
787                        }
788                    }
789                }
790            }
791        }
792
793        Ok((left_keys, right_keys))
794    }
795
796    /// Lower a single rule to an RIR node
797    fn lower_rule(&mut self, rule: &Rule) -> Result<RirNode> {
798        if let Some(lit) = rule.body.iter().find_map(|lit| match lit {
799            BodyLiteral::Epistemic(lit) => Some(lit),
800            _ => None,
801        }) {
802            return Err(XlogError::UnsupportedEpistemicConstruct {
803                construct: "RIR lowering boundary".to_string(),
804                context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
805            });
806        }
807
808        // Split body literals.
809        let (positive_atoms, negated_atoms, comparisons, is_exprs) =
810            Self::split_body_literals(&rule.body);
811
812        // Allocate RelIds for all body predicates in source order so join planning
813        // does not influence identifier assignment.
814        for lit in &rule.body {
815            match lit {
816                BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => {
817                    self.get_or_create_rel_id(&atom.predicate);
818                }
819                BodyLiteral::Epistemic(_)
820                | BodyLiteral::Comparison(_)
821                | BodyLiteral::IsExpr(_)
822                | BodyLiteral::Univ(_) => {}
823            }
824        }
825
826        // Plan positive atoms (join tree shape + leaf order).
827        //
828        // Rules with no positive atoms are legal for nullary/ground heads in our
829        // probabilistic profiles (e.g. `q() :- not p().`). Lower them by seeding
830        // the body with a unit relation ({()}) and applying filters/negations.
831        let (positive_root, leaf_order) = if positive_atoms.is_empty() {
832            (RirNode::Unit, Vec::new())
833        } else {
834            self.plan_positive_atoms(&positive_atoms)?
835        };
836
837        // Build variable environment from the planned leaf order (matches join output layout:
838        // left subtree columns then right subtree columns).
839        let mut var_env = VariableEnv::new();
840        let mut current_col = 0;
841        for atom in &leaf_order {
842            let schema = self.schemas.get(&atom.predicate);
843            for (i, term) in atom.terms.iter().enumerate() {
844                if let Term::Variable(name) = term {
845                    if name == "_" {
846                        continue;
847                    }
848                    var_env.add_occurrence(name, atom.predicate.clone(), i, current_col + i);
849                    // Also record the type for this variable (first occurrence wins)
850                    if !var_env.types.contains_key(name) {
851                        let typ = schema
852                            .and_then(|s| s.column_type(i))
853                            .unwrap_or(ScalarType::I64); // Default to I64 for arithmetic
854                        var_env.types.insert(name.to_string(), typ);
855                    }
856                }
857            }
858            current_col += atom.terms.len();
859        }
860        var_env.total_cols = current_col;
861
862        // Lower the body starting from the planned positive join root.
863        let body_node = self.lower_body_parts(
864            positive_root,
865            &negated_atoms,
866            &comparisons,
867            &is_exprs,
868            &mut var_env,
869        )?;
870
871        if rule.has_aggregation() {
872            return self.lower_aggregate_rule(&rule.head, body_node, &var_env);
873        }
874
875        // Project to head terms (variables and constants).
876        let projection_exprs = self.compute_head_projection(&rule.head, &var_env)?;
877
878        if Self::is_identity_projection(&projection_exprs, var_env.column_count()) {
879            Ok(body_node)
880        } else {
881            Ok(RirNode::Project {
882                input: Box::new(body_node),
883                columns: projection_exprs,
884            })
885        }
886    }
887
888    fn split_body_literals(
889        body: &[BodyLiteral],
890    ) -> (Vec<&Atom>, Vec<&Atom>, Vec<&Comparison>, Vec<&IsExpr>) {
891        let mut positive_atoms: Vec<&Atom> = Vec::new();
892        let mut negated_atoms: Vec<&Atom> = Vec::new();
893        let mut comparisons: Vec<&Comparison> = Vec::new();
894        let mut is_exprs: Vec<&IsExpr> = Vec::new();
895
896        for lit in body {
897            match lit {
898                BodyLiteral::Positive(atom) => positive_atoms.push(atom),
899                BodyLiteral::Negated(atom) => negated_atoms.push(atom),
900                BodyLiteral::Epistemic(_) => {}
901                BodyLiteral::Comparison(cmp) => comparisons.push(cmp),
902                BodyLiteral::IsExpr(is_expr) => is_exprs.push(is_expr),
903                BodyLiteral::Univ(_) => {}
904            }
905        }
906
907        (positive_atoms, negated_atoms, comparisons, is_exprs)
908    }
909
910    fn atom_vars(atom: &Atom) -> std::collections::HashSet<String> {
911        atom.terms
912            .iter()
913            .flat_map(|t| t.variables().into_iter())
914            .filter(|name| *name != "_")
915            .map(ToOwned::to_owned)
916            .collect()
917    }
918
919    fn estimate_atom_rows(&self, atom: &Atom) -> f64 {
920        let base = self
921            .est_cardinality
922            .get(&atom.predicate)
923            .copied()
924            .unwrap_or(1000)
925            .max(1) as f64;
926
927        let const_count = atom
928            .terms
929            .iter()
930            .filter(|t| term_to_const_value(t).is_some())
931            .count();
932
933        // Equality constants are usually selective; use a conservative default.
934        let selectivity = 0.1_f64.powi(const_count as i32);
935        (base * selectivity).max(1.0)
936    }
937
938    fn build_cartesian_join(
939        &self,
940        left: RirNode,
941        right: RirNode,
942        left_width: usize,
943        right_width: usize,
944    ) -> RirNode {
945        // Implement cross join by appending a constant key column to both inputs and joining on it,
946        // then projecting away the constant columns.
947        let left_const_col =
948            ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
949        let right_const_col =
950            ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
951
952        let mut left_cols: Vec<ProjectExpr> = (0..left_width).map(ProjectExpr::Column).collect();
953        left_cols.push(left_const_col);
954        let left_aug = RirNode::Project {
955            input: Box::new(left),
956            columns: left_cols,
957        };
958
959        let mut right_cols: Vec<ProjectExpr> = (0..right_width).map(ProjectExpr::Column).collect();
960        right_cols.push(right_const_col);
961        let right_aug = RirNode::Project {
962            input: Box::new(right),
963            columns: right_cols,
964        };
965
966        let joined = RirNode::Join {
967            left: Box::new(left_aug),
968            right: Box::new(right_aug),
969            left_keys: vec![left_width],
970            right_keys: vec![right_width],
971            join_type: JoinType::Inner,
972        };
973
974        let mut keep: Vec<ProjectExpr> = Vec::with_capacity(left_width + right_width);
975        keep.extend((0..left_width).map(ProjectExpr::Column));
976        let right_start = left_width + 1;
977        keep.extend((right_start..right_start + right_width).map(ProjectExpr::Column));
978
979        RirNode::Project {
980            input: Box::new(joined),
981            columns: keep,
982        }
983    }
984
985    fn make_leaf_plan<'a>(&mut self, atom: &'a Atom, orig_idx: usize) -> Result<JoinPlan<'a>> {
986        let rel_id = self.get_or_create_rel_id(&atom.predicate);
987        let scan = RirNode::Scan { rel: rel_id };
988        let node = self.apply_constant_filters(scan, atom, 0)?;
989
990        let mut var_pos: HashMap<String, usize> = HashMap::new();
991        for (i, term) in atom.terms.iter().enumerate() {
992            if let Term::Variable(name) = term {
993                if name != "_" {
994                    var_pos.entry(name.clone()).or_insert(i);
995                }
996            }
997        }
998
999        let est_rows = self.estimate_atom_rows(atom);
1000        Ok(JoinPlan {
1001            node,
1002            leaf_order: vec![atom],
1003            leaf_order_idx: vec![orig_idx],
1004            var_pos,
1005            width: atom.terms.len(),
1006            est_rows,
1007            total_cost: est_rows,
1008        })
1009    }
1010
1011    fn join_plans<'a>(&self, left: &JoinPlan<'a>, right: &JoinPlan<'a>) -> JoinPlan<'a> {
1012        let shared_vars: Vec<&String> = left
1013            .var_pos
1014            .keys()
1015            .filter(|v| right.var_pos.contains_key(*v))
1016            .collect();
1017
1018        let node = if shared_vars.is_empty() {
1019            self.build_cartesian_join(
1020                left.node.clone(),
1021                right.node.clone(),
1022                left.width,
1023                right.width,
1024            )
1025        } else {
1026            let mut key_pairs: Vec<(usize, usize)> = shared_vars
1027                .iter()
1028                .filter_map(|v| {
1029                    Some((
1030                        left.var_pos.get(*v).copied()?,
1031                        right.var_pos.get(*v).copied()?,
1032                    ))
1033                })
1034                .collect();
1035            key_pairs.sort_unstable();
1036
1037            let (left_keys, right_keys): (Vec<usize>, Vec<usize>) = key_pairs.into_iter().unzip();
1038
1039            RirNode::Join {
1040                left: Box::new(left.node.clone()),
1041                right: Box::new(right.node.clone()),
1042                left_keys,
1043                right_keys,
1044                join_type: JoinType::Inner,
1045            }
1046        };
1047
1048        let mut leaf_order = left.leaf_order.clone();
1049        leaf_order.extend(right.leaf_order.iter().copied());
1050
1051        let mut leaf_order_idx = left.leaf_order_idx.clone();
1052        leaf_order_idx.extend_from_slice(&right.leaf_order_idx);
1053
1054        let mut var_pos = left.var_pos.clone();
1055        for (var, pos) in &right.var_pos {
1056            var_pos.entry(var.clone()).or_insert(left.width + *pos);
1057        }
1058
1059        let shared = shared_vars.len();
1060        let mut selectivity = if shared == 0 {
1061            1.0
1062        } else {
1063            0.1_f64.powi(shared as i32)
1064        };
1065        if shared == 0 {
1066            // Penalize cartesian joins strongly.
1067            selectivity *= 1.0e6;
1068        }
1069
1070        let output_rows = (left.est_rows * right.est_rows * selectivity).max(1.0);
1071
1072        // Hash join cost is sensitive to which side is build (right) and probe (left).
1073        let build_cost = right.est_rows;
1074        let probe_cost = left.est_rows * 0.5;
1075        let total_cost = left.total_cost + right.total_cost + build_cost + probe_cost + output_rows;
1076
1077        JoinPlan {
1078            node,
1079            leaf_order,
1080            leaf_order_idx,
1081            var_pos,
1082            width: left.width + right.width,
1083            est_rows: output_rows,
1084            total_cost,
1085        }
1086    }
1087
1088    fn plan_positive_atoms_bushy<'a>(
1089        &mut self,
1090        atoms: &[&'a Atom],
1091    ) -> Result<(RirNode, Vec<&'a Atom>)> {
1092        let n = atoms.len();
1093        if n == 0 {
1094            return Err(XlogError::Compilation("Empty rule body".to_string()));
1095        }
1096        if n == 1 {
1097            let plan = self.make_leaf_plan(atoms[0], 0)?;
1098            return Ok((plan.node, plan.leaf_order));
1099        }
1100
1101        let size = 1usize << n;
1102        let mut best: Vec<Option<JoinPlan<'a>>> = (0..size).map(|_| None).collect();
1103
1104        for (i, atom) in atoms.iter().enumerate() {
1105            best[1usize << i] = Some(self.make_leaf_plan(atom, i)?);
1106        }
1107
1108        fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1109            for (ai, bi) in a.iter().zip(b.iter()) {
1110                if ai != bi {
1111                    return ai < bi;
1112                }
1113            }
1114            a.len() < b.len()
1115        }
1116
1117        for mask in 1..size {
1118            if mask.count_ones() <= 1 {
1119                continue;
1120            }
1121
1122            let mut best_for_mask: Option<JoinPlan<'a>> = None;
1123
1124            let mut sub = (mask - 1) & mask;
1125            while sub > 0 {
1126                let a = sub;
1127                let b = mask ^ a;
1128                if b == 0 {
1129                    sub = (sub - 1) & mask;
1130                    continue;
1131                }
1132
1133                let (Some(plan_a), Some(plan_b)) = (&best[a], &best[b]) else {
1134                    sub = (sub - 1) & mask;
1135                    continue;
1136                };
1137
1138                // Consider both orientations: A ⋈ B and B ⋈ A.
1139                for (left, right) in [(plan_a, plan_b), (plan_b, plan_a)] {
1140                    let cand = self.join_plans(left, right);
1141                    let replace = match &best_for_mask {
1142                        None => true,
1143                        Some(current) => {
1144                            if cand.total_cost < current.total_cost {
1145                                true
1146                            } else if (cand.total_cost - current.total_cost).abs() < 1e-9 {
1147                                lex_lt(&cand.leaf_order_idx, &current.leaf_order_idx)
1148                            } else {
1149                                false
1150                            }
1151                        }
1152                    };
1153
1154                    if replace {
1155                        best_for_mask = Some(cand);
1156                    }
1157                }
1158
1159                sub = (sub - 1) & mask;
1160            }
1161
1162            best[mask] = best_for_mask;
1163        }
1164
1165        let full_mask = size - 1;
1166        if let Some(plan) = best[full_mask].take() {
1167            return Ok((plan.node, plan.leaf_order));
1168        }
1169
1170        // Should be unreachable, but fall back to greedy ordering.
1171        let ordered = self.order_positive_atoms_greedy(atoms);
1172        let mut dummy_env = VariableEnv::new();
1173        let node = self.build_join_tree(&ordered, &mut dummy_env)?;
1174        Ok((node, ordered))
1175    }
1176
1177    fn plan_positive_atoms<'a>(&mut self, atoms: &[&'a Atom]) -> Result<(RirNode, Vec<&'a Atom>)> {
1178        if atoms.len() <= 1 {
1179            if atoms.is_empty() {
1180                return Err(XlogError::Compilation("Empty rule body".to_string()));
1181            }
1182            let plan = self.make_leaf_plan(atoms[0], 0)?;
1183            return Ok((plan.node, plan.leaf_order));
1184        }
1185
1186        const MAX_BUSHY_DP_ATOMS: usize = 10;
1187        if atoms.len() <= MAX_BUSHY_DP_ATOMS {
1188            return self.plan_positive_atoms_bushy(atoms);
1189        }
1190
1191        // Greedy bushy join planning for large rules (scales beyond exponential DP).
1192        self.plan_positive_atoms_bushy_greedy(atoms)
1193    }
1194
1195    fn plan_positive_atoms_bushy_greedy<'a>(
1196        &mut self,
1197        atoms: &[&'a Atom],
1198    ) -> Result<(RirNode, Vec<&'a Atom>)> {
1199        if atoms.is_empty() {
1200            return Err(XlogError::Compilation("Empty rule body".to_string()));
1201        }
1202
1203        fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1204            for (ai, bi) in a.iter().zip(b.iter()) {
1205                if ai != bi {
1206                    return ai < bi;
1207                }
1208            }
1209            a.len() < b.len()
1210        }
1211
1212        let mut plans: Vec<JoinPlan<'a>> = Vec::with_capacity(atoms.len());
1213        for (idx, atom) in atoms.iter().enumerate() {
1214            plans.push(self.make_leaf_plan(atom, idx)?);
1215        }
1216
1217        while plans.len() > 1 {
1218            let mut best_pair: Option<(usize, usize, JoinPlan<'a>)> = None;
1219
1220            for i in 0..plans.len() {
1221                for j in (i + 1)..plans.len() {
1222                    let a = &plans[i];
1223                    let b = &plans[j];
1224
1225                    let cand_ab = self.join_plans(a, b);
1226                    let cand_ba = self.join_plans(b, a);
1227
1228                    let cand = if cand_ab.total_cost < cand_ba.total_cost
1229                        || (cand_ab.total_cost - cand_ba.total_cost).abs() < 1e-9
1230                            && lex_lt(&cand_ab.leaf_order_idx, &cand_ba.leaf_order_idx)
1231                    {
1232                        cand_ab
1233                    } else {
1234                        cand_ba
1235                    };
1236
1237                    let replace = match &best_pair {
1238                        None => true,
1239                        Some((_bi, _bj, best)) => {
1240                            if cand.total_cost < best.total_cost {
1241                                true
1242                            } else if (cand.total_cost - best.total_cost).abs() < 1e-9 {
1243                                lex_lt(&cand.leaf_order_idx, &best.leaf_order_idx)
1244                            } else {
1245                                false
1246                            }
1247                        }
1248                    };
1249
1250                    if replace {
1251                        best_pair = Some((i, j, cand));
1252                    }
1253                }
1254            }
1255
1256            let Some((i, j, joined)) = best_pair else {
1257                break;
1258            };
1259
1260            // Remove joined inputs from the plan list and replace with the join.
1261            let (a, b) = if i < j { (i, j) } else { (j, i) };
1262            plans.remove(b);
1263            plans.remove(a);
1264            plans.push(joined);
1265        }
1266
1267        let plan = plans
1268            .pop()
1269            .ok_or_else(|| XlogError::Compilation("Join planning failed".to_string()))?;
1270        Ok((plan.node, plan.leaf_order))
1271    }
1272
1273    fn order_positive_atoms_greedy<'a>(&self, atoms: &[&'a Atom]) -> Vec<&'a Atom> {
1274        let mut remaining: Vec<(usize, &Atom)> = atoms.iter().copied().enumerate().collect();
1275        let mut ordered: Vec<&Atom> = Vec::with_capacity(atoms.len());
1276        let mut bound_vars: HashSet<String> = HashSet::new();
1277
1278        while !remaining.is_empty() {
1279            let pick_idx = if ordered.is_empty() {
1280                remaining
1281                    .iter()
1282                    .enumerate()
1283                    .min_by(|(_, a), (_, b)| {
1284                        let (ai, aa) = **a;
1285                        let (bi, bb) = **b;
1286                        self.estimate_atom_rows(aa)
1287                            .partial_cmp(&self.estimate_atom_rows(bb))
1288                            .unwrap_or(std::cmp::Ordering::Equal)
1289                            .then(ai.cmp(&bi))
1290                    })
1291                    .map(|(idx, _)| idx)
1292                    .unwrap()
1293            } else {
1294                remaining
1295                    .iter()
1296                    .enumerate()
1297                    .min_by(|(_, a), (_, b)| {
1298                        let (ai, aa) = **a;
1299                        let (bi, bb) = **b;
1300
1301                        let a_vars = Self::atom_vars(aa);
1302                        let b_vars = Self::atom_vars(bb);
1303
1304                        let a_shared = a_vars.intersection(&bound_vars).count();
1305                        let b_shared = b_vars.intersection(&bound_vars).count();
1306
1307                        let a_score = if a_shared == 0 {
1308                            self.estimate_atom_rows(aa) * 1.0e12
1309                        } else {
1310                            self.estimate_atom_rows(aa) / a_shared as f64
1311                        };
1312                        let b_score = if b_shared == 0 {
1313                            self.estimate_atom_rows(bb) * 1.0e12
1314                        } else {
1315                            self.estimate_atom_rows(bb) / b_shared as f64
1316                        };
1317
1318                        a_score
1319                            .partial_cmp(&b_score)
1320                            .unwrap_or(std::cmp::Ordering::Equal)
1321                            .then(ai.cmp(&bi))
1322                    })
1323                    .map(|(idx, _)| idx)
1324                    .unwrap()
1325            };
1326
1327            let (_orig_idx, atom) = remaining.remove(pick_idx);
1328            ordered.push(atom);
1329            bound_vars.extend(Self::atom_vars(atom));
1330        }
1331
1332        ordered
1333    }
1334
1335    fn lower_body_parts(
1336        &mut self,
1337        positive_root: RirNode,
1338        negated_atoms: &[&Atom],
1339        comparisons: &[&Comparison],
1340        is_exprs: &[&IsExpr],
1341        var_env: &mut VariableEnv,
1342    ) -> Result<RirNode> {
1343        let mut result = positive_root;
1344
1345        // Apply comparisons as filters.
1346        for cmp in comparisons {
1347            result = self.apply_comparison(result, cmp, var_env)?;
1348        }
1349
1350        // Apply is-expressions (must be after atoms that bind the input variables).
1351        for is_expr in is_exprs {
1352            result = self.lower_is_expr(is_expr, result, var_env)?;
1353        }
1354
1355        // Handle negated atoms via Diff / semi-join.
1356        for neg_atom in negated_atoms {
1357            result = self.apply_negation(result, neg_atom, var_env)?;
1358        }
1359
1360        Ok(result)
1361    }
1362
1363    /// Build a left-deep join tree from positive atoms
1364    fn build_join_tree(&mut self, atoms: &[&Atom], var_env: &mut VariableEnv) -> Result<RirNode> {
1365        if atoms.is_empty() {
1366            return Err(XlogError::Compilation("Empty rule body".to_string()));
1367        }
1368
1369        // Start with the first atom as a scan
1370        let first_atom = atoms[0];
1371        let rel_id = self.get_or_create_rel_id(&first_atom.predicate);
1372        let mut result = RirNode::Scan { rel: rel_id };
1373        let mut result_vars = self.collect_atom_vars(first_atom);
1374        let mut result_width = first_atom.terms.len();
1375
1376        // Apply constant filters if any
1377        result = self.apply_constant_filters(result, first_atom, 0)?;
1378
1379        // Join with remaining atoms (left-deep)
1380        for atom in atoms.iter().skip(1) {
1381            let right_rel_id = self.get_or_create_rel_id(&atom.predicate);
1382            let right_scan = RirNode::Scan { rel: right_rel_id };
1383
1384            // Apply constant filters to the right side
1385            let right_filtered = self.apply_constant_filters(right_scan, atom, 0)?;
1386
1387            // Compute join keys based on shared variables
1388            let (left_keys, right_keys) = self.compute_join_keys(&result_vars, atom, result_width);
1389
1390            if left_keys.is_empty() {
1391                // Cartesian product (no shared variables)
1392                result = RirNode::Join {
1393                    left: Box::new(result),
1394                    right: Box::new(right_filtered),
1395                    left_keys: vec![],
1396                    right_keys: vec![],
1397                    join_type: JoinType::Inner,
1398                };
1399            } else {
1400                result = RirNode::Join {
1401                    left: Box::new(result),
1402                    right: Box::new(right_filtered),
1403                    left_keys,
1404                    right_keys,
1405                    join_type: JoinType::Inner,
1406                };
1407            }
1408
1409            // Update result vars for the next iteration
1410            for (i, term) in atom.terms.iter().enumerate() {
1411                if let Term::Variable(name) = term {
1412                    result_vars.push((name.clone(), result_width + i));
1413                }
1414            }
1415            result_width += atom.terms.len();
1416        }
1417
1418        // Update var_env with final positions
1419        var_env.total_cols = result_width;
1420
1421        Ok(result)
1422    }
1423
1424    /// Collect variable names and their positions within an atom
1425    fn collect_atom_vars(&self, atom: &Atom) -> Vec<(String, usize)> {
1426        atom.terms
1427            .iter()
1428            .enumerate()
1429            .filter_map(|(i, term)| {
1430                if let Term::Variable(name) = term {
1431                    Some((name.clone(), i))
1432                } else {
1433                    None
1434                }
1435            })
1436            .collect()
1437    }
1438
1439    /// Compute join keys between the current result and a new atom
1440    fn compute_join_keys(
1441        &self,
1442        left_vars: &[(String, usize)],
1443        right_atom: &Atom,
1444        _left_width: usize,
1445    ) -> (Vec<usize>, Vec<usize>) {
1446        let mut left_keys = Vec::new();
1447        let mut right_keys = Vec::new();
1448
1449        for (right_idx, term) in right_atom.terms.iter().enumerate() {
1450            if let Term::Variable(name) = term {
1451                // Find if this variable exists in the left side
1452                for (left_name, left_idx) in left_vars {
1453                    if left_name == name {
1454                        left_keys.push(*left_idx);
1455                        right_keys.push(right_idx);
1456                        break; // Only use first occurrence for join key
1457                    }
1458                }
1459            }
1460        }
1461
1462        (left_keys, right_keys)
1463    }
1464
1465    /// Apply constant filters for an atom
1466    fn apply_constant_filters(
1467        &self,
1468        input: RirNode,
1469        atom: &Atom,
1470        _base_col: usize,
1471    ) -> Result<RirNode> {
1472        let mut filters = Vec::new();
1473        let mut first_var_col: HashMap<&str, usize> = HashMap::new();
1474        let schema = self.schemas.get(&atom.predicate).ok_or_else(|| {
1475            XlogError::Compilation(format!("Missing schema for predicate {}", atom.predicate))
1476        })?;
1477
1478        for (i, term) in atom.terms.iter().enumerate() {
1479            if let Term::Variable(name) = term {
1480                if name != "_" {
1481                    if let Some(&first) = first_var_col.get(name.as_str()) {
1482                        filters.push(Expr::Compare {
1483                            left: Box::new(Expr::Column(first)),
1484                            op: CompareOp::Eq,
1485                            right: Box::new(Expr::Column(i)),
1486                        });
1487                    } else {
1488                        first_var_col.insert(name.as_str(), i);
1489                    }
1490                }
1491            }
1492
1493            let col_type = schema.column_type(i).ok_or_else(|| {
1494                XlogError::Compilation(format!(
1495                    "Missing column type for {} column {}",
1496                    atom.predicate, i
1497                ))
1498            })?;
1499            if let Some(const_val) = term_to_typed_const_value(term, col_type)? {
1500                filters.push(Expr::Compare {
1501                    left: Box::new(Expr::Column(i)),
1502                    op: CompareOp::Eq,
1503                    right: Box::new(Expr::Const(const_val)),
1504                });
1505            }
1506        }
1507
1508        if filters.is_empty() {
1509            Ok(input)
1510        } else {
1511            let predicate = if filters.len() == 1 {
1512                filters.pop().unwrap()
1513            } else {
1514                Expr::And(filters)
1515            };
1516
1517            Ok(RirNode::Filter {
1518                input: Box::new(input),
1519                predicate,
1520            })
1521        }
1522    }
1523
1524    /// Apply a comparison as a filter
1525    fn apply_comparison(
1526        &self,
1527        input: RirNode,
1528        cmp: &Comparison,
1529        var_env: &VariableEnv,
1530    ) -> Result<RirNode> {
1531        let (left_expr, right_expr) = match (&cmp.left, &cmp.right) {
1532            (Term::Variable(name), term) => {
1533                let col = var_env.get_column(name).ok_or_else(|| {
1534                    XlogError::Compilation(format!("Variable {} not found in environment", name))
1535                })?;
1536                let typ = var_env.get_type(name).ok_or_else(|| {
1537                    XlogError::Compilation(format!("Missing type for variable {}", name))
1538                })?;
1539                if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1540                    (Expr::Column(col), Expr::Const(const_val))
1541                } else {
1542                    (
1543                        self.term_to_expr(&cmp.left, var_env)?,
1544                        self.term_to_expr(&cmp.right, var_env)?,
1545                    )
1546                }
1547            }
1548            (term, Term::Variable(name)) => {
1549                let col = var_env.get_column(name).ok_or_else(|| {
1550                    XlogError::Compilation(format!("Variable {} not found in environment", name))
1551                })?;
1552                let typ = var_env.get_type(name).ok_or_else(|| {
1553                    XlogError::Compilation(format!("Missing type for variable {}", name))
1554                })?;
1555                if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1556                    (Expr::Const(const_val), Expr::Column(col))
1557                } else {
1558                    (
1559                        self.term_to_expr(&cmp.left, var_env)?,
1560                        self.term_to_expr(&cmp.right, var_env)?,
1561                    )
1562                }
1563            }
1564            _ => (
1565                self.term_to_expr(&cmp.left, var_env)?,
1566                self.term_to_expr(&cmp.right, var_env)?,
1567            ),
1568        };
1569
1570        let op = match cmp.op {
1571            CompOp::Eq => CompareOp::Eq,
1572            CompOp::Ne => CompareOp::Ne,
1573            CompOp::Lt => CompareOp::Lt,
1574            CompOp::Le => CompareOp::Le,
1575            CompOp::Gt => CompareOp::Gt,
1576            CompOp::Ge => CompareOp::Ge,
1577        };
1578
1579        Ok(RirNode::Filter {
1580            input: Box::new(input),
1581            predicate: Expr::Compare {
1582                left: Box::new(left_expr),
1583                op,
1584                right: Box::new(right_expr),
1585            },
1586        })
1587    }
1588
1589    /// Convert a term to an expression
1590    fn term_to_expr(&self, term: &Term, var_env: &VariableEnv) -> Result<Expr> {
1591        match term {
1592            Term::Variable(name) => {
1593                if let Some(col) = var_env.get_column(name) {
1594                    Ok(Expr::Column(col))
1595                } else {
1596                    Err(XlogError::Compilation(format!(
1597                        "Variable {} not found in environment",
1598                        name
1599                    )))
1600                }
1601            }
1602            Term::Anonymous => Err(XlogError::Compilation(
1603                "Anonymous wildcard '_' not allowed in comparisons".to_string(),
1604            )),
1605            Term::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
1606            Term::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
1607            Term::String(s) => Ok(Expr::Const(ConstValue::Symbol(s.clone()))),
1608            Term::Symbol(id) => Ok(Expr::Const(ConstValue::Symbol(symbol::resolve(*id)))),
1609            Term::Aggregate(_) => Err(XlogError::Compilation(
1610                "Aggregates not allowed in comparisons".to_string(),
1611            )),
1612            Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => Err(
1613                term_not_lowerable_error("comparison", term_kind_for_lowering_error(term)),
1614            ),
1615        }
1616    }
1617
1618    /// Apply negation via set difference
1619    fn apply_negation(
1620        &mut self,
1621        input: RirNode,
1622        neg_atom: &Atom,
1623        var_env: &VariableEnv,
1624    ) -> Result<RirNode> {
1625        let rel_id = self.get_or_create_rel_id(&neg_atom.predicate);
1626        let neg_scan = RirNode::Scan { rel: rel_id };
1627
1628        // Apply constant filters to the negated atom
1629        let neg_filtered = self.apply_constant_filters(neg_scan, neg_atom, 0)?;
1630
1631        // Find which columns from the input correspond to variables in the negated atom
1632        let mut input_cols = Vec::new();
1633        let mut neg_cols = Vec::new();
1634
1635        for (neg_idx, term) in neg_atom.terms.iter().enumerate() {
1636            if let Term::Variable(name) = term {
1637                if let Some(col) = var_env.get_column(name) {
1638                    input_cols.push(col);
1639                    neg_cols.push(neg_idx);
1640                }
1641            }
1642        }
1643
1644        if input_cols.is_empty() {
1645            // No shared variables - this is an existence check
1646            // If the negated relation is non-empty, result is empty
1647            // This is a special case we handle with anti-join
1648            Ok(RirNode::Diff {
1649                left: Box::new(input),
1650                right: Box::new(neg_filtered),
1651            })
1652        } else {
1653            // Project the negated atom to only the shared variable columns
1654            let neg_projected = if neg_cols.len() < neg_atom.terms.len() {
1655                let neg_proj_exprs: Vec<ProjectExpr> =
1656                    neg_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1657                RirNode::Project {
1658                    input: Box::new(neg_filtered),
1659                    columns: neg_proj_exprs,
1660                }
1661            } else {
1662                neg_filtered
1663            };
1664
1665            // Project input to matching columns for the diff, then diff
1666            // Actually, for proper anti-join semantics we need to be careful.
1667            // The Diff operation subtracts matching tuples.
1668            // We need to project input to the shared columns, diff, then rejoin.
1669
1670            // Simpler approach: project input to shared columns, diff with negated,
1671            // then rejoin with original
1672            let input_proj_exprs: Vec<ProjectExpr> =
1673                input_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1674            let input_projected = RirNode::Project {
1675                input: Box::new(input.clone()),
1676                columns: input_proj_exprs,
1677            };
1678
1679            // The Diff gives us the keys that should be kept
1680            let kept_keys = RirNode::Diff {
1681                left: Box::new(input_projected),
1682                right: Box::new(neg_projected),
1683            };
1684
1685            // Join back with original input to get full tuples
1686            // This effectively filters the input to only rows where the key
1687            // is not in the negated relation
1688            Ok(RirNode::Join {
1689                left: Box::new(input),
1690                right: Box::new(kept_keys),
1691                left_keys: input_cols.clone(),
1692                right_keys: (0..input_cols.len()).collect(),
1693                join_type: JoinType::Semi,
1694            })
1695        }
1696    }
1697
1698    fn is_identity_projection(proj: &[ProjectExpr], input_cols: usize) -> bool {
1699        if proj.len() != input_cols {
1700            return false;
1701        }
1702        proj.iter()
1703            .enumerate()
1704            .all(|(i, e)| matches!(e, ProjectExpr::Column(c) if *c == i))
1705    }
1706
1707    /// Build a projection list that matches the rule head term order.
1708    ///
1709    /// For non-aggregate rules this supports:
1710    /// - Variables (column passthrough)
1711    /// - Constants (computed constant columns)
1712    fn compute_head_projection(
1713        &self,
1714        head: &Atom,
1715        var_env: &VariableEnv,
1716    ) -> Result<Vec<ProjectExpr>> {
1717        let mut cols = Vec::with_capacity(head.terms.len());
1718
1719        for term in &head.terms {
1720            match term {
1721                Term::Variable(name) => {
1722                    let col = var_env
1723                        .get_column(name)
1724                        .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1725                    cols.push(ProjectExpr::Column(col));
1726                }
1727                Term::Anonymous => {
1728                    return Err(XlogError::Compilation(
1729                        "Anonymous wildcard '_' not allowed in rule head".to_string(),
1730                    ));
1731                }
1732                Term::Aggregate(_) => {
1733                    return Err(XlogError::Compilation(
1734                        "Aggregate term in non-aggregate rule head".to_string(),
1735                    ));
1736                }
1737                Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1738                    let (expr, typ) = term_to_project_const_expr(term)?;
1739                    cols.push(ProjectExpr::Computed(expr, typ));
1740                }
1741                Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1742                    return Err(term_not_lowerable_error(
1743                        "rule head projection",
1744                        term_kind_for_lowering_error(term),
1745                    ));
1746                }
1747            }
1748        }
1749
1750        Ok(cols)
1751    }
1752
1753    /// Lower an aggregate rule head into `GroupBy` + final projection.
1754    fn lower_aggregate_rule(
1755        &mut self,
1756        head: &Atom,
1757        body: RirNode,
1758        var_env: &VariableEnv,
1759    ) -> Result<RirNode> {
1760        // Collect unique group keys in head order.
1761        let mut key_vars: Vec<String> = Vec::new();
1762        let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
1763        let mut key_src_cols: Vec<usize> = Vec::new();
1764
1765        // Collect unique aggregate specs (op, var) in head order.
1766        let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
1767        let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
1768        let mut value_vars: Vec<String> = Vec::new();
1769        let mut value_var_to_pos: HashMap<String, usize> = HashMap::new();
1770        let mut value_src_cols: Vec<usize> = Vec::new();
1771
1772        for term in &head.terms {
1773            match term {
1774                Term::Variable(name) => {
1775                    if !key_var_to_pos.contains_key(name) {
1776                        let col = var_env
1777                            .get_column(name)
1778                            .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1779                        let pos = key_vars.len();
1780                        key_vars.push(name.clone());
1781                        key_var_to_pos.insert(name.clone(), pos);
1782                        key_src_cols.push(col);
1783                    }
1784                }
1785                Term::Aggregate(agg) => {
1786                    let key = (agg.op, agg.variable.clone());
1787                    if let std::collections::hash_map::Entry::Vacant(entry) = agg_to_pos.entry(key)
1788                    {
1789                        // Ensure the aggregated variable is bound.
1790                        let col = var_env
1791                            .get_column(&agg.variable)
1792                            .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1793
1794                        // Ensure the value variable exists in the groupby input.
1795                        let value_pos = *value_var_to_pos
1796                            .entry(agg.variable.clone())
1797                            .or_insert_with(|| {
1798                                let p = value_vars.len();
1799                                value_vars.push(agg.variable.clone());
1800                                value_src_cols.push(col);
1801                                p
1802                            });
1803
1804                        let agg_pos = agg_specs.len();
1805                        agg_specs.push((agg.op, agg.variable.clone()));
1806                        entry.insert(agg_pos);
1807
1808                        // Keep clippy happy about unused value_pos in insert_with closure.
1809                        let _ = value_pos;
1810                    }
1811                }
1812                Term::Anonymous => {
1813                    return Err(XlogError::Compilation(
1814                        "Anonymous wildcard '_' not allowed in rule head".to_string(),
1815                    ));
1816                }
1817                Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1818                    // Constants are allowed in the head; they are projected after aggregation.
1819                }
1820                Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1821                    return Err(term_not_lowerable_error(
1822                        "aggregate rule head",
1823                        term_kind_for_lowering_error(term),
1824                    ));
1825                }
1826            }
1827        }
1828
1829        if agg_specs.is_empty() {
1830            return Err(XlogError::Compilation(
1831                "Rule marked as aggregate but no aggregate terms found".to_string(),
1832            ));
1833        }
1834
1835        // Build groupby input: [keys..., values...]. For global aggregates (no keys),
1836        // synthesize a constant key column so GroupBy is well-defined.
1837        let mut group_input_cols: Vec<ProjectExpr> = Vec::new();
1838        let mut key_cols: Vec<usize> = Vec::new();
1839
1840        if key_src_cols.is_empty() {
1841            group_input_cols.push(ProjectExpr::Computed(
1842                Expr::Const(ConstValue::U32(0)),
1843                ScalarType::U32,
1844            ));
1845            key_cols.push(0);
1846        } else {
1847            for (i, &col) in key_src_cols.iter().enumerate() {
1848                group_input_cols.push(ProjectExpr::Column(col));
1849                key_cols.push(i);
1850            }
1851        }
1852
1853        let value_offset = group_input_cols.len();
1854        for &col in &value_src_cols {
1855            group_input_cols.push(ProjectExpr::Column(col));
1856        }
1857
1858        let group_input = RirNode::Project {
1859            input: Box::new(body),
1860            columns: group_input_cols,
1861        };
1862
1863        // Build multi-aggregation spec list (value_col indices are in the group_input schema).
1864        let mut aggs: Vec<(usize, CoreAggOp)> = Vec::with_capacity(agg_specs.len());
1865        for (op, var) in &agg_specs {
1866            let value_pos = *value_var_to_pos
1867                .get(var)
1868                .ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
1869            let value_col = value_offset + value_pos;
1870            aggs.push((value_col, convert_agg_op(op)));
1871        }
1872
1873        let groupby = RirNode::GroupBy {
1874            input: Box::new(group_input),
1875            key_cols,
1876            aggs,
1877        };
1878
1879        // Final projection to match head term order:
1880        // - variables map to group key columns
1881        // - aggregates map to groupby output agg columns (after keys)
1882        // - constants are computed columns
1883        let key_count = if key_src_cols.is_empty() {
1884            1
1885        } else {
1886            key_vars.len()
1887        };
1888
1889        let mut final_proj: Vec<ProjectExpr> = Vec::with_capacity(head.terms.len());
1890        for term in &head.terms {
1891            match term {
1892                Term::Variable(name) => {
1893                    let idx = if key_src_cols.is_empty() {
1894                        // Global aggregates have no key vars in the output; binding a variable in the head
1895                        // is a semantic error because it would be unbound.
1896                        return Err(XlogError::UnsafeVariable(name.clone()));
1897                    } else {
1898                        *key_var_to_pos
1899                            .get(name)
1900                            .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?
1901                    };
1902                    final_proj.push(ProjectExpr::Column(idx));
1903                }
1904                Term::Aggregate(agg) => {
1905                    let pos = *agg_to_pos
1906                        .get(&(agg.op, agg.variable.clone()))
1907                        .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1908                    final_proj.push(ProjectExpr::Column(key_count + pos));
1909                }
1910                Term::Anonymous => {
1911                    return Err(XlogError::Compilation(
1912                        "Anonymous wildcard '_' not allowed in rule head".to_string(),
1913                    ));
1914                }
1915                Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1916                    let (expr, typ) = term_to_project_const_expr(term)?;
1917                    final_proj.push(ProjectExpr::Computed(expr, typ));
1918                }
1919                Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1920                    return Err(term_not_lowerable_error(
1921                        "aggregate rule projection",
1922                        term_kind_for_lowering_error(term),
1923                    ));
1924                }
1925            }
1926        }
1927
1928        if final_proj.is_empty() {
1929            return Err(XlogError::Compilation(
1930                "Aggregate rule produced empty head projection".to_string(),
1931            ));
1932        }
1933
1934        Ok(RirNode::Project {
1935            input: Box::new(groupby),
1936            columns: final_proj,
1937        })
1938    }
1939
1940    /// Infer the result type of an arithmetic expression (strict same-type)
1941    pub(crate) fn infer_arith_type(
1942        &self,
1943        expr: &ArithExpr,
1944        var_env: &VariableEnv,
1945    ) -> Result<ScalarType> {
1946        match expr {
1947            ArithExpr::Variable(name) => var_env.get_type(name).ok_or_else(|| {
1948                XlogError::Compilation(format!("Unknown variable {} in arithmetic", name))
1949            }),
1950            ArithExpr::Integer(_) => Ok(ScalarType::I64),
1951            ArithExpr::Float(_) => Ok(ScalarType::F64),
1952
1953            ArithExpr::Add(l, r)
1954            | ArithExpr::Sub(l, r)
1955            | ArithExpr::Mul(l, r)
1956            | ArithExpr::Div(l, r) => {
1957                let lt = self.infer_arith_type(l, var_env)?;
1958                let rt = self.infer_arith_type(r, var_env)?;
1959
1960                if lt != rt {
1961                    return Err(XlogError::Compilation(format!(
1962                        "Type mismatch in arithmetic: {:?} vs {:?}. Use cast() for conversion.",
1963                        lt, rt
1964                    )));
1965                }
1966
1967                if !Self::is_numeric_type(&lt) {
1968                    return Err(XlogError::Compilation(format!(
1969                        "Arithmetic requires numeric type, got {:?}",
1970                        lt
1971                    )));
1972                }
1973
1974                Ok(lt)
1975            }
1976
1977            ArithExpr::Mod(l, r) => {
1978                let lt = self.infer_arith_type(l, var_env)?;
1979                let rt = self.infer_arith_type(r, var_env)?;
1980
1981                if lt != rt {
1982                    return Err(XlogError::Compilation(format!(
1983                        "Type mismatch in mod: {:?} vs {:?}",
1984                        lt, rt
1985                    )));
1986                }
1987
1988                if matches!(lt, ScalarType::F32 | ScalarType::F64) {
1989                    return Err(XlogError::Compilation(
1990                        "Modulo (%) not supported for floating point".into(),
1991                    ));
1992                }
1993
1994                Ok(lt)
1995            }
1996
1997            ArithExpr::Abs(inner) => {
1998                let t = self.infer_arith_type(inner, var_env)?;
1999                if !Self::is_numeric_type(&t) {
2000                    return Err(XlogError::Compilation(format!(
2001                        "abs requires numeric type, got {:?}",
2002                        t
2003                    )));
2004                }
2005                Ok(t)
2006            }
2007
2008            ArithExpr::Min(l, r) | ArithExpr::Max(l, r) => {
2009                let lt = self.infer_arith_type(l, var_env)?;
2010                let rt = self.infer_arith_type(r, var_env)?;
2011
2012                if lt != rt {
2013                    return Err(XlogError::Compilation(format!(
2014                        "Type mismatch in min/max: {:?} vs {:?}",
2015                        lt, rt
2016                    )));
2017                }
2018
2019                if !Self::is_numeric_type(&lt) {
2020                    return Err(XlogError::Compilation(format!(
2021                        "min/max requires numeric type, got {:?}",
2022                        lt
2023                    )));
2024                }
2025
2026                Ok(lt)
2027            }
2028
2029            ArithExpr::Pow(base, exp) => {
2030                let base_t = self.infer_arith_type(base, var_env)?;
2031                let exp_t = self.infer_arith_type(exp, var_env)?;
2032
2033                if !Self::is_numeric_type(&base_t) || !Self::is_numeric_type(&exp_t) {
2034                    return Err(XlogError::Compilation(format!(
2035                        "pow requires numeric operands, got {:?} and {:?}",
2036                        base_t, exp_t
2037                    )));
2038                }
2039
2040                // pow always returns f64 (standard math behavior)
2041                Ok(ScalarType::F64)
2042            }
2043
2044            ArithExpr::Cast(_, target) => Ok(*target),
2045
2046            ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2047                "User-defined function '{}' must be inlined before lowering",
2048                name
2049            ))),
2050
2051            ArithExpr::Conditional {
2052                then_expr,
2053                else_expr,
2054                ..
2055            } => {
2056                // Both branches must have the same type
2057                let then_type = self.infer_arith_type(then_expr, var_env)?;
2058                let else_type = self.infer_arith_type(else_expr, var_env)?;
2059                if then_type != else_type {
2060                    return Err(XlogError::Compilation(format!(
2061                        "Conditional branches have different types: {:?} vs {:?}",
2062                        then_type, else_type
2063                    )));
2064                }
2065                Ok(then_type)
2066            }
2067        }
2068    }
2069
2070    fn is_numeric_type(t: &ScalarType) -> bool {
2071        matches!(
2072            t,
2073            ScalarType::I32
2074                | ScalarType::I64
2075                | ScalarType::U32
2076                | ScalarType::U64
2077                | ScalarType::F32
2078                | ScalarType::F64
2079        )
2080    }
2081
2082    /// Convert ArithExpr to IR Expr
2083    fn arith_to_expr(&self, arith: &ArithExpr, var_env: &VariableEnv) -> Result<Expr> {
2084        match arith {
2085            ArithExpr::Variable(name) => {
2086                let col = var_env.get_column(name).ok_or_else(|| {
2087                    XlogError::Compilation(format!(
2088                        "Variable {} not bound before use in arithmetic",
2089                        name
2090                    ))
2091                })?;
2092                Ok(Expr::Column(col))
2093            }
2094            ArithExpr::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
2095            ArithExpr::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
2096
2097            ArithExpr::Add(l, r) => Ok(Expr::Add(
2098                Box::new(self.arith_to_expr(l, var_env)?),
2099                Box::new(self.arith_to_expr(r, var_env)?),
2100            )),
2101            ArithExpr::Sub(l, r) => Ok(Expr::Sub(
2102                Box::new(self.arith_to_expr(l, var_env)?),
2103                Box::new(self.arith_to_expr(r, var_env)?),
2104            )),
2105            ArithExpr::Mul(l, r) => Ok(Expr::Mul(
2106                Box::new(self.arith_to_expr(l, var_env)?),
2107                Box::new(self.arith_to_expr(r, var_env)?),
2108            )),
2109            ArithExpr::Div(l, r) => Ok(Expr::Div(
2110                Box::new(self.arith_to_expr(l, var_env)?),
2111                Box::new(self.arith_to_expr(r, var_env)?),
2112            )),
2113            ArithExpr::Mod(l, r) => Ok(Expr::Mod(
2114                Box::new(self.arith_to_expr(l, var_env)?),
2115                Box::new(self.arith_to_expr(r, var_env)?),
2116            )),
2117
2118            ArithExpr::Abs(e) => Ok(Expr::Abs(Box::new(self.arith_to_expr(e, var_env)?))),
2119            ArithExpr::Min(l, r) => Ok(Expr::Min(
2120                Box::new(self.arith_to_expr(l, var_env)?),
2121                Box::new(self.arith_to_expr(r, var_env)?),
2122            )),
2123            ArithExpr::Max(l, r) => Ok(Expr::Max(
2124                Box::new(self.arith_to_expr(l, var_env)?),
2125                Box::new(self.arith_to_expr(r, var_env)?),
2126            )),
2127            ArithExpr::Pow(l, r) => Ok(Expr::Pow(
2128                Box::new(self.arith_to_expr(l, var_env)?),
2129                Box::new(self.arith_to_expr(r, var_env)?),
2130            )),
2131            ArithExpr::Cast(e, t) => Ok(Expr::Cast(Box::new(self.arith_to_expr(e, var_env)?), *t)),
2132
2133            ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2134                "User-defined function '{}' must be inlined before lowering",
2135                name
2136            ))),
2137
2138            ArithExpr::Conditional {
2139                cond_left,
2140                cond_op,
2141                cond_right,
2142                then_expr,
2143                else_expr,
2144            } => {
2145                // Convert AST comparison operator to IR comparison operator
2146                let ir_cond_op = match cond_op {
2147                    CompOp::Eq => CompareOp::Eq,
2148                    CompOp::Ne => CompareOp::Ne,
2149                    CompOp::Lt => CompareOp::Lt,
2150                    CompOp::Le => CompareOp::Le,
2151                    CompOp::Gt => CompareOp::Gt,
2152                    CompOp::Ge => CompareOp::Ge,
2153                };
2154
2155                // Build the condition as a Compare expression
2156                let condition = Expr::Compare {
2157                    left: Box::new(self.arith_to_expr(cond_left, var_env)?),
2158                    op: ir_cond_op,
2159                    right: Box::new(self.arith_to_expr(cond_right, var_env)?),
2160                };
2161
2162                // Build then and else expressions (recursive for nested conditionals)
2163                let then_ir = self.arith_to_expr(then_expr, var_env)?;
2164                let else_ir = self.arith_to_expr(else_expr, var_env)?;
2165
2166                Ok(Expr::Conditional {
2167                    condition: Box::new(condition),
2168                    then_expr: Box::new(then_ir),
2169                    else_expr: Box::new(else_ir),
2170                })
2171            }
2172        }
2173    }
2174
2175    /// Lower an is-expression to a Project node with computed column
2176    fn lower_is_expr(
2177        &mut self,
2178        is_expr: &IsExpr,
2179        input: RirNode,
2180        var_env: &mut VariableEnv,
2181    ) -> Result<RirNode> {
2182        // 1. Verify target is NOT already bound
2183        if var_env.contains(&is_expr.target) {
2184            return Err(XlogError::Compilation(format!(
2185                "Variable {} already bound; 'is' requires fresh variable",
2186                is_expr.target
2187            )));
2188        }
2189
2190        // 2. Verify all variables in expression are bound
2191        for var in is_expr.expr.variables() {
2192            if !var_env.contains(var) {
2193                return Err(XlogError::Compilation(format!(
2194                    "Variable {} used in arithmetic but not bound",
2195                    var
2196                )));
2197            }
2198        }
2199
2200        // 3. Infer result type
2201        let result_type = self.infer_arith_type(&is_expr.expr, var_env)?;
2202
2203        // 4. Convert expression to IR
2204        let ir_expr = self.arith_to_expr(&is_expr.expr, var_env)?;
2205
2206        // 5. Build projection: pass through all existing columns + add computed column
2207        let num_cols = var_env.column_count();
2208        let mut proj_exprs: Vec<ProjectExpr> = (0..num_cols).map(ProjectExpr::Column).collect();
2209        proj_exprs.push(ProjectExpr::Computed(ir_expr, result_type));
2210
2211        // 6. Bind the new variable
2212        var_env.bind(&is_expr.target, num_cols, result_type);
2213
2214        Ok(RirNode::Project {
2215            input: Box::new(input),
2216            columns: proj_exprs,
2217        })
2218    }
2219}
2220
2221/// Track variable occurrences and column positions
2222pub(crate) struct VariableEnv {
2223    /// Maps variable name to list of (predicate, position in atom, global column)
2224    occurrences: HashMap<String, Vec<(String, usize, usize)>>,
2225    /// Total columns in current result
2226    total_cols: usize,
2227    /// Maps variable name to its type (for type inference)
2228    types: HashMap<String, ScalarType>,
2229}
2230
2231impl VariableEnv {
2232    fn new() -> Self {
2233        Self {
2234            occurrences: HashMap::new(),
2235            total_cols: 0,
2236            types: HashMap::new(),
2237        }
2238    }
2239
2240    fn add_occurrence(&mut self, var: &str, pred: String, atom_pos: usize, global_col: usize) {
2241        self.occurrences
2242            .entry(var.to_string())
2243            .or_default()
2244            .push((pred, atom_pos, global_col));
2245    }
2246
2247    fn get_column(&self, var: &str) -> Option<usize> {
2248        self.occurrences
2249            .get(var)
2250            .and_then(|occs| occs.first())
2251            .map(|(_, _, col)| *col)
2252    }
2253
2254    /// Bind a variable to a column with a specific type (for type inference)
2255    fn bind(&mut self, name: &str, column: usize, typ: ScalarType) {
2256        self.types.insert(name.to_string(), typ);
2257        // Also add occurrence for column lookup
2258        self.occurrences
2259            .entry(name.to_string())
2260            .or_default()
2261            .push(("".to_string(), 0, column));
2262        // Update total_cols to account for the new computed column
2263        // This is critical for chained is-expressions where each adds a column
2264        if column >= self.total_cols {
2265            self.total_cols = column + 1;
2266        }
2267    }
2268
2269    /// Get the type of a bound variable
2270    fn get_type(&self, name: &str) -> Option<ScalarType> {
2271        self.types.get(name).copied()
2272    }
2273
2274    /// Check if a variable is bound
2275    fn contains(&self, name: &str) -> bool {
2276        self.occurrences.contains_key(name)
2277    }
2278
2279    /// Get the current column count (for adding new computed columns)
2280    fn column_count(&self) -> usize {
2281        self.total_cols
2282    }
2283}
2284
2285/// Infer the type of a term
2286fn infer_term_type(term: &Term) -> ScalarType {
2287    match term {
2288        Term::Variable(_) | Term::Anonymous => ScalarType::U64, // Default for variables
2289        Term::Integer(i) => {
2290            if *i >= 0 && *i <= u32::MAX as i64 {
2291                ScalarType::U32
2292            } else {
2293                ScalarType::I64
2294            }
2295        }
2296        Term::Float(_) => ScalarType::F64,
2297        Term::String(_) | Term::Symbol(_) => ScalarType::Symbol,
2298        Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
2299            ScalarType::U64
2300        }
2301        Term::Aggregate(agg) => match agg.op {
2302            AggOp::Count => ScalarType::U32,
2303            AggOp::Sum => ScalarType::U64,
2304            AggOp::Min | AggOp::Max => ScalarType::U32,
2305            AggOp::LogSumExp => ScalarType::F64,
2306        },
2307    }
2308}
2309
2310fn sort_labels_from_terms(terms: &[Term]) -> Vec<String> {
2311    terms
2312        .iter()
2313        .enumerate()
2314        .map(|(idx, term)| match term {
2315            Term::Variable(name) if !name.trim().is_empty() => name.clone(),
2316            Term::Aggregate(agg) => format!("{:?}_{}", agg.op, agg.variable),
2317            Term::List(_) => format!("list{}", idx),
2318            Term::Cons { .. } => format!("cons{}", idx),
2319            Term::Compound { functor, .. } => functor.clone(),
2320            Term::PredRef(name) => name.clone(),
2321            _ => format!("c{}", idx),
2322        })
2323        .collect()
2324}
2325
2326/// Convert a term to a constant value (if it is a constant)
2327fn term_to_const_value(term: &Term) -> Option<ConstValue> {
2328    match term {
2329        Term::Integer(i) => Some(ConstValue::I64(*i)),
2330        Term::Float(f) => Some(ConstValue::F64(*f)),
2331        Term::String(s) => Some(ConstValue::Symbol(s.clone())),
2332        Term::Symbol(id) => Some(ConstValue::Symbol(symbol::resolve(*id))),
2333        Term::Variable(_)
2334        | Term::Anonymous
2335        | Term::Aggregate(_)
2336        | Term::List(_)
2337        | Term::Cons { .. }
2338        | Term::Compound { .. }
2339        | Term::PredRef(_) => None,
2340    }
2341}
2342
2343fn term_to_typed_const_value(term: &Term, expected: ScalarType) -> Result<Option<ConstValue>> {
2344    let const_val = match term {
2345        Term::Integer(i) => match expected {
2346            ScalarType::U32 => {
2347                if *i >= 0 && *i <= u32::MAX as i64 {
2348                    ConstValue::U32(*i as u32)
2349                } else {
2350                    return Err(XlogError::Compilation(format!(
2351                        "Integer literal {} out of range for {:?}",
2352                        i, expected
2353                    )));
2354                }
2355            }
2356            ScalarType::U64 => {
2357                if *i >= 0 {
2358                    ConstValue::U64(*i as u64)
2359                } else {
2360                    return Err(XlogError::Compilation(format!(
2361                        "Integer literal {} out of range for {:?}",
2362                        i, expected
2363                    )));
2364                }
2365            }
2366            ScalarType::I32 => {
2367                if *i >= i32::MIN as i64 && *i <= i32::MAX as i64 {
2368                    ConstValue::I32(*i as i32)
2369                } else {
2370                    return Err(XlogError::Compilation(format!(
2371                        "Integer literal {} out of range for {:?}",
2372                        i, expected
2373                    )));
2374                }
2375            }
2376            ScalarType::I64 => ConstValue::I64(*i),
2377            ScalarType::F32 => {
2378                let value = *i as f64;
2379                if value < f32::MIN as f64 || value > f32::MAX as f64 {
2380                    return Err(XlogError::Compilation(format!(
2381                        "Integer literal {} out of range for {:?}",
2382                        i, expected
2383                    )));
2384                }
2385                ConstValue::F32(value as f32)
2386            }
2387            ScalarType::F64 => ConstValue::F64(*i as f64),
2388            ScalarType::Bool => {
2389                if *i == 0 || *i == 1 {
2390                    ConstValue::Bool(*i == 1)
2391                } else {
2392                    return Err(XlogError::Compilation(format!(
2393                        "Integer literal {} not valid for {:?}",
2394                        i, expected
2395                    )));
2396                }
2397            }
2398            ScalarType::Symbol => {
2399                return Err(XlogError::Compilation(format!(
2400                    "Integer literal {} not valid for {:?}",
2401                    i, expected
2402                )));
2403            }
2404        },
2405        Term::Float(f) => match expected {
2406            ScalarType::F32 => {
2407                if !f.is_finite() {
2408                    return Err(XlogError::Compilation(format!(
2409                        "Float literal {} not valid for {:?}",
2410                        f, expected
2411                    )));
2412                }
2413                if *f < f32::MIN as f64 || *f > f32::MAX as f64 {
2414                    return Err(XlogError::Compilation(format!(
2415                        "Float literal {} out of range for {:?}",
2416                        f, expected
2417                    )));
2418                }
2419                ConstValue::F32(*f as f32)
2420            }
2421            ScalarType::F64 => ConstValue::F64(*f),
2422            ScalarType::U32
2423            | ScalarType::U64
2424            | ScalarType::I32
2425            | ScalarType::I64
2426            | ScalarType::Bool
2427            | ScalarType::Symbol => {
2428                return Err(XlogError::Compilation(format!(
2429                    "Float literal {} not valid for {:?}",
2430                    f, expected
2431                )));
2432            }
2433        },
2434        Term::String(s) => {
2435            if expected == ScalarType::Symbol {
2436                ConstValue::Symbol(s.clone())
2437            } else {
2438                return Err(XlogError::Compilation(format!(
2439                    "String literal {} not valid for {:?}",
2440                    s, expected
2441                )));
2442            }
2443        }
2444        Term::Symbol(id) => {
2445            if expected == ScalarType::Symbol {
2446                ConstValue::Symbol(symbol::resolve(*id))
2447            } else {
2448                return Err(XlogError::Compilation(format!(
2449                    "Symbol literal {} not valid for {:?}",
2450                    symbol::resolve(*id),
2451                    expected
2452                )));
2453            }
2454        }
2455        Term::Variable(_)
2456        | Term::Anonymous
2457        | Term::Aggregate(_)
2458        | Term::List(_)
2459        | Term::Cons { .. }
2460        | Term::Compound { .. }
2461        | Term::PredRef(_) => return Ok(None),
2462    };
2463
2464    Ok(Some(const_val))
2465}
2466
2467fn term_to_project_const_expr(term: &Term) -> Result<(Expr, ScalarType)> {
2468    match term {
2469        Term::Integer(i) => {
2470            if *i >= 0 && *i <= u32::MAX as i64 {
2471                Ok((Expr::Const(ConstValue::U32(*i as u32)), ScalarType::U32))
2472            } else {
2473                Ok((Expr::Const(ConstValue::I64(*i)), ScalarType::I64))
2474            }
2475        }
2476        Term::Float(f) => Ok((Expr::Const(ConstValue::F64(*f)), ScalarType::F64)),
2477        Term::String(s) => Ok((
2478            Expr::Const(ConstValue::Symbol(s.clone())),
2479            ScalarType::Symbol,
2480        )),
2481        Term::Symbol(id) => Ok((
2482            Expr::Const(ConstValue::Symbol(symbol::resolve(*id))),
2483            ScalarType::Symbol,
2484        )),
2485        Term::Variable(_)
2486        | Term::Anonymous
2487        | Term::Aggregate(_)
2488        | Term::List(_)
2489        | Term::Cons { .. }
2490        | Term::Compound { .. }
2491        | Term::PredRef(_) => Err(XlogError::Compilation("Expected constant term".to_string())),
2492    }
2493}
2494
2495/// Convert AST AggOp to core AggOp
2496fn convert_agg_op(op: &AggOp) -> CoreAggOp {
2497    match op {
2498        AggOp::Count => CoreAggOp::Count,
2499        AggOp::Sum => CoreAggOp::Sum,
2500        AggOp::Min => CoreAggOp::Min,
2501        AggOp::Max => CoreAggOp::Max,
2502        AggOp::LogSumExp => CoreAggOp::LogSumExp,
2503    }
2504}
2505
2506// Export the find_sccs_for_lowering function from stratify
2507// We need to add this to the stratify module
2508
2509#[cfg(test)]
2510mod arith_type_tests {
2511    use super::*;
2512    use crate::ast::ArithExpr;
2513
2514    #[test]
2515    fn test_arith_type_inference_same_type() {
2516        // X + Y where both are i64 should succeed and return i64
2517        let lowerer = Lowerer::new();
2518        let mut var_env = VariableEnv::new();
2519        var_env.bind("X", 0, ScalarType::I64);
2520        var_env.bind("Y", 1, ScalarType::I64);
2521
2522        let expr = ArithExpr::Add(
2523            Box::new(ArithExpr::Variable("X".to_string())),
2524            Box::new(ArithExpr::Variable("Y".to_string())),
2525        );
2526        let result = lowerer.infer_arith_type(&expr, &var_env);
2527        assert!(result.is_ok());
2528        assert_eq!(result.unwrap(), ScalarType::I64);
2529    }
2530
2531    #[test]
2532    fn test_arith_type_inference_mismatch() {
2533        // X + Y where X is i64 and Y is f64 should fail
2534        let lowerer = Lowerer::new();
2535        let mut var_env = VariableEnv::new();
2536        var_env.bind("X", 0, ScalarType::I64);
2537        var_env.bind("Y", 1, ScalarType::F64);
2538
2539        let expr = ArithExpr::Add(
2540            Box::new(ArithExpr::Variable("X".to_string())),
2541            Box::new(ArithExpr::Variable("Y".to_string())),
2542        );
2543        let result = lowerer.infer_arith_type(&expr, &var_env);
2544        assert!(result.is_err());
2545    }
2546}
2547
2548#[cfg(test)]
2549mod tests {
2550    use super::*;
2551    use crate::ast::*;
2552
2553    fn pred_decl(name: &str, types: Vec<ScalarType>) -> PredDecl {
2554        let type_refs: Vec<TypeRef> = types.into_iter().map(TypeRef::Scalar).collect();
2555        let columns = type_refs
2556            .iter()
2557            .cloned()
2558            .map(|typ| PredColumn { name: None, typ })
2559            .collect();
2560        PredDecl {
2561            name: name.to_string(),
2562            types: type_refs,
2563            columns,
2564            is_private: false,
2565        }
2566    }
2567
2568    /// Helper to create a simple edge atom
2569    fn edge_atom(x: &str, y: &str) -> Atom {
2570        Atom {
2571            predicate: "edge".to_string(),
2572            terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2573        }
2574    }
2575
2576    /// Helper to create a reach atom
2577    fn reach_atom(x: &str, y: &str) -> Atom {
2578        Atom {
2579            predicate: "reach".to_string(),
2580            terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2581        }
2582    }
2583
2584    /// Helper to create a node atom
2585    fn node_atom(x: &str) -> Atom {
2586        Atom {
2587            predicate: "node".to_string(),
2588            terms: vec![Term::Variable(x.to_string())],
2589        }
2590    }
2591
2592    #[test]
2593    fn test_lowerer_new() {
2594        let lowerer = Lowerer::new();
2595        assert!(lowerer.schemas.is_empty());
2596        assert!(lowerer.strata.is_empty());
2597        assert_eq!(lowerer.next_rel_id, 0);
2598    }
2599
2600    #[test]
2601    fn test_get_or_create_rel_id() {
2602        let mut lowerer = Lowerer::new();
2603        let id1 = lowerer.get_or_create_rel_id("edge");
2604        let id2 = lowerer.get_or_create_rel_id("reach");
2605        let id3 = lowerer.get_or_create_rel_id("edge");
2606
2607        assert_eq!(id1, RelId(0));
2608        assert_eq!(id2, RelId(1));
2609        assert_eq!(id3, RelId(0)); // Same as id1
2610    }
2611
2612    #[test]
2613    fn test_infer_schemas_from_facts() {
2614        let mut program = Program::new();
2615        program.rules.push(Rule {
2616            head: Atom {
2617                predicate: "edge".to_string(),
2618                terms: vec![Term::Integer(1), Term::Integer(2)],
2619            },
2620            body: vec![],
2621        });
2622
2623        let mut lowerer = Lowerer::new();
2624        lowerer.infer_schemas(&program).unwrap();
2625
2626        assert!(lowerer.schemas.contains_key("edge"));
2627        let schema = lowerer.schemas.get("edge").unwrap();
2628        assert_eq!(schema.arity(), 2);
2629    }
2630
2631    #[test]
2632    fn test_lower_simple_rule() {
2633        // reach(X, Y) :- edge(X, Y).
2634        let rule = Rule {
2635            head: reach_atom("X", "Y"),
2636            body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
2637        };
2638
2639        let mut lowerer = Lowerer::new();
2640        lowerer.schemas.insert(
2641            "edge".to_string(),
2642            Schema::new(vec![
2643                ("c0".to_string(), ScalarType::U32),
2644                ("c1".to_string(), ScalarType::U32),
2645            ]),
2646        );
2647
2648        let result = lowerer.lower_rule(&rule);
2649        assert!(result.is_ok());
2650
2651        let node = result.unwrap();
2652        // Should be just a scan (no projection needed since columns match)
2653        assert!(matches!(node, RirNode::Scan { .. }));
2654    }
2655
2656    #[test]
2657    fn test_lower_join_rule() {
2658        // reach(X, Z) :- reach(X, Y), edge(Y, Z).
2659        let rule = Rule {
2660            head: Atom {
2661                predicate: "reach".to_string(),
2662                terms: vec![
2663                    Term::Variable("X".to_string()),
2664                    Term::Variable("Z".to_string()),
2665                ],
2666            },
2667            body: vec![
2668                BodyLiteral::Positive(reach_atom("X", "Y")),
2669                BodyLiteral::Positive(edge_atom("Y", "Z")),
2670            ],
2671        };
2672
2673        let mut lowerer = Lowerer::new();
2674        lowerer.schemas.insert(
2675            "reach".to_string(),
2676            Schema::new(vec![
2677                ("c0".to_string(), ScalarType::U32),
2678                ("c1".to_string(), ScalarType::U32),
2679            ]),
2680        );
2681        lowerer.schemas.insert(
2682            "edge".to_string(),
2683            Schema::new(vec![
2684                ("c0".to_string(), ScalarType::U32),
2685                ("c1".to_string(), ScalarType::U32),
2686            ]),
2687        );
2688
2689        let result = lowerer.lower_rule(&rule);
2690        assert!(result.is_ok());
2691
2692        let node = result.unwrap();
2693        // Should be Project(Join(Scan, Scan))
2694        if let RirNode::Project { input, columns } = node {
2695            // X from reach (col 0), Z from edge (col 3)
2696            assert_eq!(
2697                columns,
2698                vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
2699            );
2700            assert!(matches!(*input, RirNode::Join { .. }));
2701            if let RirNode::Join {
2702                left_keys,
2703                right_keys,
2704                ..
2705            } = *input
2706            {
2707                assert_eq!(left_keys, vec![1]); // Y in reach (position 1)
2708                assert_eq!(right_keys, vec![0]); // Y in edge (position 0)
2709            }
2710        } else {
2711            panic!("Expected Project node");
2712        }
2713    }
2714
2715    #[test]
2716    fn test_join_order_prefers_smaller_relation() {
2717        // out(X) :- big(X), small(X).
2718        let rule = Rule {
2719            head: Atom {
2720                predicate: "out".to_string(),
2721                terms: vec![Term::Variable("X".to_string())],
2722            },
2723            body: vec![
2724                BodyLiteral::Positive(Atom {
2725                    predicate: "big".to_string(),
2726                    terms: vec![Term::Variable("X".to_string())],
2727                }),
2728                BodyLiteral::Positive(Atom {
2729                    predicate: "small".to_string(),
2730                    terms: vec![Term::Variable("X".to_string())],
2731                }),
2732            ],
2733        };
2734
2735        let mut lowerer = Lowerer::new();
2736        lowerer.schemas.insert(
2737            "big".to_string(),
2738            Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2739        );
2740        lowerer.schemas.insert(
2741            "small".to_string(),
2742            Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2743        );
2744
2745        // Ensure stable RelIds independent of join order.
2746        let big_id = lowerer.get_or_create_rel_id("big");
2747        let small_id = lowerer.get_or_create_rel_id("small");
2748        assert_eq!(big_id, RelId(0));
2749        assert_eq!(small_id, RelId(1));
2750
2751        // Prefer scanning the smaller relation first.
2752        lowerer.est_cardinality.insert("big".to_string(), 10_000);
2753        lowerer.est_cardinality.insert("small".to_string(), 10);
2754
2755        let node = lowerer.lower_rule(&rule).unwrap();
2756        let join = match node {
2757            RirNode::Project { input, .. } => *input,
2758            other => other,
2759        };
2760
2761        match join {
2762            RirNode::Join { left, right, .. } => {
2763                // Prefer building the hash table on the smaller relation (right/build side).
2764                assert!(matches!(*left, RirNode::Scan { rel } if rel == big_id));
2765                assert!(matches!(*right, RirNode::Scan { rel } if rel == small_id));
2766            }
2767            other => panic!("Expected Join node, got {:?}", other),
2768        }
2769    }
2770
2771    #[test]
2772    fn test_lower_negation() {
2773        // isolated(X) :- node(X), not edge(X, _).
2774        let rule = Rule {
2775            head: Atom {
2776                predicate: "isolated".to_string(),
2777                terms: vec![Term::Variable("X".to_string())],
2778            },
2779            body: vec![
2780                BodyLiteral::Positive(node_atom("X")),
2781                BodyLiteral::Negated(Atom {
2782                    predicate: "edge".to_string(),
2783                    terms: vec![
2784                        Term::Variable("X".to_string()),
2785                        Term::Variable("_".to_string()),
2786                    ],
2787                }),
2788            ],
2789        };
2790
2791        let mut lowerer = Lowerer::new();
2792        lowerer.schemas.insert(
2793            "node".to_string(),
2794            Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2795        );
2796        lowerer.schemas.insert(
2797            "edge".to_string(),
2798            Schema::new(vec![
2799                ("c0".to_string(), ScalarType::U32),
2800                ("c1".to_string(), ScalarType::U32),
2801            ]),
2802        );
2803
2804        let result = lowerer.lower_rule(&rule);
2805        assert!(result.is_ok());
2806
2807        // The result should involve a Diff or semi-join for negation
2808        let node = result.unwrap();
2809        // Verify the structure contains the negation handling
2810        fn contains_diff_or_semi(node: &RirNode) -> bool {
2811            match node {
2812                RirNode::Diff { .. } => true,
2813                RirNode::Join {
2814                    join_type: JoinType::Semi,
2815                    ..
2816                } => true,
2817                RirNode::Join { left, right, .. } => {
2818                    contains_diff_or_semi(left) || contains_diff_or_semi(right)
2819                }
2820                RirNode::Project { input, .. } => contains_diff_or_semi(input),
2821                RirNode::Filter { input, .. } => contains_diff_or_semi(input),
2822                _ => false,
2823            }
2824        }
2825        assert!(contains_diff_or_semi(&node));
2826    }
2827
2828    #[test]
2829    fn test_lower_comparison() {
2830        // greater(X, Y) :- pair(X, Y), X > Y.
2831        let rule = Rule {
2832            head: Atom {
2833                predicate: "greater".to_string(),
2834                terms: vec![
2835                    Term::Variable("X".to_string()),
2836                    Term::Variable("Y".to_string()),
2837                ],
2838            },
2839            body: vec![
2840                BodyLiteral::Positive(Atom {
2841                    predicate: "pair".to_string(),
2842                    terms: vec![
2843                        Term::Variable("X".to_string()),
2844                        Term::Variable("Y".to_string()),
2845                    ],
2846                }),
2847                BodyLiteral::Comparison(Comparison {
2848                    left: Term::Variable("X".to_string()),
2849                    op: CompOp::Gt,
2850                    right: Term::Variable("Y".to_string()),
2851                }),
2852            ],
2853        };
2854
2855        let mut lowerer = Lowerer::new();
2856        lowerer.schemas.insert(
2857            "pair".to_string(),
2858            Schema::new(vec![
2859                ("c0".to_string(), ScalarType::U32),
2860                ("c1".to_string(), ScalarType::U32),
2861            ]),
2862        );
2863
2864        let result = lowerer.lower_rule(&rule);
2865        assert!(result.is_ok());
2866
2867        let node = result.unwrap();
2868        // Should contain a Filter node
2869        fn contains_filter(node: &RirNode) -> bool {
2870            match node {
2871                RirNode::Filter { .. } => true,
2872                RirNode::Project { input, .. } => contains_filter(input),
2873                RirNode::Join { left, right, .. } => {
2874                    contains_filter(left) || contains_filter(right)
2875                }
2876                _ => false,
2877            }
2878        }
2879        assert!(contains_filter(&node));
2880    }
2881
2882    #[test]
2883    fn test_lower_constant_filter() {
2884        // specific_edge(Y) :- edge(1, Y).
2885        let rule = Rule {
2886            head: Atom {
2887                predicate: "specific_edge".to_string(),
2888                terms: vec![Term::Variable("Y".to_string())],
2889            },
2890            body: vec![BodyLiteral::Positive(Atom {
2891                predicate: "edge".to_string(),
2892                terms: vec![Term::Integer(1), Term::Variable("Y".to_string())],
2893            })],
2894        };
2895
2896        let mut lowerer = Lowerer::new();
2897        lowerer.schemas.insert(
2898            "edge".to_string(),
2899            Schema::new(vec![
2900                ("c0".to_string(), ScalarType::U32),
2901                ("c1".to_string(), ScalarType::U32),
2902            ]),
2903        );
2904
2905        let result = lowerer.lower_rule(&rule);
2906        assert!(result.is_ok());
2907
2908        let node = result.unwrap();
2909        // Should contain a Filter for the constant 1
2910        fn has_const_filter(node: &RirNode) -> bool {
2911            match node {
2912                RirNode::Filter {
2913                    predicate: Expr::Compare { right, .. },
2914                    ..
2915                } => matches!(**right, Expr::Const(_)),
2916                RirNode::Project { input, .. } => has_const_filter(input),
2917                _ => false,
2918            }
2919        }
2920        assert!(has_const_filter(&node));
2921    }
2922
2923    #[test]
2924    fn test_lower_repeated_variable_filter() {
2925        // self_loop(X) :- edge(X, X).
2926        let rule = Rule {
2927            head: Atom {
2928                predicate: "self_loop".to_string(),
2929                terms: vec![Term::Variable("X".to_string())],
2930            },
2931            body: vec![BodyLiteral::Positive(Atom {
2932                predicate: "edge".to_string(),
2933                terms: vec![
2934                    Term::Variable("X".to_string()),
2935                    Term::Variable("X".to_string()),
2936                ],
2937            })],
2938        };
2939
2940        let mut lowerer = Lowerer::new();
2941        lowerer.schemas.insert(
2942            "edge".to_string(),
2943            Schema::new(vec![
2944                ("c0".to_string(), ScalarType::U32),
2945                ("c1".to_string(), ScalarType::U32),
2946            ]),
2947        );
2948
2949        let node = lowerer.lower_rule(&rule).expect("lower_rule failed");
2950
2951        fn has_col_eq_filter(node: &RirNode) -> bool {
2952            match node {
2953                RirNode::Filter { predicate, .. } => match predicate {
2954                    Expr::Compare {
2955                        left,
2956                        op: CompareOp::Eq,
2957                        right,
2958                    } => {
2959                        matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2960                            || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2961                    }
2962                    Expr::And(exprs) => exprs.iter().any(|e| match e {
2963                        Expr::Compare {
2964                            left,
2965                            op: CompareOp::Eq,
2966                            right,
2967                        } => {
2968                            matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2969                                || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2970                        }
2971                        _ => false,
2972                    }),
2973                    _ => false,
2974                },
2975                RirNode::Project { input, .. } => has_col_eq_filter(input),
2976                _ => false,
2977            }
2978        }
2979
2980        assert!(has_col_eq_filter(&node));
2981    }
2982
2983    #[test]
2984    fn test_lower_program_simple() {
2985        let mut program = Program::new();
2986
2987        // edge(1, 2).
2988        program.rules.push(Rule {
2989            head: Atom {
2990                predicate: "edge".to_string(),
2991                terms: vec![Term::Integer(1), Term::Integer(2)],
2992            },
2993            body: vec![],
2994        });
2995
2996        // reach(X, Y) :- edge(X, Y).
2997        program.rules.push(Rule {
2998            head: reach_atom("X", "Y"),
2999            body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
3000        });
3001
3002        let mut lowerer = Lowerer::new();
3003        lowerer.set_strata(vec![vec!["edge".to_string()], vec!["reach".to_string()]]);
3004
3005        let result = lowerer.lower_program(&program);
3006        assert!(result.is_ok());
3007
3008        let plan = result.unwrap();
3009        assert!(!plan.sccs.is_empty());
3010    }
3011
3012    #[test]
3013    fn test_variable_env() {
3014        let mut env = VariableEnv::new();
3015        env.add_occurrence("X", "edge".to_string(), 0, 0);
3016        env.add_occurrence("Y", "edge".to_string(), 1, 1);
3017        env.add_occurrence("Y", "node".to_string(), 0, 2);
3018
3019        assert_eq!(env.get_column("X"), Some(0));
3020        assert_eq!(env.get_column("Y"), Some(1)); // First occurrence
3021        assert_eq!(env.get_column("Z"), None);
3022    }
3023
3024    #[test]
3025    fn test_infer_term_type() {
3026        assert_eq!(
3027            infer_term_type(&Term::Variable("X".to_string())),
3028            ScalarType::U64
3029        );
3030        assert_eq!(infer_term_type(&Term::Integer(42)), ScalarType::U32);
3031        assert_eq!(infer_term_type(&Term::Integer(i64::MAX)), ScalarType::I64);
3032        assert_eq!(infer_term_type(&Term::Float(3.25)), ScalarType::F64);
3033        assert_eq!(
3034            infer_term_type(&Term::Symbol(symbol::intern("foo"))),
3035            ScalarType::Symbol
3036        );
3037    }
3038
3039    #[test]
3040    fn test_convert_agg_op() {
3041        assert_eq!(convert_agg_op(&AggOp::Count), CoreAggOp::Count);
3042        assert_eq!(convert_agg_op(&AggOp::Sum), CoreAggOp::Sum);
3043        assert_eq!(convert_agg_op(&AggOp::Min), CoreAggOp::Min);
3044        assert_eq!(convert_agg_op(&AggOp::Max), CoreAggOp::Max);
3045        assert_eq!(convert_agg_op(&AggOp::LogSumExp), CoreAggOp::LogSumExp);
3046    }
3047
3048    #[test]
3049    fn test_variable_env_bind_updates_total_cols() {
3050        // Test that bind() properly updates total_cols for chained is-expressions
3051        let mut env = VariableEnv::new();
3052        env.total_cols = 2; // Simulate 2 columns from atoms
3053
3054        // Bind first computed variable at column 2
3055        env.bind("A", 2, ScalarType::I64);
3056        assert_eq!(
3057            env.column_count(),
3058            3,
3059            "total_cols should be 3 after first bind"
3060        );
3061        assert_eq!(env.get_column("A"), Some(2));
3062
3063        // Bind second computed variable at column 3
3064        env.bind("B", 3, ScalarType::I64);
3065        assert_eq!(
3066            env.column_count(),
3067            4,
3068            "total_cols should be 4 after second bind"
3069        );
3070        assert_eq!(env.get_column("B"), Some(3));
3071    }
3072
3073    #[test]
3074    fn test_lower_chained_is_expressions() {
3075        // result(A, B) :- input(X, Y), A is X + Y, B is A * 2.
3076        // This tests that chained is-expressions correctly update column indices
3077        let rule = Rule {
3078            head: Atom {
3079                predicate: "result".to_string(),
3080                terms: vec![
3081                    Term::Variable("A".to_string()),
3082                    Term::Variable("B".to_string()),
3083                ],
3084            },
3085            body: vec![
3086                BodyLiteral::Positive(Atom {
3087                    predicate: "input".to_string(),
3088                    terms: vec![
3089                        Term::Variable("X".to_string()),
3090                        Term::Variable("Y".to_string()),
3091                    ],
3092                }),
3093                BodyLiteral::IsExpr(IsExpr {
3094                    target: "A".to_string(),
3095                    expr: ArithExpr::Add(
3096                        Box::new(ArithExpr::Variable("X".to_string())),
3097                        Box::new(ArithExpr::Variable("Y".to_string())),
3098                    ),
3099                }),
3100                BodyLiteral::IsExpr(IsExpr {
3101                    target: "B".to_string(),
3102                    expr: ArithExpr::Mul(
3103                        Box::new(ArithExpr::Variable("A".to_string())),
3104                        Box::new(ArithExpr::Integer(2)),
3105                    ),
3106                }),
3107            ],
3108        };
3109
3110        let mut lowerer = Lowerer::new();
3111        lowerer.schemas.insert(
3112            "input".to_string(),
3113            Schema::new(vec![
3114                ("c0".to_string(), ScalarType::I64),
3115                ("c1".to_string(), ScalarType::I64),
3116            ]),
3117        );
3118
3119        let result = lowerer.lower_rule(&rule);
3120        assert!(
3121            result.is_ok(),
3122            "Lowering chained is-expressions should succeed: {:?}",
3123            result.err()
3124        );
3125
3126        let node = result.unwrap();
3127
3128        // The structure should be:
3129        // Project([col 2, col 3]) <-- final projection for A, B
3130        //   Project([col 0, col 1, col 2, A*2]) <-- second is-expr adds B at col 3
3131        //     Project([col 0, col 1, X+Y]) <-- first is-expr adds A at col 2
3132        //       Scan(input)
3133
3134        // Verify we have nested Project nodes
3135        fn count_projects(node: &RirNode) -> usize {
3136            match node {
3137                RirNode::Project { input, .. } => 1 + count_projects(input),
3138                _ => 0,
3139            }
3140        }
3141
3142        // We expect 3 Project nodes: 2 for is-expressions + 1 for final head projection
3143        let project_count = count_projects(&node);
3144        assert!(
3145            project_count >= 2,
3146            "Expected at least 2 Project nodes for chained is-exprs, got {}",
3147            project_count
3148        );
3149
3150        // Verify the final projection references columns 2 and 3 (A and B)
3151        if let RirNode::Project { columns, .. } = &node {
3152            assert_eq!(columns.len(), 2, "Head has 2 variables");
3153            // A should be at column 2, B at column 3
3154            assert_eq!(columns[0], ProjectExpr::Column(2), "A should be column 2");
3155            assert_eq!(columns[1], ProjectExpr::Column(3), "B should be column 3");
3156        } else {
3157            panic!("Expected top-level Project node");
3158        }
3159    }
3160
3161    #[test]
3162    fn test_u64_comparison_type_from_pred_decl() {
3163        // Test that u64 type from pred decl is preserved in comparison lowering
3164        let mut program = Program::new();
3165
3166        // pred count_data(symbol, u64).
3167        program.predicates.push(pred_decl(
3168            "count_data",
3169            vec![ScalarType::Symbol, ScalarType::U64],
3170        ));
3171
3172        // count_data(alice, 5).
3173        program.rules.push(Rule {
3174            head: Atom {
3175                predicate: "count_data".to_string(),
3176                terms: vec![
3177                    Term::Symbol(xlog_core::symbol::intern("alice")),
3178                    Term::Integer(5),
3179                ],
3180            },
3181            body: vec![],
3182        });
3183
3184        // pred big_count(symbol, u64).
3185        program.predicates.push(pred_decl(
3186            "big_count",
3187            vec![ScalarType::Symbol, ScalarType::U64],
3188        ));
3189
3190        // big_count(Name, Count) :- count_data(Name, Count), Count >= 3.
3191        program.rules.push(Rule {
3192            head: Atom {
3193                predicate: "big_count".to_string(),
3194                terms: vec![
3195                    Term::Variable("Name".to_string()),
3196                    Term::Variable("Count".to_string()),
3197                ],
3198            },
3199            body: vec![
3200                BodyLiteral::Positive(Atom {
3201                    predicate: "count_data".to_string(),
3202                    terms: vec![
3203                        Term::Variable("Name".to_string()),
3204                        Term::Variable("Count".to_string()),
3205                    ],
3206                }),
3207                BodyLiteral::Comparison(Comparison {
3208                    left: Term::Variable("Count".to_string()),
3209                    op: CompOp::Ge,
3210                    right: Term::Integer(3),
3211                }),
3212            ],
3213        });
3214
3215        let mut lowerer = Lowerer::new();
3216        lowerer.infer_schemas(&program).unwrap();
3217
3218        // Verify schema has correct types
3219        let schema = lowerer
3220            .schemas
3221            .get("count_data")
3222            .expect("schema for count_data");
3223        assert_eq!(
3224            schema.column_type(0),
3225            Some(ScalarType::Symbol),
3226            "First column should be Symbol"
3227        );
3228        assert_eq!(
3229            schema.column_type(1),
3230            Some(ScalarType::U64),
3231            "Second column should be U64"
3232        );
3233
3234        // Now test lowering the rule with comparison
3235        lowerer.set_strata(vec![
3236            vec!["count_data".to_string()],
3237            vec!["big_count".to_string()],
3238        ]);
3239        lowerer.build_sccs(&program);
3240
3241        let rule = &program.rules[1]; // big_count rule
3242        let result = lowerer.lower_rule(rule);
3243        assert!(
3244            result.is_ok(),
3245            "Lowering should succeed: {:?}",
3246            result.err()
3247        );
3248
3249        // Check that the filter has the correct constant type
3250        fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3251            match node {
3252                RirNode::Filter { predicate, input } => {
3253                    if let Expr::Compare { right, .. } = predicate {
3254                        if let Expr::Const(val) = right.as_ref() {
3255                            return Some(val);
3256                        }
3257                    }
3258                    find_compare_const(input)
3259                }
3260                RirNode::Project { input, .. } => find_compare_const(input),
3261                RirNode::Join { left, right, .. } => {
3262                    find_compare_const(left).or_else(|| find_compare_const(right))
3263                }
3264                _ => None,
3265            }
3266        }
3267
3268        let node = result.unwrap();
3269        let const_val = find_compare_const(&node);
3270        assert!(const_val.is_some(), "Should find a constant in comparison");
3271
3272        // The constant should be U64(3), not I64(3)
3273        match const_val.unwrap() {
3274            ConstValue::U64(v) => assert_eq!(*v, 3, "Value should be 3"),
3275            other => panic!("Expected U64(3), got {:?}", other),
3276        }
3277    }
3278
3279    #[test]
3280    fn test_u64_comparison_with_aggregation() {
3281        use crate::ast::AggExpr;
3282
3283        // Test aggregation + comparison case
3284        let mut program = Program::new();
3285
3286        // pred reports_to(symbol, symbol).
3287        program.predicates.push(pred_decl(
3288            "reports_to",
3289            vec![ScalarType::Symbol, ScalarType::Symbol],
3290        ));
3291
3292        // reports_to facts
3293        program.rules.push(Rule {
3294            head: Atom {
3295                predicate: "reports_to".to_string(),
3296                terms: vec![
3297                    Term::Symbol(xlog_core::symbol::intern("alice")),
3298                    Term::Symbol(xlog_core::symbol::intern("bob")),
3299                ],
3300            },
3301            body: vec![],
3302        });
3303        program.rules.push(Rule {
3304            head: Atom {
3305                predicate: "reports_to".to_string(),
3306                terms: vec![
3307                    Term::Symbol(xlog_core::symbol::intern("carol")),
3308                    Term::Symbol(xlog_core::symbol::intern("bob")),
3309                ],
3310            },
3311            body: vec![],
3312        });
3313
3314        // pred direct_count(symbol, u64).
3315        program.predicates.push(pred_decl(
3316            "direct_count",
3317            vec![ScalarType::Symbol, ScalarType::U64],
3318        ));
3319
3320        // direct_count(Mgr, count(Emp)) :- reports_to(Emp, Mgr).
3321        program.rules.push(Rule {
3322            head: Atom {
3323                predicate: "direct_count".to_string(),
3324                terms: vec![
3325                    Term::Variable("Mgr".to_string()),
3326                    Term::Aggregate(AggExpr {
3327                        op: AggOp::Count,
3328                        variable: "Emp".to_string(),
3329                    }),
3330                ],
3331            },
3332            body: vec![BodyLiteral::Positive(Atom {
3333                predicate: "reports_to".to_string(),
3334                terms: vec![
3335                    Term::Variable("Emp".to_string()),
3336                    Term::Variable("Mgr".to_string()),
3337                ],
3338            })],
3339        });
3340
3341        // pred big_manager(symbol, u64).
3342        program.predicates.push(pred_decl(
3343            "big_manager",
3344            vec![ScalarType::Symbol, ScalarType::U64],
3345        ));
3346
3347        // big_manager(Mgr, Count) :- direct_count(Mgr, Count), Count >= 2.
3348        program.rules.push(Rule {
3349            head: Atom {
3350                predicate: "big_manager".to_string(),
3351                terms: vec![
3352                    Term::Variable("Mgr".to_string()),
3353                    Term::Variable("Count".to_string()),
3354                ],
3355            },
3356            body: vec![
3357                BodyLiteral::Positive(Atom {
3358                    predicate: "direct_count".to_string(),
3359                    terms: vec![
3360                        Term::Variable("Mgr".to_string()),
3361                        Term::Variable("Count".to_string()),
3362                    ],
3363                }),
3364                BodyLiteral::Comparison(Comparison {
3365                    left: Term::Variable("Count".to_string()),
3366                    op: CompOp::Ge,
3367                    right: Term::Integer(2),
3368                }),
3369            ],
3370        });
3371
3372        let mut lowerer = Lowerer::new();
3373        lowerer.infer_schemas(&program).unwrap();
3374
3375        // Verify schema has correct types
3376        let schema = lowerer
3377            .schemas
3378            .get("direct_count")
3379            .expect("schema for direct_count");
3380        assert_eq!(
3381            schema.column_type(0),
3382            Some(ScalarType::Symbol),
3383            "First column should be Symbol"
3384        );
3385        assert_eq!(
3386            schema.column_type(1),
3387            Some(ScalarType::U64),
3388            "Second column should be U64"
3389        );
3390
3391        lowerer.set_strata(vec![
3392            vec!["reports_to".to_string()],
3393            vec!["direct_count".to_string()],
3394            vec!["big_manager".to_string()],
3395        ]);
3396        lowerer.build_sccs(&program);
3397
3398        // Lower the big_manager rule (index 3: after 2 facts + aggregation rule)
3399        let big_manager_rule = &program.rules[3];
3400        let result = lowerer.lower_rule(big_manager_rule);
3401        assert!(
3402            result.is_ok(),
3403            "Lowering should succeed: {:?}",
3404            result.err()
3405        );
3406
3407        // Check that the filter has the correct constant type
3408        fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3409            match node {
3410                RirNode::Filter { predicate, input } => {
3411                    if let Expr::Compare { right, .. } = predicate {
3412                        if let Expr::Const(val) = right.as_ref() {
3413                            return Some(val);
3414                        }
3415                    }
3416                    find_compare_const(input)
3417                }
3418                RirNode::Project { input, .. } => find_compare_const(input),
3419                RirNode::Join { left, right, .. } => {
3420                    find_compare_const(left).or_else(|| find_compare_const(right))
3421                }
3422                _ => None,
3423            }
3424        }
3425
3426        let node = result.unwrap();
3427        let const_val = find_compare_const(&node);
3428        assert!(const_val.is_some(), "Should find a constant in comparison");
3429
3430        // The constant should be U64(2), not I64(2)
3431        match const_val.unwrap() {
3432            ConstValue::U64(v) => assert_eq!(*v, 2, "Value should be 2"),
3433            other => panic!("Expected U64(2), got {:?}", other),
3434        }
3435    }
3436}