Skip to main content

xlog_logic/
ast.rs

1//! Abstract Syntax Tree for XLOG programs
2
3use xlog_core::ScalarType;
4
5/// A term in an atom
6#[derive(Debug, Clone, PartialEq)]
7pub enum Term {
8    /// Named logic variable (e.g. `X`).
9    Variable(String),
10    /// Anonymous wildcard `_` -- each occurrence is a fresh unnamed variable.
11    Anonymous,
12    /// Integer literal.
13    Integer(i64),
14    /// Floating-point literal.
15    Float(f64),
16    /// Quoted string literal.
17    String(String),
18    /// Interned symbol ID -- use `xlog_core::symbol::resolve(id)` to get the string.
19    Symbol(u32),
20    /// Finite list literal.
21    List(Vec<Term>),
22    /// Finite cons pattern `[Head | Tail]`.
23    Cons {
24        /// Head term.
25        head: Box<Term>,
26        /// Tail term.
27        tail: Box<Term>,
28    },
29    /// Finite compound term.
30    Compound {
31        /// Functor name.
32        functor: String,
33        /// Compound arguments.
34        args: Vec<Term>,
35    },
36    /// Static predicate reference.
37    PredRef(String),
38    /// Aggregate expression (e.g. `count(X)`).
39    Aggregate(AggExpr),
40}
41
42impl Term {
43    /// Returns true if this is a named variable.
44    pub fn is_variable(&self) -> bool {
45        matches!(self, Term::Variable(_))
46    }
47
48    /// Returns true if this is an anonymous wildcard `_`
49    pub fn is_anonymous(&self) -> bool {
50        matches!(self, Term::Anonymous)
51    }
52
53    /// Returns true if this is any kind of variable (named or anonymous)
54    pub fn is_any_variable(&self) -> bool {
55        matches!(self, Term::Variable(_) | Term::Anonymous)
56    }
57
58    /// Returns true if this is a ground (non-variable, non-aggregate) term.
59    pub fn is_constant(&self) -> bool {
60        !self.is_any_variable()
61            && !matches!(
62                self,
63                Term::Aggregate(_)
64                    | Term::List(_)
65                    | Term::Cons { .. }
66                    | Term::Compound { .. }
67                    | Term::PredRef(_)
68            )
69    }
70
71    /// Returns the variable name, or None for anonymous/constants
72    pub fn variable_name(&self) -> Option<&str> {
73        match self {
74            Term::Variable(name) => Some(name),
75            _ => None,
76        }
77    }
78
79    /// Return all named variables referenced by this term.
80    pub fn variables(&self) -> Vec<&str> {
81        match self {
82            Term::Variable(name) => vec![name.as_str()],
83            Term::List(items) => items.iter().flat_map(Term::variables).collect(),
84            Term::Cons { head, tail } => {
85                let mut vars = head.variables();
86                vars.extend(tail.variables());
87                vars
88            }
89            Term::Compound { args, .. } => args.iter().flat_map(Term::variables).collect(),
90            Term::Anonymous
91            | Term::Integer(_)
92            | Term::Float(_)
93            | Term::String(_)
94            | Term::Symbol(_)
95            | Term::PredRef(_)
96            | Term::Aggregate(_) => vec![],
97        }
98    }
99}
100
101/// Aggregate expression
102#[derive(Debug, Clone, PartialEq)]
103pub struct AggExpr {
104    /// The aggregation operator.
105    pub op: AggOp,
106    /// The variable being aggregated.
107    pub variable: String,
108}
109
110/// Aggregation operator
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum AggOp {
113    /// Count aggregation.
114    Count,
115    /// Sum aggregation.
116    Sum,
117    /// Minimum aggregation.
118    Min,
119    /// Maximum aggregation.
120    Max,
121    /// Log-sum-exp aggregation.
122    LogSumExp,
123}
124
125/// Arithmetic expression tree
126#[derive(Debug, Clone, PartialEq)]
127pub enum ArithExpr {
128    /// Variable reference.
129    Variable(String),
130    /// Integer literal.
131    Integer(i64),
132    /// Float literal.
133    Float(f64),
134
135    /// Addition.
136    Add(Box<ArithExpr>, Box<ArithExpr>),
137    /// Subtraction.
138    Sub(Box<ArithExpr>, Box<ArithExpr>),
139    /// Multiplication.
140    Mul(Box<ArithExpr>, Box<ArithExpr>),
141    /// Division.
142    Div(Box<ArithExpr>, Box<ArithExpr>),
143    /// Modulo.
144    Mod(Box<ArithExpr>, Box<ArithExpr>),
145
146    /// Absolute value.
147    Abs(Box<ArithExpr>),
148    /// Minimum of two values.
149    Min(Box<ArithExpr>, Box<ArithExpr>),
150    /// Maximum of two values.
151    Max(Box<ArithExpr>, Box<ArithExpr>),
152    /// Power (base, exponent).
153    Pow(Box<ArithExpr>, Box<ArithExpr>),
154
155    /// Type cast to the given scalar type.
156    Cast(Box<ArithExpr>, ScalarType),
157
158    /// User-defined function call
159    FuncCall {
160        /// Function name being invoked.
161        name: String,
162        /// Positional arguments supplied to the function.
163        args: Vec<ArithExpr>,
164    },
165
166    /// Conditional expression (for expanded function bodies)
167    Conditional {
168        /// Left operand of the condition.
169        cond_left: Box<ArithExpr>,
170        /// Comparison operator used in the condition.
171        cond_op: CompOp,
172        /// Right operand of the condition.
173        cond_right: Box<ArithExpr>,
174        /// Expression evaluated when the condition is true.
175        then_expr: Box<ArithExpr>,
176        /// Expression evaluated when the condition is false.
177        else_expr: Box<ArithExpr>,
178    },
179}
180
181impl ArithExpr {
182    /// Get all variable names used in this expression
183    pub fn variables(&self) -> Vec<&str> {
184        match self {
185            ArithExpr::Variable(name) => vec![name.as_str()],
186            ArithExpr::Integer(_) | ArithExpr::Float(_) => vec![],
187            ArithExpr::Add(l, r)
188            | ArithExpr::Sub(l, r)
189            | ArithExpr::Mul(l, r)
190            | ArithExpr::Div(l, r)
191            | ArithExpr::Mod(l, r)
192            | ArithExpr::Min(l, r)
193            | ArithExpr::Max(l, r)
194            | ArithExpr::Pow(l, r) => {
195                let mut vars = l.variables();
196                vars.extend(r.variables());
197                vars
198            }
199            ArithExpr::Abs(e) | ArithExpr::Cast(e, _) => e.variables(),
200            ArithExpr::FuncCall { args, .. } => args.iter().flat_map(|a| a.variables()).collect(),
201            ArithExpr::Conditional {
202                cond_left,
203                cond_right,
204                then_expr,
205                else_expr,
206                ..
207            } => {
208                let mut vars = cond_left.variables();
209                vars.extend(cond_right.variables());
210                vars.extend(then_expr.variables());
211                vars.extend(else_expr.variables());
212                vars
213            }
214        }
215    }
216}
217
218/// Is-expression for variable binding: Z is X + Y
219#[derive(Debug, Clone, PartialEq)]
220pub struct IsExpr {
221    /// Target variable (must be a fresh, unbound variable).
222    pub target: String,
223    /// Arithmetic expression to evaluate.
224    pub expr: ArithExpr,
225}
226
227/// An atom (predicate applied to terms)
228#[derive(Debug, Clone, PartialEq)]
229pub struct Atom {
230    /// Predicate name.
231    pub predicate: String,
232    /// Argument terms.
233    pub terms: Vec<Term>,
234}
235
236impl Atom {
237    /// Number of arguments.
238    pub fn arity(&self) -> usize {
239        self.terms.len()
240    }
241
242    /// Collect all named variables in this atom.
243    pub fn variables(&self) -> Vec<&str> {
244        self.terms.iter().flat_map(Term::variables).collect()
245    }
246}
247
248/// Epistemic operator on an atom.
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
250pub enum EpistemicOp {
251    /// Known/believed true in the selected epistemic mode.
252    Know,
253    /// Possible/consistent in the selected epistemic mode.
254    Possible,
255}
256
257/// Epistemic atom literal in a rule body.
258#[derive(Debug, Clone, PartialEq)]
259pub struct EpistemicLiteral {
260    /// Epistemic operator.
261    pub op: EpistemicOp,
262    /// Whether this epistemic literal is explicitly negated.
263    pub negated: bool,
264    /// Atom under the epistemic operator.
265    pub atom: Atom,
266}
267
268/// Comparison operator
269#[derive(Debug, Clone, Copy, PartialEq, Eq)]
270pub enum CompOp {
271    /// Equal.
272    Eq,
273    /// Not equal.
274    Ne,
275    /// Less than.
276    Lt,
277    /// Less than or equal.
278    Le,
279    /// Greater than.
280    Gt,
281    /// Greater than or equal.
282    Ge,
283}
284
285/// A comparison expression
286#[derive(Debug, Clone, PartialEq)]
287pub struct Comparison {
288    /// Left operand.
289    pub left: Term,
290    /// Comparison operator.
291    pub op: CompOp,
292    /// Right operand.
293    pub right: Term,
294}
295
296/// A finite univ expression (`Term =.. Parts`) in a rule body.
297#[derive(Debug, Clone, PartialEq)]
298pub struct Univ {
299    /// Term side of the univ relation.
300    pub term: Term,
301    /// Parts-list side of the univ relation.
302    pub parts: Term,
303}
304
305/// A literal in the body of a rule
306#[derive(Debug, Clone, PartialEq)]
307pub enum BodyLiteral {
308    /// Positive atom.
309    Positive(Atom),
310    /// Negated atom (`not p(...)`).
311    Negated(Atom),
312    /// Epistemic atom (`know p(...)`, `possible p(...)`, or negated form).
313    Epistemic(EpistemicLiteral),
314    /// Arithmetic comparison (e.g. `X < Y`).
315    Comparison(Comparison),
316    /// Is-expression binding (e.g. `Z is X + Y`).
317    IsExpr(IsExpr),
318    /// Finite univ relation (`Term =.. Parts`).
319    Univ(Univ),
320}
321
322impl BodyLiteral {
323    /// Returns true if this is a positive literal.
324    pub fn is_positive(&self) -> bool {
325        matches!(self, BodyLiteral::Positive(_))
326    }
327
328    /// Returns true if this is a negated literal.
329    pub fn is_negated(&self) -> bool {
330        matches!(self, BodyLiteral::Negated(_))
331    }
332
333    /// Returns the atom if this is a positive or negated literal.
334    pub fn atom(&self) -> Option<&Atom> {
335        match self {
336            BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => Some(a),
337            BodyLiteral::Epistemic(lit) => Some(&lit.atom),
338            BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => None,
339        }
340    }
341
342    /// Collect all named variables referenced by this literal.
343    pub fn variables(&self) -> Vec<&str> {
344        match self {
345            BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a.variables(),
346            BodyLiteral::Epistemic(lit) => lit.atom.variables(),
347            BodyLiteral::Comparison(c) => {
348                let mut vars = vec![];
349                vars.extend(c.left.variables());
350                vars.extend(c.right.variables());
351                vars
352            }
353            BodyLiteral::IsExpr(is_expr) => {
354                let mut vars = is_expr.expr.variables();
355                vars.push(is_expr.target.as_str());
356                vars
357            }
358            BodyLiteral::Univ(univ) => {
359                let mut vars = univ.term.variables();
360                vars.extend(univ.parts.variables());
361                vars
362            }
363        }
364    }
365}
366
367/// A rule (head :- body)
368#[derive(Debug, Clone, PartialEq)]
369pub struct Rule {
370    /// Head atom of the rule.
371    pub head: Atom,
372    /// Body literals (empty for facts).
373    pub body: Vec<BodyLiteral>,
374}
375
376impl Rule {
377    /// Returns true if this rule is a ground fact (empty body).
378    pub fn is_fact(&self) -> bool {
379        self.body.is_empty()
380    }
381
382    /// Returns true if any body literal is negated.
383    pub fn has_negation(&self) -> bool {
384        self.body.iter().any(|l| l.is_negated())
385    }
386
387    /// Returns true if the head contains an aggregate term.
388    pub fn has_aggregation(&self) -> bool {
389        self.head
390            .terms
391            .iter()
392            .any(|t| matches!(t, Term::Aggregate(_)))
393    }
394
395    /// Collect predicate names from the body.
396    pub fn body_predicates(&self) -> Vec<&str> {
397        self.body
398            .iter()
399            .filter_map(|l| l.atom().map(|a| a.predicate.as_str()))
400            .collect()
401    }
402
403    /// Collect named variables from the head.
404    pub fn head_variables(&self) -> Vec<&str> {
405        self.head.variables()
406    }
407
408    /// Collect all named variables from the body.
409    pub fn body_variables(&self) -> Vec<&str> {
410        self.body.iter().flat_map(|l| l.variables()).collect()
411    }
412}
413
414/// A constraint (:- body)
415#[derive(Debug, Clone, PartialEq)]
416pub struct Constraint {
417    /// Body literals whose conjunction must never be satisfiable.
418    pub body: Vec<BodyLiteral>,
419}
420
421/// A query (`?- atom.`)
422#[derive(Debug, Clone, PartialEq)]
423pub struct Query {
424    /// Query atom.
425    pub atom: Atom,
426}
427
428/// Probabilistic engine selection.
429#[derive(Debug, Clone, Copy, PartialEq, Eq)]
430pub enum ProbEngine {
431    /// Exact inference via d-DNNF compilation.
432    ExactDdnnf,
433    /// Approximate inference via Monte Carlo sampling.
434    Mc,
435}
436
437/// Probabilistic compilation caching.
438#[derive(Debug, Clone, Copy, PartialEq, Eq)]
439pub enum ProbCache {
440    /// Enable circuit caching.
441    On,
442    /// Disable circuit caching.
443    Off,
444}
445
446/// Epistemic semantics mode.
447#[derive(Debug, Clone, Copy, PartialEq, Eq)]
448pub enum EpistemicMode {
449    /// Gelfond-1991-style compatibility semantics, selected by `g91`.
450    G91,
451    /// Founded Autoepistemic Equilibrium Logic.
452    Faeel,
453}
454
455/// Monte Carlo sampling method selection.
456#[derive(Debug, Clone, Copy, PartialEq, Eq)]
457pub enum ProbMethod {
458    /// Rejection sampling.
459    Rejection,
460    /// Forceable evidence clamping.
461    EvidenceClamping,
462}
463
464/// Magic-set rewrite mode for bound recursive deterministic queries.
465#[derive(Debug, Clone, Copy, PartialEq, Eq)]
466pub enum MagicSetsMode {
467    /// Apply the rewrite when the compiler can prove the supported safe subset.
468    Auto,
469    /// Require the rewrite and fail with a typed diagnostic if it is unsafe.
470    On,
471    /// Disable magic-set rewriting.
472    Off,
473}
474
475/// Compilation/evaluation directives (e.g., `#pragma ...`).
476#[derive(Debug, Clone, Default, PartialEq)]
477pub struct Directives {
478    /// Override for the probabilistic inference engine.
479    pub prob_engine: Option<ProbEngine>,
480    /// Override for circuit caching.
481    pub prob_cache: Option<ProbCache>,
482    /// Monte Carlo sample count.
483    pub prob_samples: Option<usize>,
484    /// Monte Carlo deterministic RNG seed.
485    pub prob_seed: Option<u64>,
486    /// Monte Carlo confidence level.
487    pub prob_confidence: Option<f64>,
488    /// Monte Carlo sampling method.
489    pub prob_method: Option<ProbMethod>,
490    /// Maximum nonmonotone MC iterations.
491    pub prob_max_nonmonotone_iterations: Option<usize>,
492    /// Maximum UDF recursion depth.
493    pub max_recursion_depth: Option<u32>,
494    /// Override for epistemic semantics.
495    pub epistemic_mode: Option<EpistemicMode>,
496    /// Magic-set rewrite mode.
497    pub magic_sets: Option<MagicSetsMode>,
498}
499
500impl Directives {
501    /// Return the configured prob engine, defaulting to ExactDdnnf.
502    pub fn prob_engine_or_default(&self) -> ProbEngine {
503        self.prob_engine.unwrap_or(ProbEngine::ExactDdnnf)
504    }
505
506    /// Return the configured max recursion depth, defaulting to 1000.
507    pub fn max_recursion_depth_or_default(&self) -> u32 {
508        self.max_recursion_depth.unwrap_or(1000)
509    }
510
511    /// Return the configured epistemic mode, defaulting to FAEEL.
512    pub fn epistemic_mode_or_default(&self) -> EpistemicMode {
513        self.epistemic_mode.unwrap_or(EpistemicMode::Faeel)
514    }
515
516    /// Return the configured MC sample count, defaulting to 10000.
517    pub fn prob_samples_or_default(&self) -> usize {
518        self.prob_samples.unwrap_or(10000)
519    }
520
521    /// Return the configured MC seed, defaulting to 0.
522    pub fn prob_seed_or_default(&self) -> u64 {
523        self.prob_seed.unwrap_or(0)
524    }
525
526    /// Return the configured MC confidence, defaulting to 0.95.
527    pub fn prob_confidence_or_default(&self) -> f64 {
528        self.prob_confidence.unwrap_or(0.95)
529    }
530
531    /// Return the configured nonmonotone MC iteration cap, defaulting to 1024.
532    pub fn prob_max_nonmonotone_iterations_or_default(&self) -> usize {
533        self.prob_max_nonmonotone_iterations.unwrap_or(1024)
534    }
535}
536
537/// A probabilistic fact (`p::atom.`)
538#[derive(Debug, Clone, PartialEq)]
539pub struct ProbFact {
540    /// Probability weight.
541    pub prob: f64,
542    /// Ground atom.
543    pub atom: Atom,
544}
545
546/// Neural predicate declaration
547///
548/// Neural predicates connect neural networks to probabilistic logic.
549/// Syntax: `nn(network, [inputs], output, [labels]) :: pred(args).`
550///
551/// The neural network produces probability distributions over labels,
552/// which become probabilistic facts in the logic program.
553///
554/// # Examples
555/// ```text
556/// nn(mnist_net, [X], Y, [0,1,2,3,4,5,6,7,8,9]) :: digit(X, Y).
557/// nn(encoder, [Text], Embedding) :: encode(Text, Embedding).
558/// ```
559#[derive(Debug, Clone, PartialEq)]
560pub struct NeuralPredDecl {
561    /// Name of the registered neural network
562    pub network: String,
563    /// Input variable names (bind to tensor sources)
564    pub inputs: Vec<String>,
565    /// Output variable name
566    pub output: String,
567    /// Optional classification labels (for classification networks)
568    /// If None, the network produces embeddings
569    pub labels: Option<Vec<NeuralLabel>>,
570    /// The predicate this neural network defines
571    pub predicate: Atom,
572}
573
574/// A label in a neural predicate classification
575///
576/// Labels can be integers or symbols (identifiers).
577#[derive(Debug, Clone, PartialEq)]
578pub enum NeuralLabel {
579    /// Integer label value.
580    Integer(i64),
581    /// Symbolic (string) label value.
582    Symbol(String),
583}
584
585/// A learnable rule template parameterized by a named tensor mask.
586/// Used for differentiable ILP — the mask selects which (body1, body2, head)
587/// combinations are active during execution.
588#[derive(Debug, Clone)]
589pub struct LearnableRule {
590    /// Name of the tensor mask controlling rule activation.
591    pub mask_name: String,
592    /// Head atom of the rule template.
593    pub head: Atom,
594    /// Body literals of the rule template.
595    pub body: Vec<BodyLiteral>,
596}
597
598/// Annotated disjunction (`p1::a1; p2::a2.`)
599#[derive(Debug, Clone, PartialEq)]
600pub struct AnnotatedDisjunction {
601    /// Disjunctive choices with their probability weights.
602    pub choices: Vec<ProbFact>,
603}
604
605/// Evidence statement (`evidence(atom, true|false).`)
606#[derive(Debug, Clone, PartialEq)]
607pub struct Evidence {
608    /// The observed atom.
609    pub atom: Atom,
610    /// Whether the atom is observed true or false.
611    pub value: bool,
612}
613
614/// Probabilistic query statement (`query(atom).`)
615#[derive(Debug, Clone, PartialEq)]
616pub struct ProbQuery {
617    /// The atom whose probability is being queried.
618    pub atom: Atom,
619}
620
621/// Import statement: use module. or use module::{pred1, pred2}.
622#[derive(Debug, Clone, PartialEq)]
623pub struct UseDecl {
624    /// Module path segments, e.g., ["utils", "math"]
625    pub module_path: Vec<String>,
626    /// Specific imports (None = import all public)
627    pub imports: Option<Vec<String>>,
628}
629
630/// Domain declaration
631#[derive(Debug, Clone, PartialEq)]
632pub struct DomainDecl {
633    /// Domain name.
634    pub name: String,
635    /// Scalar type for the domain.
636    pub typ: ScalarType,
637}
638
639/// A type reference in source declarations.
640#[derive(Debug, Clone, PartialEq, Eq)]
641pub enum TypeRef {
642    /// Built-in scalar type.
643    Scalar(ScalarType),
644    /// Domain alias resolved during semantic analysis.
645    Domain(String),
646    /// Finite homogeneous list type.
647    List(Box<TypeRef>),
648    /// Finite term type.
649    Term,
650    /// Finite compound term type.
651    Compound,
652    /// Static predicate reference type.
653    PredRef,
654}
655
656/// Predicate declaration column.
657#[derive(Debug, Clone, PartialEq, Eq)]
658pub struct PredColumn {
659    /// Optional source-level column name.
660    pub name: Option<String>,
661    /// Column type reference.
662    pub typ: TypeRef,
663}
664
665/// Predicate declaration
666#[derive(Debug, Clone, PartialEq)]
667pub struct PredDecl {
668    /// Predicate name.
669    pub name: String,
670    /// Column types.
671    pub types: Vec<TypeRef>,
672    /// Declared columns, including optional names.
673    pub columns: Vec<PredColumn>,
674    /// Whether this predicate is module-private.
675    pub is_private: bool,
676}
677
678/// Function parameter with optional type annotation
679#[derive(Debug, Clone, PartialEq)]
680pub struct FuncParam {
681    /// Parameter name.
682    pub name: String,
683    /// Optional type annotation.
684    pub typ: Option<ScalarType>,
685}
686
687/// Conditional expression: if X < 0 then A else B
688#[derive(Debug, Clone, PartialEq)]
689pub struct CondExpr {
690    /// Left side of condition
691    pub cond_left: ArithExpr,
692    /// Comparison operator
693    pub cond_op: CompOp,
694    /// Right side of condition
695    pub cond_right: ArithExpr,
696    /// Value if condition is true
697    pub then_branch: Box<FuncBody>,
698    /// Value if condition is false
699    pub else_branch: Box<FuncBody>,
700}
701
702/// Function body - arithmetic, conditional, or predicate-based
703#[derive(Debug, Clone, PartialEq)]
704pub enum FuncBody {
705    /// Pure arithmetic expression: X * X
706    Arithmetic(ArithExpr),
707    /// Conditional expression: if X < 0 then ...
708    Conditional(CondExpr),
709    /// Predicate-based: P :- parent(X, P)
710    Predicate {
711        /// Result variable
712        result: String,
713        /// Body literals
714        body: Vec<BodyLiteral>,
715    },
716}
717
718/// User-defined function
719#[derive(Debug, Clone, PartialEq)]
720pub struct FuncDef {
721    /// Function name
722    pub name: String,
723    /// Parameters
724    pub params: Vec<FuncParam>,
725    /// Optional return type annotation
726    pub return_type: Option<ScalarType>,
727    /// Function body
728    pub body: FuncBody,
729    /// Is this function private?
730    pub is_private: bool,
731}
732
733/// A complete XLOG program
734#[derive(Debug, Clone, Default)]
735pub struct Program {
736    /// Import declarations (`use ...`).
737    pub imports: Vec<UseDecl>,
738    /// User-defined function definitions.
739    pub functions: Vec<FuncDef>,
740    /// Domain declarations.
741    pub domains: Vec<DomainDecl>,
742    /// Predicate type declarations.
743    pub predicates: Vec<PredDecl>,
744    /// Rules and facts.
745    pub rules: Vec<Rule>,
746    /// Integrity constraints (`:- ...`).
747    pub constraints: Vec<Constraint>,
748    /// Queries (`?- ...`).
749    pub queries: Vec<Query>,
750    /// Probabilistic facts (`p::atom.`).
751    pub prob_facts: Vec<ProbFact>,
752    /// Annotated disjunctions.
753    pub annotated_disjunctions: Vec<AnnotatedDisjunction>,
754    /// Evidence statements.
755    pub evidence: Vec<Evidence>,
756    /// Probabilistic queries (`query(atom).`).
757    pub prob_queries: Vec<ProbQuery>,
758    /// Neural predicate declarations.
759    pub neural_predicates: Vec<NeuralPredDecl>,
760    /// Learnable rule templates (ILP).
761    pub learnable_rules: Vec<LearnableRule>,
762    /// Compilation directives.
763    pub directives: Directives,
764}
765
766impl Program {
767    /// Create an empty program.
768    pub fn new() -> Self {
769        Self::default()
770    }
771
772    /// Iterate over ground facts (rules with empty bodies).
773    pub fn facts(&self) -> impl Iterator<Item = &Rule> {
774        self.rules.iter().filter(|r| r.is_fact())
775    }
776
777    /// Iterate over proper rules (non-fact rules with bodies).
778    pub fn proper_rules(&self) -> impl Iterator<Item = &Rule> {
779        self.rules.iter().filter(|r| !r.is_fact())
780    }
781
782    /// Collect the set of predicate names defined (appearing as rule heads).
783    pub fn defined_predicates(&self) -> Vec<&str> {
784        self.rules
785            .iter()
786            .map(|r| r.head.predicate.as_str())
787            .collect::<std::collections::HashSet<_>>()
788            .into_iter()
789            .collect()
790    }
791
792    /// Returns true if this program uses probabilistic features.
793    pub fn is_probabilistic_profile(&self) -> bool {
794        !self.prob_facts.is_empty()
795            || !self.annotated_disjunctions.is_empty()
796            || !self.evidence.is_empty()
797            || !self.prob_queries.is_empty()
798            || self.directives.prob_engine.is_some()
799            || self.directives.prob_cache.is_some()
800            || self.directives.prob_samples.is_some()
801            || self.directives.prob_seed.is_some()
802            || self.directives.prob_confidence.is_some()
803            || self.directives.prob_method.is_some()
804            || self.directives.prob_max_nonmonotone_iterations.is_some()
805    }
806
807    /// Return the probabilistic engine (from directives, or the default).
808    pub fn prob_engine(&self) -> ProbEngine {
809        self.directives.prob_engine_or_default()
810    }
811
812    /// Merge another program's exports into this program.
813    /// Used for importing modules - adds predicates, functions, rules from the imported module.
814    /// Only merges public items (private items are not exported).
815    ///
816    /// # Arguments
817    /// * `other` - The program to merge from
818    /// * `imported_items` - Optional set of specific items to import. If None, imports all public items.
819    pub fn merge_from(
820        &mut self,
821        other: &Program,
822        imported_items: Option<&std::collections::HashSet<String>>,
823    ) {
824        use std::collections::HashSet;
825
826        // Track which predicates are private in the source
827        let private_preds: HashSet<&str> = other
828            .predicates
829            .iter()
830            .filter(|p| p.is_private)
831            .map(|p| p.name.as_str())
832            .collect();
833
834        let _private_funcs: HashSet<&str> = other
835            .functions
836            .iter()
837            .filter(|f| f.is_private)
838            .map(|f| f.name.as_str())
839            .collect();
840
841        // Merge predicate declarations (only public ones)
842        for pred in &other.predicates {
843            if pred.is_private {
844                continue;
845            }
846            // Check if this is in the import list (if specified)
847            if let Some(items) = imported_items {
848                if !items.contains(&pred.name) {
849                    continue;
850                }
851            }
852            // Avoid duplicate declarations
853            if !self.predicates.iter().any(|p| p.name == pred.name) {
854                self.predicates.push(pred.clone());
855            }
856        }
857
858        // Merge functions (only public ones)
859        for func in &other.functions {
860            if func.is_private {
861                continue;
862            }
863            if let Some(items) = imported_items {
864                if !items.contains(&func.name) {
865                    continue;
866                }
867            }
868            // Avoid duplicate functions
869            if !self.functions.iter().any(|f| f.name == func.name) {
870                self.functions.push(func.clone());
871            }
872        }
873
874        // Merge rules (facts and rules for public predicates)
875        for rule in &other.rules {
876            // Skip if the head predicate is private
877            if private_preds.contains(rule.head.predicate.as_str()) {
878                continue;
879            }
880            // Check import list for facts/rules
881            if let Some(items) = imported_items {
882                if !items.contains(&rule.head.predicate) {
883                    continue;
884                }
885            }
886            if !self.rules.iter().any(|existing| existing == rule) {
887                self.rules.push(rule.clone());
888            }
889        }
890
891        // Merge domains
892        for domain in &other.domains {
893            if !self.domains.iter().any(|d| d.name == domain.name) {
894                self.domains.push(domain.clone());
895            }
896        }
897    }
898}
899
900#[cfg(test)]
901mod tests {
902    use super::*;
903
904    #[test]
905    fn test_term_variable() {
906        let term = Term::Variable("X".to_string());
907        assert!(term.is_variable());
908        assert!(!term.is_constant());
909    }
910
911    #[test]
912    fn test_term_constant() {
913        let term = Term::Integer(42);
914        assert!(!term.is_variable());
915        assert!(term.is_constant());
916    }
917
918    #[test]
919    fn test_atom_arity() {
920        let atom = Atom {
921            predicate: "edge".to_string(),
922            terms: vec![Term::Integer(1), Term::Integer(2)],
923        };
924        assert_eq!(atom.arity(), 2);
925    }
926
927    #[test]
928    fn test_atom_variables() {
929        let atom = Atom {
930            predicate: "edge".to_string(),
931            terms: vec![Term::Variable("X".to_string()), Term::Integer(2)],
932        };
933        let vars = atom.variables();
934        assert_eq!(vars, vec!["X"]);
935    }
936
937    #[test]
938    fn test_rule_is_fact() {
939        let fact = Rule {
940            head: Atom {
941                predicate: "edge".to_string(),
942                terms: vec![Term::Integer(1), Term::Integer(2)],
943            },
944            body: vec![],
945        };
946        assert!(fact.is_fact());
947    }
948
949    #[test]
950    fn test_rule_has_negation() {
951        let rule = Rule {
952            head: Atom {
953                predicate: "isolated".to_string(),
954                terms: vec![Term::Variable("X".to_string())],
955            },
956            body: vec![
957                BodyLiteral::Positive(Atom {
958                    predicate: "node".to_string(),
959                    terms: vec![Term::Variable("X".to_string())],
960                }),
961                BodyLiteral::Negated(Atom {
962                    predicate: "edge".to_string(),
963                    terms: vec![
964                        Term::Variable("X".to_string()),
965                        Term::Variable("Y".to_string()),
966                    ],
967                }),
968            ],
969        };
970        assert!(rule.has_negation());
971    }
972
973    #[test]
974    fn test_program_facts() {
975        let mut program = Program::new();
976        program.rules.push(Rule {
977            head: Atom {
978                predicate: "edge".to_string(),
979                terms: vec![Term::Integer(1), Term::Integer(2)],
980            },
981            body: vec![],
982        });
983        program.rules.push(Rule {
984            head: Atom {
985                predicate: "reach".to_string(),
986                terms: vec![
987                    Term::Variable("X".to_string()),
988                    Term::Variable("Y".to_string()),
989                ],
990            },
991            body: vec![BodyLiteral::Positive(Atom {
992                predicate: "edge".to_string(),
993                terms: vec![
994                    Term::Variable("X".to_string()),
995                    Term::Variable("Y".to_string()),
996                ],
997            })],
998        });
999        assert_eq!(program.facts().count(), 1);
1000        assert_eq!(program.proper_rules().count(), 1);
1001    }
1002
1003    #[test]
1004    fn test_arith_expr_structure() {
1005        let expr = ArithExpr::Add(
1006            Box::new(ArithExpr::Variable("X".to_string())),
1007            Box::new(ArithExpr::Integer(1)),
1008        );
1009        assert!(matches!(expr, ArithExpr::Add(_, _)));
1010    }
1011
1012    #[test]
1013    fn test_is_expr_structure() {
1014        let is_expr = IsExpr {
1015            target: "Z".to_string(),
1016            expr: ArithExpr::Variable("Y".to_string()),
1017        };
1018        assert_eq!(is_expr.target, "Z");
1019    }
1020}