Skip to main content

xlog_logic/
compile.rs

1//! Compilation pipeline for XLOG programs
2//!
3//! This module provides the main entry point for compiling XLOG source code
4//! into execution plans. The compilation process consists of:
5//!
6//! 1. **Parsing**: Convert source text to AST (`parser::parse_program`)
7//! 2. **Stratification**: Analyze negation/aggregation dependencies (`stratify::stratify`)
8//! 3. **Lowering**: Transform AST to Relational IR (`lower::Lowerer::lower_program`)
9//!
10//! The `Compiler` struct orchestrates these phases and provides a single
11//! entry point via the `compile` method.
12
13use std::path::{Path, PathBuf};
14
15use xlog_core::{Result, XlogError};
16use xlog_ir::ExecutionPlan;
17use xlog_stats::{StatsManager, StatsSnapshot};
18
19use crate::compiler_config::CompilerConfig;
20use crate::list_normalize::normalize_list_builtins;
21use crate::lower::Lowerer;
22use crate::magic_sets::rewrite_magic_sets;
23use crate::meta_normalize::normalize_meta_builtins;
24use crate::module::ModuleError;
25use crate::optimizer::Optimizer;
26use crate::parser::parse_program;
27use crate::resolver::ModuleResolver;
28use crate::stratify::stratify;
29use crate::{BodyLiteral, Program, Query, Rule as AstRule, Term};
30
31/// The XLOG compiler orchestrates the full compilation pipeline.
32///
33/// # Example
34///
35/// ```ignore
36/// use xlog_logic::compile::Compiler;
37///
38/// let mut compiler = Compiler::new();
39/// let plan = compiler.compile(r#"
40///     edge(1, 2).
41///     edge(2, 3).
42///     reach(X, Y) :- edge(X, Y).
43///     reach(X, Z) :- reach(X, Y), edge(Y, Z).
44/// "#)?;
45/// ```
46pub struct Compiler {
47    lowerer: Lowerer,
48}
49
50use std::collections::{HashMap, HashSet};
51use std::sync::Arc;
52use xlog_core::{RelId, Schema};
53
54impl Default for Compiler {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl Compiler {
61    /// Create a new compiler instance.
62    pub fn new() -> Self {
63        Self {
64            lowerer: Lowerer::new(),
65        }
66    }
67
68    /// Set the maximum active rules for TensorMaskedJoin (16..=128).
69    pub fn set_max_active_rules(&mut self, max: usize) {
70        self.lowerer.set_max_active_rules(max);
71    }
72
73    /// Compile XLOG source code into an execution plan.
74    ///
75    /// This is the main entry point for compilation. It chains together:
76    /// 1. Parsing (source → AST)
77    /// 2. Stratification (analyze dependencies, check for cycles)
78    /// 3. Lowering (AST → Relational IR execution plan)
79    ///
80    /// # Arguments
81    ///
82    /// * `source` - The XLOG source code as a string
83    ///
84    /// # Returns
85    ///
86    /// * `Ok(ExecutionPlan)` - The compiled execution plan ready for execution
87    /// * `Err(XlogError)` - If any compilation phase fails:
88    ///   - `XlogError::Parse` - Syntax errors in the source
89    ///   - `XlogError::StratificationCycle` - Unstratifiable negation/aggregation
90    ///   - `XlogError::Compilation` - Other semantic errors
91    ///
92    /// # Example
93    ///
94    /// ```ignore
95    /// let mut compiler = Compiler::new();
96    ///
97    /// // Compile a simple transitive closure program
98    /// let plan = compiler.compile(r#"
99    ///     edge(1, 2).
100    ///     edge(2, 3).
101    ///     reach(X, Y) :- edge(X, Y).
102    ///     reach(X, Z) :- reach(X, Y), edge(Y, Z).
103    /// "#)?;
104    ///
105    /// // The plan can now be executed by xlog-runtime
106    /// ```
107    pub fn compile(&mut self, source: &str) -> Result<ExecutionPlan> {
108        self.compile_with_stats_snapshot(source, None)
109    }
110
111    /// Compile XLOG source code into an execution plan, optionally seeding the optimizer
112    /// with a runtime statistics snapshot.
113    ///
114    /// This entry point delegates through the composable config-aware API
115    /// with `CompilerConfig::default()`, which preserves existing triangle,
116    /// 4-cycle, recursive, and selectivity-aware dispatch behavior
117    /// bit-identically.
118    pub fn compile_with_stats_snapshot(
119        &mut self,
120        source: &str,
121        stats_snapshot: Option<&StatsSnapshot>,
122    ) -> Result<ExecutionPlan> {
123        self.compile_with_config_and_stats_snapshot(
124            source,
125            &CompilerConfig::default(),
126            stats_snapshot,
127        )
128    }
129
130    /// Composable entry point that accepts a `CompilerConfig`.
131    ///
132    /// Default-config callers should keep using `compile()` /
133    /// `compile_with_stats_snapshot()`. This entry point exists so callers can
134    /// enable the variable-ordering cost model per call without an environment
135    /// override.
136    pub fn compile_with_config_and_stats_snapshot(
137        &mut self,
138        source: &str,
139        config: &CompilerConfig,
140        stats_snapshot: Option<&StatsSnapshot>,
141    ) -> Result<ExecutionPlan> {
142        let program = parse_program(source)?;
143        self.compile_program_with_config_and_stats_snapshot(&program, config, stats_snapshot)
144    }
145
146    /// Compile a parsed XLOG program into an execution plan.
147    ///
148    /// This is useful for callers that want to inspect the AST (facts, queries,
149    /// constraints) while compiling without reparsing.
150    pub fn compile_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
151        self.compile_program_with_stats_snapshot(program, None)
152    }
153
154    /// Compile a parsed XLOG program into an execution plan, optionally seeding the optimizer.
155    ///
156    /// Delegates to [`Self::compile_program_with_config_and_stats_snapshot`]
157    /// with `CompilerConfig::default()`.
158    pub fn compile_program_with_stats_snapshot(
159        &mut self,
160        program: &Program,
161        stats_snapshot: Option<&StatsSnapshot>,
162    ) -> Result<ExecutionPlan> {
163        self.compile_program_with_config_and_stats_snapshot(
164            program,
165            &CompilerConfig::default(),
166            stats_snapshot,
167        )
168    }
169
170    /// Composable program-level entry point.
171    ///
172    /// `config` is currently consumed only by the promoter when it wires the
173    /// variable-ordering cost model. With `CompilerConfig::default()`, the
174    /// promoter keeps the default variable order.
175    pub fn compile_program_with_config_and_stats_snapshot(
176        &mut self,
177        program: &Program,
178        config: &CompilerConfig,
179        stats_snapshot: Option<&StatsSnapshot>,
180    ) -> Result<ExecutionPlan> {
181        let program = desugar_queries_and_constraints(program);
182        let program = normalize_meta_builtins(&program)?;
183        let program = normalize_list_builtins(&program)?;
184        let program = rewrite_magic_sets(&program)?.program;
185        validate_negation_safety(&program)?;
186
187        // Phase 2: Stratify (analyze dependencies, detect cycles)
188        let strata = stratify(&program).map_err(map_stratification_to_naf_error)?;
189
190        // Convert strata to the format expected by the lowerer
191        let strata_preds: Vec<Vec<String>> = strata.into_iter().map(|s| s.predicates).collect();
192
193        // Phase 3: Lower AST to execution plan
194        self.lowerer.set_strata(strata_preds);
195
196        // If we have predicate names for the snapshot, use them to seed lowering-time
197        // join ordering with better cardinality estimates.
198        let mut cardinality_hints: HashMap<String, u64> = HashMap::new();
199        if let Some(snapshot) = stats_snapshot {
200            if !snapshot.rel_names.is_empty() {
201                let rel_name_by_id: HashMap<RelId, &str> = snapshot
202                    .rel_names
203                    .iter()
204                    .map(|(id, name)| (*id, name.as_str()))
205                    .collect();
206                for rel in &snapshot.relations {
207                    if let Some(name) = rel_name_by_id.get(&rel.rel_id) {
208                        cardinality_hints.insert((*name).to_string(), rel.cardinality);
209                    }
210                }
211            }
212        }
213        self.lowerer.set_cardinality_hints(cardinality_hints);
214
215        let mut plan = self.lowerer.lower_program(&program)?;
216
217        // Phase 4: Optimize (predicate pushdown + cost-aware rewrites)
218        //
219        // Seed statistics with any known fact cardinalities so cost estimation has
220        // at least a baseline for EDB relations.
221        let mut mgr = StatsManager::new();
222        let mut fact_counts: HashMap<String, u64> = HashMap::new();
223        for fact in program.facts() {
224            *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
225        }
226
227        for (pred, rel_id) in self.lowerer.rel_ids() {
228            mgr.register_relation(*rel_id);
229            let rows = fact_counts.get(pred).copied().unwrap_or(0);
230            if rows > 0 {
231                mgr.update_cardinality(*rel_id, rows);
232                if let Some(schema) = self.lowerer.schemas().get(pred) {
233                    mgr.update_byte_size(*rel_id, rows * schema.row_size_bytes() as u64);
234                }
235            }
236        }
237
238        if let Some(snapshot) = stats_snapshot {
239            if snapshot.rel_names.is_empty() {
240                mgr.merge_snapshot(snapshot);
241            } else {
242                let rel_name_by_id: HashMap<RelId, &str> = snapshot
243                    .rel_names
244                    .iter()
245                    .map(|(id, name)| (*id, name.as_str()))
246                    .collect();
247
248                for rel in &snapshot.relations {
249                    let Some(pred) = rel_name_by_id.get(&rel.rel_id) else {
250                        continue;
251                    };
252                    let Some(rel_id) = self.lowerer.rel_ids().get(*pred) else {
253                        continue;
254                    };
255
256                    let mut remapped = rel.clone();
257                    remapped.rel_id = *rel_id;
258
259                    if let Some(schema) = self.lowerer.schemas().get(*pred) {
260                        remapped.column_stats.retain(|col| {
261                            col.col_idx < schema.arity()
262                                && schema.column_type(col.col_idx) == Some(col.dtype)
263                        });
264                    } else {
265                        remapped.column_stats.clear();
266                    }
267
268                    mgr.register_relation(*rel_id);
269                    if let Some(stats) = mgr.get_relation_stats_mut(*rel_id) {
270                        *stats = remapped;
271                    }
272                }
273
274                for js in &snapshot.join_selectivities {
275                    if js.left_keys.len() != js.right_keys.len() {
276                        continue;
277                    }
278
279                    let Some(left_pred) = rel_name_by_id.get(&js.left_rel) else {
280                        continue;
281                    };
282                    let Some(right_pred) = rel_name_by_id.get(&js.right_rel) else {
283                        continue;
284                    };
285                    let Some(&left_id) = self.lowerer.rel_ids().get(*left_pred) else {
286                        continue;
287                    };
288                    let Some(&right_id) = self.lowerer.rel_ids().get(*right_pred) else {
289                        continue;
290                    };
291
292                    let Some(left_schema) = self.lowerer.schemas().get(*left_pred) else {
293                        continue;
294                    };
295                    let Some(right_schema) = self.lowerer.schemas().get(*right_pred) else {
296                        continue;
297                    };
298                    if js.left_keys.iter().any(|&k| k >= left_schema.arity())
299                        || js.right_keys.iter().any(|&k| k >= right_schema.arity())
300                    {
301                        continue;
302                    }
303
304                    mgr.set_join_selectivity(
305                        left_id,
306                        right_id,
307                        js.left_keys.clone(),
308                        js.right_keys.clone(),
309                        js.selectivity,
310                    );
311                }
312            }
313        }
314
315        // Build schemas by RelId for the optimizer
316        let schemas_by_rel_id: HashMap<RelId, Schema> = self
317            .lowerer
318            .rel_ids()
319            .iter()
320            .filter_map(|(pred, rel_id)| {
321                self.lowerer
322                    .schemas()
323                    .get(pred)
324                    .map(|schema| (*rel_id, schema.clone()))
325            })
326            .collect();
327
328        let stats_arc = Arc::new(mgr);
329
330        crate::optimizer::helper_split_pass::run(
331            &mut plan,
332            &schemas_by_rel_id,
333            &stats_arc,
334            |schema| self.lowerer.create_helper_relation(schema),
335        );
336
337        let schemas_by_rel_id: HashMap<RelId, Schema> = self
338            .lowerer
339            .rel_ids()
340            .iter()
341            .filter_map(|(pred, rel_id)| {
342                self.lowerer
343                    .schemas()
344                    .get(pred)
345                    .map(|schema| (*rel_id, schema.clone()))
346            })
347            .collect();
348
349        let mut optimizer = Optimizer::new(Arc::clone(&stats_arc));
350        optimizer.set_schemas(schemas_by_rel_id);
351        for rules in &mut plan.rules_by_scc {
352            for rule in rules {
353                rule.body = optimizer.optimize(rule.body.clone());
354            }
355        }
356
357        // Selectivity-aware reordering pass. Runs BETWEEN the optimizer loop
358        // and promote_multiway.
359        // Locked compile-pipeline ordering:
360        //   lower → helper_split_pass → optimizer → selectivity_pass → promote_multiway
361        //
362        // Takes `rel_ids` so per-body Scans can be resolved against
363        // `StatsManager`. Behavior on empty stats / unseeded relations is
364        // no-op (safety floor).
365        crate::optimizer::selectivity_pass::run(&mut plan, &stats_arc, self.lowerer.rel_ids());
366
367        // Promote eligible triangle subtrees to RirNode::MultiWayJoin. Runs
368        // *after* the optimizer so the optimizer never has to learn the new
369        // variant. Fallback identity preserves binary-join semantics on
370        // dispatch decline.
371        //
372        // Pass the lowerer's predicate→RelId map so the promoter can gate
373        // recursive-SCC bodies on the count of in-SCC Scans (≤ 1 = promote,
374        // ≥ 2 = skip).
375        //
376        // Also pass `&stats_arc` and the caller-provided `&CompilerConfig`.
377        // With `CompilerConfig::default()` (`Disabled`), the promoter never
378        // sets `var_order` and default dispatch is bit-identical.
379        crate::promote::promote_multiway(&mut plan, self.lowerer.rel_ids(), &stats_arc, config);
380
381        let schemas_by_rel_id: HashMap<RelId, Schema> = self
382            .lowerer
383            .rel_ids()
384            .iter()
385            .filter_map(|(pred, rel_id)| {
386                self.lowerer
387                    .schemas()
388                    .get(pred)
389                    .map(|schema| (*rel_id, schema.clone()))
390            })
391            .collect();
392
393        crate::optimizer::helper_split_pass::run_kclique_specs(
394            &mut plan,
395            &schemas_by_rel_id,
396            |schema| self.lowerer.create_helper_relation(schema),
397        );
398
399        Ok(plan)
400    }
401
402    /// Reset the compiler state for a fresh compilation.
403    ///
404    /// This creates a new lowerer, clearing any cached schemas or relation IDs
405    /// from previous compilations.
406    pub fn reset(&mut self) {
407        self.lowerer = Lowerer::new();
408    }
409
410    /// Get the mapping from predicate names to relation IDs after compilation.
411    ///
412    /// This mapping is needed to register relations in the executor with
413    /// the correct RelIds.
414    pub fn rel_ids(&self) -> &HashMap<String, RelId> {
415        self.lowerer.rel_ids()
416    }
417
418    /// Get the inferred schemas for predicates after compilation.
419    ///
420    /// These schemas are needed to create GPU buffers with correct column types.
421    pub fn schemas(&self) -> &HashMap<String, Schema> {
422        self.lowerer.schemas()
423    }
424}
425
426fn desugar_queries_and_constraints(program: &Program) -> Program {
427    let mut out = program.clone();
428
429    // Constraints: `:- body.` becomes `__xlog_constraint_i(1) :- body.`
430    for (i, constraint) in program.constraints.iter().enumerate() {
431        let pred = format!("__xlog_constraint_{}", i);
432        out.rules.push(AstRule {
433            head: crate::ast::Atom {
434                predicate: pred,
435                terms: vec![Term::Integer(1)],
436            },
437            body: constraint.body.clone(),
438        });
439    }
440
441    // Queries: `?- atom.` becomes `__xlog_query_i(Vars...) :- atom.`
442    for (i, Query { atom }) in program.queries.iter().enumerate() {
443        let pred = format!("__xlog_query_{}", i);
444
445        let mut head_terms: Vec<Term> = Vec::new();
446        let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
447
448        for term in &atom.terms {
449            for name in term.variables() {
450                if seen.insert(name) {
451                    head_terms.push(Term::Variable(name.to_string()));
452                }
453            }
454        }
455
456        if head_terms.is_empty() {
457            head_terms.push(Term::Integer(1));
458        }
459
460        out.rules.push(AstRule {
461            head: crate::ast::Atom {
462                predicate: pred,
463                terms: head_terms,
464            },
465            body: vec![BodyLiteral::Positive(atom.clone())],
466        });
467    }
468
469    out
470}
471
472fn validate_negation_safety(program: &Program) -> Result<()> {
473    for rule in &program.rules {
474        validate_body_naf_safety(&rule.body, &format!("rule {}", rule.head.predicate))?;
475    }
476    for (idx, constraint) in program.constraints.iter().enumerate() {
477        validate_body_naf_safety(&constraint.body, &format!("constraint {}", idx))?;
478    }
479    for (idx, learnable) in program.learnable_rules.iter().enumerate() {
480        validate_body_naf_safety(&learnable.body, &format!("learnable rule {}", idx))?;
481    }
482    Ok(())
483}
484
485fn validate_body_naf_safety(body: &[BodyLiteral], context: &str) -> Result<()> {
486    let mut bound: HashSet<String> = HashSet::new();
487    for lit in body {
488        match lit {
489            BodyLiteral::Positive(atom) => {
490                for name in atom.variables() {
491                    bound.insert(name.to_string());
492                }
493            }
494            BodyLiteral::Negated(atom) => {
495                for name in atom.variables() {
496                    if !bound.contains(name) {
497                        return Err(naf_error(format!(
498                            "unbound variable {} in negated atom {}/{} in {}; bind it before not with a positive atom or deterministic is expression, or use '_' for existential positions",
499                            name,
500                            atom.predicate,
501                            atom.arity(),
502                            context
503                        )));
504                    }
505                }
506            }
507            BodyLiteral::IsExpr(is_expr) => {
508                bound.insert(is_expr.target.clone());
509            }
510            BodyLiteral::Epistemic(_) => {}
511            BodyLiteral::Comparison(_) | BodyLiteral::Univ(_) => {}
512        }
513    }
514    Ok(())
515}
516
517fn map_stratification_to_naf_error(err: XlogError) -> XlogError {
518    match err {
519        XlogError::StratificationCycle(cycle) => naf_error(format!(
520            "deterministic not atom must be stratified; cycle through negation or aggregation: {}",
521            cycle.join(" -> ")
522        )),
523        other => other,
524    }
525}
526
527fn naf_error(message: impl Into<String>) -> XlogError {
528    XlogError::Compilation(format!("negation safety error: {}", message.into()))
529}
530
531/// Convenience function to compile source in one call.
532///
533/// This creates a short-lived compiler and compiles the source.
534/// For multiple compilations, prefer creating a `Compiler` instance directly.
535///
536/// # Example
537///
538/// ```ignore
539/// use xlog_logic::compile::compile;
540///
541/// let plan = compile("edge(1, 2). reach(X, Y) :- edge(X, Y).")?;
542/// ```
543pub fn compile(source: &str) -> Result<ExecutionPlan> {
544    let mut compiler = Compiler::new();
545    compiler.compile(source)
546}
547
548/// Load and validate modules for a source file.
549///
550/// This function:
551/// 1. Determines the module path from the entry file name
552/// 2. Loads the entry module and all its dependencies
553/// 3. Validates imports (checks for conflicts, private predicates, etc.)
554///
555/// # Arguments
556///
557/// * `entry_file` - Path to the main .xlog file
558/// * `search_paths` - Additional directories to search for modules
559///
560/// # Returns
561///
562/// The loaded module resolver with all dependencies resolved, or an error
563/// if module resolution fails.
564pub fn load_modules(
565    entry_file: &Path,
566    search_paths: Vec<PathBuf>,
567) -> std::result::Result<ModuleResolver, ModuleError> {
568    let mut resolver = ModuleResolver::new(search_paths);
569
570    // Determine base directory and module path
571    let base_dir = entry_file.parent().unwrap_or(Path::new("."));
572    let module_name = entry_file
573        .file_stem()
574        .and_then(|s| s.to_str())
575        .unwrap_or("main");
576
577    // Load entry module (recursively loads dependencies)
578    resolver.load_module(base_dir, &[module_name.to_string()])?;
579
580    Ok(resolver)
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586    use xlog_core::ScalarType;
587    use xlog_ir::RirNode;
588    use xlog_stats::ColumnStats;
589    use xlog_stats::RelationStats;
590    use xlog_stats::StatsManager;
591
592    #[test]
593    fn test_compiler_new() {
594        let compiler = Compiler::new();
595        // Just verify it can be created
596        drop(compiler);
597    }
598
599    #[test]
600    fn test_compile_fact() {
601        let mut compiler = Compiler::new();
602        let result = compiler.compile("edge(1, 2).");
603        assert!(result.is_ok(), "Failed to compile fact: {:?}", result.err());
604    }
605
606    #[test]
607    fn test_compile_simple_rule() {
608        let mut compiler = Compiler::new();
609        let result = compiler.compile(
610            r#"
611            edge(1, 2).
612            reach(X, Y) :- edge(X, Y).
613        "#,
614        );
615        assert!(
616            result.is_ok(),
617            "Failed to compile simple rule: {:?}",
618            result.err()
619        );
620
621        let plan = result.unwrap();
622        assert!(!plan.sccs.is_empty(), "Expected at least one SCC");
623    }
624
625    #[test]
626    fn test_compile_transitive_closure() {
627        let mut compiler = Compiler::new();
628        let result = compiler.compile(
629            r#"
630            edge(1, 2).
631            edge(2, 3).
632            edge(3, 4).
633            reach(X, Y) :- edge(X, Y).
634            reach(X, Z) :- reach(X, Y), edge(Y, Z).
635        "#,
636        );
637        assert!(result.is_ok(), "Failed to compile TC: {:?}", result.err());
638
639        let plan = result.unwrap();
640        // Should have SCCs for edge and reach
641        assert!(!plan.sccs.is_empty());
642    }
643
644    #[test]
645    fn test_compile_with_negation() {
646        let mut compiler = Compiler::new();
647        let result = compiler.compile(
648            r#"
649            node(1).
650            node(2).
651            node(3).
652            edge(1, 2).
653            isolated(X) :- node(X), not edge(X, _).
654        "#,
655        );
656        assert!(
657            result.is_ok(),
658            "Failed to compile with negation: {:?}",
659            result.err()
660        );
661    }
662
663    #[test]
664    fn test_compile_with_comparison() {
665        let mut compiler = Compiler::new();
666        let result = compiler.compile(
667            r#"
668            value(1).
669            value(5).
670            value(10).
671            value(15).
672            small(X) :- value(X), X < 10.
673        "#,
674        );
675        assert!(
676            result.is_ok(),
677            "Failed to compile with comparison: {:?}",
678            result.err()
679        );
680    }
681
682    #[test]
683    fn test_schema_infers_from_rule_body_types() {
684        let mut compiler = Compiler::new();
685        let result = compiler.compile(
686            r#"
687            edge(1, 2).
688            edge(2, 3).
689            reach(X, Y) :- edge(X, Y).
690        "#,
691        );
692        assert!(
693            result.is_ok(),
694            "Failed to compile rule for schema inference: {:?}",
695            result.err()
696        );
697
698        let schema = compiler
699            .schemas()
700            .get("reach")
701            .expect("missing reach schema");
702        assert_eq!(
703            schema.column_type(0),
704            Some(ScalarType::U32),
705            "reach column 0 should match edge column type"
706        );
707        assert_eq!(
708            schema.column_type(1),
709            Some(ScalarType::U32),
710            "reach column 1 should match edge column type"
711        );
712    }
713
714    #[test]
715    fn test_compile_unstratifiable_fails() {
716        let mut compiler = Compiler::new();
717        let result = compiler.compile(
718            r#"
719            p :- not q.
720            q :- not p.
721        "#,
722        );
723        assert!(result.is_err(), "Should fail with stratification cycle");
724    }
725
726    #[test]
727    fn test_compile_syntax_error_fails() {
728        let mut compiler = Compiler::new();
729        let result = compiler.compile("edge(1, 2"); // Missing closing paren and period
730        assert!(result.is_err(), "Should fail with syntax error");
731    }
732
733    #[test]
734    fn test_compile_convenience_function() {
735        let result = compile("edge(1, 2).");
736        assert!(
737            result.is_ok(),
738            "Convenience compile failed: {:?}",
739            result.err()
740        );
741    }
742
743    #[test]
744    fn test_compiler_reset() {
745        let mut compiler = Compiler::new();
746
747        // First compilation
748        let result1 = compiler.compile("edge(1, 2).");
749        assert!(result1.is_ok());
750
751        // Reset and compile again
752        compiler.reset();
753        let result2 = compiler.compile("node(1). node(2).");
754        assert!(result2.is_ok());
755    }
756
757    #[test]
758    fn test_compile_with_pred_decl() {
759        let mut compiler = Compiler::new();
760        let result = compiler.compile(
761            r#"
762            pred edge(u32, u32).
763            edge(1, 2).
764            edge(2, 3).
765            reach(X, Y) :- edge(X, Y).
766        "#,
767        );
768        assert!(
769            result.is_ok(),
770            "Failed to compile with pred decl: {:?}",
771            result.err()
772        );
773    }
774
775    #[test]
776    fn test_compile_multi_stratum() {
777        let mut compiler = Compiler::new();
778        let result = compiler.compile(
779            r#"
780            // Base facts
781            edge(1, 2).
782            edge(2, 3).
783            edge(3, 1).
784
785            // Stratum 0: edge (base)
786            // Stratum 1: reach (depends on edge, recursive)
787            reach(X, Y) :- edge(X, Y).
788            reach(X, Z) :- reach(X, Y), edge(Y, Z).
789
790            // Stratum 2: non_reach (negates reach)
791            all_pairs(X, Y) :- edge(X, Z), edge(Y, W).
792            non_reach(X, Y) :- all_pairs(X, Y), not reach(X, Y).
793        "#,
794        );
795        assert!(
796            result.is_ok(),
797            "Failed to compile multi-stratum: {:?}",
798            result.err()
799        );
800
801        let plan = result.unwrap();
802        // Should have multiple strata
803        assert!(!plan.strata.is_empty(), "Expected multiple strata");
804    }
805
806    #[test]
807    fn test_compile_aggregation() {
808        let mut compiler = Compiler::new();
809        let result = compiler.compile(
810            r#"
811            edge(1, 2).
812            edge(1, 3).
813            edge(2, 3).
814            out_degree(X, count(Y)) :- edge(X, Y).
815        "#,
816        );
817        assert!(
818            result.is_ok(),
819            "Failed to compile with aggregation: {:?}",
820            result.err()
821        );
822
823        let plan = result.unwrap();
824        let out_degree_rules: Vec<_> = plan
825            .rules_by_scc
826            .iter()
827            .flatten()
828            .filter(|r| r.head == "out_degree")
829            .collect();
830        assert_eq!(out_degree_rules.len(), 1, "Expected one out_degree rule");
831
832        // Aggregation lowering should produce a GroupBy node (wrapped in a Project to match head order).
833        let body = &out_degree_rules[0].body;
834        match body {
835            RirNode::Project { input, .. } => {
836                assert!(
837                    matches!(input.as_ref(), RirNode::GroupBy { .. }),
838                    "Expected Project(GroupBy(..)), got {:?}",
839                    input
840                );
841            }
842            other => panic!("Expected Project(GroupBy(..)), got {:?}", other),
843        }
844    }
845
846    #[test]
847    fn test_compile_with_stats_snapshot() {
848        let mut compiler = Compiler::new();
849        let source = r#"
850            edge(1, 2).
851            edge(2, 3).
852            reach(X, Y) :- edge(X, Y).
853        "#;
854
855        let _ = compiler.compile(source).expect("Initial compile failed");
856        let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
857
858        let mut mgr = StatsManager::new();
859        mgr.register_relation(edge_id);
860        mgr.update_cardinality(edge_id, 42);
861        let snapshot = mgr.snapshot();
862
863        let plan = compiler
864            .compile_with_stats_snapshot(source, Some(&snapshot))
865            .expect("Compile with snapshot failed");
866        assert!(!plan.sccs.is_empty());
867    }
868
869    #[test]
870    fn test_compile_with_named_stats_snapshot_reorders_joins() {
871        let mut compiler = Compiler::new();
872        let source = r#"
873            foo(1).
874            edge(1).
875            out(X) :- edge(X), foo(X).
876        "#;
877
878        // Snapshot uses different RelIds than the compiler will assign for this program.
879        // Map: RelId(0) -> edge (small), RelId(1) -> foo (big)
880        let mut edge_stats = RelationStats::new(RelId(0));
881        edge_stats.update_cardinality(10);
882        let mut foo_stats = RelationStats::new(RelId(1));
883        foo_stats.update_cardinality(10_000);
884
885        let snapshot = StatsSnapshot {
886            relations: vec![edge_stats, foo_stats],
887            join_selectivities: Vec::new(),
888            rel_names: vec![
889                (RelId(0), "edge".to_string()),
890                (RelId(1), "foo".to_string()),
891            ],
892        };
893
894        let plan = compiler
895            .compile_with_stats_snapshot(source, Some(&snapshot))
896            .expect("Compile with named snapshot failed");
897
898        let foo_id = *compiler.rel_ids().get("foo").expect("foo rel_id missing");
899        let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
900
901        let out_rule = plan
902            .rules_by_scc
903            .iter()
904            .flatten()
905            .find(|r| r.head == "out")
906            .expect("out rule missing");
907
908        // Peel projections to reach the join.
909        let mut node = &out_rule.body;
910        while let RirNode::Project { input, .. } = node {
911            node = input;
912        }
913
914        match node {
915            RirNode::ChainJoin {
916                left,
917                right,
918                fallback,
919                ..
920            } => {
921                // ChainJoin promotion wraps eligible two-atom joins after
922                // stats-aware ordering. The chain node and its captured
923                // fallback must agree on the build-side choice.
924                assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
925                assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
926
927                let mut fallback_node = fallback.as_ref();
928                while let RirNode::Project { input, .. } = fallback_node {
929                    fallback_node = input;
930                }
931                match fallback_node {
932                    RirNode::Join { left, right, .. } => {
933                        assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
934                        assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
935                    }
936                    other => panic!("Expected ChainJoin fallback Join node, got {:?}", other),
937                }
938            }
939            RirNode::Join { left, right, .. } => {
940                // Prefer building on the smaller relation (right/build side).
941                assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
942                assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
943            }
944            other => panic!("Expected Join node, got {:?}", other),
945        }
946    }
947
948    fn helper_split_source() -> &'static str {
949        r#"
950            ab(0, 0). bc(0, 0). cd(0, 0). de(0, 0). ef(0, 0). af(0, 0).
951            out(A, B, C, D, F) :-
952                ab(A, B),
953                bc(B, C),
954                cd(C, D),
955                de(D, E),
956                ef(E, F),
957                af(A, F).
958        "#
959    }
960
961    fn helper_split_snapshot(distinct_d: u64) -> StatsSnapshot {
962        let mut snapshot_relations = Vec::new();
963        for (idx, name) in ["ab", "bc", "cd", "de", "ef", "af"].iter().enumerate() {
964            let mut rel_stats = RelationStats::new(RelId(idx as u32));
965            rel_stats.update_cardinality(8192);
966            if *name == "de" {
967                let mut d_col = ColumnStats::new(0, ScalarType::U32);
968                d_col.update_distinct(distinct_d);
969                rel_stats.add_column(d_col);
970            }
971            snapshot_relations.push(rel_stats);
972        }
973        StatsSnapshot {
974            relations: snapshot_relations,
975            join_selectivities: Vec::new(),
976            rel_names: ["ab", "bc", "cd", "de", "ef", "af"]
977                .iter()
978                .enumerate()
979                .map(|(idx, name)| (RelId(idx as u32), (*name).to_string()))
980                .collect(),
981        }
982    }
983
984    #[test]
985    fn test_compile_with_named_stats_snapshot_creates_helper_relation() {
986        let mut compiler = Compiler::new();
987        let snapshot = helper_split_snapshot(1);
988        let plan = compiler
989            .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
990            .expect("compile with helper stats");
991        let helper = compiler
992            .rel_ids()
993            .iter()
994            .find_map(|(name, rel)| {
995                name.starts_with("__kclique_helper_")
996                    .then_some((name.clone(), *rel))
997            })
998            .expect("helper relation allocated");
999
1000        let helper_rule_count = plan
1001            .rules_by_scc
1002            .iter()
1003            .flatten()
1004            .filter(|rule| rule.head == helper.0)
1005            .count();
1006        assert_eq!(helper_rule_count, 1);
1007
1008        let helper_rule = plan
1009            .rules_by_scc
1010            .iter()
1011            .flatten()
1012            .find(|rule| rule.head == helper.0)
1013            .expect("helper rule");
1014        assert!(
1015            matches!(helper_rule.body, RirNode::ChainJoin { .. }),
1016            "helper split output should be eligible for ChainJoin promotion"
1017        );
1018
1019        let out_rule = plan
1020            .rules_by_scc
1021            .iter()
1022            .flatten()
1023            .find(|rule| rule.head == "out")
1024            .expect("out rule");
1025        assert!(contains_scan(&out_rule.body, helper.1));
1026    }
1027
1028    #[test]
1029    fn test_compile_with_flat_named_stats_keeps_original_rule() {
1030        let mut compiler = Compiler::new();
1031        let snapshot = helper_split_snapshot(8192);
1032        let plan = compiler
1033            .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
1034            .expect("compile with flat stats");
1035
1036        assert!(!compiler
1037            .rel_ids()
1038            .keys()
1039            .any(|name| name.starts_with("__kclique_helper_")));
1040        let out_rules = plan
1041            .rules_by_scc
1042            .iter()
1043            .flatten()
1044            .filter(|rule| rule.head == "out")
1045            .count();
1046        assert_eq!(out_rules, 1);
1047    }
1048
1049    fn contains_scan(node: &RirNode, rel: RelId) -> bool {
1050        match node {
1051            RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
1052            RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
1053                contains_scan(left, rel) || contains_scan(right, rel)
1054            }
1055            RirNode::Project { input, .. }
1056            | RirNode::Filter { input, .. }
1057            | RirNode::Distinct { input, .. }
1058            | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
1059            RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
1060            RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
1061            RirNode::Fixpoint {
1062                base, recursive, ..
1063            } => contains_scan(base, rel) || contains_scan(recursive, rel),
1064            RirNode::MultiWayJoin { inputs, .. } => {
1065                inputs.iter().any(|input| contains_scan(input, rel))
1066            }
1067            RirNode::TensorMaskedJoin { rel_index, .. } => {
1068                rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
1069            }
1070            RirNode::Unit => false,
1071        }
1072    }
1073}