1use std::collections::HashMap;
4
5use xlog_core::{Result, XlogError};
6
7use crate::kc::ddnnf::{DdnnfEdge, DdnnfNodeKind, DecisionDnnf};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10#[repr(u8)]
11pub enum XgcfNodeType {
12 Const0 = 0,
13 Const1 = 1,
14 Lit = 2,
15 And = 3,
16 Or = 4,
17 Decision = 5,
18}
19
20#[derive(Debug, Clone)]
21pub struct Xgcf {
22 pub node_type: Vec<XgcfNodeType>,
23 pub child_offsets: Vec<u32>,
24 pub child_indices: Vec<u32>,
25 pub lit: Vec<i32>,
26 pub decision_var: Vec<u32>,
27 pub decision_child_false: Vec<u32>,
28 pub decision_child_true: Vec<u32>,
29 pub roots: Vec<u32>,
30 pub level_offsets: Vec<u32>,
31 pub level_nodes: Vec<u32>,
32}
33
34impl Xgcf {
35 pub fn from_ddnnf(ddnnf: &DecisionDnnf) -> Result<Self> {
36 XgcfBuilder::new(ddnnf).build()
37 }
38
39 pub fn smooth_random_vars(&self, is_random_var: &[bool]) -> Result<Self> {
47 XgcfSmoother::new(self, is_random_var)?.smooth()
48 }
49
50 pub fn eval_log_wmc<F>(&self, var_log_weights: F) -> Result<f64>
51 where
52 F: Fn(u32) -> (f64, f64),
53 {
54 if self.roots.len() != 1 {
55 return Err(XlogError::Compilation(format!(
56 "XGCF eval expects exactly 1 root, got {}",
57 self.roots.len()
58 )));
59 }
60
61 let n = self.node_type.len();
62 if self.child_offsets.len() != n + 1 {
63 return Err(XlogError::Compilation(format!(
64 "XGCF invariant violation: child_offsets len {} != num_nodes+1 ({})",
65 self.child_offsets.len(),
66 n + 1
67 )));
68 }
69 if self.lit.len() != n
70 || self.decision_var.len() != n
71 || self.decision_child_false.len() != n
72 || self.decision_child_true.len() != n
73 {
74 return Err(XlogError::Compilation(
75 "XGCF invariant violation: per-node arrays length mismatch".to_string(),
76 ));
77 }
78 if self.level_offsets.is_empty() || *self.level_offsets.first().unwrap() != 0 {
79 return Err(XlogError::Compilation(
80 "XGCF invariant violation: level_offsets must start at 0".to_string(),
81 ));
82 }
83 if *self.level_offsets.last().unwrap() != self.level_nodes.len() as u32 {
84 return Err(XlogError::Compilation(
85 "XGCF invariant violation: level_offsets last != level_nodes.len".to_string(),
86 ));
87 }
88
89 fn logsumexp(values: &[f64]) -> f64 {
90 let mut max = f64::NEG_INFINITY;
91 for &v in values {
92 if v > max {
93 max = v;
94 }
95 }
96 if max.is_infinite() {
97 return max;
98 }
99 let mut sum = 0.0;
100 for &v in values {
101 sum += (v - max).exp();
102 }
103 max + sum.ln()
104 }
105
106 let mut values: Vec<f64> = vec![0.0; n];
107
108 let num_levels = self.level_offsets.len() - 1;
109 for level in 0..num_levels {
110 let start = self.level_offsets[level] as usize;
111 let end = self.level_offsets[level + 1] as usize;
112 for &node_u32 in &self.level_nodes[start..end] {
113 let idx = node_u32 as usize;
114 let v = match self.node_type[idx] {
115 XgcfNodeType::Const0 => f64::NEG_INFINITY,
116 XgcfNodeType::Const1 => 0.0,
117 XgcfNodeType::Lit => {
118 let lit = self.lit[idx];
119 if lit == 0 {
120 return Err(XlogError::Compilation(format!(
121 "XGCF invariant violation: LIT node {} has lit=0",
122 idx
123 )));
124 }
125 let var = lit.unsigned_abs();
126 let (t, f) = var_log_weights(var);
127 if lit > 0 {
128 t
129 } else {
130 f
131 }
132 }
133 XgcfNodeType::And => {
134 let c0 = self.child_offsets[idx] as usize;
135 let c1 = self.child_offsets[idx + 1] as usize;
136 if c0 == c1 {
137 return Err(XlogError::Compilation(format!(
138 "XGCF eval error: AND node {} has no children",
139 idx
140 )));
141 }
142 let mut acc = 0.0;
143 for &child in &self.child_indices[c0..c1] {
144 acc += values[child as usize];
145 }
146 acc
147 }
148 XgcfNodeType::Or => {
149 let c0 = self.child_offsets[idx] as usize;
150 let c1 = self.child_offsets[idx + 1] as usize;
151 if c0 == c1 {
152 return Err(XlogError::Compilation(format!(
153 "XGCF eval error: OR node {} has no children",
154 idx
155 )));
156 }
157 let mut branch_vals: Vec<f64> = Vec::with_capacity(c1 - c0);
158 for &child in &self.child_indices[c0..c1] {
159 branch_vals.push(values[child as usize]);
160 }
161 logsumexp(&branch_vals)
162 }
163 XgcfNodeType::Decision => {
164 let var = self.decision_var[idx];
165 if var == 0 {
166 return Err(XlogError::Compilation(format!(
167 "XGCF invariant violation: DECISION node {} has var=0",
168 idx
169 )));
170 }
171 let child_false = self.decision_child_false[idx] as usize;
172 let child_true = self.decision_child_true[idx] as usize;
173 let (t, f) = var_log_weights(var);
174 logsumexp(&[f + values[child_false], t + values[child_true]])
175 }
176 };
177 values[idx] = v;
178 }
179 }
180
181 Ok(values[self.roots[0] as usize])
182 }
183
184 pub fn eval_log_wmc_and_grads(
185 &self,
186 var_log_weights: &[(f64, f64)],
187 ) -> Result<(f64, Vec<f64>, Vec<f64>)> {
188 if self.roots.len() != 1 {
189 return Err(XlogError::Compilation(format!(
190 "XGCF eval expects exactly 1 root, got {}",
191 self.roots.len()
192 )));
193 }
194
195 let n = self.node_type.len();
196 if self.child_offsets.len() != n + 1 {
197 return Err(XlogError::Compilation(format!(
198 "XGCF invariant violation: child_offsets len {} != num_nodes+1 ({})",
199 self.child_offsets.len(),
200 n + 1
201 )));
202 }
203 if self.lit.len() != n
204 || self.decision_var.len() != n
205 || self.decision_child_false.len() != n
206 || self.decision_child_true.len() != n
207 {
208 return Err(XlogError::Compilation(
209 "XGCF invariant violation: per-node arrays length mismatch".to_string(),
210 ));
211 }
212 if self.level_offsets.is_empty() || *self.level_offsets.first().unwrap() != 0 {
213 return Err(XlogError::Compilation(
214 "XGCF invariant violation: level_offsets must start at 0".to_string(),
215 ));
216 }
217 if *self.level_offsets.last().unwrap() != self.level_nodes.len() as u32 {
218 return Err(XlogError::Compilation(
219 "XGCF invariant violation: level_offsets last != level_nodes.len".to_string(),
220 ));
221 }
222
223 let mut max_var: u32 = 0;
224 for (&ty, &lit) in self.node_type.iter().zip(self.lit.iter()) {
225 if ty == XgcfNodeType::Lit && lit != 0 {
226 max_var = max_var.max(lit.unsigned_abs());
227 }
228 }
229 for &var in &self.decision_var {
230 max_var = max_var.max(var);
231 }
232
233 let weights_len = (max_var as usize) + 1;
234 if var_log_weights.len() < weights_len {
235 return Err(XlogError::Compilation(format!(
236 "XGCF eval expects weight table len >= {}, got {}",
237 weights_len,
238 var_log_weights.len()
239 )));
240 }
241
242 fn logsumexp(values: &[f64]) -> f64 {
243 let mut max = f64::NEG_INFINITY;
244 for &v in values {
245 if v > max {
246 max = v;
247 }
248 }
249 if max.is_infinite() {
250 return max;
251 }
252 let mut sum = 0.0;
253 for &v in values {
254 sum += (v - max).exp();
255 }
256 max + sum.ln()
257 }
258
259 let mut values: Vec<f64> = vec![0.0; n];
260
261 let num_levels = self.level_offsets.len() - 1;
262 for level in 0..num_levels {
263 let start = self.level_offsets[level] as usize;
264 let end = self.level_offsets[level + 1] as usize;
265 for &node_u32 in &self.level_nodes[start..end] {
266 let idx = node_u32 as usize;
267 let v = match self.node_type[idx] {
268 XgcfNodeType::Const0 => f64::NEG_INFINITY,
269 XgcfNodeType::Const1 => 0.0,
270 XgcfNodeType::Lit => {
271 let lit = self.lit[idx];
272 if lit == 0 {
273 return Err(XlogError::Compilation(format!(
274 "XGCF invariant violation: LIT node {} has lit=0",
275 idx
276 )));
277 }
278 let var = lit.unsigned_abs();
279 let (t, f) = var_log_weights[var as usize];
280 if lit > 0 {
281 t
282 } else {
283 f
284 }
285 }
286 XgcfNodeType::And => {
287 let c0 = self.child_offsets[idx] as usize;
288 let c1 = self.child_offsets[idx + 1] as usize;
289 if c0 == c1 {
290 return Err(XlogError::Compilation(format!(
291 "XGCF eval error: AND node {} has no children",
292 idx
293 )));
294 }
295 let mut acc = 0.0;
296 for &child in &self.child_indices[c0..c1] {
297 acc += values[child as usize];
298 }
299 acc
300 }
301 XgcfNodeType::Or => {
302 let c0 = self.child_offsets[idx] as usize;
303 let c1 = self.child_offsets[idx + 1] as usize;
304 if c0 == c1 {
305 return Err(XlogError::Compilation(format!(
306 "XGCF eval error: OR node {} has no children",
307 idx
308 )));
309 }
310 let mut branch_vals: Vec<f64> = Vec::with_capacity(c1 - c0);
311 for &child in &self.child_indices[c0..c1] {
312 branch_vals.push(values[child as usize]);
313 }
314 logsumexp(&branch_vals)
315 }
316 XgcfNodeType::Decision => {
317 let var = self.decision_var[idx];
318 if var == 0 {
319 return Err(XlogError::Compilation(format!(
320 "XGCF invariant violation: DECISION node {} has var=0",
321 idx
322 )));
323 }
324 let child_false = self.decision_child_false[idx] as usize;
325 let child_true = self.decision_child_true[idx] as usize;
326 let (t, f) = var_log_weights[var as usize];
327 logsumexp(&[f + values[child_false], t + values[child_true]])
328 }
329 };
330 values[idx] = v;
331 }
332 }
333
334 let root_idx = self.roots[0] as usize;
335 let log_z = values[root_idx];
336
337 let mut adj: Vec<f64> = vec![0.0; n];
338 adj[root_idx] = 1.0;
339
340 let mut grad_true: Vec<f64> = vec![0.0; weights_len];
341 let mut grad_false: Vec<f64> = vec![0.0; weights_len];
342
343 for level in (0..num_levels).rev() {
344 let start = self.level_offsets[level] as usize;
345 let end = self.level_offsets[level + 1] as usize;
346 for &node_u32 in &self.level_nodes[start..end] {
347 let idx = node_u32 as usize;
348 let a = adj[idx];
349 if a == 0.0 {
350 continue;
351 }
352
353 match self.node_type[idx] {
354 XgcfNodeType::Const0 | XgcfNodeType::Const1 => {}
355 XgcfNodeType::Lit => {
356 let lit = self.lit[idx];
357 if lit == 0 {
358 return Err(XlogError::Compilation(format!(
359 "XGCF invariant violation: LIT node {} has lit=0",
360 idx
361 )));
362 }
363 let var = lit.unsigned_abs() as usize;
364 if lit > 0 {
365 grad_true[var] += a;
366 } else {
367 grad_false[var] += a;
368 }
369 }
370 XgcfNodeType::And => {
371 let c0 = self.child_offsets[idx] as usize;
372 let c1 = self.child_offsets[idx + 1] as usize;
373 for &child in &self.child_indices[c0..c1] {
374 adj[child as usize] += a;
375 }
376 }
377 XgcfNodeType::Or => {
378 let parent_v = values[idx];
379 if parent_v.is_infinite() && parent_v.is_sign_negative() {
380 continue;
381 }
382 let c0 = self.child_offsets[idx] as usize;
383 let c1 = self.child_offsets[idx + 1] as usize;
384 for &child in &self.child_indices[c0..c1] {
385 let child_idx = child as usize;
386 let child_v = values[child_idx];
387 if child_v.is_infinite() && child_v.is_sign_negative() {
388 continue;
389 }
390 let ratio = (child_v - parent_v).exp();
391 if ratio != 0.0 {
392 adj[child_idx] += a * ratio;
393 }
394 }
395 }
396 XgcfNodeType::Decision => {
397 let var = self.decision_var[idx] as usize;
398 if var == 0 {
399 return Err(XlogError::Compilation(format!(
400 "XGCF invariant violation: DECISION node {} has var=0",
401 idx
402 )));
403 }
404
405 let parent_v = values[idx];
406 if parent_v.is_infinite() && parent_v.is_sign_negative() {
407 continue;
408 }
409
410 let child_false = self.decision_child_false[idx] as usize;
411 let child_true = self.decision_child_true[idx] as usize;
412 let (t, f) = var_log_weights[var];
413
414 let vf = values[child_false];
415 let vt = values[child_true];
416
417 let mut p_false = 0.0;
418 let mut p_true = 0.0;
419 if !(vf.is_infinite() && vf.is_sign_negative()) {
420 p_false = (f + vf - parent_v).exp();
421 }
422 if !(vt.is_infinite() && vt.is_sign_negative()) {
423 p_true = (t + vt - parent_v).exp();
424 }
425
426 if p_false != 0.0 {
427 adj[child_false] += a * p_false;
428 grad_false[var] += a * p_false;
429 }
430 if p_true != 0.0 {
431 adj[child_true] += a * p_true;
432 grad_true[var] += a * p_true;
433 }
434 }
435 }
436 }
437 }
438
439 Ok((log_z, grad_true, grad_false))
440 }
441}
442
443fn max_var_in_circuit(circuit: &Xgcf) -> Result<u32> {
444 let mut max_var: u32 = 0;
445 for (&ty, &lit) in circuit.node_type.iter().zip(circuit.lit.iter()) {
446 if ty == XgcfNodeType::Lit {
447 if lit == 0 {
448 return Err(XlogError::Compilation(
449 "XGCF invariant violation: LIT node has lit=0".to_string(),
450 ));
451 }
452 max_var = max_var.max(lit.unsigned_abs());
453 }
454 }
455 for (&ty, &var) in circuit.node_type.iter().zip(circuit.decision_var.iter()) {
456 if ty == XgcfNodeType::Decision {
457 if var == 0 {
458 return Err(XlogError::Compilation(
459 "XGCF invariant violation: DECISION node has var=0".to_string(),
460 ));
461 }
462 max_var = max_var.max(var);
463 }
464 }
465 Ok(max_var)
466}
467
468fn merge_union_sorted(a: &[u32], b: &[u32], out: &mut Vec<u32>) {
469 out.clear();
470 let mut i = 0usize;
471 let mut j = 0usize;
472 while i < a.len() && j < b.len() {
473 let va = a[i];
474 let vb = b[j];
475 if va == vb {
476 out.push(va);
477 i += 1;
478 j += 1;
479 } else if va < vb {
480 out.push(va);
481 i += 1;
482 } else {
483 out.push(vb);
484 j += 1;
485 }
486 }
487 if i < a.len() {
488 out.extend_from_slice(&a[i..]);
489 }
490 if j < b.len() {
491 out.extend_from_slice(&b[j..]);
492 }
493}
494
495fn sorted_difference(a: &[u32], b: &[u32], out: &mut Vec<u32>) {
496 out.clear();
497 let mut i = 0usize;
498 let mut j = 0usize;
499 while i < a.len() {
500 let va = a[i];
501 while j < b.len() && b[j] < va {
502 j += 1;
503 }
504 if j < b.len() && b[j] == va {
505 i += 1;
506 j += 1;
507 continue;
508 }
509 out.push(va);
510 i += 1;
511 }
512}
513
514fn insert_sorted_unique(sorted: &mut Vec<u32>, var: u32) {
515 match sorted.binary_search(&var) {
516 Ok(_) => {}
517 Err(pos) => sorted.insert(pos, var),
518 }
519}
520
521fn compute_random_support(circuit: &Xgcf, is_random_var: &[bool]) -> Result<Vec<Vec<u32>>> {
522 let n = circuit.node_type.len();
523 let mut support: Vec<Vec<u32>> = vec![Vec::new(); n];
524
525 let num_levels = circuit.level_offsets.len().saturating_sub(1);
526 for level in 0..num_levels {
527 let start = circuit.level_offsets[level] as usize;
528 let end = circuit.level_offsets[level + 1] as usize;
529 for &node_u32 in &circuit.level_nodes[start..end] {
530 let idx = node_u32 as usize;
531 match circuit.node_type[idx] {
532 XgcfNodeType::Const0 | XgcfNodeType::Const1 => {}
533 XgcfNodeType::Lit => {
534 let lit = circuit.lit[idx];
535 if lit == 0 {
536 return Err(XlogError::Compilation(format!(
537 "XGCF invariant violation: LIT node {} has lit=0",
538 idx
539 )));
540 }
541 let var = lit.unsigned_abs() as usize;
542 if var < is_random_var.len() && is_random_var[var] {
543 support[idx].push(var as u32);
544 }
545 }
546 XgcfNodeType::And | XgcfNodeType::Or => {
547 let c0 = circuit.child_offsets[idx] as usize;
548 let c1 = circuit.child_offsets[idx + 1] as usize;
549 let mut acc: Vec<u32> = Vec::new();
550 let mut tmp: Vec<u32> = Vec::new();
551 for &child in &circuit.child_indices[c0..c1] {
552 let child_idx = child as usize;
553 merge_union_sorted(&acc, &support[child_idx], &mut tmp);
554 std::mem::swap(&mut acc, &mut tmp);
555 }
556 support[idx] = acc;
557 }
558 XgcfNodeType::Decision => {
559 let var = circuit.decision_var[idx];
560 if var == 0 {
561 return Err(XlogError::Compilation(format!(
562 "XGCF invariant violation: DECISION node {} has var=0",
563 idx
564 )));
565 }
566 let child_false = circuit.decision_child_false[idx] as usize;
567 let child_true = circuit.decision_child_true[idx] as usize;
568
569 let mut acc: Vec<u32> = Vec::new();
570 merge_union_sorted(&support[child_false], &support[child_true], &mut acc);
571
572 let var_usize = var as usize;
573 if var_usize < is_random_var.len() && is_random_var[var_usize] {
574 insert_sorted_unique(&mut acc, var);
575 }
576 support[idx] = acc;
577 }
578 }
579 }
580 }
581
582 Ok(support)
583}
584
585struct XgcfSmoother<'a> {
586 input: &'a Xgcf,
587 is_random_var: &'a [bool],
588 support: Vec<Vec<u32>>,
589}
590
591impl<'a> XgcfSmoother<'a> {
592 fn new(input: &'a Xgcf, is_random_var: &'a [bool]) -> Result<Self> {
593 let n = input.node_type.len();
594 if input.child_offsets.len() != n + 1 {
595 return Err(XlogError::Compilation(format!(
596 "XGCF invariant violation: child_offsets len {} != num_nodes+1 ({})",
597 input.child_offsets.len(),
598 n + 1
599 )));
600 }
601 if input.lit.len() != n
602 || input.decision_var.len() != n
603 || input.decision_child_false.len() != n
604 || input.decision_child_true.len() != n
605 {
606 return Err(XlogError::Compilation(
607 "XGCF invariant violation: per-node arrays length mismatch".to_string(),
608 ));
609 }
610
611 let max_var = max_var_in_circuit(input)?;
612 if is_random_var.len() <= (max_var as usize) {
613 return Err(XlogError::Compilation(format!(
614 "XGCF smoothing expects is_random_var len >= {}, got {}",
615 (max_var as usize) + 1,
616 is_random_var.len()
617 )));
618 }
619
620 let support = compute_random_support(input, is_random_var)?;
621 Ok(Self {
622 input,
623 is_random_var,
624 support,
625 })
626 }
627
628 fn smooth(&self) -> Result<Xgcf> {
629 XgcfSmoothBuilder::new(self.input).smooth(self.is_random_var, &self.support)
630 }
631}
632
633struct XgcfSmoothBuilder<'a> {
634 input: &'a Xgcf,
635 node_type: Vec<XgcfNodeType>,
636 child_offsets: Vec<u32>,
637 child_indices: Vec<u32>,
638 lit: Vec<i32>,
639 decision_var: Vec<u32>,
640 decision_child_false: Vec<u32>,
641 decision_child_true: Vec<u32>,
642 old_to_new: Vec<Option<u32>>,
643 lit_cache: HashMap<i32, u32>,
644 tautology_cache: HashMap<u32, u32>,
645 const0: u32,
646 const1: u32,
647}
648
649impl<'a> XgcfSmoothBuilder<'a> {
650 fn new(input: &'a Xgcf) -> Self {
651 let mut b = Self {
652 input,
653 node_type: Vec::new(),
654 child_offsets: Vec::new(),
655 child_indices: Vec::new(),
656 lit: Vec::new(),
657 decision_var: Vec::new(),
658 decision_child_false: Vec::new(),
659 decision_child_true: Vec::new(),
660 old_to_new: Vec::new(),
661 lit_cache: HashMap::new(),
662 tautology_cache: HashMap::new(),
663 const0: 0,
664 const1: 0,
665 };
666
667 b.const0 = b.push_const(false);
668 b.const1 = b.push_const(true);
669 b
670 }
671
672 fn push_base_node(&mut self, ty: XgcfNodeType) -> u32 {
673 let idx = u32::try_from(self.node_type.len()).expect("XGCF node index overflow");
674 self.node_type.push(ty);
675 self.child_offsets.push(self.child_indices.len() as u32);
676 self.lit.push(0);
677 self.decision_var.push(0);
678 self.decision_child_false.push(0);
679 self.decision_child_true.push(0);
680 idx
681 }
682
683 fn push_const(&mut self, value: bool) -> u32 {
684 self.push_base_node(if value {
685 XgcfNodeType::Const1
686 } else {
687 XgcfNodeType::Const0
688 })
689 }
690
691 fn get_lit_node(&mut self, lit: i32) -> Result<u32> {
692 if lit == 0 {
693 return Err(XlogError::Compilation(
694 "Cannot create XGCF LIT for 0 literal".to_string(),
695 ));
696 }
697 if let Some(&idx) = self.lit_cache.get(&lit) {
698 return Ok(idx);
699 }
700 let idx = self.push_base_node(XgcfNodeType::Lit);
701 self.lit[idx as usize] = lit;
702 self.lit_cache.insert(lit, idx);
703 Ok(idx)
704 }
705
706 fn push_and(&mut self, mut children: Vec<u32>) -> Result<u32> {
707 if children.contains(&self.const0) {
708 return Ok(self.const0);
709 }
710 children.retain(|&c| c != self.const1);
711 children.sort();
712 children.dedup();
713 match children.as_slice() {
714 [] => Ok(self.const1),
715 [only] => Ok(*only),
716 _ => {
717 let idx = self.push_base_node(XgcfNodeType::And);
718 self.child_indices.extend_from_slice(&children);
719 Ok(idx)
720 }
721 }
722 }
723
724 fn push_or(&mut self, mut children: Vec<u32>) -> Result<u32> {
725 children.retain(|&c| c != self.const0);
726 children.sort();
727 children.dedup();
728 match children.as_slice() {
729 [] => Ok(self.const0),
730 [only] => Ok(*only),
731 _ => {
732 let idx = self.push_base_node(XgcfNodeType::Or);
733 self.child_indices.extend_from_slice(&children);
734 Ok(idx)
735 }
736 }
737 }
738
739 fn push_decision(&mut self, var: u32, child_false: u32, child_true: u32) -> Result<u32> {
740 if var == 0 {
741 return Err(XlogError::Compilation(
742 "Cannot create XGCF DECISION with var=0".to_string(),
743 ));
744 }
745 let idx = self.push_base_node(XgcfNodeType::Decision);
746 self.decision_var[idx as usize] = var;
747 self.decision_child_false[idx as usize] = child_false;
748 self.decision_child_true[idx as usize] = child_true;
749 Ok(idx)
750 }
751
752 fn tautology_decision(&mut self, var: u32) -> Result<u32> {
753 if let Some(&idx) = self.tautology_cache.get(&var) {
754 return Ok(idx);
755 }
756 let idx = self.push_decision(var, self.const1, self.const1)?;
757 self.tautology_cache.insert(var, idx);
758 Ok(idx)
759 }
760
761 fn smooth(mut self, is_random_var: &[bool], support: &[Vec<u32>]) -> Result<Xgcf> {
762 let n = self.input.node_type.len();
763 self.old_to_new = vec![None; n];
764
765 let num_levels = self.input.level_offsets.len().saturating_sub(1);
766 for level in 0..num_levels {
767 let start = self.input.level_offsets[level] as usize;
768 let end = self.input.level_offsets[level + 1] as usize;
769 for &node_u32 in &self.input.level_nodes[start..end] {
770 let idx = node_u32 as usize;
771
772 let new_idx = match self.input.node_type[idx] {
773 XgcfNodeType::Const0 => self.const0,
774 XgcfNodeType::Const1 => self.const1,
775 XgcfNodeType::Lit => {
776 let lit = self.input.lit[idx];
777 self.get_lit_node(lit)?
778 }
779 XgcfNodeType::And => {
780 let c0 = self.input.child_offsets[idx] as usize;
781 let c1 = self.input.child_offsets[idx + 1] as usize;
782 let mut children: Vec<u32> = Vec::with_capacity(c1 - c0);
783 for &child in &self.input.child_indices[c0..c1] {
784 let child_idx = child as usize;
785 let mapped = self.old_to_new[child_idx].ok_or_else(|| {
786 XlogError::Compilation(format!(
787 "XGCF smoothing error: missing mapped child {} for AND node {}",
788 child_idx, idx
789 ))
790 })?;
791 children.push(mapped);
792 }
793 self.push_and(children)?
794 }
795 XgcfNodeType::Or => {
796 let parent_support = &support[idx];
797 let c0 = self.input.child_offsets[idx] as usize;
798 let c1 = self.input.child_offsets[idx + 1] as usize;
799 let mut wrapped_children: Vec<u32> = Vec::with_capacity(c1 - c0);
800 let mut missing: Vec<u32> = Vec::new();
801 for &child in &self.input.child_indices[c0..c1] {
802 let child_idx = child as usize;
803 let child_new = self.old_to_new[child_idx].ok_or_else(|| {
804 XlogError::Compilation(format!(
805 "XGCF smoothing error: missing mapped child {} for OR node {}",
806 child_idx, idx
807 ))
808 })?;
809
810 let child_support = &support[child_idx];
811 sorted_difference(parent_support, child_support, &mut missing);
812
813 if missing.is_empty() {
814 wrapped_children.push(child_new);
815 } else {
816 let mut and_children: Vec<u32> =
817 Vec::with_capacity(1 + missing.len());
818 and_children.push(child_new);
819 for &var in &missing {
820 let var_usize = var as usize;
821 if var_usize < is_random_var.len() && is_random_var[var_usize] {
822 and_children.push(self.tautology_decision(var)?);
823 }
824 }
825 wrapped_children.push(self.push_and(and_children)?);
826 }
827 }
828 self.push_or(wrapped_children)?
829 }
830 XgcfNodeType::Decision => {
831 let var = self.input.decision_var[idx];
832 let child_false_old = self.input.decision_child_false[idx] as usize;
833 let child_true_old = self.input.decision_child_true[idx] as usize;
834
835 let child_false_new = self.old_to_new[child_false_old].ok_or_else(|| {
836 XlogError::Compilation(format!(
837 "XGCF smoothing error: missing mapped decision false child {} for node {}",
838 child_false_old, idx
839 ))
840 })?;
841 let child_true_new = self.old_to_new[child_true_old].ok_or_else(|| {
842 XlogError::Compilation(format!(
843 "XGCF smoothing error: missing mapped decision true child {} for node {}",
844 child_true_old, idx
845 ))
846 })?;
847
848 let mut union_children: Vec<u32> = Vec::new();
849 merge_union_sorted(
850 &support[child_false_old],
851 &support[child_true_old],
852 &mut union_children,
853 );
854
855 if union_children.binary_search(&var).is_ok() {
856 return Err(XlogError::Compilation(format!(
857 "XGCF smoothing error: decision var {} appears in child support at node {}",
858 var, idx
859 )));
860 }
861
862 let mut missing: Vec<u32> = Vec::new();
863
864 sorted_difference(&union_children, &support[child_false_old], &mut missing);
865 let new_false = if missing.is_empty() {
866 child_false_new
867 } else {
868 let mut and_children: Vec<u32> = Vec::with_capacity(1 + missing.len());
869 and_children.push(child_false_new);
870 for &v in &missing {
871 let v_usize = v as usize;
872 if v_usize < is_random_var.len() && is_random_var[v_usize] {
873 and_children.push(self.tautology_decision(v)?);
874 }
875 }
876 self.push_and(and_children)?
877 };
878
879 sorted_difference(&union_children, &support[child_true_old], &mut missing);
880 let new_true = if missing.is_empty() {
881 child_true_new
882 } else {
883 let mut and_children: Vec<u32> = Vec::with_capacity(1 + missing.len());
884 and_children.push(child_true_new);
885 for &v in &missing {
886 let v_usize = v as usize;
887 if v_usize < is_random_var.len() && is_random_var[v_usize] {
888 and_children.push(self.tautology_decision(v)?);
889 }
890 }
891 self.push_and(and_children)?
892 };
893
894 self.push_decision(var, new_false, new_true)?
895 }
896 };
897
898 self.old_to_new[idx] = Some(new_idx);
899 }
900 }
901
902 self.child_offsets.push(self.child_indices.len() as u32);
904
905 let mut roots: Vec<u32> = Vec::with_capacity(self.input.roots.len());
906 for &root in &self.input.roots {
907 let idx = root as usize;
908 let mapped = self.old_to_new[idx].ok_or_else(|| {
909 XlogError::Compilation(format!(
910 "XGCF smoothing error: missing mapped root node {}",
911 idx
912 ))
913 })?;
914 roots.push(mapped);
915 }
916
917 let (level_offsets, level_nodes) = XgcfBuilder::levelize(
918 &self.node_type,
919 &self.child_offsets,
920 &self.child_indices,
921 &self.decision_child_false,
922 &self.decision_child_true,
923 &roots,
924 )?;
925
926 Ok(Xgcf {
927 node_type: self.node_type,
928 child_offsets: self.child_offsets,
929 child_indices: self.child_indices,
930 lit: self.lit,
931 decision_var: self.decision_var,
932 decision_child_false: self.decision_child_false,
933 decision_child_true: self.decision_child_true,
934 roots,
935 level_offsets,
936 level_nodes,
937 })
938 }
939}
940
941struct XgcfBuilder<'a> {
942 ddnnf: &'a DecisionDnnf,
943 node_type: Vec<XgcfNodeType>,
944 child_offsets: Vec<u32>,
945 child_indices: Vec<u32>,
946 lit: Vec<i32>,
947 decision_var: Vec<u32>,
948 decision_child_false: Vec<u32>,
949 decision_child_true: Vec<u32>,
950 lit_cache: HashMap<i32, u32>,
951 ddnnf_cache: HashMap<u32, u32>,
952 const0: u32,
953 const1: u32,
954}
955
956impl<'a> XgcfBuilder<'a> {
957 fn new(ddnnf: &'a DecisionDnnf) -> Self {
958 let mut builder = Self {
959 ddnnf,
960 node_type: Vec::new(),
961 child_offsets: Vec::new(),
962 child_indices: Vec::new(),
963 lit: Vec::new(),
964 decision_var: Vec::new(),
965 decision_child_false: Vec::new(),
966 decision_child_true: Vec::new(),
967 lit_cache: HashMap::new(),
968 ddnnf_cache: HashMap::new(),
969 const0: 0,
970 const1: 0,
971 };
972
973 builder.const0 = builder.push_const(false);
974 builder.const1 = builder.push_const(true);
975 builder
976 }
977
978 fn push_base_node(&mut self, ty: XgcfNodeType) -> u32 {
979 let idx = u32::try_from(self.node_type.len()).expect("XGCF node index overflow");
980 self.node_type.push(ty);
981 self.child_offsets.push(self.child_indices.len() as u32);
982 self.lit.push(0);
983 self.decision_var.push(0);
984 self.decision_child_false.push(0);
985 self.decision_child_true.push(0);
986 idx
987 }
988
989 fn push_const(&mut self, value: bool) -> u32 {
990 self.push_base_node(if value {
991 XgcfNodeType::Const1
992 } else {
993 XgcfNodeType::Const0
994 })
995 }
996
997 fn get_lit_node(&mut self, lit: i32) -> Result<u32> {
998 if lit == 0 {
999 return Err(XlogError::Compilation(
1000 "Cannot create XGCF LIT for 0 literal".to_string(),
1001 ));
1002 }
1003 if let Some(&idx) = self.lit_cache.get(&lit) {
1004 return Ok(idx);
1005 }
1006 let idx = self.push_base_node(XgcfNodeType::Lit);
1007 self.lit[idx as usize] = lit;
1008 self.lit_cache.insert(lit, idx);
1009 Ok(idx)
1010 }
1011
1012 fn push_and(&mut self, mut children: Vec<u32>) -> Result<u32> {
1013 if children.contains(&self.const0) {
1014 return Ok(self.const0);
1015 }
1016 children.retain(|&c| c != self.const1);
1017 children.sort();
1018 children.dedup();
1019 match children.as_slice() {
1020 [] => Ok(self.const1),
1021 [only] => Ok(*only),
1022 _ => {
1023 let idx = self.push_base_node(XgcfNodeType::And);
1024 self.child_indices.extend_from_slice(&children);
1025 Ok(idx)
1026 }
1027 }
1028 }
1029
1030 fn push_or(&mut self, mut children: Vec<u32>) -> Result<u32> {
1031 children.retain(|&c| c != self.const0);
1032 children.sort();
1033 children.dedup();
1034 match children.as_slice() {
1035 [] => Ok(self.const0),
1036 [only] => Ok(*only),
1037 _ => {
1038 let idx = self.push_base_node(XgcfNodeType::Or);
1039 self.child_indices.extend_from_slice(&children);
1040 Ok(idx)
1041 }
1042 }
1043 }
1044
1045 fn push_decision(&mut self, var: u32, child_false: u32, child_true: u32) -> Result<u32> {
1046 if var == 0 {
1047 return Err(XlogError::Compilation(
1048 "Cannot create XGCF DECISION with var=0".to_string(),
1049 ));
1050 }
1051 let idx = self.push_base_node(XgcfNodeType::Decision);
1052 self.decision_var[idx as usize] = var;
1053 self.decision_child_false[idx as usize] = child_false;
1054 self.decision_child_true[idx as usize] = child_true;
1055 Ok(idx)
1056 }
1057
1058 fn build(mut self) -> Result<Xgcf> {
1059 let root = self.convert_ddnnf_node(self.ddnnf.root())?;
1060
1061 self.child_offsets.push(self.child_indices.len() as u32);
1063
1064 let roots = vec![root];
1065 let (level_offsets, level_nodes) = Self::levelize(
1066 &self.node_type,
1067 &self.child_offsets,
1068 &self.child_indices,
1069 &self.decision_child_false,
1070 &self.decision_child_true,
1071 &roots,
1072 )?;
1073
1074 Ok(Xgcf {
1075 node_type: self.node_type,
1076 child_offsets: self.child_offsets,
1077 child_indices: self.child_indices,
1078 lit: self.lit,
1079 decision_var: self.decision_var,
1080 decision_child_false: self.decision_child_false,
1081 decision_child_true: self.decision_child_true,
1082 roots,
1083 level_offsets,
1084 level_nodes,
1085 })
1086 }
1087
1088 fn convert_ddnnf_node(&mut self, node_id: u32) -> Result<u32> {
1089 if let Some(&idx) = self.ddnnf_cache.get(&node_id) {
1090 return Ok(idx);
1091 }
1092 let kind = self.ddnnf.node_kind(node_id).ok_or_else(|| {
1093 XlogError::Compilation(format!("XGCF build error: unknown DDNNF node {}", node_id))
1094 })?;
1095
1096 let idx = match kind {
1097 DdnnfNodeKind::True => self.const1,
1098 DdnnfNodeKind::False => self.const0,
1099 DdnnfNodeKind::And => {
1100 let out = self.ddnnf.outgoing_edge_indices(node_id).ok_or_else(|| {
1101 XlogError::Compilation(format!(
1102 "XGCF build error: AND node {} has no outgoing edges",
1103 node_id
1104 ))
1105 })?;
1106 let mut child_nodes: Vec<u32> = Vec::with_capacity(out.len());
1107 for &edge_idx in out {
1108 child_nodes.push(self.convert_ddnnf_edge_branch(edge_idx, None)?);
1109 }
1110 self.push_and(child_nodes)?
1111 }
1112 DdnnfNodeKind::Or => {
1113 let out = self.ddnnf.outgoing_edge_indices(node_id).ok_or_else(|| {
1114 XlogError::Compilation(format!(
1115 "XGCF build error: OR node {} has no outgoing edges",
1116 node_id
1117 ))
1118 })?;
1119 if out.len() == 2 {
1120 let e0 = self.ddnnf.edge(out[0]).ok_or_else(|| {
1121 XlogError::Compilation(format!("XGCF build error: missing edge {}", out[0]))
1122 })?;
1123 let e1 = self.ddnnf.edge(out[1]).ok_or_else(|| {
1124 XlogError::Compilation(format!("XGCF build error: missing edge {}", out[1]))
1125 })?;
1126
1127 if let Some((var, edge_true, edge_false)) = infer_decision_var(e0, e1)? {
1128 let edge_true = out[edge_true];
1129 let edge_false = out[edge_false];
1130 let child_true =
1131 self.convert_ddnnf_edge_branch(edge_true, Some(var as i32))?;
1132 let child_false =
1133 self.convert_ddnnf_edge_branch(edge_false, Some(-(var as i32)))?;
1134 self.push_decision(var, child_false, child_true)?
1135 } else {
1136 let mut child_nodes: Vec<u32> = Vec::with_capacity(out.len());
1137 for &edge_idx in out {
1138 child_nodes.push(self.convert_ddnnf_edge_branch(edge_idx, None)?);
1139 }
1140 self.push_or(child_nodes)?
1141 }
1142 } else {
1143 let mut child_nodes: Vec<u32> = Vec::with_capacity(out.len());
1144 for &edge_idx in out {
1145 child_nodes.push(self.convert_ddnnf_edge_branch(edge_idx, None)?);
1146 }
1147 self.push_or(child_nodes)?
1148 }
1149 }
1150 };
1151
1152 self.ddnnf_cache.insert(node_id, idx);
1153 Ok(idx)
1154 }
1155
1156 fn convert_ddnnf_edge_branch(&mut self, edge_idx: usize, drop_lit: Option<i32>) -> Result<u32> {
1157 let edge = self.ddnnf.edge(edge_idx).ok_or_else(|| {
1158 XlogError::Compilation(format!("XGCF build error: missing edge {}", edge_idx))
1159 })?;
1160
1161 let child = self.convert_ddnnf_node(edge.to)?;
1162
1163 let mut children: Vec<u32> = Vec::new();
1164 children.push(child);
1165
1166 let mut dropped = false;
1167 for &lit in &edge.lits {
1168 if let Some(dl) = drop_lit {
1169 if !dropped && lit == dl {
1170 dropped = true;
1171 continue;
1172 }
1173 }
1174 children.push(self.get_lit_node(lit)?);
1175 }
1176
1177 if let Some(dl) = drop_lit {
1178 if !dropped {
1179 return Err(XlogError::Compilation(format!(
1180 "XGCF build error: expected to drop literal {} on edge {}->{} but not present",
1181 dl, edge.from, edge.to
1182 )));
1183 }
1184 }
1185
1186 self.push_and(children)
1187 }
1188
1189 fn levelize(
1190 node_type: &[XgcfNodeType],
1191 child_offsets: &[u32],
1192 child_indices: &[u32],
1193 decision_child_false: &[u32],
1194 decision_child_true: &[u32],
1195 roots: &[u32],
1196 ) -> Result<(Vec<u32>, Vec<u32>)> {
1197 let n = node_type.len();
1198 let mut levels: Vec<Option<u32>> = vec![None; n];
1199 let mut visiting: Vec<bool> = vec![false; n];
1200
1201 #[allow(clippy::too_many_arguments)]
1202 fn level_of(
1203 idx: usize,
1204 node_type: &[XgcfNodeType],
1205 child_offsets: &[u32],
1206 child_indices: &[u32],
1207 decision_child_false: &[u32],
1208 decision_child_true: &[u32],
1209 levels: &mut [Option<u32>],
1210 visiting: &mut [bool],
1211 ) -> Result<u32> {
1212 if let Some(lvl) = levels[idx] {
1213 return Ok(lvl);
1214 }
1215 if visiting[idx] {
1216 return Err(XlogError::Compilation(format!(
1217 "XGCF levelize error: cycle detected at node {}",
1218 idx
1219 )));
1220 }
1221 visiting[idx] = true;
1222
1223 let lvl = match node_type[idx] {
1224 XgcfNodeType::Const0 | XgcfNodeType::Const1 | XgcfNodeType::Lit => 0,
1225 XgcfNodeType::And | XgcfNodeType::Or => {
1226 let c0 = child_offsets[idx] as usize;
1227 let c1 = child_offsets[idx + 1] as usize;
1228 let mut max_child = 0u32;
1229 for &child in &child_indices[c0..c1] {
1230 max_child = max_child.max(level_of(
1231 child as usize,
1232 node_type,
1233 child_offsets,
1234 child_indices,
1235 decision_child_false,
1236 decision_child_true,
1237 levels,
1238 visiting,
1239 )?);
1240 }
1241 max_child + 1
1242 }
1243 XgcfNodeType::Decision => {
1244 let lf = level_of(
1245 decision_child_false[idx] as usize,
1246 node_type,
1247 child_offsets,
1248 child_indices,
1249 decision_child_false,
1250 decision_child_true,
1251 levels,
1252 visiting,
1253 )?;
1254 let lt = level_of(
1255 decision_child_true[idx] as usize,
1256 node_type,
1257 child_offsets,
1258 child_indices,
1259 decision_child_false,
1260 decision_child_true,
1261 levels,
1262 visiting,
1263 )?;
1264 lf.max(lt) + 1
1265 }
1266 };
1267
1268 visiting[idx] = false;
1269 levels[idx] = Some(lvl);
1270 Ok(lvl)
1271 }
1272
1273 for &root in roots {
1274 level_of(
1275 root as usize,
1276 node_type,
1277 child_offsets,
1278 child_indices,
1279 decision_child_false,
1280 decision_child_true,
1281 &mut levels,
1282 &mut visiting,
1283 )?;
1284 }
1285
1286 let max_level = levels.iter().flatten().copied().max().unwrap_or(0);
1287 let mut buckets: Vec<Vec<u32>> = vec![Vec::new(); (max_level as usize) + 1];
1288 for (i, lvl) in levels.iter().enumerate().take(n) {
1289 let Some(lvl) = lvl else {
1290 continue;
1291 };
1292 buckets[*lvl as usize].push(i as u32);
1293 }
1294
1295 let mut level_offsets: Vec<u32> = Vec::with_capacity(buckets.len() + 1);
1296 let mut level_nodes: Vec<u32> = Vec::new();
1297 level_offsets.push(0);
1298 for bucket in buckets {
1299 level_nodes.extend(bucket);
1300 level_offsets.push(level_nodes.len() as u32);
1301 }
1302 Ok((level_offsets, level_nodes))
1303 }
1304}
1305
1306fn infer_decision_var(e0: &DdnnfEdge, e1: &DdnnfEdge) -> Result<Option<(u32, usize, usize)>> {
1307 fn sign_map(lits: &[i32]) -> Result<HashMap<u32, bool>> {
1308 let mut map: HashMap<u32, bool> = HashMap::new();
1309 for &lit in lits {
1310 let var = lit.unsigned_abs();
1311 let sign = lit > 0;
1312 if let Some(prev) = map.insert(var, sign) {
1313 if prev != sign {
1314 return Err(XlogError::Compilation(format!(
1315 "XGCF build error: conflicting literals {} and {} in same branch",
1316 var, lit
1317 )));
1318 }
1319 }
1320 }
1321 Ok(map)
1322 }
1323
1324 let m0 = sign_map(&e0.lits)?;
1325 let m1 = sign_map(&e1.lits)?;
1326
1327 let mut candidates: Vec<u32> = Vec::new();
1328 for (var, &s0) in &m0 {
1329 if let Some(&s1) = m1.get(var) {
1330 if s0 != s1 {
1331 candidates.push(*var);
1332 }
1333 }
1334 }
1335
1336 if candidates.len() != 1 {
1337 return Ok(None);
1338 }
1339 let var = candidates[0];
1340
1341 let edge0_is_true = m0.get(&var).copied().unwrap_or(false);
1342 let (edge_true, edge_false) = if edge0_is_true {
1343 (0usize, 1usize)
1344 } else {
1345 (1usize, 0usize)
1346 };
1347
1348 Ok(Some((var, edge_true, edge_false)))
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353 use super::*;
1354 use crate::kc::ddnnf::DecisionDnnf;
1355
1356 #[test]
1357 fn test_xgcf_matches_ddnnf_on_single_decision() {
1358 let nnf = r#"
1359o 1 0
1360t 2 0
1361f 3 0
13621 2 1 0
13631 3 -1 0
1364"#;
1365 let ddnnf = DecisionDnnf::parse_str(nnf).unwrap();
1366 let xgcf = Xgcf::from_ddnnf(&ddnnf).unwrap();
1367
1368 let p = 0.37_f64;
1369 let w = |var: u32| match var {
1370 1 => (p.ln(), (1.0 - p).ln()),
1371 _ => panic!("unexpected var {}", var),
1372 };
1373
1374 let a = ddnnf.eval_log_wmc(w).unwrap();
1375 let b = xgcf.eval_log_wmc(w).unwrap();
1376 assert!((a - b).abs() < 1e-9, "ddnnf={} xgcf={}", a, b);
1377 }
1378
1379 #[test]
1380 fn test_xgcf_matches_ddnnf_on_two_stage_decision() {
1381 let nnf = r#"
1383o 1 0
1384o 2 0
1385t 3 0
1386f 4 0
13871 3 1 0
13881 2 -1 0
13892 3 2 0
13902 4 -2 0
1391"#;
1392 let ddnnf = DecisionDnnf::parse_str(nnf).unwrap();
1393 let xgcf = Xgcf::from_ddnnf(&ddnnf).unwrap();
1394
1395 let p1 = 0.2_f64;
1396 let p2 = 0.6_f64;
1397 let w = |var: u32| match var {
1398 1 => (p1.ln(), (1.0 - p1).ln()),
1399 2 => (p2.ln(), (1.0 - p2).ln()),
1400 _ => panic!("unexpected var {}", var),
1401 };
1402
1403 let a = ddnnf.eval_log_wmc(w).unwrap();
1404 let b = xgcf.eval_log_wmc(w).unwrap();
1405 assert!((a - b).abs() < 1e-9, "ddnnf={} xgcf={}", a, b);
1406 }
1407}