1use xlog_core::{Result, XlogError};
4
5use super::{McProgram, McSamplingMethod};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub enum ForceabilityReason {
10 AllForceable,
11 ContainsDerivedEvidence,
12 ContainsNegativeAdHeadEvidence,
13 NoEvidence,
14}
15
16#[derive(Debug, Clone)]
18pub struct EvidenceForcing {
19 pub force_mask: Vec<u8>,
20 pub forced_value: Vec<u8>,
21 pub forceable: bool,
22 pub reason: ForceabilityReason,
23}
24
25impl McProgram {
26 pub(super) fn resolve_sampling_method(
28 &self,
29 requested: Option<McSamplingMethod>,
30 ) -> Result<(McSamplingMethod, EvidenceForcing)> {
31 let forcing = self.compile_evidence_forcing()?;
32 let method = match requested {
33 Some(McSamplingMethod::EvidenceClamping) => {
34 if !forcing.forceable {
35 return Err(XlogError::Execution(format!(
36 "Cannot use EvidenceClamping: {:?}",
37 forcing.reason
38 )));
39 }
40 McSamplingMethod::EvidenceClamping
41 }
42 Some(McSamplingMethod::Rejection) => McSamplingMethod::Rejection,
43 None => {
44 if forcing.forceable {
45 McSamplingMethod::EvidenceClamping
46 } else {
47 McSamplingMethod::Rejection
48 }
49 }
50 };
51 Ok((method, forcing))
52 }
53
54 pub fn compile_evidence_forcing(&self) -> Result<EvidenceForcing> {
55 let num_vars = self.bernoulli_probs.len();
56 let mut force_mask = vec![0u8; num_vars];
57 let mut forced_value = vec![0u8; num_vars];
58
59 if self.evidence.is_empty() {
60 return Ok(EvidenceForcing {
61 force_mask,
62 forced_value,
63 forceable: false,
64 reason: ForceabilityReason::NoEvidence,
65 });
66 }
67
68 for (atom, expected) in &self.evidence {
69 if let Some(spec) = self.prob_facts.iter().find(|s| &s.atom == atom) {
71 force_mask[spec.var_idx] = 1;
72 forced_value[spec.var_idx] = if *expected { 1 } else { 0 };
73 continue;
74 }
75
76 let mut found_ad = false;
78 for ad in &self.annotated_disjunctions {
79 if let Some(choice_idx) = ad.choices.iter().position(|c| c == atom) {
80 if !*expected {
81 return Ok(EvidenceForcing {
83 force_mask: vec![0u8; num_vars],
84 forced_value: vec![0u8; num_vars],
85 forceable: false,
86 reason: ForceabilityReason::ContainsNegativeAdHeadEvidence,
87 });
88 }
89
90 let num_decision_vars = ad.decision_vars.len();
91 if choice_idx < num_decision_vars {
92 for i in 0..choice_idx {
94 force_mask[ad.decision_vars[i]] = 1;
95 forced_value[ad.decision_vars[i]] = 0;
96 }
97 force_mask[ad.decision_vars[choice_idx]] = 1;
98 forced_value[ad.decision_vars[choice_idx]] = 1;
99 } else {
100 for &dv in &ad.decision_vars {
102 force_mask[dv] = 1;
103 forced_value[dv] = 0;
104 }
105 }
106 found_ad = true;
107 break;
108 }
109 }
110 if found_ad {
111 continue;
112 }
113
114 return Ok(EvidenceForcing {
116 force_mask: vec![0u8; num_vars],
117 forced_value: vec![0u8; num_vars],
118 forceable: false,
119 reason: ForceabilityReason::ContainsDerivedEvidence,
120 });
121 }
122
123 Ok(EvidenceForcing {
124 force_mask,
125 forced_value,
126 forceable: true,
127 reason: ForceabilityReason::AllForceable,
128 })
129 }
130}