Skip to main content

xlog_logic/
resolver.rs

1//! Module resolution for XLOG programs.
2
3use crate::ast::{Program, UseDecl};
4use crate::module::{module_path_to_string, LoadedModule, ModuleError, ModulePath};
5use crate::parser::parse_program;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::path::{Path, PathBuf};
9
10/// Resolves and loads modules
11pub struct ModuleResolver {
12    /// Directories to search for modules
13    search_paths: Vec<PathBuf>,
14    /// Already loaded modules (path string -> module)
15    loaded: HashMap<String, LoadedModule>,
16    /// Currently loading (for cycle detection)
17    loading: Vec<ModulePath>,
18}
19
20impl ModuleResolver {
21    /// Create a new resolver with given search paths
22    pub fn new(search_paths: Vec<PathBuf>) -> Self {
23        Self {
24            search_paths,
25            loaded: HashMap::new(),
26            loading: Vec::new(),
27        }
28    }
29
30    /// Find the file for a module path
31    pub fn find_module_file(&self, base_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
32        let relative_path = format!("{}.xlog", module_path.join("/"));
33
34        // Try relative to base_dir first
35        let candidate = base_dir.join(&relative_path);
36        if candidate.exists() {
37            return Some(candidate);
38        }
39
40        // Try search paths
41        for search_path in &self.search_paths {
42            let candidate = search_path.join(&relative_path);
43            if candidate.exists() {
44                return Some(candidate);
45            }
46        }
47
48        None
49    }
50
51    /// Get the list of searched paths for error reporting
52    fn searched_paths(&self, base_dir: &Path, module_path: &[String]) -> Vec<PathBuf> {
53        let relative_path = format!("{}.xlog", module_path.join("/"));
54        let mut searched = vec![base_dir.join(&relative_path)];
55        for sp in &self.search_paths {
56            searched.push(sp.join(&relative_path));
57        }
58        searched
59    }
60
61    /// Check if we're in a circular import
62    fn check_cycle(&self, module_path: &[String]) -> Option<Vec<ModulePath>> {
63        let path_str = module_path_to_string(module_path);
64        for (i, loading_path) in self.loading.iter().enumerate() {
65            if module_path_to_string(loading_path) == path_str {
66                // Found cycle - return the cycle path
67                let mut cycle: Vec<ModulePath> = self.loading[i..].to_vec();
68                cycle.push(module_path.to_vec());
69                return Some(cycle);
70            }
71        }
72        None
73    }
74
75    /// Extract exports from a parsed program
76    /// Returns (predicate exports, function exports)
77    pub fn extract_exports(program: &Program) -> (HashSet<String>, HashSet<String>) {
78        let mut pred_exports = HashSet::new();
79        let mut func_exports = HashSet::new();
80
81        // Add declared predicates that aren't private
82        for pred in &program.predicates {
83            if !pred.is_private {
84                pred_exports.insert(pred.name.clone());
85            }
86        }
87
88        // Add rule heads (all rules define public predicates unless declared private)
89        for rule in &program.rules {
90            // Check if this predicate was declared as private
91            let is_private = program
92                .predicates
93                .iter()
94                .any(|p| p.name == rule.head.predicate && p.is_private);
95            if !is_private {
96                pred_exports.insert(rule.head.predicate.clone());
97            }
98        }
99
100        // Add functions that aren't private
101        for func in &program.functions {
102            if !func.is_private {
103                func_exports.insert(func.name.clone());
104            }
105        }
106
107        (pred_exports, func_exports)
108    }
109
110    /// Load a module from a path
111    pub fn load_module(
112        &mut self,
113        base_dir: &Path,
114        module_path: &[String],
115    ) -> Result<&LoadedModule, ModuleError> {
116        let path_key = module_path_to_string(module_path);
117
118        // Already loaded?
119        if self.loaded.contains_key(&path_key) {
120            return Ok(self.loaded.get(&path_key).unwrap());
121        }
122
123        // Check for cycle
124        if let Some(cycle) = self.check_cycle(module_path) {
125            return Err(ModuleError::CircularImport { cycle });
126        }
127
128        // Find the file
129        let source_file = self
130            .find_module_file(base_dir, module_path)
131            .ok_or_else(|| ModuleError::NotFound {
132                path: module_path.to_vec(),
133                searched: self.searched_paths(base_dir, module_path),
134            })?;
135
136        // Mark as loading
137        self.loading.push(module_path.to_vec());
138
139        // Read and parse
140        let source = fs::read_to_string(&source_file).map_err(|e| ModuleError::ParseError {
141            path: source_file.clone(),
142            message: e.to_string(),
143        })?;
144
145        let program = parse_program(&source).map_err(|e| ModuleError::ParseError {
146            path: source_file.clone(),
147            message: e.to_string(),
148        })?;
149
150        // Extract exports
151        let (exports, function_exports) = Self::extract_exports(&program);
152
153        // Recursively load imports
154        let module_dir = source_file.parent().unwrap_or(base_dir);
155        for import in &program.imports {
156            self.load_module(module_dir, &import.module_path)?;
157        }
158
159        // Remove from loading
160        self.loading.pop();
161
162        // Store loaded module
163        let module = LoadedModule {
164            path: module_path.to_vec(),
165            source_file,
166            exports,
167            function_exports,
168            program,
169        };
170
171        self.loaded.insert(path_key.clone(), module);
172        Ok(self.loaded.get(&path_key).unwrap())
173    }
174
175    /// Check if a predicate can be imported from a module
176    pub fn check_import(&self, module_path: &[String], predicate: &str) -> Result<(), ModuleError> {
177        let path_key = module_path_to_string(module_path);
178        let module = self
179            .loaded
180            .get(&path_key)
181            .ok_or_else(|| ModuleError::NotFound {
182                path: module_path.to_vec(),
183                searched: vec![],
184            })?;
185
186        if !module.exports.contains(predicate) {
187            return Err(ModuleError::PredicateNotFound {
188                name: predicate.to_string(),
189                module: module_path.to_vec(),
190            });
191        }
192
193        Ok(())
194    }
195
196    /// Validate all imports in a program
197    /// Returns (predicate imports, function imports) mapped to their source modules
198    #[allow(clippy::type_complexity)]
199    pub fn validate_imports(
200        &self,
201        program: &Program,
202    ) -> Result<(HashMap<String, ModulePath>, HashMap<String, ModulePath>), ModuleError> {
203        let mut imported_predicates: HashMap<String, ModulePath> = HashMap::new();
204        let mut imported_functions: HashMap<String, ModulePath> = HashMap::new();
205
206        for use_decl in &program.imports {
207            let module = self
208                .loaded
209                .get(&module_path_to_string(&use_decl.module_path))
210                .expect("module should be loaded");
211
212            // Combine all available exports for wildcard imports
213            let all_exports: HashSet<String> = module
214                .exports
215                .iter()
216                .chain(module.function_exports.iter())
217                .cloned()
218                .collect();
219
220            let names_to_import: Vec<String> = match &use_decl.imports {
221                Some(specific) => specific.clone(),
222                None => all_exports.iter().cloned().collect(),
223            };
224
225            for name in names_to_import {
226                // Check if name exists as predicate or function
227                let is_predicate = module.exports.contains(&name);
228                let is_function = module.function_exports.contains(&name);
229
230                if !is_predicate && !is_function {
231                    return Err(ModuleError::PredicateNotFound {
232                        name: name.clone(),
233                        module: use_decl.module_path.clone(),
234                    });
235                }
236
237                // Check for conflicts with predicates
238                if is_predicate {
239                    if let Some(prev_module) = imported_predicates.get(&name) {
240                        if prev_module != &use_decl.module_path {
241                            return Err(ModuleError::ImportConflict {
242                                name,
243                                module1: prev_module.clone(),
244                                module2: use_decl.module_path.clone(),
245                            });
246                        }
247                    }
248                    imported_predicates.insert(name.clone(), use_decl.module_path.clone());
249                }
250
251                // Check for conflicts with functions
252                if is_function {
253                    if let Some(prev_module) = imported_functions.get(&name) {
254                        if prev_module != &use_decl.module_path {
255                            return Err(ModuleError::ImportConflict {
256                                name,
257                                module1: prev_module.clone(),
258                                module2: use_decl.module_path.clone(),
259                            });
260                        }
261                    }
262                    imported_functions.insert(name.clone(), use_decl.module_path.clone());
263                }
264            }
265        }
266
267        Ok((imported_predicates, imported_functions))
268    }
269
270    /// Get a loaded module by path
271    pub fn get_module(&self, module_path: &[String]) -> Option<&LoadedModule> {
272        self.loaded.get(&module_path_to_string(module_path))
273    }
274
275    /// Check if a module is loaded
276    pub fn is_loaded(&self, module_path: &str) -> bool {
277        self.loaded.contains_key(module_path)
278    }
279
280    /// Get all loaded module paths (for testing)
281    pub fn loaded_modules(&self) -> Vec<&str> {
282        self.loaded.keys().map(|s| s.as_str()).collect()
283    }
284
285    fn imported_item_set(use_decl: &UseDecl) -> Option<HashSet<String>> {
286        match &use_decl.imports {
287            Some(items) if !items.is_empty() => Some(items.iter().cloned().collect()),
288            _ => None,
289        }
290    }
291
292    fn import_merge_key(
293        module_path: &[String],
294        imported_items: Option<&HashSet<String>>,
295    ) -> String {
296        let mut key = module_path_to_string(module_path);
297        if let Some(items) = imported_items {
298            let mut sorted_items = items.iter().cloned().collect::<Vec<_>>();
299            sorted_items.sort();
300            key.push_str("::{");
301            key.push_str(&sorted_items.join(","));
302            key.push('}');
303        } else {
304            key.push_str("::*");
305        }
306        key
307    }
308
309    fn merge_import_closure(
310        &self,
311        program: &mut Program,
312        use_decl: &UseDecl,
313        merged_imports: &mut HashSet<String>,
314    ) -> Result<(), ModuleError> {
315        let path_key = module_path_to_string(&use_decl.module_path);
316        let loaded_module = self
317            .loaded
318            .get(&path_key)
319            .ok_or_else(|| ModuleError::NotFound {
320                path: use_decl.module_path.clone(),
321                searched: vec![],
322            })?;
323
324        for nested_use in &loaded_module.program.imports {
325            self.merge_import_closure(program, nested_use, merged_imports)?;
326        }
327
328        let imported_items = Self::imported_item_set(use_decl);
329        let merge_key = Self::import_merge_key(&use_decl.module_path, imported_items.as_ref());
330        if merged_imports.insert(merge_key) {
331            program.merge_from(&loaded_module.program, imported_items.as_ref());
332        }
333        Ok(())
334    }
335
336    /// Merge all imported modules into a program.
337    /// Returns a new program with all imports resolved and merged.
338    ///
339    /// # Arguments
340    /// * `program` - The main program with imports to resolve
341    ///
342    /// # Returns
343    /// The program with all imports merged in
344    pub fn merge_imports(&self, mut program: Program) -> Result<Program, ModuleError> {
345        let entry_rules = std::mem::take(&mut program.rules);
346        let mut merged_imports = HashSet::new();
347        for use_decl in &program.imports.clone() {
348            self.merge_import_closure(&mut program, use_decl, &mut merged_imports)?;
349        }
350        program.rules.extend(entry_rules);
351
352        Ok(program)
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use std::io::Write;
360    use tempfile::TempDir;
361
362    fn create_test_module(dir: &Path, name: &str, content: &str) -> PathBuf {
363        let path = dir.join(format!("{}.xlog", name));
364        let mut file = fs::File::create(&path).unwrap();
365        file.write_all(content.as_bytes()).unwrap();
366        path
367    }
368
369    #[test]
370    fn test_find_module_file() {
371        let tmp = TempDir::new().unwrap();
372        create_test_module(tmp.path(), "graph", "edge(1, 2).");
373
374        let resolver = ModuleResolver::new(vec![]);
375        let found = resolver.find_module_file(tmp.path(), &["graph".into()]);
376        assert!(found.is_some());
377    }
378
379    #[test]
380    fn test_module_not_found() {
381        let tmp = TempDir::new().unwrap();
382        let mut resolver = ModuleResolver::new(vec![]);
383
384        let result = resolver.load_module(tmp.path(), &["nonexistent".into()]);
385        assert!(matches!(result, Err(ModuleError::NotFound { .. })));
386    }
387
388    #[test]
389    fn test_circular_import() {
390        let tmp = TempDir::new().unwrap();
391        create_test_module(tmp.path(), "a", "use b.");
392        create_test_module(tmp.path(), "b", "use a.");
393
394        let mut resolver = ModuleResolver::new(vec![]);
395        let result = resolver.load_module(tmp.path(), &["a".into()]);
396        assert!(matches!(result, Err(ModuleError::CircularImport { .. })));
397    }
398
399    #[test]
400    fn test_load_simple_module() {
401        let tmp = TempDir::new().unwrap();
402        create_test_module(
403            tmp.path(),
404            "math",
405            r#"
406            pred add(u32, u32, u32).
407            add(1, 2, 3).
408        "#,
409        );
410
411        let mut resolver = ModuleResolver::new(vec![]);
412        let result = resolver.load_module(tmp.path(), &["math".into()]);
413        assert!(result.is_ok());
414        let module = result.unwrap();
415        assert!(module.exports.contains("add"));
416    }
417
418    #[test]
419    fn test_private_not_exported() {
420        let tmp = TempDir::new().unwrap();
421        create_test_module(
422            tmp.path(),
423            "graph",
424            r#"
425            pred edge(u32, u32).
426            private pred helper(u32).
427            edge(1, 2).
428            helper(1).
429        "#,
430        );
431
432        let mut resolver = ModuleResolver::new(vec![]);
433        let result = resolver.load_module(tmp.path(), &["graph".into()]);
434        assert!(result.is_ok());
435        let module = result.unwrap();
436        assert!(module.exports.contains("edge"));
437        assert!(!module.exports.contains("helper"));
438    }
439
440    #[test]
441    fn test_search_paths() {
442        let tmp = TempDir::new().unwrap();
443        let lib_dir = tmp.path().join("lib");
444        fs::create_dir(&lib_dir).unwrap();
445        create_test_module(&lib_dir, "stdlib", "helper(1).");
446
447        let resolver = ModuleResolver::new(vec![lib_dir.clone()]);
448        let found = resolver.find_module_file(tmp.path(), &["stdlib".into()]);
449        assert!(found.is_some());
450        assert!(found.unwrap().starts_with(&lib_dir));
451    }
452
453    #[test]
454    fn test_function_exports() {
455        let tmp = TempDir::new().unwrap();
456        create_test_module(
457            tmp.path(),
458            "mathfuncs",
459            r#"
460            func square(X) = X * X.
461            func cube(X) = X * X * X.
462            private func helper(X) = X.
463        "#,
464        );
465
466        let mut resolver = ModuleResolver::new(vec![]);
467        let result = resolver.load_module(tmp.path(), &["mathfuncs".into()]);
468        assert!(result.is_ok());
469        let module = result.unwrap();
470
471        // Public functions should be exported
472        assert!(module.function_exports.contains("square"));
473        assert!(module.function_exports.contains("cube"));
474
475        // Private function should not be exported
476        assert!(!module.function_exports.contains("helper"));
477    }
478
479    #[test]
480    fn test_mixed_exports() {
481        let tmp = TempDir::new().unwrap();
482        create_test_module(
483            tmp.path(),
484            "mixed",
485            r#"
486            pred value(i64).
487            value(42).
488            func double(X) = X * 2.
489        "#,
490        );
491
492        let mut resolver = ModuleResolver::new(vec![]);
493        let result = resolver.load_module(tmp.path(), &["mixed".into()]);
494        assert!(result.is_ok());
495        let module = result.unwrap();
496
497        // Both predicate and function exports should be present
498        assert!(module.exports.contains("value"));
499        assert!(module.function_exports.contains("double"));
500    }
501}