Skip to main content

xlog_logic/
typeinfer.rs

1//! Type inference for user-defined functions.
2//!
3//! Reserved API: type inference is not yet wired into the main compilation pipeline.
4
5use crate::ast::{ArithExpr, FuncDef};
6use std::collections::HashMap;
7use xlog_core::ScalarType;
8
9/// Type inference context
10#[derive(Debug, Default)]
11pub(crate) struct TypeContext {
12    /// Known variable types
13    bindings: HashMap<String, ScalarType>,
14}
15
16impl TypeContext {
17    pub fn new() -> Self {
18        Self::default()
19    }
20
21    /// Bind a variable to a type
22    pub fn bind(&mut self, name: &str, typ: ScalarType) {
23        self.bindings.insert(name.to_string(), typ);
24    }
25
26    /// Get a variable's type
27    pub fn get(&self, name: &str) -> Option<ScalarType> {
28        self.bindings.get(name).copied()
29    }
30
31    /// Infer type of an expression
32    pub fn infer_expr(&self, expr: &ArithExpr) -> Option<ScalarType> {
33        match expr {
34            ArithExpr::Variable(name) => self.get(name),
35            ArithExpr::Integer(_) => Some(ScalarType::I64),
36            ArithExpr::Float(_) => Some(ScalarType::F64),
37            ArithExpr::Add(l, r)
38            | ArithExpr::Sub(l, r)
39            | ArithExpr::Mul(l, r)
40            | ArithExpr::Div(l, r)
41            | ArithExpr::Mod(l, r) => {
42                let lt = self.infer_expr(l)?;
43                let rt = self.infer_expr(r)?;
44                // Numeric promotion: if either is f64, result is f64
45                if lt == ScalarType::F64 || rt == ScalarType::F64 {
46                    Some(ScalarType::F64)
47                } else {
48                    Some(lt)
49                }
50            }
51            ArithExpr::Cast(_, t) => Some(*t),
52            ArithExpr::Abs(e) => self.infer_expr(e),
53            ArithExpr::Min(l, r) | ArithExpr::Max(l, r) | ArithExpr::Pow(l, r) => {
54                let lt = self.infer_expr(l)?;
55                let rt = self.infer_expr(r)?;
56                if lt == ScalarType::F64 || rt == ScalarType::F64 {
57                    Some(ScalarType::F64)
58                } else {
59                    Some(lt)
60                }
61            }
62            ArithExpr::FuncCall { .. } => None, // Need registry lookup
63            ArithExpr::Conditional {
64                then_expr,
65                else_expr,
66                ..
67            } => {
68                // Type is the common type of both branches
69                let then_t = self.infer_expr(then_expr)?;
70                let else_t = self.infer_expr(else_expr)?;
71                if then_t == else_t {
72                    Some(then_t)
73                } else if then_t == ScalarType::F64 || else_t == ScalarType::F64 {
74                    Some(ScalarType::F64)
75                } else {
76                    Some(then_t)
77                }
78            }
79        }
80    }
81}
82
83/// Infer parameter types from function definition
84pub(crate) fn infer_param_types(func: &FuncDef) -> Vec<Option<ScalarType>> {
85    func.params.iter().map(|p| p.typ).collect()
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91
92    #[test]
93    fn test_infer_literal() {
94        let ctx = TypeContext::new();
95        assert_eq!(
96            ctx.infer_expr(&ArithExpr::Integer(5)),
97            Some(ScalarType::I64)
98        );
99        assert_eq!(
100            ctx.infer_expr(&ArithExpr::Float(3.25)),
101            Some(ScalarType::F64)
102        );
103    }
104
105    #[test]
106    fn test_infer_variable() {
107        let mut ctx = TypeContext::new();
108        ctx.bind("X", ScalarType::F64);
109        assert_eq!(
110            ctx.infer_expr(&ArithExpr::Variable("X".into())),
111            Some(ScalarType::F64)
112        );
113    }
114
115    #[test]
116    fn test_infer_numeric_promotion() {
117        let ctx = TypeContext::new();
118        let expr = ArithExpr::Add(
119            Box::new(ArithExpr::Integer(5)),
120            Box::new(ArithExpr::Float(3.0)),
121        );
122        assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::F64));
123    }
124
125    #[test]
126    fn test_infer_unknown_variable() {
127        let ctx = TypeContext::new();
128        assert_eq!(ctx.infer_expr(&ArithExpr::Variable("Unknown".into())), None);
129    }
130
131    #[test]
132    fn test_infer_cast() {
133        let ctx = TypeContext::new();
134        let expr = ArithExpr::Cast(Box::new(ArithExpr::Integer(5)), ScalarType::F64);
135        assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::F64));
136    }
137
138    #[test]
139    fn test_infer_abs() {
140        let mut ctx = TypeContext::new();
141        ctx.bind("X", ScalarType::I64);
142        let expr = ArithExpr::Abs(Box::new(ArithExpr::Variable("X".into())));
143        assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::I64));
144    }
145
146    #[test]
147    fn test_infer_min_max() {
148        let ctx = TypeContext::new();
149        let expr = ArithExpr::Min(
150            Box::new(ArithExpr::Integer(5)),
151            Box::new(ArithExpr::Integer(3)),
152        );
153        assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::I64));
154
155        let expr_float = ArithExpr::Max(
156            Box::new(ArithExpr::Integer(5)),
157            Box::new(ArithExpr::Float(3.0)),
158        );
159        assert_eq!(ctx.infer_expr(&expr_float), Some(ScalarType::F64));
160    }
161
162    #[test]
163    fn test_infer_conditional() {
164        use crate::ast::CompOp;
165
166        let mut ctx = TypeContext::new();
167        ctx.bind("X", ScalarType::I64);
168        let expr = ArithExpr::Conditional {
169            cond_left: Box::new(ArithExpr::Variable("X".into())),
170            cond_op: CompOp::Lt,
171            cond_right: Box::new(ArithExpr::Integer(0)),
172            then_expr: Box::new(ArithExpr::Integer(1)),
173            else_expr: Box::new(ArithExpr::Integer(2)),
174        };
175        assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::I64));
176    }
177
178    #[test]
179    fn test_infer_param_types() {
180        use crate::ast::{FuncBody, FuncParam};
181
182        let func = FuncDef {
183            name: "test".to_string(),
184            params: vec![
185                FuncParam {
186                    name: "X".to_string(),
187                    typ: Some(ScalarType::I64),
188                },
189                FuncParam {
190                    name: "Y".to_string(),
191                    typ: None,
192                },
193            ],
194            return_type: None,
195            body: FuncBody::Arithmetic(ArithExpr::Integer(1)),
196            is_private: false,
197        };
198
199        let types = infer_param_types(&func);
200        assert_eq!(types.len(), 2);
201        assert_eq!(types[0], Some(ScalarType::I64));
202        assert_eq!(types[1], None);
203    }
204}