1use std::collections::{BTreeSet, HashMap, HashSet};
4
5use xlog_core::{Result, XlogError};
6
7use crate::ast::{Atom, BodyLiteral, MagicSetsMode, Program, Rule, Term};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MagicSetStatus {
12 Disabled,
14 Applied,
16 Declined,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct MagicSetReport {
23 pub status: MagicSetStatus,
25 pub generated_predicates: Vec<String>,
27 pub adorned_predicates: Vec<String>,
29 pub declined_reasons: Vec<String>,
31}
32
33#[derive(Debug, Clone)]
35pub struct MagicSetRewrite {
36 pub program: Program,
38 pub report: MagicSetReport,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
43struct Adornment {
44 pred: String,
45 pattern: Vec<bool>,
46}
47
48#[derive(Debug, Clone)]
49struct Seed {
50 pred: String,
51 pattern: Vec<bool>,
52 terms: Vec<Term>,
53}
54
55pub fn rewrite_magic_sets(program: &Program) -> Result<MagicSetRewrite> {
57 let mode = program.directives.magic_sets;
58 if mode.is_none() || mode == Some(MagicSetsMode::Off) {
59 return Ok(with_status(program, MagicSetStatus::Disabled));
60 }
61 let mode = mode.expect("checked above");
62
63 let recursive = recursive_predicates(program);
64 let seeds = collect_query_seeds(program, &recursive);
65 if seeds.is_empty() {
66 return decline_or_error(
67 program,
68 mode,
69 vec!["no bound recursive query eligible for magic_sets".to_string()],
70 );
71 }
72
73 if program.is_probabilistic_profile() {
74 return decline_or_error(
75 program,
76 mode,
77 vec!["probabilistic profiles are handled outside deterministic magic_sets".to_string()],
78 );
79 }
80
81 let target_preds: BTreeSet<String> = seeds.iter().map(|seed| seed.pred.clone()).collect();
82 let unsafe_reasons = unsupported_target_reasons(program, &target_preds, &recursive);
83 if !unsafe_reasons.is_empty() {
84 return decline_or_error(program, mode, unsafe_reasons);
85 }
86
87 let mut adornments = initial_adornments(&seeds);
88 expand_adornments(program, &target_preds, &mut adornments)?;
89
90 let mut generated_predicates: BTreeSet<String> = BTreeSet::new();
91 let mut adorned_predicates: BTreeSet<String> = BTreeSet::new();
92 for adornment in &adornments {
93 generated_predicates.insert(magic_predicate(&adornment.pred, &adornment.pattern));
94 adorned_predicates.insert(format!(
95 "{}/{}",
96 adornment.pred,
97 adornment_key(&adornment.pattern)
98 ));
99 }
100
101 let mut rewritten = program.clone();
102 rewritten.rules = rewrite_rules(program, &target_preds, &adornments, &seeds)?;
103
104 Ok(MagicSetRewrite {
105 program: rewritten,
106 report: MagicSetReport {
107 status: MagicSetStatus::Applied,
108 generated_predicates: generated_predicates.into_iter().collect(),
109 adorned_predicates: adorned_predicates.into_iter().collect(),
110 declined_reasons: Vec::new(),
111 },
112 })
113}
114
115fn with_status(program: &Program, status: MagicSetStatus) -> MagicSetRewrite {
116 MagicSetRewrite {
117 program: program.clone(),
118 report: MagicSetReport {
119 status,
120 generated_predicates: Vec::new(),
121 adorned_predicates: Vec::new(),
122 declined_reasons: Vec::new(),
123 },
124 }
125}
126
127fn decline_or_error(
128 program: &Program,
129 mode: MagicSetsMode,
130 reasons: Vec<String>,
131) -> Result<MagicSetRewrite> {
132 if mode == MagicSetsMode::On {
133 return Err(magic_error(reasons.join("; ")));
134 }
135 Ok(MagicSetRewrite {
136 program: program.clone(),
137 report: MagicSetReport {
138 status: MagicSetStatus::Declined,
139 generated_predicates: Vec::new(),
140 adorned_predicates: Vec::new(),
141 declined_reasons: reasons,
142 },
143 })
144}
145
146fn magic_error(message: impl Into<String>) -> XlogError {
147 XlogError::Compilation(format!("magic_sets error: {}", message.into()))
148}
149
150fn collect_query_seeds(program: &Program, recursive: &HashSet<String>) -> Vec<Seed> {
151 let mut seen = HashSet::new();
152 let mut out = Vec::new();
153 for query in &program.queries {
154 if !recursive.contains(&query.atom.predicate) {
155 continue;
156 }
157 let pattern: Vec<bool> = query.atom.terms.iter().map(is_seed_term).collect();
158 if !pattern.iter().any(|bound| *bound) {
159 continue;
160 }
161 if !query
162 .atom
163 .terms
164 .iter()
165 .zip(&pattern)
166 .all(|(term, bound)| !*bound || is_supported_magic_term(term))
167 {
168 continue;
169 }
170 let terms = bound_terms(&query.atom, &pattern);
171 let key = format!(
172 "{}:{}:{:?}",
173 query.atom.predicate,
174 adornment_key(&pattern),
175 terms
176 );
177 if seen.insert(key) {
178 out.push(Seed {
179 pred: query.atom.predicate.clone(),
180 pattern,
181 terms,
182 });
183 }
184 }
185 out
186}
187
188fn initial_adornments(seeds: &[Seed]) -> BTreeSet<Adornment> {
189 seeds
190 .iter()
191 .map(|seed| Adornment {
192 pred: seed.pred.clone(),
193 pattern: seed.pattern.clone(),
194 })
195 .collect()
196}
197
198fn expand_adornments(
199 program: &Program,
200 target_preds: &BTreeSet<String>,
201 adornments: &mut BTreeSet<Adornment>,
202) -> Result<()> {
203 let mut changed = true;
204 while changed {
205 changed = false;
206 let snapshot: Vec<Adornment> = adornments.iter().cloned().collect();
207 for adornment in snapshot {
208 for rule in program
209 .rules
210 .iter()
211 .filter(|rule| rule.head.predicate == adornment.pred)
212 {
213 for discovered in discover_body_adornments(rule, &adornment.pattern, target_preds)?
214 {
215 changed |= adornments.insert(discovered);
216 }
217 }
218 }
219 }
220 Ok(())
221}
222
223fn discover_body_adornments(
224 rule: &Rule,
225 head_pattern: &[bool],
226 target_preds: &BTreeSet<String>,
227) -> Result<Vec<Adornment>> {
228 let mut bound = head_bound_variables(&rule.head, head_pattern);
229 let mut out = Vec::new();
230 for lit in &rule.body {
231 match lit {
232 BodyLiteral::Positive(atom) => {
233 if target_preds.contains(&atom.predicate) {
234 let pattern = atom_adornment(atom, &bound);
235 if !pattern.iter().any(|is_bound| *is_bound) {
236 return Err(magic_error(format!(
237 "recursive call {}/{} has no bound argument under supported SIPS",
238 atom.predicate,
239 atom.arity()
240 )));
241 }
242 out.push(Adornment {
243 pred: atom.predicate.clone(),
244 pattern,
245 });
246 }
247 bind_atom_variables(atom, &mut bound);
248 }
249 BodyLiteral::Comparison(_)
250 | BodyLiteral::Epistemic(_)
251 | BodyLiteral::IsExpr(_)
252 | BodyLiteral::Negated(_)
253 | BodyLiteral::Univ(_) => {}
254 }
255 }
256 Ok(out)
257}
258
259fn rewrite_rules(
260 program: &Program,
261 target_preds: &BTreeSet<String>,
262 adornments: &BTreeSet<Adornment>,
263 seeds: &[Seed],
264) -> Result<Vec<Rule>> {
265 let mut out: Vec<Rule> = program
266 .rules
267 .iter()
268 .filter(|rule| !target_preds.contains(&rule.head.predicate))
269 .cloned()
270 .collect();
271
272 let mut emitted = HashSet::new();
273 for seed in seeds {
274 let rule = Rule {
275 head: Atom {
276 predicate: magic_predicate(&seed.pred, &seed.pattern),
277 terms: seed.terms.clone(),
278 },
279 body: Vec::new(),
280 };
281 push_unique_rule(&mut out, &mut emitted, rule);
282 }
283
284 for adornment in adornments {
285 for rule in program
286 .rules
287 .iter()
288 .filter(|rule| rule.head.predicate == adornment.pred)
289 {
290 for magic_rule in propagation_rules(rule, &adornment.pattern, target_preds)? {
291 push_unique_rule(&mut out, &mut emitted, magic_rule);
292 }
293 }
294 }
295
296 for adornment in adornments {
297 for rule in program
298 .rules
299 .iter()
300 .filter(|rule| rule.head.predicate == adornment.pred)
301 {
302 let mut body = vec![BodyLiteral::Positive(magic_atom_for(
303 &rule.head,
304 &adornment.pattern,
305 ))];
306 body.extend(rule.body.clone());
307 out.push(Rule {
308 head: rule.head.clone(),
309 body,
310 });
311 }
312 }
313
314 Ok(out)
315}
316
317fn propagation_rules(
318 rule: &Rule,
319 head_pattern: &[bool],
320 target_preds: &BTreeSet<String>,
321) -> Result<Vec<Rule>> {
322 let caller_magic = magic_atom_for(&rule.head, head_pattern);
323 let mut prefix = vec![BodyLiteral::Positive(caller_magic.clone())];
324 let mut bound = head_bound_variables(&rule.head, head_pattern);
325 let mut out = Vec::new();
326
327 for lit in &rule.body {
328 let BodyLiteral::Positive(atom) = lit else {
329 continue;
330 };
331 if target_preds.contains(&atom.predicate) {
332 let pattern = atom_adornment(atom, &bound);
333 if !pattern.iter().any(|is_bound| *is_bound) {
334 return Err(magic_error(format!(
335 "recursive call {}/{} has no bound argument under supported SIPS",
336 atom.predicate,
337 atom.arity()
338 )));
339 }
340 let head = magic_atom_for(atom, &pattern);
341 let is_trivial = prefix.len() == 1
342 && matches!(&prefix[0], BodyLiteral::Positive(prefix_atom) if *prefix_atom == head);
343 if !is_trivial {
344 out.push(Rule {
345 head,
346 body: prefix.clone(),
347 });
348 }
349 }
350 bind_atom_variables(atom, &mut bound);
351 prefix.push(lit.clone());
352 }
353
354 Ok(out)
355}
356
357fn unsupported_target_reasons(
358 program: &Program,
359 target_preds: &BTreeSet<String>,
360 recursive: &HashSet<String>,
361) -> Vec<String> {
362 let mut reasons = BTreeSet::new();
363 for rule in &program.rules {
364 if !target_preds.contains(&rule.head.predicate) {
365 continue;
366 }
367 if rule.has_negation() {
368 reasons.insert(format!(
369 "negation in recursive rule for {} is outside the supported magic_sets subset",
370 rule.head.predicate
371 ));
372 }
373 if rule.has_aggregation() || rule.body.iter().any(body_literal_has_aggregate) {
374 reasons.insert(format!(
375 "aggregation in recursive rule for {} is outside the supported magic_sets subset",
376 rule.head.predicate
377 ));
378 }
379 for lit in &rule.body {
380 match lit {
381 BodyLiteral::Positive(atom) => {
382 if recursive.contains(&atom.predicate) && atom.predicate != rule.head.predicate
383 {
384 reasons.insert(format!(
385 "mutual recursion through {} is outside the supported magic_sets subset",
386 atom.predicate
387 ));
388 }
389 if atom.predicate.starts_with("__xlog_meta_")
390 || atom.predicate.starts_with("__xlog_list_")
391 {
392 reasons.insert(format!(
393 "meta/list helper {} in recursive rule is outside the supported magic_sets subset",
394 atom.predicate
395 ));
396 }
397 }
398 BodyLiteral::Negated(_) => {}
399 BodyLiteral::Comparison(_)
400 | BodyLiteral::Epistemic(_)
401 | BodyLiteral::IsExpr(_)
402 | BodyLiteral::Univ(_) => {
403 reasons.insert(format!(
404 "non-positive literal in recursive rule for {} is outside the supported magic_sets subset",
405 rule.head.predicate
406 ));
407 }
408 }
409 }
410 }
411 reasons.into_iter().collect()
412}
413
414fn recursive_predicates(program: &Program) -> HashSet<String> {
415 let mut deps: HashMap<String, HashSet<String>> = HashMap::new();
416 for rule in &program.rules {
417 let entry = deps.entry(rule.head.predicate.clone()).or_default();
418 for pred in rule.body_predicates() {
419 entry.insert(pred.to_string());
420 }
421 }
422 deps.keys()
423 .filter(|pred| reaches(pred, pred, &deps, &mut HashSet::new()))
424 .cloned()
425 .collect()
426}
427
428fn reaches(
429 start: &str,
430 target: &str,
431 deps: &HashMap<String, HashSet<String>>,
432 seen: &mut HashSet<String>,
433) -> bool {
434 let Some(next) = deps.get(start) else {
435 return false;
436 };
437 for pred in next {
438 if pred == target {
439 return true;
440 }
441 if seen.insert(pred.clone()) && reaches(pred, target, deps, seen) {
442 return true;
443 }
444 }
445 false
446}
447
448fn head_bound_variables(atom: &Atom, pattern: &[bool]) -> HashSet<String> {
449 atom.terms
450 .iter()
451 .zip(pattern)
452 .filter(|(_, bound)| **bound)
453 .flat_map(|(term, _)| term.variables().into_iter().map(str::to_string))
454 .collect()
455}
456
457fn atom_adornment(atom: &Atom, bound: &HashSet<String>) -> Vec<bool> {
458 atom.terms
459 .iter()
460 .map(|term| term_is_bound(term, bound))
461 .collect()
462}
463
464fn term_is_bound(term: &Term, bound: &HashSet<String>) -> bool {
465 match term {
466 Term::Variable(name) => bound.contains(name),
467 Term::Anonymous => false,
468 Term::List(items) => items.iter().all(|item| term_is_bound(item, bound)),
469 Term::Cons { head, tail } => term_is_bound(head, bound) && term_is_bound(tail, bound),
470 Term::Compound { args, .. } => args.iter().all(|arg| term_is_bound(arg, bound)),
471 Term::Integer(_)
472 | Term::Float(_)
473 | Term::String(_)
474 | Term::Symbol(_)
475 | Term::PredRef(_) => true,
476 Term::Aggregate(_) => false,
477 }
478}
479
480fn bind_atom_variables(atom: &Atom, bound: &mut HashSet<String>) {
481 for name in atom.variables() {
482 bound.insert(name.to_string());
483 }
484}
485
486fn body_literal_has_aggregate(lit: &BodyLiteral) -> bool {
487 match lit {
488 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom_has_aggregate(atom),
489 BodyLiteral::Epistemic(lit) => atom_has_aggregate(&lit.atom),
490 BodyLiteral::Comparison(comparison) => {
491 term_has_aggregate(&comparison.left) || term_has_aggregate(&comparison.right)
492 }
493 BodyLiteral::IsExpr(_) => false,
494 BodyLiteral::Univ(univ) => {
495 term_has_aggregate(&univ.term) || term_has_aggregate(&univ.parts)
496 }
497 }
498}
499
500fn atom_has_aggregate(atom: &Atom) -> bool {
501 atom.terms.iter().any(term_has_aggregate)
502}
503
504fn term_has_aggregate(term: &Term) -> bool {
505 match term {
506 Term::Aggregate(_) => true,
507 Term::List(items) => items.iter().any(term_has_aggregate),
508 Term::Cons { head, tail } => term_has_aggregate(head) || term_has_aggregate(tail),
509 Term::Compound { args, .. } => args.iter().any(term_has_aggregate),
510 Term::Variable(_)
511 | Term::Anonymous
512 | Term::Integer(_)
513 | Term::Float(_)
514 | Term::String(_)
515 | Term::Symbol(_)
516 | Term::PredRef(_) => false,
517 }
518}
519
520fn is_seed_term(term: &Term) -> bool {
521 is_supported_magic_term(term) && !term.is_any_variable()
522}
523
524fn is_supported_magic_term(term: &Term) -> bool {
525 matches!(
526 term,
527 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_)
528 )
529}
530
531fn bound_terms(atom: &Atom, pattern: &[bool]) -> Vec<Term> {
532 atom.terms
533 .iter()
534 .zip(pattern)
535 .filter(|(_, bound)| **bound)
536 .map(|(term, _)| term.clone())
537 .collect()
538}
539
540fn magic_atom_for(atom: &Atom, pattern: &[bool]) -> Atom {
541 Atom {
542 predicate: magic_predicate(&atom.predicate, pattern),
543 terms: bound_terms(atom, pattern),
544 }
545}
546
547fn magic_predicate(pred: &str, pattern: &[bool]) -> String {
548 format!("__xlog_magic_{}_{}", pred, adornment_key(pattern))
549}
550
551fn adornment_key(pattern: &[bool]) -> String {
552 pattern
553 .iter()
554 .map(|bound| if *bound { 'b' } else { 'f' })
555 .collect()
556}
557
558fn push_unique_rule(out: &mut Vec<Rule>, emitted: &mut HashSet<String>, rule: Rule) {
559 let key = format!("{:?}", rule);
560 if emitted.insert(key) {
561 out.push(rule);
562 }
563}