Skip to main content

xlog_logic/
expand.rs

1//! Inline expansion of user-defined functions.
2
3use crate::ast::{ArithExpr, Atom, BodyLiteral, Comparison, FuncBody, FuncDef, IsExpr, Term, Univ};
4use crate::function::{FunctionError, FunctionRegistry};
5use std::collections::HashMap;
6
7/// Context for inline expansion of user-defined functions.
8pub struct ExpansionContext<'a> {
9    registry: &'a FunctionRegistry,
10    depth: u32,
11    max_depth: u32,
12}
13
14impl<'a> ExpansionContext<'a> {
15    /// Create an expansion context with the given function registry and recursion limit.
16    pub fn new(registry: &'a FunctionRegistry, max_depth: u32) -> Self {
17        Self {
18            registry,
19            depth: 0,
20            max_depth,
21        }
22    }
23
24    /// Expand a function call to its body with arguments substituted
25    pub fn expand_call(
26        &mut self,
27        name: &str,
28        args: &[ArithExpr],
29    ) -> Result<ArithExpr, FunctionError> {
30        // Check depth limit
31        if self.depth >= self.max_depth {
32            return Err(FunctionError::MaxRecursionDepth {
33                name: name.to_string(),
34                depth: self.max_depth,
35            });
36        }
37
38        let func = self
39            .registry
40            .get(name)
41            .ok_or_else(|| FunctionError::UndefinedFunction {
42                name: name.to_string(),
43            })?;
44
45        // Build substitution map
46        let mut subst: HashMap<String, ArithExpr> = HashMap::new();
47        for (param, arg) in func.params.iter().zip(args.iter()) {
48            subst.insert(param.name.clone(), arg.clone());
49        }
50
51        // Expand body
52        self.depth += 1;
53        let result = self.expand_body(&func.body, &subst)?;
54        self.depth -= 1;
55
56        Ok(result)
57    }
58
59    fn expand_body(
60        &mut self,
61        body: &FuncBody,
62        subst: &HashMap<String, ArithExpr>,
63    ) -> Result<ArithExpr, FunctionError> {
64        match body {
65            FuncBody::Arithmetic(expr) => self.expand_expr(expr, subst),
66            FuncBody::Conditional(cond) => {
67                // Expand condition parts
68                let cond_left = self.expand_expr(&cond.cond_left, subst)?;
69                let cond_right = self.expand_expr(&cond.cond_right, subst)?;
70                let then_expr = self.expand_body(&cond.then_branch, subst)?;
71                let else_expr = self.expand_body(&cond.else_branch, subst)?;
72
73                // Return a conditional ArithExpr (need to represent this)
74                // For now, we create a structure that captures the conditional
75                // The actual evaluation will happen at runtime
76                Ok(ArithExpr::Conditional {
77                    cond_left: Box::new(cond_left),
78                    cond_op: cond.cond_op,
79                    cond_right: Box::new(cond_right),
80                    then_expr: Box::new(then_expr),
81                    else_expr: Box::new(else_expr),
82                })
83            }
84            FuncBody::Predicate { result, .. } => {
85                // Predicate bodies are expanded via expand_predicate_func at the rule level.
86                // When expand_body is called for a predicate body (shouldn't happen directly),
87                // return the result variable after substitution as an ArithExpr.
88                // The actual expansion to join literals happens via expand_predicate_func.
89                let result_var = subst
90                    .get(result)
91                    .cloned()
92                    .unwrap_or_else(|| ArithExpr::Variable(result.clone()));
93                Ok(result_var)
94            }
95        }
96    }
97
98    fn expand_expr(
99        &mut self,
100        expr: &ArithExpr,
101        subst: &HashMap<String, ArithExpr>,
102    ) -> Result<ArithExpr, FunctionError> {
103        match expr {
104            ArithExpr::Variable(name) => {
105                Ok(subst.get(name).cloned().unwrap_or_else(|| expr.clone()))
106            }
107            ArithExpr::Integer(_) | ArithExpr::Float(_) => Ok(expr.clone()),
108            ArithExpr::FuncCall { name, args } => {
109                // First expand arguments
110                let expanded_args: Result<Vec<_>, _> =
111                    args.iter().map(|a| self.expand_expr(a, subst)).collect();
112                let expanded_args = expanded_args?;
113
114                // Then expand the function call if it's a UDF
115                if self.registry.contains(name) {
116                    self.expand_call(name, &expanded_args)
117                } else {
118                    // Built-in function, just return with expanded args
119                    Ok(ArithExpr::FuncCall {
120                        name: name.clone(),
121                        args: expanded_args,
122                    })
123                }
124            }
125            ArithExpr::Add(l, r) => {
126                let el = self.expand_expr(l, subst)?;
127                let er = self.expand_expr(r, subst)?;
128                Ok(ArithExpr::Add(Box::new(el), Box::new(er)))
129            }
130            ArithExpr::Sub(l, r) => {
131                let el = self.expand_expr(l, subst)?;
132                let er = self.expand_expr(r, subst)?;
133                Ok(ArithExpr::Sub(Box::new(el), Box::new(er)))
134            }
135            ArithExpr::Mul(l, r) => {
136                let el = self.expand_expr(l, subst)?;
137                let er = self.expand_expr(r, subst)?;
138                Ok(ArithExpr::Mul(Box::new(el), Box::new(er)))
139            }
140            ArithExpr::Div(l, r) => {
141                let el = self.expand_expr(l, subst)?;
142                let er = self.expand_expr(r, subst)?;
143                Ok(ArithExpr::Div(Box::new(el), Box::new(er)))
144            }
145            ArithExpr::Mod(l, r) => {
146                let el = self.expand_expr(l, subst)?;
147                let er = self.expand_expr(r, subst)?;
148                Ok(ArithExpr::Mod(Box::new(el), Box::new(er)))
149            }
150            ArithExpr::Abs(e) => {
151                let ee = self.expand_expr(e, subst)?;
152                Ok(ArithExpr::Abs(Box::new(ee)))
153            }
154            ArithExpr::Min(l, r) => {
155                let el = self.expand_expr(l, subst)?;
156                let er = self.expand_expr(r, subst)?;
157                Ok(ArithExpr::Min(Box::new(el), Box::new(er)))
158            }
159            ArithExpr::Max(l, r) => {
160                let el = self.expand_expr(l, subst)?;
161                let er = self.expand_expr(r, subst)?;
162                Ok(ArithExpr::Max(Box::new(el), Box::new(er)))
163            }
164            ArithExpr::Pow(l, r) => {
165                let el = self.expand_expr(l, subst)?;
166                let er = self.expand_expr(r, subst)?;
167                Ok(ArithExpr::Pow(Box::new(el), Box::new(er)))
168            }
169            ArithExpr::Cast(e, t) => {
170                let ee = self.expand_expr(e, subst)?;
171                Ok(ArithExpr::Cast(Box::new(ee), *t))
172            }
173            ArithExpr::Conditional {
174                cond_left,
175                cond_op,
176                cond_right,
177                then_expr,
178                else_expr,
179            } => {
180                let cl = self.expand_expr(cond_left, subst)?;
181                let cr = self.expand_expr(cond_right, subst)?;
182                let te = self.expand_expr(then_expr, subst)?;
183                let ee = self.expand_expr(else_expr, subst)?;
184                Ok(ArithExpr::Conditional {
185                    cond_left: Box::new(cl),
186                    cond_op: *cond_op,
187                    cond_right: Box::new(cr),
188                    then_expr: Box::new(te),
189                    else_expr: Box::new(ee),
190                })
191            }
192        }
193    }
194
195    /// Expand a predicate-based function call to join literals.
196    ///
197    /// Predicate functions like `func get_parent(X) = P :- parent(X, P).`
198    /// expand to body literals that get added to the calling rule.
199    ///
200    /// Returns the expanded body literals and the result variable name.
201    #[allow(dead_code)] // reserved API: predicate-func expansion not yet wired
202    pub(crate) fn expand_predicate_func(
203        &self,
204        func: &FuncDef,
205        args: &[ArithExpr],
206    ) -> Result<(Vec<BodyLiteral>, String), FunctionError> {
207        match &func.body {
208            FuncBody::Predicate { result, body } => {
209                // Build substitution map from params to args
210                let mut subst: HashMap<String, ArithExpr> = HashMap::new();
211                for (param, arg) in func.params.iter().zip(args.iter()) {
212                    subst.insert(param.name.clone(), arg.clone());
213                }
214
215                // Substitute in body literals
216                let expanded_body: Vec<BodyLiteral> = body
217                    .iter()
218                    .map(|lit| self.substitute_literal(lit, &subst))
219                    .collect();
220
221                // The result variable becomes the output (substitute if mapped)
222                let result_var = self.substitute_var(result, &subst);
223
224                Ok((expanded_body, result_var))
225            }
226            _ => Err(FunctionError::UndefinedFunction {
227                name: func.name.clone(),
228            }),
229        }
230    }
231
232    /// Substitute variables in a body literal using the given substitution map.
233    fn substitute_literal(
234        &self,
235        lit: &BodyLiteral,
236        subst: &HashMap<String, ArithExpr>,
237    ) -> BodyLiteral {
238        match lit {
239            BodyLiteral::Positive(atom) => BodyLiteral::Positive(self.substitute_atom(atom, subst)),
240            BodyLiteral::Negated(atom) => BodyLiteral::Negated(self.substitute_atom(atom, subst)),
241            BodyLiteral::Epistemic(lit) => BodyLiteral::Epistemic(crate::ast::EpistemicLiteral {
242                op: lit.op,
243                negated: lit.negated,
244                atom: self.substitute_atom(&lit.atom, subst),
245            }),
246            BodyLiteral::Comparison(cmp) => BodyLiteral::Comparison(Comparison {
247                left: self.substitute_term(&cmp.left, subst),
248                op: cmp.op,
249                right: self.substitute_term(&cmp.right, subst),
250            }),
251            BodyLiteral::IsExpr(is_expr) => {
252                // Substitute in both the target variable and the expression
253                let target = self.substitute_var(&is_expr.target, subst);
254                // For ArithExpr substitution, we need to substitute variables
255                let expr = self.substitute_arith_expr(&is_expr.expr, subst);
256                BodyLiteral::IsExpr(IsExpr { target, expr })
257            }
258            BodyLiteral::Univ(univ) => BodyLiteral::Univ(Univ {
259                term: self.substitute_term(&univ.term, subst),
260                parts: self.substitute_term(&univ.parts, subst),
261            }),
262        }
263    }
264
265    /// Substitute variables in an atom.
266    fn substitute_atom(&self, atom: &Atom, subst: &HashMap<String, ArithExpr>) -> Atom {
267        Atom {
268            predicate: atom.predicate.clone(),
269            terms: atom
270                .terms
271                .iter()
272                .map(|t| self.substitute_term(t, subst))
273                .collect(),
274        }
275    }
276
277    /// Substitute a variable in a term.
278    fn substitute_term(&self, term: &Term, subst: &HashMap<String, ArithExpr>) -> Term {
279        match term {
280            Term::Variable(name) => {
281                if let Some(replacement) = subst.get(name) {
282                    match replacement {
283                        ArithExpr::Variable(new_name) => Term::Variable(new_name.clone()),
284                        ArithExpr::Integer(n) => Term::Integer(*n),
285                        ArithExpr::Float(f) => Term::Float(*f),
286                        // For complex expressions, we can't directly substitute into a Term,
287                        // so we keep the original variable (this is a limitation)
288                        _ => term.clone(),
289                    }
290                } else {
291                    term.clone()
292                }
293            }
294            _ => term.clone(),
295        }
296    }
297
298    /// Substitute variables in an arithmetic expression.
299    fn substitute_arith_expr(
300        &self,
301        expr: &ArithExpr,
302        subst: &HashMap<String, ArithExpr>,
303    ) -> ArithExpr {
304        match expr {
305            ArithExpr::Variable(name) => subst.get(name).cloned().unwrap_or_else(|| expr.clone()),
306            ArithExpr::Integer(_) | ArithExpr::Float(_) => expr.clone(),
307            ArithExpr::Add(l, r) => ArithExpr::Add(
308                Box::new(self.substitute_arith_expr(l, subst)),
309                Box::new(self.substitute_arith_expr(r, subst)),
310            ),
311            ArithExpr::Sub(l, r) => ArithExpr::Sub(
312                Box::new(self.substitute_arith_expr(l, subst)),
313                Box::new(self.substitute_arith_expr(r, subst)),
314            ),
315            ArithExpr::Mul(l, r) => ArithExpr::Mul(
316                Box::new(self.substitute_arith_expr(l, subst)),
317                Box::new(self.substitute_arith_expr(r, subst)),
318            ),
319            ArithExpr::Div(l, r) => ArithExpr::Div(
320                Box::new(self.substitute_arith_expr(l, subst)),
321                Box::new(self.substitute_arith_expr(r, subst)),
322            ),
323            ArithExpr::Mod(l, r) => ArithExpr::Mod(
324                Box::new(self.substitute_arith_expr(l, subst)),
325                Box::new(self.substitute_arith_expr(r, subst)),
326            ),
327            ArithExpr::Abs(e) => ArithExpr::Abs(Box::new(self.substitute_arith_expr(e, subst))),
328            ArithExpr::Min(l, r) => ArithExpr::Min(
329                Box::new(self.substitute_arith_expr(l, subst)),
330                Box::new(self.substitute_arith_expr(r, subst)),
331            ),
332            ArithExpr::Max(l, r) => ArithExpr::Max(
333                Box::new(self.substitute_arith_expr(l, subst)),
334                Box::new(self.substitute_arith_expr(r, subst)),
335            ),
336            ArithExpr::Pow(l, r) => ArithExpr::Pow(
337                Box::new(self.substitute_arith_expr(l, subst)),
338                Box::new(self.substitute_arith_expr(r, subst)),
339            ),
340            ArithExpr::Cast(e, t) => {
341                ArithExpr::Cast(Box::new(self.substitute_arith_expr(e, subst)), *t)
342            }
343            ArithExpr::FuncCall { name, args } => ArithExpr::FuncCall {
344                name: name.clone(),
345                args: args
346                    .iter()
347                    .map(|a| self.substitute_arith_expr(a, subst))
348                    .collect(),
349            },
350            ArithExpr::Conditional {
351                cond_left,
352                cond_op,
353                cond_right,
354                then_expr,
355                else_expr,
356            } => ArithExpr::Conditional {
357                cond_left: Box::new(self.substitute_arith_expr(cond_left, subst)),
358                cond_op: *cond_op,
359                cond_right: Box::new(self.substitute_arith_expr(cond_right, subst)),
360                then_expr: Box::new(self.substitute_arith_expr(then_expr, subst)),
361                else_expr: Box::new(self.substitute_arith_expr(else_expr, subst)),
362            },
363        }
364    }
365
366    /// Substitute a variable name using the substitution map.
367    fn substitute_var(&self, var: &str, subst: &HashMap<String, ArithExpr>) -> String {
368        if let Some(ArithExpr::Variable(new_name)) = subst.get(var) {
369            new_name.clone()
370        } else {
371            var.to_string()
372        }
373    }
374
375    /// Check if a function has a predicate body.
376    #[allow(dead_code)] // reserved API: predicate-func expansion not yet wired
377    pub(crate) fn is_predicate_func(&self, name: &str) -> bool {
378        self.registry
379            .get(name)
380            .map(|f| matches!(f.body, FuncBody::Predicate { .. }))
381            .unwrap_or(false)
382    }
383
384    /// Expand all function calls in an arithmetic expression.
385    /// Returns the expanded expression with all UDF calls inlined.
386    pub(crate) fn expand_expr_fully(
387        &mut self,
388        expr: &ArithExpr,
389    ) -> Result<ArithExpr, FunctionError> {
390        match expr {
391            ArithExpr::Variable(_) | ArithExpr::Integer(_) | ArithExpr::Float(_) => {
392                Ok(expr.clone())
393            }
394            ArithExpr::FuncCall { name, args } => {
395                // First expand arguments
396                let expanded_args: Result<Vec<_>, _> =
397                    args.iter().map(|a| self.expand_expr_fully(a)).collect();
398                let expanded_args = expanded_args?;
399
400                // Then expand the function call if it's a UDF
401                if self.registry.contains(name) {
402                    self.expand_call(name, &expanded_args)
403                } else {
404                    // Built-in function, just return with expanded args
405                    Ok(ArithExpr::FuncCall {
406                        name: name.clone(),
407                        args: expanded_args,
408                    })
409                }
410            }
411            ArithExpr::Add(l, r) => Ok(ArithExpr::Add(
412                Box::new(self.expand_expr_fully(l)?),
413                Box::new(self.expand_expr_fully(r)?),
414            )),
415            ArithExpr::Sub(l, r) => Ok(ArithExpr::Sub(
416                Box::new(self.expand_expr_fully(l)?),
417                Box::new(self.expand_expr_fully(r)?),
418            )),
419            ArithExpr::Mul(l, r) => Ok(ArithExpr::Mul(
420                Box::new(self.expand_expr_fully(l)?),
421                Box::new(self.expand_expr_fully(r)?),
422            )),
423            ArithExpr::Div(l, r) => Ok(ArithExpr::Div(
424                Box::new(self.expand_expr_fully(l)?),
425                Box::new(self.expand_expr_fully(r)?),
426            )),
427            ArithExpr::Mod(l, r) => Ok(ArithExpr::Mod(
428                Box::new(self.expand_expr_fully(l)?),
429                Box::new(self.expand_expr_fully(r)?),
430            )),
431            ArithExpr::Abs(e) => Ok(ArithExpr::Abs(Box::new(self.expand_expr_fully(e)?))),
432            ArithExpr::Min(l, r) => Ok(ArithExpr::Min(
433                Box::new(self.expand_expr_fully(l)?),
434                Box::new(self.expand_expr_fully(r)?),
435            )),
436            ArithExpr::Max(l, r) => Ok(ArithExpr::Max(
437                Box::new(self.expand_expr_fully(l)?),
438                Box::new(self.expand_expr_fully(r)?),
439            )),
440            ArithExpr::Pow(l, r) => Ok(ArithExpr::Pow(
441                Box::new(self.expand_expr_fully(l)?),
442                Box::new(self.expand_expr_fully(r)?),
443            )),
444            ArithExpr::Cast(e, t) => Ok(ArithExpr::Cast(Box::new(self.expand_expr_fully(e)?), *t)),
445            ArithExpr::Conditional {
446                cond_left,
447                cond_op,
448                cond_right,
449                then_expr,
450                else_expr,
451            } => Ok(ArithExpr::Conditional {
452                cond_left: Box::new(self.expand_expr_fully(cond_left)?),
453                cond_op: *cond_op,
454                cond_right: Box::new(self.expand_expr_fully(cond_right)?),
455                then_expr: Box::new(self.expand_expr_fully(then_expr)?),
456                else_expr: Box::new(self.expand_expr_fully(else_expr)?),
457            }),
458        }
459    }
460}
461
462use crate::ast::{Program, Rule};
463
464/// Expand all user-defined function calls in a program.
465/// Returns a new program with all UDF calls replaced by their expanded bodies.
466/// Expand all user-defined function calls in the program to inline arithmetic.
467pub fn expand_program_functions(
468    program: &Program,
469    max_depth: u32,
470) -> Result<Program, FunctionError> {
471    // Build function registry from program
472    let mut registry = FunctionRegistry::new();
473    for func in &program.functions {
474        registry.register(func.clone())?;
475    }
476
477    // If no functions defined, return program unchanged
478    if program.functions.is_empty() {
479        return Ok(program.clone());
480    }
481
482    let mut ctx = ExpansionContext::new(&registry, max_depth);
483
484    // Expand function calls in each rule
485    let expanded_rules: Result<Vec<Rule>, FunctionError> = program
486        .rules
487        .iter()
488        .map(|rule| expand_rule_functions(&mut ctx, rule))
489        .collect();
490
491    Ok(Program {
492        rules: expanded_rules?,
493        directives: program.directives.clone(),
494        queries: program.queries.clone(),
495        predicates: program.predicates.clone(),
496        constraints: program.constraints.clone(),
497        imports: program.imports.clone(),
498        functions: program.functions.clone(),
499        domains: program.domains.clone(),
500        prob_facts: program.prob_facts.clone(),
501        annotated_disjunctions: program.annotated_disjunctions.clone(),
502        evidence: program.evidence.clone(),
503        prob_queries: program.prob_queries.clone(),
504        neural_predicates: program.neural_predicates.clone(),
505        learnable_rules: program.learnable_rules.clone(),
506    })
507}
508
509/// Expand function calls in a single rule.
510fn expand_rule_functions(ctx: &mut ExpansionContext, rule: &Rule) -> Result<Rule, FunctionError> {
511    let expanded_body: Result<Vec<BodyLiteral>, FunctionError> = rule
512        .body
513        .iter()
514        .map(|lit| expand_literal_functions(ctx, lit))
515        .collect();
516
517    Ok(Rule {
518        head: rule.head.clone(),
519        body: expanded_body?,
520    })
521}
522
523/// Expand function calls in a body literal.
524fn expand_literal_functions(
525    ctx: &mut ExpansionContext,
526    lit: &BodyLiteral,
527) -> Result<BodyLiteral, FunctionError> {
528    match lit {
529        BodyLiteral::Positive(atom) => Ok(BodyLiteral::Positive(atom.clone())),
530        BodyLiteral::Negated(atom) => Ok(BodyLiteral::Negated(atom.clone())),
531        BodyLiteral::Epistemic(lit) => Ok(BodyLiteral::Epistemic(lit.clone())),
532        BodyLiteral::Comparison(cmp) => Ok(BodyLiteral::Comparison(cmp.clone())),
533        BodyLiteral::IsExpr(is_expr) => {
534            let expanded_expr = ctx.expand_expr_fully(&is_expr.expr)?;
535            Ok(BodyLiteral::IsExpr(IsExpr {
536                target: is_expr.target.clone(),
537                expr: expanded_expr,
538            }))
539        }
540        BodyLiteral::Univ(univ) => Ok(BodyLiteral::Univ(univ.clone())),
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547    use crate::ast::{FuncDef, FuncParam};
548
549    #[test]
550    fn test_simple_expansion() {
551        let mut reg = FunctionRegistry::new();
552
553        // func double(X) = X + X
554        let double = FuncDef {
555            name: "double".to_string(),
556            params: vec![FuncParam {
557                name: "X".to_string(),
558                typ: None,
559            }],
560            return_type: None,
561            body: FuncBody::Arithmetic(ArithExpr::Add(
562                Box::new(ArithExpr::Variable("X".to_string())),
563                Box::new(ArithExpr::Variable("X".to_string())),
564            )),
565            is_private: false,
566        };
567        reg.register(double).unwrap();
568
569        let mut ctx = ExpansionContext::new(&reg, 100);
570
571        // double(5) should expand to 5 + 5
572        let result = ctx.expand_call("double", &[ArithExpr::Integer(5)]).unwrap();
573
574        match result {
575            ArithExpr::Add(l, r) => {
576                assert!(matches!(*l, ArithExpr::Integer(5)));
577                assert!(matches!(*r, ArithExpr::Integer(5)));
578            }
579            _ => panic!("Expected Add expression"),
580        }
581    }
582
583    #[test]
584    fn test_nested_expansion() {
585        let mut reg = FunctionRegistry::new();
586
587        // func double(X) = X + X
588        let double = FuncDef {
589            name: "double".to_string(),
590            params: vec![FuncParam {
591                name: "X".to_string(),
592                typ: None,
593            }],
594            return_type: None,
595            body: FuncBody::Arithmetic(ArithExpr::Add(
596                Box::new(ArithExpr::Variable("X".to_string())),
597                Box::new(ArithExpr::Variable("X".to_string())),
598            )),
599            is_private: false,
600        };
601
602        // func quadruple(X) = double(double(X))
603        let quadruple = FuncDef {
604            name: "quadruple".to_string(),
605            params: vec![FuncParam {
606                name: "X".to_string(),
607                typ: None,
608            }],
609            return_type: None,
610            body: FuncBody::Arithmetic(ArithExpr::FuncCall {
611                name: "double".to_string(),
612                args: vec![ArithExpr::FuncCall {
613                    name: "double".to_string(),
614                    args: vec![ArithExpr::Variable("X".to_string())],
615                }],
616            }),
617            is_private: false,
618        };
619
620        reg.register(double).unwrap();
621        reg.register(quadruple).unwrap();
622
623        let mut ctx = ExpansionContext::new(&reg, 100);
624
625        // quadruple(2) should expand to (2 + 2) + (2 + 2)
626        let result = ctx
627            .expand_call("quadruple", &[ArithExpr::Integer(2)])
628            .unwrap();
629
630        // Result should be Add(Add(2, 2), Add(2, 2))
631        match &result {
632            ArithExpr::Add(l, r) => {
633                assert!(matches!(l.as_ref(), ArithExpr::Add(_, _)));
634                assert!(matches!(r.as_ref(), ArithExpr::Add(_, _)));
635            }
636            _ => panic!("Expected nested Add expression, got {:?}", result),
637        }
638    }
639
640    #[test]
641    fn test_max_recursion_depth() {
642        let mut reg = FunctionRegistry::new();
643
644        // func infinite(X) = infinite(X)
645        let infinite = FuncDef {
646            name: "infinite".to_string(),
647            params: vec![FuncParam {
648                name: "X".to_string(),
649                typ: None,
650            }],
651            return_type: None,
652            body: FuncBody::Arithmetic(ArithExpr::FuncCall {
653                name: "infinite".to_string(),
654                args: vec![ArithExpr::Variable("X".to_string())],
655            }),
656            is_private: false,
657        };
658        reg.register(infinite).unwrap();
659
660        let mut ctx = ExpansionContext::new(&reg, 10);
661
662        let result = ctx.expand_call("infinite", &[ArithExpr::Integer(1)]);
663        assert!(matches!(
664            result,
665            Err(FunctionError::MaxRecursionDepth { .. })
666        ));
667    }
668
669    #[test]
670    fn test_undefined_function() {
671        let reg = FunctionRegistry::new();
672        let mut ctx = ExpansionContext::new(&reg, 100);
673
674        let result = ctx.expand_call("undefined", &[ArithExpr::Integer(1)]);
675        assert!(matches!(
676            result,
677            Err(FunctionError::UndefinedFunction { .. })
678        ));
679    }
680
681    #[test]
682    fn test_builtin_function_passthrough() {
683        let mut reg = FunctionRegistry::new();
684
685        // func abs_x(X) = abs(X)
686        let abs_x = FuncDef {
687            name: "abs_x".to_string(),
688            params: vec![FuncParam {
689                name: "X".to_string(),
690                typ: None,
691            }],
692            return_type: None,
693            body: FuncBody::Arithmetic(ArithExpr::FuncCall {
694                name: "abs".to_string(),
695                args: vec![ArithExpr::Variable("X".to_string())],
696            }),
697            is_private: false,
698        };
699        reg.register(abs_x).unwrap();
700
701        let mut ctx = ExpansionContext::new(&reg, 100);
702
703        let result = ctx.expand_call("abs_x", &[ArithExpr::Integer(-5)]).unwrap();
704
705        // Should preserve abs call with substituted arg
706        match result {
707            ArithExpr::FuncCall { name, args } => {
708                assert_eq!(name, "abs");
709                assert_eq!(args.len(), 1);
710                assert!(matches!(args[0], ArithExpr::Integer(-5)));
711            }
712            _ => panic!("Expected FuncCall for builtin"),
713        }
714    }
715
716    #[test]
717    fn test_variable_substitution() {
718        let mut reg = FunctionRegistry::new();
719
720        // func add(X, Y) = X + Y
721        let add = FuncDef {
722            name: "add".to_string(),
723            params: vec![
724                FuncParam {
725                    name: "X".to_string(),
726                    typ: None,
727                },
728                FuncParam {
729                    name: "Y".to_string(),
730                    typ: None,
731                },
732            ],
733            return_type: None,
734            body: FuncBody::Arithmetic(ArithExpr::Add(
735                Box::new(ArithExpr::Variable("X".to_string())),
736                Box::new(ArithExpr::Variable("Y".to_string())),
737            )),
738            is_private: false,
739        };
740        reg.register(add).unwrap();
741
742        let mut ctx = ExpansionContext::new(&reg, 100);
743
744        // add(3, 7) should expand to 3 + 7
745        let result = ctx
746            .expand_call("add", &[ArithExpr::Integer(3), ArithExpr::Integer(7)])
747            .unwrap();
748
749        match result {
750            ArithExpr::Add(l, r) => {
751                assert!(matches!(*l, ArithExpr::Integer(3)));
752                assert!(matches!(*r, ArithExpr::Integer(7)));
753            }
754            _ => panic!("Expected Add expression"),
755        }
756    }
757
758    #[test]
759    fn test_expansion_with_variable_args() {
760        let mut reg = FunctionRegistry::new();
761
762        // func double(X) = X + X
763        let double = FuncDef {
764            name: "double".to_string(),
765            params: vec![FuncParam {
766                name: "X".to_string(),
767                typ: None,
768            }],
769            return_type: None,
770            body: FuncBody::Arithmetic(ArithExpr::Add(
771                Box::new(ArithExpr::Variable("X".to_string())),
772                Box::new(ArithExpr::Variable("X".to_string())),
773            )),
774            is_private: false,
775        };
776        reg.register(double).unwrap();
777
778        let mut ctx = ExpansionContext::new(&reg, 100);
779
780        // double(Y) should expand to Y + Y
781        let result = ctx
782            .expand_call("double", &[ArithExpr::Variable("Y".to_string())])
783            .unwrap();
784
785        match result {
786            ArithExpr::Add(l, r) => {
787                assert!(matches!(l.as_ref(), ArithExpr::Variable(n) if n == "Y"));
788                assert!(matches!(r.as_ref(), ArithExpr::Variable(n) if n == "Y"));
789            }
790            _ => panic!("Expected Add expression"),
791        }
792    }
793
794    #[test]
795    fn test_predicate_func_expansion() {
796        // func get_parent(X) = P :- parent(X, P).
797        // get_parent(alice) should expand to: parent(alice, P)
798
799        let func = FuncDef {
800            name: "get_parent".to_string(),
801            params: vec![FuncParam {
802                name: "X".to_string(),
803                typ: None,
804            }],
805            return_type: None,
806            body: FuncBody::Predicate {
807                result: "P".to_string(),
808                body: vec![BodyLiteral::Positive(Atom {
809                    predicate: "parent".to_string(),
810                    terms: vec![
811                        Term::Variable("X".to_string()),
812                        Term::Variable("P".to_string()),
813                    ],
814                })],
815            },
816            is_private: false,
817        };
818
819        let mut reg = FunctionRegistry::new();
820        reg.register(func).unwrap();
821
822        let ctx = ExpansionContext::new(&reg, 100);
823
824        // Call get_parent with "alice"
825        let args = vec![ArithExpr::Variable("alice".to_string())];
826        let func_def = reg.get("get_parent").unwrap();
827        let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
828
829        assert_eq!(result, "P");
830        assert_eq!(body.len(), 1);
831
832        // Check the expanded literal
833        if let BodyLiteral::Positive(atom) = &body[0] {
834            assert_eq!(atom.predicate, "parent");
835            assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "alice"));
836            assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "P"));
837        } else {
838            panic!("Expected Positive literal");
839        }
840    }
841
842    #[test]
843    fn test_predicate_func_with_constant_arg() {
844        // func get_child(P) = C :- parent(C, P).
845        // get_child(bob) should expand to: parent(C, bob)
846
847        let func = FuncDef {
848            name: "get_child".to_string(),
849            params: vec![FuncParam {
850                name: "P".to_string(),
851                typ: None,
852            }],
853            return_type: None,
854            body: FuncBody::Predicate {
855                result: "C".to_string(),
856                body: vec![BodyLiteral::Positive(Atom {
857                    predicate: "parent".to_string(),
858                    terms: vec![
859                        Term::Variable("C".to_string()),
860                        Term::Variable("P".to_string()),
861                    ],
862                })],
863            },
864            is_private: false,
865        };
866
867        let mut reg = FunctionRegistry::new();
868        reg.register(func).unwrap();
869
870        let ctx = ExpansionContext::new(&reg, 100);
871
872        // Call get_child with integer constant
873        let args = vec![ArithExpr::Integer(42)];
874        let func_def = reg.get("get_child").unwrap();
875        let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
876
877        assert_eq!(result, "C");
878        assert_eq!(body.len(), 1);
879
880        // Check the expanded literal has integer substituted
881        if let BodyLiteral::Positive(atom) = &body[0] {
882            assert_eq!(atom.predicate, "parent");
883            assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "C"));
884            assert!(matches!(&atom.terms[1], Term::Integer(42)));
885        } else {
886            panic!("Expected Positive literal");
887        }
888    }
889
890    #[test]
891    fn test_predicate_func_multiple_body_literals() {
892        // func get_grandparent(X) = G :- parent(X, P), parent(P, G).
893        // get_grandparent(alice) should expand to: parent(alice, P), parent(P, G)
894
895        let func = FuncDef {
896            name: "get_grandparent".to_string(),
897            params: vec![FuncParam {
898                name: "X".to_string(),
899                typ: None,
900            }],
901            return_type: None,
902            body: FuncBody::Predicate {
903                result: "G".to_string(),
904                body: vec![
905                    BodyLiteral::Positive(Atom {
906                        predicate: "parent".to_string(),
907                        terms: vec![
908                            Term::Variable("X".to_string()),
909                            Term::Variable("P".to_string()),
910                        ],
911                    }),
912                    BodyLiteral::Positive(Atom {
913                        predicate: "parent".to_string(),
914                        terms: vec![
915                            Term::Variable("P".to_string()),
916                            Term::Variable("G".to_string()),
917                        ],
918                    }),
919                ],
920            },
921            is_private: false,
922        };
923
924        let mut reg = FunctionRegistry::new();
925        reg.register(func).unwrap();
926
927        let ctx = ExpansionContext::new(&reg, 100);
928
929        let args = vec![ArithExpr::Variable("alice".to_string())];
930        let func_def = reg.get("get_grandparent").unwrap();
931        let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
932
933        assert_eq!(result, "G");
934        assert_eq!(body.len(), 2);
935
936        // First literal: parent(alice, P)
937        if let BodyLiteral::Positive(atom) = &body[0] {
938            assert_eq!(atom.predicate, "parent");
939            assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "alice"));
940            assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "P"));
941        } else {
942            panic!("Expected Positive literal for first body");
943        }
944
945        // Second literal: parent(P, G)
946        if let BodyLiteral::Positive(atom) = &body[1] {
947            assert_eq!(atom.predicate, "parent");
948            assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "P"));
949            assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "G"));
950        } else {
951            panic!("Expected Positive literal for second body");
952        }
953    }
954
955    #[test]
956    fn test_is_predicate_func() {
957        let mut reg = FunctionRegistry::new();
958
959        // Arithmetic function
960        let arith_func = FuncDef {
961            name: "double".to_string(),
962            params: vec![FuncParam {
963                name: "X".to_string(),
964                typ: None,
965            }],
966            return_type: None,
967            body: FuncBody::Arithmetic(ArithExpr::Add(
968                Box::new(ArithExpr::Variable("X".to_string())),
969                Box::new(ArithExpr::Variable("X".to_string())),
970            )),
971            is_private: false,
972        };
973
974        // Predicate function
975        let pred_func = FuncDef {
976            name: "get_parent".to_string(),
977            params: vec![FuncParam {
978                name: "X".to_string(),
979                typ: None,
980            }],
981            return_type: None,
982            body: FuncBody::Predicate {
983                result: "P".to_string(),
984                body: vec![BodyLiteral::Positive(Atom {
985                    predicate: "parent".to_string(),
986                    terms: vec![
987                        Term::Variable("X".to_string()),
988                        Term::Variable("P".to_string()),
989                    ],
990                })],
991            },
992            is_private: false,
993        };
994
995        reg.register(arith_func).unwrap();
996        reg.register(pred_func).unwrap();
997
998        let ctx = ExpansionContext::new(&reg, 100);
999
1000        assert!(!ctx.is_predicate_func("double"));
1001        assert!(ctx.is_predicate_func("get_parent"));
1002        assert!(!ctx.is_predicate_func("nonexistent"));
1003    }
1004}