Skip to main content

xlog_logic/
stratify.rs

1//! Stratification analysis for negation and aggregation
2
3use crate::ast::{BodyLiteral, Program};
4use std::collections::{HashMap, HashSet};
5use xlog_core::{Result, XlogError};
6
7/// Dependency edge type
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub(crate) enum DepType {
10    Positive,
11    Negative,
12    Aggregate,
13}
14
15/// Dependency graph edge
16#[derive(Debug, Clone)]
17pub(crate) struct DepEdge {
18    pub from: String,
19    pub to: String,
20    pub dep_type: DepType,
21}
22
23/// Dependency graph for stratification analysis.
24#[derive(Debug, Default)]
25pub struct DependencyGraph {
26    /// Set of all predicate names in the graph.
27    pub predicates: HashSet<String>,
28    pub(crate) edges: Vec<DepEdge>,
29}
30
31impl DependencyGraph {
32    /// Create an empty dependency graph.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Add a predicate node to the graph.
38    pub fn add_predicate(&mut self, name: String) {
39        self.predicates.insert(name);
40    }
41
42    pub(crate) fn add_edge(&mut self, from: String, to: String, dep_type: DepType) {
43        self.predicates.insert(from.clone());
44        self.predicates.insert(to.clone());
45        self.edges.push(DepEdge { from, to, dep_type });
46    }
47
48    pub(crate) fn outgoing(&self, pred: &str) -> Vec<&DepEdge> {
49        self.edges.iter().filter(|e| e.from == pred).collect()
50    }
51}
52
53/// Build dependency graph from program
54pub fn build_dependency_graph(program: &Program) -> DependencyGraph {
55    let mut graph = DependencyGraph::new();
56
57    for rule in &program.rules {
58        let head = &rule.head.predicate;
59        graph.add_predicate(head.clone());
60
61        for lit in &rule.body {
62            match lit {
63                BodyLiteral::Positive(atom) => {
64                    graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Positive);
65                }
66                BodyLiteral::Negated(atom) => {
67                    graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Negative);
68                }
69                BodyLiteral::Epistemic(lit) => {
70                    graph.add_edge(head.clone(), lit.atom.predicate.clone(), DepType::Negative);
71                }
72                BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => {}
73            }
74        }
75
76        if rule.has_aggregation() {
77            for lit in &rule.body {
78                if let BodyLiteral::Positive(atom) = lit {
79                    graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Aggregate);
80                }
81            }
82        }
83    }
84
85    // Learnable rules: head depends on body predicates.
86    // At runtime, TensorMaskedJoin dynamically selects which relations
87    // to join, but for stratification we conservatively register the
88    // template's body predicates as positive dependencies.
89    for lr in &program.learnable_rules {
90        let head = &lr.head.predicate;
91        graph.add_predicate(head.clone());
92        for body_lit in &lr.body {
93            if let Some(atom) = body_lit.atom() {
94                graph.add_predicate(atom.predicate.clone());
95                graph.add_edge(head.clone(), atom.predicate.clone(), DepType::Positive);
96            }
97        }
98    }
99
100    graph
101}
102
103/// Find strongly connected components using Tarjan's algorithm
104/// Returns SCCs in reverse topological order (dependencies first)
105fn find_sccs(graph: &DependencyGraph) -> Vec<Vec<String>> {
106    let mut index_counter = 0;
107    let mut stack = Vec::new();
108    let mut indices: HashMap<String, usize> = HashMap::new();
109    let mut lowlinks: HashMap<String, usize> = HashMap::new();
110    let mut on_stack: HashSet<String> = HashSet::new();
111    let mut sccs: Vec<Vec<String>> = Vec::new();
112
113    #[allow(clippy::too_many_arguments)]
114    fn strongconnect(
115        v: &str,
116        graph: &DependencyGraph,
117        index_counter: &mut usize,
118        stack: &mut Vec<String>,
119        indices: &mut HashMap<String, usize>,
120        lowlinks: &mut HashMap<String, usize>,
121        on_stack: &mut HashSet<String>,
122        sccs: &mut Vec<Vec<String>>,
123    ) {
124        indices.insert(v.to_string(), *index_counter);
125        lowlinks.insert(v.to_string(), *index_counter);
126        *index_counter += 1;
127        stack.push(v.to_string());
128        on_stack.insert(v.to_string());
129
130        for edge in graph.outgoing(v) {
131            let w = &edge.to;
132            if !indices.contains_key(w) {
133                strongconnect(
134                    w,
135                    graph,
136                    index_counter,
137                    stack,
138                    indices,
139                    lowlinks,
140                    on_stack,
141                    sccs,
142                );
143                let low_v = *lowlinks.get(v).unwrap();
144                let low_w = *lowlinks.get(w).unwrap();
145                lowlinks.insert(v.to_string(), low_v.min(low_w));
146            } else if on_stack.contains(w) {
147                let low_v = *lowlinks.get(v).unwrap();
148                let idx_w = *indices.get(w).unwrap();
149                lowlinks.insert(v.to_string(), low_v.min(idx_w));
150            }
151        }
152
153        let low_v = *lowlinks.get(v).unwrap();
154        let idx_v = *indices.get(v).unwrap();
155        if low_v == idx_v {
156            let mut scc = Vec::new();
157            loop {
158                let w = stack.pop().unwrap();
159                on_stack.remove(&w);
160                scc.push(w.clone());
161                if w == v {
162                    break;
163                }
164            }
165            sccs.push(scc);
166        }
167    }
168
169    for pred in &graph.predicates {
170        if !indices.contains_key(pred) {
171            strongconnect(
172                pred,
173                graph,
174                &mut index_counter,
175                &mut stack,
176                &mut indices,
177                &mut lowlinks,
178                &mut on_stack,
179                &mut sccs,
180            );
181        }
182    }
183
184    sccs
185}
186
187/// Check for cycles through negation/aggregation in an SCC
188fn check_scc_for_negation_cycle(scc: &[String], graph: &DependencyGraph) -> Option<Vec<String>> {
189    if scc.len() == 1 {
190        let pred = &scc[0];
191        for edge in graph.outgoing(pred) {
192            if edge.to == *pred && edge.dep_type != DepType::Positive {
193                return Some(vec![pred.clone()]);
194            }
195        }
196        return None;
197    }
198
199    let scc_set: HashSet<&str> = scc.iter().map(|s| s.as_str()).collect();
200    for pred in scc {
201        for edge in graph.outgoing(pred) {
202            if scc_set.contains(edge.to.as_str()) && edge.dep_type != DepType::Positive {
203                return Some(scc.to_vec());
204            }
205        }
206    }
207    None
208}
209
210/// Stratum assignment result.
211#[derive(Debug, Clone)]
212pub struct Stratum {
213    /// Stratum index (0 = base stratum).
214    pub id: usize,
215    /// Predicates assigned to this stratum.
216    pub predicates: Vec<String>,
217}
218
219/// Result of stratification analysis for probabilistic inference
220#[derive(Debug, Clone)]
221pub struct StratificationResult {
222    /// SCCs in evaluation order (dependencies first)
223    pub sccs: Vec<Vec<String>>,
224    /// Indices of SCCs that have cycles through negation (non-monotone)
225    pub non_monotone_sccs: HashSet<usize>,
226    /// Stratum number for each predicate (if fully stratified)
227    pub strata: HashMap<String, usize>,
228}
229
230/// Perform stratification analysis
231pub fn stratify(program: &Program) -> Result<Vec<Stratum>> {
232    let graph = build_dependency_graph(program);
233    let sccs = find_sccs(&graph);
234
235    for scc in &sccs {
236        if let Some(cycle) = check_scc_for_negation_cycle(scc, &graph) {
237            // Non-monotone programs are now supported for exact_ddnnf via WFS
238            // Only reject for non-probabilistic programs that require stratification
239            if !program.is_probabilistic_profile() {
240                return Err(XlogError::StratificationCycle(cycle));
241            }
242            // Probabilistic programs (both exact_ddnnf and mc) can handle non-monotone cycles:
243            // - exact_ddnnf uses WFS to compute well-founded model
244            // - mc uses sampling which naturally handles non-stratified programs
245        }
246    }
247
248    let mut stratum_map: HashMap<String, usize> = HashMap::new();
249    let mut max_stratum = 0;
250
251    // Tarjan produces SCCs in reverse topological order.
252    // Since edges go from "dependent" to "dependency" (head -> body predicate),
253    // SCCs of dependencies come first. Process in order to ensure dependencies
254    // are assigned strata before dependents.
255    for scc in &sccs {
256        let mut min_stratum = 0;
257        for pred in scc {
258            for edge in graph.outgoing(pred) {
259                if let Some(&dep_stratum) = stratum_map.get(&edge.to) {
260                    let required = match edge.dep_type {
261                        DepType::Positive => dep_stratum,
262                        DepType::Negative | DepType::Aggregate => dep_stratum + 1,
263                    };
264                    min_stratum = min_stratum.max(required);
265                }
266            }
267        }
268        for pred in scc {
269            stratum_map.insert(pred.clone(), min_stratum);
270        }
271        max_stratum = max_stratum.max(min_stratum);
272    }
273
274    let mut strata: Vec<Stratum> = (0..=max_stratum)
275        .map(|id| Stratum {
276            id,
277            predicates: vec![],
278        })
279        .collect();
280
281    for (pred, stratum) in stratum_map {
282        strata[stratum].predicates.push(pred);
283    }
284
285    strata.retain(|s| !s.predicates.is_empty());
286    for (i, stratum) in strata.iter_mut().enumerate() {
287        stratum.id = i;
288    }
289
290    Ok(strata)
291}
292
293/// Analyze stratification for probabilistic inference
294/// Returns detailed information about SCCs and which ones are non-monotone
295pub fn analyze_stratification(program: &Program) -> StratificationResult {
296    let graph = build_dependency_graph(program);
297    let sccs = find_sccs(&graph);
298
299    let mut non_monotone_sccs: HashSet<usize> = HashSet::new();
300    for (i, scc) in sccs.iter().enumerate() {
301        if check_scc_for_negation_cycle(scc, &graph).is_some() {
302            non_monotone_sccs.insert(i);
303        }
304    }
305
306    // Compute strata for predicates in stratified SCCs
307    let mut strata: HashMap<String, usize> = HashMap::new();
308    let mut max_stratum = 0;
309
310    for (scc_idx, scc) in sccs.iter().enumerate() {
311        if non_monotone_sccs.contains(&scc_idx) {
312            continue; // Skip non-monotone SCCs for stratum assignment
313        }
314
315        let mut min_stratum = 0;
316        for pred in scc {
317            for edge in graph.outgoing(pred) {
318                if let Some(&dep_stratum) = strata.get(&edge.to) {
319                    let required = match edge.dep_type {
320                        DepType::Positive => dep_stratum,
321                        DepType::Negative | DepType::Aggregate => dep_stratum + 1,
322                    };
323                    min_stratum = min_stratum.max(required);
324                }
325            }
326        }
327        for pred in scc {
328            strata.insert(pred.clone(), min_stratum);
329        }
330        max_stratum = max_stratum.max(min_stratum);
331    }
332
333    StratificationResult {
334        sccs,
335        non_monotone_sccs,
336        strata,
337    }
338}
339
340/// Find SCCs for the lowering phase
341/// Returns SCCs in reverse topological order (dependencies first)
342pub fn find_sccs_for_lowering(graph: &DependencyGraph) -> Vec<Vec<String>> {
343    find_sccs(graph)
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::ast::*;
350
351    fn create_tc_program() -> Program {
352        let mut program = Program::new();
353        program.rules.push(Rule {
354            head: Atom {
355                predicate: "edge".into(),
356                terms: vec![Term::Integer(1), Term::Integer(2)],
357            },
358            body: vec![],
359        });
360        program.rules.push(Rule {
361            head: Atom {
362                predicate: "reach".into(),
363                terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
364            },
365            body: vec![BodyLiteral::Positive(Atom {
366                predicate: "edge".into(),
367                terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
368            })],
369        });
370        program.rules.push(Rule {
371            head: Atom {
372                predicate: "reach".into(),
373                terms: vec![Term::Variable("X".into()), Term::Variable("Z".into())],
374            },
375            body: vec![
376                BodyLiteral::Positive(Atom {
377                    predicate: "reach".into(),
378                    terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
379                }),
380                BodyLiteral::Positive(Atom {
381                    predicate: "edge".into(),
382                    terms: vec![Term::Variable("Y".into()), Term::Variable("Z".into())],
383                }),
384            ],
385        });
386        program
387    }
388
389    fn create_isolated_program() -> Program {
390        let mut program = Program::new();
391        for i in 1..=3 {
392            program.rules.push(Rule {
393                head: Atom {
394                    predicate: "node".into(),
395                    terms: vec![Term::Integer(i)],
396                },
397                body: vec![],
398            });
399        }
400        program.rules.push(Rule {
401            head: Atom {
402                predicate: "edge".into(),
403                terms: vec![Term::Integer(1), Term::Integer(2)],
404            },
405            body: vec![],
406        });
407        program.rules.push(Rule {
408            head: Atom {
409                predicate: "isolated".into(),
410                terms: vec![Term::Variable("X".into())],
411            },
412            body: vec![
413                BodyLiteral::Positive(Atom {
414                    predicate: "node".into(),
415                    terms: vec![Term::Variable("X".into())],
416                }),
417                BodyLiteral::Negated(Atom {
418                    predicate: "edge".into(),
419                    terms: vec![Term::Variable("X".into()), Term::Variable("Y".into())],
420                }),
421            ],
422        });
423        program
424    }
425
426    fn create_unstratifiable_program() -> Program {
427        let mut program = Program::new();
428        program.rules.push(Rule {
429            head: Atom {
430                predicate: "p".into(),
431                terms: vec![],
432            },
433            body: vec![BodyLiteral::Negated(Atom {
434                predicate: "q".into(),
435                terms: vec![],
436            })],
437        });
438        program.rules.push(Rule {
439            head: Atom {
440                predicate: "q".into(),
441                terms: vec![],
442            },
443            body: vec![BodyLiteral::Negated(Atom {
444                predicate: "p".into(),
445                terms: vec![],
446            })],
447        });
448        program
449    }
450
451    #[test]
452    fn test_stratify_simple() {
453        let program = create_tc_program();
454        let result = stratify(&program);
455        assert!(result.is_ok(), "Stratification failed: {:?}", result.err());
456    }
457
458    #[test]
459    fn test_stratify_with_negation() {
460        let program = create_isolated_program();
461        let result = stratify(&program);
462        assert!(result.is_ok(), "Stratification failed: {:?}", result.err());
463        let strata = result.unwrap();
464        assert!(
465            strata.len() >= 2,
466            "Expected at least 2 strata, got {}",
467            strata.len()
468        );
469    }
470
471    #[test]
472    fn test_stratify_cycle_through_negation() {
473        let program = create_unstratifiable_program();
474        let result = stratify(&program);
475        assert!(result.is_err(), "Should fail with cycle through negation");
476        if let Err(XlogError::StratificationCycle(preds)) = result {
477            assert!(preds.contains(&"p".to_string()) || preds.contains(&"q".to_string()));
478        }
479    }
480
481    #[test]
482    fn test_stratify_probabilistic_non_monotone_allows_exact_ddnnf() {
483        // Non-monotone programs are now supported with exact_ddnnf via WFS
484        let mut program = create_unstratifiable_program();
485        program.directives.prob_engine = Some(ProbEngine::ExactDdnnf);
486
487        let result = stratify(&program);
488        assert!(
489            result.is_ok(),
490            "Expected exact_ddnnf to allow non-monotone recursion (via WFS), got: {:?}",
491            result.err()
492        );
493    }
494
495    #[test]
496    fn test_stratify_probabilistic_non_monotone_allows_mc() {
497        let mut program = create_unstratifiable_program();
498        program.directives.prob_engine = Some(ProbEngine::Mc);
499
500        let result = stratify(&program);
501        assert!(
502            result.is_ok(),
503            "Expected mc to allow non-monotone recursion, got: {:?}",
504            result.err()
505        );
506    }
507
508    #[test]
509    fn test_dependency_graph_construction() {
510        let program = create_tc_program();
511        let graph = build_dependency_graph(&program);
512        assert!(graph.predicates.contains("edge"));
513        assert!(graph.predicates.contains("reach"));
514        let reach_deps = graph.outgoing("reach");
515        assert!(!reach_deps.is_empty());
516    }
517
518    #[test]
519    fn test_analyze_stratification_detects_non_monotone() {
520        let program = create_unstratifiable_program(); // p :- not q. q :- not p.
521        let result = analyze_stratification(&program);
522
523        assert!(
524            !result.non_monotone_sccs.is_empty(),
525            "Should detect non-monotone SCC"
526        );
527        // The SCC containing p and q should be marked as non-monotone
528        let has_non_monotone = result.sccs.iter().enumerate().any(|(i, scc)| {
529            result.non_monotone_sccs.contains(&i)
530                && (scc.contains(&"p".to_string()) || scc.contains(&"q".to_string()))
531        });
532        assert!(has_non_monotone, "SCC with p/q should be non-monotone");
533    }
534
535    #[test]
536    fn test_analyze_stratification_stratified_program() {
537        let program = create_isolated_program(); // isolated(X) :- node(X), not edge(X, Y).
538        let result = analyze_stratification(&program);
539
540        assert!(
541            result.non_monotone_sccs.is_empty(),
542            "Stratified program has no non-monotone SCCs"
543        );
544        assert!(
545            result.strata.contains_key("isolated"),
546            "isolated should have a stratum"
547        );
548        assert!(
549            result.strata.contains_key("edge"),
550            "edge should have a stratum"
551        );
552
553        // isolated depends negatively on edge, so isolated.stratum > edge.stratum
554        let isolated_stratum = result.strata.get("isolated").unwrap();
555        let edge_stratum = result.strata.get("edge").unwrap();
556        assert!(
557            isolated_stratum > edge_stratum,
558            "isolated should be in higher stratum than edge"
559        );
560    }
561}