Skip to main content

xlog_logic/
magic_sets.rs

1//! Magic-set rewriting for the deterministic language-completeness subset.
2
3use std::collections::{BTreeSet, HashMap, HashSet};
4
5use xlog_core::{Result, XlogError};
6
7use crate::ast::{Atom, BodyLiteral, MagicSetsMode, Program, Rule, Term};
8
9/// Status of a magic-set rewrite attempt.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MagicSetStatus {
12    /// Rewriting was disabled by source configuration or there was no directive.
13    Disabled,
14    /// A supported bound recursive query was rewritten.
15    Applied,
16    /// `auto` mode found an unsafe or inapplicable program and left it unchanged.
17    Declined,
18}
19
20/// Human- and test-readable metadata for a magic-set rewrite attempt.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct MagicSetReport {
23    /// Final status.
24    pub status: MagicSetStatus,
25    /// Generated magic predicate names.
26    pub generated_predicates: Vec<String>,
27    /// Adorned recursive predicates, formatted as `predicate/adornment`.
28    pub adorned_predicates: Vec<String>,
29    /// Reasons the rewrite declined.
30    pub declined_reasons: Vec<String>,
31}
32
33/// Rewritten program plus its report.
34#[derive(Debug, Clone)]
35pub struct MagicSetRewrite {
36    /// Program after rewriting, or the original program when disabled/declined.
37    pub program: Program,
38    /// Rewrite metadata.
39    pub report: MagicSetReport,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
43struct Adornment {
44    pred: String,
45    pattern: Vec<bool>,
46}
47
48#[derive(Debug, Clone)]
49struct Seed {
50    pred: String,
51    pattern: Vec<bool>,
52    terms: Vec<Term>,
53}
54
55/// Rewrite supported bound recursive queries with magic predicates.
56pub fn rewrite_magic_sets(program: &Program) -> Result<MagicSetRewrite> {
57    let mode = program.directives.magic_sets;
58    if mode.is_none() || mode == Some(MagicSetsMode::Off) {
59        return Ok(with_status(program, MagicSetStatus::Disabled));
60    }
61    let mode = mode.expect("checked above");
62
63    let recursive = recursive_predicates(program);
64    let seeds = collect_query_seeds(program, &recursive);
65    if seeds.is_empty() {
66        return decline_or_error(
67            program,
68            mode,
69            vec!["no bound recursive query eligible for magic_sets".to_string()],
70        );
71    }
72
73    if program.is_probabilistic_profile() {
74        return decline_or_error(
75            program,
76            mode,
77            vec!["probabilistic profiles are handled outside deterministic magic_sets".to_string()],
78        );
79    }
80
81    let target_preds: BTreeSet<String> = seeds.iter().map(|seed| seed.pred.clone()).collect();
82    let unsafe_reasons = unsupported_target_reasons(program, &target_preds, &recursive);
83    if !unsafe_reasons.is_empty() {
84        return decline_or_error(program, mode, unsafe_reasons);
85    }
86
87    let mut adornments = initial_adornments(&seeds);
88    expand_adornments(program, &target_preds, &mut adornments)?;
89
90    let mut generated_predicates: BTreeSet<String> = BTreeSet::new();
91    let mut adorned_predicates: BTreeSet<String> = BTreeSet::new();
92    for adornment in &adornments {
93        generated_predicates.insert(magic_predicate(&adornment.pred, &adornment.pattern));
94        adorned_predicates.insert(format!(
95            "{}/{}",
96            adornment.pred,
97            adornment_key(&adornment.pattern)
98        ));
99    }
100
101    let mut rewritten = program.clone();
102    rewritten.rules = rewrite_rules(program, &target_preds, &adornments, &seeds)?;
103
104    Ok(MagicSetRewrite {
105        program: rewritten,
106        report: MagicSetReport {
107            status: MagicSetStatus::Applied,
108            generated_predicates: generated_predicates.into_iter().collect(),
109            adorned_predicates: adorned_predicates.into_iter().collect(),
110            declined_reasons: Vec::new(),
111        },
112    })
113}
114
115fn with_status(program: &Program, status: MagicSetStatus) -> MagicSetRewrite {
116    MagicSetRewrite {
117        program: program.clone(),
118        report: MagicSetReport {
119            status,
120            generated_predicates: Vec::new(),
121            adorned_predicates: Vec::new(),
122            declined_reasons: Vec::new(),
123        },
124    }
125}
126
127fn decline_or_error(
128    program: &Program,
129    mode: MagicSetsMode,
130    reasons: Vec<String>,
131) -> Result<MagicSetRewrite> {
132    if mode == MagicSetsMode::On {
133        return Err(magic_error(reasons.join("; ")));
134    }
135    Ok(MagicSetRewrite {
136        program: program.clone(),
137        report: MagicSetReport {
138            status: MagicSetStatus::Declined,
139            generated_predicates: Vec::new(),
140            adorned_predicates: Vec::new(),
141            declined_reasons: reasons,
142        },
143    })
144}
145
146fn magic_error(message: impl Into<String>) -> XlogError {
147    XlogError::Compilation(format!("magic_sets error: {}", message.into()))
148}
149
150fn collect_query_seeds(program: &Program, recursive: &HashSet<String>) -> Vec<Seed> {
151    let mut seen = HashSet::new();
152    let mut out = Vec::new();
153    for query in &program.queries {
154        if !recursive.contains(&query.atom.predicate) {
155            continue;
156        }
157        let pattern: Vec<bool> = query.atom.terms.iter().map(is_seed_term).collect();
158        if !pattern.iter().any(|bound| *bound) {
159            continue;
160        }
161        if !query
162            .atom
163            .terms
164            .iter()
165            .zip(&pattern)
166            .all(|(term, bound)| !*bound || is_supported_magic_term(term))
167        {
168            continue;
169        }
170        let terms = bound_terms(&query.atom, &pattern);
171        let key = format!(
172            "{}:{}:{:?}",
173            query.atom.predicate,
174            adornment_key(&pattern),
175            terms
176        );
177        if seen.insert(key) {
178            out.push(Seed {
179                pred: query.atom.predicate.clone(),
180                pattern,
181                terms,
182            });
183        }
184    }
185    out
186}
187
188fn initial_adornments(seeds: &[Seed]) -> BTreeSet<Adornment> {
189    seeds
190        .iter()
191        .map(|seed| Adornment {
192            pred: seed.pred.clone(),
193            pattern: seed.pattern.clone(),
194        })
195        .collect()
196}
197
198fn expand_adornments(
199    program: &Program,
200    target_preds: &BTreeSet<String>,
201    adornments: &mut BTreeSet<Adornment>,
202) -> Result<()> {
203    let mut changed = true;
204    while changed {
205        changed = false;
206        let snapshot: Vec<Adornment> = adornments.iter().cloned().collect();
207        for adornment in snapshot {
208            for rule in program
209                .rules
210                .iter()
211                .filter(|rule| rule.head.predicate == adornment.pred)
212            {
213                for discovered in discover_body_adornments(rule, &adornment.pattern, target_preds)?
214                {
215                    changed |= adornments.insert(discovered);
216                }
217            }
218        }
219    }
220    Ok(())
221}
222
223fn discover_body_adornments(
224    rule: &Rule,
225    head_pattern: &[bool],
226    target_preds: &BTreeSet<String>,
227) -> Result<Vec<Adornment>> {
228    let mut bound = head_bound_variables(&rule.head, head_pattern);
229    let mut out = Vec::new();
230    for lit in &rule.body {
231        match lit {
232            BodyLiteral::Positive(atom) => {
233                if target_preds.contains(&atom.predicate) {
234                    let pattern = atom_adornment(atom, &bound);
235                    if !pattern.iter().any(|is_bound| *is_bound) {
236                        return Err(magic_error(format!(
237                            "recursive call {}/{} has no bound argument under supported SIPS",
238                            atom.predicate,
239                            atom.arity()
240                        )));
241                    }
242                    out.push(Adornment {
243                        pred: atom.predicate.clone(),
244                        pattern,
245                    });
246                }
247                bind_atom_variables(atom, &mut bound);
248            }
249            BodyLiteral::Comparison(_)
250            | BodyLiteral::Epistemic(_)
251            | BodyLiteral::IsExpr(_)
252            | BodyLiteral::Negated(_)
253            | BodyLiteral::Univ(_) => {}
254        }
255    }
256    Ok(out)
257}
258
259fn rewrite_rules(
260    program: &Program,
261    target_preds: &BTreeSet<String>,
262    adornments: &BTreeSet<Adornment>,
263    seeds: &[Seed],
264) -> Result<Vec<Rule>> {
265    let mut out: Vec<Rule> = program
266        .rules
267        .iter()
268        .filter(|rule| !target_preds.contains(&rule.head.predicate))
269        .cloned()
270        .collect();
271
272    let mut emitted = HashSet::new();
273    for seed in seeds {
274        let rule = Rule {
275            head: Atom {
276                predicate: magic_predicate(&seed.pred, &seed.pattern),
277                terms: seed.terms.clone(),
278            },
279            body: Vec::new(),
280        };
281        push_unique_rule(&mut out, &mut emitted, rule);
282    }
283
284    for adornment in adornments {
285        for rule in program
286            .rules
287            .iter()
288            .filter(|rule| rule.head.predicate == adornment.pred)
289        {
290            for magic_rule in propagation_rules(rule, &adornment.pattern, target_preds)? {
291                push_unique_rule(&mut out, &mut emitted, magic_rule);
292            }
293        }
294    }
295
296    for adornment in adornments {
297        for rule in program
298            .rules
299            .iter()
300            .filter(|rule| rule.head.predicate == adornment.pred)
301        {
302            let mut body = vec![BodyLiteral::Positive(magic_atom_for(
303                &rule.head,
304                &adornment.pattern,
305            ))];
306            body.extend(rule.body.clone());
307            out.push(Rule {
308                head: rule.head.clone(),
309                body,
310            });
311        }
312    }
313
314    Ok(out)
315}
316
317fn propagation_rules(
318    rule: &Rule,
319    head_pattern: &[bool],
320    target_preds: &BTreeSet<String>,
321) -> Result<Vec<Rule>> {
322    let caller_magic = magic_atom_for(&rule.head, head_pattern);
323    let mut prefix = vec![BodyLiteral::Positive(caller_magic.clone())];
324    let mut bound = head_bound_variables(&rule.head, head_pattern);
325    let mut out = Vec::new();
326
327    for lit in &rule.body {
328        let BodyLiteral::Positive(atom) = lit else {
329            continue;
330        };
331        if target_preds.contains(&atom.predicate) {
332            let pattern = atom_adornment(atom, &bound);
333            if !pattern.iter().any(|is_bound| *is_bound) {
334                return Err(magic_error(format!(
335                    "recursive call {}/{} has no bound argument under supported SIPS",
336                    atom.predicate,
337                    atom.arity()
338                )));
339            }
340            let head = magic_atom_for(atom, &pattern);
341            let is_trivial = prefix.len() == 1
342                && matches!(&prefix[0], BodyLiteral::Positive(prefix_atom) if *prefix_atom == head);
343            if !is_trivial {
344                out.push(Rule {
345                    head,
346                    body: prefix.clone(),
347                });
348            }
349        }
350        bind_atom_variables(atom, &mut bound);
351        prefix.push(lit.clone());
352    }
353
354    Ok(out)
355}
356
357fn unsupported_target_reasons(
358    program: &Program,
359    target_preds: &BTreeSet<String>,
360    recursive: &HashSet<String>,
361) -> Vec<String> {
362    let mut reasons = BTreeSet::new();
363    for rule in &program.rules {
364        if !target_preds.contains(&rule.head.predicate) {
365            continue;
366        }
367        if rule.has_negation() {
368            reasons.insert(format!(
369                "negation in recursive rule for {} is outside the supported magic_sets subset",
370                rule.head.predicate
371            ));
372        }
373        if rule.has_aggregation() || rule.body.iter().any(body_literal_has_aggregate) {
374            reasons.insert(format!(
375                "aggregation in recursive rule for {} is outside the supported magic_sets subset",
376                rule.head.predicate
377            ));
378        }
379        for lit in &rule.body {
380            match lit {
381                BodyLiteral::Positive(atom) => {
382                    if recursive.contains(&atom.predicate) && atom.predicate != rule.head.predicate
383                    {
384                        reasons.insert(format!(
385                            "mutual recursion through {} is outside the supported magic_sets subset",
386                            atom.predicate
387                        ));
388                    }
389                    if atom.predicate.starts_with("__xlog_meta_")
390                        || atom.predicate.starts_with("__xlog_list_")
391                    {
392                        reasons.insert(format!(
393                            "meta/list helper {} in recursive rule is outside the supported magic_sets subset",
394                            atom.predicate
395                        ));
396                    }
397                }
398                BodyLiteral::Negated(_) => {}
399                BodyLiteral::Comparison(_)
400                | BodyLiteral::Epistemic(_)
401                | BodyLiteral::IsExpr(_)
402                | BodyLiteral::Univ(_) => {
403                    reasons.insert(format!(
404                        "non-positive literal in recursive rule for {} is outside the supported magic_sets subset",
405                        rule.head.predicate
406                    ));
407                }
408            }
409        }
410    }
411    reasons.into_iter().collect()
412}
413
414fn recursive_predicates(program: &Program) -> HashSet<String> {
415    let mut deps: HashMap<String, HashSet<String>> = HashMap::new();
416    for rule in &program.rules {
417        let entry = deps.entry(rule.head.predicate.clone()).or_default();
418        for pred in rule.body_predicates() {
419            entry.insert(pred.to_string());
420        }
421    }
422    deps.keys()
423        .filter(|pred| reaches(pred, pred, &deps, &mut HashSet::new()))
424        .cloned()
425        .collect()
426}
427
428fn reaches(
429    start: &str,
430    target: &str,
431    deps: &HashMap<String, HashSet<String>>,
432    seen: &mut HashSet<String>,
433) -> bool {
434    let Some(next) = deps.get(start) else {
435        return false;
436    };
437    for pred in next {
438        if pred == target {
439            return true;
440        }
441        if seen.insert(pred.clone()) && reaches(pred, target, deps, seen) {
442            return true;
443        }
444    }
445    false
446}
447
448fn head_bound_variables(atom: &Atom, pattern: &[bool]) -> HashSet<String> {
449    atom.terms
450        .iter()
451        .zip(pattern)
452        .filter(|(_, bound)| **bound)
453        .flat_map(|(term, _)| term.variables().into_iter().map(str::to_string))
454        .collect()
455}
456
457fn atom_adornment(atom: &Atom, bound: &HashSet<String>) -> Vec<bool> {
458    atom.terms
459        .iter()
460        .map(|term| term_is_bound(term, bound))
461        .collect()
462}
463
464fn term_is_bound(term: &Term, bound: &HashSet<String>) -> bool {
465    match term {
466        Term::Variable(name) => bound.contains(name),
467        Term::Anonymous => false,
468        Term::List(items) => items.iter().all(|item| term_is_bound(item, bound)),
469        Term::Cons { head, tail } => term_is_bound(head, bound) && term_is_bound(tail, bound),
470        Term::Compound { args, .. } => args.iter().all(|arg| term_is_bound(arg, bound)),
471        Term::Integer(_)
472        | Term::Float(_)
473        | Term::String(_)
474        | Term::Symbol(_)
475        | Term::PredRef(_) => true,
476        Term::Aggregate(_) => false,
477    }
478}
479
480fn bind_atom_variables(atom: &Atom, bound: &mut HashSet<String>) {
481    for name in atom.variables() {
482        bound.insert(name.to_string());
483    }
484}
485
486fn body_literal_has_aggregate(lit: &BodyLiteral) -> bool {
487    match lit {
488        BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom_has_aggregate(atom),
489        BodyLiteral::Epistemic(lit) => atom_has_aggregate(&lit.atom),
490        BodyLiteral::Comparison(comparison) => {
491            term_has_aggregate(&comparison.left) || term_has_aggregate(&comparison.right)
492        }
493        BodyLiteral::IsExpr(_) => false,
494        BodyLiteral::Univ(univ) => {
495            term_has_aggregate(&univ.term) || term_has_aggregate(&univ.parts)
496        }
497    }
498}
499
500fn atom_has_aggregate(atom: &Atom) -> bool {
501    atom.terms.iter().any(term_has_aggregate)
502}
503
504fn term_has_aggregate(term: &Term) -> bool {
505    match term {
506        Term::Aggregate(_) => true,
507        Term::List(items) => items.iter().any(term_has_aggregate),
508        Term::Cons { head, tail } => term_has_aggregate(head) || term_has_aggregate(tail),
509        Term::Compound { args, .. } => args.iter().any(term_has_aggregate),
510        Term::Variable(_)
511        | Term::Anonymous
512        | Term::Integer(_)
513        | Term::Float(_)
514        | Term::String(_)
515        | Term::Symbol(_)
516        | Term::PredRef(_) => false,
517    }
518}
519
520fn is_seed_term(term: &Term) -> bool {
521    is_supported_magic_term(term) && !term.is_any_variable()
522}
523
524fn is_supported_magic_term(term: &Term) -> bool {
525    matches!(
526        term,
527        Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_)
528    )
529}
530
531fn bound_terms(atom: &Atom, pattern: &[bool]) -> Vec<Term> {
532    atom.terms
533        .iter()
534        .zip(pattern)
535        .filter(|(_, bound)| **bound)
536        .map(|(term, _)| term.clone())
537        .collect()
538}
539
540fn magic_atom_for(atom: &Atom, pattern: &[bool]) -> Atom {
541    Atom {
542        predicate: magic_predicate(&atom.predicate, pattern),
543        terms: bound_terms(atom, pattern),
544    }
545}
546
547fn magic_predicate(pred: &str, pattern: &[bool]) -> String {
548    format!("__xlog_magic_{}_{}", pred, adornment_key(pattern))
549}
550
551fn adornment_key(pattern: &[bool]) -> String {
552    pattern
553        .iter()
554        .map(|bound| if *bound { 'b' } else { 'f' })
555        .collect()
556}
557
558fn push_unique_rule(out: &mut Vec<Rule>, emitted: &mut HashSet<String>, rule: Rule) {
559    let key = format!("{:?}", rule);
560    if emitted.insert(key) {
561        out.push(rule);
562    }
563}