1use crate::ast::{ArithExpr, Atom, BodyLiteral, Comparison, FuncBody, FuncDef, IsExpr, Term, Univ};
4use crate::function::{FunctionError, FunctionRegistry};
5use std::collections::HashMap;
6
7pub struct ExpansionContext<'a> {
9 registry: &'a FunctionRegistry,
10 depth: u32,
11 max_depth: u32,
12}
13
14impl<'a> ExpansionContext<'a> {
15 pub fn new(registry: &'a FunctionRegistry, max_depth: u32) -> Self {
17 Self {
18 registry,
19 depth: 0,
20 max_depth,
21 }
22 }
23
24 pub fn expand_call(
26 &mut self,
27 name: &str,
28 args: &[ArithExpr],
29 ) -> Result<ArithExpr, FunctionError> {
30 if self.depth >= self.max_depth {
32 return Err(FunctionError::MaxRecursionDepth {
33 name: name.to_string(),
34 depth: self.max_depth,
35 });
36 }
37
38 let func = self
39 .registry
40 .get(name)
41 .ok_or_else(|| FunctionError::UndefinedFunction {
42 name: name.to_string(),
43 })?;
44
45 let mut subst: HashMap<String, ArithExpr> = HashMap::new();
47 for (param, arg) in func.params.iter().zip(args.iter()) {
48 subst.insert(param.name.clone(), arg.clone());
49 }
50
51 self.depth += 1;
53 let result = self.expand_body(&func.body, &subst)?;
54 self.depth -= 1;
55
56 Ok(result)
57 }
58
59 fn expand_body(
60 &mut self,
61 body: &FuncBody,
62 subst: &HashMap<String, ArithExpr>,
63 ) -> Result<ArithExpr, FunctionError> {
64 match body {
65 FuncBody::Arithmetic(expr) => self.expand_expr(expr, subst),
66 FuncBody::Conditional(cond) => {
67 let cond_left = self.expand_expr(&cond.cond_left, subst)?;
69 let cond_right = self.expand_expr(&cond.cond_right, subst)?;
70 let then_expr = self.expand_body(&cond.then_branch, subst)?;
71 let else_expr = self.expand_body(&cond.else_branch, subst)?;
72
73 Ok(ArithExpr::Conditional {
77 cond_left: Box::new(cond_left),
78 cond_op: cond.cond_op,
79 cond_right: Box::new(cond_right),
80 then_expr: Box::new(then_expr),
81 else_expr: Box::new(else_expr),
82 })
83 }
84 FuncBody::Predicate { result, .. } => {
85 let result_var = subst
90 .get(result)
91 .cloned()
92 .unwrap_or_else(|| ArithExpr::Variable(result.clone()));
93 Ok(result_var)
94 }
95 }
96 }
97
98 fn expand_expr(
99 &mut self,
100 expr: &ArithExpr,
101 subst: &HashMap<String, ArithExpr>,
102 ) -> Result<ArithExpr, FunctionError> {
103 match expr {
104 ArithExpr::Variable(name) => {
105 Ok(subst.get(name).cloned().unwrap_or_else(|| expr.clone()))
106 }
107 ArithExpr::Integer(_) | ArithExpr::Float(_) => Ok(expr.clone()),
108 ArithExpr::FuncCall { name, args } => {
109 let expanded_args: Result<Vec<_>, _> =
111 args.iter().map(|a| self.expand_expr(a, subst)).collect();
112 let expanded_args = expanded_args?;
113
114 if self.registry.contains(name) {
116 self.expand_call(name, &expanded_args)
117 } else {
118 Ok(ArithExpr::FuncCall {
120 name: name.clone(),
121 args: expanded_args,
122 })
123 }
124 }
125 ArithExpr::Add(l, r) => {
126 let el = self.expand_expr(l, subst)?;
127 let er = self.expand_expr(r, subst)?;
128 Ok(ArithExpr::Add(Box::new(el), Box::new(er)))
129 }
130 ArithExpr::Sub(l, r) => {
131 let el = self.expand_expr(l, subst)?;
132 let er = self.expand_expr(r, subst)?;
133 Ok(ArithExpr::Sub(Box::new(el), Box::new(er)))
134 }
135 ArithExpr::Mul(l, r) => {
136 let el = self.expand_expr(l, subst)?;
137 let er = self.expand_expr(r, subst)?;
138 Ok(ArithExpr::Mul(Box::new(el), Box::new(er)))
139 }
140 ArithExpr::Div(l, r) => {
141 let el = self.expand_expr(l, subst)?;
142 let er = self.expand_expr(r, subst)?;
143 Ok(ArithExpr::Div(Box::new(el), Box::new(er)))
144 }
145 ArithExpr::Mod(l, r) => {
146 let el = self.expand_expr(l, subst)?;
147 let er = self.expand_expr(r, subst)?;
148 Ok(ArithExpr::Mod(Box::new(el), Box::new(er)))
149 }
150 ArithExpr::Abs(e) => {
151 let ee = self.expand_expr(e, subst)?;
152 Ok(ArithExpr::Abs(Box::new(ee)))
153 }
154 ArithExpr::Min(l, r) => {
155 let el = self.expand_expr(l, subst)?;
156 let er = self.expand_expr(r, subst)?;
157 Ok(ArithExpr::Min(Box::new(el), Box::new(er)))
158 }
159 ArithExpr::Max(l, r) => {
160 let el = self.expand_expr(l, subst)?;
161 let er = self.expand_expr(r, subst)?;
162 Ok(ArithExpr::Max(Box::new(el), Box::new(er)))
163 }
164 ArithExpr::Pow(l, r) => {
165 let el = self.expand_expr(l, subst)?;
166 let er = self.expand_expr(r, subst)?;
167 Ok(ArithExpr::Pow(Box::new(el), Box::new(er)))
168 }
169 ArithExpr::Cast(e, t) => {
170 let ee = self.expand_expr(e, subst)?;
171 Ok(ArithExpr::Cast(Box::new(ee), *t))
172 }
173 ArithExpr::Conditional {
174 cond_left,
175 cond_op,
176 cond_right,
177 then_expr,
178 else_expr,
179 } => {
180 let cl = self.expand_expr(cond_left, subst)?;
181 let cr = self.expand_expr(cond_right, subst)?;
182 let te = self.expand_expr(then_expr, subst)?;
183 let ee = self.expand_expr(else_expr, subst)?;
184 Ok(ArithExpr::Conditional {
185 cond_left: Box::new(cl),
186 cond_op: *cond_op,
187 cond_right: Box::new(cr),
188 then_expr: Box::new(te),
189 else_expr: Box::new(ee),
190 })
191 }
192 }
193 }
194
195 #[allow(dead_code)] pub(crate) fn expand_predicate_func(
203 &self,
204 func: &FuncDef,
205 args: &[ArithExpr],
206 ) -> Result<(Vec<BodyLiteral>, String), FunctionError> {
207 match &func.body {
208 FuncBody::Predicate { result, body } => {
209 let mut subst: HashMap<String, ArithExpr> = HashMap::new();
211 for (param, arg) in func.params.iter().zip(args.iter()) {
212 subst.insert(param.name.clone(), arg.clone());
213 }
214
215 let expanded_body: Vec<BodyLiteral> = body
217 .iter()
218 .map(|lit| self.substitute_literal(lit, &subst))
219 .collect();
220
221 let result_var = self.substitute_var(result, &subst);
223
224 Ok((expanded_body, result_var))
225 }
226 _ => Err(FunctionError::UndefinedFunction {
227 name: func.name.clone(),
228 }),
229 }
230 }
231
232 fn substitute_literal(
234 &self,
235 lit: &BodyLiteral,
236 subst: &HashMap<String, ArithExpr>,
237 ) -> BodyLiteral {
238 match lit {
239 BodyLiteral::Positive(atom) => BodyLiteral::Positive(self.substitute_atom(atom, subst)),
240 BodyLiteral::Negated(atom) => BodyLiteral::Negated(self.substitute_atom(atom, subst)),
241 BodyLiteral::Epistemic(lit) => BodyLiteral::Epistemic(crate::ast::EpistemicLiteral {
242 op: lit.op,
243 negated: lit.negated,
244 atom: self.substitute_atom(&lit.atom, subst),
245 }),
246 BodyLiteral::Comparison(cmp) => BodyLiteral::Comparison(Comparison {
247 left: self.substitute_term(&cmp.left, subst),
248 op: cmp.op,
249 right: self.substitute_term(&cmp.right, subst),
250 }),
251 BodyLiteral::IsExpr(is_expr) => {
252 let target = self.substitute_var(&is_expr.target, subst);
254 let expr = self.substitute_arith_expr(&is_expr.expr, subst);
256 BodyLiteral::IsExpr(IsExpr { target, expr })
257 }
258 BodyLiteral::Univ(univ) => BodyLiteral::Univ(Univ {
259 term: self.substitute_term(&univ.term, subst),
260 parts: self.substitute_term(&univ.parts, subst),
261 }),
262 }
263 }
264
265 fn substitute_atom(&self, atom: &Atom, subst: &HashMap<String, ArithExpr>) -> Atom {
267 Atom {
268 predicate: atom.predicate.clone(),
269 terms: atom
270 .terms
271 .iter()
272 .map(|t| self.substitute_term(t, subst))
273 .collect(),
274 }
275 }
276
277 fn substitute_term(&self, term: &Term, subst: &HashMap<String, ArithExpr>) -> Term {
279 match term {
280 Term::Variable(name) => {
281 if let Some(replacement) = subst.get(name) {
282 match replacement {
283 ArithExpr::Variable(new_name) => Term::Variable(new_name.clone()),
284 ArithExpr::Integer(n) => Term::Integer(*n),
285 ArithExpr::Float(f) => Term::Float(*f),
286 _ => term.clone(),
289 }
290 } else {
291 term.clone()
292 }
293 }
294 _ => term.clone(),
295 }
296 }
297
298 fn substitute_arith_expr(
300 &self,
301 expr: &ArithExpr,
302 subst: &HashMap<String, ArithExpr>,
303 ) -> ArithExpr {
304 match expr {
305 ArithExpr::Variable(name) => subst.get(name).cloned().unwrap_or_else(|| expr.clone()),
306 ArithExpr::Integer(_) | ArithExpr::Float(_) => expr.clone(),
307 ArithExpr::Add(l, r) => ArithExpr::Add(
308 Box::new(self.substitute_arith_expr(l, subst)),
309 Box::new(self.substitute_arith_expr(r, subst)),
310 ),
311 ArithExpr::Sub(l, r) => ArithExpr::Sub(
312 Box::new(self.substitute_arith_expr(l, subst)),
313 Box::new(self.substitute_arith_expr(r, subst)),
314 ),
315 ArithExpr::Mul(l, r) => ArithExpr::Mul(
316 Box::new(self.substitute_arith_expr(l, subst)),
317 Box::new(self.substitute_arith_expr(r, subst)),
318 ),
319 ArithExpr::Div(l, r) => ArithExpr::Div(
320 Box::new(self.substitute_arith_expr(l, subst)),
321 Box::new(self.substitute_arith_expr(r, subst)),
322 ),
323 ArithExpr::Mod(l, r) => ArithExpr::Mod(
324 Box::new(self.substitute_arith_expr(l, subst)),
325 Box::new(self.substitute_arith_expr(r, subst)),
326 ),
327 ArithExpr::Abs(e) => ArithExpr::Abs(Box::new(self.substitute_arith_expr(e, subst))),
328 ArithExpr::Min(l, r) => ArithExpr::Min(
329 Box::new(self.substitute_arith_expr(l, subst)),
330 Box::new(self.substitute_arith_expr(r, subst)),
331 ),
332 ArithExpr::Max(l, r) => ArithExpr::Max(
333 Box::new(self.substitute_arith_expr(l, subst)),
334 Box::new(self.substitute_arith_expr(r, subst)),
335 ),
336 ArithExpr::Pow(l, r) => ArithExpr::Pow(
337 Box::new(self.substitute_arith_expr(l, subst)),
338 Box::new(self.substitute_arith_expr(r, subst)),
339 ),
340 ArithExpr::Cast(e, t) => {
341 ArithExpr::Cast(Box::new(self.substitute_arith_expr(e, subst)), *t)
342 }
343 ArithExpr::FuncCall { name, args } => ArithExpr::FuncCall {
344 name: name.clone(),
345 args: args
346 .iter()
347 .map(|a| self.substitute_arith_expr(a, subst))
348 .collect(),
349 },
350 ArithExpr::Conditional {
351 cond_left,
352 cond_op,
353 cond_right,
354 then_expr,
355 else_expr,
356 } => ArithExpr::Conditional {
357 cond_left: Box::new(self.substitute_arith_expr(cond_left, subst)),
358 cond_op: *cond_op,
359 cond_right: Box::new(self.substitute_arith_expr(cond_right, subst)),
360 then_expr: Box::new(self.substitute_arith_expr(then_expr, subst)),
361 else_expr: Box::new(self.substitute_arith_expr(else_expr, subst)),
362 },
363 }
364 }
365
366 fn substitute_var(&self, var: &str, subst: &HashMap<String, ArithExpr>) -> String {
368 if let Some(ArithExpr::Variable(new_name)) = subst.get(var) {
369 new_name.clone()
370 } else {
371 var.to_string()
372 }
373 }
374
375 #[allow(dead_code)] pub(crate) fn is_predicate_func(&self, name: &str) -> bool {
378 self.registry
379 .get(name)
380 .map(|f| matches!(f.body, FuncBody::Predicate { .. }))
381 .unwrap_or(false)
382 }
383
384 pub(crate) fn expand_expr_fully(
387 &mut self,
388 expr: &ArithExpr,
389 ) -> Result<ArithExpr, FunctionError> {
390 match expr {
391 ArithExpr::Variable(_) | ArithExpr::Integer(_) | ArithExpr::Float(_) => {
392 Ok(expr.clone())
393 }
394 ArithExpr::FuncCall { name, args } => {
395 let expanded_args: Result<Vec<_>, _> =
397 args.iter().map(|a| self.expand_expr_fully(a)).collect();
398 let expanded_args = expanded_args?;
399
400 if self.registry.contains(name) {
402 self.expand_call(name, &expanded_args)
403 } else {
404 Ok(ArithExpr::FuncCall {
406 name: name.clone(),
407 args: expanded_args,
408 })
409 }
410 }
411 ArithExpr::Add(l, r) => Ok(ArithExpr::Add(
412 Box::new(self.expand_expr_fully(l)?),
413 Box::new(self.expand_expr_fully(r)?),
414 )),
415 ArithExpr::Sub(l, r) => Ok(ArithExpr::Sub(
416 Box::new(self.expand_expr_fully(l)?),
417 Box::new(self.expand_expr_fully(r)?),
418 )),
419 ArithExpr::Mul(l, r) => Ok(ArithExpr::Mul(
420 Box::new(self.expand_expr_fully(l)?),
421 Box::new(self.expand_expr_fully(r)?),
422 )),
423 ArithExpr::Div(l, r) => Ok(ArithExpr::Div(
424 Box::new(self.expand_expr_fully(l)?),
425 Box::new(self.expand_expr_fully(r)?),
426 )),
427 ArithExpr::Mod(l, r) => Ok(ArithExpr::Mod(
428 Box::new(self.expand_expr_fully(l)?),
429 Box::new(self.expand_expr_fully(r)?),
430 )),
431 ArithExpr::Abs(e) => Ok(ArithExpr::Abs(Box::new(self.expand_expr_fully(e)?))),
432 ArithExpr::Min(l, r) => Ok(ArithExpr::Min(
433 Box::new(self.expand_expr_fully(l)?),
434 Box::new(self.expand_expr_fully(r)?),
435 )),
436 ArithExpr::Max(l, r) => Ok(ArithExpr::Max(
437 Box::new(self.expand_expr_fully(l)?),
438 Box::new(self.expand_expr_fully(r)?),
439 )),
440 ArithExpr::Pow(l, r) => Ok(ArithExpr::Pow(
441 Box::new(self.expand_expr_fully(l)?),
442 Box::new(self.expand_expr_fully(r)?),
443 )),
444 ArithExpr::Cast(e, t) => Ok(ArithExpr::Cast(Box::new(self.expand_expr_fully(e)?), *t)),
445 ArithExpr::Conditional {
446 cond_left,
447 cond_op,
448 cond_right,
449 then_expr,
450 else_expr,
451 } => Ok(ArithExpr::Conditional {
452 cond_left: Box::new(self.expand_expr_fully(cond_left)?),
453 cond_op: *cond_op,
454 cond_right: Box::new(self.expand_expr_fully(cond_right)?),
455 then_expr: Box::new(self.expand_expr_fully(then_expr)?),
456 else_expr: Box::new(self.expand_expr_fully(else_expr)?),
457 }),
458 }
459 }
460}
461
462use crate::ast::{Program, Rule};
463
464pub fn expand_program_functions(
468 program: &Program,
469 max_depth: u32,
470) -> Result<Program, FunctionError> {
471 let mut registry = FunctionRegistry::new();
473 for func in &program.functions {
474 registry.register(func.clone())?;
475 }
476
477 if program.functions.is_empty() {
479 return Ok(program.clone());
480 }
481
482 let mut ctx = ExpansionContext::new(®istry, max_depth);
483
484 let expanded_rules: Result<Vec<Rule>, FunctionError> = program
486 .rules
487 .iter()
488 .map(|rule| expand_rule_functions(&mut ctx, rule))
489 .collect();
490
491 Ok(Program {
492 rules: expanded_rules?,
493 directives: program.directives.clone(),
494 queries: program.queries.clone(),
495 predicates: program.predicates.clone(),
496 constraints: program.constraints.clone(),
497 imports: program.imports.clone(),
498 functions: program.functions.clone(),
499 domains: program.domains.clone(),
500 prob_facts: program.prob_facts.clone(),
501 annotated_disjunctions: program.annotated_disjunctions.clone(),
502 evidence: program.evidence.clone(),
503 prob_queries: program.prob_queries.clone(),
504 neural_predicates: program.neural_predicates.clone(),
505 learnable_rules: program.learnable_rules.clone(),
506 })
507}
508
509fn expand_rule_functions(ctx: &mut ExpansionContext, rule: &Rule) -> Result<Rule, FunctionError> {
511 let expanded_body: Result<Vec<BodyLiteral>, FunctionError> = rule
512 .body
513 .iter()
514 .map(|lit| expand_literal_functions(ctx, lit))
515 .collect();
516
517 Ok(Rule {
518 head: rule.head.clone(),
519 body: expanded_body?,
520 })
521}
522
523fn expand_literal_functions(
525 ctx: &mut ExpansionContext,
526 lit: &BodyLiteral,
527) -> Result<BodyLiteral, FunctionError> {
528 match lit {
529 BodyLiteral::Positive(atom) => Ok(BodyLiteral::Positive(atom.clone())),
530 BodyLiteral::Negated(atom) => Ok(BodyLiteral::Negated(atom.clone())),
531 BodyLiteral::Epistemic(lit) => Ok(BodyLiteral::Epistemic(lit.clone())),
532 BodyLiteral::Comparison(cmp) => Ok(BodyLiteral::Comparison(cmp.clone())),
533 BodyLiteral::IsExpr(is_expr) => {
534 let expanded_expr = ctx.expand_expr_fully(&is_expr.expr)?;
535 Ok(BodyLiteral::IsExpr(IsExpr {
536 target: is_expr.target.clone(),
537 expr: expanded_expr,
538 }))
539 }
540 BodyLiteral::Univ(univ) => Ok(BodyLiteral::Univ(univ.clone())),
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use crate::ast::{FuncDef, FuncParam};
548
549 #[test]
550 fn test_simple_expansion() {
551 let mut reg = FunctionRegistry::new();
552
553 let double = FuncDef {
555 name: "double".to_string(),
556 params: vec![FuncParam {
557 name: "X".to_string(),
558 typ: None,
559 }],
560 return_type: None,
561 body: FuncBody::Arithmetic(ArithExpr::Add(
562 Box::new(ArithExpr::Variable("X".to_string())),
563 Box::new(ArithExpr::Variable("X".to_string())),
564 )),
565 is_private: false,
566 };
567 reg.register(double).unwrap();
568
569 let mut ctx = ExpansionContext::new(®, 100);
570
571 let result = ctx.expand_call("double", &[ArithExpr::Integer(5)]).unwrap();
573
574 match result {
575 ArithExpr::Add(l, r) => {
576 assert!(matches!(*l, ArithExpr::Integer(5)));
577 assert!(matches!(*r, ArithExpr::Integer(5)));
578 }
579 _ => panic!("Expected Add expression"),
580 }
581 }
582
583 #[test]
584 fn test_nested_expansion() {
585 let mut reg = FunctionRegistry::new();
586
587 let double = FuncDef {
589 name: "double".to_string(),
590 params: vec![FuncParam {
591 name: "X".to_string(),
592 typ: None,
593 }],
594 return_type: None,
595 body: FuncBody::Arithmetic(ArithExpr::Add(
596 Box::new(ArithExpr::Variable("X".to_string())),
597 Box::new(ArithExpr::Variable("X".to_string())),
598 )),
599 is_private: false,
600 };
601
602 let quadruple = FuncDef {
604 name: "quadruple".to_string(),
605 params: vec![FuncParam {
606 name: "X".to_string(),
607 typ: None,
608 }],
609 return_type: None,
610 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
611 name: "double".to_string(),
612 args: vec![ArithExpr::FuncCall {
613 name: "double".to_string(),
614 args: vec![ArithExpr::Variable("X".to_string())],
615 }],
616 }),
617 is_private: false,
618 };
619
620 reg.register(double).unwrap();
621 reg.register(quadruple).unwrap();
622
623 let mut ctx = ExpansionContext::new(®, 100);
624
625 let result = ctx
627 .expand_call("quadruple", &[ArithExpr::Integer(2)])
628 .unwrap();
629
630 match &result {
632 ArithExpr::Add(l, r) => {
633 assert!(matches!(l.as_ref(), ArithExpr::Add(_, _)));
634 assert!(matches!(r.as_ref(), ArithExpr::Add(_, _)));
635 }
636 _ => panic!("Expected nested Add expression, got {:?}", result),
637 }
638 }
639
640 #[test]
641 fn test_max_recursion_depth() {
642 let mut reg = FunctionRegistry::new();
643
644 let infinite = FuncDef {
646 name: "infinite".to_string(),
647 params: vec![FuncParam {
648 name: "X".to_string(),
649 typ: None,
650 }],
651 return_type: None,
652 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
653 name: "infinite".to_string(),
654 args: vec![ArithExpr::Variable("X".to_string())],
655 }),
656 is_private: false,
657 };
658 reg.register(infinite).unwrap();
659
660 let mut ctx = ExpansionContext::new(®, 10);
661
662 let result = ctx.expand_call("infinite", &[ArithExpr::Integer(1)]);
663 assert!(matches!(
664 result,
665 Err(FunctionError::MaxRecursionDepth { .. })
666 ));
667 }
668
669 #[test]
670 fn test_undefined_function() {
671 let reg = FunctionRegistry::new();
672 let mut ctx = ExpansionContext::new(®, 100);
673
674 let result = ctx.expand_call("undefined", &[ArithExpr::Integer(1)]);
675 assert!(matches!(
676 result,
677 Err(FunctionError::UndefinedFunction { .. })
678 ));
679 }
680
681 #[test]
682 fn test_builtin_function_passthrough() {
683 let mut reg = FunctionRegistry::new();
684
685 let abs_x = FuncDef {
687 name: "abs_x".to_string(),
688 params: vec![FuncParam {
689 name: "X".to_string(),
690 typ: None,
691 }],
692 return_type: None,
693 body: FuncBody::Arithmetic(ArithExpr::FuncCall {
694 name: "abs".to_string(),
695 args: vec![ArithExpr::Variable("X".to_string())],
696 }),
697 is_private: false,
698 };
699 reg.register(abs_x).unwrap();
700
701 let mut ctx = ExpansionContext::new(®, 100);
702
703 let result = ctx.expand_call("abs_x", &[ArithExpr::Integer(-5)]).unwrap();
704
705 match result {
707 ArithExpr::FuncCall { name, args } => {
708 assert_eq!(name, "abs");
709 assert_eq!(args.len(), 1);
710 assert!(matches!(args[0], ArithExpr::Integer(-5)));
711 }
712 _ => panic!("Expected FuncCall for builtin"),
713 }
714 }
715
716 #[test]
717 fn test_variable_substitution() {
718 let mut reg = FunctionRegistry::new();
719
720 let add = FuncDef {
722 name: "add".to_string(),
723 params: vec![
724 FuncParam {
725 name: "X".to_string(),
726 typ: None,
727 },
728 FuncParam {
729 name: "Y".to_string(),
730 typ: None,
731 },
732 ],
733 return_type: None,
734 body: FuncBody::Arithmetic(ArithExpr::Add(
735 Box::new(ArithExpr::Variable("X".to_string())),
736 Box::new(ArithExpr::Variable("Y".to_string())),
737 )),
738 is_private: false,
739 };
740 reg.register(add).unwrap();
741
742 let mut ctx = ExpansionContext::new(®, 100);
743
744 let result = ctx
746 .expand_call("add", &[ArithExpr::Integer(3), ArithExpr::Integer(7)])
747 .unwrap();
748
749 match result {
750 ArithExpr::Add(l, r) => {
751 assert!(matches!(*l, ArithExpr::Integer(3)));
752 assert!(matches!(*r, ArithExpr::Integer(7)));
753 }
754 _ => panic!("Expected Add expression"),
755 }
756 }
757
758 #[test]
759 fn test_expansion_with_variable_args() {
760 let mut reg = FunctionRegistry::new();
761
762 let double = FuncDef {
764 name: "double".to_string(),
765 params: vec![FuncParam {
766 name: "X".to_string(),
767 typ: None,
768 }],
769 return_type: None,
770 body: FuncBody::Arithmetic(ArithExpr::Add(
771 Box::new(ArithExpr::Variable("X".to_string())),
772 Box::new(ArithExpr::Variable("X".to_string())),
773 )),
774 is_private: false,
775 };
776 reg.register(double).unwrap();
777
778 let mut ctx = ExpansionContext::new(®, 100);
779
780 let result = ctx
782 .expand_call("double", &[ArithExpr::Variable("Y".to_string())])
783 .unwrap();
784
785 match result {
786 ArithExpr::Add(l, r) => {
787 assert!(matches!(l.as_ref(), ArithExpr::Variable(n) if n == "Y"));
788 assert!(matches!(r.as_ref(), ArithExpr::Variable(n) if n == "Y"));
789 }
790 _ => panic!("Expected Add expression"),
791 }
792 }
793
794 #[test]
795 fn test_predicate_func_expansion() {
796 let func = FuncDef {
800 name: "get_parent".to_string(),
801 params: vec![FuncParam {
802 name: "X".to_string(),
803 typ: None,
804 }],
805 return_type: None,
806 body: FuncBody::Predicate {
807 result: "P".to_string(),
808 body: vec![BodyLiteral::Positive(Atom {
809 predicate: "parent".to_string(),
810 terms: vec![
811 Term::Variable("X".to_string()),
812 Term::Variable("P".to_string()),
813 ],
814 })],
815 },
816 is_private: false,
817 };
818
819 let mut reg = FunctionRegistry::new();
820 reg.register(func).unwrap();
821
822 let ctx = ExpansionContext::new(®, 100);
823
824 let args = vec![ArithExpr::Variable("alice".to_string())];
826 let func_def = reg.get("get_parent").unwrap();
827 let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
828
829 assert_eq!(result, "P");
830 assert_eq!(body.len(), 1);
831
832 if let BodyLiteral::Positive(atom) = &body[0] {
834 assert_eq!(atom.predicate, "parent");
835 assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "alice"));
836 assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "P"));
837 } else {
838 panic!("Expected Positive literal");
839 }
840 }
841
842 #[test]
843 fn test_predicate_func_with_constant_arg() {
844 let func = FuncDef {
848 name: "get_child".to_string(),
849 params: vec![FuncParam {
850 name: "P".to_string(),
851 typ: None,
852 }],
853 return_type: None,
854 body: FuncBody::Predicate {
855 result: "C".to_string(),
856 body: vec![BodyLiteral::Positive(Atom {
857 predicate: "parent".to_string(),
858 terms: vec![
859 Term::Variable("C".to_string()),
860 Term::Variable("P".to_string()),
861 ],
862 })],
863 },
864 is_private: false,
865 };
866
867 let mut reg = FunctionRegistry::new();
868 reg.register(func).unwrap();
869
870 let ctx = ExpansionContext::new(®, 100);
871
872 let args = vec![ArithExpr::Integer(42)];
874 let func_def = reg.get("get_child").unwrap();
875 let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
876
877 assert_eq!(result, "C");
878 assert_eq!(body.len(), 1);
879
880 if let BodyLiteral::Positive(atom) = &body[0] {
882 assert_eq!(atom.predicate, "parent");
883 assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "C"));
884 assert!(matches!(&atom.terms[1], Term::Integer(42)));
885 } else {
886 panic!("Expected Positive literal");
887 }
888 }
889
890 #[test]
891 fn test_predicate_func_multiple_body_literals() {
892 let func = FuncDef {
896 name: "get_grandparent".to_string(),
897 params: vec![FuncParam {
898 name: "X".to_string(),
899 typ: None,
900 }],
901 return_type: None,
902 body: FuncBody::Predicate {
903 result: "G".to_string(),
904 body: vec![
905 BodyLiteral::Positive(Atom {
906 predicate: "parent".to_string(),
907 terms: vec![
908 Term::Variable("X".to_string()),
909 Term::Variable("P".to_string()),
910 ],
911 }),
912 BodyLiteral::Positive(Atom {
913 predicate: "parent".to_string(),
914 terms: vec![
915 Term::Variable("P".to_string()),
916 Term::Variable("G".to_string()),
917 ],
918 }),
919 ],
920 },
921 is_private: false,
922 };
923
924 let mut reg = FunctionRegistry::new();
925 reg.register(func).unwrap();
926
927 let ctx = ExpansionContext::new(®, 100);
928
929 let args = vec![ArithExpr::Variable("alice".to_string())];
930 let func_def = reg.get("get_grandparent").unwrap();
931 let (body, result) = ctx.expand_predicate_func(func_def, &args).unwrap();
932
933 assert_eq!(result, "G");
934 assert_eq!(body.len(), 2);
935
936 if let BodyLiteral::Positive(atom) = &body[0] {
938 assert_eq!(atom.predicate, "parent");
939 assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "alice"));
940 assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "P"));
941 } else {
942 panic!("Expected Positive literal for first body");
943 }
944
945 if let BodyLiteral::Positive(atom) = &body[1] {
947 assert_eq!(atom.predicate, "parent");
948 assert!(matches!(&atom.terms[0], Term::Variable(v) if v == "P"));
949 assert!(matches!(&atom.terms[1], Term::Variable(v) if v == "G"));
950 } else {
951 panic!("Expected Positive literal for second body");
952 }
953 }
954
955 #[test]
956 fn test_is_predicate_func() {
957 let mut reg = FunctionRegistry::new();
958
959 let arith_func = FuncDef {
961 name: "double".to_string(),
962 params: vec![FuncParam {
963 name: "X".to_string(),
964 typ: None,
965 }],
966 return_type: None,
967 body: FuncBody::Arithmetic(ArithExpr::Add(
968 Box::new(ArithExpr::Variable("X".to_string())),
969 Box::new(ArithExpr::Variable("X".to_string())),
970 )),
971 is_private: false,
972 };
973
974 let pred_func = FuncDef {
976 name: "get_parent".to_string(),
977 params: vec![FuncParam {
978 name: "X".to_string(),
979 typ: None,
980 }],
981 return_type: None,
982 body: FuncBody::Predicate {
983 result: "P".to_string(),
984 body: vec![BodyLiteral::Positive(Atom {
985 predicate: "parent".to_string(),
986 terms: vec![
987 Term::Variable("X".to_string()),
988 Term::Variable("P".to_string()),
989 ],
990 })],
991 },
992 is_private: false,
993 };
994
995 reg.register(arith_func).unwrap();
996 reg.register(pred_func).unwrap();
997
998 let ctx = ExpansionContext::new(®, 100);
999
1000 assert!(!ctx.is_predicate_func("double"));
1001 assert!(ctx.is_predicate_func("get_parent"));
1002 assert!(!ctx.is_predicate_func("nonexistent"));
1003 }
1004}