Skip to main content

xlog_logic/
incremental_parse.rs

1//! Statement-level parser session cache for incremental workflows.
2
3use std::collections::HashMap;
4use std::hash::{Hash, Hasher};
5use std::path::{Path, PathBuf};
6use std::time::Duration;
7
8use xlog_core::{Result, XlogError};
9
10use crate::ast::{Directives, Program};
11use crate::parser::parse_program;
12
13/// Byte and line/column span for one source statement.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct StatementSpan {
16    /// Byte offset where the statement starts.
17    pub start: usize,
18    /// Byte offset where the statement ends.
19    pub end: usize,
20    /// One-based line where the statement starts.
21    pub line: usize,
22    /// One-based column where the statement starts.
23    pub column: usize,
24}
25
26/// One statement unit discovered in a source file.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct StatementUnit {
29    /// Statement text.
30    pub text: String,
31    /// Statement source span.
32    pub span: StatementSpan,
33    /// Stable hash of the statement text.
34    pub hash: u64,
35}
36
37/// Cache statistics from one incremental parse.
38#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
39pub struct ParseCacheStats {
40    /// Number of unchanged statements reused from cache.
41    pub hits: usize,
42    /// Number of statements parsed fresh.
43    pub misses: usize,
44    /// Number of prior cached statements invalidated for this source.
45    pub invalidated: usize,
46    /// Number of cached source files invalidated through module invalidation.
47    pub module_invalidations: usize,
48    /// Total statement count in the parsed source.
49    pub statement_count: usize,
50    /// Full-parse work estimate, in statement units.
51    pub full_parse_units: usize,
52    /// Incremental work estimate, in statement units.
53    pub incremental_parse_units: usize,
54    /// Wall-clock time spent in the incremental parse call.
55    pub elapsed: Duration,
56}
57
58impl ParseCacheStats {
59    /// Structural speedup estimate based on avoided statement parses.
60    pub fn estimated_speedup(&self) -> f64 {
61        if self.incremental_parse_units == 0 {
62            return self.full_parse_units.max(1) as f64;
63        }
64        self.full_parse_units as f64 / self.incremental_parse_units as f64
65    }
66}
67
68/// Result from parsing source through a [`ParserSession`].
69#[derive(Debug, Clone)]
70pub struct IncrementalParseResult {
71    /// Parsed program assembled from statement fragments.
72    pub program: Program,
73    /// Statement units discovered in source order.
74    pub statements: Vec<StatementUnit>,
75    /// Cache statistics for this parse.
76    pub stats: ParseCacheStats,
77}
78
79#[derive(Debug, Clone)]
80struct CachedStatement {
81    hash: u64,
82    text: String,
83    program: Program,
84}
85
86#[derive(Debug, Clone, Default)]
87struct CachedSource {
88    statements: Vec<CachedStatement>,
89    imports: Vec<Vec<String>>,
90}
91
92/// Incremental parser cache keyed by source path.
93#[derive(Debug, Default)]
94pub struct ParserSession {
95    sources: HashMap<PathBuf, CachedSource>,
96    module_invalidations: usize,
97}
98
99impl ParserSession {
100    /// Create an empty parser session.
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Split source text into statement units with byte and line/column spans.
106    pub fn split_statements(source: &str) -> Vec<StatementUnit> {
107        split_statements(source)
108    }
109
110    /// Parse source associated with a path, reusing unchanged statement parses.
111    pub fn parse_path(
112        &mut self,
113        path: impl AsRef<Path>,
114        source: &str,
115    ) -> Result<IncrementalParseResult> {
116        let started = std::time::Instant::now();
117        let path = path.as_ref().to_path_buf();
118        let units = split_statements(source);
119        let previous = self.sources.get(&path);
120
121        let mut stats = ParseCacheStats {
122            statement_count: units.len(),
123            full_parse_units: units.len(),
124            module_invalidations: self.module_invalidations,
125            ..ParseCacheStats::default()
126        };
127        self.module_invalidations = 0;
128
129        let mut parsed_statements = Vec::with_capacity(units.len());
130        for (idx, unit) in units.iter().enumerate() {
131            if let Some(prev) = previous.and_then(|src| src.statements.get(idx)) {
132                if prev.hash == unit.hash && prev.text == unit.text {
133                    stats.hits += 1;
134                    parsed_statements.push(prev.clone());
135                    continue;
136                }
137            }
138
139            stats.misses += 1;
140            let program = parse_program(&unit.text).map_err(|err| {
141                XlogError::Parse(format!(
142                    "incremental parse error at {}:{} (bytes {}..{}): {}",
143                    unit.span.line, unit.span.column, unit.span.start, unit.span.end, err
144                ))
145            })?;
146            parsed_statements.push(CachedStatement {
147                hash: unit.hash,
148                text: unit.text.clone(),
149                program,
150            });
151        }
152
153        if let Some(previous) = previous {
154            let retained = parsed_statements
155                .iter()
156                .filter(|stmt| {
157                    previous
158                        .statements
159                        .iter()
160                        .any(|prev| prev.hash == stmt.hash && prev.text == stmt.text)
161                })
162                .count();
163            stats.invalidated = previous.statements.len().saturating_sub(retained);
164        }
165
166        stats.incremental_parse_units = stats.misses;
167        stats.elapsed = started.elapsed();
168
169        let mut program = Program::new();
170        let mut imports = Vec::new();
171        for cached in &parsed_statements {
172            append_program(&mut program, cached.program.clone());
173            imports.extend(cached.program.imports.iter().map(|u| u.module_path.clone()));
174        }
175
176        self.sources.insert(
177            path,
178            CachedSource {
179                statements: parsed_statements,
180                imports,
181            },
182        );
183
184        Ok(IncrementalParseResult {
185            program,
186            statements: units,
187            stats,
188        })
189    }
190
191    /// Invalidate one module path and cached sources that import it by final path segment.
192    pub fn invalidate_module(&mut self, path: impl AsRef<Path>) -> usize {
193        let path = path.as_ref();
194        let module_name = path
195            .file_stem()
196            .and_then(|s| s.to_str())
197            .map(str::to_string);
198        let mut removed = Vec::new();
199        for (cached_path, source) in &self.sources {
200            let direct = cached_path == path;
201            let dependent = module_name.as_ref().is_some_and(|name| {
202                source
203                    .imports
204                    .iter()
205                    .any(|parts| parts.last().is_some_and(|part| part == name))
206            });
207            if direct || dependent {
208                removed.push(cached_path.clone());
209            }
210        }
211        let count = removed.len();
212        for path in removed {
213            self.sources.remove(&path);
214        }
215        self.module_invalidations = self.module_invalidations.saturating_add(count);
216        count
217    }
218
219    /// Return the number of cached source files.
220    pub fn cached_source_count(&self) -> usize {
221        self.sources.len()
222    }
223}
224
225fn split_statements(source: &str) -> Vec<StatementUnit> {
226    let line_starts = line_starts(source);
227    let mut out = Vec::new();
228    let mut start = 0usize;
229    let mut in_string = false;
230    let mut escaped = false;
231    let mut in_comment = false;
232    let mut line_start = 0usize;
233
234    for (idx, ch) in source.char_indices() {
235        if in_comment {
236            if ch == '\n' {
237                in_comment = false;
238                let segment = source[start..idx].trim_start();
239                if segment.starts_with("#pragma") {
240                    push_statement(source, &line_starts, start, idx, &mut out);
241                    start = idx + ch.len_utf8();
242                } else if segment.starts_with("//") || segment.is_empty() {
243                    start = idx + ch.len_utf8();
244                }
245                line_start = idx + ch.len_utf8();
246            }
247            continue;
248        }
249
250        if in_string {
251            if escaped {
252                escaped = false;
253            } else if ch == '\\' {
254                escaped = true;
255            } else if ch == '"' {
256                in_string = false;
257            }
258            continue;
259        }
260
261        if ch == '"' {
262            in_string = true;
263            continue;
264        }
265
266        if ch == '/' && source[idx..].starts_with("//") {
267            if source[start..idx].trim().is_empty() {
268                start = idx;
269            }
270            in_comment = true;
271            continue;
272        }
273
274        if ch == '\n' {
275            if source[start..idx].trim_start().starts_with("#pragma") {
276                push_statement(source, &line_starts, start, idx, &mut out);
277                start = idx + ch.len_utf8();
278            }
279            line_start = idx + ch.len_utf8();
280            continue;
281        }
282
283        if ch == '.' && !is_decimal_point(source, idx) {
284            push_statement(source, &line_starts, start, idx + ch.len_utf8(), &mut out);
285            start = idx + ch.len_utf8();
286        }
287
288        if idx == line_start && ch.is_whitespace() {
289            line_start = idx + ch.len_utf8();
290        }
291    }
292
293    if source[start..].trim().is_empty() {
294        return out;
295    }
296    push_statement(source, &line_starts, start, source.len(), &mut out);
297    out
298}
299
300fn push_statement(
301    source: &str,
302    line_starts: &[usize],
303    start: usize,
304    end: usize,
305    out: &mut Vec<StatementUnit>,
306) {
307    let text = source[start..end].trim().to_string();
308    if text.is_empty() || text.starts_with("//") {
309        return;
310    }
311    let trimmed_start = source[start..end]
312        .find(|c: char| !c.is_whitespace())
313        .map(|offset| start + offset)
314        .unwrap_or(start);
315    let (line, column) = line_col(line_starts, trimmed_start);
316    out.push(StatementUnit {
317        hash: stable_hash(&text),
318        text,
319        span: StatementSpan {
320            start: trimmed_start,
321            end,
322            line,
323            column,
324        },
325    });
326}
327
328fn line_starts(source: &str) -> Vec<usize> {
329    let mut starts = vec![0];
330    for (idx, ch) in source.char_indices() {
331        if ch == '\n' {
332            starts.push(idx + ch.len_utf8());
333        }
334    }
335    starts
336}
337
338fn line_col(line_starts: &[usize], byte: usize) -> (usize, usize) {
339    let idx = match line_starts.binary_search(&byte) {
340        Ok(idx) => idx,
341        Err(idx) => idx.saturating_sub(1),
342    };
343    (idx + 1, byte.saturating_sub(line_starts[idx]) + 1)
344}
345
346fn is_decimal_point(source: &str, idx: usize) -> bool {
347    let prev = source[..idx].chars().next_back();
348    let next = source[idx + 1..].chars().next();
349    matches!((prev, next), (Some(a), Some(b)) if a.is_ascii_digit() && b.is_ascii_digit())
350}
351
352fn stable_hash(text: &str) -> u64 {
353    let mut hasher = std::collections::hash_map::DefaultHasher::new();
354    text.hash(&mut hasher);
355    hasher.finish()
356}
357
358fn append_program(target: &mut Program, fragment: Program) {
359    target.imports.extend(fragment.imports);
360    target.functions.extend(fragment.functions);
361    target.domains.extend(fragment.domains);
362    target.predicates.extend(fragment.predicates);
363    target.rules.extend(fragment.rules);
364    target.constraints.extend(fragment.constraints);
365    target.queries.extend(fragment.queries);
366    target.prob_facts.extend(fragment.prob_facts);
367    target
368        .annotated_disjunctions
369        .extend(fragment.annotated_disjunctions);
370    target.evidence.extend(fragment.evidence);
371    target.prob_queries.extend(fragment.prob_queries);
372    target.neural_predicates.extend(fragment.neural_predicates);
373    target.learnable_rules.extend(fragment.learnable_rules);
374    merge_directives(&mut target.directives, fragment.directives);
375}
376
377fn merge_directives(target: &mut Directives, fragment: Directives) {
378    if fragment.prob_engine.is_some() {
379        target.prob_engine = fragment.prob_engine;
380    }
381    if fragment.prob_cache.is_some() {
382        target.prob_cache = fragment.prob_cache;
383    }
384    if fragment.prob_samples.is_some() {
385        target.prob_samples = fragment.prob_samples;
386    }
387    if fragment.prob_seed.is_some() {
388        target.prob_seed = fragment.prob_seed;
389    }
390    if fragment.prob_confidence.is_some() {
391        target.prob_confidence = fragment.prob_confidence;
392    }
393    if fragment.prob_method.is_some() {
394        target.prob_method = fragment.prob_method;
395    }
396    if fragment.prob_max_nonmonotone_iterations.is_some() {
397        target.prob_max_nonmonotone_iterations = fragment.prob_max_nonmonotone_iterations;
398    }
399    if fragment.max_recursion_depth.is_some() {
400        target.max_recursion_depth = fragment.max_recursion_depth;
401    }
402    if fragment.magic_sets.is_some() {
403        target.magic_sets = fragment.magic_sets;
404    }
405}