1use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
4
5use xlog_core::{Result, XlogError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum DdnnfNodeKind {
9 Or,
10 And,
11 True,
12 False,
13}
14
15#[derive(Debug, Clone)]
16pub struct DdnnfNode {
17 pub kind: DdnnfNodeKind,
18}
19
20#[derive(Debug, Clone)]
21pub struct DdnnfEdge {
22 pub from: u32,
23 pub to: u32,
24 pub lits: Vec<i32>,
25}
26
27#[derive(Debug, Clone)]
28pub struct DecisionDnnf {
29 root: u32,
30 nodes: BTreeMap<u32, DdnnfNode>,
31 edges: Vec<DdnnfEdge>,
32 outgoing: BTreeMap<u32, Vec<usize>>,
33 max_var: u32,
34}
35
36impl DecisionDnnf {
37 pub fn root(&self) -> u32 {
38 self.root
39 }
40
41 pub fn max_var(&self) -> u32 {
42 self.max_var
43 }
44
45 pub fn node_kind(&self, node_id: u32) -> Option<DdnnfNodeKind> {
46 self.nodes.get(&node_id).map(|n| n.kind)
47 }
48
49 pub fn outgoing_edge_indices(&self, node_id: u32) -> Option<&[usize]> {
50 self.outgoing.get(&node_id).map(|v| v.as_slice())
51 }
52
53 pub fn edge(&self, edge_idx: usize) -> Option<&DdnnfEdge> {
54 self.edges.get(edge_idx)
55 }
56
57 pub fn parse_str(input: &str) -> Result<Self> {
58 let mut nodes: BTreeMap<u32, DdnnfNode> = BTreeMap::new();
59 let mut edges: Vec<DdnnfEdge> = Vec::new();
60 let mut targets: HashSet<u32> = HashSet::new();
61 let mut max_var: u32 = 0;
62
63 for (line_no, raw_line) in input.lines().enumerate() {
64 let line = raw_line.trim();
65 if line.is_empty() {
66 continue;
67 }
68
69 let mut tokens: Vec<&str> = line.split_whitespace().collect();
70 if tokens.is_empty() {
71 continue;
72 }
73
74 if tokens.last() != Some(&"0") {
75 return Err(XlogError::Compilation(format!(
76 "Decision-DNNF parse error at line {}: missing 0 terminator",
77 line_no + 1
78 )));
79 }
80 tokens.pop();
81 if tokens.is_empty() {
82 return Err(XlogError::Compilation(format!(
83 "Decision-DNNF parse error at line {}: empty record before terminator",
84 line_no + 1
85 )));
86 }
87
88 match tokens[0] {
89 "o" | "a" | "t" | "f" => {
90 if tokens.len() < 2 {
91 return Err(XlogError::Compilation(format!(
92 "Decision-DNNF parse error at line {}: node record missing id",
93 line_no + 1
94 )));
95 }
96 let id: u32 = tokens[1].parse().map_err(|_| {
97 XlogError::Compilation(format!(
98 "Decision-DNNF parse error at line {}: invalid node id '{}'",
99 line_no + 1,
100 tokens[1]
101 ))
102 })?;
103
104 let kind = match tokens[0] {
105 "o" => DdnnfNodeKind::Or,
106 "a" => DdnnfNodeKind::And,
107 "t" => DdnnfNodeKind::True,
108 "f" => DdnnfNodeKind::False,
109 _ => unreachable!(),
110 };
111
112 if nodes.insert(id, DdnnfNode { kind }).is_some() {
113 return Err(XlogError::Compilation(format!(
114 "Decision-DNNF parse error at line {}: duplicate node id {}",
115 line_no + 1,
116 id
117 )));
118 }
119 }
120 _ => {
121 if tokens.len() < 2 {
122 return Err(XlogError::Compilation(format!(
123 "Decision-DNNF parse error at line {}: edge record missing dst",
124 line_no + 1
125 )));
126 }
127 let from: u32 = tokens[0].parse().map_err(|_| {
128 XlogError::Compilation(format!(
129 "Decision-DNNF parse error at line {}: invalid edge src '{}'",
130 line_no + 1,
131 tokens[0]
132 ))
133 })?;
134 let to: u32 = tokens[1].parse().map_err(|_| {
135 XlogError::Compilation(format!(
136 "Decision-DNNF parse error at line {}: invalid edge dst '{}'",
137 line_no + 1,
138 tokens[1]
139 ))
140 })?;
141
142 let mut lits: Vec<i32> = Vec::new();
143 for &tok in &tokens[2..] {
144 let lit: i32 = tok.parse().map_err(|_| {
145 XlogError::Compilation(format!(
146 "Decision-DNNF parse error at line {}: invalid literal '{}'",
147 line_no + 1,
148 tok
149 ))
150 })?;
151 if lit == 0 {
152 return Err(XlogError::Compilation(format!(
153 "Decision-DNNF parse error at line {}: literal cannot be 0",
154 line_no + 1
155 )));
156 }
157 max_var = max_var.max(lit.unsigned_abs());
158 lits.push(lit);
159 }
160
161 let edge_id = edges.len();
162 edges.push(DdnnfEdge { from, to, lits });
163 targets.insert(to);
164
165 let _ = edge_id;
167 }
168 }
169 }
170
171 if nodes.is_empty() {
172 return Err(XlogError::Compilation(
173 "Decision-DNNF parse error: no nodes found".to_string(),
174 ));
175 }
176
177 for edge in &edges {
178 let from_kind = nodes.get(&edge.from).ok_or_else(|| {
179 XlogError::Compilation(format!(
180 "Decision-DNNF parse error: edge references unknown src node {}",
181 edge.from
182 ))
183 })?;
184 let _to_kind = nodes.get(&edge.to).ok_or_else(|| {
185 XlogError::Compilation(format!(
186 "Decision-DNNF parse error: edge references unknown dst node {}",
187 edge.to
188 ))
189 })?;
190
191 match from_kind.kind {
192 DdnnfNodeKind::Or | DdnnfNodeKind::And => {}
193 DdnnfNodeKind::True | DdnnfNodeKind::False => {
194 return Err(XlogError::Compilation(format!(
195 "Decision-DNNF parse error: leaf node {} cannot have outgoing edges",
196 edge.from
197 )));
198 }
199 }
200 }
201
202 let declared: BTreeSet<u32> = nodes.keys().copied().collect();
203 let target_set: BTreeSet<u32> = targets.into_iter().collect();
204 let roots: Vec<u32> = declared.difference(&target_set).copied().collect();
205 let root = match roots.as_slice() {
206 [only] => *only,
207 [] => {
208 return Err(XlogError::Compilation(
209 "Decision-DNNF parse error: could not infer root (no root candidates)"
210 .to_string(),
211 ))
212 }
213 many => {
214 return Err(XlogError::Compilation(format!(
215 "Decision-DNNF parse error: could not infer unique root (candidates: {:?})",
216 many
217 )))
218 }
219 };
220
221 let mut outgoing: BTreeMap<u32, Vec<usize>> = BTreeMap::new();
222 for (idx, edge) in edges.iter().enumerate() {
223 outgoing.entry(edge.from).or_default().push(idx);
224 }
225
226 Self::check_acyclic(root, &nodes, &edges, &outgoing)?;
228
229 Ok(Self {
230 root,
231 nodes,
232 edges,
233 outgoing,
234 max_var,
235 })
236 }
237
238 fn check_acyclic(
239 root: u32,
240 nodes: &BTreeMap<u32, DdnnfNode>,
241 edges: &[DdnnfEdge],
242 outgoing: &BTreeMap<u32, Vec<usize>>,
243 ) -> Result<()> {
244 let mut visiting: HashSet<u32> = HashSet::new();
245 let mut visited: HashSet<u32> = HashSet::new();
246
247 fn dfs(
248 node_id: u32,
249 nodes: &BTreeMap<u32, DdnnfNode>,
250 edges: &[DdnnfEdge],
251 outgoing: &BTreeMap<u32, Vec<usize>>,
252 visiting: &mut HashSet<u32>,
253 visited: &mut HashSet<u32>,
254 ) -> Result<()> {
255 if visited.contains(&node_id) {
256 return Ok(());
257 }
258 if !visiting.insert(node_id) {
259 return Err(XlogError::Compilation(format!(
260 "Decision-DNNF parse error: cycle detected at node {}",
261 node_id
262 )));
263 }
264
265 let node = nodes.get(&node_id).ok_or_else(|| {
266 XlogError::Compilation(format!(
267 "Decision-DNNF parse error: unknown node {} during cycle check",
268 node_id
269 ))
270 })?;
271
272 match node.kind {
273 DdnnfNodeKind::True | DdnnfNodeKind::False => {}
274 DdnnfNodeKind::Or | DdnnfNodeKind::And => {
275 if let Some(out) = outgoing.get(&node_id) {
276 for &edge_idx in out {
277 let edge = &edges[edge_idx];
278 dfs(edge.to, nodes, edges, outgoing, visiting, visited)?;
279 }
280 }
281 }
282 }
283
284 visiting.remove(&node_id);
285 visited.insert(node_id);
286 Ok(())
287 }
288
289 dfs(root, nodes, edges, outgoing, &mut visiting, &mut visited)
290 }
291
292 pub fn eval_log_wmc<F>(&self, var_log_weights: F) -> Result<f64>
293 where
294 F: Fn(u32) -> (f64, f64),
295 {
296 let mut memo: HashMap<u32, f64> = HashMap::new();
297
298 fn logsumexp(values: &[f64]) -> f64 {
299 let mut max = f64::NEG_INFINITY;
300 for &v in values {
301 if v > max {
302 max = v;
303 }
304 }
305 if max.is_infinite() {
306 return max;
307 }
308 let mut sum = 0.0;
309 for &v in values {
310 sum += (v - max).exp();
311 }
312 max + sum.ln()
313 }
314
315 fn eval_node<F>(
316 node_id: u32,
317 ddnnf: &DecisionDnnf,
318 memo: &mut HashMap<u32, f64>,
319 var_log_weights: &F,
320 ) -> Result<f64>
321 where
322 F: Fn(u32) -> (f64, f64),
323 {
324 if let Some(&v) = memo.get(&node_id) {
325 return Ok(v);
326 }
327
328 let node = ddnnf.nodes.get(&node_id).ok_or_else(|| {
329 XlogError::Compilation(format!(
330 "Decision-DNNF eval error: unknown node {}",
331 node_id
332 ))
333 })?;
334
335 let value = match node.kind {
336 DdnnfNodeKind::True => 0.0,
337 DdnnfNodeKind::False => f64::NEG_INFINITY,
338 DdnnfNodeKind::And => {
339 let out = ddnnf.outgoing.get(&node_id).ok_or_else(|| {
340 XlogError::Compilation(format!(
341 "Decision-DNNF eval error: AND node {} has no children",
342 node_id
343 ))
344 })?;
345
346 let mut acc = 0.0;
347 for &edge_idx in out {
348 let edge = &ddnnf.edges[edge_idx];
349 let child = eval_node(edge.to, ddnnf, memo, var_log_weights)?;
350 let mut lit_sum = 0.0;
351 for &lit in &edge.lits {
352 let var = lit.unsigned_abs();
353 let (t, f) = var_log_weights(var);
354 lit_sum += if lit > 0 { t } else { f };
355 }
356 acc += lit_sum + child;
357 }
358 acc
359 }
360 DdnnfNodeKind::Or => {
361 let out = ddnnf.outgoing.get(&node_id).ok_or_else(|| {
362 XlogError::Compilation(format!(
363 "Decision-DNNF eval error: OR node {} has no children",
364 node_id
365 ))
366 })?;
367
368 let mut branch_vals: Vec<f64> = Vec::with_capacity(out.len());
369 for &edge_idx in out {
370 let edge = &ddnnf.edges[edge_idx];
371 let child = eval_node(edge.to, ddnnf, memo, var_log_weights)?;
372 let mut lit_sum = 0.0;
373 for &lit in &edge.lits {
374 let var = lit.unsigned_abs();
375 let (t, f) = var_log_weights(var);
376 lit_sum += if lit > 0 { t } else { f };
377 }
378 branch_vals.push(lit_sum + child);
379 }
380 logsumexp(&branch_vals)
381 }
382 };
383
384 memo.insert(node_id, value);
385 Ok(value)
386 }
387
388 eval_node(self.root, self, &mut memo, &var_log_weights)
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_parse_and_eval_identity_variable() {
398 let nnf = r#"
400o 1 0
401t 2 0
402f 3 0
4031 2 1 0
4041 3 -1 0
405"#;
406
407 let ddnnf = DecisionDnnf::parse_str(nnf).unwrap();
408 assert_eq!(ddnnf.root(), 1);
409 assert_eq!(ddnnf.max_var(), 1);
410
411 let p = 0.3_f64;
412 let log_wmc = ddnnf
413 .eval_log_wmc(|var| match var {
414 1 => (p.ln(), (1.0 - p).ln()),
415 _ => panic!("unexpected var {}", var),
416 })
417 .unwrap();
418
419 assert!((log_wmc - p.ln()).abs() < 1e-9, "log_wmc={}", log_wmc);
420 }
421
422 #[test]
423 fn test_parse_detects_missing_terminator() {
424 let nnf = "t 1";
425 let err = DecisionDnnf::parse_str(nnf).unwrap_err();
426 let msg = err.to_string();
427 assert!(msg.contains("terminator"), "msg={}", msg);
428 }
429}