1use std::collections::{HashMap, HashSet};
14
15use xlog_core::{symbol, AggOp as CoreAggOp, RelId, Result, ScalarType, Schema, XlogError};
16use xlog_ir::{
17 CompareOp, CompiledRule, ConstValue, ExecutionPlan, Expr, JoinType, PlanBuilder, ProjectExpr,
18 RirMeta, RirNode, Scc, Stratum as IrStratum,
19};
20
21use crate::ast::{
22 AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Comparison, IsExpr, LearnableRule, PredColumn,
23 Program, Rule, Term, TypeRef,
24};
25use crate::stratify::{build_dependency_graph, find_sccs_for_lowering, DepType};
26
27struct JoinPlan<'a> {
28 node: RirNode,
29 leaf_order: Vec<&'a Atom>,
30 leaf_order_idx: Vec<usize>,
31 var_pos: HashMap<String, usize>,
32 width: usize,
33 est_rows: f64,
34 total_cost: f64,
35}
36
37fn pred_columns_for_decl(pred_decl: &crate::ast::PredDecl) -> Vec<PredColumn> {
38 if pred_decl.columns.is_empty() {
39 pred_decl
40 .types
41 .iter()
42 .cloned()
43 .map(|typ| PredColumn { name: None, typ })
44 .collect()
45 } else {
46 pred_decl.columns.clone()
47 }
48}
49
50fn resolve_pred_column_type(
51 predicate: &str,
52 index: usize,
53 typ: &TypeRef,
54 domains: &HashMap<String, ScalarType>,
55) -> Result<ScalarType> {
56 match typ {
57 TypeRef::Scalar(ty) => Ok(*ty),
58 TypeRef::Domain(name) => domains.get(name).copied().ok_or_else(|| {
59 XlogError::Compilation(format!(
60 "v0.8.5 unknown domain alias '{}' in predicate '{}' column {}",
61 name, predicate, index
62 ))
63 }),
64 TypeRef::List(_) | TypeRef::Term | TypeRef::Compound | TypeRef::PredRef => {
65 Ok(ScalarType::U64)
66 }
67 }
68}
69
70fn validate_lowerable_terms(program: &Program) -> Result<()> {
71 for rule in &program.rules {
72 validate_atom_terms(&rule.head, "rule head")?;
73 for lit in &rule.body {
74 match lit {
75 BodyLiteral::Positive(atom) => validate_atom_terms(atom, "positive body atom")?,
76 BodyLiteral::Negated(atom) => validate_atom_terms(atom, "negated body atom")?,
77 BodyLiteral::Epistemic(_) => {}
78 BodyLiteral::Comparison(cmp) => {
79 validate_term_lowerable(&cmp.left, "comparison left operand")?;
80 validate_term_lowerable(&cmp.right, "comparison right operand")?;
81 }
82 BodyLiteral::IsExpr(_) => {}
83 BodyLiteral::Univ(_) => {
84 return Err(XlogError::Compilation(
85 "v0.8.5 meta error: univ literal was not normalized before lowering"
86 .to_string(),
87 ));
88 }
89 }
90 }
91 }
92 for constraint in &program.constraints {
93 for lit in &constraint.body {
94 match lit {
95 BodyLiteral::Positive(atom) => validate_atom_terms(atom, "constraint body atom")?,
96 BodyLiteral::Negated(atom) => {
97 validate_atom_terms(atom, "constraint negated body atom")?
98 }
99 BodyLiteral::Epistemic(_) => {}
100 BodyLiteral::Comparison(cmp) => {
101 validate_term_lowerable(&cmp.left, "constraint comparison left operand")?;
102 validate_term_lowerable(&cmp.right, "constraint comparison right operand")?;
103 }
104 BodyLiteral::IsExpr(_) => {}
105 BodyLiteral::Univ(_) => {
106 return Err(XlogError::Compilation(
107 "v0.8.5 meta error: univ literal was not normalized before lowering"
108 .to_string(),
109 ));
110 }
111 }
112 }
113 }
114 for query in &program.queries {
115 validate_atom_terms(&query.atom, "query atom")?;
116 }
117 for pf in &program.prob_facts {
118 validate_atom_terms(&pf.atom, "probabilistic fact")?;
119 }
120 for ad in &program.annotated_disjunctions {
121 for choice in &ad.choices {
122 validate_atom_terms(&choice.atom, "annotated disjunction choice")?;
123 }
124 }
125 for evidence in &program.evidence {
126 validate_atom_terms(&evidence.atom, "evidence atom")?;
127 }
128 for query in &program.prob_queries {
129 validate_atom_terms(&query.atom, "probabilistic query")?;
130 }
131 for neural in &program.neural_predicates {
132 validate_atom_terms(&neural.predicate, "neural predicate")?;
133 }
134 for learnable in &program.learnable_rules {
135 validate_atom_terms(&learnable.head, "learnable rule head")?;
136 for lit in &learnable.body {
137 if let BodyLiteral::Positive(atom) = lit {
138 validate_atom_terms(atom, "learnable rule body")?;
139 }
140 }
141 }
142 Ok(())
143}
144
145fn validate_atom_terms(atom: &Atom, context: &str) -> Result<()> {
146 for term in &atom.terms {
147 validate_term_lowerable(term, context)?;
148 }
149 Ok(())
150}
151
152fn validate_term_lowerable(term: &Term, context: &str) -> Result<()> {
153 match term {
154 Term::List(_) => Err(term_not_lowerable_error(context, "list")),
155 Term::Cons { .. } => Err(term_not_lowerable_error(context, "cons")),
156 Term::Compound { .. } => Err(term_not_lowerable_error(context, "compound")),
157 Term::PredRef(_) => Err(term_not_lowerable_error(context, "predref")),
158 Term::Variable(_)
159 | Term::Anonymous
160 | Term::Integer(_)
161 | Term::Float(_)
162 | Term::String(_)
163 | Term::Symbol(_)
164 | Term::Aggregate(_) => Ok(()),
165 }
166}
167
168fn term_not_lowerable_error(context: &str, kind: &str) -> XlogError {
169 XlogError::Compilation(format!(
170 "term form '{}' in {} is parsed but not lowerable by this execution path",
171 kind, context
172 ))
173}
174
175fn term_kind_for_lowering_error(term: &Term) -> &'static str {
176 match term {
177 Term::List(_) => "list",
178 Term::Cons { .. } => "cons",
179 Term::Compound { .. } => "compound",
180 Term::PredRef(_) => "predref",
181 Term::Variable(_)
182 | Term::Anonymous
183 | Term::Integer(_)
184 | Term::Float(_)
185 | Term::String(_)
186 | Term::Symbol(_)
187 | Term::Aggregate(_) => "term",
188 }
189}
190
191pub struct Lowerer {
193 schemas: HashMap<String, Schema>,
195 strata: Vec<Vec<String>>,
197 est_cardinality: HashMap<String, u64>,
199 cardinality_hints: HashMap<String, u64>,
201 next_rel_id: u32,
203 rel_ids: HashMap<String, RelId>,
205 sccs: Vec<Scc>,
207 max_active_rules: usize,
209}
210
211impl Default for Lowerer {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl Lowerer {
218 pub fn new() -> Self {
220 Self {
221 schemas: HashMap::new(),
222 strata: Vec::new(),
223 est_cardinality: HashMap::new(),
224 cardinality_hints: HashMap::new(),
225 next_rel_id: 0,
226 rel_ids: HashMap::new(),
227 sccs: Vec::new(),
228 max_active_rules: 32,
229 }
230 }
231
232 pub fn set_max_active_rules(&mut self, max: usize) {
234 self.max_active_rules = max;
235 }
236
237 pub(crate) fn set_strata(&mut self, strata: Vec<Vec<String>>) {
239 self.strata = strata;
240 }
241
242 pub(crate) fn set_cardinality_hints(&mut self, hints: HashMap<String, u64>) {
246 self.cardinality_hints = hints;
247 }
248
249 pub fn rel_ids(&self) -> &HashMap<String, RelId> {
251 &self.rel_ids
252 }
253
254 pub fn schemas(&self) -> &HashMap<String, Schema> {
256 &self.schemas
257 }
258
259 pub(crate) fn create_helper_relation(&mut self, schema: Schema) -> (String, RelId) {
260 let name = format!("__kclique_helper_{}", self.next_rel_id);
261 let rel_id = self.get_or_create_rel_id(&name);
262 self.schemas.insert(name.clone(), schema);
263 (name, rel_id)
264 }
265
266 fn get_or_create_rel_id(&mut self, name: &str) -> RelId {
268 if let Some(&id) = self.rel_ids.get(name) {
269 id
270 } else {
271 let id = RelId(self.next_rel_id);
272 self.next_rel_id += 1;
273 self.rel_ids.insert(name.to_string(), id);
274 id
275 }
276 }
277
278 fn infer_schemas(&mut self, program: &Program) -> Result<()> {
280 let domains: HashMap<String, ScalarType> = program
281 .domains
282 .iter()
283 .map(|domain| (domain.name.clone(), domain.typ))
284 .collect();
285
286 for pred_decl in &program.predicates {
288 let declared_columns = pred_columns_for_decl(pred_decl);
289 let columns: Vec<(String, ScalarType)> = declared_columns
290 .iter()
291 .enumerate()
292 .map(|(i, col)| {
293 let name = col.name.clone().unwrap_or_else(|| format!("c{}", i));
294 resolve_pred_column_type(&pred_decl.name, i, &col.typ, &domains)
295 .map(|ty| (name, ty))
296 })
297 .collect::<Result<Vec<_>>>()?;
298 self.schemas
299 .insert(pred_decl.name.clone(), Schema::new(columns));
300 }
301
302 for rule in program.facts() {
304 let pred = &rule.head.predicate;
305 if !self.schemas.contains_key(pred) {
306 let columns: Vec<(String, ScalarType)> = rule
307 .head
308 .terms
309 .iter()
310 .enumerate()
311 .map(|(i, term)| {
312 let ty = infer_term_type(term);
313 (format!("c{}", i), ty)
314 })
315 .collect();
316 self.schemas.insert(pred.clone(), Schema::new(columns));
317 }
318 }
319
320 for rule in &program.rules {
322 let pred = &rule.head.predicate;
323 if !self.schemas.contains_key(pred) {
324 let columns: Vec<(String, ScalarType)> = rule
326 .head
327 .terms
328 .iter()
329 .enumerate()
330 .map(|(i, term)| {
331 let ty = match term {
332 Term::Variable(name) => self
333 .infer_head_term_type_from_body(rule, name)
334 .unwrap_or_else(|| infer_term_type(term)),
335 _ => infer_term_type(term),
336 };
337 (format!("c{}", i), ty)
338 })
339 .collect();
340 let schema = Schema::new(columns)
341 .with_sort_labels(sort_labels_from_terms(&rule.head.terms))
342 .expect("rule head sort labels match inferred schema arity");
343 self.schemas.insert(pred.clone(), schema);
344 }
345 }
346
347 for rule in &program.rules {
349 for lit in &rule.body {
350 let atom = match lit {
351 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
352 BodyLiteral::Epistemic(_)
353 | BodyLiteral::Comparison(_)
354 | BodyLiteral::IsExpr(_)
355 | BodyLiteral::Univ(_) => continue,
356 };
357 let pred = &atom.predicate;
358 if self.schemas.contains_key(pred) {
359 continue;
360 }
361 let columns: Vec<(String, ScalarType)> = atom
362 .terms
363 .iter()
364 .enumerate()
365 .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
366 .collect();
367 let schema = Schema::new(columns)
368 .with_sort_labels(sort_labels_from_terms(&atom.terms))
369 .expect("body sort labels match inferred schema arity");
370 self.schemas.insert(pred.clone(), schema);
371 }
372 }
373
374 for pf in &program.prob_facts {
376 let pred = &pf.atom.predicate;
377 if self.schemas.contains_key(pred) {
378 continue;
379 }
380 let columns: Vec<(String, ScalarType)> = pf
381 .atom
382 .terms
383 .iter()
384 .enumerate()
385 .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
386 .collect();
387 self.schemas.insert(pred.clone(), Schema::new(columns));
388 }
389
390 for ad in &program.annotated_disjunctions {
391 for choice in &ad.choices {
392 let pred = &choice.atom.predicate;
393 if self.schemas.contains_key(pred) {
394 continue;
395 }
396 let columns: Vec<(String, ScalarType)> = choice
397 .atom
398 .terms
399 .iter()
400 .enumerate()
401 .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
402 .collect();
403 self.schemas.insert(pred.clone(), Schema::new(columns));
404 }
405 }
406
407 Ok(())
408 }
409
410 fn infer_head_term_type_from_body(&self, rule: &Rule, var_name: &str) -> Option<ScalarType> {
411 for lit in &rule.body {
412 let atom = match lit {
413 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
414 BodyLiteral::Epistemic(_)
415 | BodyLiteral::Comparison(_)
416 | BodyLiteral::IsExpr(_)
417 | BodyLiteral::Univ(_) => continue,
418 };
419 let schema = self.schemas.get(&atom.predicate)?;
420 for (idx, term) in atom.terms.iter().enumerate() {
421 if let Term::Variable(name) = term {
422 if name == var_name {
423 if let Some(ty) = schema.column_type(idx) {
424 return Some(ty);
425 }
426 }
427 }
428 }
429 }
430 None
431 }
432
433 fn infer_cardinalities(&mut self, program: &Program) {
434 self.est_cardinality.clear();
435
436 let mut fact_counts: HashMap<String, u64> = HashMap::new();
437 for fact in program.facts() {
438 *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
439 }
440
441 for pred in self.schemas.keys() {
442 let est = self
443 .cardinality_hints
444 .get(pred)
445 .copied()
446 .or_else(|| fact_counts.get(pred).copied())
447 .unwrap_or(1000)
448 .max(1);
449 self.est_cardinality.insert(pred.clone(), est);
450 }
451 }
452
453 fn build_sccs(&mut self, program: &Program) {
455 let graph = build_dependency_graph(program);
456 let scc_groups = find_sccs_for_lowering(&graph);
457
458 self.sccs.clear();
459 for (id, predicates) in scc_groups.iter().enumerate() {
460 let is_recursive = if predicates.len() > 1 {
463 true
464 } else {
465 let pred = &predicates[0];
466 graph
467 .outgoing(pred)
468 .iter()
469 .any(|e| e.to == *pred && e.dep_type == DepType::Positive)
470 };
471
472 self.sccs.push(Scc {
473 id: id as u32,
474 predicates: predicates.clone(),
475 is_recursive,
476 });
477 }
478 }
479
480 pub fn lower_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
482 validate_lowerable_terms(program)?;
483 self.infer_schemas(program)?;
485 self.infer_cardinalities(program);
486
487 for pred_decl in &program.predicates {
492 self.get_or_create_rel_id(&pred_decl.name);
493 }
494
495 self.build_sccs(program);
497
498 let mut builder = PlanBuilder::new();
500
501 for scc in &self.sccs {
503 builder.add_scc(scc.clone());
504 }
505
506 for (id, preds) in self.strata.iter().enumerate() {
508 let scc_ids: Vec<u32> = self
510 .sccs
511 .iter()
512 .filter(|scc| scc.predicates.iter().any(|p| preds.contains(p)))
513 .map(|scc| scc.id)
514 .collect();
515
516 if !scc_ids.is_empty() {
517 builder.add_stratum(IrStratum {
518 id: id as u32,
519 sccs: scc_ids,
520 });
521 }
522 }
523
524 let mut rules_by_pred: HashMap<String, Vec<&Rule>> = HashMap::new();
526 for rule in program.proper_rules() {
527 rules_by_pred
528 .entry(rule.head.predicate.clone())
529 .or_default()
530 .push(rule);
531 }
532
533 for fact in program.facts() {
535 let pred = &fact.head.predicate;
536 let scc_id = self.find_scc_for_predicate(pred);
537 let rel_id = self.get_or_create_rel_id(pred);
538
539 let body = RirNode::Scan { rel: rel_id };
540 let meta = self.create_meta_for_predicate(pred);
541
542 builder.add_rule(
543 scc_id,
544 CompiledRule {
545 head: pred.clone(),
546 body,
547 meta,
548 },
549 );
550 }
551
552 for (pred, rules) in &rules_by_pred {
554 let scc_id = self.find_scc_for_predicate(pred);
555
556 for rule in rules {
557 let body = self.lower_rule(rule)?;
558 let meta = self.create_meta_for_predicate(pred);
559
560 builder.add_rule(
561 scc_id,
562 CompiledRule {
563 head: pred.clone(),
564 body,
565 meta,
566 },
567 );
568 }
569 }
570
571 for learnable in &program.learnable_rules {
575 self.get_or_create_rel_id(&learnable.head.predicate);
576 for lit in &learnable.body {
577 if let BodyLiteral::Positive(atom) = lit {
578 self.get_or_create_rel_id(&atom.predicate);
579 }
580 }
581 }
582 for learnable in &program.learnable_rules {
583 let head_pred = &learnable.head.predicate;
584 let scc_id = self.find_scc_for_predicate(head_pred);
585 let body = self.lower_learnable_rule(learnable)?;
586 let meta = self.create_meta_for_predicate(head_pred);
587 builder.add_rule(
588 scc_id,
589 CompiledRule {
590 head: head_pred.clone(),
591 body,
592 meta,
593 },
594 );
595 }
596
597 let mut plan = builder.build();
598 for rule in program.proper_rules() {
604 if let Some(&id) = self.rel_ids.get(&rule.head.predicate) {
605 plan.rel_arities.insert(id, rule.head.terms.len());
606 }
607 for lit in &rule.body {
608 let atom = match lit {
609 BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a,
610 _ => continue,
611 };
612 if let Some(&id) = self.rel_ids.get(&atom.predicate) {
613 plan.rel_arities.insert(id, atom.terms.len());
614 }
615 }
616 }
617 for fact in program.facts() {
618 if let Some(&id) = self.rel_ids.get(&fact.head.predicate) {
619 plan.rel_arities.insert(id, fact.head.terms.len());
620 }
621 }
622 Ok(plan)
623 }
624
625 fn find_scc_for_predicate(&self, pred: &str) -> u32 {
627 self.sccs
628 .iter()
629 .find(|scc| scc.predicates.contains(&pred.to_string()))
630 .map(|scc| scc.id)
631 .unwrap_or(0)
632 }
633
634 fn create_meta_for_predicate(&self, pred: &str) -> RirMeta {
636 let schema = self
637 .schemas
638 .get(pred)
639 .cloned()
640 .unwrap_or_else(|| Schema::new(vec![]));
641 RirMeta::with_schema(schema)
642 }
643
644 fn lower_learnable_rule(&mut self, rule: &LearnableRule) -> Result<RirNode> {
649 if rule.body.len() != 2 {
651 return Err(XlogError::Compilation(format!(
652 "learnable rule '{}' requires exactly 2 body literals, got {}",
653 rule.mask_name,
654 rule.body.len()
655 )));
656 }
657 for (idx, lit) in rule.body.iter().enumerate() {
658 match lit {
659 BodyLiteral::Positive(_) => {}
660 _ => {
661 return Err(XlogError::Compilation(format!(
662 "learnable rule '{}' body[{}]: only positive atoms allowed",
663 rule.mask_name, idx
664 )));
665 }
666 }
667 }
668
669 let mut rel_index: Vec<(RelId, String)> = self
671 .rel_ids()
672 .iter()
673 .map(|(name, id)| (*id, name.clone()))
674 .collect();
675 rel_index.sort_by_key(|(id, _)| id.0);
676 let schema_size = rel_index.len();
677
678 let (left_keys, right_keys) =
679 self.extract_template_join_keys(&rule.body[0], &rule.body[1])?;
680
681 let head_rel_name = rule.head.predicate.clone();
682 let head_rel_id = self.get_or_create_rel_id(&head_rel_name);
684
685 let left_atom = rule.body[0].atom().unwrap();
688 let right_atom = rule.body[1].atom().unwrap();
689 let left_arity = left_atom.terms.len();
690
691 let mut var_to_col: HashMap<String, usize> = HashMap::new();
693 for (i, term) in left_atom.terms.iter().enumerate() {
694 if let Some(name) = term.variable_name() {
695 var_to_col.entry(name.to_string()).or_insert(i);
696 }
697 }
698 for (i, term) in right_atom.terms.iter().enumerate() {
699 if let Some(name) = term.variable_name() {
700 var_to_col.entry(name.to_string()).or_insert(left_arity + i);
701 }
702 }
703
704 let mut head_projection: Vec<usize> = Vec::new();
705 for term in &rule.head.terms {
706 if let Some(name) = term.variable_name() {
707 let col = var_to_col.get(name).ok_or_else(|| {
708 XlogError::Compilation(format!(
709 "Learnable rule head variable '{}' not found in body atoms \
710 ({}, {}). All head variables must appear in the body.",
711 name, left_atom.predicate, right_atom.predicate,
712 ))
713 })?;
714 head_projection.push(*col);
715 } else {
716 return Err(XlogError::Compilation(format!(
717 "Learnable rule head must contain only variables, \
718 found constant {:?} in head of '{}'",
719 term, head_rel_name,
720 )));
721 }
722 }
723
724 if !self.schemas.contains_key(&head_rel_name) {
727 let columns: Vec<(String, ScalarType)> = head_projection
728 .iter()
729 .enumerate()
730 .map(|(i, &col)| {
731 let ty = if col < left_arity {
733 self.schemas
734 .get(&left_atom.predicate)
735 .and_then(|s| s.column_type(col))
736 .unwrap_or(ScalarType::U32)
737 } else {
738 self.schemas
739 .get(&right_atom.predicate)
740 .and_then(|s| s.column_type(col - left_arity))
741 .unwrap_or(ScalarType::U32)
742 };
743 (format!("c{}", i), ty)
744 })
745 .collect();
746 self.schemas
747 .insert(head_rel_name.clone(), Schema::new(columns));
748 }
749
750 Ok(RirNode::TensorMaskedJoin {
751 mask_name: rule.mask_name.clone(),
752 schema_size,
753 left_keys,
754 right_keys,
755 rel_index,
756 head_rel_name,
757 head_rel_id,
758 max_active_rules: self.max_active_rules,
759 head_projection,
760 })
761 }
762
763 fn extract_template_join_keys(
766 &self,
767 left: &BodyLiteral,
768 right: &BodyLiteral,
769 ) -> Result<(Vec<usize>, Vec<usize>)> {
770 let left_atom = left
771 .atom()
772 .ok_or_else(|| XlogError::Compilation("Learnable body[0] is not an atom".into()))?;
773 let right_atom = right
774 .atom()
775 .ok_or_else(|| XlogError::Compilation("Learnable body[1] is not an atom".into()))?;
776
777 let mut left_keys = Vec::new();
778 let mut right_keys = Vec::new();
779
780 for (li, lt) in left_atom.terms.iter().enumerate() {
781 if let Some(lname) = lt.variable_name() {
782 for (ri, rt) in right_atom.terms.iter().enumerate() {
783 if let Some(rname) = rt.variable_name() {
784 if lname == rname {
785 left_keys.push(li);
786 right_keys.push(ri);
787 }
788 }
789 }
790 }
791 }
792
793 Ok((left_keys, right_keys))
794 }
795
796 fn lower_rule(&mut self, rule: &Rule) -> Result<RirNode> {
798 if let Some(lit) = rule.body.iter().find_map(|lit| match lit {
799 BodyLiteral::Epistemic(lit) => Some(lit),
800 _ => None,
801 }) {
802 return Err(XlogError::UnsupportedEpistemicConstruct {
803 construct: "RIR lowering boundary".to_string(),
804 context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
805 });
806 }
807
808 let (positive_atoms, negated_atoms, comparisons, is_exprs) =
810 Self::split_body_literals(&rule.body);
811
812 for lit in &rule.body {
815 match lit {
816 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => {
817 self.get_or_create_rel_id(&atom.predicate);
818 }
819 BodyLiteral::Epistemic(_)
820 | BodyLiteral::Comparison(_)
821 | BodyLiteral::IsExpr(_)
822 | BodyLiteral::Univ(_) => {}
823 }
824 }
825
826 let (positive_root, leaf_order) = if positive_atoms.is_empty() {
832 (RirNode::Unit, Vec::new())
833 } else {
834 self.plan_positive_atoms(&positive_atoms)?
835 };
836
837 let mut var_env = VariableEnv::new();
840 let mut current_col = 0;
841 for atom in &leaf_order {
842 let schema = self.schemas.get(&atom.predicate);
843 for (i, term) in atom.terms.iter().enumerate() {
844 if let Term::Variable(name) = term {
845 if name == "_" {
846 continue;
847 }
848 var_env.add_occurrence(name, atom.predicate.clone(), i, current_col + i);
849 if !var_env.types.contains_key(name) {
851 let typ = schema
852 .and_then(|s| s.column_type(i))
853 .unwrap_or(ScalarType::I64); var_env.types.insert(name.to_string(), typ);
855 }
856 }
857 }
858 current_col += atom.terms.len();
859 }
860 var_env.total_cols = current_col;
861
862 let body_node = self.lower_body_parts(
864 positive_root,
865 &negated_atoms,
866 &comparisons,
867 &is_exprs,
868 &mut var_env,
869 )?;
870
871 if rule.has_aggregation() {
872 return self.lower_aggregate_rule(&rule.head, body_node, &var_env);
873 }
874
875 let projection_exprs = self.compute_head_projection(&rule.head, &var_env)?;
877
878 if Self::is_identity_projection(&projection_exprs, var_env.column_count()) {
879 Ok(body_node)
880 } else {
881 Ok(RirNode::Project {
882 input: Box::new(body_node),
883 columns: projection_exprs,
884 })
885 }
886 }
887
888 fn split_body_literals(
889 body: &[BodyLiteral],
890 ) -> (Vec<&Atom>, Vec<&Atom>, Vec<&Comparison>, Vec<&IsExpr>) {
891 let mut positive_atoms: Vec<&Atom> = Vec::new();
892 let mut negated_atoms: Vec<&Atom> = Vec::new();
893 let mut comparisons: Vec<&Comparison> = Vec::new();
894 let mut is_exprs: Vec<&IsExpr> = Vec::new();
895
896 for lit in body {
897 match lit {
898 BodyLiteral::Positive(atom) => positive_atoms.push(atom),
899 BodyLiteral::Negated(atom) => negated_atoms.push(atom),
900 BodyLiteral::Epistemic(_) => {}
901 BodyLiteral::Comparison(cmp) => comparisons.push(cmp),
902 BodyLiteral::IsExpr(is_expr) => is_exprs.push(is_expr),
903 BodyLiteral::Univ(_) => {}
904 }
905 }
906
907 (positive_atoms, negated_atoms, comparisons, is_exprs)
908 }
909
910 fn atom_vars(atom: &Atom) -> std::collections::HashSet<String> {
911 atom.terms
912 .iter()
913 .flat_map(|t| t.variables().into_iter())
914 .filter(|name| *name != "_")
915 .map(ToOwned::to_owned)
916 .collect()
917 }
918
919 fn estimate_atom_rows(&self, atom: &Atom) -> f64 {
920 let base = self
921 .est_cardinality
922 .get(&atom.predicate)
923 .copied()
924 .unwrap_or(1000)
925 .max(1) as f64;
926
927 let const_count = atom
928 .terms
929 .iter()
930 .filter(|t| term_to_const_value(t).is_some())
931 .count();
932
933 let selectivity = 0.1_f64.powi(const_count as i32);
935 (base * selectivity).max(1.0)
936 }
937
938 fn build_cartesian_join(
939 &self,
940 left: RirNode,
941 right: RirNode,
942 left_width: usize,
943 right_width: usize,
944 ) -> RirNode {
945 let left_const_col =
948 ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
949 let right_const_col =
950 ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
951
952 let mut left_cols: Vec<ProjectExpr> = (0..left_width).map(ProjectExpr::Column).collect();
953 left_cols.push(left_const_col);
954 let left_aug = RirNode::Project {
955 input: Box::new(left),
956 columns: left_cols,
957 };
958
959 let mut right_cols: Vec<ProjectExpr> = (0..right_width).map(ProjectExpr::Column).collect();
960 right_cols.push(right_const_col);
961 let right_aug = RirNode::Project {
962 input: Box::new(right),
963 columns: right_cols,
964 };
965
966 let joined = RirNode::Join {
967 left: Box::new(left_aug),
968 right: Box::new(right_aug),
969 left_keys: vec![left_width],
970 right_keys: vec![right_width],
971 join_type: JoinType::Inner,
972 };
973
974 let mut keep: Vec<ProjectExpr> = Vec::with_capacity(left_width + right_width);
975 keep.extend((0..left_width).map(ProjectExpr::Column));
976 let right_start = left_width + 1;
977 keep.extend((right_start..right_start + right_width).map(ProjectExpr::Column));
978
979 RirNode::Project {
980 input: Box::new(joined),
981 columns: keep,
982 }
983 }
984
985 fn make_leaf_plan<'a>(&mut self, atom: &'a Atom, orig_idx: usize) -> Result<JoinPlan<'a>> {
986 let rel_id = self.get_or_create_rel_id(&atom.predicate);
987 let scan = RirNode::Scan { rel: rel_id };
988 let node = self.apply_constant_filters(scan, atom, 0)?;
989
990 let mut var_pos: HashMap<String, usize> = HashMap::new();
991 for (i, term) in atom.terms.iter().enumerate() {
992 if let Term::Variable(name) = term {
993 if name != "_" {
994 var_pos.entry(name.clone()).or_insert(i);
995 }
996 }
997 }
998
999 let est_rows = self.estimate_atom_rows(atom);
1000 Ok(JoinPlan {
1001 node,
1002 leaf_order: vec![atom],
1003 leaf_order_idx: vec![orig_idx],
1004 var_pos,
1005 width: atom.terms.len(),
1006 est_rows,
1007 total_cost: est_rows,
1008 })
1009 }
1010
1011 fn join_plans<'a>(&self, left: &JoinPlan<'a>, right: &JoinPlan<'a>) -> JoinPlan<'a> {
1012 let shared_vars: Vec<&String> = left
1013 .var_pos
1014 .keys()
1015 .filter(|v| right.var_pos.contains_key(*v))
1016 .collect();
1017
1018 let node = if shared_vars.is_empty() {
1019 self.build_cartesian_join(
1020 left.node.clone(),
1021 right.node.clone(),
1022 left.width,
1023 right.width,
1024 )
1025 } else {
1026 let mut key_pairs: Vec<(usize, usize)> = shared_vars
1027 .iter()
1028 .filter_map(|v| {
1029 Some((
1030 left.var_pos.get(*v).copied()?,
1031 right.var_pos.get(*v).copied()?,
1032 ))
1033 })
1034 .collect();
1035 key_pairs.sort_unstable();
1036
1037 let (left_keys, right_keys): (Vec<usize>, Vec<usize>) = key_pairs.into_iter().unzip();
1038
1039 RirNode::Join {
1040 left: Box::new(left.node.clone()),
1041 right: Box::new(right.node.clone()),
1042 left_keys,
1043 right_keys,
1044 join_type: JoinType::Inner,
1045 }
1046 };
1047
1048 let mut leaf_order = left.leaf_order.clone();
1049 leaf_order.extend(right.leaf_order.iter().copied());
1050
1051 let mut leaf_order_idx = left.leaf_order_idx.clone();
1052 leaf_order_idx.extend_from_slice(&right.leaf_order_idx);
1053
1054 let mut var_pos = left.var_pos.clone();
1055 for (var, pos) in &right.var_pos {
1056 var_pos.entry(var.clone()).or_insert(left.width + *pos);
1057 }
1058
1059 let shared = shared_vars.len();
1060 let mut selectivity = if shared == 0 {
1061 1.0
1062 } else {
1063 0.1_f64.powi(shared as i32)
1064 };
1065 if shared == 0 {
1066 selectivity *= 1.0e6;
1068 }
1069
1070 let output_rows = (left.est_rows * right.est_rows * selectivity).max(1.0);
1071
1072 let build_cost = right.est_rows;
1074 let probe_cost = left.est_rows * 0.5;
1075 let total_cost = left.total_cost + right.total_cost + build_cost + probe_cost + output_rows;
1076
1077 JoinPlan {
1078 node,
1079 leaf_order,
1080 leaf_order_idx,
1081 var_pos,
1082 width: left.width + right.width,
1083 est_rows: output_rows,
1084 total_cost,
1085 }
1086 }
1087
1088 fn plan_positive_atoms_bushy<'a>(
1089 &mut self,
1090 atoms: &[&'a Atom],
1091 ) -> Result<(RirNode, Vec<&'a Atom>)> {
1092 let n = atoms.len();
1093 if n == 0 {
1094 return Err(XlogError::Compilation("Empty rule body".to_string()));
1095 }
1096 if n == 1 {
1097 let plan = self.make_leaf_plan(atoms[0], 0)?;
1098 return Ok((plan.node, plan.leaf_order));
1099 }
1100
1101 let size = 1usize << n;
1102 let mut best: Vec<Option<JoinPlan<'a>>> = (0..size).map(|_| None).collect();
1103
1104 for (i, atom) in atoms.iter().enumerate() {
1105 best[1usize << i] = Some(self.make_leaf_plan(atom, i)?);
1106 }
1107
1108 fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1109 for (ai, bi) in a.iter().zip(b.iter()) {
1110 if ai != bi {
1111 return ai < bi;
1112 }
1113 }
1114 a.len() < b.len()
1115 }
1116
1117 for mask in 1..size {
1118 if mask.count_ones() <= 1 {
1119 continue;
1120 }
1121
1122 let mut best_for_mask: Option<JoinPlan<'a>> = None;
1123
1124 let mut sub = (mask - 1) & mask;
1125 while sub > 0 {
1126 let a = sub;
1127 let b = mask ^ a;
1128 if b == 0 {
1129 sub = (sub - 1) & mask;
1130 continue;
1131 }
1132
1133 let (Some(plan_a), Some(plan_b)) = (&best[a], &best[b]) else {
1134 sub = (sub - 1) & mask;
1135 continue;
1136 };
1137
1138 for (left, right) in [(plan_a, plan_b), (plan_b, plan_a)] {
1140 let cand = self.join_plans(left, right);
1141 let replace = match &best_for_mask {
1142 None => true,
1143 Some(current) => {
1144 if cand.total_cost < current.total_cost {
1145 true
1146 } else if (cand.total_cost - current.total_cost).abs() < 1e-9 {
1147 lex_lt(&cand.leaf_order_idx, ¤t.leaf_order_idx)
1148 } else {
1149 false
1150 }
1151 }
1152 };
1153
1154 if replace {
1155 best_for_mask = Some(cand);
1156 }
1157 }
1158
1159 sub = (sub - 1) & mask;
1160 }
1161
1162 best[mask] = best_for_mask;
1163 }
1164
1165 let full_mask = size - 1;
1166 if let Some(plan) = best[full_mask].take() {
1167 return Ok((plan.node, plan.leaf_order));
1168 }
1169
1170 let ordered = self.order_positive_atoms_greedy(atoms);
1172 let mut dummy_env = VariableEnv::new();
1173 let node = self.build_join_tree(&ordered, &mut dummy_env)?;
1174 Ok((node, ordered))
1175 }
1176
1177 fn plan_positive_atoms<'a>(&mut self, atoms: &[&'a Atom]) -> Result<(RirNode, Vec<&'a Atom>)> {
1178 if atoms.len() <= 1 {
1179 if atoms.is_empty() {
1180 return Err(XlogError::Compilation("Empty rule body".to_string()));
1181 }
1182 let plan = self.make_leaf_plan(atoms[0], 0)?;
1183 return Ok((plan.node, plan.leaf_order));
1184 }
1185
1186 const MAX_BUSHY_DP_ATOMS: usize = 10;
1187 if atoms.len() <= MAX_BUSHY_DP_ATOMS {
1188 return self.plan_positive_atoms_bushy(atoms);
1189 }
1190
1191 self.plan_positive_atoms_bushy_greedy(atoms)
1193 }
1194
1195 fn plan_positive_atoms_bushy_greedy<'a>(
1196 &mut self,
1197 atoms: &[&'a Atom],
1198 ) -> Result<(RirNode, Vec<&'a Atom>)> {
1199 if atoms.is_empty() {
1200 return Err(XlogError::Compilation("Empty rule body".to_string()));
1201 }
1202
1203 fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1204 for (ai, bi) in a.iter().zip(b.iter()) {
1205 if ai != bi {
1206 return ai < bi;
1207 }
1208 }
1209 a.len() < b.len()
1210 }
1211
1212 let mut plans: Vec<JoinPlan<'a>> = Vec::with_capacity(atoms.len());
1213 for (idx, atom) in atoms.iter().enumerate() {
1214 plans.push(self.make_leaf_plan(atom, idx)?);
1215 }
1216
1217 while plans.len() > 1 {
1218 let mut best_pair: Option<(usize, usize, JoinPlan<'a>)> = None;
1219
1220 for i in 0..plans.len() {
1221 for j in (i + 1)..plans.len() {
1222 let a = &plans[i];
1223 let b = &plans[j];
1224
1225 let cand_ab = self.join_plans(a, b);
1226 let cand_ba = self.join_plans(b, a);
1227
1228 let cand = if cand_ab.total_cost < cand_ba.total_cost
1229 || (cand_ab.total_cost - cand_ba.total_cost).abs() < 1e-9
1230 && lex_lt(&cand_ab.leaf_order_idx, &cand_ba.leaf_order_idx)
1231 {
1232 cand_ab
1233 } else {
1234 cand_ba
1235 };
1236
1237 let replace = match &best_pair {
1238 None => true,
1239 Some((_bi, _bj, best)) => {
1240 if cand.total_cost < best.total_cost {
1241 true
1242 } else if (cand.total_cost - best.total_cost).abs() < 1e-9 {
1243 lex_lt(&cand.leaf_order_idx, &best.leaf_order_idx)
1244 } else {
1245 false
1246 }
1247 }
1248 };
1249
1250 if replace {
1251 best_pair = Some((i, j, cand));
1252 }
1253 }
1254 }
1255
1256 let Some((i, j, joined)) = best_pair else {
1257 break;
1258 };
1259
1260 let (a, b) = if i < j { (i, j) } else { (j, i) };
1262 plans.remove(b);
1263 plans.remove(a);
1264 plans.push(joined);
1265 }
1266
1267 let plan = plans
1268 .pop()
1269 .ok_or_else(|| XlogError::Compilation("Join planning failed".to_string()))?;
1270 Ok((plan.node, plan.leaf_order))
1271 }
1272
1273 fn order_positive_atoms_greedy<'a>(&self, atoms: &[&'a Atom]) -> Vec<&'a Atom> {
1274 let mut remaining: Vec<(usize, &Atom)> = atoms.iter().copied().enumerate().collect();
1275 let mut ordered: Vec<&Atom> = Vec::with_capacity(atoms.len());
1276 let mut bound_vars: HashSet<String> = HashSet::new();
1277
1278 while !remaining.is_empty() {
1279 let pick_idx = if ordered.is_empty() {
1280 remaining
1281 .iter()
1282 .enumerate()
1283 .min_by(|(_, a), (_, b)| {
1284 let (ai, aa) = **a;
1285 let (bi, bb) = **b;
1286 self.estimate_atom_rows(aa)
1287 .partial_cmp(&self.estimate_atom_rows(bb))
1288 .unwrap_or(std::cmp::Ordering::Equal)
1289 .then(ai.cmp(&bi))
1290 })
1291 .map(|(idx, _)| idx)
1292 .unwrap()
1293 } else {
1294 remaining
1295 .iter()
1296 .enumerate()
1297 .min_by(|(_, a), (_, b)| {
1298 let (ai, aa) = **a;
1299 let (bi, bb) = **b;
1300
1301 let a_vars = Self::atom_vars(aa);
1302 let b_vars = Self::atom_vars(bb);
1303
1304 let a_shared = a_vars.intersection(&bound_vars).count();
1305 let b_shared = b_vars.intersection(&bound_vars).count();
1306
1307 let a_score = if a_shared == 0 {
1308 self.estimate_atom_rows(aa) * 1.0e12
1309 } else {
1310 self.estimate_atom_rows(aa) / a_shared as f64
1311 };
1312 let b_score = if b_shared == 0 {
1313 self.estimate_atom_rows(bb) * 1.0e12
1314 } else {
1315 self.estimate_atom_rows(bb) / b_shared as f64
1316 };
1317
1318 a_score
1319 .partial_cmp(&b_score)
1320 .unwrap_or(std::cmp::Ordering::Equal)
1321 .then(ai.cmp(&bi))
1322 })
1323 .map(|(idx, _)| idx)
1324 .unwrap()
1325 };
1326
1327 let (_orig_idx, atom) = remaining.remove(pick_idx);
1328 ordered.push(atom);
1329 bound_vars.extend(Self::atom_vars(atom));
1330 }
1331
1332 ordered
1333 }
1334
1335 fn lower_body_parts(
1336 &mut self,
1337 positive_root: RirNode,
1338 negated_atoms: &[&Atom],
1339 comparisons: &[&Comparison],
1340 is_exprs: &[&IsExpr],
1341 var_env: &mut VariableEnv,
1342 ) -> Result<RirNode> {
1343 let mut result = positive_root;
1344
1345 for cmp in comparisons {
1347 result = self.apply_comparison(result, cmp, var_env)?;
1348 }
1349
1350 for is_expr in is_exprs {
1352 result = self.lower_is_expr(is_expr, result, var_env)?;
1353 }
1354
1355 for neg_atom in negated_atoms {
1357 result = self.apply_negation(result, neg_atom, var_env)?;
1358 }
1359
1360 Ok(result)
1361 }
1362
1363 fn build_join_tree(&mut self, atoms: &[&Atom], var_env: &mut VariableEnv) -> Result<RirNode> {
1365 if atoms.is_empty() {
1366 return Err(XlogError::Compilation("Empty rule body".to_string()));
1367 }
1368
1369 let first_atom = atoms[0];
1371 let rel_id = self.get_or_create_rel_id(&first_atom.predicate);
1372 let mut result = RirNode::Scan { rel: rel_id };
1373 let mut result_vars = self.collect_atom_vars(first_atom);
1374 let mut result_width = first_atom.terms.len();
1375
1376 result = self.apply_constant_filters(result, first_atom, 0)?;
1378
1379 for atom in atoms.iter().skip(1) {
1381 let right_rel_id = self.get_or_create_rel_id(&atom.predicate);
1382 let right_scan = RirNode::Scan { rel: right_rel_id };
1383
1384 let right_filtered = self.apply_constant_filters(right_scan, atom, 0)?;
1386
1387 let (left_keys, right_keys) = self.compute_join_keys(&result_vars, atom, result_width);
1389
1390 if left_keys.is_empty() {
1391 result = RirNode::Join {
1393 left: Box::new(result),
1394 right: Box::new(right_filtered),
1395 left_keys: vec![],
1396 right_keys: vec![],
1397 join_type: JoinType::Inner,
1398 };
1399 } else {
1400 result = RirNode::Join {
1401 left: Box::new(result),
1402 right: Box::new(right_filtered),
1403 left_keys,
1404 right_keys,
1405 join_type: JoinType::Inner,
1406 };
1407 }
1408
1409 for (i, term) in atom.terms.iter().enumerate() {
1411 if let Term::Variable(name) = term {
1412 result_vars.push((name.clone(), result_width + i));
1413 }
1414 }
1415 result_width += atom.terms.len();
1416 }
1417
1418 var_env.total_cols = result_width;
1420
1421 Ok(result)
1422 }
1423
1424 fn collect_atom_vars(&self, atom: &Atom) -> Vec<(String, usize)> {
1426 atom.terms
1427 .iter()
1428 .enumerate()
1429 .filter_map(|(i, term)| {
1430 if let Term::Variable(name) = term {
1431 Some((name.clone(), i))
1432 } else {
1433 None
1434 }
1435 })
1436 .collect()
1437 }
1438
1439 fn compute_join_keys(
1441 &self,
1442 left_vars: &[(String, usize)],
1443 right_atom: &Atom,
1444 _left_width: usize,
1445 ) -> (Vec<usize>, Vec<usize>) {
1446 let mut left_keys = Vec::new();
1447 let mut right_keys = Vec::new();
1448
1449 for (right_idx, term) in right_atom.terms.iter().enumerate() {
1450 if let Term::Variable(name) = term {
1451 for (left_name, left_idx) in left_vars {
1453 if left_name == name {
1454 left_keys.push(*left_idx);
1455 right_keys.push(right_idx);
1456 break; }
1458 }
1459 }
1460 }
1461
1462 (left_keys, right_keys)
1463 }
1464
1465 fn apply_constant_filters(
1467 &self,
1468 input: RirNode,
1469 atom: &Atom,
1470 _base_col: usize,
1471 ) -> Result<RirNode> {
1472 let mut filters = Vec::new();
1473 let mut first_var_col: HashMap<&str, usize> = HashMap::new();
1474 let schema = self.schemas.get(&atom.predicate).ok_or_else(|| {
1475 XlogError::Compilation(format!("Missing schema for predicate {}", atom.predicate))
1476 })?;
1477
1478 for (i, term) in atom.terms.iter().enumerate() {
1479 if let Term::Variable(name) = term {
1480 if name != "_" {
1481 if let Some(&first) = first_var_col.get(name.as_str()) {
1482 filters.push(Expr::Compare {
1483 left: Box::new(Expr::Column(first)),
1484 op: CompareOp::Eq,
1485 right: Box::new(Expr::Column(i)),
1486 });
1487 } else {
1488 first_var_col.insert(name.as_str(), i);
1489 }
1490 }
1491 }
1492
1493 let col_type = schema.column_type(i).ok_or_else(|| {
1494 XlogError::Compilation(format!(
1495 "Missing column type for {} column {}",
1496 atom.predicate, i
1497 ))
1498 })?;
1499 if let Some(const_val) = term_to_typed_const_value(term, col_type)? {
1500 filters.push(Expr::Compare {
1501 left: Box::new(Expr::Column(i)),
1502 op: CompareOp::Eq,
1503 right: Box::new(Expr::Const(const_val)),
1504 });
1505 }
1506 }
1507
1508 if filters.is_empty() {
1509 Ok(input)
1510 } else {
1511 let predicate = if filters.len() == 1 {
1512 filters.pop().unwrap()
1513 } else {
1514 Expr::And(filters)
1515 };
1516
1517 Ok(RirNode::Filter {
1518 input: Box::new(input),
1519 predicate,
1520 })
1521 }
1522 }
1523
1524 fn apply_comparison(
1526 &self,
1527 input: RirNode,
1528 cmp: &Comparison,
1529 var_env: &VariableEnv,
1530 ) -> Result<RirNode> {
1531 let (left_expr, right_expr) = match (&cmp.left, &cmp.right) {
1532 (Term::Variable(name), term) => {
1533 let col = var_env.get_column(name).ok_or_else(|| {
1534 XlogError::Compilation(format!("Variable {} not found in environment", name))
1535 })?;
1536 let typ = var_env.get_type(name).ok_or_else(|| {
1537 XlogError::Compilation(format!("Missing type for variable {}", name))
1538 })?;
1539 if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1540 (Expr::Column(col), Expr::Const(const_val))
1541 } else {
1542 (
1543 self.term_to_expr(&cmp.left, var_env)?,
1544 self.term_to_expr(&cmp.right, var_env)?,
1545 )
1546 }
1547 }
1548 (term, Term::Variable(name)) => {
1549 let col = var_env.get_column(name).ok_or_else(|| {
1550 XlogError::Compilation(format!("Variable {} not found in environment", name))
1551 })?;
1552 let typ = var_env.get_type(name).ok_or_else(|| {
1553 XlogError::Compilation(format!("Missing type for variable {}", name))
1554 })?;
1555 if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1556 (Expr::Const(const_val), Expr::Column(col))
1557 } else {
1558 (
1559 self.term_to_expr(&cmp.left, var_env)?,
1560 self.term_to_expr(&cmp.right, var_env)?,
1561 )
1562 }
1563 }
1564 _ => (
1565 self.term_to_expr(&cmp.left, var_env)?,
1566 self.term_to_expr(&cmp.right, var_env)?,
1567 ),
1568 };
1569
1570 let op = match cmp.op {
1571 CompOp::Eq => CompareOp::Eq,
1572 CompOp::Ne => CompareOp::Ne,
1573 CompOp::Lt => CompareOp::Lt,
1574 CompOp::Le => CompareOp::Le,
1575 CompOp::Gt => CompareOp::Gt,
1576 CompOp::Ge => CompareOp::Ge,
1577 };
1578
1579 Ok(RirNode::Filter {
1580 input: Box::new(input),
1581 predicate: Expr::Compare {
1582 left: Box::new(left_expr),
1583 op,
1584 right: Box::new(right_expr),
1585 },
1586 })
1587 }
1588
1589 fn term_to_expr(&self, term: &Term, var_env: &VariableEnv) -> Result<Expr> {
1591 match term {
1592 Term::Variable(name) => {
1593 if let Some(col) = var_env.get_column(name) {
1594 Ok(Expr::Column(col))
1595 } else {
1596 Err(XlogError::Compilation(format!(
1597 "Variable {} not found in environment",
1598 name
1599 )))
1600 }
1601 }
1602 Term::Anonymous => Err(XlogError::Compilation(
1603 "Anonymous wildcard '_' not allowed in comparisons".to_string(),
1604 )),
1605 Term::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
1606 Term::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
1607 Term::String(s) => Ok(Expr::Const(ConstValue::Symbol(s.clone()))),
1608 Term::Symbol(id) => Ok(Expr::Const(ConstValue::Symbol(symbol::resolve(*id)))),
1609 Term::Aggregate(_) => Err(XlogError::Compilation(
1610 "Aggregates not allowed in comparisons".to_string(),
1611 )),
1612 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => Err(
1613 term_not_lowerable_error("comparison", term_kind_for_lowering_error(term)),
1614 ),
1615 }
1616 }
1617
1618 fn apply_negation(
1620 &mut self,
1621 input: RirNode,
1622 neg_atom: &Atom,
1623 var_env: &VariableEnv,
1624 ) -> Result<RirNode> {
1625 let rel_id = self.get_or_create_rel_id(&neg_atom.predicate);
1626 let neg_scan = RirNode::Scan { rel: rel_id };
1627
1628 let neg_filtered = self.apply_constant_filters(neg_scan, neg_atom, 0)?;
1630
1631 let mut input_cols = Vec::new();
1633 let mut neg_cols = Vec::new();
1634
1635 for (neg_idx, term) in neg_atom.terms.iter().enumerate() {
1636 if let Term::Variable(name) = term {
1637 if let Some(col) = var_env.get_column(name) {
1638 input_cols.push(col);
1639 neg_cols.push(neg_idx);
1640 }
1641 }
1642 }
1643
1644 if input_cols.is_empty() {
1645 Ok(RirNode::Diff {
1649 left: Box::new(input),
1650 right: Box::new(neg_filtered),
1651 })
1652 } else {
1653 let neg_projected = if neg_cols.len() < neg_atom.terms.len() {
1655 let neg_proj_exprs: Vec<ProjectExpr> =
1656 neg_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1657 RirNode::Project {
1658 input: Box::new(neg_filtered),
1659 columns: neg_proj_exprs,
1660 }
1661 } else {
1662 neg_filtered
1663 };
1664
1665 let input_proj_exprs: Vec<ProjectExpr> =
1673 input_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1674 let input_projected = RirNode::Project {
1675 input: Box::new(input.clone()),
1676 columns: input_proj_exprs,
1677 };
1678
1679 let kept_keys = RirNode::Diff {
1681 left: Box::new(input_projected),
1682 right: Box::new(neg_projected),
1683 };
1684
1685 Ok(RirNode::Join {
1689 left: Box::new(input),
1690 right: Box::new(kept_keys),
1691 left_keys: input_cols.clone(),
1692 right_keys: (0..input_cols.len()).collect(),
1693 join_type: JoinType::Semi,
1694 })
1695 }
1696 }
1697
1698 fn is_identity_projection(proj: &[ProjectExpr], input_cols: usize) -> bool {
1699 if proj.len() != input_cols {
1700 return false;
1701 }
1702 proj.iter()
1703 .enumerate()
1704 .all(|(i, e)| matches!(e, ProjectExpr::Column(c) if *c == i))
1705 }
1706
1707 fn compute_head_projection(
1713 &self,
1714 head: &Atom,
1715 var_env: &VariableEnv,
1716 ) -> Result<Vec<ProjectExpr>> {
1717 let mut cols = Vec::with_capacity(head.terms.len());
1718
1719 for term in &head.terms {
1720 match term {
1721 Term::Variable(name) => {
1722 let col = var_env
1723 .get_column(name)
1724 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1725 cols.push(ProjectExpr::Column(col));
1726 }
1727 Term::Anonymous => {
1728 return Err(XlogError::Compilation(
1729 "Anonymous wildcard '_' not allowed in rule head".to_string(),
1730 ));
1731 }
1732 Term::Aggregate(_) => {
1733 return Err(XlogError::Compilation(
1734 "Aggregate term in non-aggregate rule head".to_string(),
1735 ));
1736 }
1737 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1738 let (expr, typ) = term_to_project_const_expr(term)?;
1739 cols.push(ProjectExpr::Computed(expr, typ));
1740 }
1741 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1742 return Err(term_not_lowerable_error(
1743 "rule head projection",
1744 term_kind_for_lowering_error(term),
1745 ));
1746 }
1747 }
1748 }
1749
1750 Ok(cols)
1751 }
1752
1753 fn lower_aggregate_rule(
1755 &mut self,
1756 head: &Atom,
1757 body: RirNode,
1758 var_env: &VariableEnv,
1759 ) -> Result<RirNode> {
1760 let mut key_vars: Vec<String> = Vec::new();
1762 let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
1763 let mut key_src_cols: Vec<usize> = Vec::new();
1764
1765 let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
1767 let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
1768 let mut value_vars: Vec<String> = Vec::new();
1769 let mut value_var_to_pos: HashMap<String, usize> = HashMap::new();
1770 let mut value_src_cols: Vec<usize> = Vec::new();
1771
1772 for term in &head.terms {
1773 match term {
1774 Term::Variable(name) => {
1775 if !key_var_to_pos.contains_key(name) {
1776 let col = var_env
1777 .get_column(name)
1778 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1779 let pos = key_vars.len();
1780 key_vars.push(name.clone());
1781 key_var_to_pos.insert(name.clone(), pos);
1782 key_src_cols.push(col);
1783 }
1784 }
1785 Term::Aggregate(agg) => {
1786 let key = (agg.op, agg.variable.clone());
1787 if let std::collections::hash_map::Entry::Vacant(entry) = agg_to_pos.entry(key)
1788 {
1789 let col = var_env
1791 .get_column(&agg.variable)
1792 .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1793
1794 let value_pos = *value_var_to_pos
1796 .entry(agg.variable.clone())
1797 .or_insert_with(|| {
1798 let p = value_vars.len();
1799 value_vars.push(agg.variable.clone());
1800 value_src_cols.push(col);
1801 p
1802 });
1803
1804 let agg_pos = agg_specs.len();
1805 agg_specs.push((agg.op, agg.variable.clone()));
1806 entry.insert(agg_pos);
1807
1808 let _ = value_pos;
1810 }
1811 }
1812 Term::Anonymous => {
1813 return Err(XlogError::Compilation(
1814 "Anonymous wildcard '_' not allowed in rule head".to_string(),
1815 ));
1816 }
1817 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1818 }
1820 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1821 return Err(term_not_lowerable_error(
1822 "aggregate rule head",
1823 term_kind_for_lowering_error(term),
1824 ));
1825 }
1826 }
1827 }
1828
1829 if agg_specs.is_empty() {
1830 return Err(XlogError::Compilation(
1831 "Rule marked as aggregate but no aggregate terms found".to_string(),
1832 ));
1833 }
1834
1835 let mut group_input_cols: Vec<ProjectExpr> = Vec::new();
1838 let mut key_cols: Vec<usize> = Vec::new();
1839
1840 if key_src_cols.is_empty() {
1841 group_input_cols.push(ProjectExpr::Computed(
1842 Expr::Const(ConstValue::U32(0)),
1843 ScalarType::U32,
1844 ));
1845 key_cols.push(0);
1846 } else {
1847 for (i, &col) in key_src_cols.iter().enumerate() {
1848 group_input_cols.push(ProjectExpr::Column(col));
1849 key_cols.push(i);
1850 }
1851 }
1852
1853 let value_offset = group_input_cols.len();
1854 for &col in &value_src_cols {
1855 group_input_cols.push(ProjectExpr::Column(col));
1856 }
1857
1858 let group_input = RirNode::Project {
1859 input: Box::new(body),
1860 columns: group_input_cols,
1861 };
1862
1863 let mut aggs: Vec<(usize, CoreAggOp)> = Vec::with_capacity(agg_specs.len());
1865 for (op, var) in &agg_specs {
1866 let value_pos = *value_var_to_pos
1867 .get(var)
1868 .ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
1869 let value_col = value_offset + value_pos;
1870 aggs.push((value_col, convert_agg_op(op)));
1871 }
1872
1873 let groupby = RirNode::GroupBy {
1874 input: Box::new(group_input),
1875 key_cols,
1876 aggs,
1877 };
1878
1879 let key_count = if key_src_cols.is_empty() {
1884 1
1885 } else {
1886 key_vars.len()
1887 };
1888
1889 let mut final_proj: Vec<ProjectExpr> = Vec::with_capacity(head.terms.len());
1890 for term in &head.terms {
1891 match term {
1892 Term::Variable(name) => {
1893 let idx = if key_src_cols.is_empty() {
1894 return Err(XlogError::UnsafeVariable(name.clone()));
1897 } else {
1898 *key_var_to_pos
1899 .get(name)
1900 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?
1901 };
1902 final_proj.push(ProjectExpr::Column(idx));
1903 }
1904 Term::Aggregate(agg) => {
1905 let pos = *agg_to_pos
1906 .get(&(agg.op, agg.variable.clone()))
1907 .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1908 final_proj.push(ProjectExpr::Column(key_count + pos));
1909 }
1910 Term::Anonymous => {
1911 return Err(XlogError::Compilation(
1912 "Anonymous wildcard '_' not allowed in rule head".to_string(),
1913 ));
1914 }
1915 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1916 let (expr, typ) = term_to_project_const_expr(term)?;
1917 final_proj.push(ProjectExpr::Computed(expr, typ));
1918 }
1919 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1920 return Err(term_not_lowerable_error(
1921 "aggregate rule projection",
1922 term_kind_for_lowering_error(term),
1923 ));
1924 }
1925 }
1926 }
1927
1928 if final_proj.is_empty() {
1929 return Err(XlogError::Compilation(
1930 "Aggregate rule produced empty head projection".to_string(),
1931 ));
1932 }
1933
1934 Ok(RirNode::Project {
1935 input: Box::new(groupby),
1936 columns: final_proj,
1937 })
1938 }
1939
1940 pub(crate) fn infer_arith_type(
1942 &self,
1943 expr: &ArithExpr,
1944 var_env: &VariableEnv,
1945 ) -> Result<ScalarType> {
1946 match expr {
1947 ArithExpr::Variable(name) => var_env.get_type(name).ok_or_else(|| {
1948 XlogError::Compilation(format!("Unknown variable {} in arithmetic", name))
1949 }),
1950 ArithExpr::Integer(_) => Ok(ScalarType::I64),
1951 ArithExpr::Float(_) => Ok(ScalarType::F64),
1952
1953 ArithExpr::Add(l, r)
1954 | ArithExpr::Sub(l, r)
1955 | ArithExpr::Mul(l, r)
1956 | ArithExpr::Div(l, r) => {
1957 let lt = self.infer_arith_type(l, var_env)?;
1958 let rt = self.infer_arith_type(r, var_env)?;
1959
1960 if lt != rt {
1961 return Err(XlogError::Compilation(format!(
1962 "Type mismatch in arithmetic: {:?} vs {:?}. Use cast() for conversion.",
1963 lt, rt
1964 )));
1965 }
1966
1967 if !Self::is_numeric_type(<) {
1968 return Err(XlogError::Compilation(format!(
1969 "Arithmetic requires numeric type, got {:?}",
1970 lt
1971 )));
1972 }
1973
1974 Ok(lt)
1975 }
1976
1977 ArithExpr::Mod(l, r) => {
1978 let lt = self.infer_arith_type(l, var_env)?;
1979 let rt = self.infer_arith_type(r, var_env)?;
1980
1981 if lt != rt {
1982 return Err(XlogError::Compilation(format!(
1983 "Type mismatch in mod: {:?} vs {:?}",
1984 lt, rt
1985 )));
1986 }
1987
1988 if matches!(lt, ScalarType::F32 | ScalarType::F64) {
1989 return Err(XlogError::Compilation(
1990 "Modulo (%) not supported for floating point".into(),
1991 ));
1992 }
1993
1994 Ok(lt)
1995 }
1996
1997 ArithExpr::Abs(inner) => {
1998 let t = self.infer_arith_type(inner, var_env)?;
1999 if !Self::is_numeric_type(&t) {
2000 return Err(XlogError::Compilation(format!(
2001 "abs requires numeric type, got {:?}",
2002 t
2003 )));
2004 }
2005 Ok(t)
2006 }
2007
2008 ArithExpr::Min(l, r) | ArithExpr::Max(l, r) => {
2009 let lt = self.infer_arith_type(l, var_env)?;
2010 let rt = self.infer_arith_type(r, var_env)?;
2011
2012 if lt != rt {
2013 return Err(XlogError::Compilation(format!(
2014 "Type mismatch in min/max: {:?} vs {:?}",
2015 lt, rt
2016 )));
2017 }
2018
2019 if !Self::is_numeric_type(<) {
2020 return Err(XlogError::Compilation(format!(
2021 "min/max requires numeric type, got {:?}",
2022 lt
2023 )));
2024 }
2025
2026 Ok(lt)
2027 }
2028
2029 ArithExpr::Pow(base, exp) => {
2030 let base_t = self.infer_arith_type(base, var_env)?;
2031 let exp_t = self.infer_arith_type(exp, var_env)?;
2032
2033 if !Self::is_numeric_type(&base_t) || !Self::is_numeric_type(&exp_t) {
2034 return Err(XlogError::Compilation(format!(
2035 "pow requires numeric operands, got {:?} and {:?}",
2036 base_t, exp_t
2037 )));
2038 }
2039
2040 Ok(ScalarType::F64)
2042 }
2043
2044 ArithExpr::Cast(_, target) => Ok(*target),
2045
2046 ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2047 "User-defined function '{}' must be inlined before lowering",
2048 name
2049 ))),
2050
2051 ArithExpr::Conditional {
2052 then_expr,
2053 else_expr,
2054 ..
2055 } => {
2056 let then_type = self.infer_arith_type(then_expr, var_env)?;
2058 let else_type = self.infer_arith_type(else_expr, var_env)?;
2059 if then_type != else_type {
2060 return Err(XlogError::Compilation(format!(
2061 "Conditional branches have different types: {:?} vs {:?}",
2062 then_type, else_type
2063 )));
2064 }
2065 Ok(then_type)
2066 }
2067 }
2068 }
2069
2070 fn is_numeric_type(t: &ScalarType) -> bool {
2071 matches!(
2072 t,
2073 ScalarType::I32
2074 | ScalarType::I64
2075 | ScalarType::U32
2076 | ScalarType::U64
2077 | ScalarType::F32
2078 | ScalarType::F64
2079 )
2080 }
2081
2082 fn arith_to_expr(&self, arith: &ArithExpr, var_env: &VariableEnv) -> Result<Expr> {
2084 match arith {
2085 ArithExpr::Variable(name) => {
2086 let col = var_env.get_column(name).ok_or_else(|| {
2087 XlogError::Compilation(format!(
2088 "Variable {} not bound before use in arithmetic",
2089 name
2090 ))
2091 })?;
2092 Ok(Expr::Column(col))
2093 }
2094 ArithExpr::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
2095 ArithExpr::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
2096
2097 ArithExpr::Add(l, r) => Ok(Expr::Add(
2098 Box::new(self.arith_to_expr(l, var_env)?),
2099 Box::new(self.arith_to_expr(r, var_env)?),
2100 )),
2101 ArithExpr::Sub(l, r) => Ok(Expr::Sub(
2102 Box::new(self.arith_to_expr(l, var_env)?),
2103 Box::new(self.arith_to_expr(r, var_env)?),
2104 )),
2105 ArithExpr::Mul(l, r) => Ok(Expr::Mul(
2106 Box::new(self.arith_to_expr(l, var_env)?),
2107 Box::new(self.arith_to_expr(r, var_env)?),
2108 )),
2109 ArithExpr::Div(l, r) => Ok(Expr::Div(
2110 Box::new(self.arith_to_expr(l, var_env)?),
2111 Box::new(self.arith_to_expr(r, var_env)?),
2112 )),
2113 ArithExpr::Mod(l, r) => Ok(Expr::Mod(
2114 Box::new(self.arith_to_expr(l, var_env)?),
2115 Box::new(self.arith_to_expr(r, var_env)?),
2116 )),
2117
2118 ArithExpr::Abs(e) => Ok(Expr::Abs(Box::new(self.arith_to_expr(e, var_env)?))),
2119 ArithExpr::Min(l, r) => Ok(Expr::Min(
2120 Box::new(self.arith_to_expr(l, var_env)?),
2121 Box::new(self.arith_to_expr(r, var_env)?),
2122 )),
2123 ArithExpr::Max(l, r) => Ok(Expr::Max(
2124 Box::new(self.arith_to_expr(l, var_env)?),
2125 Box::new(self.arith_to_expr(r, var_env)?),
2126 )),
2127 ArithExpr::Pow(l, r) => Ok(Expr::Pow(
2128 Box::new(self.arith_to_expr(l, var_env)?),
2129 Box::new(self.arith_to_expr(r, var_env)?),
2130 )),
2131 ArithExpr::Cast(e, t) => Ok(Expr::Cast(Box::new(self.arith_to_expr(e, var_env)?), *t)),
2132
2133 ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2134 "User-defined function '{}' must be inlined before lowering",
2135 name
2136 ))),
2137
2138 ArithExpr::Conditional {
2139 cond_left,
2140 cond_op,
2141 cond_right,
2142 then_expr,
2143 else_expr,
2144 } => {
2145 let ir_cond_op = match cond_op {
2147 CompOp::Eq => CompareOp::Eq,
2148 CompOp::Ne => CompareOp::Ne,
2149 CompOp::Lt => CompareOp::Lt,
2150 CompOp::Le => CompareOp::Le,
2151 CompOp::Gt => CompareOp::Gt,
2152 CompOp::Ge => CompareOp::Ge,
2153 };
2154
2155 let condition = Expr::Compare {
2157 left: Box::new(self.arith_to_expr(cond_left, var_env)?),
2158 op: ir_cond_op,
2159 right: Box::new(self.arith_to_expr(cond_right, var_env)?),
2160 };
2161
2162 let then_ir = self.arith_to_expr(then_expr, var_env)?;
2164 let else_ir = self.arith_to_expr(else_expr, var_env)?;
2165
2166 Ok(Expr::Conditional {
2167 condition: Box::new(condition),
2168 then_expr: Box::new(then_ir),
2169 else_expr: Box::new(else_ir),
2170 })
2171 }
2172 }
2173 }
2174
2175 fn lower_is_expr(
2177 &mut self,
2178 is_expr: &IsExpr,
2179 input: RirNode,
2180 var_env: &mut VariableEnv,
2181 ) -> Result<RirNode> {
2182 if var_env.contains(&is_expr.target) {
2184 return Err(XlogError::Compilation(format!(
2185 "Variable {} already bound; 'is' requires fresh variable",
2186 is_expr.target
2187 )));
2188 }
2189
2190 for var in is_expr.expr.variables() {
2192 if !var_env.contains(var) {
2193 return Err(XlogError::Compilation(format!(
2194 "Variable {} used in arithmetic but not bound",
2195 var
2196 )));
2197 }
2198 }
2199
2200 let result_type = self.infer_arith_type(&is_expr.expr, var_env)?;
2202
2203 let ir_expr = self.arith_to_expr(&is_expr.expr, var_env)?;
2205
2206 let num_cols = var_env.column_count();
2208 let mut proj_exprs: Vec<ProjectExpr> = (0..num_cols).map(ProjectExpr::Column).collect();
2209 proj_exprs.push(ProjectExpr::Computed(ir_expr, result_type));
2210
2211 var_env.bind(&is_expr.target, num_cols, result_type);
2213
2214 Ok(RirNode::Project {
2215 input: Box::new(input),
2216 columns: proj_exprs,
2217 })
2218 }
2219}
2220
2221pub(crate) struct VariableEnv {
2223 occurrences: HashMap<String, Vec<(String, usize, usize)>>,
2225 total_cols: usize,
2227 types: HashMap<String, ScalarType>,
2229}
2230
2231impl VariableEnv {
2232 fn new() -> Self {
2233 Self {
2234 occurrences: HashMap::new(),
2235 total_cols: 0,
2236 types: HashMap::new(),
2237 }
2238 }
2239
2240 fn add_occurrence(&mut self, var: &str, pred: String, atom_pos: usize, global_col: usize) {
2241 self.occurrences
2242 .entry(var.to_string())
2243 .or_default()
2244 .push((pred, atom_pos, global_col));
2245 }
2246
2247 fn get_column(&self, var: &str) -> Option<usize> {
2248 self.occurrences
2249 .get(var)
2250 .and_then(|occs| occs.first())
2251 .map(|(_, _, col)| *col)
2252 }
2253
2254 fn bind(&mut self, name: &str, column: usize, typ: ScalarType) {
2256 self.types.insert(name.to_string(), typ);
2257 self.occurrences
2259 .entry(name.to_string())
2260 .or_default()
2261 .push(("".to_string(), 0, column));
2262 if column >= self.total_cols {
2265 self.total_cols = column + 1;
2266 }
2267 }
2268
2269 fn get_type(&self, name: &str) -> Option<ScalarType> {
2271 self.types.get(name).copied()
2272 }
2273
2274 fn contains(&self, name: &str) -> bool {
2276 self.occurrences.contains_key(name)
2277 }
2278
2279 fn column_count(&self) -> usize {
2281 self.total_cols
2282 }
2283}
2284
2285fn infer_term_type(term: &Term) -> ScalarType {
2287 match term {
2288 Term::Variable(_) | Term::Anonymous => ScalarType::U64, Term::Integer(i) => {
2290 if *i >= 0 && *i <= u32::MAX as i64 {
2291 ScalarType::U32
2292 } else {
2293 ScalarType::I64
2294 }
2295 }
2296 Term::Float(_) => ScalarType::F64,
2297 Term::String(_) | Term::Symbol(_) => ScalarType::Symbol,
2298 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
2299 ScalarType::U64
2300 }
2301 Term::Aggregate(agg) => match agg.op {
2302 AggOp::Count => ScalarType::U32,
2303 AggOp::Sum => ScalarType::U64,
2304 AggOp::Min | AggOp::Max => ScalarType::U32,
2305 AggOp::LogSumExp => ScalarType::F64,
2306 },
2307 }
2308}
2309
2310fn sort_labels_from_terms(terms: &[Term]) -> Vec<String> {
2311 terms
2312 .iter()
2313 .enumerate()
2314 .map(|(idx, term)| match term {
2315 Term::Variable(name) if !name.trim().is_empty() => name.clone(),
2316 Term::Aggregate(agg) => format!("{:?}_{}", agg.op, agg.variable),
2317 Term::List(_) => format!("list{}", idx),
2318 Term::Cons { .. } => format!("cons{}", idx),
2319 Term::Compound { functor, .. } => functor.clone(),
2320 Term::PredRef(name) => name.clone(),
2321 _ => format!("c{}", idx),
2322 })
2323 .collect()
2324}
2325
2326fn term_to_const_value(term: &Term) -> Option<ConstValue> {
2328 match term {
2329 Term::Integer(i) => Some(ConstValue::I64(*i)),
2330 Term::Float(f) => Some(ConstValue::F64(*f)),
2331 Term::String(s) => Some(ConstValue::Symbol(s.clone())),
2332 Term::Symbol(id) => Some(ConstValue::Symbol(symbol::resolve(*id))),
2333 Term::Variable(_)
2334 | Term::Anonymous
2335 | Term::Aggregate(_)
2336 | Term::List(_)
2337 | Term::Cons { .. }
2338 | Term::Compound { .. }
2339 | Term::PredRef(_) => None,
2340 }
2341}
2342
2343fn term_to_typed_const_value(term: &Term, expected: ScalarType) -> Result<Option<ConstValue>> {
2344 let const_val = match term {
2345 Term::Integer(i) => match expected {
2346 ScalarType::U32 => {
2347 if *i >= 0 && *i <= u32::MAX as i64 {
2348 ConstValue::U32(*i as u32)
2349 } else {
2350 return Err(XlogError::Compilation(format!(
2351 "Integer literal {} out of range for {:?}",
2352 i, expected
2353 )));
2354 }
2355 }
2356 ScalarType::U64 => {
2357 if *i >= 0 {
2358 ConstValue::U64(*i as u64)
2359 } else {
2360 return Err(XlogError::Compilation(format!(
2361 "Integer literal {} out of range for {:?}",
2362 i, expected
2363 )));
2364 }
2365 }
2366 ScalarType::I32 => {
2367 if *i >= i32::MIN as i64 && *i <= i32::MAX as i64 {
2368 ConstValue::I32(*i as i32)
2369 } else {
2370 return Err(XlogError::Compilation(format!(
2371 "Integer literal {} out of range for {:?}",
2372 i, expected
2373 )));
2374 }
2375 }
2376 ScalarType::I64 => ConstValue::I64(*i),
2377 ScalarType::F32 => {
2378 let value = *i as f64;
2379 if value < f32::MIN as f64 || value > f32::MAX as f64 {
2380 return Err(XlogError::Compilation(format!(
2381 "Integer literal {} out of range for {:?}",
2382 i, expected
2383 )));
2384 }
2385 ConstValue::F32(value as f32)
2386 }
2387 ScalarType::F64 => ConstValue::F64(*i as f64),
2388 ScalarType::Bool => {
2389 if *i == 0 || *i == 1 {
2390 ConstValue::Bool(*i == 1)
2391 } else {
2392 return Err(XlogError::Compilation(format!(
2393 "Integer literal {} not valid for {:?}",
2394 i, expected
2395 )));
2396 }
2397 }
2398 ScalarType::Symbol => {
2399 return Err(XlogError::Compilation(format!(
2400 "Integer literal {} not valid for {:?}",
2401 i, expected
2402 )));
2403 }
2404 },
2405 Term::Float(f) => match expected {
2406 ScalarType::F32 => {
2407 if !f.is_finite() {
2408 return Err(XlogError::Compilation(format!(
2409 "Float literal {} not valid for {:?}",
2410 f, expected
2411 )));
2412 }
2413 if *f < f32::MIN as f64 || *f > f32::MAX as f64 {
2414 return Err(XlogError::Compilation(format!(
2415 "Float literal {} out of range for {:?}",
2416 f, expected
2417 )));
2418 }
2419 ConstValue::F32(*f as f32)
2420 }
2421 ScalarType::F64 => ConstValue::F64(*f),
2422 ScalarType::U32
2423 | ScalarType::U64
2424 | ScalarType::I32
2425 | ScalarType::I64
2426 | ScalarType::Bool
2427 | ScalarType::Symbol => {
2428 return Err(XlogError::Compilation(format!(
2429 "Float literal {} not valid for {:?}",
2430 f, expected
2431 )));
2432 }
2433 },
2434 Term::String(s) => {
2435 if expected == ScalarType::Symbol {
2436 ConstValue::Symbol(s.clone())
2437 } else {
2438 return Err(XlogError::Compilation(format!(
2439 "String literal {} not valid for {:?}",
2440 s, expected
2441 )));
2442 }
2443 }
2444 Term::Symbol(id) => {
2445 if expected == ScalarType::Symbol {
2446 ConstValue::Symbol(symbol::resolve(*id))
2447 } else {
2448 return Err(XlogError::Compilation(format!(
2449 "Symbol literal {} not valid for {:?}",
2450 symbol::resolve(*id),
2451 expected
2452 )));
2453 }
2454 }
2455 Term::Variable(_)
2456 | Term::Anonymous
2457 | Term::Aggregate(_)
2458 | Term::List(_)
2459 | Term::Cons { .. }
2460 | Term::Compound { .. }
2461 | Term::PredRef(_) => return Ok(None),
2462 };
2463
2464 Ok(Some(const_val))
2465}
2466
2467fn term_to_project_const_expr(term: &Term) -> Result<(Expr, ScalarType)> {
2468 match term {
2469 Term::Integer(i) => {
2470 if *i >= 0 && *i <= u32::MAX as i64 {
2471 Ok((Expr::Const(ConstValue::U32(*i as u32)), ScalarType::U32))
2472 } else {
2473 Ok((Expr::Const(ConstValue::I64(*i)), ScalarType::I64))
2474 }
2475 }
2476 Term::Float(f) => Ok((Expr::Const(ConstValue::F64(*f)), ScalarType::F64)),
2477 Term::String(s) => Ok((
2478 Expr::Const(ConstValue::Symbol(s.clone())),
2479 ScalarType::Symbol,
2480 )),
2481 Term::Symbol(id) => Ok((
2482 Expr::Const(ConstValue::Symbol(symbol::resolve(*id))),
2483 ScalarType::Symbol,
2484 )),
2485 Term::Variable(_)
2486 | Term::Anonymous
2487 | Term::Aggregate(_)
2488 | Term::List(_)
2489 | Term::Cons { .. }
2490 | Term::Compound { .. }
2491 | Term::PredRef(_) => Err(XlogError::Compilation("Expected constant term".to_string())),
2492 }
2493}
2494
2495fn convert_agg_op(op: &AggOp) -> CoreAggOp {
2497 match op {
2498 AggOp::Count => CoreAggOp::Count,
2499 AggOp::Sum => CoreAggOp::Sum,
2500 AggOp::Min => CoreAggOp::Min,
2501 AggOp::Max => CoreAggOp::Max,
2502 AggOp::LogSumExp => CoreAggOp::LogSumExp,
2503 }
2504}
2505
2506#[cfg(test)]
2510mod arith_type_tests {
2511 use super::*;
2512 use crate::ast::ArithExpr;
2513
2514 #[test]
2515 fn test_arith_type_inference_same_type() {
2516 let lowerer = Lowerer::new();
2518 let mut var_env = VariableEnv::new();
2519 var_env.bind("X", 0, ScalarType::I64);
2520 var_env.bind("Y", 1, ScalarType::I64);
2521
2522 let expr = ArithExpr::Add(
2523 Box::new(ArithExpr::Variable("X".to_string())),
2524 Box::new(ArithExpr::Variable("Y".to_string())),
2525 );
2526 let result = lowerer.infer_arith_type(&expr, &var_env);
2527 assert!(result.is_ok());
2528 assert_eq!(result.unwrap(), ScalarType::I64);
2529 }
2530
2531 #[test]
2532 fn test_arith_type_inference_mismatch() {
2533 let lowerer = Lowerer::new();
2535 let mut var_env = VariableEnv::new();
2536 var_env.bind("X", 0, ScalarType::I64);
2537 var_env.bind("Y", 1, ScalarType::F64);
2538
2539 let expr = ArithExpr::Add(
2540 Box::new(ArithExpr::Variable("X".to_string())),
2541 Box::new(ArithExpr::Variable("Y".to_string())),
2542 );
2543 let result = lowerer.infer_arith_type(&expr, &var_env);
2544 assert!(result.is_err());
2545 }
2546}
2547
2548#[cfg(test)]
2549mod tests {
2550 use super::*;
2551 use crate::ast::*;
2552
2553 fn pred_decl(name: &str, types: Vec<ScalarType>) -> PredDecl {
2554 let type_refs: Vec<TypeRef> = types.into_iter().map(TypeRef::Scalar).collect();
2555 let columns = type_refs
2556 .iter()
2557 .cloned()
2558 .map(|typ| PredColumn { name: None, typ })
2559 .collect();
2560 PredDecl {
2561 name: name.to_string(),
2562 types: type_refs,
2563 columns,
2564 is_private: false,
2565 }
2566 }
2567
2568 fn edge_atom(x: &str, y: &str) -> Atom {
2570 Atom {
2571 predicate: "edge".to_string(),
2572 terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2573 }
2574 }
2575
2576 fn reach_atom(x: &str, y: &str) -> Atom {
2578 Atom {
2579 predicate: "reach".to_string(),
2580 terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2581 }
2582 }
2583
2584 fn node_atom(x: &str) -> Atom {
2586 Atom {
2587 predicate: "node".to_string(),
2588 terms: vec![Term::Variable(x.to_string())],
2589 }
2590 }
2591
2592 #[test]
2593 fn test_lowerer_new() {
2594 let lowerer = Lowerer::new();
2595 assert!(lowerer.schemas.is_empty());
2596 assert!(lowerer.strata.is_empty());
2597 assert_eq!(lowerer.next_rel_id, 0);
2598 }
2599
2600 #[test]
2601 fn test_get_or_create_rel_id() {
2602 let mut lowerer = Lowerer::new();
2603 let id1 = lowerer.get_or_create_rel_id("edge");
2604 let id2 = lowerer.get_or_create_rel_id("reach");
2605 let id3 = lowerer.get_or_create_rel_id("edge");
2606
2607 assert_eq!(id1, RelId(0));
2608 assert_eq!(id2, RelId(1));
2609 assert_eq!(id3, RelId(0)); }
2611
2612 #[test]
2613 fn test_infer_schemas_from_facts() {
2614 let mut program = Program::new();
2615 program.rules.push(Rule {
2616 head: Atom {
2617 predicate: "edge".to_string(),
2618 terms: vec![Term::Integer(1), Term::Integer(2)],
2619 },
2620 body: vec![],
2621 });
2622
2623 let mut lowerer = Lowerer::new();
2624 lowerer.infer_schemas(&program).unwrap();
2625
2626 assert!(lowerer.schemas.contains_key("edge"));
2627 let schema = lowerer.schemas.get("edge").unwrap();
2628 assert_eq!(schema.arity(), 2);
2629 }
2630
2631 #[test]
2632 fn test_lower_simple_rule() {
2633 let rule = Rule {
2635 head: reach_atom("X", "Y"),
2636 body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
2637 };
2638
2639 let mut lowerer = Lowerer::new();
2640 lowerer.schemas.insert(
2641 "edge".to_string(),
2642 Schema::new(vec![
2643 ("c0".to_string(), ScalarType::U32),
2644 ("c1".to_string(), ScalarType::U32),
2645 ]),
2646 );
2647
2648 let result = lowerer.lower_rule(&rule);
2649 assert!(result.is_ok());
2650
2651 let node = result.unwrap();
2652 assert!(matches!(node, RirNode::Scan { .. }));
2654 }
2655
2656 #[test]
2657 fn test_lower_join_rule() {
2658 let rule = Rule {
2660 head: Atom {
2661 predicate: "reach".to_string(),
2662 terms: vec![
2663 Term::Variable("X".to_string()),
2664 Term::Variable("Z".to_string()),
2665 ],
2666 },
2667 body: vec![
2668 BodyLiteral::Positive(reach_atom("X", "Y")),
2669 BodyLiteral::Positive(edge_atom("Y", "Z")),
2670 ],
2671 };
2672
2673 let mut lowerer = Lowerer::new();
2674 lowerer.schemas.insert(
2675 "reach".to_string(),
2676 Schema::new(vec![
2677 ("c0".to_string(), ScalarType::U32),
2678 ("c1".to_string(), ScalarType::U32),
2679 ]),
2680 );
2681 lowerer.schemas.insert(
2682 "edge".to_string(),
2683 Schema::new(vec![
2684 ("c0".to_string(), ScalarType::U32),
2685 ("c1".to_string(), ScalarType::U32),
2686 ]),
2687 );
2688
2689 let result = lowerer.lower_rule(&rule);
2690 assert!(result.is_ok());
2691
2692 let node = result.unwrap();
2693 if let RirNode::Project { input, columns } = node {
2695 assert_eq!(
2697 columns,
2698 vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
2699 );
2700 assert!(matches!(*input, RirNode::Join { .. }));
2701 if let RirNode::Join {
2702 left_keys,
2703 right_keys,
2704 ..
2705 } = *input
2706 {
2707 assert_eq!(left_keys, vec![1]); assert_eq!(right_keys, vec![0]); }
2710 } else {
2711 panic!("Expected Project node");
2712 }
2713 }
2714
2715 #[test]
2716 fn test_join_order_prefers_smaller_relation() {
2717 let rule = Rule {
2719 head: Atom {
2720 predicate: "out".to_string(),
2721 terms: vec![Term::Variable("X".to_string())],
2722 },
2723 body: vec![
2724 BodyLiteral::Positive(Atom {
2725 predicate: "big".to_string(),
2726 terms: vec![Term::Variable("X".to_string())],
2727 }),
2728 BodyLiteral::Positive(Atom {
2729 predicate: "small".to_string(),
2730 terms: vec![Term::Variable("X".to_string())],
2731 }),
2732 ],
2733 };
2734
2735 let mut lowerer = Lowerer::new();
2736 lowerer.schemas.insert(
2737 "big".to_string(),
2738 Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2739 );
2740 lowerer.schemas.insert(
2741 "small".to_string(),
2742 Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2743 );
2744
2745 let big_id = lowerer.get_or_create_rel_id("big");
2747 let small_id = lowerer.get_or_create_rel_id("small");
2748 assert_eq!(big_id, RelId(0));
2749 assert_eq!(small_id, RelId(1));
2750
2751 lowerer.est_cardinality.insert("big".to_string(), 10_000);
2753 lowerer.est_cardinality.insert("small".to_string(), 10);
2754
2755 let node = lowerer.lower_rule(&rule).unwrap();
2756 let join = match node {
2757 RirNode::Project { input, .. } => *input,
2758 other => other,
2759 };
2760
2761 match join {
2762 RirNode::Join { left, right, .. } => {
2763 assert!(matches!(*left, RirNode::Scan { rel } if rel == big_id));
2765 assert!(matches!(*right, RirNode::Scan { rel } if rel == small_id));
2766 }
2767 other => panic!("Expected Join node, got {:?}", other),
2768 }
2769 }
2770
2771 #[test]
2772 fn test_lower_negation() {
2773 let rule = Rule {
2775 head: Atom {
2776 predicate: "isolated".to_string(),
2777 terms: vec![Term::Variable("X".to_string())],
2778 },
2779 body: vec![
2780 BodyLiteral::Positive(node_atom("X")),
2781 BodyLiteral::Negated(Atom {
2782 predicate: "edge".to_string(),
2783 terms: vec![
2784 Term::Variable("X".to_string()),
2785 Term::Variable("_".to_string()),
2786 ],
2787 }),
2788 ],
2789 };
2790
2791 let mut lowerer = Lowerer::new();
2792 lowerer.schemas.insert(
2793 "node".to_string(),
2794 Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2795 );
2796 lowerer.schemas.insert(
2797 "edge".to_string(),
2798 Schema::new(vec![
2799 ("c0".to_string(), ScalarType::U32),
2800 ("c1".to_string(), ScalarType::U32),
2801 ]),
2802 );
2803
2804 let result = lowerer.lower_rule(&rule);
2805 assert!(result.is_ok());
2806
2807 let node = result.unwrap();
2809 fn contains_diff_or_semi(node: &RirNode) -> bool {
2811 match node {
2812 RirNode::Diff { .. } => true,
2813 RirNode::Join {
2814 join_type: JoinType::Semi,
2815 ..
2816 } => true,
2817 RirNode::Join { left, right, .. } => {
2818 contains_diff_or_semi(left) || contains_diff_or_semi(right)
2819 }
2820 RirNode::Project { input, .. } => contains_diff_or_semi(input),
2821 RirNode::Filter { input, .. } => contains_diff_or_semi(input),
2822 _ => false,
2823 }
2824 }
2825 assert!(contains_diff_or_semi(&node));
2826 }
2827
2828 #[test]
2829 fn test_lower_comparison() {
2830 let rule = Rule {
2832 head: Atom {
2833 predicate: "greater".to_string(),
2834 terms: vec![
2835 Term::Variable("X".to_string()),
2836 Term::Variable("Y".to_string()),
2837 ],
2838 },
2839 body: vec![
2840 BodyLiteral::Positive(Atom {
2841 predicate: "pair".to_string(),
2842 terms: vec![
2843 Term::Variable("X".to_string()),
2844 Term::Variable("Y".to_string()),
2845 ],
2846 }),
2847 BodyLiteral::Comparison(Comparison {
2848 left: Term::Variable("X".to_string()),
2849 op: CompOp::Gt,
2850 right: Term::Variable("Y".to_string()),
2851 }),
2852 ],
2853 };
2854
2855 let mut lowerer = Lowerer::new();
2856 lowerer.schemas.insert(
2857 "pair".to_string(),
2858 Schema::new(vec![
2859 ("c0".to_string(), ScalarType::U32),
2860 ("c1".to_string(), ScalarType::U32),
2861 ]),
2862 );
2863
2864 let result = lowerer.lower_rule(&rule);
2865 assert!(result.is_ok());
2866
2867 let node = result.unwrap();
2868 fn contains_filter(node: &RirNode) -> bool {
2870 match node {
2871 RirNode::Filter { .. } => true,
2872 RirNode::Project { input, .. } => contains_filter(input),
2873 RirNode::Join { left, right, .. } => {
2874 contains_filter(left) || contains_filter(right)
2875 }
2876 _ => false,
2877 }
2878 }
2879 assert!(contains_filter(&node));
2880 }
2881
2882 #[test]
2883 fn test_lower_constant_filter() {
2884 let rule = Rule {
2886 head: Atom {
2887 predicate: "specific_edge".to_string(),
2888 terms: vec![Term::Variable("Y".to_string())],
2889 },
2890 body: vec![BodyLiteral::Positive(Atom {
2891 predicate: "edge".to_string(),
2892 terms: vec![Term::Integer(1), Term::Variable("Y".to_string())],
2893 })],
2894 };
2895
2896 let mut lowerer = Lowerer::new();
2897 lowerer.schemas.insert(
2898 "edge".to_string(),
2899 Schema::new(vec![
2900 ("c0".to_string(), ScalarType::U32),
2901 ("c1".to_string(), ScalarType::U32),
2902 ]),
2903 );
2904
2905 let result = lowerer.lower_rule(&rule);
2906 assert!(result.is_ok());
2907
2908 let node = result.unwrap();
2909 fn has_const_filter(node: &RirNode) -> bool {
2911 match node {
2912 RirNode::Filter {
2913 predicate: Expr::Compare { right, .. },
2914 ..
2915 } => matches!(**right, Expr::Const(_)),
2916 RirNode::Project { input, .. } => has_const_filter(input),
2917 _ => false,
2918 }
2919 }
2920 assert!(has_const_filter(&node));
2921 }
2922
2923 #[test]
2924 fn test_lower_repeated_variable_filter() {
2925 let rule = Rule {
2927 head: Atom {
2928 predicate: "self_loop".to_string(),
2929 terms: vec![Term::Variable("X".to_string())],
2930 },
2931 body: vec![BodyLiteral::Positive(Atom {
2932 predicate: "edge".to_string(),
2933 terms: vec![
2934 Term::Variable("X".to_string()),
2935 Term::Variable("X".to_string()),
2936 ],
2937 })],
2938 };
2939
2940 let mut lowerer = Lowerer::new();
2941 lowerer.schemas.insert(
2942 "edge".to_string(),
2943 Schema::new(vec![
2944 ("c0".to_string(), ScalarType::U32),
2945 ("c1".to_string(), ScalarType::U32),
2946 ]),
2947 );
2948
2949 let node = lowerer.lower_rule(&rule).expect("lower_rule failed");
2950
2951 fn has_col_eq_filter(node: &RirNode) -> bool {
2952 match node {
2953 RirNode::Filter { predicate, .. } => match predicate {
2954 Expr::Compare {
2955 left,
2956 op: CompareOp::Eq,
2957 right,
2958 } => {
2959 matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2960 || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2961 }
2962 Expr::And(exprs) => exprs.iter().any(|e| match e {
2963 Expr::Compare {
2964 left,
2965 op: CompareOp::Eq,
2966 right,
2967 } => {
2968 matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2969 || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2970 }
2971 _ => false,
2972 }),
2973 _ => false,
2974 },
2975 RirNode::Project { input, .. } => has_col_eq_filter(input),
2976 _ => false,
2977 }
2978 }
2979
2980 assert!(has_col_eq_filter(&node));
2981 }
2982
2983 #[test]
2984 fn test_lower_program_simple() {
2985 let mut program = Program::new();
2986
2987 program.rules.push(Rule {
2989 head: Atom {
2990 predicate: "edge".to_string(),
2991 terms: vec![Term::Integer(1), Term::Integer(2)],
2992 },
2993 body: vec![],
2994 });
2995
2996 program.rules.push(Rule {
2998 head: reach_atom("X", "Y"),
2999 body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
3000 });
3001
3002 let mut lowerer = Lowerer::new();
3003 lowerer.set_strata(vec![vec!["edge".to_string()], vec!["reach".to_string()]]);
3004
3005 let result = lowerer.lower_program(&program);
3006 assert!(result.is_ok());
3007
3008 let plan = result.unwrap();
3009 assert!(!plan.sccs.is_empty());
3010 }
3011
3012 #[test]
3013 fn test_variable_env() {
3014 let mut env = VariableEnv::new();
3015 env.add_occurrence("X", "edge".to_string(), 0, 0);
3016 env.add_occurrence("Y", "edge".to_string(), 1, 1);
3017 env.add_occurrence("Y", "node".to_string(), 0, 2);
3018
3019 assert_eq!(env.get_column("X"), Some(0));
3020 assert_eq!(env.get_column("Y"), Some(1)); assert_eq!(env.get_column("Z"), None);
3022 }
3023
3024 #[test]
3025 fn test_infer_term_type() {
3026 assert_eq!(
3027 infer_term_type(&Term::Variable("X".to_string())),
3028 ScalarType::U64
3029 );
3030 assert_eq!(infer_term_type(&Term::Integer(42)), ScalarType::U32);
3031 assert_eq!(infer_term_type(&Term::Integer(i64::MAX)), ScalarType::I64);
3032 assert_eq!(infer_term_type(&Term::Float(3.25)), ScalarType::F64);
3033 assert_eq!(
3034 infer_term_type(&Term::Symbol(symbol::intern("foo"))),
3035 ScalarType::Symbol
3036 );
3037 }
3038
3039 #[test]
3040 fn test_convert_agg_op() {
3041 assert_eq!(convert_agg_op(&AggOp::Count), CoreAggOp::Count);
3042 assert_eq!(convert_agg_op(&AggOp::Sum), CoreAggOp::Sum);
3043 assert_eq!(convert_agg_op(&AggOp::Min), CoreAggOp::Min);
3044 assert_eq!(convert_agg_op(&AggOp::Max), CoreAggOp::Max);
3045 assert_eq!(convert_agg_op(&AggOp::LogSumExp), CoreAggOp::LogSumExp);
3046 }
3047
3048 #[test]
3049 fn test_variable_env_bind_updates_total_cols() {
3050 let mut env = VariableEnv::new();
3052 env.total_cols = 2; env.bind("A", 2, ScalarType::I64);
3056 assert_eq!(
3057 env.column_count(),
3058 3,
3059 "total_cols should be 3 after first bind"
3060 );
3061 assert_eq!(env.get_column("A"), Some(2));
3062
3063 env.bind("B", 3, ScalarType::I64);
3065 assert_eq!(
3066 env.column_count(),
3067 4,
3068 "total_cols should be 4 after second bind"
3069 );
3070 assert_eq!(env.get_column("B"), Some(3));
3071 }
3072
3073 #[test]
3074 fn test_lower_chained_is_expressions() {
3075 let rule = Rule {
3078 head: Atom {
3079 predicate: "result".to_string(),
3080 terms: vec![
3081 Term::Variable("A".to_string()),
3082 Term::Variable("B".to_string()),
3083 ],
3084 },
3085 body: vec![
3086 BodyLiteral::Positive(Atom {
3087 predicate: "input".to_string(),
3088 terms: vec![
3089 Term::Variable("X".to_string()),
3090 Term::Variable("Y".to_string()),
3091 ],
3092 }),
3093 BodyLiteral::IsExpr(IsExpr {
3094 target: "A".to_string(),
3095 expr: ArithExpr::Add(
3096 Box::new(ArithExpr::Variable("X".to_string())),
3097 Box::new(ArithExpr::Variable("Y".to_string())),
3098 ),
3099 }),
3100 BodyLiteral::IsExpr(IsExpr {
3101 target: "B".to_string(),
3102 expr: ArithExpr::Mul(
3103 Box::new(ArithExpr::Variable("A".to_string())),
3104 Box::new(ArithExpr::Integer(2)),
3105 ),
3106 }),
3107 ],
3108 };
3109
3110 let mut lowerer = Lowerer::new();
3111 lowerer.schemas.insert(
3112 "input".to_string(),
3113 Schema::new(vec![
3114 ("c0".to_string(), ScalarType::I64),
3115 ("c1".to_string(), ScalarType::I64),
3116 ]),
3117 );
3118
3119 let result = lowerer.lower_rule(&rule);
3120 assert!(
3121 result.is_ok(),
3122 "Lowering chained is-expressions should succeed: {:?}",
3123 result.err()
3124 );
3125
3126 let node = result.unwrap();
3127
3128 fn count_projects(node: &RirNode) -> usize {
3136 match node {
3137 RirNode::Project { input, .. } => 1 + count_projects(input),
3138 _ => 0,
3139 }
3140 }
3141
3142 let project_count = count_projects(&node);
3144 assert!(
3145 project_count >= 2,
3146 "Expected at least 2 Project nodes for chained is-exprs, got {}",
3147 project_count
3148 );
3149
3150 if let RirNode::Project { columns, .. } = &node {
3152 assert_eq!(columns.len(), 2, "Head has 2 variables");
3153 assert_eq!(columns[0], ProjectExpr::Column(2), "A should be column 2");
3155 assert_eq!(columns[1], ProjectExpr::Column(3), "B should be column 3");
3156 } else {
3157 panic!("Expected top-level Project node");
3158 }
3159 }
3160
3161 #[test]
3162 fn test_u64_comparison_type_from_pred_decl() {
3163 let mut program = Program::new();
3165
3166 program.predicates.push(pred_decl(
3168 "count_data",
3169 vec![ScalarType::Symbol, ScalarType::U64],
3170 ));
3171
3172 program.rules.push(Rule {
3174 head: Atom {
3175 predicate: "count_data".to_string(),
3176 terms: vec![
3177 Term::Symbol(xlog_core::symbol::intern("alice")),
3178 Term::Integer(5),
3179 ],
3180 },
3181 body: vec![],
3182 });
3183
3184 program.predicates.push(pred_decl(
3186 "big_count",
3187 vec![ScalarType::Symbol, ScalarType::U64],
3188 ));
3189
3190 program.rules.push(Rule {
3192 head: Atom {
3193 predicate: "big_count".to_string(),
3194 terms: vec![
3195 Term::Variable("Name".to_string()),
3196 Term::Variable("Count".to_string()),
3197 ],
3198 },
3199 body: vec![
3200 BodyLiteral::Positive(Atom {
3201 predicate: "count_data".to_string(),
3202 terms: vec![
3203 Term::Variable("Name".to_string()),
3204 Term::Variable("Count".to_string()),
3205 ],
3206 }),
3207 BodyLiteral::Comparison(Comparison {
3208 left: Term::Variable("Count".to_string()),
3209 op: CompOp::Ge,
3210 right: Term::Integer(3),
3211 }),
3212 ],
3213 });
3214
3215 let mut lowerer = Lowerer::new();
3216 lowerer.infer_schemas(&program).unwrap();
3217
3218 let schema = lowerer
3220 .schemas
3221 .get("count_data")
3222 .expect("schema for count_data");
3223 assert_eq!(
3224 schema.column_type(0),
3225 Some(ScalarType::Symbol),
3226 "First column should be Symbol"
3227 );
3228 assert_eq!(
3229 schema.column_type(1),
3230 Some(ScalarType::U64),
3231 "Second column should be U64"
3232 );
3233
3234 lowerer.set_strata(vec![
3236 vec!["count_data".to_string()],
3237 vec!["big_count".to_string()],
3238 ]);
3239 lowerer.build_sccs(&program);
3240
3241 let rule = &program.rules[1]; let result = lowerer.lower_rule(rule);
3243 assert!(
3244 result.is_ok(),
3245 "Lowering should succeed: {:?}",
3246 result.err()
3247 );
3248
3249 fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3251 match node {
3252 RirNode::Filter { predicate, input } => {
3253 if let Expr::Compare { right, .. } = predicate {
3254 if let Expr::Const(val) = right.as_ref() {
3255 return Some(val);
3256 }
3257 }
3258 find_compare_const(input)
3259 }
3260 RirNode::Project { input, .. } => find_compare_const(input),
3261 RirNode::Join { left, right, .. } => {
3262 find_compare_const(left).or_else(|| find_compare_const(right))
3263 }
3264 _ => None,
3265 }
3266 }
3267
3268 let node = result.unwrap();
3269 let const_val = find_compare_const(&node);
3270 assert!(const_val.is_some(), "Should find a constant in comparison");
3271
3272 match const_val.unwrap() {
3274 ConstValue::U64(v) => assert_eq!(*v, 3, "Value should be 3"),
3275 other => panic!("Expected U64(3), got {:?}", other),
3276 }
3277 }
3278
3279 #[test]
3280 fn test_u64_comparison_with_aggregation() {
3281 use crate::ast::AggExpr;
3282
3283 let mut program = Program::new();
3285
3286 program.predicates.push(pred_decl(
3288 "reports_to",
3289 vec![ScalarType::Symbol, ScalarType::Symbol],
3290 ));
3291
3292 program.rules.push(Rule {
3294 head: Atom {
3295 predicate: "reports_to".to_string(),
3296 terms: vec![
3297 Term::Symbol(xlog_core::symbol::intern("alice")),
3298 Term::Symbol(xlog_core::symbol::intern("bob")),
3299 ],
3300 },
3301 body: vec![],
3302 });
3303 program.rules.push(Rule {
3304 head: Atom {
3305 predicate: "reports_to".to_string(),
3306 terms: vec![
3307 Term::Symbol(xlog_core::symbol::intern("carol")),
3308 Term::Symbol(xlog_core::symbol::intern("bob")),
3309 ],
3310 },
3311 body: vec![],
3312 });
3313
3314 program.predicates.push(pred_decl(
3316 "direct_count",
3317 vec![ScalarType::Symbol, ScalarType::U64],
3318 ));
3319
3320 program.rules.push(Rule {
3322 head: Atom {
3323 predicate: "direct_count".to_string(),
3324 terms: vec![
3325 Term::Variable("Mgr".to_string()),
3326 Term::Aggregate(AggExpr {
3327 op: AggOp::Count,
3328 variable: "Emp".to_string(),
3329 }),
3330 ],
3331 },
3332 body: vec![BodyLiteral::Positive(Atom {
3333 predicate: "reports_to".to_string(),
3334 terms: vec![
3335 Term::Variable("Emp".to_string()),
3336 Term::Variable("Mgr".to_string()),
3337 ],
3338 })],
3339 });
3340
3341 program.predicates.push(pred_decl(
3343 "big_manager",
3344 vec![ScalarType::Symbol, ScalarType::U64],
3345 ));
3346
3347 program.rules.push(Rule {
3349 head: Atom {
3350 predicate: "big_manager".to_string(),
3351 terms: vec![
3352 Term::Variable("Mgr".to_string()),
3353 Term::Variable("Count".to_string()),
3354 ],
3355 },
3356 body: vec![
3357 BodyLiteral::Positive(Atom {
3358 predicate: "direct_count".to_string(),
3359 terms: vec![
3360 Term::Variable("Mgr".to_string()),
3361 Term::Variable("Count".to_string()),
3362 ],
3363 }),
3364 BodyLiteral::Comparison(Comparison {
3365 left: Term::Variable("Count".to_string()),
3366 op: CompOp::Ge,
3367 right: Term::Integer(2),
3368 }),
3369 ],
3370 });
3371
3372 let mut lowerer = Lowerer::new();
3373 lowerer.infer_schemas(&program).unwrap();
3374
3375 let schema = lowerer
3377 .schemas
3378 .get("direct_count")
3379 .expect("schema for direct_count");
3380 assert_eq!(
3381 schema.column_type(0),
3382 Some(ScalarType::Symbol),
3383 "First column should be Symbol"
3384 );
3385 assert_eq!(
3386 schema.column_type(1),
3387 Some(ScalarType::U64),
3388 "Second column should be U64"
3389 );
3390
3391 lowerer.set_strata(vec![
3392 vec!["reports_to".to_string()],
3393 vec!["direct_count".to_string()],
3394 vec!["big_manager".to_string()],
3395 ]);
3396 lowerer.build_sccs(&program);
3397
3398 let big_manager_rule = &program.rules[3];
3400 let result = lowerer.lower_rule(big_manager_rule);
3401 assert!(
3402 result.is_ok(),
3403 "Lowering should succeed: {:?}",
3404 result.err()
3405 );
3406
3407 fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3409 match node {
3410 RirNode::Filter { predicate, input } => {
3411 if let Expr::Compare { right, .. } = predicate {
3412 if let Expr::Const(val) = right.as_ref() {
3413 return Some(val);
3414 }
3415 }
3416 find_compare_const(input)
3417 }
3418 RirNode::Project { input, .. } => find_compare_const(input),
3419 RirNode::Join { left, right, .. } => {
3420 find_compare_const(left).or_else(|| find_compare_const(right))
3421 }
3422 _ => None,
3423 }
3424 }
3425
3426 let node = result.unwrap();
3427 let const_val = find_compare_const(&node);
3428 assert!(const_val.is_some(), "Should find a constant in comparison");
3429
3430 match const_val.unwrap() {
3432 ConstValue::U64(v) => assert_eq!(*v, 2, "Value should be 2"),
3433 other => panic!("Expected U64(2), got {:?}", other),
3434 }
3435 }
3436}