1use std::collections::{BTreeMap, HashMap};
4use std::hash::{Hash, Hasher};
5
6use xlog_core::{Result, XlogError};
7use xlog_logic::ast::{
8 AggExpr, AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Evidence, ProbQuery, Program, Rule, Term,
9};
10use xlog_logic::stratify::{
11 analyze_stratification, build_dependency_graph, find_sccs_for_lowering, stratify,
12};
13
14use crate::wfs::{evaluate_wfs_rules, WfsAtom, WfsConfig, WfsLiteral, WfsRule};
15
16use crate::aggregates::{AggState, AggStateKey};
17use crate::pir::{ChoiceVarId, LeafId, PirGraph, PirNodeId};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
20pub enum Value {
21 I64(i64),
22 F64(u64),
23 Symbol(u32),
24 String(String),
25}
26
27impl From<i64> for Value {
28 fn from(v: i64) -> Self {
29 Self::I64(v)
30 }
31}
32
33impl From<u32> for Value {
34 fn from(v: u32) -> Self {
35 Self::Symbol(v)
36 }
37}
38
39impl From<String> for Value {
40 fn from(v: String) -> Self {
41 Self::String(v)
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
46pub struct GroundAtom {
47 pub predicate: String,
48 pub args: Vec<Value>,
49}
50
51impl GroundAtom {
52 pub fn new(predicate: impl Into<String>, args: Vec<Value>) -> Self {
53 Self {
54 predicate: predicate.into(),
55 args,
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq)]
62pub struct ChoiceSource {
63 pub choices: Vec<(GroundAtom, f64)>,
66 pub choice_index: usize,
68 pub source_id: Option<usize>,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum AggregateLiftStatus {
74 Fired,
75 FallbackExactEnumeration,
76 Declined,
77}
78
79impl AggregateLiftStatus {
80 pub fn as_str(self) -> &'static str {
81 match self {
82 AggregateLiftStatus::Fired => "fired",
83 AggregateLiftStatus::FallbackExactEnumeration => "fallback_exact_enumeration",
84 AggregateLiftStatus::Declined => "declined",
85 }
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct AggregateLiftReport {
91 pub predicate: String,
92 pub group_key: Vec<Value>,
93 pub operator: String,
94 pub finite_domain_source: String,
95 pub deterministic_rows: usize,
96 pub uncertain_rows: usize,
97 pub domain_size: usize,
98 pub cap: usize,
99 pub status: AggregateLiftStatus,
100 pub reason: String,
101 pub naive_outcomes: u128,
102 pub dynamic_programming_states: usize,
103}
104
105#[derive(Debug, Clone)]
106struct Relation {
107 tuples: BTreeMap<Vec<Value>, PirNodeId>,
108}
109
110impl Relation {
111 fn new() -> Self {
112 Self {
113 tuples: BTreeMap::new(),
114 }
115 }
116
117 fn get(&self, tuple: &[Value]) -> Option<PirNodeId> {
118 self.tuples.get(tuple).copied()
119 }
120
121 fn is_empty(&self) -> bool {
122 self.tuples.is_empty()
123 }
124
125 fn insert_or(&mut self, tuple: Vec<Value>, formula: PirNodeId, builder: &mut PirBuilder) {
126 let entry = self
127 .tuples
128 .entry(tuple)
129 .or_insert_with(|| builder.const_false());
130 *entry = builder.or(vec![*entry, formula]);
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135enum PirKey {
136 Const(bool),
137 Lit(LeafId),
138 NegLit(LeafId),
139 And(Vec<PirNodeId>),
140 Or(Vec<PirNodeId>),
141 Decision {
142 var: ChoiceVarId,
143 child_false: PirNodeId,
144 child_true: PirNodeId,
145 },
146}
147
148impl Hash for PirKey {
149 fn hash<H: Hasher>(&self, state: &mut H) {
150 match self {
151 PirKey::Const(b) => {
152 0u8.hash(state);
153 b.hash(state);
154 }
155 PirKey::Lit(l) => {
156 1u8.hash(state);
157 l.hash(state);
158 }
159 PirKey::NegLit(l) => {
160 5u8.hash(state);
161 l.hash(state);
162 }
163 PirKey::And(children) => {
164 2u8.hash(state);
165 children.hash(state);
166 }
167 PirKey::Or(children) => {
168 3u8.hash(state);
169 children.hash(state);
170 }
171 PirKey::Decision {
172 var,
173 child_false,
174 child_true,
175 } => {
176 4u8.hash(state);
177 var.hash(state);
178 child_false.hash(state);
179 child_true.hash(state);
180 }
181 }
182 }
183}
184
185#[derive(Debug)]
186struct PirBuilder {
187 pir: PirGraph,
188 intern: HashMap<PirKey, PirNodeId>,
189 const_true: PirNodeId,
190 const_false: PirNodeId,
191 or_children: HashMap<PirNodeId, Vec<PirNodeId>>,
194 and_children: HashMap<PirNodeId, Vec<PirNodeId>>,
196}
197
198impl PirBuilder {
199 fn new() -> Self {
200 let mut pir = PirGraph::new();
201 let const_true = pir.const_true();
202 let const_false = pir.const_false();
203
204 let mut intern = HashMap::new();
205 intern.insert(PirKey::Const(true), const_true);
206 intern.insert(PirKey::Const(false), const_false);
207
208 Self {
209 pir,
210 intern,
211 const_true,
212 const_false,
213 or_children: HashMap::new(),
214 and_children: HashMap::new(),
215 }
216 }
217
218 fn finish(self) -> PirGraph {
219 self.pir
220 }
221
222 fn const_true(&self) -> PirNodeId {
223 self.const_true
224 }
225
226 fn const_false(&self) -> PirNodeId {
227 self.const_false
228 }
229
230 fn lit(&mut self, leaf: LeafId) -> PirNodeId {
231 let key = PirKey::Lit(leaf);
232 if let Some(&id) = self.intern.get(&key) {
233 return id;
234 }
235 let id = self.pir.lit(leaf);
236 self.intern.insert(key, id);
237 id
238 }
239
240 fn neg_lit(&mut self, leaf: LeafId) -> PirNodeId {
241 let key = PirKey::NegLit(leaf);
242 if let Some(&id) = self.intern.get(&key) {
243 return id;
244 }
245 let id = self.pir.neg_lit(leaf);
246 self.intern.insert(key, id);
247 id
248 }
249
250 fn and(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
251 let mut flat: Vec<PirNodeId> = Vec::with_capacity(children.len());
254 for c in children {
255 match self.and_children.get(&c) {
256 Some(sub) => flat.extend_from_slice(sub),
257 None => flat.push(c),
258 }
259 }
260 let mut children = flat;
261 children.retain(|&c| c != self.const_true);
262 if children.contains(&self.const_false) {
263 return self.const_false;
264 }
265 if children.is_empty() {
266 return self.const_true;
267 }
268 if children.len() == 1 {
269 return children[0];
270 }
271 children.sort_by_key(|id| id.as_u32());
272 children.dedup();
273 if children.len() > 1 {
275 let members = children.clone();
276 children.retain(|c| match self.or_children.get(c) {
277 Some(sub) => !sub.iter().any(|s| {
278 s != c
279 && members
280 .binary_search_by_key(&s.as_u32(), |m| m.as_u32())
281 .is_ok()
282 }),
283 None => true,
284 });
285 }
286 if children.len() == 1 {
287 return children[0];
288 }
289 let key = PirKey::And(children.clone());
290 if let Some(&id) = self.intern.get(&key) {
291 return id;
292 }
293 let id = self.pir.and(children.clone());
294 self.intern.insert(key, id);
295 self.and_children.insert(id, children);
296 id
297 }
298
299 fn or(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
300 let mut flat: Vec<PirNodeId> = Vec::with_capacity(children.len());
302 for c in children {
303 match self.or_children.get(&c) {
304 Some(sub) => flat.extend_from_slice(sub),
305 None => flat.push(c),
306 }
307 }
308 let mut children = flat;
309 children.retain(|&c| c != self.const_false);
310 if children.contains(&self.const_true) {
311 return self.const_true;
312 }
313 if children.is_empty() {
314 return self.const_false;
315 }
316 if children.len() == 1 {
317 return children[0];
318 }
319 children.sort_by_key(|id| id.as_u32());
320 children.dedup();
321 if children.len() > 1 {
323 let members = children.clone();
324 children.retain(|c| match self.and_children.get(c) {
325 Some(sub) => !sub.iter().any(|s| {
326 s != c
327 && members
328 .binary_search_by_key(&s.as_u32(), |m| m.as_u32())
329 .is_ok()
330 }),
331 None => true,
332 });
333 }
334 if children.len() == 1 {
335 return children[0];
336 }
337 let key = PirKey::Or(children.clone());
338 if let Some(&id) = self.intern.get(&key) {
339 return id;
340 }
341 let id = self.pir.or(children.clone());
342 self.intern.insert(key, id);
343 self.or_children.insert(id, children);
344 id
345 }
346
347 fn decision(
348 &mut self,
349 var: ChoiceVarId,
350 child_false: PirNodeId,
351 child_true: PirNodeId,
352 ) -> PirNodeId {
353 if child_false == child_true {
354 return child_true;
355 }
356 let key = PirKey::Decision {
357 var,
358 child_false,
359 child_true,
360 };
361 if let Some(&id) = self.intern.get(&key) {
362 return id;
363 }
364 let id = self.pir.decision(var, child_false, child_true);
365 self.intern.insert(key, id);
366 id
367 }
368
369 fn choice_lit(&mut self, var: ChoiceVarId, is_true: bool) -> PirNodeId {
370 if is_true {
371 self.decision(var, self.const_false(), self.const_true())
372 } else {
373 self.decision(var, self.const_true(), self.const_false())
374 }
375 }
376}
377
378#[derive(Debug)]
380pub struct Provenance {
381 pub pir: PirGraph,
382 pub leaf_probs: BTreeMap<LeafId, f64>,
383 pub choice_probs: BTreeMap<ChoiceVarId, (f64, f64)>,
384 tuple_formulas: BTreeMap<GroundAtom, PirNodeId>,
385 pub queries: Vec<GroundAtom>,
386 pub evidence: Vec<(GroundAtom, bool)>,
387 pub leaf_atoms: BTreeMap<LeafId, GroundAtom>,
388 pub choice_sources: BTreeMap<ChoiceVarId, ChoiceSource>,
389 pub aggregate_lifting: Vec<AggregateLiftReport>,
390}
391
392impl Provenance {
393 pub fn query_formula(&self, predicate: &str, args: &[Value]) -> Option<PirNodeId> {
394 self.tuple_formulas
395 .get(&GroundAtom::new(predicate, args.to_vec()))
396 .copied()
397 }
398
399 pub fn leaf_atom(&self, leaf: LeafId) -> Option<&GroundAtom> {
400 self.leaf_atoms.get(&leaf)
401 }
402
403 pub fn choice_source(&self, var: ChoiceVarId) -> Option<&ChoiceSource> {
404 self.choice_sources.get(&var)
405 }
406
407 pub fn atoms_with_formulas(&self) -> impl Iterator<Item = (&GroundAtom, PirNodeId)> + '_ {
408 self.tuple_formulas.iter().map(|(atom, &id)| (atom, id))
409 }
410}
411
412pub fn extract_from_source(source: &str) -> Result<Provenance> {
413 let program = xlog_logic::parse_program(source)?;
414 extract_from_program(&program)
415}
416
417pub fn extract_from_program(program: &Program) -> Result<Provenance> {
418 let _ = stratify(program)?;
420
421 let mut builder = PirBuilder::new();
422
423 let mut leaf_probs: BTreeMap<LeafId, f64> = BTreeMap::new();
424 let mut choice_probs: BTreeMap<ChoiceVarId, (f64, f64)> = BTreeMap::new();
425 let mut leaf_atoms: BTreeMap<LeafId, GroundAtom> = BTreeMap::new();
426 let mut choice_sources: BTreeMap<ChoiceVarId, ChoiceSource> = BTreeMap::new();
427 let mut aggregate_lifting: Vec<AggregateLiftReport> = Vec::new();
428
429 let mut store: BTreeMap<String, Relation> = BTreeMap::new();
430
431 for fact in program.facts() {
433 let key = atom_key_from_ground_atom(&fact.head)?;
434 let rel = store
435 .entry(key.predicate.clone())
436 .or_insert_with(Relation::new);
437 rel.insert_or(key.args.clone(), builder.const_true(), &mut builder);
438 }
439
440 let mut next_leaf: u32 = 0;
442 for pf in &program.prob_facts {
443 validate_prob(pf.prob, "probabilistic fact")?;
444 let key = atom_key_from_ground_atom(&pf.atom)?;
445 let leaf = LeafId::new(next_leaf);
446 next_leaf = next_leaf.checked_add(1).ok_or_else(|| {
447 XlogError::Compilation("probabilistic fact leaf id overflow".to_string())
448 })?;
449 leaf_probs.insert(leaf, pf.prob);
450 leaf_atoms.insert(leaf, key.clone());
451
452 let rel = store
453 .entry(key.predicate.clone())
454 .or_insert_with(Relation::new);
455 rel.insert_or(key.args.clone(), builder.lit(leaf), &mut builder);
456 }
457
458 let mut next_choice: u32 = 0;
460 for ad in &program.annotated_disjunctions {
461 if ad.choices.is_empty() {
462 return Err(XlogError::Compilation(
463 "Annotated disjunction must contain at least one choice".to_string(),
464 ));
465 }
466 let (vars, outcome_formulas) = compile_annotated_disjunction(
467 ad,
468 &mut next_choice,
469 &mut choice_probs,
470 &mut choice_sources,
471 &mut builder,
472 )?;
473 let _ = vars;
474
475 for (pf, formula) in ad.choices.iter().zip(outcome_formulas) {
476 let key = atom_key_from_ground_atom(&pf.atom)?;
477 let rel = store
478 .entry(key.predicate.clone())
479 .or_insert_with(Relation::new);
480 rel.insert_or(key.args.clone(), formula, &mut builder);
481 }
482 }
483
484 let graph = build_dependency_graph(program);
486 for pred in &graph.predicates {
487 store.entry(pred.clone()).or_insert_with(Relation::new);
488 }
489
490 let strat_result = analyze_stratification(program);
492 let sccs = find_sccs_for_lowering(&graph);
493
494 let non_monotone_scc_preds: std::collections::HashSet<String> = strat_result
498 .sccs
499 .iter()
500 .enumerate()
501 .filter(|(i, _)| strat_result.non_monotone_sccs.contains(i))
502 .flat_map(|(_, scc)| scc.iter().cloned())
503 .collect();
504
505 let mut rules_by_head: BTreeMap<String, Vec<Rule>> = BTreeMap::new();
506 for rule in program.proper_rules() {
507 rules_by_head
509 .entry(rule.head.predicate.clone())
510 .or_default()
511 .push(rule.clone());
512 }
513
514 for scc in sccs {
515 let mut scc_rules: Vec<Rule> = Vec::new();
516 for pred in &scc {
517 if let Some(rules) = rules_by_head.get(pred) {
518 scc_rules.extend(rules.iter().cloned());
519 }
520 }
521 if scc_rules.is_empty() {
522 continue;
523 }
524
525 let is_non_monotone = scc.iter().any(|p| non_monotone_scc_preds.contains(p));
527
528 if is_non_monotone {
529 eval_non_monotone_scc_with_wfs(&scc, &scc_rules, &mut store, &mut builder)?;
531 } else {
532 let recursive = is_recursive_scc(&scc, &scc_rules);
533 if recursive {
534 eval_recursive_scc(
535 &scc,
536 &scc_rules,
537 &mut store,
538 &mut builder,
539 &mut aggregate_lifting,
540 )?;
541 } else {
542 eval_non_recursive_scc(
543 &scc_rules,
544 &mut store,
545 &mut builder,
546 &mut aggregate_lifting,
547 )?;
548 }
549 }
550 }
551
552 let mut tuple_formulas: BTreeMap<GroundAtom, PirNodeId> = BTreeMap::new();
554 for (pred, rel) in &store {
555 for (tuple, formula) in &rel.tuples {
556 tuple_formulas.insert(GroundAtom::new(pred.clone(), tuple.clone()), *formula);
557 }
558 }
559
560 let mut queries: Vec<GroundAtom> = Vec::new();
561 for ProbQuery { atom } in &program.prob_queries {
562 queries.push(atom_key_from_ground_atom(atom)?);
563 }
564
565 let mut evidence: Vec<(GroundAtom, bool)> = Vec::new();
566 for Evidence { atom, value } in &program.evidence {
567 evidence.push((atom_key_from_ground_atom(atom)?, *value));
568 }
569
570 Ok(Provenance {
571 pir: builder.finish(),
572 leaf_probs,
573 choice_probs,
574 tuple_formulas,
575 queries,
576 evidence,
577 leaf_atoms,
578 choice_sources,
579 aggregate_lifting,
580 })
581}
582
583pub(crate) fn validate_prob(p: f64, what: &str) -> Result<()> {
584 if !(0.0..=1.0).contains(&p) || p.is_nan() {
585 return Err(XlogError::Compilation(format!(
586 "Invalid probability {} for {} (expected 0<=p<=1)",
587 p, what
588 )));
589 }
590 Ok(())
591}
592
593pub(crate) fn atom_key_from_ground_atom(atom: &Atom) -> Result<GroundAtom> {
594 let mut args = Vec::with_capacity(atom.terms.len());
595 for term in &atom.terms {
596 if !term.is_constant() {
597 return Err(XlogError::Compilation(format!(
598 "Expected ground atom, found non-constant term in {}",
599 atom.predicate
600 )));
601 }
602 args.push(value_from_term(term)?);
603 }
604 Ok(GroundAtom::new(atom.predicate.clone(), args))
605}
606
607pub(crate) fn value_from_term(term: &Term) -> Result<Value> {
608 match term {
609 Term::Integer(i) => Ok(Value::I64(*i)),
610 Term::Float(f) => Ok(Value::F64(f.to_bits())),
611 Term::String(s) => Ok(Value::String(s.clone())),
612 Term::Symbol(id) => Ok(Value::Symbol(*id)),
613 Term::Variable(_) | Term::Anonymous | Term::Aggregate(_) => Err(XlogError::Compilation(
614 "Non-constant term cannot be converted to a value".to_string(),
615 )),
616 Term::List(_) => Err(unsupported_probabilistic_term_error(
617 "value conversion",
618 "list",
619 )),
620 Term::Cons { .. } => Err(unsupported_probabilistic_term_error(
621 "value conversion",
622 "cons",
623 )),
624 Term::Compound { .. } => Err(unsupported_probabilistic_term_error(
625 "value conversion",
626 "compound",
627 )),
628 Term::PredRef(_) => Err(unsupported_probabilistic_term_error(
629 "value conversion",
630 "predref",
631 )),
632 }
633}
634
635fn unsupported_probabilistic_term_error(context: &str, kind: &str) -> XlogError {
636 XlogError::Compilation(format!(
637 "high-level term form '{}' is parsed but not supported in probabilistic provenance {} until a lowering/materialization path exists",
638 kind, context
639 ))
640}
641
642fn compile_annotated_disjunction(
643 ad: &xlog_logic::ast::AnnotatedDisjunction,
644 next_choice: &mut u32,
645 choice_probs: &mut BTreeMap<ChoiceVarId, (f64, f64)>,
646 choice_sources: &mut BTreeMap<ChoiceVarId, ChoiceSource>,
647 builder: &mut PirBuilder,
648) -> Result<(Vec<ChoiceVarId>, Vec<PirNodeId>)> {
649 for pf in &ad.choices {
650 validate_prob(pf.prob, "annotated disjunction choice")?;
651 let _ = atom_key_from_ground_atom(&pf.atom)?;
652 }
653
654 let explicit_choices: Vec<(GroundAtom, f64)> = ad
655 .choices
656 .iter()
657 .map(|pf| {
658 let atom = atom_key_from_ground_atom(&pf.atom).unwrap();
659 (atom, pf.prob)
660 })
661 .collect();
662
663 let mut probs: Vec<f64> = ad.choices.iter().map(|pf| pf.prob).collect();
664 let sum: f64 = probs.iter().copied().sum();
665 let eps = 1e-12;
666 if sum > 1.0 + eps {
667 return Err(XlogError::Compilation(format!(
668 "Annotated disjunction probabilities sum to {} (> 1.0)",
669 sum
670 )));
671 }
672
673 let mut has_none = false;
674 let none_prob = (1.0 - sum).max(0.0);
675 if none_prob > eps {
676 probs.push(none_prob);
677 has_none = true;
678 }
679
680 let m = probs.len();
681 if m == 1 {
682 return Ok((Vec::new(), vec![builder.const_true()]));
683 }
684
685 let mut vars: Vec<ChoiceVarId> = Vec::with_capacity(m.saturating_sub(1));
686 let mut remaining = 1.0f64;
687 for (i, &p_i) in probs.iter().enumerate().take(m - 1) {
688 let cond_true = if remaining <= 0.0 {
689 0.0
690 } else {
691 p_i / remaining
692 };
693 validate_prob(cond_true, "annotated disjunction conditional")?;
694 let cond_false = 1.0 - cond_true;
695 let var = ChoiceVarId::new(*next_choice);
696 *next_choice = (*next_choice).checked_add(1).ok_or_else(|| {
697 XlogError::Compilation("annotated disjunction choice id overflow".to_string())
698 })?;
699 vars.push(var);
700 choice_probs.insert(var, (cond_true, cond_false));
701 choice_sources.insert(
702 var,
703 ChoiceSource {
704 choices: explicit_choices.clone(),
705 choice_index: i,
706 source_id: None,
707 },
708 );
709 remaining -= p_i;
710 }
711
712 let mut outcome_formulas: Vec<PirNodeId> = Vec::new();
713 for i in 0..ad.choices.len() {
714 let mut conds: Vec<PirNodeId> = Vec::new();
715 for (j, &var) in vars.iter().enumerate() {
716 if j < i {
717 conds.push(builder.choice_lit(var, false));
718 } else if j == i {
719 conds.push(builder.choice_lit(var, true));
720 break;
721 }
722 }
723 outcome_formulas.push(builder.and(conds));
724 }
725
726 if has_none {
727 }
730
731 Ok((vars, outcome_formulas))
732}
733
734fn is_recursive_scc(scc: &[String], rules: &[Rule]) -> bool {
735 if scc.len() > 1 {
736 return true;
737 }
738 let Some(only) = scc.first() else {
739 return false;
740 };
741 for rule in rules {
742 for lit in &rule.body {
743 if let BodyLiteral::Positive(atom) = lit {
744 if &atom.predicate == only {
745 return true;
746 }
747 }
748 }
749 }
750 false
751}
752
753fn eval_non_recursive_scc(
754 rules: &[Rule],
755 store: &mut BTreeMap<String, Relation>,
756 builder: &mut PirBuilder,
757 aggregate_lifting: &mut Vec<AggregateLiftReport>,
758) -> Result<()> {
759 for rule in rules {
760 let derived = eval_rule(
761 rule,
762 store,
763 &BTreeMap::new(),
764 None,
765 builder,
766 aggregate_lifting,
767 )?;
768 let rel = store
769 .entry(rule.head.predicate.clone())
770 .or_insert_with(Relation::new);
771 for (tuple, formula) in derived {
772 rel.insert_or(tuple, formula, builder);
773 }
774 }
775 Ok(())
776}
777
778const MAX_PROVENANCE_ITERATIONS: usize = 1024;
779
780fn eval_recursive_scc(
781 scc: &[String],
782 rules: &[Rule],
783 store: &mut BTreeMap<String, Relation>,
784 builder: &mut PirBuilder,
785 aggregate_lifting: &mut Vec<AggregateLiftReport>,
786) -> Result<()> {
787 let scc_set: std::collections::HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
788
789 let mut full: BTreeMap<String, Relation> = BTreeMap::new();
791 for pred in scc {
792 let rel = store.get(pred).cloned().unwrap_or_else(Relation::new);
793 full.insert(pred.clone(), rel);
794 }
795
796 let mut delta: BTreeMap<String, Relation> = BTreeMap::new();
798 for rule in rules {
799 let derived = eval_rule(rule, store, &full, None, builder, aggregate_lifting)?;
800 if derived.is_empty() {
801 continue;
802 }
803 let head = rule.head.predicate.clone();
804 let delta_rel = delta.entry(head.clone()).or_insert_with(Relation::new);
805 let full_rel = full.entry(head).or_insert_with(Relation::new);
806 for (tuple, proof) in derived {
807 let old = full_rel.get(&tuple).unwrap_or(builder.const_false());
808 let combined = builder.or(vec![old, proof]);
809 if combined != old {
810 full_rel.tuples.insert(tuple.clone(), combined);
811 delta_rel.insert_or(tuple, proof, builder);
812 }
813 }
814 }
815
816 let mut reached_fixpoint = false;
817 for _ in 0..MAX_PROVENANCE_ITERATIONS {
818 let any_delta = delta.values().any(|r| !r.is_empty());
819 if !any_delta {
820 reached_fixpoint = true;
821 break;
822 }
823
824 let full_prev = full.clone();
825 let delta_prev = delta.clone();
826 delta.clear();
827
828 for rule in rules {
829 let body_indices: Vec<usize> = rule
830 .body
831 .iter()
832 .enumerate()
833 .filter_map(|(i, lit)| match lit {
834 BodyLiteral::Positive(atom) if scc_set.contains(atom.predicate.as_str()) => {
835 let pred = &atom.predicate;
836 let non_empty =
837 delta_prev.get(pred).map(|r| !r.is_empty()).unwrap_or(false);
838 non_empty.then_some(i)
839 }
840 _ => None,
841 })
842 .collect();
843 if body_indices.is_empty() {
844 continue;
845 }
846
847 let mut derived_all: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
848 for idx in body_indices {
849 let derived = eval_rule(
850 rule,
851 store,
852 &full_prev,
853 Some((idx, &delta_prev)),
854 builder,
855 aggregate_lifting,
856 )?;
857 for (tuple, proof) in derived {
858 let entry = derived_all
859 .entry(tuple)
860 .or_insert_with(|| builder.const_false());
861 *entry = builder.or(vec![*entry, proof]);
862 }
863 }
864
865 if derived_all.is_empty() {
866 continue;
867 }
868
869 let head = rule.head.predicate.clone();
870 let delta_rel = delta.entry(head.clone()).or_insert_with(Relation::new);
871 let full_rel = full.entry(head).or_insert_with(Relation::new);
872 for (tuple, proof) in derived_all {
873 let old = full_rel.get(&tuple).unwrap_or(builder.const_false());
874 let combined = builder.or(vec![old, proof]);
875 if combined != old {
876 full_rel.tuples.insert(tuple.clone(), combined);
877 delta_rel.insert_or(tuple, proof, builder);
878 }
879 }
880 }
881 }
882 if !reached_fixpoint {
883 return Err(XlogError::Compilation(format!(
884 "Provenance iteration limit ({}) exceeded for SCC {:?}",
885 MAX_PROVENANCE_ITERATIONS, scc
886 )));
887 }
888
889 for (pred, rel) in full {
891 store.insert(pred, rel);
892 }
893
894 Ok(())
895}
896
897fn eval_non_monotone_scc_with_wfs(
907 scc: &[String],
908 rules: &[Rule],
909 store: &mut BTreeMap<String, Relation>,
910 builder: &mut PirBuilder,
911) -> Result<()> {
912 let scc_set: std::collections::HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
913
914 let mut wfs_rules: Vec<WfsRule> = Vec::new();
917
918 for rule in rules {
919 let grounded = ground_rule_for_wfs(rule, store, &scc_set, builder)?;
921 wfs_rules.extend(grounded);
922 }
923
924 if wfs_rules.is_empty() {
925 return Ok(());
927 }
928
929 let wfs_result = evaluate_wfs_rules(&wfs_rules, &mut builder.pir, &WfsConfig::default())?;
931
932 for (wfs_atom, prov) in wfs_result.true_set {
935 let rel = store
936 .entry(wfs_atom.predicate.clone())
937 .or_insert_with(Relation::new);
938 rel.insert_or(wfs_atom.args, prov, builder);
939 }
940
941 Ok(())
942}
943
944fn ground_rule_for_wfs(
949 rule: &Rule,
950 store: &BTreeMap<String, Relation>,
951 scc_set: &std::collections::HashSet<&str>,
952 builder: &mut PirBuilder,
953) -> Result<Vec<WfsRule>> {
954 let mut bindings: Vec<(HashMap<String, Value>, PirNodeId)> =
956 vec![(HashMap::new(), builder.const_true())];
957
958 let mut wfs_body_template: Vec<(usize, bool)> = Vec::new(); for (idx, lit) in rule.body.iter().enumerate() {
963 match lit {
964 BodyLiteral::Positive(atom) => {
965 if scc_set.contains(atom.predicate.as_str()) {
966 wfs_body_template.push((idx, true));
968 } else {
969 let rel = store.get(&atom.predicate);
971 let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
972
973 for (binding, prov) in bindings {
974 if let Some(rel) = rel {
975 for (tuple, tuple_prov) in &rel.tuples {
976 let mut new_binding = binding.clone();
977 if unify_atom(atom, tuple, &mut new_binding)? {
978 let new_prov = builder.and(vec![prov, *tuple_prov]);
979 next_bindings.push((new_binding, new_prov));
980 }
981 }
982 }
983 }
985 bindings = next_bindings;
986 if bindings.is_empty() {
987 return Ok(Vec::new());
988 }
989 }
990 }
991 BodyLiteral::Negated(atom) => {
992 if scc_set.contains(atom.predicate.as_str()) {
993 wfs_body_template.push((idx, false));
995 } else {
996 let rel = store.get(&atom.predicate);
998 let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
999
1000 for (binding, prov) in bindings {
1001 let all_bound = atom.terms.iter().all(|t| match t {
1003 Term::Variable(v) => binding.contains_key(v),
1004 _ => true,
1005 });
1006
1007 if !all_bound {
1008 continue;
1010 }
1011
1012 if let Some(rel) = rel {
1013 let mut matching_provs: Vec<PirNodeId> = Vec::new();
1015 for (tuple, tuple_prov) in &rel.tuples {
1016 let mut test_binding = binding.clone();
1017 if unify_atom(atom, tuple, &mut test_binding)? {
1018 matching_provs.push(*tuple_prov);
1019 }
1020 }
1021
1022 if matching_provs.is_empty() {
1023 next_bindings.push((binding, prov));
1025 } else {
1026 let combined = builder.or(matching_provs);
1028 let neg_prov = negate_provenance(combined, builder);
1029 let new_prov = builder.and(vec![prov, neg_prov]);
1030 next_bindings.push((binding, new_prov));
1031 }
1032 } else {
1033 next_bindings.push((binding, prov));
1035 }
1036 }
1037 bindings = next_bindings;
1038 if bindings.is_empty() {
1039 return Ok(Vec::new());
1040 }
1041 }
1042 }
1043 BodyLiteral::Epistemic(lit) => {
1044 return Err(XlogError::UnsupportedEpistemicConstruct {
1045 construct: "probabilistic WFS grounding".to_string(),
1046 context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
1047 });
1048 }
1049 BodyLiteral::Comparison(cmp) => {
1050 let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
1051 for (binding, prov) in bindings {
1052 if eval_comparison(cmp.op, &cmp.left, &cmp.right, &binding)? {
1053 next_bindings.push((binding, prov));
1054 }
1055 }
1056 bindings = next_bindings;
1057 if bindings.is_empty() {
1058 return Ok(Vec::new());
1059 }
1060 }
1061 BodyLiteral::IsExpr(is_expr) => {
1062 let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
1063 for (mut binding, prov) in bindings {
1064 if binding.contains_key(&is_expr.target) {
1065 return Err(XlogError::Compilation(format!(
1066 "Is-expression target {} is already bound",
1067 is_expr.target
1068 )));
1069 }
1070 let v = eval_arith_expr(&is_expr.expr, &binding)?;
1071 binding.insert(is_expr.target.clone(), v);
1072 next_bindings.push((binding, prov));
1073 }
1074 bindings = next_bindings;
1075 if bindings.is_empty() {
1076 return Ok(Vec::new());
1077 }
1078 }
1079 BodyLiteral::Univ(_) => {
1080 return Err(XlogError::Compilation(
1081 "univ literal was not normalized before provenance extraction".to_string(),
1082 ));
1083 }
1084 }
1085 }
1086
1087 let mut result: Vec<WfsRule> = Vec::new();
1089
1090 for (binding, external_prov) in bindings {
1091 let mut wfs_body: Vec<WfsLiteral> = Vec::new();
1093
1094 for &(idx, is_positive) in &wfs_body_template {
1095 let atom = match &rule.body[idx] {
1096 BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a,
1097 _ => continue,
1098 };
1099
1100 let mut args: Vec<Value> = Vec::new();
1102 for term in &atom.terms {
1103 match term {
1104 Term::Variable(name) => {
1105 if let Some(v) = binding.get(name) {
1106 args.push(v.clone());
1107 } else {
1108 continue;
1111 }
1112 }
1113 _ => {
1114 args.push(value_from_term(term)?);
1115 }
1116 }
1117 }
1118
1119 let wfs_atom = WfsAtom::new(atom.predicate.clone(), args);
1120 if is_positive {
1121 wfs_body.push(WfsLiteral::Positive(wfs_atom));
1122 } else {
1123 wfs_body.push(WfsLiteral::Negative(wfs_atom));
1124 }
1125 }
1126
1127 let mut head_args: Vec<Value> = Vec::new();
1129 for term in &rule.head.terms {
1130 match term {
1131 Term::Variable(name) => {
1132 if let Some(v) = binding.get(name) {
1133 head_args.push(v.clone());
1134 } else {
1135 continue;
1137 }
1138 }
1139 _ => {
1140 head_args.push(value_from_term(term)?);
1141 }
1142 }
1143 }
1144
1145 let wfs_head = WfsAtom::new(rule.head.predicate.clone(), head_args);
1146 result.push(WfsRule::new(wfs_head, wfs_body, external_prov));
1147 }
1148
1149 Ok(result)
1150}
1151
1152fn negate_provenance(prov: PirNodeId, builder: &mut PirBuilder) -> PirNodeId {
1160 use crate::pir::PirNode;
1161 match builder.pir.node(prov).cloned() {
1162 Some(PirNode::Const(b)) => {
1163 if b {
1164 builder.const_false()
1165 } else {
1166 builder.const_true()
1167 }
1168 }
1169 Some(PirNode::Lit { leaf }) => builder.neg_lit(leaf),
1170 Some(PirNode::NegLit { leaf }) => builder.lit(leaf), Some(PirNode::And { children }) => {
1172 let neg_children: Vec<PirNodeId> = children
1174 .iter()
1175 .map(|&c| negate_provenance(c, builder))
1176 .collect();
1177 builder.or(neg_children)
1178 }
1179 Some(PirNode::Or { children }) => {
1180 let neg_children: Vec<PirNodeId> = children
1182 .iter()
1183 .map(|&c| negate_provenance(c, builder))
1184 .collect();
1185 builder.and(neg_children)
1186 }
1187 Some(PirNode::Decision {
1188 var,
1189 child_false,
1190 child_true,
1191 }) => {
1192 let neg_false = negate_provenance(child_false, builder);
1194 let neg_true = negate_provenance(child_true, builder);
1195 builder.decision(var, neg_false, neg_true)
1196 }
1197 None => prov,
1198 }
1199}
1200
1201fn eval_rule(
1206 rule: &Rule,
1207 global: &BTreeMap<String, Relation>,
1208 full_scc: &BTreeMap<String, Relation>,
1209 delta_scc: Option<(usize, &BTreeMap<String, Relation>)>,
1210 builder: &mut PirBuilder,
1211 aggregate_lifting: &mut Vec<AggregateLiftReport>,
1212) -> Result<BTreeMap<Vec<Value>, PirNodeId>> {
1213 let mut states: Vec<(HashMap<String, Value>, PirNodeId)> =
1214 vec![(HashMap::new(), builder.const_true())];
1215
1216 for (idx, lit) in rule.body.iter().enumerate() {
1217 let mut next_states: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
1218 match lit {
1219 BodyLiteral::Positive(atom) => {
1220 let rel = select_relation(atom, idx, global, full_scc, delta_scc)?;
1221 for (binding, prov) in states {
1222 for (tuple, tuple_prov) in &rel.tuples {
1223 let mut binding2 = binding.clone();
1224 if unify_atom(atom, tuple, &mut binding2)? {
1225 let prov2 = builder.and(vec![prov, *tuple_prov]);
1226 next_states.push((binding2, prov2));
1227 }
1228 }
1229 }
1230 }
1231 BodyLiteral::Comparison(cmp) => {
1232 for (binding, prov) in states {
1233 if eval_comparison(cmp.op, &cmp.left, &cmp.right, &binding)? {
1234 next_states.push((binding, prov));
1235 }
1236 }
1237 }
1238 BodyLiteral::IsExpr(is_expr) => {
1239 for (mut binding, prov) in states {
1240 if binding.contains_key(&is_expr.target) {
1241 return Err(XlogError::Compilation(format!(
1242 "Is-expression target {} is already bound",
1243 is_expr.target
1244 )));
1245 }
1246 let v = eval_arith_expr(&is_expr.expr, &binding)?;
1247 binding.insert(is_expr.target.clone(), v);
1248 next_states.push((binding, prov));
1249 }
1250 }
1251 BodyLiteral::Negated(atom) => {
1252 let rel = if let Some(r) = full_scc.get(&atom.predicate) {
1259 r
1260 } else if let Some(r) = global.get(&atom.predicate) {
1261 r
1262 } else {
1263 for (binding, prov) in states {
1265 let all_bound = atom.terms.iter().all(|t| match t {
1267 Term::Variable(v) => binding.contains_key(v),
1268 _ => true,
1269 });
1270 if all_bound {
1271 next_states.push((binding, prov));
1272 }
1273 }
1274 states = next_states;
1275 if states.is_empty() {
1276 break;
1277 }
1278 continue;
1279 };
1280
1281 for (binding, prov) in states {
1282 let all_bound = atom.terms.iter().all(|t| match t {
1285 Term::Variable(v) => binding.contains_key(v),
1286 _ => true,
1287 });
1288 if !all_bound {
1289 continue;
1291 }
1292
1293 let mut matching_provs: Vec<PirNodeId> = Vec::new();
1295 for (tuple, tuple_prov) in &rel.tuples {
1296 let mut binding2 = binding.clone();
1297 if unify_atom(atom, tuple, &mut binding2)? {
1298 matching_provs.push(*tuple_prov);
1300 }
1301 }
1302
1303 if matching_provs.is_empty() {
1304 next_states.push((binding, prov));
1306 } else {
1307 let combined_tuple_prov = builder.or(matching_provs);
1312 let neg_prov = negate_provenance(combined_tuple_prov, builder);
1313 let new_prov = builder.and(vec![prov, neg_prov]);
1314 next_states.push((binding, new_prov));
1315 }
1316 }
1317 }
1318 BodyLiteral::Epistemic(lit) => {
1319 return Err(XlogError::UnsupportedEpistemicConstruct {
1320 construct: "probabilistic provenance evaluation".to_string(),
1321 context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
1322 });
1323 }
1324 BodyLiteral::Univ(_) => {
1325 return Err(XlogError::Compilation(
1326 "univ literal was not normalized before provenance extraction".to_string(),
1327 ));
1328 }
1329 }
1330 states = next_states;
1331 if states.is_empty() {
1332 break;
1333 }
1334 }
1335
1336 if rule.has_aggregation() {
1337 eval_aggregate_head_provenance(&rule.head, states, builder, aggregate_lifting)
1338 } else {
1339 let mut out: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
1340 for (binding, prov) in states {
1341 let head_tuple = materialize_head(&rule.head, &binding)?;
1342 let entry = out
1343 .entry(head_tuple)
1344 .or_insert_with(|| builder.const_false());
1345 *entry = builder.or(vec![*entry, prov]);
1346 }
1347 Ok(out)
1348 }
1349}
1350
1351const MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS: usize = 16;
1352const MAX_EXACT_PROB_COUNT_LIFT_ROWS: usize = 64;
1353
1354#[derive(Debug, Clone)]
1355struct AggregateProvRow {
1356 binding: HashMap<String, Value>,
1357 prov: PirNodeId,
1358}
1359
1360fn eval_aggregate_head_provenance(
1361 head: &Atom,
1362 states: Vec<(HashMap<String, Value>, PirNodeId)>,
1363 builder: &mut PirBuilder,
1364 aggregate_lifting: &mut Vec<AggregateLiftReport>,
1365) -> Result<BTreeMap<Vec<Value>, PirNodeId>> {
1366 let (key_vars, key_var_to_pos, agg_specs, agg_to_pos) = aggregate_head_plan(head)?;
1367
1368 let mut deduped_states: BTreeMap<Vec<(String, Value)>, AggregateProvRow> = BTreeMap::new();
1369 for (binding, prov) in states {
1370 let key = canonical_binding_key(&binding);
1371 match deduped_states.get_mut(&key) {
1372 Some(row) => {
1373 row.prov = builder.or(vec![row.prov, prov]);
1374 }
1375 None => {
1376 deduped_states.insert(key, AggregateProvRow { binding, prov });
1377 }
1378 }
1379 }
1380
1381 #[derive(Debug)]
1382 struct GroupRows {
1383 key: Vec<Value>,
1384 rows: Vec<AggregateProvRow>,
1385 }
1386
1387 let mut groups: BTreeMap<Vec<Value>, GroupRows> = BTreeMap::new();
1388 for row in deduped_states.into_values() {
1389 let mut key: Vec<Value> = Vec::with_capacity(key_vars.len());
1390 for name in &key_vars {
1391 let v = row
1392 .binding
1393 .get(name)
1394 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1395 key.push(v.clone());
1396 }
1397 groups
1398 .entry(key.clone())
1399 .or_insert_with(|| GroupRows {
1400 key,
1401 rows: Vec::new(),
1402 })
1403 .rows
1404 .push(row);
1405 }
1406
1407 let mut out: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
1408 let count_only = agg_specs.iter().all(|(op, _)| *op == AggOp::Count);
1409 for group in groups.into_values() {
1410 let mut always_rows: Vec<AggregateProvRow> = Vec::new();
1411 let mut uncertain_rows: Vec<AggregateProvRow> = Vec::new();
1412 for row in group.rows {
1413 match pir_const_value(builder, row.prov) {
1414 Some(true) => always_rows.push(row),
1415 Some(false) => {}
1416 None => uncertain_rows.push(row),
1417 }
1418 }
1419
1420 if always_rows.is_empty() && uncertain_rows.is_empty() {
1421 continue;
1422 }
1423 if count_only {
1424 if uncertain_rows.len() > MAX_EXACT_PROB_COUNT_LIFT_ROWS {
1425 return Err(XlogError::Compilation(format!(
1426 "count aggregate lifting finite domain cap exceeded for predicate {} group {:?}: {} uncertain rows > cap {}; use prob_engine = mc or reduce the finite aggregate domain",
1427 head.predicate,
1428 group.key,
1429 uncertain_rows.len(),
1430 MAX_EXACT_PROB_COUNT_LIFT_ROWS
1431 )));
1432 }
1433 validate_count_lift_rows(&agg_specs, &always_rows, &uncertain_rows)?;
1434 record_aggregate_lift_reports(
1435 aggregate_lifting,
1436 head,
1437 &group.key,
1438 &agg_specs,
1439 always_rows.len(),
1440 uncertain_rows.len(),
1441 AggregateLiftStatus::Fired,
1442 "finite count domain lifted with exact cardinality dynamic programming",
1443 MAX_EXACT_PROB_COUNT_LIFT_ROWS,
1444 count_lift_dp_states(uncertain_rows.len()),
1445 );
1446 let count_formulas = count_lift_formulas(&uncertain_rows, builder);
1447 for (selected_uncertain_rows, proof) in count_formulas.into_iter().enumerate() {
1448 if always_rows.is_empty() && selected_uncertain_rows == 0 {
1449 continue;
1450 }
1451 let count_value = always_rows.len() + selected_uncertain_rows;
1452 let tuple =
1453 materialize_count_lift_tuple(head, &group.key, &key_var_to_pos, count_value)?;
1454 let entry = out.entry(tuple).or_insert_with(|| builder.const_false());
1455 *entry = builder.or(vec![*entry, proof]);
1456 }
1457 continue;
1458 }
1459
1460 if uncertain_rows.len() > MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS {
1461 return Err(XlogError::Compilation(format!(
1462 "exact probabilistic aggregate domain cap exceeded for predicate {} group {:?}: {} uncertain rows > cap {}; use prob_engine = mc or reduce the finite aggregate domain",
1463 head.predicate,
1464 group.key,
1465 uncertain_rows.len(),
1466 MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS
1467 )));
1468 }
1469 let (outcomes, dp_states) =
1470 factorized_aggregate_outcomes(&agg_specs, &always_rows, &uncertain_rows, builder)?;
1471 record_aggregate_lift_reports(
1472 aggregate_lifting,
1473 head,
1474 &group.key,
1475 &agg_specs,
1476 always_rows.len(),
1477 uncertain_rows.len(),
1478 AggregateLiftStatus::Fired,
1479 "finite outcome domain folded with factorized aggregate-state dynamic programming",
1480 MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS,
1481 dp_states,
1482 );
1483
1484 for (agg_states, selected_any, proof) in outcomes {
1485 if always_rows.is_empty() && !selected_any {
1486 continue;
1489 }
1490
1491 let tuple = materialize_aggregate_tuple(
1492 head,
1493 &group.key,
1494 &key_var_to_pos,
1495 &agg_specs,
1496 &agg_to_pos,
1497 &agg_states,
1498 )?;
1499 let entry = out.entry(tuple).or_insert_with(|| builder.const_false());
1500 *entry = builder.or(vec![*entry, proof]);
1501 }
1502 }
1503
1504 Ok(out)
1505}
1506
1507#[allow(clippy::type_complexity)]
1523fn factorized_aggregate_outcomes(
1524 agg_specs: &[(AggOp, String)],
1525 always_rows: &[AggregateProvRow],
1526 uncertain_rows: &[AggregateProvRow],
1527 builder: &mut PirBuilder,
1528) -> Result<(Vec<(Vec<AggState>, bool, PirNodeId)>, usize)> {
1529 use std::collections::btree_map::Entry;
1530
1531 fn states_key(states: &[AggState]) -> Vec<AggStateKey> {
1532 states.iter().map(AggState::dp_key).collect()
1533 }
1534
1535 let mut base: Vec<AggState> = agg_specs.iter().map(|(op, _)| AggState::new(*op)).collect();
1536 for row in always_rows {
1537 update_aggregate_states(&mut base, agg_specs, row)?;
1538 }
1539
1540 let mut dp: BTreeMap<(Vec<AggStateKey>, bool), (Vec<AggState>, PirNodeId)> = BTreeMap::new();
1541 let true_proof = builder.const_true();
1542 dp.insert((states_key(&base), false), (base, true_proof));
1543 let mut dp_states = dp.len();
1544
1545 for row in uncertain_rows {
1546 let absent = negate_provenance(row.prov, builder);
1547 let mut next: BTreeMap<(Vec<AggStateKey>, bool), (Vec<AggState>, PirNodeId)> =
1548 BTreeMap::new();
1549 for ((key, selected_any), (states, proof)) in dp {
1550 let mut present_states = states.clone();
1551 update_aggregate_states(&mut present_states, agg_specs, row)?;
1552 let present_key = states_key(&present_states);
1553 let present_proof = builder.and(vec![proof, row.prov]);
1554 match next.entry((present_key, true)) {
1555 Entry::Occupied(mut entry) => {
1556 entry.get_mut().1 = builder.or(vec![entry.get().1, present_proof]);
1557 }
1558 Entry::Vacant(entry) => {
1559 entry.insert((present_states, present_proof));
1560 }
1561 }
1562
1563 let absent_proof = builder.and(vec![proof, absent]);
1564 match next.entry((key, selected_any)) {
1565 Entry::Occupied(mut entry) => {
1566 entry.get_mut().1 = builder.or(vec![entry.get().1, absent_proof]);
1567 }
1568 Entry::Vacant(entry) => {
1569 entry.insert((states, absent_proof));
1570 }
1571 }
1572 }
1573 dp = next;
1574 dp_states += dp.len();
1575 }
1576
1577 let outcomes = dp
1578 .into_iter()
1579 .map(|((_, selected_any), (states, proof))| (states, selected_any, proof))
1580 .collect();
1581 Ok((outcomes, dp_states))
1582}
1583
1584fn validate_count_lift_rows(
1585 agg_specs: &[(AggOp, String)],
1586 always_rows: &[AggregateProvRow],
1587 uncertain_rows: &[AggregateProvRow],
1588) -> Result<()> {
1589 for (_, var) in agg_specs {
1590 for row in always_rows.iter().chain(uncertain_rows.iter()) {
1591 if !row.binding.contains_key(var) {
1592 return Err(XlogError::UnsafeVariable(var.clone()));
1593 }
1594 }
1595 }
1596 Ok(())
1597}
1598
1599fn count_lift_formulas(
1600 uncertain_rows: &[AggregateProvRow],
1601 builder: &mut PirBuilder,
1602) -> Vec<PirNodeId> {
1603 let n = uncertain_rows.len();
1604 let mut dp = vec![builder.const_false(); n + 1];
1605 dp[0] = builder.const_true();
1606
1607 for (idx, row) in uncertain_rows.iter().enumerate() {
1608 let mut next = vec![builder.const_false(); n + 1];
1609 let present = row.prov;
1610 let absent = negate_provenance(row.prov, builder);
1611 for selected in 0..=idx {
1612 let absent_case = builder.and(vec![dp[selected], absent]);
1613 next[selected] = builder.or(vec![next[selected], absent_case]);
1614
1615 let present_case = builder.and(vec![dp[selected], present]);
1616 next[selected + 1] = builder.or(vec![next[selected + 1], present_case]);
1617 }
1618 dp = next;
1619 }
1620
1621 dp
1622}
1623
1624fn materialize_count_lift_tuple(
1625 head: &Atom,
1626 group_key: &[Value],
1627 key_var_to_pos: &HashMap<String, usize>,
1628 count_value: usize,
1629) -> Result<Vec<Value>> {
1630 let count_value: i64 = count_value
1631 .try_into()
1632 .map_err(|_| XlogError::Compilation("count() overflowed i64".to_string()))?;
1633 let mut tuple: Vec<Value> = Vec::with_capacity(head.terms.len());
1634 for term in &head.terms {
1635 match term {
1636 Term::Variable(name) => {
1637 let pos = *key_var_to_pos.get(name).ok_or_else(|| {
1638 XlogError::Compilation(format!(
1639 "Aggregate head variable {} is not a group key",
1640 name
1641 ))
1642 })?;
1643 tuple.push(group_key[pos].clone());
1644 }
1645 Term::Aggregate(AggExpr {
1646 op: AggOp::Count, ..
1647 }) => tuple.push(Value::I64(count_value)),
1648 Term::Aggregate(AggExpr { op, .. }) => {
1649 return Err(XlogError::Compilation(format!(
1650 "Internal aggregate lift state mismatch for {}",
1651 agg_op_label(*op)
1652 )));
1653 }
1654 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1655 tuple.push(value_from_term(term)?);
1656 }
1657 Term::Anonymous => unreachable!("aggregate head plan rejects anonymous terms"),
1658 Term::List(_) => {
1659 return Err(unsupported_probabilistic_term_error(
1660 "aggregate head materialization",
1661 "list",
1662 ));
1663 }
1664 Term::Cons { .. } => {
1665 return Err(unsupported_probabilistic_term_error(
1666 "aggregate head materialization",
1667 "cons",
1668 ));
1669 }
1670 Term::Compound { .. } => {
1671 return Err(unsupported_probabilistic_term_error(
1672 "aggregate head materialization",
1673 "compound",
1674 ));
1675 }
1676 Term::PredRef(_) => {
1677 return Err(unsupported_probabilistic_term_error(
1678 "aggregate head materialization",
1679 "predref",
1680 ));
1681 }
1682 }
1683 }
1684 Ok(tuple)
1685}
1686
1687#[allow(clippy::too_many_arguments)]
1688fn record_aggregate_lift_reports(
1689 aggregate_lifting: &mut Vec<AggregateLiftReport>,
1690 head: &Atom,
1691 group_key: &[Value],
1692 agg_specs: &[(AggOp, String)],
1693 deterministic_rows: usize,
1694 uncertain_rows: usize,
1695 status: AggregateLiftStatus,
1696 reason: &str,
1697 cap: usize,
1698 dynamic_programming_states: usize,
1699) {
1700 for (op, _) in agg_specs {
1701 aggregate_lifting.push(AggregateLiftReport {
1702 predicate: head.predicate.clone(),
1703 group_key: group_key.to_vec(),
1704 operator: agg_op_label(*op).to_string(),
1705 finite_domain_source: "grounded body rows".to_string(),
1706 deterministic_rows,
1707 uncertain_rows,
1708 domain_size: deterministic_rows + uncertain_rows,
1709 cap,
1710 status,
1711 reason: reason.to_string(),
1712 naive_outcomes: naive_outcome_count(uncertain_rows),
1713 dynamic_programming_states,
1714 });
1715 }
1716}
1717
1718fn agg_op_label(op: AggOp) -> &'static str {
1719 match op {
1720 AggOp::Count => "count",
1721 AggOp::Sum => "sum",
1722 AggOp::Min => "min",
1723 AggOp::Max => "max",
1724 AggOp::LogSumExp => "logsumexp",
1725 }
1726}
1727
1728fn naive_outcome_count(uncertain_rows: usize) -> u128 {
1729 if uncertain_rows >= u128::BITS as usize {
1730 u128::MAX
1731 } else {
1732 1u128 << uncertain_rows
1733 }
1734}
1735
1736fn count_lift_dp_states(uncertain_rows: usize) -> usize {
1737 (uncertain_rows + 1) * (uncertain_rows + 2) / 2
1738}
1739
1740type AggregatePlan = (
1741 Vec<String>,
1742 HashMap<String, usize>,
1743 Vec<(AggOp, String)>,
1744 HashMap<(AggOp, String), usize>,
1745);
1746
1747fn aggregate_head_plan(head: &Atom) -> Result<AggregatePlan> {
1748 let mut key_vars: Vec<String> = Vec::new();
1749 let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
1750 let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
1751 let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
1752
1753 for term in &head.terms {
1754 match term {
1755 Term::Variable(name) => {
1756 if !key_var_to_pos.contains_key(name) {
1757 let pos = key_vars.len();
1758 key_vars.push(name.clone());
1759 key_var_to_pos.insert(name.clone(), pos);
1760 }
1761 }
1762 Term::Aggregate(agg) => {
1763 let key = (agg.op, agg.variable.clone());
1764 if let std::collections::hash_map::Entry::Vacant(entry) =
1765 agg_to_pos.entry(key.clone())
1766 {
1767 let pos = agg_specs.len();
1768 agg_specs.push(key);
1769 entry.insert(pos);
1770 }
1771 }
1772 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {}
1773 Term::Anonymous => {
1774 return Err(XlogError::Compilation(format!(
1775 "Anonymous variable in aggregate head of {} is not supported",
1776 head.predicate
1777 )));
1778 }
1779 Term::List(_) => {
1780 return Err(unsupported_probabilistic_term_error(
1781 "aggregate head planning",
1782 "list",
1783 ));
1784 }
1785 Term::Cons { .. } => {
1786 return Err(unsupported_probabilistic_term_error(
1787 "aggregate head planning",
1788 "cons",
1789 ));
1790 }
1791 Term::Compound { .. } => {
1792 return Err(unsupported_probabilistic_term_error(
1793 "aggregate head planning",
1794 "compound",
1795 ));
1796 }
1797 Term::PredRef(_) => {
1798 return Err(unsupported_probabilistic_term_error(
1799 "aggregate head planning",
1800 "predref",
1801 ));
1802 }
1803 }
1804 }
1805
1806 Ok((key_vars, key_var_to_pos, agg_specs, agg_to_pos))
1807}
1808
1809fn canonical_binding_key(binding: &HashMap<String, Value>) -> Vec<(String, Value)> {
1810 let mut key: Vec<(String, Value)> = binding
1811 .iter()
1812 .map(|(name, value)| (name.clone(), value.clone()))
1813 .collect();
1814 key.sort();
1815 key
1816}
1817
1818fn pir_const_value(builder: &PirBuilder, node: PirNodeId) -> Option<bool> {
1819 match builder.pir.node(node) {
1820 Some(crate::pir::PirNode::Const(value)) => Some(*value),
1821 _ => None,
1822 }
1823}
1824
1825fn update_aggregate_states(
1826 states: &mut [AggState],
1827 agg_specs: &[(AggOp, String)],
1828 row: &AggregateProvRow,
1829) -> Result<()> {
1830 for (idx, (op, var)) in agg_specs.iter().enumerate() {
1831 let v = row
1832 .binding
1833 .get(var)
1834 .ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
1835 states[idx].update(*op, v)?;
1836 }
1837 Ok(())
1838}
1839
1840fn materialize_aggregate_tuple(
1841 head: &Atom,
1842 group_key: &[Value],
1843 key_var_to_pos: &HashMap<String, usize>,
1844 agg_specs: &[(AggOp, String)],
1845 agg_to_pos: &HashMap<(AggOp, String), usize>,
1846 agg_states: &[AggState],
1847) -> Result<Vec<Value>> {
1848 let mut tuple: Vec<Value> = Vec::with_capacity(head.terms.len());
1849 for term in &head.terms {
1850 match term {
1851 Term::Variable(name) => {
1852 let pos = *key_var_to_pos.get(name).ok_or_else(|| {
1853 XlogError::Compilation(format!(
1854 "Aggregate head variable {} is not a group key",
1855 name
1856 ))
1857 })?;
1858 tuple.push(group_key[pos].clone());
1859 }
1860 Term::Aggregate(AggExpr { op, variable }) => {
1861 let idx = *agg_to_pos
1862 .get(&(*op, variable.clone()))
1863 .expect("agg_to_pos missing");
1864 let spec = agg_specs
1865 .get(idx)
1866 .expect("aggregate state index should have a spec");
1867 tuple.push(agg_states[idx].finish(spec.0)?);
1868 }
1869 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1870 tuple.push(value_from_term(term)?);
1871 }
1872 Term::Anonymous => unreachable!("aggregate head plan rejects anonymous terms"),
1873 Term::List(_) => {
1874 return Err(unsupported_probabilistic_term_error(
1875 "aggregate head materialization",
1876 "list",
1877 ));
1878 }
1879 Term::Cons { .. } => {
1880 return Err(unsupported_probabilistic_term_error(
1881 "aggregate head materialization",
1882 "cons",
1883 ));
1884 }
1885 Term::Compound { .. } => {
1886 return Err(unsupported_probabilistic_term_error(
1887 "aggregate head materialization",
1888 "compound",
1889 ));
1890 }
1891 Term::PredRef(_) => {
1892 return Err(unsupported_probabilistic_term_error(
1893 "aggregate head materialization",
1894 "predref",
1895 ));
1896 }
1897 }
1898 }
1899 Ok(tuple)
1900}
1901
1902fn select_relation<'a>(
1903 atom: &Atom,
1904 body_index: usize,
1905 global: &'a BTreeMap<String, Relation>,
1906 full_scc: &'a BTreeMap<String, Relation>,
1907 delta_scc: Option<(usize, &'a BTreeMap<String, Relation>)>,
1908) -> Result<&'a Relation> {
1909 if let Some((delta_index, delta_map)) = delta_scc {
1910 if delta_index == body_index {
1911 return delta_map.get(&atom.predicate).ok_or_else(|| {
1912 XlogError::Compilation(format!(
1913 "Missing delta relation for predicate {}",
1914 atom.predicate
1915 ))
1916 });
1917 }
1918 }
1919 if let Some(rel) = full_scc.get(&atom.predicate) {
1920 return Ok(rel);
1921 }
1922 global
1923 .get(&atom.predicate)
1924 .ok_or_else(|| XlogError::Compilation(format!("Unknown predicate {}", atom.predicate)))
1925}
1926
1927pub(crate) fn unify_atom(
1928 atom: &Atom,
1929 tuple: &[Value],
1930 binding: &mut HashMap<String, Value>,
1931) -> Result<bool> {
1932 if atom.terms.len() != tuple.len() {
1933 return Err(XlogError::Compilation(format!(
1934 "Arity mismatch for {}: atom has {}, tuple has {}",
1935 atom.predicate,
1936 atom.terms.len(),
1937 tuple.len()
1938 )));
1939 }
1940 for (term, value) in atom.terms.iter().zip(tuple.iter()) {
1941 match term {
1942 Term::Variable(name) => match binding.get(name) {
1943 Some(existing) => {
1944 if existing != value {
1945 return Ok(false);
1946 }
1947 }
1948 None => {
1949 binding.insert(name.clone(), value.clone());
1950 }
1951 },
1952 Term::Anonymous => {}
1953 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1954 if &value_from_term(term)? != value {
1955 return Ok(false);
1956 }
1957 }
1958 Term::Aggregate(AggExpr { op: _, variable: _ }) => {
1959 return Err(XlogError::Compilation(
1960 "Aggregation not supported in provenance extraction".to_string(),
1961 ));
1962 }
1963 Term::List(_) => {
1964 return Err(unsupported_probabilistic_term_error("unification", "list"))
1965 }
1966 Term::Cons { .. } => {
1967 return Err(unsupported_probabilistic_term_error("unification", "cons"))
1968 }
1969 Term::Compound { .. } => {
1970 return Err(unsupported_probabilistic_term_error(
1971 "unification",
1972 "compound",
1973 ));
1974 }
1975 Term::PredRef(_) => {
1976 return Err(unsupported_probabilistic_term_error(
1977 "unification",
1978 "predref",
1979 ))
1980 }
1981 }
1982 }
1983 Ok(true)
1984}
1985
1986fn materialize_head(head: &Atom, binding: &HashMap<String, Value>) -> Result<Vec<Value>> {
1987 let mut out = Vec::with_capacity(head.terms.len());
1988 for term in &head.terms {
1989 match term {
1990 Term::Variable(name) => {
1991 let v = binding.get(name).ok_or_else(|| {
1992 XlogError::Compilation(format!(
1993 "Unbound head variable {} in {}",
1994 name, head.predicate
1995 ))
1996 })?;
1997 out.push(v.clone());
1998 }
1999 Term::Anonymous => {
2000 return Err(XlogError::Compilation(format!(
2001 "Anonymous variable in head of {} is not supported",
2002 head.predicate
2003 )));
2004 }
2005 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
2006 out.push(value_from_term(term)?);
2007 }
2008 Term::Aggregate(AggExpr {
2009 op: AggOp::Count,
2010 variable: _,
2011 })
2012 | Term::Aggregate(AggExpr {
2013 op: AggOp::Sum,
2014 variable: _,
2015 })
2016 | Term::Aggregate(AggExpr {
2017 op: AggOp::Min,
2018 variable: _,
2019 })
2020 | Term::Aggregate(AggExpr {
2021 op: AggOp::Max,
2022 variable: _,
2023 })
2024 | Term::Aggregate(AggExpr {
2025 op: AggOp::LogSumExp,
2026 variable: _,
2027 }) => {
2028 return Err(XlogError::Compilation(
2029 "Aggregation not supported in provenance extraction".to_string(),
2030 ));
2031 }
2032 Term::List(_) => {
2033 return Err(unsupported_probabilistic_term_error(
2034 "head materialization",
2035 "list",
2036 ));
2037 }
2038 Term::Cons { .. } => {
2039 return Err(unsupported_probabilistic_term_error(
2040 "head materialization",
2041 "cons",
2042 ));
2043 }
2044 Term::Compound { .. } => {
2045 return Err(unsupported_probabilistic_term_error(
2046 "head materialization",
2047 "compound",
2048 ));
2049 }
2050 Term::PredRef(_) => {
2051 return Err(unsupported_probabilistic_term_error(
2052 "head materialization",
2053 "predref",
2054 ));
2055 }
2056 }
2057 }
2058 Ok(out)
2059}
2060
2061pub(crate) fn eval_comparison(
2062 op: CompOp,
2063 left: &Term,
2064 right: &Term,
2065 binding: &HashMap<String, Value>,
2066) -> Result<bool> {
2067 let l = resolve_term(left, binding)?;
2068 let r = resolve_term(right, binding)?;
2069 match (l, r) {
2070 (Value::I64(a), Value::I64(b)) => Ok(compare_ord(op, a.cmp(&b))),
2071 (Value::F64(a_bits), Value::F64(b_bits)) => {
2072 let a = f64::from_bits(a_bits);
2073 let b = f64::from_bits(b_bits);
2074 match op {
2075 CompOp::Eq => Ok(a == b),
2076 CompOp::Ne => Ok(a != b),
2077 CompOp::Lt => Ok(a < b),
2078 CompOp::Le => Ok(a <= b),
2079 CompOp::Gt => Ok(a > b),
2080 CompOp::Ge => Ok(a >= b),
2081 }
2082 }
2083 (Value::Symbol(a), Value::Symbol(b)) => Ok(compare_ord(op, a.cmp(&b))),
2084 (Value::String(a), Value::String(b)) => Ok(compare_ord(op, a.cmp(&b))),
2085 _ => Err(XlogError::Compilation(
2086 "Comparison between differing types is not supported".to_string(),
2087 )),
2088 }
2089}
2090
2091pub(crate) fn compare_ord(op: CompOp, ord: std::cmp::Ordering) -> bool {
2092 use std::cmp::Ordering;
2093 match op {
2094 CompOp::Eq => ord == Ordering::Equal,
2095 CompOp::Ne => ord != Ordering::Equal,
2096 CompOp::Lt => ord == Ordering::Less,
2097 CompOp::Le => ord == Ordering::Less || ord == Ordering::Equal,
2098 CompOp::Gt => ord == Ordering::Greater,
2099 CompOp::Ge => ord == Ordering::Greater || ord == Ordering::Equal,
2100 }
2101}
2102
2103pub(crate) fn resolve_term(term: &Term, binding: &HashMap<String, Value>) -> Result<Value> {
2104 match term {
2105 Term::Variable(name) => binding.get(name).cloned().ok_or_else(|| {
2106 XlogError::Compilation(format!("Unbound variable {} in comparison", name))
2107 }),
2108 Term::Anonymous => Err(XlogError::Compilation(
2109 "Anonymous variable not allowed in comparison".to_string(),
2110 )),
2111 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
2112 value_from_term(term)
2113 }
2114 Term::Aggregate(_) => Err(XlogError::Compilation(
2115 "Aggregation not supported in provenance extraction".to_string(),
2116 )),
2117 Term::List(_) => Err(unsupported_probabilistic_term_error("comparison", "list")),
2118 Term::Cons { .. } => Err(unsupported_probabilistic_term_error("comparison", "cons")),
2119 Term::Compound { .. } => Err(unsupported_probabilistic_term_error(
2120 "comparison",
2121 "compound",
2122 )),
2123 Term::PredRef(_) => Err(unsupported_probabilistic_term_error(
2124 "comparison",
2125 "predref",
2126 )),
2127 }
2128}
2129
2130pub(crate) fn eval_arith_expr(expr: &ArithExpr, binding: &HashMap<String, Value>) -> Result<Value> {
2131 match expr {
2132 ArithExpr::Variable(name) => binding.get(name).cloned().ok_or_else(|| {
2133 XlogError::Compilation(format!("Unbound variable {} in arithmetic", name))
2134 }),
2135 ArithExpr::Integer(i) => Ok(Value::I64(*i)),
2136 ArithExpr::Float(f) => Ok(Value::F64(f.to_bits())),
2137 ArithExpr::Add(l, r) => eval_bin_op(l, r, binding, |a, b| a + b, |a, b| a + b),
2138 ArithExpr::Sub(l, r) => eval_bin_op(l, r, binding, |a, b| a - b, |a, b| a - b),
2139 ArithExpr::Mul(l, r) => eval_bin_op(l, r, binding, |a, b| a * b, |a, b| a * b),
2140 ArithExpr::Div(l, r) => eval_bin_op(l, r, binding, |a, b| a / b, |a, b| a / b),
2141 ArithExpr::Mod(l, r) => eval_bin_op(l, r, binding, |a, b| a % b, |a, b| a % b),
2142 ArithExpr::Abs(e) => match eval_arith_expr(e, binding)? {
2143 Value::I64(i) => Ok(Value::I64(i.abs())),
2144 Value::F64(bits) => {
2145 let f = f64::from_bits(bits).abs();
2146 Ok(Value::F64(f.to_bits()))
2147 }
2148 _ => Err(XlogError::Compilation(
2149 "abs() requires numeric input".to_string(),
2150 )),
2151 },
2152 ArithExpr::Min(l, r) => eval_bin_op(l, r, binding, |a, b| a.min(b), |a, b| a.min(b)),
2153 ArithExpr::Max(l, r) => eval_bin_op(l, r, binding, |a, b| a.max(b), |a, b| a.max(b)),
2154 ArithExpr::Pow(l, r) => {
2155 let a = eval_arith_expr(l, binding)?;
2156 let b = eval_arith_expr(r, binding)?;
2157 match (a, b) {
2158 (Value::I64(a), Value::I64(b)) => {
2159 Ok(Value::I64(a.pow(u32::try_from(b).map_err(|_| {
2160 XlogError::Compilation("pow exponent must fit in u32".to_string())
2161 })?)))
2162 }
2163 (Value::F64(a), Value::F64(b)) => Ok(Value::F64(
2164 f64::from_bits(a).powf(f64::from_bits(b)).to_bits(),
2165 )),
2166 _ => Err(XlogError::Compilation(
2167 "pow requires numeric inputs of same type".to_string(),
2168 )),
2169 }
2170 }
2171 ArithExpr::Cast(e, _ty) => {
2172 eval_arith_expr(e, binding)
2175 }
2176 ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2177 "Function call `{}` must be expanded before provenance extraction",
2178 name
2179 ))),
2180 ArithExpr::Conditional { .. } => Err(XlogError::Compilation(
2181 "Conditional expressions must be expanded before provenance extraction".to_string(),
2182 )),
2183 }
2184}
2185
2186pub(crate) fn eval_bin_op<FInt, FFloat>(
2187 l: &ArithExpr,
2188 r: &ArithExpr,
2189 binding: &HashMap<String, Value>,
2190 op_int: FInt,
2191 op_float: FFloat,
2192) -> Result<Value>
2193where
2194 FInt: FnOnce(i64, i64) -> i64,
2195 FFloat: FnOnce(f64, f64) -> f64,
2196{
2197 let a = eval_arith_expr(l, binding)?;
2198 let b = eval_arith_expr(r, binding)?;
2199 match (a, b) {
2200 (Value::I64(a), Value::I64(b)) => Ok(Value::I64(op_int(a, b))),
2201 (Value::F64(a), Value::F64(b)) => Ok(Value::F64(
2202 op_float(f64::from_bits(a), f64::from_bits(b)).to_bits(),
2203 )),
2204 _ => Err(XlogError::Compilation(
2205 "Arithmetic operation requires matching numeric types".to_string(),
2206 )),
2207 }
2208}