Skip to main content

xlog_prob/
pir.rs

1//! Provenance IR (PIR) for probabilistic compilation.
2
3use xlog_core::{Result, XlogError};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
6pub struct PirNodeId(u32);
7
8impl PirNodeId {
9    pub fn from_u32(id: u32) -> Self {
10        Self(id)
11    }
12
13    pub fn as_u32(self) -> u32 {
14        self.0
15    }
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
19pub struct LeafId(u32);
20
21impl LeafId {
22    pub fn new(id: u32) -> Self {
23        Self(id)
24    }
25
26    pub fn as_u32(self) -> u32 {
27        self.0
28    }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
32pub struct ChoiceVarId(u32);
33
34impl ChoiceVarId {
35    pub fn new(id: u32) -> Self {
36        Self(id)
37    }
38
39    pub fn as_u32(self) -> u32 {
40        self.0
41    }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum PirNode {
46    Const(bool),
47    Lit {
48        leaf: LeafId,
49    },
50    NegLit {
51        leaf: LeafId,
52    }, // Negated leaf: weight (1-p, p)
53    And {
54        children: Vec<PirNodeId>,
55    },
56    Or {
57        children: Vec<PirNodeId>,
58    },
59    Decision {
60        var: ChoiceVarId,
61        child_false: PirNodeId,
62        child_true: PirNodeId,
63    },
64}
65
66#[derive(Debug, Default, Clone)]
67pub struct PirGraph {
68    nodes: Vec<PirNode>,
69}
70
71impl PirGraph {
72    pub fn new() -> Self {
73        Self { nodes: Vec::new() }
74    }
75
76    pub fn len(&self) -> usize {
77        self.nodes.len()
78    }
79
80    pub fn is_empty(&self) -> bool {
81        self.nodes.is_empty()
82    }
83
84    pub fn nodes(&self) -> &[PirNode] {
85        &self.nodes
86    }
87
88    pub(crate) fn nodes_mut(&mut self) -> &mut [PirNode] {
89        &mut self.nodes
90    }
91
92    pub fn node(&self, id: PirNodeId) -> Option<&PirNode> {
93        self.nodes.get(id.0 as usize)
94    }
95
96    fn push_node(&mut self, node: PirNode) -> PirNodeId {
97        let id = PirNodeId(u32::try_from(self.nodes.len()).expect("PIR node id overflow"));
98        self.nodes.push(node);
99        id
100    }
101
102    pub fn const_true(&mut self) -> PirNodeId {
103        self.push_node(PirNode::Const(true))
104    }
105
106    pub fn const_false(&mut self) -> PirNodeId {
107        self.push_node(PirNode::Const(false))
108    }
109
110    pub fn lit(&mut self, leaf: LeafId) -> PirNodeId {
111        self.push_node(PirNode::Lit { leaf })
112    }
113
114    pub fn neg_lit(&mut self, leaf: LeafId) -> PirNodeId {
115        self.push_node(PirNode::NegLit { leaf })
116    }
117
118    pub fn and(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
119        self.push_node(PirNode::And { children })
120    }
121
122    pub fn or(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
123        self.push_node(PirNode::Or { children })
124    }
125
126    pub fn decision(
127        &mut self,
128        var: ChoiceVarId,
129        child_false: PirNodeId,
130        child_true: PirNodeId,
131    ) -> PirNodeId {
132        self.push_node(PirNode::Decision {
133            var,
134            child_false,
135            child_true,
136        })
137    }
138
139    pub fn levelize(&self, roots: &[PirNodeId]) -> Result<Vec<Vec<PirNodeId>>> {
140        use std::collections::{HashMap, HashSet};
141
142        let mut visiting: HashSet<PirNodeId> = HashSet::new();
143        let mut levels: HashMap<PirNodeId, u32> = HashMap::new();
144
145        fn compute_level(
146            graph: &PirGraph,
147            id: PirNodeId,
148            visiting: &mut HashSet<PirNodeId>,
149            levels: &mut HashMap<PirNodeId, u32>,
150        ) -> Result<u32> {
151            if let Some(&lvl) = levels.get(&id) {
152                return Ok(lvl);
153            }
154            if !visiting.insert(id) {
155                return Err(XlogError::Compilation(format!(
156                    "Cycle detected while levelizing PIR at node {:?}",
157                    id
158                )));
159            }
160
161            let node = graph.node(id).ok_or_else(|| {
162                XlogError::Compilation(format!("Invalid PIR node id while levelizing: {:?}", id))
163            })?;
164
165            let lvl = match node {
166                PirNode::Const(_) | PirNode::Lit { .. } | PirNode::NegLit { .. } => 0,
167                PirNode::And { children } | PirNode::Or { children } => {
168                    let mut max_child = 0u32;
169                    for &child in children {
170                        let child_lvl = compute_level(graph, child, visiting, levels)?;
171                        max_child = max_child.max(child_lvl);
172                    }
173                    max_child + 1
174                }
175                PirNode::Decision {
176                    child_false,
177                    child_true,
178                    ..
179                } => {
180                    let lf = compute_level(graph, *child_false, visiting, levels)?;
181                    let lt = compute_level(graph, *child_true, visiting, levels)?;
182                    lf.max(lt) + 1
183                }
184            };
185
186            visiting.remove(&id);
187            levels.insert(id, lvl);
188            Ok(lvl)
189        }
190
191        for &root in roots {
192            compute_level(self, root, &mut visiting, &mut levels)?;
193        }
194
195        let max_level = levels.values().copied().max().unwrap_or(0);
196        let mut buckets: Vec<Vec<PirNodeId>> = vec![Vec::new(); (max_level as usize) + 1];
197
198        let mut ids: Vec<PirNodeId> = levels.keys().copied().collect();
199        ids.sort_by_key(|id| id.0);
200        for id in ids {
201            let lvl = levels[&id] as usize;
202            buckets[lvl].push(id);
203        }
204
205        Ok(buckets)
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_levelize_simple_dag() {
215        let mut pir = PirGraph::new();
216        let a = pir.lit(LeafId(0));
217        let b = pir.lit(LeafId(1));
218        let and_ab = pir.and(vec![a, b]);
219        let root = pir.or(vec![and_ab, a]);
220
221        let levels = pir.levelize(&[root]).unwrap();
222
223        assert!(!levels.is_empty());
224        assert!(levels.iter().any(|lvl| lvl.contains(&root)));
225        assert!(levels.iter().any(|lvl| lvl.contains(&a)));
226        assert!(levels.iter().any(|lvl| lvl.contains(&b)));
227    }
228
229    #[test]
230    fn test_levelize_decision_node() {
231        let mut pir = PirGraph::new();
232        let t = pir.const_true();
233        let f = pir.const_false();
234        let d = pir.decision(ChoiceVarId(0), f, t);
235
236        let levels = pir.levelize(&[d]).unwrap();
237        assert!(levels.iter().any(|lvl| lvl.contains(&d)));
238        assert!(levels.iter().any(|lvl| lvl.contains(&t)));
239        assert!(levels.iter().any(|lvl| lvl.contains(&f)));
240    }
241}