1use xlog_core::{Result, XlogError};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
6pub struct PirNodeId(u32);
7
8impl PirNodeId {
9 pub fn from_u32(id: u32) -> Self {
10 Self(id)
11 }
12
13 pub fn as_u32(self) -> u32 {
14 self.0
15 }
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
19pub struct LeafId(u32);
20
21impl LeafId {
22 pub fn new(id: u32) -> Self {
23 Self(id)
24 }
25
26 pub fn as_u32(self) -> u32 {
27 self.0
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
32pub struct ChoiceVarId(u32);
33
34impl ChoiceVarId {
35 pub fn new(id: u32) -> Self {
36 Self(id)
37 }
38
39 pub fn as_u32(self) -> u32 {
40 self.0
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
45pub enum PirNode {
46 Const(bool),
47 Lit {
48 leaf: LeafId,
49 },
50 NegLit {
51 leaf: LeafId,
52 }, And {
54 children: Vec<PirNodeId>,
55 },
56 Or {
57 children: Vec<PirNodeId>,
58 },
59 Decision {
60 var: ChoiceVarId,
61 child_false: PirNodeId,
62 child_true: PirNodeId,
63 },
64}
65
66#[derive(Debug, Default, Clone)]
67pub struct PirGraph {
68 nodes: Vec<PirNode>,
69}
70
71impl PirGraph {
72 pub fn new() -> Self {
73 Self { nodes: Vec::new() }
74 }
75
76 pub fn len(&self) -> usize {
77 self.nodes.len()
78 }
79
80 pub fn is_empty(&self) -> bool {
81 self.nodes.is_empty()
82 }
83
84 pub fn nodes(&self) -> &[PirNode] {
85 &self.nodes
86 }
87
88 pub(crate) fn nodes_mut(&mut self) -> &mut [PirNode] {
89 &mut self.nodes
90 }
91
92 pub fn node(&self, id: PirNodeId) -> Option<&PirNode> {
93 self.nodes.get(id.0 as usize)
94 }
95
96 fn push_node(&mut self, node: PirNode) -> PirNodeId {
97 let id = PirNodeId(u32::try_from(self.nodes.len()).expect("PIR node id overflow"));
98 self.nodes.push(node);
99 id
100 }
101
102 pub fn const_true(&mut self) -> PirNodeId {
103 self.push_node(PirNode::Const(true))
104 }
105
106 pub fn const_false(&mut self) -> PirNodeId {
107 self.push_node(PirNode::Const(false))
108 }
109
110 pub fn lit(&mut self, leaf: LeafId) -> PirNodeId {
111 self.push_node(PirNode::Lit { leaf })
112 }
113
114 pub fn neg_lit(&mut self, leaf: LeafId) -> PirNodeId {
115 self.push_node(PirNode::NegLit { leaf })
116 }
117
118 pub fn and(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
119 self.push_node(PirNode::And { children })
120 }
121
122 pub fn or(&mut self, children: Vec<PirNodeId>) -> PirNodeId {
123 self.push_node(PirNode::Or { children })
124 }
125
126 pub fn decision(
127 &mut self,
128 var: ChoiceVarId,
129 child_false: PirNodeId,
130 child_true: PirNodeId,
131 ) -> PirNodeId {
132 self.push_node(PirNode::Decision {
133 var,
134 child_false,
135 child_true,
136 })
137 }
138
139 pub fn levelize(&self, roots: &[PirNodeId]) -> Result<Vec<Vec<PirNodeId>>> {
140 use std::collections::{HashMap, HashSet};
141
142 let mut visiting: HashSet<PirNodeId> = HashSet::new();
143 let mut levels: HashMap<PirNodeId, u32> = HashMap::new();
144
145 fn compute_level(
146 graph: &PirGraph,
147 id: PirNodeId,
148 visiting: &mut HashSet<PirNodeId>,
149 levels: &mut HashMap<PirNodeId, u32>,
150 ) -> Result<u32> {
151 if let Some(&lvl) = levels.get(&id) {
152 return Ok(lvl);
153 }
154 if !visiting.insert(id) {
155 return Err(XlogError::Compilation(format!(
156 "Cycle detected while levelizing PIR at node {:?}",
157 id
158 )));
159 }
160
161 let node = graph.node(id).ok_or_else(|| {
162 XlogError::Compilation(format!("Invalid PIR node id while levelizing: {:?}", id))
163 })?;
164
165 let lvl = match node {
166 PirNode::Const(_) | PirNode::Lit { .. } | PirNode::NegLit { .. } => 0,
167 PirNode::And { children } | PirNode::Or { children } => {
168 let mut max_child = 0u32;
169 for &child in children {
170 let child_lvl = compute_level(graph, child, visiting, levels)?;
171 max_child = max_child.max(child_lvl);
172 }
173 max_child + 1
174 }
175 PirNode::Decision {
176 child_false,
177 child_true,
178 ..
179 } => {
180 let lf = compute_level(graph, *child_false, visiting, levels)?;
181 let lt = compute_level(graph, *child_true, visiting, levels)?;
182 lf.max(lt) + 1
183 }
184 };
185
186 visiting.remove(&id);
187 levels.insert(id, lvl);
188 Ok(lvl)
189 }
190
191 for &root in roots {
192 compute_level(self, root, &mut visiting, &mut levels)?;
193 }
194
195 let max_level = levels.values().copied().max().unwrap_or(0);
196 let mut buckets: Vec<Vec<PirNodeId>> = vec![Vec::new(); (max_level as usize) + 1];
197
198 let mut ids: Vec<PirNodeId> = levels.keys().copied().collect();
199 ids.sort_by_key(|id| id.0);
200 for id in ids {
201 let lvl = levels[&id] as usize;
202 buckets[lvl].push(id);
203 }
204
205 Ok(buckets)
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_levelize_simple_dag() {
215 let mut pir = PirGraph::new();
216 let a = pir.lit(LeafId(0));
217 let b = pir.lit(LeafId(1));
218 let and_ab = pir.and(vec![a, b]);
219 let root = pir.or(vec![and_ab, a]);
220
221 let levels = pir.levelize(&[root]).unwrap();
222
223 assert!(!levels.is_empty());
224 assert!(levels.iter().any(|lvl| lvl.contains(&root)));
225 assert!(levels.iter().any(|lvl| lvl.contains(&a)));
226 assert!(levels.iter().any(|lvl| lvl.contains(&b)));
227 }
228
229 #[test]
230 fn test_levelize_decision_node() {
231 let mut pir = PirGraph::new();
232 let t = pir.const_true();
233 let f = pir.const_false();
234 let d = pir.decision(ChoiceVarId(0), f, t);
235
236 let levels = pir.levelize(&[d]).unwrap();
237 assert!(levels.iter().any(|lvl| lvl.contains(&d)));
238 assert!(levels.iter().any(|lvl| lvl.contains(&t)));
239 assert!(levels.iter().any(|lvl| lvl.contains(&f)));
240 }
241}