Skip to main content

xlog_prob/
cnf.rs

1//! CNF emission for PIR via Tseitin encoding (DIMACS).
2
3use std::collections::{BTreeMap, HashMap, HashSet};
4
5use xlog_core::{Result, XlogError};
6
7use crate::pir::{ChoiceVarId, LeafId, PirGraph, PirNode, PirNodeId};
8
9/// A CNF formula in DIMACS-style representation.
10#[derive(Debug, Clone, Default)]
11pub struct CnfFormula {
12    num_vars: u32,
13    clauses: Vec<Vec<i32>>,
14}
15
16impl CnfFormula {
17    pub fn num_vars(&self) -> u32 {
18        self.num_vars
19    }
20
21    pub fn clauses(&self) -> &[Vec<i32>] {
22        &self.clauses
23    }
24
25    pub fn to_dimacs(&self) -> String {
26        let mut out = String::new();
27        out.push_str("c xlog-prob cnf\n");
28        out.push_str(&format!("p cnf {} {}\n", self.num_vars, self.clauses.len()));
29        for clause in &self.clauses {
30            for lit in clause {
31                out.push_str(&format!("{} ", lit));
32            }
33            out.push_str("0\n");
34        }
35        out
36    }
37}
38
39#[derive(Debug, Clone)]
40pub struct CnfEncoding {
41    pub cnf: CnfFormula,
42    pub node_var: BTreeMap<PirNodeId, u32>,
43    pub leaf_var: BTreeMap<LeafId, u32>,
44    pub choice_var: BTreeMap<ChoiceVarId, u32>,
45}
46
47/// FNV-1a 64-bit — deterministic, process-independent.
48fn fnv1a(bytes: &[u8]) -> u64 {
49    const OFFSET: u64 = 0xcbf29ce484222325;
50    const PRIME: u64 = 0x100000001b3;
51    let mut h = OFFSET;
52    for &b in bytes {
53        h ^= b as u64;
54        h = h.wrapping_mul(PRIME);
55    }
56    h
57}
58
59/// Compute a canonical hash for each reachable PIR node.
60/// The hash depends only on node structure (variant + leaf/choice IDs + children's
61/// canonical hashes), not on PirNodeId numeric values.
62fn canonical_hashes(pir: &PirGraph, levels: &[Vec<PirNodeId>]) -> HashMap<PirNodeId, u64> {
63    let mut canon: HashMap<PirNodeId, u64> = HashMap::new();
64    for level in levels {
65        for &id in level {
66            let node = pir.node(id).unwrap();
67            let h = match node {
68                PirNode::Const(b) => fnv1a(&[0, *b as u8]),
69                PirNode::Lit { leaf } => {
70                    let mut buf = [0u8; 5];
71                    buf[0] = 1;
72                    buf[1..5].copy_from_slice(&leaf.as_u32().to_le_bytes());
73                    fnv1a(&buf)
74                }
75                PirNode::NegLit { leaf } => {
76                    let mut buf = [0u8; 5];
77                    buf[0] = 2;
78                    buf[1..5].copy_from_slice(&leaf.as_u32().to_le_bytes());
79                    fnv1a(&buf)
80                }
81                PirNode::And { children } => {
82                    let mut child_h: Vec<u64> = children.iter().map(|c| canon[c]).collect();
83                    child_h.sort();
84                    let mut buf = vec![3u8];
85                    for h in child_h {
86                        buf.extend_from_slice(&h.to_le_bytes());
87                    }
88                    fnv1a(&buf)
89                }
90                PirNode::Or { children } => {
91                    let mut child_h: Vec<u64> = children.iter().map(|c| canon[c]).collect();
92                    child_h.sort();
93                    let mut buf = vec![4u8];
94                    for h in child_h {
95                        buf.extend_from_slice(&h.to_le_bytes());
96                    }
97                    fnv1a(&buf)
98                }
99                PirNode::Decision {
100                    var,
101                    child_false,
102                    child_true,
103                } => {
104                    let mut buf = vec![5u8];
105                    buf.extend_from_slice(&var.as_u32().to_le_bytes());
106                    buf.extend_from_slice(&canon[child_false].to_le_bytes());
107                    buf.extend_from_slice(&canon[child_true].to_le_bytes());
108                    fnv1a(&buf)
109                }
110            };
111            canon.insert(id, h);
112        }
113    }
114    canon
115}
116
117/// Compute a process-independent canonical hash of the PIR structure.
118///
119/// This hash depends only on the semantic structure of the PIR graph (node types,
120/// leaf/choice IDs, children's structural hashes), NOT on `PirNodeId` numeric
121/// values. Two processes that build the same XLOG program will produce the same
122/// canonical hash even if `HashMap`-based interning assigns different node IDs.
123///
124/// Used as the `cnf_hash` component of the disk cache key to enable cross-process
125/// cache hits.
126pub fn canonical_pir_hash(pir: &PirGraph, roots: &[PirNodeId]) -> Result<u64> {
127    if roots.is_empty() {
128        return Err(XlogError::Compilation(
129            "Cannot compute canonical hash for empty PIR root set".to_string(),
130        ));
131    }
132    let levels = pir.levelize(roots)?;
133    let canon = canonical_hashes(pir, &levels);
134
135    // Combine root hashes in a deterministic order (sorted by canonical hash).
136    let mut root_hashes: Vec<u64> = roots
137        .iter()
138        .map(|r| canon.get(r).copied().unwrap_or(0))
139        .collect();
140    root_hashes.sort();
141
142    let mut buf = Vec::with_capacity(1 + root_hashes.len() * 8);
143    buf.push(0xFFu8); // tag to distinguish from node hashes
144    for h in root_hashes {
145        buf.extend_from_slice(&h.to_le_bytes());
146    }
147    Ok(fnv1a(&buf))
148}
149
150pub fn encode_cnf(pir: &PirGraph, roots: &[PirNodeId]) -> Result<CnfEncoding> {
151    if roots.is_empty() {
152        return Err(XlogError::Compilation(
153            "Cannot encode CNF for empty PIR root set".to_string(),
154        ));
155    }
156
157    let mut visited: HashSet<PirNodeId> = HashSet::new();
158    let mut leaf_ids: HashSet<LeafId> = HashSet::new();
159    let mut choice_ids: HashSet<ChoiceVarId> = HashSet::new();
160
161    let mut stack: Vec<PirNodeId> = roots.to_vec();
162    while let Some(node_id) = stack.pop() {
163        if !visited.insert(node_id) {
164            continue;
165        }
166
167        let node = pir.node(node_id).ok_or_else(|| {
168            XlogError::Compilation(format!(
169                "Invalid PIR node id while encoding CNF: {:?}",
170                node_id
171            ))
172        })?;
173
174        match node {
175            PirNode::Const(_) => {}
176            PirNode::Lit { leaf } | PirNode::NegLit { leaf } => {
177                leaf_ids.insert(*leaf);
178            }
179            PirNode::And { children } | PirNode::Or { children } => {
180                stack.extend(children.iter().copied());
181            }
182            PirNode::Decision {
183                var,
184                child_false,
185                child_true,
186            } => {
187                choice_ids.insert(*var);
188                stack.push(*child_false);
189                stack.push(*child_true);
190            }
191        }
192    }
193
194    let mut leaf_list: Vec<LeafId> = leaf_ids.into_iter().collect();
195    leaf_list.sort();
196    let mut choice_list: Vec<ChoiceVarId> = choice_ids.into_iter().collect();
197    choice_list.sort();
198
199    let mut next_var: u32 = 1;
200    let mut leaf_var: BTreeMap<LeafId, u32> = BTreeMap::new();
201    for leaf in leaf_list {
202        leaf_var.insert(leaf, next_var);
203        next_var += 1;
204    }
205
206    let mut choice_var: BTreeMap<ChoiceVarId, u32> = BTreeMap::new();
207    for choice in choice_list {
208        choice_var.insert(choice, next_var);
209        next_var += 1;
210    }
211
212    let mut node_ids: Vec<PirNodeId> = visited.into_iter().collect();
213    node_ids.sort();
214
215    let mut node_var: BTreeMap<PirNodeId, u32> = BTreeMap::new();
216    for node_id in node_ids {
217        let node = pir.node(node_id).ok_or_else(|| {
218            XlogError::Compilation(format!(
219                "Invalid PIR node id while encoding CNF: {:?}",
220                node_id
221            ))
222        })?;
223
224        let var_id = match node {
225            PirNode::Lit { leaf } => *leaf_var.get(leaf).ok_or_else(|| {
226                XlogError::Compilation(format!(
227                    "Missing CNF var for PIR leaf {:?} referenced by node {:?}",
228                    leaf, node_id
229                ))
230            })?,
231            PirNode::NegLit { .. } => {
232                // NegLit gets its own variable, which will be constrained to !leaf_var
233                let v = next_var;
234                next_var += 1;
235                v
236            }
237            _ => {
238                let v = next_var;
239                next_var += 1;
240                v
241            }
242        };
243
244        node_var.insert(node_id, var_id);
245    }
246
247    let num_vars = next_var - 1;
248    let mut clauses: Vec<Vec<i32>> = Vec::new();
249
250    let levels = pir.levelize(roots)?;
251    for level in &levels {
252        for &node_id in level {
253            let node = pir.node(node_id).ok_or_else(|| {
254                XlogError::Compilation(format!(
255                    "Invalid PIR node id while emitting CNF clauses: {:?}",
256                    node_id
257                ))
258            })?;
259            let v = *node_var.get(&node_id).ok_or_else(|| {
260                XlogError::Compilation(format!(
261                    "Missing CNF var for PIR node {:?} while emitting clauses",
262                    node_id
263                ))
264            })? as i32;
265
266            match node {
267                PirNode::Const(true) => clauses.push(vec![v]),
268                PirNode::Const(false) => clauses.push(vec![-v]),
269                PirNode::Lit { .. } => {}
270                PirNode::NegLit { leaf } => {
271                    // NegLit uses opposite polarity: node_var <-> !leaf_var
272                    let leaf_v = *leaf_var.get(leaf).ok_or_else(|| {
273                        XlogError::Compilation(format!(
274                            "Missing CNF var for NegLit leaf {:?} at node {:?}",
275                            leaf, node_id
276                        ))
277                    })? as i32;
278                    // v <-> !leaf_v  means:  (v | leaf_v) & (!v | !leaf_v)
279                    clauses.push(vec![v, leaf_v]); // v=false -> leaf_v=true (contrapositive: !leaf_v -> v)
280                    clauses.push(vec![-v, -leaf_v]); // v=true -> leaf_v=false
281                }
282                PirNode::And { children } => {
283                    if children.is_empty() {
284                        clauses.push(vec![v]);
285                        continue;
286                    }
287                    for &child in children {
288                        let c = *node_var.get(&child).ok_or_else(|| {
289                            XlogError::Compilation(format!(
290                                "Missing CNF var for AND child {:?} of {:?}",
291                                child, node_id
292                            ))
293                        })? as i32;
294                        clauses.push(vec![-v, c]);
295                    }
296                    let mut clause = Vec::with_capacity(children.len() + 1);
297                    for &child in children {
298                        let c = *node_var.get(&child).ok_or_else(|| {
299                            XlogError::Compilation(format!(
300                                "Missing CNF var for AND child {:?} of {:?}",
301                                child, node_id
302                            ))
303                        })? as i32;
304                        clause.push(-c);
305                    }
306                    clause.push(v);
307                    clauses.push(clause);
308                }
309                PirNode::Or { children } => {
310                    if children.is_empty() {
311                        clauses.push(vec![-v]);
312                        continue;
313                    }
314                    for &child in children {
315                        let c = *node_var.get(&child).ok_or_else(|| {
316                            XlogError::Compilation(format!(
317                                "Missing CNF var for OR child {:?} of {:?}",
318                                child, node_id
319                            ))
320                        })? as i32;
321                        clauses.push(vec![-c, v]);
322                    }
323                    let mut clause = Vec::with_capacity(children.len() + 1);
324                    clause.push(-v);
325                    for &child in children {
326                        let c = *node_var.get(&child).ok_or_else(|| {
327                            XlogError::Compilation(format!(
328                                "Missing CNF var for OR child {:?} of {:?}",
329                                child, node_id
330                            ))
331                        })? as i32;
332                        clause.push(c);
333                    }
334                    clauses.push(clause);
335                }
336                PirNode::Decision {
337                    var,
338                    child_false,
339                    child_true,
340                } => {
341                    let x = *choice_var.get(var).ok_or_else(|| {
342                        XlogError::Compilation(format!(
343                            "Missing CNF var for decision variable {:?} at node {:?}",
344                            var, node_id
345                        ))
346                    })? as i32;
347
348                    let f = *node_var.get(child_false).ok_or_else(|| {
349                        XlogError::Compilation(format!(
350                            "Missing CNF var for decision false child {:?} at node {:?}",
351                            child_false, node_id
352                        ))
353                    })? as i32;
354                    let t = *node_var.get(child_true).ok_or_else(|| {
355                        XlogError::Compilation(format!(
356                            "Missing CNF var for decision true child {:?} at node {:?}",
357                            child_true, node_id
358                        ))
359                    })? as i32;
360
361                    // v <-> (x ? t : f)
362                    clauses.push(vec![-x, -t, v]); // (x & t) -> v
363                    clauses.push(vec![x, -f, v]); // (!x & f) -> v
364                    clauses.push(vec![-x, t, -v]); // (v & x) -> t
365                    clauses.push(vec![x, f, -v]); // (v & !x) -> f
366                }
367            }
368        }
369    }
370
371    Ok(CnfEncoding {
372        cnf: CnfFormula { num_vars, clauses },
373        node_var,
374        leaf_var,
375        choice_var,
376    })
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::pir::{ChoiceVarId, LeafId, PirGraph};
383
384    fn is_sat_with_unit_clauses(cnf: &CnfFormula, units: &[i32]) -> bool {
385        let num_vars = cnf.num_vars() as usize;
386        assert!(
387            num_vars <= 20,
388            "test sat checker only supports small CNFs (vars={})",
389            num_vars
390        );
391
392        let mut clauses: Vec<&[i32]> = cnf.clauses().iter().map(|c| c.as_slice()).collect();
393        let unit_clauses: Vec<Vec<i32>> = units.iter().map(|&u| vec![u]).collect();
394        for uc in &unit_clauses {
395            clauses.push(uc.as_slice());
396        }
397
398        'assign: for mask in 0..(1u64 << num_vars) {
399            for clause in &clauses {
400                let mut clause_sat = false;
401                for &lit in *clause {
402                    let var = lit.unsigned_abs() as usize;
403                    assert!(var >= 1 && var <= num_vars);
404                    let bit = (mask >> (var - 1)) & 1;
405                    let val = bit == 1;
406                    let lit_sat = if lit > 0 { val } else { !val };
407                    if lit_sat {
408                        clause_sat = true;
409                        break;
410                    }
411                }
412                if !clause_sat {
413                    continue 'assign;
414                }
415            }
416            return true;
417        }
418        false
419    }
420
421    #[test]
422    fn test_encode_cnf_does_not_force_root_assignment() {
423        let mut pir = PirGraph::new();
424        let a = pir.lit(LeafId::new(0));
425
426        let encoding = encode_cnf(&pir, &[a]).unwrap();
427        let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
428
429        assert!(is_sat_with_unit_clauses(&encoding.cnf, &[-var_a]));
430    }
431
432    #[test]
433    fn test_tseitin_and_requires_both_children() {
434        let mut pir = PirGraph::new();
435        let a = pir.lit(LeafId::new(0));
436        let b = pir.lit(LeafId::new(1));
437        let root = pir.and(vec![a, b]);
438
439        let encoding = encode_cnf(&pir, &[root]).unwrap();
440        let var_root = *encoding.node_var.get(&root).unwrap() as i32;
441        let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
442        let var_b = *encoding.leaf_var.get(&LeafId::new(1)).unwrap() as i32;
443
444        assert!(is_sat_with_unit_clauses(
445            &encoding.cnf,
446            &[var_root, var_a, var_b]
447        ));
448        assert!(!is_sat_with_unit_clauses(
449            &encoding.cnf,
450            &[var_root, var_a, -var_b]
451        ));
452        assert!(!is_sat_with_unit_clauses(
453            &encoding.cnf,
454            &[var_root, -var_a, var_b]
455        ));
456        assert!(!is_sat_with_unit_clauses(
457            &encoding.cnf,
458            &[var_root, -var_a, -var_b]
459        ));
460    }
461
462    #[test]
463    fn test_tseitin_or_requires_one_child() {
464        let mut pir = PirGraph::new();
465        let a = pir.lit(LeafId::new(0));
466        let b = pir.lit(LeafId::new(1));
467        let root = pir.or(vec![a, b]);
468
469        let encoding = encode_cnf(&pir, &[root]).unwrap();
470        let var_root = *encoding.node_var.get(&root).unwrap() as i32;
471        let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
472        let var_b = *encoding.leaf_var.get(&LeafId::new(1)).unwrap() as i32;
473
474        assert!(is_sat_with_unit_clauses(
475            &encoding.cnf,
476            &[var_root, var_a, var_b]
477        ));
478        assert!(is_sat_with_unit_clauses(
479            &encoding.cnf,
480            &[var_root, var_a, -var_b]
481        ));
482        assert!(is_sat_with_unit_clauses(
483            &encoding.cnf,
484            &[var_root, -var_a, var_b]
485        ));
486        assert!(!is_sat_with_unit_clauses(
487            &encoding.cnf,
488            &[var_root, -var_a, -var_b]
489        ));
490    }
491
492    #[test]
493    fn test_tseitin_decision_mux_matches_choice_var() {
494        let mut pir = PirGraph::new();
495        let t = pir.const_true();
496        let f = pir.const_false();
497        let root = pir.decision(ChoiceVarId::new(0), f, t);
498
499        let encoding = encode_cnf(&pir, &[root]).unwrap();
500        let var_root = *encoding.node_var.get(&root).unwrap() as i32;
501        let x = *encoding.choice_var.get(&ChoiceVarId::new(0)).unwrap() as i32;
502
503        assert!(is_sat_with_unit_clauses(&encoding.cnf, &[var_root, x]));
504        assert!(!is_sat_with_unit_clauses(&encoding.cnf, &[var_root, -x]));
505    }
506
507    #[test]
508    fn test_dimacs_is_well_formed() {
509        let mut pir = PirGraph::new();
510        let a = pir.lit(LeafId::new(0));
511        let root = pir.or(vec![a]);
512
513        let encoding = encode_cnf(&pir, &[root]).unwrap();
514        let dimacs = encoding.cnf.to_dimacs();
515
516        assert!(dimacs.contains("\np cnf "));
517        assert!(dimacs.lines().any(|l| l.ends_with('0')));
518    }
519
520    #[test]
521    fn test_tseitin_neg_lit_uses_negated_polarity() {
522        let mut pir = PirGraph::new();
523        let a = pir.neg_lit(LeafId::new(0)); // NegLit instead of Lit
524        let root = pir.or(vec![a]);
525
526        let encoding = encode_cnf(&pir, &[root]).unwrap();
527        let var_root = *encoding.node_var.get(&root).unwrap() as i32;
528        let var_a = *encoding.leaf_var.get(&LeafId::new(0)).unwrap() as i32;
529
530        // When NegLit node is true, the underlying leaf var should be FALSE
531        // (because NegLit represents "not leaf")
532        // So if root is true and NegLit(a) makes it true, leaf a must be false
533        assert!(
534            is_sat_with_unit_clauses(&encoding.cnf, &[var_root, -var_a]),
535            "root=true with leaf=false should be SAT"
536        );
537
538        // If leaf is true, then NegLit(leaf) is false, so root (which is Or of NegLit) should be false
539        // Actually wait - the Or node can be true or false. Let's think more carefully.
540        //
541        // The CNF encodes: root <-> Or(neg_lit_node)
542        //                  neg_lit_node <-> !leaf
543        // So: root=true requires neg_lit_node=true requires leaf=false
544        assert!(
545            !is_sat_with_unit_clauses(&encoding.cnf, &[var_root, var_a]),
546            "root=true with leaf=true should be UNSAT (NegLit of true leaf is false)"
547        );
548    }
549}