Skip to main content

xlog_prob/mc/
evidence.rs

1//! Evidence forcing for Monte Carlo sampling.
2
3use xlog_core::{Result, XlogError};
4
5use super::{McProgram, McSamplingMethod};
6
7/// Why evidence may or may not be forceable to root Bernoulli variables.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum ForceabilityReason {
10    AllForceable,
11    ContainsDerivedEvidence,
12    ContainsNegativeAdHeadEvidence,
13    NoEvidence,
14}
15
16/// Compiled evidence forcing for the MC sampler.
17#[derive(Debug, Clone)]
18pub struct EvidenceForcing {
19    pub force_mask: Vec<u8>,
20    pub forced_value: Vec<u8>,
21    pub forceable: bool,
22    pub reason: ForceabilityReason,
23}
24
25impl McProgram {
26    /// Resolve the sampling method from config + evidence forceability.
27    pub(super) fn resolve_sampling_method(
28        &self,
29        requested: Option<McSamplingMethod>,
30    ) -> Result<(McSamplingMethod, EvidenceForcing)> {
31        let forcing = self.compile_evidence_forcing()?;
32        let method = match requested {
33            Some(McSamplingMethod::EvidenceClamping) => {
34                if !forcing.forceable {
35                    return Err(XlogError::Execution(format!(
36                        "Cannot use EvidenceClamping: {:?}",
37                        forcing.reason
38                    )));
39                }
40                McSamplingMethod::EvidenceClamping
41            }
42            Some(McSamplingMethod::Rejection) => McSamplingMethod::Rejection,
43            None => {
44                if forcing.forceable {
45                    McSamplingMethod::EvidenceClamping
46                } else {
47                    McSamplingMethod::Rejection
48                }
49            }
50        };
51        Ok((method, forcing))
52    }
53
54    pub fn compile_evidence_forcing(&self) -> Result<EvidenceForcing> {
55        let num_vars = self.bernoulli_probs.len();
56        let mut force_mask = vec![0u8; num_vars];
57        let mut forced_value = vec![0u8; num_vars];
58
59        if self.evidence.is_empty() {
60            return Ok(EvidenceForcing {
61                force_mask,
62                forced_value,
63                forceable: false,
64                reason: ForceabilityReason::NoEvidence,
65            });
66        }
67
68        for (atom, expected) in &self.evidence {
69            // Try to match against prob fact specs
70            if let Some(spec) = self.prob_facts.iter().find(|s| &s.atom == atom) {
71                force_mask[spec.var_idx] = 1;
72                forced_value[spec.var_idx] = if *expected { 1 } else { 0 };
73                continue;
74            }
75
76            // Try to match against AD choice atoms (positive evidence only)
77            let mut found_ad = false;
78            for ad in &self.annotated_disjunctions {
79                if let Some(choice_idx) = ad.choices.iter().position(|c| c == atom) {
80                    if !*expected {
81                        // evidence(ad_head, false) — not forceable in v0.5.1
82                        return Ok(EvidenceForcing {
83                            force_mask: vec![0u8; num_vars],
84                            forced_value: vec![0u8; num_vars],
85                            forceable: false,
86                            reason: ForceabilityReason::ContainsNegativeAdHeadEvidence,
87                        });
88                    }
89
90                    let num_decision_vars = ad.decision_vars.len();
91                    if choice_idx < num_decision_vars {
92                        // Force d_i = 0 for all i < choice_idx, d_{choice_idx} = 1
93                        for i in 0..choice_idx {
94                            force_mask[ad.decision_vars[i]] = 1;
95                            forced_value[ad.decision_vars[i]] = 0;
96                        }
97                        force_mask[ad.decision_vars[choice_idx]] = 1;
98                        forced_value[ad.decision_vars[choice_idx]] = 1;
99                    } else {
100                        // Last head (no none branch): force all decision vars to 0
101                        for &dv in &ad.decision_vars {
102                            force_mask[dv] = 1;
103                            forced_value[dv] = 0;
104                        }
105                    }
106                    found_ad = true;
107                    break;
108                }
109            }
110            if found_ad {
111                continue;
112            }
113
114            // Evidence atom not found in prob facts or AD choices → derived/deterministic
115            return Ok(EvidenceForcing {
116                force_mask: vec![0u8; num_vars],
117                forced_value: vec![0u8; num_vars],
118                forceable: false,
119                reason: ForceabilityReason::ContainsDerivedEvidence,
120            });
121        }
122
123        Ok(EvidenceForcing {
124            force_mask,
125            forced_value,
126            forceable: true,
127            reason: ForceabilityReason::AllForceable,
128        })
129    }
130}