Skip to main content

xlog_neural/
bridge.rs

1//! Neural → Probability Bridge
2//!
3//! This module converts neural network outputs (softmax probability distributions)
4//! into probabilistic logic constructs used by XLOG's inference engine.
5//!
6//! # Architecture
7//!
8//! Neural networks produce softmax outputs: `[p1, p2, ..., pn]` for n labels.
9//! These are converted to:
10//!
11//! 1. **Annotated Disjunctions**: `p1::pred(X,l1); p2::pred(X,l2); ...`
12//! 2. **Circuit Leaves**: Weighted leaf nodes for d-DNNF circuit evaluation
13//! 3. **Log Probabilities**: For numerical stability in gradient computation
14//!
15//! # Example
16//!
17//! ```
18//! use xlog_neural::bridge::{NeuralBridge, NeuralOutput};
19//!
20//! let output = NeuralOutput {
21//!     values: vec![0.7, 0.2, 0.1],
22//!     labels: vec!["a".to_string(), "b".to_string(), "c".to_string()],
23//! };
24//!
25//! let bridge = NeuralBridge::new();
26//! let probs = bridge.to_ad_probabilities(&output);
27//! // probs[0] = ADProbability { probability: 0.7, label: "a" }
28//! ```
29
30/// Neural network output with probability distribution over labels.
31#[derive(Debug, Clone)]
32pub struct NeuralOutput {
33    /// Softmax probability values (should sum to ~1.0)
34    pub values: Vec<f64>,
35    /// Corresponding label names (strings or integer representations)
36    pub labels: Vec<String>,
37}
38
39impl NeuralOutput {
40    /// Create output with string labels.
41    pub fn new(values: Vec<f64>, labels: Vec<String>) -> Self {
42        debug_assert_eq!(
43            values.len(),
44            labels.len(),
45            "values and labels must have same length"
46        );
47        Self { values, labels }
48    }
49
50    /// Create output with integer labels converted to strings.
51    pub fn with_integer_labels(values: Vec<f64>, labels: Vec<i64>) -> Self {
52        Self {
53            values,
54            labels: labels.into_iter().map(|i| i.to_string()).collect(),
55        }
56    }
57
58    /// Number of classes/labels.
59    pub fn num_classes(&self) -> usize {
60        self.values.len()
61    }
62
63    /// Get the argmax (most likely class).
64    pub fn argmax(&self) -> Option<(usize, &str)> {
65        self.values
66            .iter()
67            .enumerate()
68            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
69            .map(|(idx, _)| (idx, self.labels[idx].as_str()))
70    }
71}
72
73/// Annotated disjunction probability component.
74///
75/// Represents one choice in an annotated disjunction: `probability::pred(X, label)`
76#[derive(Debug, Clone)]
77pub struct ADProbability {
78    /// Probability weight (clamped to [epsilon, 1.0])
79    pub probability: f64,
80    /// Label value as string
81    pub label: String,
82}
83
84/// Circuit leaf node for d-DNNF evaluation.
85///
86/// Each leaf corresponds to a probabilistic variable with a weight.
87#[derive(Debug, Clone)]
88pub struct CircuitLeaf {
89    /// Variable ID in the circuit
90    pub variable_id: usize,
91    /// Weight for weighted model counting
92    pub weight: f64,
93}
94
95/// Bridge for converting neural outputs to probabilistic constructs.
96///
97/// Handles numerical stability through epsilon clamping and normalization.
98pub struct NeuralBridge {
99    /// Minimum probability to prevent log(0)
100    epsilon: f64,
101}
102
103impl NeuralBridge {
104    /// Create a new bridge with default epsilon (1e-8).
105    pub fn new() -> Self {
106        Self { epsilon: 1e-8 }
107    }
108
109    /// Create a bridge with custom epsilon.
110    pub fn with_epsilon(epsilon: f64) -> Self {
111        debug_assert!(epsilon > 0.0, "epsilon must be positive");
112        Self { epsilon }
113    }
114
115    /// Convert softmax output to annotated disjunction probabilities.
116    ///
117    /// Each probability is clamped to [epsilon, 1.0] for numerical stability.
118    pub fn to_ad_probabilities(&self, output: &NeuralOutput) -> Vec<ADProbability> {
119        output
120            .values
121            .iter()
122            .zip(output.labels.iter())
123            .map(|(&prob, label)| ADProbability {
124                probability: prob.max(self.epsilon).min(1.0),
125                label: label.clone(),
126            })
127            .collect()
128    }
129
130    /// Convert batch of neural outputs to circuit leaf weights.
131    ///
132    /// Returns a 2D structure: `leaves[sample_idx][label_idx]`
133    pub fn batch_to_circuit_leaves(&self, outputs: &[NeuralOutput]) -> Vec<Vec<CircuitLeaf>> {
134        outputs
135            .iter()
136            .map(|output| {
137                output
138                    .values
139                    .iter()
140                    .enumerate()
141                    .map(|(i, &weight)| CircuitLeaf {
142                        variable_id: i,
143                        weight: weight.max(self.epsilon).min(1.0),
144                    })
145                    .collect()
146            })
147            .collect()
148    }
149
150    /// Convert to log probabilities for numerical stability.
151    ///
152    /// Log probabilities are used for:
153    /// - Computing NLL loss: `-log(p_true)`
154    /// - Avoiding underflow in product of many small probabilities
155    pub fn to_log_probabilities(&self, output: &NeuralOutput) -> Vec<f64> {
156        output
157            .values
158            .iter()
159            .map(|&p| (p.max(self.epsilon)).ln())
160            .collect()
161    }
162
163    /// Normalize probabilities to sum to 1.0.
164    ///
165    /// Useful when network outputs have small numerical errors.
166    pub fn normalize(&self, output: &NeuralOutput) -> NeuralOutput {
167        let sum: f64 = output.values.iter().sum();
168        if sum.abs() < self.epsilon {
169            // Avoid division by zero - uniform distribution
170            let uniform = 1.0 / output.values.len() as f64;
171            NeuralOutput {
172                values: vec![uniform; output.values.len()],
173                labels: output.labels.clone(),
174            }
175        } else {
176            NeuralOutput {
177                values: output.values.iter().map(|&v| v / sum).collect(),
178                labels: output.labels.clone(),
179            }
180        }
181    }
182
183    /// Extract raw weights for gradient computation.
184    ///
185    /// These weights are passed to the backward pass to compute
186    /// gradients w.r.t. the neural network parameters.
187    pub fn extract_gradient_weights(&self, output: &NeuralOutput) -> Vec<f64> {
188        output.values.clone()
189    }
190
191    /// Compute the probability of a specific label.
192    pub fn probability_of(&self, output: &NeuralOutput, label: &str) -> Option<f64> {
193        output
194            .labels
195            .iter()
196            .position(|l| l == label)
197            .map(|idx| output.values[idx].max(self.epsilon))
198    }
199
200    /// Compute the log probability of a specific label.
201    pub fn log_probability_of(&self, output: &NeuralOutput, label: &str) -> Option<f64> {
202        self.probability_of(output, label).map(|p| p.ln())
203    }
204
205    /// Create circuit leaves for a single sample with variable ID offset.
206    ///
207    /// Used when multiple samples share a circuit structure but have
208    /// different variable ID ranges.
209    pub fn to_circuit_leaves_with_offset(
210        &self,
211        output: &NeuralOutput,
212        variable_offset: usize,
213    ) -> Vec<CircuitLeaf> {
214        output
215            .values
216            .iter()
217            .enumerate()
218            .map(|(i, &weight)| CircuitLeaf {
219                variable_id: variable_offset + i,
220                weight: weight.max(self.epsilon).min(1.0),
221            })
222            .collect()
223    }
224}
225
226impl Default for NeuralBridge {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_neural_output_argmax() {
238        let output = NeuralOutput {
239            values: vec![0.1, 0.7, 0.2],
240            labels: vec!["a".to_string(), "b".to_string(), "c".to_string()],
241        };
242
243        let (idx, label) = output.argmax().unwrap();
244        assert_eq!(idx, 1);
245        assert_eq!(label, "b");
246    }
247
248    #[test]
249    fn test_probability_of_label() {
250        let output = NeuralOutput {
251            values: vec![0.3, 0.5, 0.2],
252            labels: vec!["cat".to_string(), "dog".to_string(), "bird".to_string()],
253        };
254
255        let bridge = NeuralBridge::new();
256
257        assert!((bridge.probability_of(&output, "cat").unwrap() - 0.3).abs() < 1e-6);
258        assert!((bridge.probability_of(&output, "dog").unwrap() - 0.5).abs() < 1e-6);
259        assert!(bridge.probability_of(&output, "fish").is_none());
260    }
261
262    #[test]
263    fn test_circuit_leaves_with_offset() {
264        let output = NeuralOutput {
265            values: vec![0.4, 0.6],
266            labels: vec!["x".to_string(), "y".to_string()],
267        };
268
269        let bridge = NeuralBridge::new();
270        let leaves = bridge.to_circuit_leaves_with_offset(&output, 100);
271
272        assert_eq!(leaves[0].variable_id, 100);
273        assert_eq!(leaves[1].variable_id, 101);
274    }
275}