1use std::collections::{BTreeMap, HashMap, HashSet};
4
5use xlog_core::{Result, XlogError};
6
7use crate::pir::{ChoiceVarId, LeafId, PirGraph, PirNode, PirNodeId};
8
9#[derive(Debug, Clone, Default)]
11pub struct CnfFormula {
12 num_vars: u32,
13 clauses: Vec<Vec<i32>>,
14}
15
16impl CnfFormula {
17 pub fn num_vars(&self) -> u32 {
18 self.num_vars
19 }
20
21 pub fn clauses(&self) -> &[Vec<i32>] {
22 &self.clauses
23 }
24
25 pub fn to_dimacs(&self) -> String {
26 let mut out = String::new();
27 out.push_str("c xlog-prob cnf\n");
28 out.push_str(&format!("p cnf {} {}\n", self.num_vars, self.clauses.len()));
29 for clause in &self.clauses {
30 for lit in clause {
31 out.push_str(&format!("{} ", lit));
32 }
33 out.push_str("0\n");
34 }
35 out
36 }
37}
38
39#[derive(Debug, Clone)]
40pub struct CnfEncoding {
41 pub cnf: CnfFormula,
42 pub node_var: BTreeMap<PirNodeId, u32>,
43 pub leaf_var: BTreeMap<LeafId, u32>,
44 pub choice_var: BTreeMap<ChoiceVarId, u32>,
45}
46
47fn fnv1a(bytes: &[u8]) -> u64 {
49 const OFFSET: u64 = 0xcbf29ce484222325;
50 const PRIME: u64 = 0x100000001b3;
51 let mut h = OFFSET;
52 for &b in bytes {
53 h ^= b as u64;
54 h = h.wrapping_mul(PRIME);
55 }
56 h
57}
58
59fn canonical_hashes(pir: &PirGraph, levels: &[Vec<PirNodeId>]) -> HashMap<PirNodeId, u64> {
63 let mut canon: HashMap<PirNodeId, u64> = HashMap::new();
64 for level in levels {
65 for &id in level {
66 let node = pir.node(id).unwrap();
67 let h = match node {
68 PirNode::Const(b) => fnv1a(&[0, *b as u8]),
69 PirNode::Lit { leaf } => {
70 let mut buf = [0u8; 5];
71 buf[0] = 1;
72 buf[1..5].copy_from_slice(&leaf.as_u32().to_le_bytes());
73 fnv1a(&buf)
74 }
75 PirNode::NegLit { leaf } => {
76 let mut buf = [0u8; 5];
77 buf[0] = 2;
78 buf[1..5].copy_from_slice(&leaf.as_u32().to_le_bytes());
79 fnv1a(&buf)
80 }
81 PirNode::And { children } => {
82 let mut child_h: Vec<u64> = children.iter().map(|c| canon[c]).collect();
83 child_h.sort();
84 let mut buf = vec![3u8];
85 for h in child_h {
86 buf.extend_from_slice(&h.to_le_bytes());
87 }
88 fnv1a(&buf)
89 }
90 PirNode::Or { children } => {
91 let mut child_h: Vec<u64> = children.iter().map(|c| canon[c]).collect();
92 child_h.sort();
93 let mut buf = vec![4u8];
94 for h in child_h {
95 buf.extend_from_slice(&h.to_le_bytes());
96 }
97 fnv1a(&buf)
98 }
99 PirNode::Decision {
100 var,
101 child_false,
102 child_true,
103 } => {
104 let mut buf = vec![5u8];
105 buf.extend_from_slice(&var.as_u32().to_le_bytes());
106 buf.extend_from_slice(&canon[child_false].to_le_bytes());
107 buf.extend_from_slice(&canon[child_true].to_le_bytes());
108 fnv1a(&buf)
109 }
110 };
111 canon.insert(id, h);
112 }
113 }
114 canon
115}
116
117pub fn canonical_pir_hash(pir: &PirGraph, roots: &[PirNodeId]) -> Result<u64> {
127 if roots.is_empty() {
128 return Err(XlogError::Compilation(
129 "Cannot compute canonical hash for empty PIR root set".to_string(),
130 ));
131 }
132 let levels = pir.levelize(roots)?;
133 let canon = canonical_hashes(pir, &levels);
134
135 let mut root_hashes: Vec<u64> = roots
137 .iter()
138 .map(|r| canon.get(r).copied().unwrap_or(0))
139 .collect();
140 root_hashes.sort();
141
142 let mut buf = Vec::with_capacity(1 + root_hashes.len() * 8);
143 buf.push(0xFFu8); for h in root_hashes {
145 buf.extend_from_slice(&h.to_le_bytes());
146 }
147 Ok(fnv1a(&buf))
148}
149
150pub fn encode_cnf(pir: &PirGraph, roots: &[PirNodeId]) -> Result<CnfEncoding> {
151 if roots.is_empty() {
152 return Err(XlogError::Compilation(
153 "Cannot encode CNF for empty PIR root set".to_string(),
154 ));
155 }
156
157 let mut visited: HashSet<PirNodeId> = HashSet::new();
158 let mut leaf_ids: HashSet<LeafId> = HashSet::new();
159 let mut choice_ids: HashSet<ChoiceVarId> = HashSet::new();
160
161 let mut stack: Vec<PirNodeId> = roots.to_vec();
162 while let Some(node_id) = stack.pop() {
163 if !visited.insert(node_id) {
164 continue;
165 }
166
167 let node = pir.node(node_id).ok_or_else(|| {
168 XlogError::Compilation(format!(
169 "Invalid PIR node id while encoding CNF: {:?}",
170 node_id
171 ))
172 })?;
173
174 match node {
175 PirNode::Const(_) => {}
176 PirNode::Lit { leaf } | PirNode::NegLit { leaf } => {
177 leaf_ids.insert(*leaf);
178 }
179 PirNode::And { children } | PirNode::Or { children } => {
180 stack.extend(children.iter().copied());
181 }
182 PirNode::Decision {
183 var,
184 child_false,
185 child_true,
186 } => {
187 choice_ids.insert(*var);
188 stack.push(*child_false);
189 stack.push(*child_true);
190 }
191 }
192 }
193
194 let mut leaf_list: Vec<LeafId> = leaf_ids.into_iter().collect();
195 leaf_list.sort();
196 let mut choice_list: Vec<ChoiceVarId> = choice_ids.into_iter().collect();
197 choice_list.sort();
198
199 let mut next_var: u32 = 1;
200 let mut leaf_var: BTreeMap<LeafId, u32> = BTreeMap::new();
201 for leaf in leaf_list {
202 leaf_var.insert(leaf, next_var);
203 next_var += 1;
204 }
205
206 let mut choice_var: BTreeMap<ChoiceVarId, u32> = BTreeMap::new();
207 for choice in choice_list {
208 choice_var.insert(choice, next_var);
209 next_var += 1;
210 }
211
212 let mut node_ids: Vec<PirNodeId> = visited.into_iter().collect();
213 node_ids.sort();
214
215 let mut node_var: BTreeMap<PirNodeId, u32> = BTreeMap::new();
216 for node_id in node_ids {
217 let node = pir.node(node_id).ok_or_else(|| {
218 XlogError::Compilation(format!(
219 "Invalid PIR node id while encoding CNF: {:?}",
220 node_id
221 ))
222 })?;
223
224 let var_id = match node {
225 PirNode::Lit { leaf } => *leaf_var.get(leaf).ok_or_else(|| {
226 XlogError::Compilation(format!(
227 "Missing CNF var for PIR leaf {:?} referenced by node {:?}",
228 leaf, node_id
229 ))
230 })?,
231 PirNode::NegLit { .. } => {
232 let v = next_var;
234 next_var += 1;
235 v
236 }
237 _ => {
238 let v = next_var;
239 next_var += 1;
240 v
241 }
242 };
243
244 node_var.insert(node_id, var_id);
245 }
246
247 let num_vars = next_var - 1;
248 let mut clauses: Vec<Vec<i32>> = Vec::new();
249
250 let levels = pir.levelize(roots)?;
251 for level in &levels {
252 for &node_id in level {
253 let node = pir.node(node_id).ok_or_else(|| {
254 XlogError::Compilation(format!(
255 "Invalid PIR node id while emitting CNF clauses: {:?}",
256 node_id
257 ))
258 })?;
259 let v = *node_var.get(&node_id).ok_or_else(|| {
260 XlogError::Compilation(format!(
261 "Missing CNF var for PIR node {:?} while emitting clauses",
262 node_id
263 ))
264 })? as i32;
265
266 match node {
267 PirNode::Const(true) => clauses.push(vec![v]),
268 PirNode::Const(false) => clauses.push(vec![-v]),
269 PirNode::Lit { .. } => {}
270 PirNode::NegLit { leaf } => {
271 let leaf_v = *leaf_var.get(leaf).ok_or_else(|| {
273 XlogError::Compilation(format!(
274 "Missing CNF var for NegLit leaf {:?} at node {:?}",
275 leaf, node_id
276 ))
277 })? as i32;
278 clauses.push(vec![v, leaf_v]); clauses.push(vec![-v, -leaf_v]); }
282 PirNode::And { children } => {
283 if children.is_empty() {
284 clauses.push(vec![v]);
285 continue;
286 }
287 for &child in children {
288 let c = *node_var.get(&child).ok_or_else(|| {
289 XlogError::Compilation(format!(
290 "Missing CNF var for AND child {:?} of {:?}",
291 child, node_id
292 ))
293 })? as i32;
294 clauses.push(vec![-v, c]);
295 }
296 let mut clause = Vec::with_capacity(children.len() + 1);
297 for &child in children {
298 let c = *node_var.get(&child).ok_or_else(|| {
299 XlogError::Compilation(format!(
300 "Missing CNF var for AND child {:?} of {:?}",
301 child, node_id
302 ))
303 })? as i32;
304 clause.push(-c);
305 }
306 clause.push(v);
307 clauses.push(clause);
308 }
309 PirNode::Or { children } => {
310 if children.is_empty() {
311 clauses.push(vec![-v]);
312 continue;
313 }
314 for &child in children {
315 let c = *node_var.get(&child).ok_or_else(|| {
316 XlogError::Compilation(format!(
317 "Missing CNF var for OR child {:?} of {:?}",
318 child, node_id
319 ))
320 })? as i32;
321 clauses.push(vec![-c, v]);
322 }
323 let mut clause = Vec::with_capacity(children.len() + 1);
324 clause.push(-v);
325 for &child in children {
326 let c = *node_var.get(&child).ok_or_else(|| {
327 XlogError::Compilation(format!(
328 "Missing CNF var for OR child {:?} of {:?}",
329 child, node_id
330 ))
331 })? as i32;
332 clause.push(c);
333 }
334 clauses.push(clause);
335 }
336 PirNode::Decision {
337 var,
338 child_false,
339 child_true,
340 } => {
341 let x = *choice_var.get(var).ok_or_else(|| {
342 XlogError::Compilation(format!(
343 "Missing CNF var for decision variable {:?} at node {:?}",
344 var, node_id
345 ))
346 })? as i32;
347
348 let f = *node_var.get(child_false).ok_or_else(|| {
349 XlogError::Compilation(format!(
350 "Missing CNF var for decision false child {:?} at node {:?}",
351 child_false, node_id
352 ))
353 })? as i32;
354 let t = *node_var.get(child_true).ok_or_else(|| {
355 XlogError::Compilation(format!(
356 "Missing CNF var for decision true child {:?} at node {:?}",
357 child_true, node_id
358 ))
359 })? as i32;
360
361 clauses.push(vec![-x, -t, v]); clauses.push(vec![x, -f, v]); clauses.push(vec![-x, t, -v]); clauses.push(vec![x, f, -v]); }
367 }
368 }
369 }
370
371 Ok(CnfEncoding {
372 cnf: CnfFormula { num_vars, clauses },
373 node_var,
374 leaf_var,
375 choice_var,
376 })
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::pir::{ChoiceVarId, LeafId, PirGraph};
383
384 fn is_sat_with_unit_clauses(cnf: &CnfFormula, units: &[i32]) -> bool {
385 let num_vars = cnf.num_vars() as usize;
386 assert!(
387 num_vars <= 20,
388 "test sat checker only supports small CNFs (vars={})",
389 num_vars
390 );
391
392 let mut clauses: Vec<&[i32]> = cnf.clauses().iter().map(|c| c.as_slice()).collect();
393 let unit_clauses: Vec<Vec<i32>> = units.iter().map(|&u| vec![u]).collect();
394 for uc in &unit_clauses {
395 clauses.push(uc.as_slice());
396 }
397
398 'assign: for mask in 0..(1u64 << num_vars) {
399 for clause in &clauses {
400 let mut clause_sat = false;
401 for &lit in *clause {
402 let var = lit.unsigned_abs() as usize;
403 assert!(var >= 1 && var <= num_vars);
404 let bit = (mask >> (var - 1)) & 1;
405 let val = bit == 1;
406 let lit_sat = if lit > 0 { val } else { !val };
407 if lit_sat {
408 clause_sat = true;
409 break;
410 }
411 }
412 if !clause_sat {
413 continue 'assign;
414 }
415 }
416 return true;
417 }
418 false
419 }
420
421 #[test]
422 fn test_encode_cnf_does_not_force_root_assignment() {
423 let mut pir = PirGraph::new();
424 let a = pir.lit(LeafId::new(0));
425
426 let encoding = encode_cnf(&pir, &[a]).unwrap();
427 let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
428
429 assert!(is_sat_with_unit_clauses(&encoding.cnf, &[-var_a]));
430 }
431
432 #[test]
433 fn test_tseitin_and_requires_both_children() {
434 let mut pir = PirGraph::new();
435 let a = pir.lit(LeafId::new(0));
436 let b = pir.lit(LeafId::new(1));
437 let root = pir.and(vec![a, b]);
438
439 let encoding = encode_cnf(&pir, &[root]).unwrap();
440 let var_root = *encoding.node_var.get(&root).unwrap() as i32;
441 let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
442 let var_b = *encoding.leaf_var.get(&LeafId::new(1)).unwrap() as i32;
443
444 assert!(is_sat_with_unit_clauses(
445 &encoding.cnf,
446 &[var_root, var_a, var_b]
447 ));
448 assert!(!is_sat_with_unit_clauses(
449 &encoding.cnf,
450 &[var_root, var_a, -var_b]
451 ));
452 assert!(!is_sat_with_unit_clauses(
453 &encoding.cnf,
454 &[var_root, -var_a, var_b]
455 ));
456 assert!(!is_sat_with_unit_clauses(
457 &encoding.cnf,
458 &[var_root, -var_a, -var_b]
459 ));
460 }
461
462 #[test]
463 fn test_tseitin_or_requires_one_child() {
464 let mut pir = PirGraph::new();
465 let a = pir.lit(LeafId::new(0));
466 let b = pir.lit(LeafId::new(1));
467 let root = pir.or(vec![a, b]);
468
469 let encoding = encode_cnf(&pir, &[root]).unwrap();
470 let var_root = *encoding.node_var.get(&root).unwrap() as i32;
471 let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
472 let var_b = *encoding.leaf_var.get(&LeafId::new(1)).unwrap() as i32;
473
474 assert!(is_sat_with_unit_clauses(
475 &encoding.cnf,
476 &[var_root, var_a, var_b]
477 ));
478 assert!(is_sat_with_unit_clauses(
479 &encoding.cnf,
480 &[var_root, var_a, -var_b]
481 ));
482 assert!(is_sat_with_unit_clauses(
483 &encoding.cnf,
484 &[var_root, -var_a, var_b]
485 ));
486 assert!(!is_sat_with_unit_clauses(
487 &encoding.cnf,
488 &[var_root, -var_a, -var_b]
489 ));
490 }
491
492 #[test]
493 fn test_tseitin_decision_mux_matches_choice_var() {
494 let mut pir = PirGraph::new();
495 let t = pir.const_true();
496 let f = pir.const_false();
497 let root = pir.decision(ChoiceVarId::new(0), f, t);
498
499 let encoding = encode_cnf(&pir, &[root]).unwrap();
500 let var_root = *encoding.node_var.get(&root).unwrap() as i32;
501 let x = *encoding.choice_var.get(&ChoiceVarId::new(0)).unwrap() as i32;
502
503 assert!(is_sat_with_unit_clauses(&encoding.cnf, &[var_root, x]));
504 assert!(!is_sat_with_unit_clauses(&encoding.cnf, &[var_root, -x]));
505 }
506
507 #[test]
508 fn test_dimacs_is_well_formed() {
509 let mut pir = PirGraph::new();
510 let a = pir.lit(LeafId::new(0));
511 let root = pir.or(vec![a]);
512
513 let encoding = encode_cnf(&pir, &[root]).unwrap();
514 let dimacs = encoding.cnf.to_dimacs();
515
516 assert!(dimacs.contains("\np cnf "));
517 assert!(dimacs.lines().any(|l| l.ends_with('0')));
518 }
519
520 #[test]
521 fn test_tseitin_neg_lit_uses_negated_polarity() {
522 let mut pir = PirGraph::new();
523 let a = pir.neg_lit(LeafId::new(0)); let root = pir.or(vec![a]);
525
526 let encoding = encode_cnf(&pir, &[root]).unwrap();
527 let var_root = *encoding.node_var.get(&root).unwrap() as i32;
528 let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
529
530 assert!(
534 is_sat_with_unit_clauses(&encoding.cnf, &[var_root, -var_a]),
535 "root=true with leaf=false should be SAT"
536 );
537
538 assert!(
545 !is_sat_with_unit_clauses(&encoding.cnf, &[var_root, var_a]),
546 "root=true with leaf=true should be UNSAT (NegLit of true leaf is false)"
547 );
548 }
549}