1use crate::ast::{ArithExpr, FuncDef};
6use std::collections::HashMap;
7use xlog_core::ScalarType;
8
9#[derive(Debug, Default)]
11pub(crate) struct TypeContext {
12 bindings: HashMap<String, ScalarType>,
14}
15
16impl TypeContext {
17 pub fn new() -> Self {
18 Self::default()
19 }
20
21 pub fn bind(&mut self, name: &str, typ: ScalarType) {
23 self.bindings.insert(name.to_string(), typ);
24 }
25
26 pub fn get(&self, name: &str) -> Option<ScalarType> {
28 self.bindings.get(name).copied()
29 }
30
31 pub fn infer_expr(&self, expr: &ArithExpr) -> Option<ScalarType> {
33 match expr {
34 ArithExpr::Variable(name) => self.get(name),
35 ArithExpr::Integer(_) => Some(ScalarType::I64),
36 ArithExpr::Float(_) => Some(ScalarType::F64),
37 ArithExpr::Add(l, r)
38 | ArithExpr::Sub(l, r)
39 | ArithExpr::Mul(l, r)
40 | ArithExpr::Div(l, r)
41 | ArithExpr::Mod(l, r) => {
42 let lt = self.infer_expr(l)?;
43 let rt = self.infer_expr(r)?;
44 if lt == ScalarType::F64 || rt == ScalarType::F64 {
46 Some(ScalarType::F64)
47 } else {
48 Some(lt)
49 }
50 }
51 ArithExpr::Cast(_, t) => Some(*t),
52 ArithExpr::Abs(e) => self.infer_expr(e),
53 ArithExpr::Min(l, r) | ArithExpr::Max(l, r) | ArithExpr::Pow(l, r) => {
54 let lt = self.infer_expr(l)?;
55 let rt = self.infer_expr(r)?;
56 if lt == ScalarType::F64 || rt == ScalarType::F64 {
57 Some(ScalarType::F64)
58 } else {
59 Some(lt)
60 }
61 }
62 ArithExpr::FuncCall { .. } => None, ArithExpr::Conditional {
64 then_expr,
65 else_expr,
66 ..
67 } => {
68 let then_t = self.infer_expr(then_expr)?;
70 let else_t = self.infer_expr(else_expr)?;
71 if then_t == else_t {
72 Some(then_t)
73 } else if then_t == ScalarType::F64 || else_t == ScalarType::F64 {
74 Some(ScalarType::F64)
75 } else {
76 Some(then_t)
77 }
78 }
79 }
80 }
81}
82
83pub(crate) fn infer_param_types(func: &FuncDef) -> Vec<Option<ScalarType>> {
85 func.params.iter().map(|p| p.typ).collect()
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91
92 #[test]
93 fn test_infer_literal() {
94 let ctx = TypeContext::new();
95 assert_eq!(
96 ctx.infer_expr(&ArithExpr::Integer(5)),
97 Some(ScalarType::I64)
98 );
99 assert_eq!(
100 ctx.infer_expr(&ArithExpr::Float(3.25)),
101 Some(ScalarType::F64)
102 );
103 }
104
105 #[test]
106 fn test_infer_variable() {
107 let mut ctx = TypeContext::new();
108 ctx.bind("X", ScalarType::F64);
109 assert_eq!(
110 ctx.infer_expr(&ArithExpr::Variable("X".into())),
111 Some(ScalarType::F64)
112 );
113 }
114
115 #[test]
116 fn test_infer_numeric_promotion() {
117 let ctx = TypeContext::new();
118 let expr = ArithExpr::Add(
119 Box::new(ArithExpr::Integer(5)),
120 Box::new(ArithExpr::Float(3.0)),
121 );
122 assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::F64));
123 }
124
125 #[test]
126 fn test_infer_unknown_variable() {
127 let ctx = TypeContext::new();
128 assert_eq!(ctx.infer_expr(&ArithExpr::Variable("Unknown".into())), None);
129 }
130
131 #[test]
132 fn test_infer_cast() {
133 let ctx = TypeContext::new();
134 let expr = ArithExpr::Cast(Box::new(ArithExpr::Integer(5)), ScalarType::F64);
135 assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::F64));
136 }
137
138 #[test]
139 fn test_infer_abs() {
140 let mut ctx = TypeContext::new();
141 ctx.bind("X", ScalarType::I64);
142 let expr = ArithExpr::Abs(Box::new(ArithExpr::Variable("X".into())));
143 assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::I64));
144 }
145
146 #[test]
147 fn test_infer_min_max() {
148 let ctx = TypeContext::new();
149 let expr = ArithExpr::Min(
150 Box::new(ArithExpr::Integer(5)),
151 Box::new(ArithExpr::Integer(3)),
152 );
153 assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::I64));
154
155 let expr_float = ArithExpr::Max(
156 Box::new(ArithExpr::Integer(5)),
157 Box::new(ArithExpr::Float(3.0)),
158 );
159 assert_eq!(ctx.infer_expr(&expr_float), Some(ScalarType::F64));
160 }
161
162 #[test]
163 fn test_infer_conditional() {
164 use crate::ast::CompOp;
165
166 let mut ctx = TypeContext::new();
167 ctx.bind("X", ScalarType::I64);
168 let expr = ArithExpr::Conditional {
169 cond_left: Box::new(ArithExpr::Variable("X".into())),
170 cond_op: CompOp::Lt,
171 cond_right: Box::new(ArithExpr::Integer(0)),
172 then_expr: Box::new(ArithExpr::Integer(1)),
173 else_expr: Box::new(ArithExpr::Integer(2)),
174 };
175 assert_eq!(ctx.infer_expr(&expr), Some(ScalarType::I64));
176 }
177
178 #[test]
179 fn test_infer_param_types() {
180 use crate::ast::{FuncBody, FuncParam};
181
182 let func = FuncDef {
183 name: "test".to_string(),
184 params: vec![
185 FuncParam {
186 name: "X".to_string(),
187 typ: Some(ScalarType::I64),
188 },
189 FuncParam {
190 name: "Y".to_string(),
191 typ: None,
192 },
193 ],
194 return_type: None,
195 body: FuncBody::Arithmetic(ArithExpr::Integer(1)),
196 is_private: false,
197 };
198
199 let types = infer_param_types(&func);
200 assert_eq!(types.len(), 2);
201 assert_eq!(types[0], Some(ScalarType::I64));
202 assert_eq!(types[1], None);
203 }
204}