1use crate::ast::{Program, UseDecl};
4use crate::module::{module_path_to_string, LoadedModule, ModuleError, ModulePath};
5use crate::parser::parse_program;
6use std::collections::{HashMap, HashSet};
7use std::fs;
8use std::path::{Path, PathBuf};
9
10pub struct ModuleResolver {
12 search_paths: Vec<PathBuf>,
14 loaded: HashMap<String, LoadedModule>,
16 loading: Vec<ModulePath>,
18}
19
20impl ModuleResolver {
21 pub fn new(search_paths: Vec<PathBuf>) -> Self {
23 Self {
24 search_paths,
25 loaded: HashMap::new(),
26 loading: Vec::new(),
27 }
28 }
29
30 pub fn find_module_file(&self, base_dir: &Path, module_path: &[String]) -> Option<PathBuf> {
32 let relative_path = format!("{}.xlog", module_path.join("/"));
33
34 let candidate = base_dir.join(&relative_path);
36 if candidate.exists() {
37 return Some(candidate);
38 }
39
40 for search_path in &self.search_paths {
42 let candidate = search_path.join(&relative_path);
43 if candidate.exists() {
44 return Some(candidate);
45 }
46 }
47
48 None
49 }
50
51 fn searched_paths(&self, base_dir: &Path, module_path: &[String]) -> Vec<PathBuf> {
53 let relative_path = format!("{}.xlog", module_path.join("/"));
54 let mut searched = vec![base_dir.join(&relative_path)];
55 for sp in &self.search_paths {
56 searched.push(sp.join(&relative_path));
57 }
58 searched
59 }
60
61 fn check_cycle(&self, module_path: &[String]) -> Option<Vec<ModulePath>> {
63 let path_str = module_path_to_string(module_path);
64 for (i, loading_path) in self.loading.iter().enumerate() {
65 if module_path_to_string(loading_path) == path_str {
66 let mut cycle: Vec<ModulePath> = self.loading[i..].to_vec();
68 cycle.push(module_path.to_vec());
69 return Some(cycle);
70 }
71 }
72 None
73 }
74
75 pub fn extract_exports(program: &Program) -> (HashSet<String>, HashSet<String>) {
78 let mut pred_exports = HashSet::new();
79 let mut func_exports = HashSet::new();
80
81 for pred in &program.predicates {
83 if !pred.is_private {
84 pred_exports.insert(pred.name.clone());
85 }
86 }
87
88 for rule in &program.rules {
90 let is_private = program
92 .predicates
93 .iter()
94 .any(|p| p.name == rule.head.predicate && p.is_private);
95 if !is_private {
96 pred_exports.insert(rule.head.predicate.clone());
97 }
98 }
99
100 for func in &program.functions {
102 if !func.is_private {
103 func_exports.insert(func.name.clone());
104 }
105 }
106
107 (pred_exports, func_exports)
108 }
109
110 pub fn load_module(
112 &mut self,
113 base_dir: &Path,
114 module_path: &[String],
115 ) -> Result<&LoadedModule, ModuleError> {
116 let path_key = module_path_to_string(module_path);
117
118 if self.loaded.contains_key(&path_key) {
120 return Ok(self.loaded.get(&path_key).unwrap());
121 }
122
123 if let Some(cycle) = self.check_cycle(module_path) {
125 return Err(ModuleError::CircularImport { cycle });
126 }
127
128 let source_file = self
130 .find_module_file(base_dir, module_path)
131 .ok_or_else(|| ModuleError::NotFound {
132 path: module_path.to_vec(),
133 searched: self.searched_paths(base_dir, module_path),
134 })?;
135
136 self.loading.push(module_path.to_vec());
138
139 let source = fs::read_to_string(&source_file).map_err(|e| ModuleError::ParseError {
141 path: source_file.clone(),
142 message: e.to_string(),
143 })?;
144
145 let program = parse_program(&source).map_err(|e| ModuleError::ParseError {
146 path: source_file.clone(),
147 message: e.to_string(),
148 })?;
149
150 let (exports, function_exports) = Self::extract_exports(&program);
152
153 let module_dir = source_file.parent().unwrap_or(base_dir);
155 for import in &program.imports {
156 self.load_module(module_dir, &import.module_path)?;
157 }
158
159 self.loading.pop();
161
162 let module = LoadedModule {
164 path: module_path.to_vec(),
165 source_file,
166 exports,
167 function_exports,
168 program,
169 };
170
171 self.loaded.insert(path_key.clone(), module);
172 Ok(self.loaded.get(&path_key).unwrap())
173 }
174
175 pub fn check_import(&self, module_path: &[String], predicate: &str) -> Result<(), ModuleError> {
177 let path_key = module_path_to_string(module_path);
178 let module = self
179 .loaded
180 .get(&path_key)
181 .ok_or_else(|| ModuleError::NotFound {
182 path: module_path.to_vec(),
183 searched: vec![],
184 })?;
185
186 if !module.exports.contains(predicate) {
187 return Err(ModuleError::PredicateNotFound {
188 name: predicate.to_string(),
189 module: module_path.to_vec(),
190 });
191 }
192
193 Ok(())
194 }
195
196 #[allow(clippy::type_complexity)]
199 pub fn validate_imports(
200 &self,
201 program: &Program,
202 ) -> Result<(HashMap<String, ModulePath>, HashMap<String, ModulePath>), ModuleError> {
203 let mut imported_predicates: HashMap<String, ModulePath> = HashMap::new();
204 let mut imported_functions: HashMap<String, ModulePath> = HashMap::new();
205
206 for use_decl in &program.imports {
207 let module = self
208 .loaded
209 .get(&module_path_to_string(&use_decl.module_path))
210 .expect("module should be loaded");
211
212 let all_exports: HashSet<String> = module
214 .exports
215 .iter()
216 .chain(module.function_exports.iter())
217 .cloned()
218 .collect();
219
220 let names_to_import: Vec<String> = match &use_decl.imports {
221 Some(specific) => specific.clone(),
222 None => all_exports.iter().cloned().collect(),
223 };
224
225 for name in names_to_import {
226 let is_predicate = module.exports.contains(&name);
228 let is_function = module.function_exports.contains(&name);
229
230 if !is_predicate && !is_function {
231 return Err(ModuleError::PredicateNotFound {
232 name: name.clone(),
233 module: use_decl.module_path.clone(),
234 });
235 }
236
237 if is_predicate {
239 if let Some(prev_module) = imported_predicates.get(&name) {
240 if prev_module != &use_decl.module_path {
241 return Err(ModuleError::ImportConflict {
242 name,
243 module1: prev_module.clone(),
244 module2: use_decl.module_path.clone(),
245 });
246 }
247 }
248 imported_predicates.insert(name.clone(), use_decl.module_path.clone());
249 }
250
251 if is_function {
253 if let Some(prev_module) = imported_functions.get(&name) {
254 if prev_module != &use_decl.module_path {
255 return Err(ModuleError::ImportConflict {
256 name,
257 module1: prev_module.clone(),
258 module2: use_decl.module_path.clone(),
259 });
260 }
261 }
262 imported_functions.insert(name.clone(), use_decl.module_path.clone());
263 }
264 }
265 }
266
267 Ok((imported_predicates, imported_functions))
268 }
269
270 pub fn get_module(&self, module_path: &[String]) -> Option<&LoadedModule> {
272 self.loaded.get(&module_path_to_string(module_path))
273 }
274
275 pub fn is_loaded(&self, module_path: &str) -> bool {
277 self.loaded.contains_key(module_path)
278 }
279
280 pub fn loaded_modules(&self) -> Vec<&str> {
282 self.loaded.keys().map(|s| s.as_str()).collect()
283 }
284
285 fn imported_item_set(use_decl: &UseDecl) -> Option<HashSet<String>> {
286 match &use_decl.imports {
287 Some(items) if !items.is_empty() => Some(items.iter().cloned().collect()),
288 _ => None,
289 }
290 }
291
292 fn import_merge_key(
293 module_path: &[String],
294 imported_items: Option<&HashSet<String>>,
295 ) -> String {
296 let mut key = module_path_to_string(module_path);
297 if let Some(items) = imported_items {
298 let mut sorted_items = items.iter().cloned().collect::<Vec<_>>();
299 sorted_items.sort();
300 key.push_str("::{");
301 key.push_str(&sorted_items.join(","));
302 key.push('}');
303 } else {
304 key.push_str("::*");
305 }
306 key
307 }
308
309 fn merge_import_closure(
310 &self,
311 program: &mut Program,
312 use_decl: &UseDecl,
313 merged_imports: &mut HashSet<String>,
314 ) -> Result<(), ModuleError> {
315 let path_key = module_path_to_string(&use_decl.module_path);
316 let loaded_module = self
317 .loaded
318 .get(&path_key)
319 .ok_or_else(|| ModuleError::NotFound {
320 path: use_decl.module_path.clone(),
321 searched: vec![],
322 })?;
323
324 for nested_use in &loaded_module.program.imports {
325 self.merge_import_closure(program, nested_use, merged_imports)?;
326 }
327
328 let imported_items = Self::imported_item_set(use_decl);
329 let merge_key = Self::import_merge_key(&use_decl.module_path, imported_items.as_ref());
330 if merged_imports.insert(merge_key) {
331 program.merge_from(&loaded_module.program, imported_items.as_ref());
332 }
333 Ok(())
334 }
335
336 pub fn merge_imports(&self, mut program: Program) -> Result<Program, ModuleError> {
345 let entry_rules = std::mem::take(&mut program.rules);
346 let mut merged_imports = HashSet::new();
347 for use_decl in &program.imports.clone() {
348 self.merge_import_closure(&mut program, use_decl, &mut merged_imports)?;
349 }
350 program.rules.extend(entry_rules);
351
352 Ok(program)
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use std::io::Write;
360 use tempfile::TempDir;
361
362 fn create_test_module(dir: &Path, name: &str, content: &str) -> PathBuf {
363 let path = dir.join(format!("{}.xlog", name));
364 let mut file = fs::File::create(&path).unwrap();
365 file.write_all(content.as_bytes()).unwrap();
366 path
367 }
368
369 #[test]
370 fn test_find_module_file() {
371 let tmp = TempDir::new().unwrap();
372 create_test_module(tmp.path(), "graph", "edge(1, 2).");
373
374 let resolver = ModuleResolver::new(vec![]);
375 let found = resolver.find_module_file(tmp.path(), &["graph".into()]);
376 assert!(found.is_some());
377 }
378
379 #[test]
380 fn test_module_not_found() {
381 let tmp = TempDir::new().unwrap();
382 let mut resolver = ModuleResolver::new(vec![]);
383
384 let result = resolver.load_module(tmp.path(), &["nonexistent".into()]);
385 assert!(matches!(result, Err(ModuleError::NotFound { .. })));
386 }
387
388 #[test]
389 fn test_circular_import() {
390 let tmp = TempDir::new().unwrap();
391 create_test_module(tmp.path(), "a", "use b.");
392 create_test_module(tmp.path(), "b", "use a.");
393
394 let mut resolver = ModuleResolver::new(vec![]);
395 let result = resolver.load_module(tmp.path(), &["a".into()]);
396 assert!(matches!(result, Err(ModuleError::CircularImport { .. })));
397 }
398
399 #[test]
400 fn test_load_simple_module() {
401 let tmp = TempDir::new().unwrap();
402 create_test_module(
403 tmp.path(),
404 "math",
405 r#"
406 pred add(u32, u32, u32).
407 add(1, 2, 3).
408 "#,
409 );
410
411 let mut resolver = ModuleResolver::new(vec![]);
412 let result = resolver.load_module(tmp.path(), &["math".into()]);
413 assert!(result.is_ok());
414 let module = result.unwrap();
415 assert!(module.exports.contains("add"));
416 }
417
418 #[test]
419 fn test_private_not_exported() {
420 let tmp = TempDir::new().unwrap();
421 create_test_module(
422 tmp.path(),
423 "graph",
424 r#"
425 pred edge(u32, u32).
426 private pred helper(u32).
427 edge(1, 2).
428 helper(1).
429 "#,
430 );
431
432 let mut resolver = ModuleResolver::new(vec![]);
433 let result = resolver.load_module(tmp.path(), &["graph".into()]);
434 assert!(result.is_ok());
435 let module = result.unwrap();
436 assert!(module.exports.contains("edge"));
437 assert!(!module.exports.contains("helper"));
438 }
439
440 #[test]
441 fn test_search_paths() {
442 let tmp = TempDir::new().unwrap();
443 let lib_dir = tmp.path().join("lib");
444 fs::create_dir(&lib_dir).unwrap();
445 create_test_module(&lib_dir, "stdlib", "helper(1).");
446
447 let resolver = ModuleResolver::new(vec![lib_dir.clone()]);
448 let found = resolver.find_module_file(tmp.path(), &["stdlib".into()]);
449 assert!(found.is_some());
450 assert!(found.unwrap().starts_with(&lib_dir));
451 }
452
453 #[test]
454 fn test_function_exports() {
455 let tmp = TempDir::new().unwrap();
456 create_test_module(
457 tmp.path(),
458 "mathfuncs",
459 r#"
460 func square(X) = X * X.
461 func cube(X) = X * X * X.
462 private func helper(X) = X.
463 "#,
464 );
465
466 let mut resolver = ModuleResolver::new(vec![]);
467 let result = resolver.load_module(tmp.path(), &["mathfuncs".into()]);
468 assert!(result.is_ok());
469 let module = result.unwrap();
470
471 assert!(module.function_exports.contains("square"));
473 assert!(module.function_exports.contains("cube"));
474
475 assert!(!module.function_exports.contains("helper"));
477 }
478
479 #[test]
480 fn test_mixed_exports() {
481 let tmp = TempDir::new().unwrap();
482 create_test_module(
483 tmp.path(),
484 "mixed",
485 r#"
486 pred value(i64).
487 value(42).
488 func double(X) = X * 2.
489 "#,
490 );
491
492 let mut resolver = ModuleResolver::new(vec![]);
493 let result = resolver.load_module(tmp.path(), &["mixed".into()]);
494 assert!(result.is_ok());
495 let module = result.unwrap();
496
497 assert!(module.exports.contains("value"));
499 assert!(module.function_exports.contains("double"));
500 }
501}