1use std::path::{Path, PathBuf};
14
15use xlog_core::{Result, XlogError};
16use xlog_ir::ExecutionPlan;
17use xlog_stats::{StatsManager, StatsSnapshot};
18
19use crate::compiler_config::CompilerConfig;
20use crate::list_normalize::normalize_list_builtins;
21use crate::lower::Lowerer;
22use crate::magic_sets::rewrite_magic_sets;
23use crate::meta_normalize::normalize_meta_builtins;
24use crate::module::ModuleError;
25use crate::optimizer::Optimizer;
26use crate::parser::parse_program;
27use crate::resolver::ModuleResolver;
28use crate::stratify::stratify;
29use crate::{BodyLiteral, Program, Query, Rule as AstRule, Term};
30
31pub struct Compiler {
47 lowerer: Lowerer,
48}
49
50use std::collections::{HashMap, HashSet};
51use std::sync::Arc;
52use xlog_core::{RelId, Schema};
53
54impl Default for Compiler {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl Compiler {
61 pub fn new() -> Self {
63 Self {
64 lowerer: Lowerer::new(),
65 }
66 }
67
68 pub fn set_max_active_rules(&mut self, max: usize) {
70 self.lowerer.set_max_active_rules(max);
71 }
72
73 pub fn compile(&mut self, source: &str) -> Result<ExecutionPlan> {
108 self.compile_with_stats_snapshot(source, None)
109 }
110
111 pub fn compile_with_stats_snapshot(
119 &mut self,
120 source: &str,
121 stats_snapshot: Option<&StatsSnapshot>,
122 ) -> Result<ExecutionPlan> {
123 self.compile_with_config_and_stats_snapshot(
124 source,
125 &CompilerConfig::default(),
126 stats_snapshot,
127 )
128 }
129
130 pub fn compile_with_config_and_stats_snapshot(
137 &mut self,
138 source: &str,
139 config: &CompilerConfig,
140 stats_snapshot: Option<&StatsSnapshot>,
141 ) -> Result<ExecutionPlan> {
142 let program = parse_program(source)?;
143 self.compile_program_with_config_and_stats_snapshot(&program, config, stats_snapshot)
144 }
145
146 pub fn compile_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
151 self.compile_program_with_stats_snapshot(program, None)
152 }
153
154 pub fn compile_program_with_stats_snapshot(
159 &mut self,
160 program: &Program,
161 stats_snapshot: Option<&StatsSnapshot>,
162 ) -> Result<ExecutionPlan> {
163 self.compile_program_with_config_and_stats_snapshot(
164 program,
165 &CompilerConfig::default(),
166 stats_snapshot,
167 )
168 }
169
170 pub fn compile_program_with_config_and_stats_snapshot(
176 &mut self,
177 program: &Program,
178 config: &CompilerConfig,
179 stats_snapshot: Option<&StatsSnapshot>,
180 ) -> Result<ExecutionPlan> {
181 let program = desugar_queries_and_constraints(program);
182 let program = normalize_meta_builtins(&program)?;
183 let program = normalize_list_builtins(&program)?;
184 let program = rewrite_magic_sets(&program)?.program;
185 validate_negation_safety(&program)?;
186
187 let strata = stratify(&program).map_err(map_stratification_to_naf_error)?;
189
190 let strata_preds: Vec<Vec<String>> = strata.into_iter().map(|s| s.predicates).collect();
192
193 self.lowerer.set_strata(strata_preds);
195
196 let mut cardinality_hints: HashMap<String, u64> = HashMap::new();
199 if let Some(snapshot) = stats_snapshot {
200 if !snapshot.rel_names.is_empty() {
201 let rel_name_by_id: HashMap<RelId, &str> = snapshot
202 .rel_names
203 .iter()
204 .map(|(id, name)| (*id, name.as_str()))
205 .collect();
206 for rel in &snapshot.relations {
207 if let Some(name) = rel_name_by_id.get(&rel.rel_id) {
208 cardinality_hints.insert((*name).to_string(), rel.cardinality);
209 }
210 }
211 }
212 }
213 self.lowerer.set_cardinality_hints(cardinality_hints);
214
215 let mut plan = self.lowerer.lower_program(&program)?;
216
217 let mut mgr = StatsManager::new();
222 let mut fact_counts: HashMap<String, u64> = HashMap::new();
223 for fact in program.facts() {
224 *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
225 }
226
227 for (pred, rel_id) in self.lowerer.rel_ids() {
228 mgr.register_relation(*rel_id);
229 let rows = fact_counts.get(pred).copied().unwrap_or(0);
230 if rows > 0 {
231 mgr.update_cardinality(*rel_id, rows);
232 if let Some(schema) = self.lowerer.schemas().get(pred) {
233 mgr.update_byte_size(*rel_id, rows * schema.row_size_bytes() as u64);
234 }
235 }
236 }
237
238 if let Some(snapshot) = stats_snapshot {
239 if snapshot.rel_names.is_empty() {
240 mgr.merge_snapshot(snapshot);
241 } else {
242 let rel_name_by_id: HashMap<RelId, &str> = snapshot
243 .rel_names
244 .iter()
245 .map(|(id, name)| (*id, name.as_str()))
246 .collect();
247
248 for rel in &snapshot.relations {
249 let Some(pred) = rel_name_by_id.get(&rel.rel_id) else {
250 continue;
251 };
252 let Some(rel_id) = self.lowerer.rel_ids().get(*pred) else {
253 continue;
254 };
255
256 let mut remapped = rel.clone();
257 remapped.rel_id = *rel_id;
258
259 if let Some(schema) = self.lowerer.schemas().get(*pred) {
260 remapped.column_stats.retain(|col| {
261 col.col_idx < schema.arity()
262 && schema.column_type(col.col_idx) == Some(col.dtype)
263 });
264 } else {
265 remapped.column_stats.clear();
266 }
267
268 mgr.register_relation(*rel_id);
269 if let Some(stats) = mgr.get_relation_stats_mut(*rel_id) {
270 *stats = remapped;
271 }
272 }
273
274 for js in &snapshot.join_selectivities {
275 if js.left_keys.len() != js.right_keys.len() {
276 continue;
277 }
278
279 let Some(left_pred) = rel_name_by_id.get(&js.left_rel) else {
280 continue;
281 };
282 let Some(right_pred) = rel_name_by_id.get(&js.right_rel) else {
283 continue;
284 };
285 let Some(&left_id) = self.lowerer.rel_ids().get(*left_pred) else {
286 continue;
287 };
288 let Some(&right_id) = self.lowerer.rel_ids().get(*right_pred) else {
289 continue;
290 };
291
292 let Some(left_schema) = self.lowerer.schemas().get(*left_pred) else {
293 continue;
294 };
295 let Some(right_schema) = self.lowerer.schemas().get(*right_pred) else {
296 continue;
297 };
298 if js.left_keys.iter().any(|&k| k >= left_schema.arity())
299 || js.right_keys.iter().any(|&k| k >= right_schema.arity())
300 {
301 continue;
302 }
303
304 mgr.set_join_selectivity(
305 left_id,
306 right_id,
307 js.left_keys.clone(),
308 js.right_keys.clone(),
309 js.selectivity,
310 );
311 }
312 }
313 }
314
315 let schemas_by_rel_id: HashMap<RelId, Schema> = self
317 .lowerer
318 .rel_ids()
319 .iter()
320 .filter_map(|(pred, rel_id)| {
321 self.lowerer
322 .schemas()
323 .get(pred)
324 .map(|schema| (*rel_id, schema.clone()))
325 })
326 .collect();
327
328 let stats_arc = Arc::new(mgr);
329
330 crate::optimizer::helper_split_pass::run(
331 &mut plan,
332 &schemas_by_rel_id,
333 &stats_arc,
334 |schema| self.lowerer.create_helper_relation(schema),
335 );
336
337 let schemas_by_rel_id: HashMap<RelId, Schema> = self
338 .lowerer
339 .rel_ids()
340 .iter()
341 .filter_map(|(pred, rel_id)| {
342 self.lowerer
343 .schemas()
344 .get(pred)
345 .map(|schema| (*rel_id, schema.clone()))
346 })
347 .collect();
348
349 let mut optimizer = Optimizer::new(Arc::clone(&stats_arc));
350 optimizer.set_schemas(schemas_by_rel_id);
351 for rules in &mut plan.rules_by_scc {
352 for rule in rules {
353 rule.body = optimizer.optimize(rule.body.clone());
354 }
355 }
356
357 crate::optimizer::selectivity_pass::run(&mut plan, &stats_arc, self.lowerer.rel_ids());
366
367 crate::promote::promote_multiway(&mut plan, self.lowerer.rel_ids(), &stats_arc, config);
380
381 let schemas_by_rel_id: HashMap<RelId, Schema> = self
382 .lowerer
383 .rel_ids()
384 .iter()
385 .filter_map(|(pred, rel_id)| {
386 self.lowerer
387 .schemas()
388 .get(pred)
389 .map(|schema| (*rel_id, schema.clone()))
390 })
391 .collect();
392
393 crate::optimizer::helper_split_pass::run_kclique_specs(
394 &mut plan,
395 &schemas_by_rel_id,
396 |schema| self.lowerer.create_helper_relation(schema),
397 );
398
399 Ok(plan)
400 }
401
402 pub fn reset(&mut self) {
407 self.lowerer = Lowerer::new();
408 }
409
410 pub fn rel_ids(&self) -> &HashMap<String, RelId> {
415 self.lowerer.rel_ids()
416 }
417
418 pub fn schemas(&self) -> &HashMap<String, Schema> {
422 self.lowerer.schemas()
423 }
424}
425
426fn desugar_queries_and_constraints(program: &Program) -> Program {
427 let mut out = program.clone();
428
429 for (i, constraint) in program.constraints.iter().enumerate() {
431 let pred = format!("__xlog_constraint_{}", i);
432 out.rules.push(AstRule {
433 head: crate::ast::Atom {
434 predicate: pred,
435 terms: vec![Term::Integer(1)],
436 },
437 body: constraint.body.clone(),
438 });
439 }
440
441 for (i, Query { atom }) in program.queries.iter().enumerate() {
443 let pred = format!("__xlog_query_{}", i);
444
445 let mut head_terms: Vec<Term> = Vec::new();
446 let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
447
448 for term in &atom.terms {
449 for name in term.variables() {
450 if seen.insert(name) {
451 head_terms.push(Term::Variable(name.to_string()));
452 }
453 }
454 }
455
456 if head_terms.is_empty() {
457 head_terms.push(Term::Integer(1));
458 }
459
460 out.rules.push(AstRule {
461 head: crate::ast::Atom {
462 predicate: pred,
463 terms: head_terms,
464 },
465 body: vec![BodyLiteral::Positive(atom.clone())],
466 });
467 }
468
469 out
470}
471
472fn validate_negation_safety(program: &Program) -> Result<()> {
473 for rule in &program.rules {
474 validate_body_naf_safety(&rule.body, &format!("rule {}", rule.head.predicate))?;
475 }
476 for (idx, constraint) in program.constraints.iter().enumerate() {
477 validate_body_naf_safety(&constraint.body, &format!("constraint {}", idx))?;
478 }
479 for (idx, learnable) in program.learnable_rules.iter().enumerate() {
480 validate_body_naf_safety(&learnable.body, &format!("learnable rule {}", idx))?;
481 }
482 Ok(())
483}
484
485fn validate_body_naf_safety(body: &[BodyLiteral], context: &str) -> Result<()> {
486 let mut bound: HashSet<String> = HashSet::new();
487 for lit in body {
488 match lit {
489 BodyLiteral::Positive(atom) => {
490 for name in atom.variables() {
491 bound.insert(name.to_string());
492 }
493 }
494 BodyLiteral::Negated(atom) => {
495 for name in atom.variables() {
496 if !bound.contains(name) {
497 return Err(naf_error(format!(
498 "unbound variable {} in negated atom {}/{} in {}; bind it before not with a positive atom or deterministic is expression, or use '_' for existential positions",
499 name,
500 atom.predicate,
501 atom.arity(),
502 context
503 )));
504 }
505 }
506 }
507 BodyLiteral::IsExpr(is_expr) => {
508 bound.insert(is_expr.target.clone());
509 }
510 BodyLiteral::Epistemic(_) => {}
511 BodyLiteral::Comparison(_) | BodyLiteral::Univ(_) => {}
512 }
513 }
514 Ok(())
515}
516
517fn map_stratification_to_naf_error(err: XlogError) -> XlogError {
518 match err {
519 XlogError::StratificationCycle(cycle) => naf_error(format!(
520 "deterministic not atom must be stratified; cycle through negation or aggregation: {}",
521 cycle.join(" -> ")
522 )),
523 other => other,
524 }
525}
526
527fn naf_error(message: impl Into<String>) -> XlogError {
528 XlogError::Compilation(format!("negation safety error: {}", message.into()))
529}
530
531pub fn compile(source: &str) -> Result<ExecutionPlan> {
544 let mut compiler = Compiler::new();
545 compiler.compile(source)
546}
547
548pub fn load_modules(
565 entry_file: &Path,
566 search_paths: Vec<PathBuf>,
567) -> std::result::Result<ModuleResolver, ModuleError> {
568 let mut resolver = ModuleResolver::new(search_paths);
569
570 let base_dir = entry_file.parent().unwrap_or(Path::new("."));
572 let module_name = entry_file
573 .file_stem()
574 .and_then(|s| s.to_str())
575 .unwrap_or("main");
576
577 resolver.load_module(base_dir, &[module_name.to_string()])?;
579
580 Ok(resolver)
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586 use xlog_core::ScalarType;
587 use xlog_ir::RirNode;
588 use xlog_stats::ColumnStats;
589 use xlog_stats::RelationStats;
590 use xlog_stats::StatsManager;
591
592 #[test]
593 fn test_compiler_new() {
594 let compiler = Compiler::new();
595 drop(compiler);
597 }
598
599 #[test]
600 fn test_compile_fact() {
601 let mut compiler = Compiler::new();
602 let result = compiler.compile("edge(1, 2).");
603 assert!(result.is_ok(), "Failed to compile fact: {:?}", result.err());
604 }
605
606 #[test]
607 fn test_compile_simple_rule() {
608 let mut compiler = Compiler::new();
609 let result = compiler.compile(
610 r#"
611 edge(1, 2).
612 reach(X, Y) :- edge(X, Y).
613 "#,
614 );
615 assert!(
616 result.is_ok(),
617 "Failed to compile simple rule: {:?}",
618 result.err()
619 );
620
621 let plan = result.unwrap();
622 assert!(!plan.sccs.is_empty(), "Expected at least one SCC");
623 }
624
625 #[test]
626 fn test_compile_transitive_closure() {
627 let mut compiler = Compiler::new();
628 let result = compiler.compile(
629 r#"
630 edge(1, 2).
631 edge(2, 3).
632 edge(3, 4).
633 reach(X, Y) :- edge(X, Y).
634 reach(X, Z) :- reach(X, Y), edge(Y, Z).
635 "#,
636 );
637 assert!(result.is_ok(), "Failed to compile TC: {:?}", result.err());
638
639 let plan = result.unwrap();
640 assert!(!plan.sccs.is_empty());
642 }
643
644 #[test]
645 fn test_compile_with_negation() {
646 let mut compiler = Compiler::new();
647 let result = compiler.compile(
648 r#"
649 node(1).
650 node(2).
651 node(3).
652 edge(1, 2).
653 isolated(X) :- node(X), not edge(X, _).
654 "#,
655 );
656 assert!(
657 result.is_ok(),
658 "Failed to compile with negation: {:?}",
659 result.err()
660 );
661 }
662
663 #[test]
664 fn test_compile_with_comparison() {
665 let mut compiler = Compiler::new();
666 let result = compiler.compile(
667 r#"
668 value(1).
669 value(5).
670 value(10).
671 value(15).
672 small(X) :- value(X), X < 10.
673 "#,
674 );
675 assert!(
676 result.is_ok(),
677 "Failed to compile with comparison: {:?}",
678 result.err()
679 );
680 }
681
682 #[test]
683 fn test_schema_infers_from_rule_body_types() {
684 let mut compiler = Compiler::new();
685 let result = compiler.compile(
686 r#"
687 edge(1, 2).
688 edge(2, 3).
689 reach(X, Y) :- edge(X, Y).
690 "#,
691 );
692 assert!(
693 result.is_ok(),
694 "Failed to compile rule for schema inference: {:?}",
695 result.err()
696 );
697
698 let schema = compiler
699 .schemas()
700 .get("reach")
701 .expect("missing reach schema");
702 assert_eq!(
703 schema.column_type(0),
704 Some(ScalarType::U32),
705 "reach column 0 should match edge column type"
706 );
707 assert_eq!(
708 schema.column_type(1),
709 Some(ScalarType::U32),
710 "reach column 1 should match edge column type"
711 );
712 }
713
714 #[test]
715 fn test_compile_unstratifiable_fails() {
716 let mut compiler = Compiler::new();
717 let result = compiler.compile(
718 r#"
719 p :- not q.
720 q :- not p.
721 "#,
722 );
723 assert!(result.is_err(), "Should fail with stratification cycle");
724 }
725
726 #[test]
727 fn test_compile_syntax_error_fails() {
728 let mut compiler = Compiler::new();
729 let result = compiler.compile("edge(1, 2"); assert!(result.is_err(), "Should fail with syntax error");
731 }
732
733 #[test]
734 fn test_compile_convenience_function() {
735 let result = compile("edge(1, 2).");
736 assert!(
737 result.is_ok(),
738 "Convenience compile failed: {:?}",
739 result.err()
740 );
741 }
742
743 #[test]
744 fn test_compiler_reset() {
745 let mut compiler = Compiler::new();
746
747 let result1 = compiler.compile("edge(1, 2).");
749 assert!(result1.is_ok());
750
751 compiler.reset();
753 let result2 = compiler.compile("node(1). node(2).");
754 assert!(result2.is_ok());
755 }
756
757 #[test]
758 fn test_compile_with_pred_decl() {
759 let mut compiler = Compiler::new();
760 let result = compiler.compile(
761 r#"
762 pred edge(u32, u32).
763 edge(1, 2).
764 edge(2, 3).
765 reach(X, Y) :- edge(X, Y).
766 "#,
767 );
768 assert!(
769 result.is_ok(),
770 "Failed to compile with pred decl: {:?}",
771 result.err()
772 );
773 }
774
775 #[test]
776 fn test_compile_multi_stratum() {
777 let mut compiler = Compiler::new();
778 let result = compiler.compile(
779 r#"
780 // Base facts
781 edge(1, 2).
782 edge(2, 3).
783 edge(3, 1).
784
785 // Stratum 0: edge (base)
786 // Stratum 1: reach (depends on edge, recursive)
787 reach(X, Y) :- edge(X, Y).
788 reach(X, Z) :- reach(X, Y), edge(Y, Z).
789
790 // Stratum 2: non_reach (negates reach)
791 all_pairs(X, Y) :- edge(X, Z), edge(Y, W).
792 non_reach(X, Y) :- all_pairs(X, Y), not reach(X, Y).
793 "#,
794 );
795 assert!(
796 result.is_ok(),
797 "Failed to compile multi-stratum: {:?}",
798 result.err()
799 );
800
801 let plan = result.unwrap();
802 assert!(!plan.strata.is_empty(), "Expected multiple strata");
804 }
805
806 #[test]
807 fn test_compile_aggregation() {
808 let mut compiler = Compiler::new();
809 let result = compiler.compile(
810 r#"
811 edge(1, 2).
812 edge(1, 3).
813 edge(2, 3).
814 out_degree(X, count(Y)) :- edge(X, Y).
815 "#,
816 );
817 assert!(
818 result.is_ok(),
819 "Failed to compile with aggregation: {:?}",
820 result.err()
821 );
822
823 let plan = result.unwrap();
824 let out_degree_rules: Vec<_> = plan
825 .rules_by_scc
826 .iter()
827 .flatten()
828 .filter(|r| r.head == "out_degree")
829 .collect();
830 assert_eq!(out_degree_rules.len(), 1, "Expected one out_degree rule");
831
832 let body = &out_degree_rules[0].body;
834 match body {
835 RirNode::Project { input, .. } => {
836 assert!(
837 matches!(input.as_ref(), RirNode::GroupBy { .. }),
838 "Expected Project(GroupBy(..)), got {:?}",
839 input
840 );
841 }
842 other => panic!("Expected Project(GroupBy(..)), got {:?}", other),
843 }
844 }
845
846 #[test]
847 fn test_compile_with_stats_snapshot() {
848 let mut compiler = Compiler::new();
849 let source = r#"
850 edge(1, 2).
851 edge(2, 3).
852 reach(X, Y) :- edge(X, Y).
853 "#;
854
855 let _ = compiler.compile(source).expect("Initial compile failed");
856 let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
857
858 let mut mgr = StatsManager::new();
859 mgr.register_relation(edge_id);
860 mgr.update_cardinality(edge_id, 42);
861 let snapshot = mgr.snapshot();
862
863 let plan = compiler
864 .compile_with_stats_snapshot(source, Some(&snapshot))
865 .expect("Compile with snapshot failed");
866 assert!(!plan.sccs.is_empty());
867 }
868
869 #[test]
870 fn test_compile_with_named_stats_snapshot_reorders_joins() {
871 let mut compiler = Compiler::new();
872 let source = r#"
873 foo(1).
874 edge(1).
875 out(X) :- edge(X), foo(X).
876 "#;
877
878 let mut edge_stats = RelationStats::new(RelId(0));
881 edge_stats.update_cardinality(10);
882 let mut foo_stats = RelationStats::new(RelId(1));
883 foo_stats.update_cardinality(10_000);
884
885 let snapshot = StatsSnapshot {
886 relations: vec![edge_stats, foo_stats],
887 join_selectivities: Vec::new(),
888 rel_names: vec![
889 (RelId(0), "edge".to_string()),
890 (RelId(1), "foo".to_string()),
891 ],
892 };
893
894 let plan = compiler
895 .compile_with_stats_snapshot(source, Some(&snapshot))
896 .expect("Compile with named snapshot failed");
897
898 let foo_id = *compiler.rel_ids().get("foo").expect("foo rel_id missing");
899 let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
900
901 let out_rule = plan
902 .rules_by_scc
903 .iter()
904 .flatten()
905 .find(|r| r.head == "out")
906 .expect("out rule missing");
907
908 let mut node = &out_rule.body;
910 while let RirNode::Project { input, .. } = node {
911 node = input;
912 }
913
914 match node {
915 RirNode::ChainJoin {
916 left,
917 right,
918 fallback,
919 ..
920 } => {
921 assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
925 assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
926
927 let mut fallback_node = fallback.as_ref();
928 while let RirNode::Project { input, .. } = fallback_node {
929 fallback_node = input;
930 }
931 match fallback_node {
932 RirNode::Join { left, right, .. } => {
933 assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
934 assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
935 }
936 other => panic!("Expected ChainJoin fallback Join node, got {:?}", other),
937 }
938 }
939 RirNode::Join { left, right, .. } => {
940 assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
942 assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
943 }
944 other => panic!("Expected Join node, got {:?}", other),
945 }
946 }
947
948 fn helper_split_source() -> &'static str {
949 r#"
950 ab(0, 0). bc(0, 0). cd(0, 0). de(0, 0). ef(0, 0). af(0, 0).
951 out(A, B, C, D, F) :-
952 ab(A, B),
953 bc(B, C),
954 cd(C, D),
955 de(D, E),
956 ef(E, F),
957 af(A, F).
958 "#
959 }
960
961 fn helper_split_snapshot(distinct_d: u64) -> StatsSnapshot {
962 let mut snapshot_relations = Vec::new();
963 for (idx, name) in ["ab", "bc", "cd", "de", "ef", "af"].iter().enumerate() {
964 let mut rel_stats = RelationStats::new(RelId(idx as u32));
965 rel_stats.update_cardinality(8192);
966 if *name == "de" {
967 let mut d_col = ColumnStats::new(0, ScalarType::U32);
968 d_col.update_distinct(distinct_d);
969 rel_stats.add_column(d_col);
970 }
971 snapshot_relations.push(rel_stats);
972 }
973 StatsSnapshot {
974 relations: snapshot_relations,
975 join_selectivities: Vec::new(),
976 rel_names: ["ab", "bc", "cd", "de", "ef", "af"]
977 .iter()
978 .enumerate()
979 .map(|(idx, name)| (RelId(idx as u32), (*name).to_string()))
980 .collect(),
981 }
982 }
983
984 #[test]
985 fn test_compile_with_named_stats_snapshot_creates_helper_relation() {
986 let mut compiler = Compiler::new();
987 let snapshot = helper_split_snapshot(1);
988 let plan = compiler
989 .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
990 .expect("compile with helper stats");
991 let helper = compiler
992 .rel_ids()
993 .iter()
994 .find_map(|(name, rel)| {
995 name.starts_with("__kclique_helper_")
996 .then_some((name.clone(), *rel))
997 })
998 .expect("helper relation allocated");
999
1000 let helper_rule_count = plan
1001 .rules_by_scc
1002 .iter()
1003 .flatten()
1004 .filter(|rule| rule.head == helper.0)
1005 .count();
1006 assert_eq!(helper_rule_count, 1);
1007
1008 let helper_rule = plan
1009 .rules_by_scc
1010 .iter()
1011 .flatten()
1012 .find(|rule| rule.head == helper.0)
1013 .expect("helper rule");
1014 assert!(
1015 matches!(helper_rule.body, RirNode::ChainJoin { .. }),
1016 "helper split output should be eligible for ChainJoin promotion"
1017 );
1018
1019 let out_rule = plan
1020 .rules_by_scc
1021 .iter()
1022 .flatten()
1023 .find(|rule| rule.head == "out")
1024 .expect("out rule");
1025 assert!(contains_scan(&out_rule.body, helper.1));
1026 }
1027
1028 #[test]
1029 fn test_compile_with_flat_named_stats_keeps_original_rule() {
1030 let mut compiler = Compiler::new();
1031 let snapshot = helper_split_snapshot(8192);
1032 let plan = compiler
1033 .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
1034 .expect("compile with flat stats");
1035
1036 assert!(!compiler
1037 .rel_ids()
1038 .keys()
1039 .any(|name| name.starts_with("__kclique_helper_")));
1040 let out_rules = plan
1041 .rules_by_scc
1042 .iter()
1043 .flatten()
1044 .filter(|rule| rule.head == "out")
1045 .count();
1046 assert_eq!(out_rules, 1);
1047 }
1048
1049 fn contains_scan(node: &RirNode, rel: RelId) -> bool {
1050 match node {
1051 RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
1052 RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
1053 contains_scan(left, rel) || contains_scan(right, rel)
1054 }
1055 RirNode::Project { input, .. }
1056 | RirNode::Filter { input, .. }
1057 | RirNode::Distinct { input, .. }
1058 | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
1059 RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
1060 RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
1061 RirNode::Fixpoint {
1062 base, recursive, ..
1063 } => contains_scan(base, rel) || contains_scan(recursive, rel),
1064 RirNode::MultiWayJoin { inputs, .. } => {
1065 inputs.iter().any(|input| contains_scan(input, rel))
1066 }
1067 RirNode::TensorMaskedJoin { rel_index, .. } => {
1068 rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
1069 }
1070 RirNode::Unit => false,
1071 }
1072 }
1073}