Skip to main content

xlog_prob/
provenance.rs

1//! Provenance extraction from XLOG programs into PIR.
2
3use std::collections::{BTreeMap, HashMap};
4use std::hash::{Hash, Hasher};
5
6use xlog_core::{Result, XlogError};
7use xlog_logic::ast::{
8    AggExpr, AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Evidence, ProbQuery, Program, Rule, Term,
9};
10use xlog_logic::stratify::{
11    analyze_stratification, build_dependency_graph, find_sccs_for_lowering, stratify,
12};
13
14use crate::wfs::{evaluate_wfs_rules, WfsAtom, WfsConfig, WfsLiteral, WfsRule};
15
16use crate::aggregates::{AggState, AggStateKey};
17use crate::pir::{ChoiceVarId, LeafId, PirGraph, PirNodeId};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
20pub enum Value {
21    I64(i64),
22    F64(u64),
23    Symbol(u32),
24    String(String),
25}
26
27impl From<i64> for Value {
28    fn from(v: i64) -> Self {
29        Self::I64(v)
30    }
31}
32
33impl From<u32> for Value {
34    fn from(v: u32) -> Self {
35        Self::Symbol(v)
36    }
37}
38
39impl From<String> for Value {
40    fn from(v: String) -> Self {
41        Self::String(v)
42    }
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
46pub struct GroundAtom {
47    pub predicate: String,
48    pub args: Vec<Value>,
49}
50
51impl GroundAtom {
52    pub fn new(predicate: impl Into<String>, args: Vec<Value>) -> Self {
53        Self {
54            predicate: predicate.into(),
55            args,
56        }
57    }
58}
59
60/// Metadata for a single Bernoulli decision stage in an annotated disjunction.
61#[derive(Debug, Clone, PartialEq)]
62pub struct ChoiceSource {
63    /// Explicit heads of the annotated disjunction, paired with probabilities.
64    /// Does not include the synthetic implicit "none" branch.
65    pub choices: Vec<(GroundAtom, f64)>,
66    /// Position of this ChoiceVarId in the m-1 Bernoulli decision chain.
67    pub choice_index: usize,
68    /// Enclosing annotated-disjunction identity. `None` in v1.
69    pub source_id: Option<usize>,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum AggregateLiftStatus {
74    Fired,
75    FallbackExactEnumeration,
76    Declined,
77}
78
79impl AggregateLiftStatus {
80    pub fn as_str(self) -> &'static str {
81        match self {
82            AggregateLiftStatus::Fired => "fired",
83            AggregateLiftStatus::FallbackExactEnumeration => "fallback_exact_enumeration",
84            AggregateLiftStatus::Declined => "declined",
85        }
86    }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct AggregateLiftReport {
91    pub predicate: String,
92    pub group_key: Vec<Value>,
93    pub operator: String,
94    pub finite_domain_source: String,
95    pub deterministic_rows: usize,
96    pub uncertain_rows: usize,
97    pub domain_size: usize,
98    pub cap: usize,
99    pub status: AggregateLiftStatus,
100    pub reason: String,
101    pub naive_outcomes: u128,
102    pub dynamic_programming_states: usize,
103}
104
105#[derive(Debug, Clone)]
106struct Relation {
107    tuples: BTreeMap<Vec<Value>, PirNodeId>,
108}
109
110impl Relation {
111    fn new() -> Self {
112        Self {
113            tuples: BTreeMap::new(),
114        }
115    }
116
117    fn get(&self, tuple: &[Value]) -> Option<PirNodeId> {
118        self.tuples.get(tuple).copied()
119    }
120
121    fn is_empty(&self) -> bool {
122        self.tuples.is_empty()
123    }
124
125    fn insert_or(&mut self, tuple: Vec<Value>, formula: PirNodeId, builder: &mut PirBuilder) {
126        let entry = self
127            .tuples
128            .entry(tuple)
129            .or_insert_with(|| builder.const_false());
130        *entry = builder.or(vec![*entry, formula]);
131    }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135enum PirKey {
136    Const(bool),
137    Lit(LeafId),
138    NegLit(LeafId),
139    And(Vec<PirNodeId>),
140    Or(Vec<PirNodeId>),
141    Decision {
142        var: ChoiceVarId,
143        child_false: PirNodeId,
144        child_true: PirNodeId,
145    },
146}
147
148impl Hash for PirKey {
149    fn hash<H: Hasher>(&self, state: &mut H) {
150        match self {
151            PirKey::Const(b) => {
152                0u8.hash(state);
153                b.hash(state);
154            }
155            PirKey::Lit(l) => {
156                1u8.hash(state);
157                l.hash(state);
158            }
159            PirKey::NegLit(l) => {
160                5u8.hash(state);
161                l.hash(state);
162            }
163            PirKey::And(children) => {
164                2u8.hash(state);
165                children.hash(state);
166            }
167            PirKey::Or(children) => {
168                3u8.hash(state);
169                children.hash(state);
170            }
171            PirKey::Decision {
172                var,
173                child_false,
174                child_true,
175            } => {
176                4u8.hash(state);
177                var.hash(state);
178                child_false.hash(state);
179                child_true.hash(state);
180            }
181        }
182    }
183}
184
185#[derive(Debug)]
186struct PirBuilder {
187    pir: PirGraph,
188    intern: HashMap<PirKey, PirNodeId>,
189    const_true: PirNodeId,
190    const_false: PirNodeId,
191    /// Children of interned Or nodes, used to flatten nested ORs and apply
192    /// absorption during normalization. Nodes absent from the map are opaque.
193    or_children: HashMap<PirNodeId, Vec<PirNodeId>>,
194    /// Children of interned And nodes (same role as `or_children`).
195    and_children: HashMap<PirNodeId, Vec<PirNodeId>>,
196}
197
198impl PirBuilder {
199    fn new() -> Self {
200        let mut pir = PirGraph::new();
201        let const_true = pir.const_true();
202        let const_false = pir.const_false();
203
204        let mut intern = HashMap::new();
205        intern.insert(PirKey::Const(true), const_true);
206        intern.insert(PirKey::Const(false), const_false);
207
208        Self {
209            pir,
210            intern,
211            const_true,
212            const_false,
213            or_children: HashMap::new(),
214            and_children: HashMap::new(),
215        }
216    }
217
218    fn finish(self) -> PirGraph {
219        self.pir
220    }
221
222    fn const_true(&self) -> PirNodeId {
223        self.const_true
224    }
225
226    fn const_false(&self) -> PirNodeId {
227        self.const_false
228    }
229
230    fn lit(&mut self, leaf: LeafId) -> PirNodeId {
231        let key = PirKey::Lit(leaf);
232        if let Some(&id) = self.intern.get(&key) {
233            return id;
234        }
235        let id = self.pir.lit(leaf);
236        self.intern.insert(key, id);
237        id
238    }
239
240    fn neg_lit(&mut self, leaf: LeafId) -> PirNodeId {
241        let key = PirKey::NegLit(leaf);
242        if let Some(&id) = self.intern.get(&key) {
243            return id;
244        }
245        let id = self.pir.neg_lit(leaf);
246        self.intern.insert(key, id);
247        id
248    }
249
250    fn and(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
251        // Flatten nested ANDs (associativity) so recursive-SCC provenance cannot
252        // grow syntactically forever while staying semantically fixed.
253        let mut flat: Vec<PirNodeId> = Vec::with_capacity(children.len());
254        for c in children {
255            match self.and_children.get(&c) {
256                Some(sub) => flat.extend_from_slice(sub),
257                None => flat.push(c),
258            }
259        }
260        let mut children = flat;
261        children.retain(|&c| c != self.const_true);
262        if children.contains(&self.const_false) {
263            return self.const_false;
264        }
265        if children.is_empty() {
266            return self.const_true;
267        }
268        if children.len() == 1 {
269            return children[0];
270        }
271        children.sort_by_key(|id| id.as_u32());
272        children.dedup();
273        // Absorption: a ∧ (a ∨ b) = a — drop any Or-child containing another member.
274        if children.len() > 1 {
275            let members = children.clone();
276            children.retain(|c| match self.or_children.get(c) {
277                Some(sub) => !sub.iter().any(|s| {
278                    s != c
279                        && members
280                            .binary_search_by_key(&s.as_u32(), |m| m.as_u32())
281                            .is_ok()
282                }),
283                None => true,
284            });
285        }
286        if children.len() == 1 {
287            return children[0];
288        }
289        let key = PirKey::And(children.clone());
290        if let Some(&id) = self.intern.get(&key) {
291            return id;
292        }
293        let id = self.pir.and(children.clone());
294        self.intern.insert(key, id);
295        self.and_children.insert(id, children);
296        id
297    }
298
299    fn or(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
300        // Flatten nested ORs (associativity) — see `and` for rationale.
301        let mut flat: Vec<PirNodeId> = Vec::with_capacity(children.len());
302        for c in children {
303            match self.or_children.get(&c) {
304                Some(sub) => flat.extend_from_slice(sub),
305                None => flat.push(c),
306            }
307        }
308        let mut children = flat;
309        children.retain(|&c| c != self.const_false);
310        if children.contains(&self.const_true) {
311            return self.const_true;
312        }
313        if children.is_empty() {
314            return self.const_false;
315        }
316        if children.len() == 1 {
317            return children[0];
318        }
319        children.sort_by_key(|id| id.as_u32());
320        children.dedup();
321        // Absorption: a ∨ (a ∧ b) = a — drop any And-child containing another member.
322        if children.len() > 1 {
323            let members = children.clone();
324            children.retain(|c| match self.and_children.get(c) {
325                Some(sub) => !sub.iter().any(|s| {
326                    s != c
327                        && members
328                            .binary_search_by_key(&s.as_u32(), |m| m.as_u32())
329                            .is_ok()
330                }),
331                None => true,
332            });
333        }
334        if children.len() == 1 {
335            return children[0];
336        }
337        let key = PirKey::Or(children.clone());
338        if let Some(&id) = self.intern.get(&key) {
339            return id;
340        }
341        let id = self.pir.or(children.clone());
342        self.intern.insert(key, id);
343        self.or_children.insert(id, children);
344        id
345    }
346
347    fn decision(
348        &mut self,
349        var: ChoiceVarId,
350        child_false: PirNodeId,
351        child_true: PirNodeId,
352    ) -> PirNodeId {
353        if child_false == child_true {
354            return child_true;
355        }
356        let key = PirKey::Decision {
357            var,
358            child_false,
359            child_true,
360        };
361        if let Some(&id) = self.intern.get(&key) {
362            return id;
363        }
364        let id = self.pir.decision(var, child_false, child_true);
365        self.intern.insert(key, id);
366        id
367    }
368
369    fn choice_lit(&mut self, var: ChoiceVarId, is_true: bool) -> PirNodeId {
370        if is_true {
371            self.decision(var, self.const_false(), self.const_true())
372        } else {
373            self.decision(var, self.const_true(), self.const_false())
374        }
375    }
376}
377
378/// Provenance extraction result: PIR graph plus per-tuple formulas and weight metadata.
379#[derive(Debug)]
380pub struct Provenance {
381    pub pir: PirGraph,
382    pub leaf_probs: BTreeMap<LeafId, f64>,
383    pub choice_probs: BTreeMap<ChoiceVarId, (f64, f64)>,
384    tuple_formulas: BTreeMap<GroundAtom, PirNodeId>,
385    pub queries: Vec<GroundAtom>,
386    pub evidence: Vec<(GroundAtom, bool)>,
387    pub leaf_atoms: BTreeMap<LeafId, GroundAtom>,
388    pub choice_sources: BTreeMap<ChoiceVarId, ChoiceSource>,
389    pub aggregate_lifting: Vec<AggregateLiftReport>,
390}
391
392impl Provenance {
393    pub fn query_formula(&self, predicate: &str, args: &[Value]) -> Option<PirNodeId> {
394        self.tuple_formulas
395            .get(&GroundAtom::new(predicate, args.to_vec()))
396            .copied()
397    }
398
399    pub fn leaf_atom(&self, leaf: LeafId) -> Option<&GroundAtom> {
400        self.leaf_atoms.get(&leaf)
401    }
402
403    pub fn choice_source(&self, var: ChoiceVarId) -> Option<&ChoiceSource> {
404        self.choice_sources.get(&var)
405    }
406
407    pub fn atoms_with_formulas(&self) -> impl Iterator<Item = (&GroundAtom, PirNodeId)> + '_ {
408        self.tuple_formulas.iter().map(|(atom, &id)| (atom, id))
409    }
410}
411
412pub fn extract_from_source(source: &str) -> Result<Provenance> {
413    let program = xlog_logic::parse_program(source)?;
414    extract_from_program(&program)
415}
416
417pub fn extract_from_program(program: &Program) -> Result<Provenance> {
418    // Stratify first to fail fast on unsupported recursion patterns.
419    let _ = stratify(program)?;
420
421    let mut builder = PirBuilder::new();
422
423    let mut leaf_probs: BTreeMap<LeafId, f64> = BTreeMap::new();
424    let mut choice_probs: BTreeMap<ChoiceVarId, (f64, f64)> = BTreeMap::new();
425    let mut leaf_atoms: BTreeMap<LeafId, GroundAtom> = BTreeMap::new();
426    let mut choice_sources: BTreeMap<ChoiceVarId, ChoiceSource> = BTreeMap::new();
427    let mut aggregate_lifting: Vec<AggregateLiftReport> = Vec::new();
428
429    let mut store: BTreeMap<String, Relation> = BTreeMap::new();
430
431    // Deterministic facts.
432    for fact in program.facts() {
433        let key = atom_key_from_ground_atom(&fact.head)?;
434        let rel = store
435            .entry(key.predicate.clone())
436            .or_insert_with(Relation::new);
437        rel.insert_or(key.args.clone(), builder.const_true(), &mut builder);
438    }
439
440    // Probabilistic facts.
441    let mut next_leaf: u32 = 0;
442    for pf in &program.prob_facts {
443        validate_prob(pf.prob, "probabilistic fact")?;
444        let key = atom_key_from_ground_atom(&pf.atom)?;
445        let leaf = LeafId::new(next_leaf);
446        next_leaf = next_leaf.checked_add(1).ok_or_else(|| {
447            XlogError::Compilation("probabilistic fact leaf id overflow".to_string())
448        })?;
449        leaf_probs.insert(leaf, pf.prob);
450        leaf_atoms.insert(leaf, key.clone());
451
452        let rel = store
453            .entry(key.predicate.clone())
454            .or_insert_with(Relation::new);
455        rel.insert_or(key.args.clone(), builder.lit(leaf), &mut builder);
456    }
457
458    // Annotated disjunctions: lower to a chain of Bernoulli decisions.
459    let mut next_choice: u32 = 0;
460    for ad in &program.annotated_disjunctions {
461        if ad.choices.is_empty() {
462            return Err(XlogError::Compilation(
463                "Annotated disjunction must contain at least one choice".to_string(),
464            ));
465        }
466        let (vars, outcome_formulas) = compile_annotated_disjunction(
467            ad,
468            &mut next_choice,
469            &mut choice_probs,
470            &mut choice_sources,
471            &mut builder,
472        )?;
473        let _ = vars;
474
475        for (pf, formula) in ad.choices.iter().zip(outcome_formulas) {
476            let key = atom_key_from_ground_atom(&pf.atom)?;
477            let rel = store
478                .entry(key.predicate.clone())
479                .or_insert_with(Relation::new);
480            rel.insert_or(key.args.clone(), formula, &mut builder);
481        }
482    }
483
484    // Evaluate rules SCC-by-SCC (semi-naive for recursive SCCs).
485    let graph = build_dependency_graph(program);
486    for pred in &graph.predicates {
487        store.entry(pred.clone()).or_insert_with(Relation::new);
488    }
489
490    // Use analyze_stratification to detect non-monotone SCCs
491    let strat_result = analyze_stratification(program);
492    let sccs = find_sccs_for_lowering(&graph);
493
494    // Build a set of SCC indices that are non-monotone
495    // We need to map the SCCs from find_sccs_for_lowering to analyze_stratification
496    // Both use the same SCC algorithm, so indices should match
497    let non_monotone_scc_preds: std::collections::HashSet<String> = strat_result
498        .sccs
499        .iter()
500        .enumerate()
501        .filter(|(i, _)| strat_result.non_monotone_sccs.contains(i))
502        .flat_map(|(_, scc)| scc.iter().cloned())
503        .collect();
504
505    let mut rules_by_head: BTreeMap<String, Vec<Rule>> = BTreeMap::new();
506    for rule in program.proper_rules() {
507        // Note: Negation is now supported via stratified evaluation and negate_provenance()
508        rules_by_head
509            .entry(rule.head.predicate.clone())
510            .or_default()
511            .push(rule.clone());
512    }
513
514    for scc in sccs {
515        let mut scc_rules: Vec<Rule> = Vec::new();
516        for pred in &scc {
517            if let Some(rules) = rules_by_head.get(pred) {
518                scc_rules.extend(rules.iter().cloned());
519            }
520        }
521        if scc_rules.is_empty() {
522            continue;
523        }
524
525        // Check if any predicate in this SCC is in a non-monotone cycle
526        let is_non_monotone = scc.iter().any(|p| non_monotone_scc_preds.contains(p));
527
528        if is_non_monotone {
529            // Use WFS for non-monotone SCCs (cycles through negation)
530            eval_non_monotone_scc_with_wfs(&scc, &scc_rules, &mut store, &mut builder)?;
531        } else {
532            let recursive = is_recursive_scc(&scc, &scc_rules);
533            if recursive {
534                eval_recursive_scc(
535                    &scc,
536                    &scc_rules,
537                    &mut store,
538                    &mut builder,
539                    &mut aggregate_lifting,
540                )?;
541            } else {
542                eval_non_recursive_scc(
543                    &scc_rules,
544                    &mut store,
545                    &mut builder,
546                    &mut aggregate_lifting,
547                )?;
548            }
549        }
550    }
551
552    // Snapshot tuple formulas.
553    let mut tuple_formulas: BTreeMap<GroundAtom, PirNodeId> = BTreeMap::new();
554    for (pred, rel) in &store {
555        for (tuple, formula) in &rel.tuples {
556            tuple_formulas.insert(GroundAtom::new(pred.clone(), tuple.clone()), *formula);
557        }
558    }
559
560    let mut queries: Vec<GroundAtom> = Vec::new();
561    for ProbQuery { atom } in &program.prob_queries {
562        queries.push(atom_key_from_ground_atom(atom)?);
563    }
564
565    let mut evidence: Vec<(GroundAtom, bool)> = Vec::new();
566    for Evidence { atom, value } in &program.evidence {
567        evidence.push((atom_key_from_ground_atom(atom)?, *value));
568    }
569
570    Ok(Provenance {
571        pir: builder.finish(),
572        leaf_probs,
573        choice_probs,
574        tuple_formulas,
575        queries,
576        evidence,
577        leaf_atoms,
578        choice_sources,
579        aggregate_lifting,
580    })
581}
582
583pub(crate) fn validate_prob(p: f64, what: &str) -> Result<()> {
584    if !(0.0..=1.0).contains(&p) || p.is_nan() {
585        return Err(XlogError::Compilation(format!(
586            "Invalid probability {} for {} (expected 0<=p<=1)",
587            p, what
588        )));
589    }
590    Ok(())
591}
592
593pub(crate) fn atom_key_from_ground_atom(atom: &Atom) -> Result<GroundAtom> {
594    let mut args = Vec::with_capacity(atom.terms.len());
595    for term in &atom.terms {
596        if !term.is_constant() {
597            return Err(XlogError::Compilation(format!(
598                "Expected ground atom, found non-constant term in {}",
599                atom.predicate
600            )));
601        }
602        args.push(value_from_term(term)?);
603    }
604    Ok(GroundAtom::new(atom.predicate.clone(), args))
605}
606
607pub(crate) fn value_from_term(term: &Term) -> Result<Value> {
608    match term {
609        Term::Integer(i) => Ok(Value::I64(*i)),
610        Term::Float(f) => Ok(Value::F64(f.to_bits())),
611        Term::String(s) => Ok(Value::String(s.clone())),
612        Term::Symbol(id) => Ok(Value::Symbol(*id)),
613        Term::Variable(_) | Term::Anonymous | Term::Aggregate(_) => Err(XlogError::Compilation(
614            "Non-constant term cannot be converted to a value".to_string(),
615        )),
616        Term::List(_) => Err(unsupported_probabilistic_term_error(
617            "value conversion",
618            "list",
619        )),
620        Term::Cons { .. } => Err(unsupported_probabilistic_term_error(
621            "value conversion",
622            "cons",
623        )),
624        Term::Compound { .. } => Err(unsupported_probabilistic_term_error(
625            "value conversion",
626            "compound",
627        )),
628        Term::PredRef(_) => Err(unsupported_probabilistic_term_error(
629            "value conversion",
630            "predref",
631        )),
632    }
633}
634
635fn unsupported_probabilistic_term_error(context: &str, kind: &str) -> XlogError {
636    XlogError::Compilation(format!(
637        "high-level term form '{}' is parsed but not supported in probabilistic provenance {} until a lowering/materialization path exists",
638        kind, context
639    ))
640}
641
642fn compile_annotated_disjunction(
643    ad: &xlog_logic::ast::AnnotatedDisjunction,
644    next_choice: &mut u32,
645    choice_probs: &mut BTreeMap<ChoiceVarId, (f64, f64)>,
646    choice_sources: &mut BTreeMap<ChoiceVarId, ChoiceSource>,
647    builder: &mut PirBuilder,
648) -> Result<(Vec<ChoiceVarId>, Vec<PirNodeId>)> {
649    for pf in &ad.choices {
650        validate_prob(pf.prob, "annotated disjunction choice")?;
651        let _ = atom_key_from_ground_atom(&pf.atom)?;
652    }
653
654    let explicit_choices: Vec<(GroundAtom, f64)> = ad
655        .choices
656        .iter()
657        .map(|pf| {
658            let atom = atom_key_from_ground_atom(&pf.atom).unwrap();
659            (atom, pf.prob)
660        })
661        .collect();
662
663    let mut probs: Vec<f64> = ad.choices.iter().map(|pf| pf.prob).collect();
664    let sum: f64 = probs.iter().copied().sum();
665    let eps = 1e-12;
666    if sum > 1.0 + eps {
667        return Err(XlogError::Compilation(format!(
668            "Annotated disjunction probabilities sum to {} (> 1.0)",
669            sum
670        )));
671    }
672
673    let mut has_none = false;
674    let none_prob = (1.0 - sum).max(0.0);
675    if none_prob > eps {
676        probs.push(none_prob);
677        has_none = true;
678    }
679
680    let m = probs.len();
681    if m == 1 {
682        return Ok((Vec::new(), vec![builder.const_true()]));
683    }
684
685    let mut vars: Vec<ChoiceVarId> = Vec::with_capacity(m.saturating_sub(1));
686    let mut remaining = 1.0f64;
687    for (i, &p_i) in probs.iter().enumerate().take(m - 1) {
688        let cond_true = if remaining <= 0.0 {
689            0.0
690        } else {
691            p_i / remaining
692        };
693        validate_prob(cond_true, "annotated disjunction conditional")?;
694        let cond_false = 1.0 - cond_true;
695        let var = ChoiceVarId::new(*next_choice);
696        *next_choice = (*next_choice).checked_add(1).ok_or_else(|| {
697            XlogError::Compilation("annotated disjunction choice id overflow".to_string())
698        })?;
699        vars.push(var);
700        choice_probs.insert(var, (cond_true, cond_false));
701        choice_sources.insert(
702            var,
703            ChoiceSource {
704                choices: explicit_choices.clone(),
705                choice_index: i,
706                source_id: None,
707            },
708        );
709        remaining -= p_i;
710    }
711
712    let mut outcome_formulas: Vec<PirNodeId> = Vec::new();
713    for i in 0..ad.choices.len() {
714        let mut conds: Vec<PirNodeId> = Vec::new();
715        for (j, &var) in vars.iter().enumerate() {
716            if j < i {
717                conds.push(builder.choice_lit(var, false));
718            } else if j == i {
719                conds.push(builder.choice_lit(var, true));
720                break;
721            }
722        }
723        outcome_formulas.push(builder.and(conds));
724    }
725
726    if has_none {
727        // None branch consumes the final remaining probability; it produces no fact.
728        // We still need the decision variables so probabilities normalize.
729    }
730
731    Ok((vars, outcome_formulas))
732}
733
734fn is_recursive_scc(scc: &[String], rules: &[Rule]) -> bool {
735    if scc.len() > 1 {
736        return true;
737    }
738    let Some(only) = scc.first() else {
739        return false;
740    };
741    for rule in rules {
742        for lit in &rule.body {
743            if let BodyLiteral::Positive(atom) = lit {
744                if &atom.predicate == only {
745                    return true;
746                }
747            }
748        }
749    }
750    false
751}
752
753fn eval_non_recursive_scc(
754    rules: &[Rule],
755    store: &mut BTreeMap<String, Relation>,
756    builder: &mut PirBuilder,
757    aggregate_lifting: &mut Vec<AggregateLiftReport>,
758) -> Result<()> {
759    for rule in rules {
760        let derived = eval_rule(
761            rule,
762            store,
763            &BTreeMap::new(),
764            None,
765            builder,
766            aggregate_lifting,
767        )?;
768        let rel = store
769            .entry(rule.head.predicate.clone())
770            .or_insert_with(Relation::new);
771        for (tuple, formula) in derived {
772            rel.insert_or(tuple, formula, builder);
773        }
774    }
775    Ok(())
776}
777
778const MAX_PROVENANCE_ITERATIONS: usize = 1024;
779
780fn eval_recursive_scc(
781    scc: &[String],
782    rules: &[Rule],
783    store: &mut BTreeMap<String, Relation>,
784    builder: &mut PirBuilder,
785    aggregate_lifting: &mut Vec<AggregateLiftReport>,
786) -> Result<()> {
787    let scc_set: std::collections::HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
788
789    // Snapshot full relations for the SCC.
790    let mut full: BTreeMap<String, Relation> = BTreeMap::new();
791    for pred in scc {
792        let rel = store.get(pred).cloned().unwrap_or_else(Relation::new);
793        full.insert(pred.clone(), rel);
794    }
795
796    // Seed: evaluate all rules once against the current full snapshot.
797    let mut delta: BTreeMap<String, Relation> = BTreeMap::new();
798    for rule in rules {
799        let derived = eval_rule(rule, store, &full, None, builder, aggregate_lifting)?;
800        if derived.is_empty() {
801            continue;
802        }
803        let head = rule.head.predicate.clone();
804        let delta_rel = delta.entry(head.clone()).or_insert_with(Relation::new);
805        let full_rel = full.entry(head).or_insert_with(Relation::new);
806        for (tuple, proof) in derived {
807            let old = full_rel.get(&tuple).unwrap_or(builder.const_false());
808            let combined = builder.or(vec![old, proof]);
809            if combined != old {
810                full_rel.tuples.insert(tuple.clone(), combined);
811                delta_rel.insert_or(tuple, proof, builder);
812            }
813        }
814    }
815
816    let mut reached_fixpoint = false;
817    for _ in 0..MAX_PROVENANCE_ITERATIONS {
818        let any_delta = delta.values().any(|r| !r.is_empty());
819        if !any_delta {
820            reached_fixpoint = true;
821            break;
822        }
823
824        let full_prev = full.clone();
825        let delta_prev = delta.clone();
826        delta.clear();
827
828        for rule in rules {
829            let body_indices: Vec<usize> = rule
830                .body
831                .iter()
832                .enumerate()
833                .filter_map(|(i, lit)| match lit {
834                    BodyLiteral::Positive(atom) if scc_set.contains(atom.predicate.as_str()) => {
835                        let pred = &atom.predicate;
836                        let non_empty =
837                            delta_prev.get(pred).map(|r| !r.is_empty()).unwrap_or(false);
838                        non_empty.then_some(i)
839                    }
840                    _ => None,
841                })
842                .collect();
843            if body_indices.is_empty() {
844                continue;
845            }
846
847            let mut derived_all: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
848            for idx in body_indices {
849                let derived = eval_rule(
850                    rule,
851                    store,
852                    &full_prev,
853                    Some((idx, &delta_prev)),
854                    builder,
855                    aggregate_lifting,
856                )?;
857                for (tuple, proof) in derived {
858                    let entry = derived_all
859                        .entry(tuple)
860                        .or_insert_with(|| builder.const_false());
861                    *entry = builder.or(vec![*entry, proof]);
862                }
863            }
864
865            if derived_all.is_empty() {
866                continue;
867            }
868
869            let head = rule.head.predicate.clone();
870            let delta_rel = delta.entry(head.clone()).or_insert_with(Relation::new);
871            let full_rel = full.entry(head).or_insert_with(Relation::new);
872            for (tuple, proof) in derived_all {
873                let old = full_rel.get(&tuple).unwrap_or(builder.const_false());
874                let combined = builder.or(vec![old, proof]);
875                if combined != old {
876                    full_rel.tuples.insert(tuple.clone(), combined);
877                    delta_rel.insert_or(tuple, proof, builder);
878                }
879            }
880        }
881    }
882    if !reached_fixpoint {
883        return Err(XlogError::Compilation(format!(
884            "Provenance iteration limit ({}) exceeded for SCC {:?}",
885            MAX_PROVENANCE_ITERATIONS, scc
886        )));
887    }
888
889    // Write back SCC relations.
890    for (pred, rel) in full {
891        store.insert(pred, rel);
892    }
893
894    Ok(())
895}
896
897/// Evaluate a non-monotone SCC using Well-Founded Semantics.
898///
899/// This function handles SCCs that have cycles through negation. It:
900/// 1. Grounds the rules by enumerating all variable bindings from existing tuples
901/// 2. Converts ground rules to WFS rules
902/// 3. Calls WFS to compute the well-founded model
903/// 4. Stores the results (true atoms with provenance) back
904///
905/// Undefined atoms (those in a true cycle) get no provenance (probability 0).
906fn eval_non_monotone_scc_with_wfs(
907    scc: &[String],
908    rules: &[Rule],
909    store: &mut BTreeMap<String, Relation>,
910    builder: &mut PirBuilder,
911) -> Result<()> {
912    let scc_set: std::collections::HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
913
914    // Step 1: Ground all rules in the SCC
915    // We enumerate all possible variable bindings by iterating over existing tuples
916    let mut wfs_rules: Vec<WfsRule> = Vec::new();
917
918    for rule in rules {
919        // Ground this rule against the current store
920        let grounded = ground_rule_for_wfs(rule, store, &scc_set, builder)?;
921        wfs_rules.extend(grounded);
922    }
923
924    if wfs_rules.is_empty() {
925        // No ground rules, nothing to do
926        return Ok(());
927    }
928
929    // Step 2: Call WFS to compute the well-founded model
930    let wfs_result = evaluate_wfs_rules(&wfs_rules, &mut builder.pir, &WfsConfig::default())?;
931
932    // Step 3: Store the results back
933    // True atoms get their provenance, false/undefined atoms are not added
934    for (wfs_atom, prov) in wfs_result.true_set {
935        let rel = store
936            .entry(wfs_atom.predicate.clone())
937            .or_insert_with(Relation::new);
938        rel.insert_or(wfs_atom.args, prov, builder);
939    }
940
941    Ok(())
942}
943
944/// Ground a rule for WFS evaluation.
945///
946/// This generates all ground instances of a rule by iterating over existing tuples
947/// that match the body literals (excluding SCC predicates which are handled by WFS).
948fn ground_rule_for_wfs(
949    rule: &Rule,
950    store: &BTreeMap<String, Relation>,
951    scc_set: &std::collections::HashSet<&str>,
952    builder: &mut PirBuilder,
953) -> Result<Vec<WfsRule>> {
954    // Start with empty binding
955    let mut bindings: Vec<(HashMap<String, Value>, PirNodeId)> =
956        vec![(HashMap::new(), builder.const_true())];
957
958    // Collect body literals that are in the SCC (will become WFS body literals)
959    // and non-SCC literals (will be grounded now)
960    let mut wfs_body_template: Vec<(usize, bool)> = Vec::new(); // (body_index, is_positive)
961
962    for (idx, lit) in rule.body.iter().enumerate() {
963        match lit {
964            BodyLiteral::Positive(atom) => {
965                if scc_set.contains(atom.predicate.as_str()) {
966                    // This will become a WFS body literal
967                    wfs_body_template.push((idx, true));
968                } else {
969                    // Ground now by iterating over existing tuples
970                    let rel = store.get(&atom.predicate);
971                    let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
972
973                    for (binding, prov) in bindings {
974                        if let Some(rel) = rel {
975                            for (tuple, tuple_prov) in &rel.tuples {
976                                let mut new_binding = binding.clone();
977                                if unify_atom(atom, tuple, &mut new_binding)? {
978                                    let new_prov = builder.and(vec![prov, *tuple_prov]);
979                                    next_bindings.push((new_binding, new_prov));
980                                }
981                            }
982                        }
983                        // If relation doesn't exist, no tuples match
984                    }
985                    bindings = next_bindings;
986                    if bindings.is_empty() {
987                        return Ok(Vec::new());
988                    }
989                }
990            }
991            BodyLiteral::Negated(atom) => {
992                if scc_set.contains(atom.predicate.as_str()) {
993                    // This will become a WFS negative body literal
994                    wfs_body_template.push((idx, false));
995                } else {
996                    // Ground now: negation of non-SCC predicate
997                    let rel = store.get(&atom.predicate);
998                    let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
999
1000                    for (binding, prov) in bindings {
1001                        // Check if all variables in the negated atom are bound
1002                        let all_bound = atom.terms.iter().all(|t| match t {
1003                            Term::Variable(v) => binding.contains_key(v),
1004                            _ => true,
1005                        });
1006
1007                        if !all_bound {
1008                            // Skip unsafe negation
1009                            continue;
1010                        }
1011
1012                        if let Some(rel) = rel {
1013                            // Collect matching tuples
1014                            let mut matching_provs: Vec<PirNodeId> = Vec::new();
1015                            for (tuple, tuple_prov) in &rel.tuples {
1016                                let mut test_binding = binding.clone();
1017                                if unify_atom(atom, tuple, &mut test_binding)? {
1018                                    matching_provs.push(*tuple_prov);
1019                                }
1020                            }
1021
1022                            if matching_provs.is_empty() {
1023                                // No matches - closed world: negation succeeds
1024                                next_bindings.push((binding, prov));
1025                            } else {
1026                                // Negate the combined provenance
1027                                let combined = builder.or(matching_provs);
1028                                let neg_prov = negate_provenance(combined, builder);
1029                                let new_prov = builder.and(vec![prov, neg_prov]);
1030                                next_bindings.push((binding, new_prov));
1031                            }
1032                        } else {
1033                            // Relation doesn't exist - closed world: negation succeeds
1034                            next_bindings.push((binding, prov));
1035                        }
1036                    }
1037                    bindings = next_bindings;
1038                    if bindings.is_empty() {
1039                        return Ok(Vec::new());
1040                    }
1041                }
1042            }
1043            BodyLiteral::Epistemic(lit) => {
1044                return Err(XlogError::UnsupportedEpistemicConstruct {
1045                    construct: "probabilistic WFS grounding".to_string(),
1046                    context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
1047                });
1048            }
1049            BodyLiteral::Comparison(cmp) => {
1050                let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
1051                for (binding, prov) in bindings {
1052                    if eval_comparison(cmp.op, &cmp.left, &cmp.right, &binding)? {
1053                        next_bindings.push((binding, prov));
1054                    }
1055                }
1056                bindings = next_bindings;
1057                if bindings.is_empty() {
1058                    return Ok(Vec::new());
1059                }
1060            }
1061            BodyLiteral::IsExpr(is_expr) => {
1062                let mut next_bindings: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
1063                for (mut binding, prov) in bindings {
1064                    if binding.contains_key(&is_expr.target) {
1065                        return Err(XlogError::Compilation(format!(
1066                            "Is-expression target {} is already bound",
1067                            is_expr.target
1068                        )));
1069                    }
1070                    let v = eval_arith_expr(&is_expr.expr, &binding)?;
1071                    binding.insert(is_expr.target.clone(), v);
1072                    next_bindings.push((binding, prov));
1073                }
1074                bindings = next_bindings;
1075                if bindings.is_empty() {
1076                    return Ok(Vec::new());
1077                }
1078            }
1079            BodyLiteral::Univ(_) => {
1080                return Err(XlogError::Compilation(
1081                    "univ literal was not normalized before provenance extraction".to_string(),
1082                ));
1083            }
1084        }
1085    }
1086
1087    // Now create WFS rules for each binding
1088    let mut result: Vec<WfsRule> = Vec::new();
1089
1090    for (binding, external_prov) in bindings {
1091        // Build the WFS body from SCC literals
1092        let mut wfs_body: Vec<WfsLiteral> = Vec::new();
1093
1094        for &(idx, is_positive) in &wfs_body_template {
1095            let atom = match &rule.body[idx] {
1096                BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a,
1097                _ => continue,
1098            };
1099
1100            // Ground the atom with the current binding
1101            let mut args: Vec<Value> = Vec::new();
1102            for term in &atom.terms {
1103                match term {
1104                    Term::Variable(name) => {
1105                        if let Some(v) = binding.get(name) {
1106                            args.push(v.clone());
1107                        } else {
1108                            // Variable not bound - this shouldn't happen for well-formed rules
1109                            // Skip this ground instance
1110                            continue;
1111                        }
1112                    }
1113                    _ => {
1114                        args.push(value_from_term(term)?);
1115                    }
1116                }
1117            }
1118
1119            let wfs_atom = WfsAtom::new(atom.predicate.clone(), args);
1120            if is_positive {
1121                wfs_body.push(WfsLiteral::Positive(wfs_atom));
1122            } else {
1123                wfs_body.push(WfsLiteral::Negative(wfs_atom));
1124            }
1125        }
1126
1127        // Build the ground head
1128        let mut head_args: Vec<Value> = Vec::new();
1129        for term in &rule.head.terms {
1130            match term {
1131                Term::Variable(name) => {
1132                    if let Some(v) = binding.get(name) {
1133                        head_args.push(v.clone());
1134                    } else {
1135                        // Unbound head variable - skip this instance
1136                        continue;
1137                    }
1138                }
1139                _ => {
1140                    head_args.push(value_from_term(term)?);
1141                }
1142            }
1143        }
1144
1145        let wfs_head = WfsAtom::new(rule.head.predicate.clone(), head_args);
1146        result.push(WfsRule::new(wfs_head, wfs_body, external_prov));
1147    }
1148
1149    Ok(result)
1150}
1151
1152/// Negate a provenance formula, pushing negation to leaves (NNF form).
1153///
1154/// This implements the logical negation of a provenance formula by applying De Morgan's laws
1155/// to push negations down to the leaves. At the leaf level:
1156/// - `Lit { leaf }` becomes `NegLit { leaf }` (negated probabilistic fact)
1157/// - `NegLit { leaf }` becomes `Lit { leaf }` (double negation elimination)
1158/// - `Const(true)` becomes `Const(false)` and vice versa
1159fn negate_provenance(prov: PirNodeId, builder: &mut PirBuilder) -> PirNodeId {
1160    use crate::pir::PirNode;
1161    match builder.pir.node(prov).cloned() {
1162        Some(PirNode::Const(b)) => {
1163            if b {
1164                builder.const_false()
1165            } else {
1166                builder.const_true()
1167            }
1168        }
1169        Some(PirNode::Lit { leaf }) => builder.neg_lit(leaf),
1170        Some(PirNode::NegLit { leaf }) => builder.lit(leaf), // Double negation elimination
1171        Some(PirNode::And { children }) => {
1172            // De Morgan: not(A and B) = (not A) or (not B)
1173            let neg_children: Vec<PirNodeId> = children
1174                .iter()
1175                .map(|&c| negate_provenance(c, builder))
1176                .collect();
1177            builder.or(neg_children)
1178        }
1179        Some(PirNode::Or { children }) => {
1180            // De Morgan: not(A or B) = (not A) and (not B)
1181            let neg_children: Vec<PirNodeId> = children
1182                .iter()
1183                .map(|&c| negate_provenance(c, builder))
1184                .collect();
1185            builder.and(neg_children)
1186        }
1187        Some(PirNode::Decision {
1188            var,
1189            child_false,
1190            child_true,
1191        }) => {
1192            // Negate both branches
1193            let neg_false = negate_provenance(child_false, builder);
1194            let neg_true = negate_provenance(child_true, builder);
1195            builder.decision(var, neg_false, neg_true)
1196        }
1197        None => prov,
1198    }
1199}
1200
1201/// Evaluate a single rule and produce a map from head tuples to proof formulas.
1202///
1203/// `full_scc` is the per-SCC snapshot for recursive predicates; `delta_scc` is optional and
1204/// provides a delta relation for a specific body literal index.
1205fn eval_rule(
1206    rule: &Rule,
1207    global: &BTreeMap<String, Relation>,
1208    full_scc: &BTreeMap<String, Relation>,
1209    delta_scc: Option<(usize, &BTreeMap<String, Relation>)>,
1210    builder: &mut PirBuilder,
1211    aggregate_lifting: &mut Vec<AggregateLiftReport>,
1212) -> Result<BTreeMap<Vec<Value>, PirNodeId>> {
1213    let mut states: Vec<(HashMap<String, Value>, PirNodeId)> =
1214        vec![(HashMap::new(), builder.const_true())];
1215
1216    for (idx, lit) in rule.body.iter().enumerate() {
1217        let mut next_states: Vec<(HashMap<String, Value>, PirNodeId)> = Vec::new();
1218        match lit {
1219            BodyLiteral::Positive(atom) => {
1220                let rel = select_relation(atom, idx, global, full_scc, delta_scc)?;
1221                for (binding, prov) in states {
1222                    for (tuple, tuple_prov) in &rel.tuples {
1223                        let mut binding2 = binding.clone();
1224                        if unify_atom(atom, tuple, &mut binding2)? {
1225                            let prov2 = builder.and(vec![prov, *tuple_prov]);
1226                            next_states.push((binding2, prov2));
1227                        }
1228                    }
1229                }
1230            }
1231            BodyLiteral::Comparison(cmp) => {
1232                for (binding, prov) in states {
1233                    if eval_comparison(cmp.op, &cmp.left, &cmp.right, &binding)? {
1234                        next_states.push((binding, prov));
1235                    }
1236                }
1237            }
1238            BodyLiteral::IsExpr(is_expr) => {
1239                for (mut binding, prov) in states {
1240                    if binding.contains_key(&is_expr.target) {
1241                        return Err(XlogError::Compilation(format!(
1242                            "Is-expression target {} is already bound",
1243                            is_expr.target
1244                        )));
1245                    }
1246                    let v = eval_arith_expr(&is_expr.expr, &binding)?;
1247                    binding.insert(is_expr.target.clone(), v);
1248                    next_states.push((binding, prov));
1249                }
1250            }
1251            BodyLiteral::Negated(atom) => {
1252                // Stratified negation: for each binding, check if any matching tuple exists.
1253                // - If a matching tuple exists with provenance P, the negation has provenance "not P"
1254                // - If no matching tuple exists, the negation succeeds trivially (closed-world assumption)
1255                //
1256                // For negated literals, we only use the global store and full_scc snapshot,
1257                // never the delta (negation is evaluated against the complete relation).
1258                let rel = if let Some(r) = full_scc.get(&atom.predicate) {
1259                    r
1260                } else if let Some(r) = global.get(&atom.predicate) {
1261                    r
1262                } else {
1263                    // Predicate not found - closed world assumption: all negations succeed
1264                    for (binding, prov) in states {
1265                        // Ensure all variables in the negated atom are bound
1266                        let all_bound = atom.terms.iter().all(|t| match t {
1267                            Term::Variable(v) => binding.contains_key(v),
1268                            _ => true,
1269                        });
1270                        if all_bound {
1271                            next_states.push((binding, prov));
1272                        }
1273                    }
1274                    states = next_states;
1275                    if states.is_empty() {
1276                        break;
1277                    }
1278                    continue;
1279                };
1280
1281                for (binding, prov) in states {
1282                    // First, check if all variables in the negated atom are bound.
1283                    // Negation requires all variables to be bound (safety condition).
1284                    let all_bound = atom.terms.iter().all(|t| match t {
1285                        Term::Variable(v) => binding.contains_key(v),
1286                        _ => true,
1287                    });
1288                    if !all_bound {
1289                        // Skip this binding - variables must be bound before negation
1290                        continue;
1291                    }
1292
1293                    // Collect matching tuples and their provenances
1294                    let mut matching_provs: Vec<PirNodeId> = Vec::new();
1295                    for (tuple, tuple_prov) in &rel.tuples {
1296                        let mut binding2 = binding.clone();
1297                        if unify_atom(atom, tuple, &mut binding2)? {
1298                            // A match was found; we need its negated provenance
1299                            matching_provs.push(*tuple_prov);
1300                        }
1301                    }
1302
1303                    if matching_provs.is_empty() {
1304                        // No matching tuples - closed world assumption: negation succeeds trivially
1305                        next_states.push((binding, prov));
1306                    } else {
1307                        // For negation to succeed, ALL matching tuples must be "absent" (negated).
1308                        // If tuple can exist via multiple provenances (disjunction), we negate that.
1309                        // Negation of (proof_a or proof_b or ...) =
1310                        // (not proof_a) and (not proof_b) and ...
1311                        let combined_tuple_prov = builder.or(matching_provs);
1312                        let neg_prov = negate_provenance(combined_tuple_prov, builder);
1313                        let new_prov = builder.and(vec![prov, neg_prov]);
1314                        next_states.push((binding, new_prov));
1315                    }
1316                }
1317            }
1318            BodyLiteral::Epistemic(lit) => {
1319                return Err(XlogError::UnsupportedEpistemicConstruct {
1320                    construct: "probabilistic provenance evaluation".to_string(),
1321                    context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
1322                });
1323            }
1324            BodyLiteral::Univ(_) => {
1325                return Err(XlogError::Compilation(
1326                    "univ literal was not normalized before provenance extraction".to_string(),
1327                ));
1328            }
1329        }
1330        states = next_states;
1331        if states.is_empty() {
1332            break;
1333        }
1334    }
1335
1336    if rule.has_aggregation() {
1337        eval_aggregate_head_provenance(&rule.head, states, builder, aggregate_lifting)
1338    } else {
1339        let mut out: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
1340        for (binding, prov) in states {
1341            let head_tuple = materialize_head(&rule.head, &binding)?;
1342            let entry = out
1343                .entry(head_tuple)
1344                .or_insert_with(|| builder.const_false());
1345            *entry = builder.or(vec![*entry, prov]);
1346        }
1347        Ok(out)
1348    }
1349}
1350
1351const MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS: usize = 16;
1352const MAX_EXACT_PROB_COUNT_LIFT_ROWS: usize = 64;
1353
1354#[derive(Debug, Clone)]
1355struct AggregateProvRow {
1356    binding: HashMap<String, Value>,
1357    prov: PirNodeId,
1358}
1359
1360fn eval_aggregate_head_provenance(
1361    head: &Atom,
1362    states: Vec<(HashMap<String, Value>, PirNodeId)>,
1363    builder: &mut PirBuilder,
1364    aggregate_lifting: &mut Vec<AggregateLiftReport>,
1365) -> Result<BTreeMap<Vec<Value>, PirNodeId>> {
1366    let (key_vars, key_var_to_pos, agg_specs, agg_to_pos) = aggregate_head_plan(head)?;
1367
1368    let mut deduped_states: BTreeMap<Vec<(String, Value)>, AggregateProvRow> = BTreeMap::new();
1369    for (binding, prov) in states {
1370        let key = canonical_binding_key(&binding);
1371        match deduped_states.get_mut(&key) {
1372            Some(row) => {
1373                row.prov = builder.or(vec![row.prov, prov]);
1374            }
1375            None => {
1376                deduped_states.insert(key, AggregateProvRow { binding, prov });
1377            }
1378        }
1379    }
1380
1381    #[derive(Debug)]
1382    struct GroupRows {
1383        key: Vec<Value>,
1384        rows: Vec<AggregateProvRow>,
1385    }
1386
1387    let mut groups: BTreeMap<Vec<Value>, GroupRows> = BTreeMap::new();
1388    for row in deduped_states.into_values() {
1389        let mut key: Vec<Value> = Vec::with_capacity(key_vars.len());
1390        for name in &key_vars {
1391            let v = row
1392                .binding
1393                .get(name)
1394                .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1395            key.push(v.clone());
1396        }
1397        groups
1398            .entry(key.clone())
1399            .or_insert_with(|| GroupRows {
1400                key,
1401                rows: Vec::new(),
1402            })
1403            .rows
1404            .push(row);
1405    }
1406
1407    let mut out: BTreeMap<Vec<Value>, PirNodeId> = BTreeMap::new();
1408    let count_only = agg_specs.iter().all(|(op, _)| *op == AggOp::Count);
1409    for group in groups.into_values() {
1410        let mut always_rows: Vec<AggregateProvRow> = Vec::new();
1411        let mut uncertain_rows: Vec<AggregateProvRow> = Vec::new();
1412        for row in group.rows {
1413            match pir_const_value(builder, row.prov) {
1414                Some(true) => always_rows.push(row),
1415                Some(false) => {}
1416                None => uncertain_rows.push(row),
1417            }
1418        }
1419
1420        if always_rows.is_empty() && uncertain_rows.is_empty() {
1421            continue;
1422        }
1423        if count_only {
1424            if uncertain_rows.len() > MAX_EXACT_PROB_COUNT_LIFT_ROWS {
1425                return Err(XlogError::Compilation(format!(
1426                    "count aggregate lifting finite domain cap exceeded for predicate {} group {:?}: {} uncertain rows > cap {}; use prob_engine = mc or reduce the finite aggregate domain",
1427                    head.predicate,
1428                    group.key,
1429                    uncertain_rows.len(),
1430                    MAX_EXACT_PROB_COUNT_LIFT_ROWS
1431                )));
1432            }
1433            validate_count_lift_rows(&agg_specs, &always_rows, &uncertain_rows)?;
1434            record_aggregate_lift_reports(
1435                aggregate_lifting,
1436                head,
1437                &group.key,
1438                &agg_specs,
1439                always_rows.len(),
1440                uncertain_rows.len(),
1441                AggregateLiftStatus::Fired,
1442                "finite count domain lifted with exact cardinality dynamic programming",
1443                MAX_EXACT_PROB_COUNT_LIFT_ROWS,
1444                count_lift_dp_states(uncertain_rows.len()),
1445            );
1446            let count_formulas = count_lift_formulas(&uncertain_rows, builder);
1447            for (selected_uncertain_rows, proof) in count_formulas.into_iter().enumerate() {
1448                if always_rows.is_empty() && selected_uncertain_rows == 0 {
1449                    continue;
1450                }
1451                let count_value = always_rows.len() + selected_uncertain_rows;
1452                let tuple =
1453                    materialize_count_lift_tuple(head, &group.key, &key_var_to_pos, count_value)?;
1454                let entry = out.entry(tuple).or_insert_with(|| builder.const_false());
1455                *entry = builder.or(vec![*entry, proof]);
1456            }
1457            continue;
1458        }
1459
1460        if uncertain_rows.len() > MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS {
1461            return Err(XlogError::Compilation(format!(
1462                "exact probabilistic aggregate domain cap exceeded for predicate {} group {:?}: {} uncertain rows > cap {}; use prob_engine = mc or reduce the finite aggregate domain",
1463                head.predicate,
1464                group.key,
1465                uncertain_rows.len(),
1466                MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS
1467            )));
1468        }
1469        let (outcomes, dp_states) =
1470            factorized_aggregate_outcomes(&agg_specs, &always_rows, &uncertain_rows, builder)?;
1471        record_aggregate_lift_reports(
1472            aggregate_lifting,
1473            head,
1474            &group.key,
1475            &agg_specs,
1476            always_rows.len(),
1477            uncertain_rows.len(),
1478            AggregateLiftStatus::Fired,
1479            "finite outcome domain folded with factorized aggregate-state dynamic programming",
1480            MAX_EXACT_PROB_AGG_UNCERTAIN_ROWS,
1481            dp_states,
1482        );
1483
1484        for (agg_states, selected_any, proof) in outcomes {
1485            if always_rows.is_empty() && !selected_any {
1486                // No deterministic rows and no uncertain row selected: the group
1487                // is empty in this outcome, so no head tuple materializes.
1488                continue;
1489            }
1490
1491            let tuple = materialize_aggregate_tuple(
1492                head,
1493                &group.key,
1494                &key_var_to_pos,
1495                &agg_specs,
1496                &agg_to_pos,
1497                &agg_states,
1498            )?;
1499            let entry = out.entry(tuple).or_insert_with(|| builder.const_false());
1500            *entry = builder.or(vec![*entry, proof]);
1501        }
1502    }
1503
1504    Ok(out)
1505}
1506
1507/// Factorized aggregate-outcome folding for non-count exact aggregates.
1508///
1509/// Instead of enumerating all `2^k` present/absent masks over the `k` uncertain
1510/// rows (one conjunctive PIR formula per mask), fold the rows one at a time
1511/// through a dynamic program keyed by the aggregate state reached so far.
1512/// Outcomes that agree on the aggregate state share one PIR sub-DAG, so the
1513/// emitted PIR is `O(k * #distinct-states)` instead of `O(2^k)` formulas.
1514///
1515/// Rows are folded in the same order as the previous mask enumeration
1516/// (deterministic rows first, then uncertain rows in index order), so every
1517/// outcome value is bit-identical to the enumerated result and the union of
1518/// worlds reaching each outcome is unchanged (identical query probabilities).
1519///
1520/// Returns the folded outcomes as `(aggregate states, any-uncertain-row-selected,
1521/// proof formula)` triples plus the total number of DP states visited.
1522#[allow(clippy::type_complexity)]
1523fn factorized_aggregate_outcomes(
1524    agg_specs: &[(AggOp, String)],
1525    always_rows: &[AggregateProvRow],
1526    uncertain_rows: &[AggregateProvRow],
1527    builder: &mut PirBuilder,
1528) -> Result<(Vec<(Vec<AggState>, bool, PirNodeId)>, usize)> {
1529    use std::collections::btree_map::Entry;
1530
1531    fn states_key(states: &[AggState]) -> Vec<AggStateKey> {
1532        states.iter().map(AggState::dp_key).collect()
1533    }
1534
1535    let mut base: Vec<AggState> = agg_specs.iter().map(|(op, _)| AggState::new(*op)).collect();
1536    for row in always_rows {
1537        update_aggregate_states(&mut base, agg_specs, row)?;
1538    }
1539
1540    let mut dp: BTreeMap<(Vec<AggStateKey>, bool), (Vec<AggState>, PirNodeId)> = BTreeMap::new();
1541    let true_proof = builder.const_true();
1542    dp.insert((states_key(&base), false), (base, true_proof));
1543    let mut dp_states = dp.len();
1544
1545    for row in uncertain_rows {
1546        let absent = negate_provenance(row.prov, builder);
1547        let mut next: BTreeMap<(Vec<AggStateKey>, bool), (Vec<AggState>, PirNodeId)> =
1548            BTreeMap::new();
1549        for ((key, selected_any), (states, proof)) in dp {
1550            let mut present_states = states.clone();
1551            update_aggregate_states(&mut present_states, agg_specs, row)?;
1552            let present_key = states_key(&present_states);
1553            let present_proof = builder.and(vec![proof, row.prov]);
1554            match next.entry((present_key, true)) {
1555                Entry::Occupied(mut entry) => {
1556                    entry.get_mut().1 = builder.or(vec![entry.get().1, present_proof]);
1557                }
1558                Entry::Vacant(entry) => {
1559                    entry.insert((present_states, present_proof));
1560                }
1561            }
1562
1563            let absent_proof = builder.and(vec![proof, absent]);
1564            match next.entry((key, selected_any)) {
1565                Entry::Occupied(mut entry) => {
1566                    entry.get_mut().1 = builder.or(vec![entry.get().1, absent_proof]);
1567                }
1568                Entry::Vacant(entry) => {
1569                    entry.insert((states, absent_proof));
1570                }
1571            }
1572        }
1573        dp = next;
1574        dp_states += dp.len();
1575    }
1576
1577    let outcomes = dp
1578        .into_iter()
1579        .map(|((_, selected_any), (states, proof))| (states, selected_any, proof))
1580        .collect();
1581    Ok((outcomes, dp_states))
1582}
1583
1584fn validate_count_lift_rows(
1585    agg_specs: &[(AggOp, String)],
1586    always_rows: &[AggregateProvRow],
1587    uncertain_rows: &[AggregateProvRow],
1588) -> Result<()> {
1589    for (_, var) in agg_specs {
1590        for row in always_rows.iter().chain(uncertain_rows.iter()) {
1591            if !row.binding.contains_key(var) {
1592                return Err(XlogError::UnsafeVariable(var.clone()));
1593            }
1594        }
1595    }
1596    Ok(())
1597}
1598
1599fn count_lift_formulas(
1600    uncertain_rows: &[AggregateProvRow],
1601    builder: &mut PirBuilder,
1602) -> Vec<PirNodeId> {
1603    let n = uncertain_rows.len();
1604    let mut dp = vec![builder.const_false(); n + 1];
1605    dp[0] = builder.const_true();
1606
1607    for (idx, row) in uncertain_rows.iter().enumerate() {
1608        let mut next = vec![builder.const_false(); n + 1];
1609        let present = row.prov;
1610        let absent = negate_provenance(row.prov, builder);
1611        for selected in 0..=idx {
1612            let absent_case = builder.and(vec![dp[selected], absent]);
1613            next[selected] = builder.or(vec![next[selected], absent_case]);
1614
1615            let present_case = builder.and(vec![dp[selected], present]);
1616            next[selected + 1] = builder.or(vec![next[selected + 1], present_case]);
1617        }
1618        dp = next;
1619    }
1620
1621    dp
1622}
1623
1624fn materialize_count_lift_tuple(
1625    head: &Atom,
1626    group_key: &[Value],
1627    key_var_to_pos: &HashMap<String, usize>,
1628    count_value: usize,
1629) -> Result<Vec<Value>> {
1630    let count_value: i64 = count_value
1631        .try_into()
1632        .map_err(|_| XlogError::Compilation("count() overflowed i64".to_string()))?;
1633    let mut tuple: Vec<Value> = Vec::with_capacity(head.terms.len());
1634    for term in &head.terms {
1635        match term {
1636            Term::Variable(name) => {
1637                let pos = *key_var_to_pos.get(name).ok_or_else(|| {
1638                    XlogError::Compilation(format!(
1639                        "Aggregate head variable {} is not a group key",
1640                        name
1641                    ))
1642                })?;
1643                tuple.push(group_key[pos].clone());
1644            }
1645            Term::Aggregate(AggExpr {
1646                op: AggOp::Count, ..
1647            }) => tuple.push(Value::I64(count_value)),
1648            Term::Aggregate(AggExpr { op, .. }) => {
1649                return Err(XlogError::Compilation(format!(
1650                    "Internal aggregate lift state mismatch for {}",
1651                    agg_op_label(*op)
1652                )));
1653            }
1654            Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1655                tuple.push(value_from_term(term)?);
1656            }
1657            Term::Anonymous => unreachable!("aggregate head plan rejects anonymous terms"),
1658            Term::List(_) => {
1659                return Err(unsupported_probabilistic_term_error(
1660                    "aggregate head materialization",
1661                    "list",
1662                ));
1663            }
1664            Term::Cons { .. } => {
1665                return Err(unsupported_probabilistic_term_error(
1666                    "aggregate head materialization",
1667                    "cons",
1668                ));
1669            }
1670            Term::Compound { .. } => {
1671                return Err(unsupported_probabilistic_term_error(
1672                    "aggregate head materialization",
1673                    "compound",
1674                ));
1675            }
1676            Term::PredRef(_) => {
1677                return Err(unsupported_probabilistic_term_error(
1678                    "aggregate head materialization",
1679                    "predref",
1680                ));
1681            }
1682        }
1683    }
1684    Ok(tuple)
1685}
1686
1687#[allow(clippy::too_many_arguments)]
1688fn record_aggregate_lift_reports(
1689    aggregate_lifting: &mut Vec<AggregateLiftReport>,
1690    head: &Atom,
1691    group_key: &[Value],
1692    agg_specs: &[(AggOp, String)],
1693    deterministic_rows: usize,
1694    uncertain_rows: usize,
1695    status: AggregateLiftStatus,
1696    reason: &str,
1697    cap: usize,
1698    dynamic_programming_states: usize,
1699) {
1700    for (op, _) in agg_specs {
1701        aggregate_lifting.push(AggregateLiftReport {
1702            predicate: head.predicate.clone(),
1703            group_key: group_key.to_vec(),
1704            operator: agg_op_label(*op).to_string(),
1705            finite_domain_source: "grounded body rows".to_string(),
1706            deterministic_rows,
1707            uncertain_rows,
1708            domain_size: deterministic_rows + uncertain_rows,
1709            cap,
1710            status,
1711            reason: reason.to_string(),
1712            naive_outcomes: naive_outcome_count(uncertain_rows),
1713            dynamic_programming_states,
1714        });
1715    }
1716}
1717
1718fn agg_op_label(op: AggOp) -> &'static str {
1719    match op {
1720        AggOp::Count => "count",
1721        AggOp::Sum => "sum",
1722        AggOp::Min => "min",
1723        AggOp::Max => "max",
1724        AggOp::LogSumExp => "logsumexp",
1725    }
1726}
1727
1728fn naive_outcome_count(uncertain_rows: usize) -> u128 {
1729    if uncertain_rows >= u128::BITS as usize {
1730        u128::MAX
1731    } else {
1732        1u128 << uncertain_rows
1733    }
1734}
1735
1736fn count_lift_dp_states(uncertain_rows: usize) -> usize {
1737    (uncertain_rows + 1) * (uncertain_rows + 2) / 2
1738}
1739
1740type AggregatePlan = (
1741    Vec<String>,
1742    HashMap<String, usize>,
1743    Vec<(AggOp, String)>,
1744    HashMap<(AggOp, String), usize>,
1745);
1746
1747fn aggregate_head_plan(head: &Atom) -> Result<AggregatePlan> {
1748    let mut key_vars: Vec<String> = Vec::new();
1749    let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
1750    let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
1751    let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
1752
1753    for term in &head.terms {
1754        match term {
1755            Term::Variable(name) => {
1756                if !key_var_to_pos.contains_key(name) {
1757                    let pos = key_vars.len();
1758                    key_vars.push(name.clone());
1759                    key_var_to_pos.insert(name.clone(), pos);
1760                }
1761            }
1762            Term::Aggregate(agg) => {
1763                let key = (agg.op, agg.variable.clone());
1764                if let std::collections::hash_map::Entry::Vacant(entry) =
1765                    agg_to_pos.entry(key.clone())
1766                {
1767                    let pos = agg_specs.len();
1768                    agg_specs.push(key);
1769                    entry.insert(pos);
1770                }
1771            }
1772            Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {}
1773            Term::Anonymous => {
1774                return Err(XlogError::Compilation(format!(
1775                    "Anonymous variable in aggregate head of {} is not supported",
1776                    head.predicate
1777                )));
1778            }
1779            Term::List(_) => {
1780                return Err(unsupported_probabilistic_term_error(
1781                    "aggregate head planning",
1782                    "list",
1783                ));
1784            }
1785            Term::Cons { .. } => {
1786                return Err(unsupported_probabilistic_term_error(
1787                    "aggregate head planning",
1788                    "cons",
1789                ));
1790            }
1791            Term::Compound { .. } => {
1792                return Err(unsupported_probabilistic_term_error(
1793                    "aggregate head planning",
1794                    "compound",
1795                ));
1796            }
1797            Term::PredRef(_) => {
1798                return Err(unsupported_probabilistic_term_error(
1799                    "aggregate head planning",
1800                    "predref",
1801                ));
1802            }
1803        }
1804    }
1805
1806    Ok((key_vars, key_var_to_pos, agg_specs, agg_to_pos))
1807}
1808
1809fn canonical_binding_key(binding: &HashMap<String, Value>) -> Vec<(String, Value)> {
1810    let mut key: Vec<(String, Value)> = binding
1811        .iter()
1812        .map(|(name, value)| (name.clone(), value.clone()))
1813        .collect();
1814    key.sort();
1815    key
1816}
1817
1818fn pir_const_value(builder: &PirBuilder, node: PirNodeId) -> Option<bool> {
1819    match builder.pir.node(node) {
1820        Some(crate::pir::PirNode::Const(value)) => Some(*value),
1821        _ => None,
1822    }
1823}
1824
1825fn update_aggregate_states(
1826    states: &mut [AggState],
1827    agg_specs: &[(AggOp, String)],
1828    row: &AggregateProvRow,
1829) -> Result<()> {
1830    for (idx, (op, var)) in agg_specs.iter().enumerate() {
1831        let v = row
1832            .binding
1833            .get(var)
1834            .ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
1835        states[idx].update(*op, v)?;
1836    }
1837    Ok(())
1838}
1839
1840fn materialize_aggregate_tuple(
1841    head: &Atom,
1842    group_key: &[Value],
1843    key_var_to_pos: &HashMap<String, usize>,
1844    agg_specs: &[(AggOp, String)],
1845    agg_to_pos: &HashMap<(AggOp, String), usize>,
1846    agg_states: &[AggState],
1847) -> Result<Vec<Value>> {
1848    let mut tuple: Vec<Value> = Vec::with_capacity(head.terms.len());
1849    for term in &head.terms {
1850        match term {
1851            Term::Variable(name) => {
1852                let pos = *key_var_to_pos.get(name).ok_or_else(|| {
1853                    XlogError::Compilation(format!(
1854                        "Aggregate head variable {} is not a group key",
1855                        name
1856                    ))
1857                })?;
1858                tuple.push(group_key[pos].clone());
1859            }
1860            Term::Aggregate(AggExpr { op, variable }) => {
1861                let idx = *agg_to_pos
1862                    .get(&(*op, variable.clone()))
1863                    .expect("agg_to_pos missing");
1864                let spec = agg_specs
1865                    .get(idx)
1866                    .expect("aggregate state index should have a spec");
1867                tuple.push(agg_states[idx].finish(spec.0)?);
1868            }
1869            Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1870                tuple.push(value_from_term(term)?);
1871            }
1872            Term::Anonymous => unreachable!("aggregate head plan rejects anonymous terms"),
1873            Term::List(_) => {
1874                return Err(unsupported_probabilistic_term_error(
1875                    "aggregate head materialization",
1876                    "list",
1877                ));
1878            }
1879            Term::Cons { .. } => {
1880                return Err(unsupported_probabilistic_term_error(
1881                    "aggregate head materialization",
1882                    "cons",
1883                ));
1884            }
1885            Term::Compound { .. } => {
1886                return Err(unsupported_probabilistic_term_error(
1887                    "aggregate head materialization",
1888                    "compound",
1889                ));
1890            }
1891            Term::PredRef(_) => {
1892                return Err(unsupported_probabilistic_term_error(
1893                    "aggregate head materialization",
1894                    "predref",
1895                ));
1896            }
1897        }
1898    }
1899    Ok(tuple)
1900}
1901
1902fn select_relation<'a>(
1903    atom: &Atom,
1904    body_index: usize,
1905    global: &'a BTreeMap<String, Relation>,
1906    full_scc: &'a BTreeMap<String, Relation>,
1907    delta_scc: Option<(usize, &'a BTreeMap<String, Relation>)>,
1908) -> Result<&'a Relation> {
1909    if let Some((delta_index, delta_map)) = delta_scc {
1910        if delta_index == body_index {
1911            return delta_map.get(&atom.predicate).ok_or_else(|| {
1912                XlogError::Compilation(format!(
1913                    "Missing delta relation for predicate {}",
1914                    atom.predicate
1915                ))
1916            });
1917        }
1918    }
1919    if let Some(rel) = full_scc.get(&atom.predicate) {
1920        return Ok(rel);
1921    }
1922    global
1923        .get(&atom.predicate)
1924        .ok_or_else(|| XlogError::Compilation(format!("Unknown predicate {}", atom.predicate)))
1925}
1926
1927pub(crate) fn unify_atom(
1928    atom: &Atom,
1929    tuple: &[Value],
1930    binding: &mut HashMap<String, Value>,
1931) -> Result<bool> {
1932    if atom.terms.len() != tuple.len() {
1933        return Err(XlogError::Compilation(format!(
1934            "Arity mismatch for {}: atom has {}, tuple has {}",
1935            atom.predicate,
1936            atom.terms.len(),
1937            tuple.len()
1938        )));
1939    }
1940    for (term, value) in atom.terms.iter().zip(tuple.iter()) {
1941        match term {
1942            Term::Variable(name) => match binding.get(name) {
1943                Some(existing) => {
1944                    if existing != value {
1945                        return Ok(false);
1946                    }
1947                }
1948                None => {
1949                    binding.insert(name.clone(), value.clone());
1950                }
1951            },
1952            Term::Anonymous => {}
1953            Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1954                if &value_from_term(term)? != value {
1955                    return Ok(false);
1956                }
1957            }
1958            Term::Aggregate(AggExpr { op: _, variable: _ }) => {
1959                return Err(XlogError::Compilation(
1960                    "Aggregation not supported in provenance extraction".to_string(),
1961                ));
1962            }
1963            Term::List(_) => {
1964                return Err(unsupported_probabilistic_term_error("unification", "list"))
1965            }
1966            Term::Cons { .. } => {
1967                return Err(unsupported_probabilistic_term_error("unification", "cons"))
1968            }
1969            Term::Compound { .. } => {
1970                return Err(unsupported_probabilistic_term_error(
1971                    "unification",
1972                    "compound",
1973                ));
1974            }
1975            Term::PredRef(_) => {
1976                return Err(unsupported_probabilistic_term_error(
1977                    "unification",
1978                    "predref",
1979                ))
1980            }
1981        }
1982    }
1983    Ok(true)
1984}
1985
1986fn materialize_head(head: &Atom, binding: &HashMap<String, Value>) -> Result<Vec<Value>> {
1987    let mut out = Vec::with_capacity(head.terms.len());
1988    for term in &head.terms {
1989        match term {
1990            Term::Variable(name) => {
1991                let v = binding.get(name).ok_or_else(|| {
1992                    XlogError::Compilation(format!(
1993                        "Unbound head variable {} in {}",
1994                        name, head.predicate
1995                    ))
1996                })?;
1997                out.push(v.clone());
1998            }
1999            Term::Anonymous => {
2000                return Err(XlogError::Compilation(format!(
2001                    "Anonymous variable in head of {} is not supported",
2002                    head.predicate
2003                )));
2004            }
2005            Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
2006                out.push(value_from_term(term)?);
2007            }
2008            Term::Aggregate(AggExpr {
2009                op: AggOp::Count,
2010                variable: _,
2011            })
2012            | Term::Aggregate(AggExpr {
2013                op: AggOp::Sum,
2014                variable: _,
2015            })
2016            | Term::Aggregate(AggExpr {
2017                op: AggOp::Min,
2018                variable: _,
2019            })
2020            | Term::Aggregate(AggExpr {
2021                op: AggOp::Max,
2022                variable: _,
2023            })
2024            | Term::Aggregate(AggExpr {
2025                op: AggOp::LogSumExp,
2026                variable: _,
2027            }) => {
2028                return Err(XlogError::Compilation(
2029                    "Aggregation not supported in provenance extraction".to_string(),
2030                ));
2031            }
2032            Term::List(_) => {
2033                return Err(unsupported_probabilistic_term_error(
2034                    "head materialization",
2035                    "list",
2036                ));
2037            }
2038            Term::Cons { .. } => {
2039                return Err(unsupported_probabilistic_term_error(
2040                    "head materialization",
2041                    "cons",
2042                ));
2043            }
2044            Term::Compound { .. } => {
2045                return Err(unsupported_probabilistic_term_error(
2046                    "head materialization",
2047                    "compound",
2048                ));
2049            }
2050            Term::PredRef(_) => {
2051                return Err(unsupported_probabilistic_term_error(
2052                    "head materialization",
2053                    "predref",
2054                ));
2055            }
2056        }
2057    }
2058    Ok(out)
2059}
2060
2061pub(crate) fn eval_comparison(
2062    op: CompOp,
2063    left: &Term,
2064    right: &Term,
2065    binding: &HashMap<String, Value>,
2066) -> Result<bool> {
2067    let l = resolve_term(left, binding)?;
2068    let r = resolve_term(right, binding)?;
2069    match (l, r) {
2070        (Value::I64(a), Value::I64(b)) => Ok(compare_ord(op, a.cmp(&b))),
2071        (Value::F64(a_bits), Value::F64(b_bits)) => {
2072            let a = f64::from_bits(a_bits);
2073            let b = f64::from_bits(b_bits);
2074            match op {
2075                CompOp::Eq => Ok(a == b),
2076                CompOp::Ne => Ok(a != b),
2077                CompOp::Lt => Ok(a < b),
2078                CompOp::Le => Ok(a <= b),
2079                CompOp::Gt => Ok(a > b),
2080                CompOp::Ge => Ok(a >= b),
2081            }
2082        }
2083        (Value::Symbol(a), Value::Symbol(b)) => Ok(compare_ord(op, a.cmp(&b))),
2084        (Value::String(a), Value::String(b)) => Ok(compare_ord(op, a.cmp(&b))),
2085        _ => Err(XlogError::Compilation(
2086            "Comparison between differing types is not supported".to_string(),
2087        )),
2088    }
2089}
2090
2091pub(crate) fn compare_ord(op: CompOp, ord: std::cmp::Ordering) -> bool {
2092    use std::cmp::Ordering;
2093    match op {
2094        CompOp::Eq => ord == Ordering::Equal,
2095        CompOp::Ne => ord != Ordering::Equal,
2096        CompOp::Lt => ord == Ordering::Less,
2097        CompOp::Le => ord == Ordering::Less || ord == Ordering::Equal,
2098        CompOp::Gt => ord == Ordering::Greater,
2099        CompOp::Ge => ord == Ordering::Greater || ord == Ordering::Equal,
2100    }
2101}
2102
2103pub(crate) fn resolve_term(term: &Term, binding: &HashMap<String, Value>) -> Result<Value> {
2104    match term {
2105        Term::Variable(name) => binding.get(name).cloned().ok_or_else(|| {
2106            XlogError::Compilation(format!("Unbound variable {} in comparison", name))
2107        }),
2108        Term::Anonymous => Err(XlogError::Compilation(
2109            "Anonymous variable not allowed in comparison".to_string(),
2110        )),
2111        Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
2112            value_from_term(term)
2113        }
2114        Term::Aggregate(_) => Err(XlogError::Compilation(
2115            "Aggregation not supported in provenance extraction".to_string(),
2116        )),
2117        Term::List(_) => Err(unsupported_probabilistic_term_error("comparison", "list")),
2118        Term::Cons { .. } => Err(unsupported_probabilistic_term_error("comparison", "cons")),
2119        Term::Compound { .. } => Err(unsupported_probabilistic_term_error(
2120            "comparison",
2121            "compound",
2122        )),
2123        Term::PredRef(_) => Err(unsupported_probabilistic_term_error(
2124            "comparison",
2125            "predref",
2126        )),
2127    }
2128}
2129
2130pub(crate) fn eval_arith_expr(expr: &ArithExpr, binding: &HashMap<String, Value>) -> Result<Value> {
2131    match expr {
2132        ArithExpr::Variable(name) => binding.get(name).cloned().ok_or_else(|| {
2133            XlogError::Compilation(format!("Unbound variable {} in arithmetic", name))
2134        }),
2135        ArithExpr::Integer(i) => Ok(Value::I64(*i)),
2136        ArithExpr::Float(f) => Ok(Value::F64(f.to_bits())),
2137        ArithExpr::Add(l, r) => eval_bin_op(l, r, binding, |a, b| a + b, |a, b| a + b),
2138        ArithExpr::Sub(l, r) => eval_bin_op(l, r, binding, |a, b| a - b, |a, b| a - b),
2139        ArithExpr::Mul(l, r) => eval_bin_op(l, r, binding, |a, b| a * b, |a, b| a * b),
2140        ArithExpr::Div(l, r) => eval_bin_op(l, r, binding, |a, b| a / b, |a, b| a / b),
2141        ArithExpr::Mod(l, r) => eval_bin_op(l, r, binding, |a, b| a % b, |a, b| a % b),
2142        ArithExpr::Abs(e) => match eval_arith_expr(e, binding)? {
2143            Value::I64(i) => Ok(Value::I64(i.abs())),
2144            Value::F64(bits) => {
2145                let f = f64::from_bits(bits).abs();
2146                Ok(Value::F64(f.to_bits()))
2147            }
2148            _ => Err(XlogError::Compilation(
2149                "abs() requires numeric input".to_string(),
2150            )),
2151        },
2152        ArithExpr::Min(l, r) => eval_bin_op(l, r, binding, |a, b| a.min(b), |a, b| a.min(b)),
2153        ArithExpr::Max(l, r) => eval_bin_op(l, r, binding, |a, b| a.max(b), |a, b| a.max(b)),
2154        ArithExpr::Pow(l, r) => {
2155            let a = eval_arith_expr(l, binding)?;
2156            let b = eval_arith_expr(r, binding)?;
2157            match (a, b) {
2158                (Value::I64(a), Value::I64(b)) => {
2159                    Ok(Value::I64(a.pow(u32::try_from(b).map_err(|_| {
2160                        XlogError::Compilation("pow exponent must fit in u32".to_string())
2161                    })?)))
2162                }
2163                (Value::F64(a), Value::F64(b)) => Ok(Value::F64(
2164                    f64::from_bits(a).powf(f64::from_bits(b)).to_bits(),
2165                )),
2166                _ => Err(XlogError::Compilation(
2167                    "pow requires numeric inputs of same type".to_string(),
2168                )),
2169            }
2170        }
2171        ArithExpr::Cast(e, _ty) => {
2172            // For provenance compilation we preserve the numeric value; the runtime has a full
2173            // type system, but provenance needs only deterministic evaluation.
2174            eval_arith_expr(e, binding)
2175        }
2176        ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2177            "Function call `{}` must be expanded before provenance extraction",
2178            name
2179        ))),
2180        ArithExpr::Conditional { .. } => Err(XlogError::Compilation(
2181            "Conditional expressions must be expanded before provenance extraction".to_string(),
2182        )),
2183    }
2184}
2185
2186pub(crate) fn eval_bin_op<FInt, FFloat>(
2187    l: &ArithExpr,
2188    r: &ArithExpr,
2189    binding: &HashMap<String, Value>,
2190    op_int: FInt,
2191    op_float: FFloat,
2192) -> Result<Value>
2193where
2194    FInt: FnOnce(i64, i64) -> i64,
2195    FFloat: FnOnce(f64, f64) -> f64,
2196{
2197    let a = eval_arith_expr(l, binding)?;
2198    let b = eval_arith_expr(r, binding)?;
2199    match (a, b) {
2200        (Value::I64(a), Value::I64(b)) => Ok(Value::I64(op_int(a, b))),
2201        (Value::F64(a), Value::F64(b)) => Ok(Value::F64(
2202            op_float(f64::from_bits(a), f64::from_bits(b)).to_bits(),
2203        )),
2204        _ => Err(XlogError::Compilation(
2205            "Arithmetic operation requires matching numeric types".to_string(),
2206        )),
2207    }
2208}