1use crate::ast::{BodyLiteral, Program};
4use std::collections::{HashMap, HashSet};
5use xlog_core::{Result, XlogError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub(crate) enum DepType {
10 Positive,
11 Negative,
12 Aggregate,
13}
14
15#[derive(Debug, Clone)]
17pub(crate) struct DepEdge {
18 pub from: String,
19 pub to: String,
20 pub dep_type: DepType,
21}
22
23#[derive(Debug, Default)]
25pub struct DependencyGraph {
26 pub predicates: HashSet<String>,
28 pub(crate) edges: Vec<DepEdge>,
29}
30
31impl DependencyGraph {
32 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn add_predicate(&mut self, name: String) {
39 self.predicates.insert(name);
40 }
41
42 pub(crate) fn add_edge(&mut self, from: String, to: String, dep_type: DepType) {
43 self.predicates.insert(from.clone());
44 self.predicates.insert(to.clone());
45 self.edges.push(DepEdge { from, to, dep_type });
46 }
47
48 pub(crate) fn outgoing(&self, pred: &str) -> Vec<&DepEdge> {
49 self.edges.iter().filter(|e| e.from == pred).collect()
50 }
51}
52
53pub fn build_dependency_graph(program: &Program) -> DependencyGraph {
55 let mut graph = DependencyGraph::new();
56
57 for rule in &program.rules {
58 let head = &rule.head.predicate;
59 graph.add_predicate(head.clone());
60
61 for lit in &rule.body {
62 match lit {
63 BodyLiteral::Positive(atom) => {
64 graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Positive);
65 }
66 BodyLiteral::Negated(atom) => {
67 graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Negative);
68 }
69 BodyLiteral::Epistemic(lit) => {
70 graph.add_edge(head.clone(), lit.atom.predicate.clone(), DepType::Negative);
71 }
72 BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => {}
73 }
74 }
75
76 if rule.has_aggregation() {
77 for lit in &rule.body {
78 if let BodyLiteral::Positive(atom) = lit {
79 graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Aggregate);
80 }
81 }
82 }
83 }
84
85 for lr in &program.learnable_rules {
90 let head = &lr.head.predicate;
91 graph.add_predicate(head.clone());
92 for body_lit in &lr.body {
93 if let Some(atom) = body_lit.atom() {
94 graph.add_predicate(atom.predicate.clone());
95 graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Positive);
96 }
97 }
98 }
99
100 graph
101}
102
103fn find_sccs(graph: &DependencyGraph) -> Vec<Vec<String>> {
106 let mut index_counter = 0;
107 let mut stack = Vec::new();
108 let mut indices: HashMap<String, usize> = HashMap::new();
109 let mut lowlinks: HashMap<String, usize> = HashMap::new();
110 let mut on_stack: HashSet<String> = HashSet::new();
111 let mut sccs: Vec<Vec<String>> = Vec::new();
112
113 #[allow(clippy::too_many_arguments)]
114 fn strongconnect(
115 v: &str,
116 graph: &DependencyGraph,
117 index_counter: &mut usize,
118 stack: &mut Vec<String>,
119 indices: &mut HashMap<String, usize>,
120 lowlinks: &mut HashMap<String, usize>,
121 on_stack: &mut HashSet<String>,
122 sccs: &mut Vec<Vec<String>>,
123 ) {
124 indices.insert(v.to_string(), *index_counter);
125 lowlinks.insert(v.to_string(), *index_counter);
126 *index_counter += 1;
127 stack.push(v.to_string());
128 on_stack.insert(v.to_string());
129
130 for edge in graph.outgoing(v) {
131 let w = &edge.to;
132 if !indices.contains_key(w) {
133 strongconnect(
134 w,
135 graph,
136 index_counter,
137 stack,
138 indices,
139 lowlinks,
140 on_stack,
141 sccs,
142 );
143 let low_v = *lowlinks.get(v).unwrap();
144 let low_w = *lowlinks.get(w).unwrap();
145 lowlinks.insert(v.to_string(), low_v.min(low_w));
146 } else if on_stack.contains(w) {
147 let low_v = *lowlinks.get(v).unwrap();
148 let idx_w = *indices.get(w).unwrap();
149 lowlinks.insert(v.to_string(), low_v.min(idx_w));
150 }
151 }
152
153 let low_v = *lowlinks.get(v).unwrap();
154 let idx_v = *indices.get(v).unwrap();
155 if low_v == idx_v {
156 let mut scc = Vec::new();
157 loop {
158 let w = stack.pop().unwrap();
159 on_stack.remove(&w);
160 scc.push(w.clone());
161 if w == v {
162 break;
163 }
164 }
165 sccs.push(scc);
166 }
167 }
168
169 for pred in &graph.predicates {
170 if !indices.contains_key(pred) {
171 strongconnect(
172 pred,
173 graph,
174 &mut index_counter,
175 &mut stack,
176 &mut indices,
177 &mut lowlinks,
178 &mut on_stack,
179 &mut sccs,
180 );
181 }
182 }
183
184 sccs
185}
186
187fn check_scc_for_negation_cycle(scc: &[String], graph: &DependencyGraph) -> Option<Vec<String>> {
189 if scc.len() == 1 {
190 let pred = &scc[0];
191 for edge in graph.outgoing(pred) {
192 if edge.to == *pred && edge.dep_type != DepType::Positive {
193 return Some(vec![pred.clone()]);
194 }
195 }
196 return None;
197 }
198
199 let scc_set: HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
200 for pred in scc {
201 for edge in graph.outgoing(pred) {
202 if scc_set.contains(edge.to.as_str()) && edge.dep_type != DepType::Positive {
203 return Some(scc.to_vec());
204 }
205 }
206 }
207 None
208}
209
210#[derive(Debug, Clone)]
212pub struct Stratum {
213 pub id: usize,
215 pub predicates: Vec<String>,
217}
218
219#[derive(Debug, Clone)]
221pub struct StratificationResult {
222 pub sccs: Vec<Vec<String>>,
224 pub non_monotone_sccs: HashSet<usize>,
226 pub strata: HashMap<String, usize>,
228}
229
230pub fn stratify(program: &Program) -> Result<Vec<Stratum>> {
232 let graph = build_dependency_graph(program);
233 let sccs = find_sccs(&graph);
234
235 for scc in &sccs {
236 if let Some(cycle) = check_scc_for_negation_cycle(scc, &graph) {
237 if !program.is_probabilistic_profile() {
240 return Err(XlogError::StratificationCycle(cycle));
241 }
242 }
246 }
247
248 let mut stratum_map: HashMap<String, usize> = HashMap::new();
249 let mut max_stratum = 0;
250
251 for scc in &sccs {
256 let mut min_stratum = 0;
257 for pred in scc {
258 for edge in graph.outgoing(pred) {
259 if let Some(&dep_stratum) = stratum_map.get(&edge.to) {
260 let required = match edge.dep_type {
261 DepType::Positive => dep_stratum,
262 DepType::Negative | DepType::Aggregate => dep_stratum + 1,
263 };
264 min_stratum = min_stratum.max(required);
265 }
266 }
267 }
268 for pred in scc {
269 stratum_map.insert(pred.clone(), min_stratum);
270 }
271 max_stratum = max_stratum.max(min_stratum);
272 }
273
274 let mut strata: Vec<Stratum> = (0..=max_stratum)
275 .map(|id| Stratum {
276 id,
277 predicates: vec![],
278 })
279 .collect();
280
281 for (pred, stratum) in stratum_map {
282 strata[stratum].predicates.push(pred);
283 }
284
285 strata.retain(|s| !s.predicates.is_empty());
286 for (i, stratum) in strata.iter_mut().enumerate() {
287 stratum.id = i;
288 }
289
290 Ok(strata)
291}
292
293pub fn analyze_stratification(program: &Program) -> StratificationResult {
296 let graph = build_dependency_graph(program);
297 let sccs = find_sccs(&graph);
298
299 let mut non_monotone_sccs: HashSet<usize> = HashSet::new();
300 for (i, scc) in sccs.iter().enumerate() {
301 if check_scc_for_negation_cycle(scc, &graph).is_some() {
302 non_monotone_sccs.insert(i);
303 }
304 }
305
306 let mut strata: HashMap<String, usize> = HashMap::new();
308 let mut max_stratum = 0;
309
310 for (scc_idx, scc) in sccs.iter().enumerate() {
311 if non_monotone_sccs.contains(&scc_idx) {
312 continue; }
314
315 let mut min_stratum = 0;
316 for pred in scc {
317 for edge in graph.outgoing(pred) {
318 if let Some(&dep_stratum) = strata.get(&edge.to) {
319 let required = match edge.dep_type {
320 DepType::Positive => dep_stratum,
321 DepType::Negative | DepType::Aggregate => dep_stratum + 1,
322 };
323 min_stratum = min_stratum.max(required);
324 }
325 }
326 }
327 for pred in scc {
328 strata.insert(pred.clone(), min_stratum);
329 }
330 max_stratum = max_stratum.max(min_stratum);
331 }
332
333 StratificationResult {
334 sccs,
335 non_monotone_sccs,
336 strata,
337 }
338}
339
340pub fn find_sccs_for_lowering(graph: &DependencyGraph) -> Vec<Vec<String>> {
343 find_sccs(graph)
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::ast::*;
350
351 fn create_tc_program() -> Program {
352 let mut program = Program::new();
353 program.rules.push(Rule {
354 head: Atom {
355 predicate: "edge".into(),
356 terms: vec![Term::Integer(1), Term::Integer(2)],
357 },
358 body: vec![],
359 });
360 program.rules.push(Rule {
361 head: Atom {
362 predicate: "reach".into(),
363 terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
364 },
365 body: vec![BodyLiteral::Positive(Atom {
366 predicate: "edge".into(),
367 terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
368 })],
369 });
370 program.rules.push(Rule {
371 head: Atom {
372 predicate: "reach".into(),
373 terms: vec![Term::Variable("X".into()), Term::Variable("Z".into())],
374 },
375 body: vec![
376 BodyLiteral::Positive(Atom {
377 predicate: "reach".into(),
378 terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
379 }),
380 BodyLiteral::Positive(Atom {
381 predicate: "edge".into(),
382 terms: vec![Term::Variable("Y".into()), Term::Variable("Z".into())],
383 }),
384 ],
385 });
386 program
387 }
388
389 fn create_isolated_program() -> Program {
390 let mut program = Program::new();
391 for i in 1..=3 {
392 program.rules.push(Rule {
393 head: Atom {
394 predicate: "node".into(),
395 terms: vec![Term::Integer(i)],
396 },
397 body: vec![],
398 });
399 }
400 program.rules.push(Rule {
401 head: Atom {
402 predicate: "edge".into(),
403 terms: vec![Term::Integer(1), Term::Integer(2)],
404 },
405 body: vec![],
406 });
407 program.rules.push(Rule {
408 head: Atom {
409 predicate: "isolated".into(),
410 terms: vec![Term::Variable("X".into())],
411 },
412 body: vec![
413 BodyLiteral::Positive(Atom {
414 predicate: "node".into(),
415 terms: vec![Term::Variable("X".into())],
416 }),
417 BodyLiteral::Negated(Atom {
418 predicate: "edge".into(),
419 terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
420 }),
421 ],
422 });
423 program
424 }
425
426 fn create_unstratifiable_program() -> Program {
427 let mut program = Program::new();
428 program.rules.push(Rule {
429 head: Atom {
430 predicate: "p".into(),
431 terms: vec![],
432 },
433 body: vec![BodyLiteral::Negated(Atom {
434 predicate: "q".into(),
435 terms: vec![],
436 })],
437 });
438 program.rules.push(Rule {
439 head: Atom {
440 predicate: "q".into(),
441 terms: vec![],
442 },
443 body: vec![BodyLiteral::Negated(Atom {
444 predicate: "p".into(),
445 terms: vec![],
446 })],
447 });
448 program
449 }
450
451 #[test]
452 fn test_stratify_simple() {
453 let program = create_tc_program();
454 let result = stratify(&program);
455 assert!(result.is_ok(), "Stratification failed: {:?}", result.err());
456 }
457
458 #[test]
459 fn test_stratify_with_negation() {
460 let program = create_isolated_program();
461 let result = stratify(&program);
462 assert!(result.is_ok(), "Stratification failed: {:?}", result.err());
463 let strata = result.unwrap();
464 assert!(
465 strata.len() >= 2,
466 "Expected at least 2 strata, got {}",
467 strata.len()
468 );
469 }
470
471 #[test]
472 fn test_stratify_cycle_through_negation() {
473 let program = create_unstratifiable_program();
474 let result = stratify(&program);
475 assert!(result.is_err(), "Should fail with cycle through negation");
476 if let Err(XlogError::StratificationCycle(preds)) = result {
477 assert!(preds.contains(&"p".to_string()) || preds.contains(&"q".to_string()));
478 }
479 }
480
481 #[test]
482 fn test_stratify_probabilistic_non_monotone_allows_exact_ddnnf() {
483 let mut program = create_unstratifiable_program();
485 program.directives.prob_engine = Some(ProbEngine::ExactDdnnf);
486
487 let result = stratify(&program);
488 assert!(
489 result.is_ok(),
490 "Expected exact_ddnnf to allow non-monotone recursion (via WFS), got: {:?}",
491 result.err()
492 );
493 }
494
495 #[test]
496 fn test_stratify_probabilistic_non_monotone_allows_mc() {
497 let mut program = create_unstratifiable_program();
498 program.directives.prob_engine = Some(ProbEngine::Mc);
499
500 let result = stratify(&program);
501 assert!(
502 result.is_ok(),
503 "Expected mc to allow non-monotone recursion, got: {:?}",
504 result.err()
505 );
506 }
507
508 #[test]
509 fn test_dependency_graph_construction() {
510 let program = create_tc_program();
511 let graph = build_dependency_graph(&program);
512 assert!(graph.predicates.contains("edge"));
513 assert!(graph.predicates.contains("reach"));
514 let reach_deps = graph.outgoing("reach");
515 assert!(!reach_deps.is_empty());
516 }
517
518 #[test]
519 fn test_analyze_stratification_detects_non_monotone() {
520 let program = create_unstratifiable_program(); let result = analyze_stratification(&program);
522
523 assert!(
524 !result.non_monotone_sccs.is_empty(),
525 "Should detect non-monotone SCC"
526 );
527 let has_non_monotone = result.sccs.iter().enumerate().any(|(i, scc)| {
529 result.non_monotone_sccs.contains(&i)
530 && (scc.contains(&"p".to_string()) || scc.contains(&"q".to_string()))
531 });
532 assert!(has_non_monotone, "SCC with p/q should be non-monotone");
533 }
534
535 #[test]
536 fn test_analyze_stratification_stratified_program() {
537 let program = create_isolated_program(); let result = analyze_stratification(&program);
539
540 assert!(
541 result.non_monotone_sccs.is_empty(),
542 "Stratified program has no non-monotone SCCs"
543 );
544 assert!(
545 result.strata.contains_key("isolated"),
546 "isolated should have a stratum"
547 );
548 assert!(
549 result.strata.contains_key("edge"),
550 "edge should have a stratum"
551 );
552
553 let isolated_stratum = result.strata.get("isolated").unwrap();
555 let edge_stratum = result.strata.get("edge").unwrap();
556 assert!(
557 isolated_stratum > edge_stratum,
558 "isolated should be in higher stratum than edge"
559 );
560 }
561}