1use std::collections::HashMap;
4use std::hash::{Hash, Hasher};
5use std::path::{Path, PathBuf};
6use std::time::Duration;
7
8use xlog_core::{Result, XlogError};
9
10use crate::ast::{Directives, Program};
11use crate::parser::parse_program;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct StatementSpan {
16 pub start: usize,
18 pub end: usize,
20 pub line: usize,
22 pub column: usize,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct StatementUnit {
29 pub text: String,
31 pub span: StatementSpan,
33 pub hash: u64,
35}
36
37#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
39pub struct ParseCacheStats {
40 pub hits: usize,
42 pub misses: usize,
44 pub invalidated: usize,
46 pub module_invalidations: usize,
48 pub statement_count: usize,
50 pub full_parse_units: usize,
52 pub incremental_parse_units: usize,
54 pub elapsed: Duration,
56}
57
58impl ParseCacheStats {
59 pub fn estimated_speedup(&self) -> f64 {
61 if self.incremental_parse_units == 0 {
62 return self.full_parse_units.max(1) as f64;
63 }
64 self.full_parse_units as f64 / self.incremental_parse_units as f64
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct IncrementalParseResult {
71 pub program: Program,
73 pub statements: Vec<StatementUnit>,
75 pub stats: ParseCacheStats,
77}
78
79#[derive(Debug, Clone)]
80struct CachedStatement {
81 hash: u64,
82 text: String,
83 program: Program,
84}
85
86#[derive(Debug, Clone, Default)]
87struct CachedSource {
88 statements: Vec<CachedStatement>,
89 imports: Vec<Vec<String>>,
90}
91
92#[derive(Debug, Default)]
94pub struct ParserSession {
95 sources: HashMap<PathBuf, CachedSource>,
96 module_invalidations: usize,
97}
98
99impl ParserSession {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn split_statements(source: &str) -> Vec<StatementUnit> {
107 split_statements(source)
108 }
109
110 pub fn parse_path(
112 &mut self,
113 path: impl AsRef<Path>,
114 source: &str,
115 ) -> Result<IncrementalParseResult> {
116 let started = std::time::Instant::now();
117 let path = path.as_ref().to_path_buf();
118 let units = split_statements(source);
119 let previous = self.sources.get(&path);
120
121 let mut stats = ParseCacheStats {
122 statement_count: units.len(),
123 full_parse_units: units.len(),
124 module_invalidations: self.module_invalidations,
125 ..ParseCacheStats::default()
126 };
127 self.module_invalidations = 0;
128
129 let mut parsed_statements = Vec::with_capacity(units.len());
130 for (idx, unit) in units.iter().enumerate() {
131 if let Some(prev) = previous.and_then(|src| src.statements.get(idx)) {
132 if prev.hash == unit.hash && prev.text == unit.text {
133 stats.hits += 1;
134 parsed_statements.push(prev.clone());
135 continue;
136 }
137 }
138
139 stats.misses += 1;
140 let program = parse_program(&unit.text).map_err(|err| {
141 XlogError::Parse(format!(
142 "incremental parse error at {}:{} (bytes {}..{}): {}",
143 unit.span.line, unit.span.column, unit.span.start, unit.span.end, err
144 ))
145 })?;
146 parsed_statements.push(CachedStatement {
147 hash: unit.hash,
148 text: unit.text.clone(),
149 program,
150 });
151 }
152
153 if let Some(previous) = previous {
154 let retained = parsed_statements
155 .iter()
156 .filter(|stmt| {
157 previous
158 .statements
159 .iter()
160 .any(|prev| prev.hash == stmt.hash && prev.text == stmt.text)
161 })
162 .count();
163 stats.invalidated = previous.statements.len().saturating_sub(retained);
164 }
165
166 stats.incremental_parse_units = stats.misses;
167 stats.elapsed = started.elapsed();
168
169 let mut program = Program::new();
170 let mut imports = Vec::new();
171 for cached in &parsed_statements {
172 append_program(&mut program, cached.program.clone());
173 imports.extend(cached.program.imports.iter().map(|u| u.module_path.clone()));
174 }
175
176 self.sources.insert(
177 path,
178 CachedSource {
179 statements: parsed_statements,
180 imports,
181 },
182 );
183
184 Ok(IncrementalParseResult {
185 program,
186 statements: units,
187 stats,
188 })
189 }
190
191 pub fn invalidate_module(&mut self, path: impl AsRef<Path>) -> usize {
193 let path = path.as_ref();
194 let module_name = path
195 .file_stem()
196 .and_then(|s| s.to_str())
197 .map(str::to_string);
198 let mut removed = Vec::new();
199 for (cached_path, source) in &self.sources {
200 let direct = cached_path == path;
201 let dependent = module_name.as_ref().is_some_and(|name| {
202 source
203 .imports
204 .iter()
205 .any(|parts| parts.last().is_some_and(|part| part == name))
206 });
207 if direct || dependent {
208 removed.push(cached_path.clone());
209 }
210 }
211 let count = removed.len();
212 for path in removed {
213 self.sources.remove(&path);
214 }
215 self.module_invalidations = self.module_invalidations.saturating_add(count);
216 count
217 }
218
219 pub fn cached_source_count(&self) -> usize {
221 self.sources.len()
222 }
223}
224
225fn split_statements(source: &str) -> Vec<StatementUnit> {
226 let line_starts = line_starts(source);
227 let mut out = Vec::new();
228 let mut start = 0usize;
229 let mut in_string = false;
230 let mut escaped = false;
231 let mut in_comment = false;
232 let mut line_start = 0usize;
233
234 for (idx, ch) in source.char_indices() {
235 if in_comment {
236 if ch == '\n' {
237 in_comment = false;
238 let segment = source[start..idx].trim_start();
239 if segment.starts_with("#pragma") {
240 push_statement(source, &line_starts, start, idx, &mut out);
241 start = idx + ch.len_utf8();
242 } else if segment.starts_with("//") || segment.is_empty() {
243 start = idx + ch.len_utf8();
244 }
245 line_start = idx + ch.len_utf8();
246 }
247 continue;
248 }
249
250 if in_string {
251 if escaped {
252 escaped = false;
253 } else if ch == '\\' {
254 escaped = true;
255 } else if ch == '"' {
256 in_string = false;
257 }
258 continue;
259 }
260
261 if ch == '"' {
262 in_string = true;
263 continue;
264 }
265
266 if ch == '/' && source[idx..].starts_with("//") {
267 if source[start..idx].trim().is_empty() {
268 start = idx;
269 }
270 in_comment = true;
271 continue;
272 }
273
274 if ch == '\n' {
275 if source[start..idx].trim_start().starts_with("#pragma") {
276 push_statement(source, &line_starts, start, idx, &mut out);
277 start = idx + ch.len_utf8();
278 }
279 line_start = idx + ch.len_utf8();
280 continue;
281 }
282
283 if ch == '.' && !is_decimal_point(source, idx) {
284 push_statement(source, &line_starts, start, idx + ch.len_utf8(), &mut out);
285 start = idx + ch.len_utf8();
286 }
287
288 if idx == line_start && ch.is_whitespace() {
289 line_start = idx + ch.len_utf8();
290 }
291 }
292
293 if source[start..].trim().is_empty() {
294 return out;
295 }
296 push_statement(source, &line_starts, start, source.len(), &mut out);
297 out
298}
299
300fn push_statement(
301 source: &str,
302 line_starts: &[usize],
303 start: usize,
304 end: usize,
305 out: &mut Vec<StatementUnit>,
306) {
307 let text = source[start..end].trim().to_string();
308 if text.is_empty() || text.starts_with("//") {
309 return;
310 }
311 let trimmed_start = source[start..end]
312 .find(|c: char| !c.is_whitespace())
313 .map(|offset| start + offset)
314 .unwrap_or(start);
315 let (line, column) = line_col(line_starts, trimmed_start);
316 out.push(StatementUnit {
317 hash: stable_hash(&text),
318 text,
319 span: StatementSpan {
320 start: trimmed_start,
321 end,
322 line,
323 column,
324 },
325 });
326}
327
328fn line_starts(source: &str) -> Vec<usize> {
329 let mut starts = vec![0];
330 for (idx, ch) in source.char_indices() {
331 if ch == '\n' {
332 starts.push(idx + ch.len_utf8());
333 }
334 }
335 starts
336}
337
338fn line_col(line_starts: &[usize], byte: usize) -> (usize, usize) {
339 let idx = match line_starts.binary_search(&byte) {
340 Ok(idx) => idx,
341 Err(idx) => idx.saturating_sub(1),
342 };
343 (idx + 1, byte.saturating_sub(line_starts[idx]) + 1)
344}
345
346fn is_decimal_point(source: &str, idx: usize) -> bool {
347 let prev = source[..idx].chars().next_back();
348 let next = source[idx + 1..].chars().next();
349 matches!((prev, next), (Some(a), Some(b)) if a.is_ascii_digit() && b.is_ascii_digit())
350}
351
352fn stable_hash(text: &str) -> u64 {
353 let mut hasher = std::collections::hash_map::DefaultHasher::new();
354 text.hash(&mut hasher);
355 hasher.finish()
356}
357
358fn append_program(target: &mut Program, fragment: Program) {
359 target.imports.extend(fragment.imports);
360 target.functions.extend(fragment.functions);
361 target.domains.extend(fragment.domains);
362 target.predicates.extend(fragment.predicates);
363 target.rules.extend(fragment.rules);
364 target.constraints.extend(fragment.constraints);
365 target.queries.extend(fragment.queries);
366 target.prob_facts.extend(fragment.prob_facts);
367 target
368 .annotated_disjunctions
369 .extend(fragment.annotated_disjunctions);
370 target.evidence.extend(fragment.evidence);
371 target.prob_queries.extend(fragment.prob_queries);
372 target.neural_predicates.extend(fragment.neural_predicates);
373 target.learnable_rules.extend(fragment.learnable_rules);
374 merge_directives(&mut target.directives, fragment.directives);
375}
376
377fn merge_directives(target: &mut Directives, fragment: Directives) {
378 if fragment.prob_engine.is_some() {
379 target.prob_engine = fragment.prob_engine;
380 }
381 if fragment.prob_cache.is_some() {
382 target.prob_cache = fragment.prob_cache;
383 }
384 if fragment.prob_samples.is_some() {
385 target.prob_samples = fragment.prob_samples;
386 }
387 if fragment.prob_seed.is_some() {
388 target.prob_seed = fragment.prob_seed;
389 }
390 if fragment.prob_confidence.is_some() {
391 target.prob_confidence = fragment.prob_confidence;
392 }
393 if fragment.prob_method.is_some() {
394 target.prob_method = fragment.prob_method;
395 }
396 if fragment.prob_max_nonmonotone_iterations.is_some() {
397 target.prob_max_nonmonotone_iterations = fragment.prob_max_nonmonotone_iterations;
398 }
399 if fragment.max_recursion_depth.is_some() {
400 target.max_recursion_depth = fragment.max_recursion_depth;
401 }
402 if fragment.magic_sets.is_some() {
403 target.magic_sets = fragment.magic_sets;
404 }
405}