1#[derive(Debug, Clone)]
32pub struct NeuralOutput {
33 pub values: Vec<f64>,
35 pub labels: Vec<String>,
37}
38
39impl NeuralOutput {
40 pub fn new(values: Vec<f64>, labels: Vec<String>) -> Self {
42 debug_assert_eq!(
43 values.len(),
44 labels.len(),
45 "values and labels must have same length"
46 );
47 Self { values, labels }
48 }
49
50 pub fn with_integer_labels(values: Vec<f64>, labels: Vec<i64>) -> Self {
52 Self {
53 values,
54 labels: labels.into_iter().map(|i| i.to_string()).collect(),
55 }
56 }
57
58 pub fn num_classes(&self) -> usize {
60 self.values.len()
61 }
62
63 pub fn argmax(&self) -> Option<(usize, &str)> {
65 self.values
66 .iter()
67 .enumerate()
68 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
69 .map(|(idx, _)| (idx, self.labels[idx].as_str()))
70 }
71}
72
73#[derive(Debug, Clone)]
77pub struct ADProbability {
78 pub probability: f64,
80 pub label: String,
82}
83
84#[derive(Debug, Clone)]
88pub struct CircuitLeaf {
89 pub variable_id: usize,
91 pub weight: f64,
93}
94
95pub struct NeuralBridge {
99 epsilon: f64,
101}
102
103impl NeuralBridge {
104 pub fn new() -> Self {
106 Self { epsilon: 1e-8 }
107 }
108
109 pub fn with_epsilon(epsilon: f64) -> Self {
111 debug_assert!(epsilon > 0.0, "epsilon must be positive");
112 Self { epsilon }
113 }
114
115 pub fn to_ad_probabilities(&self, output: &NeuralOutput) -> Vec<ADProbability> {
119 output
120 .values
121 .iter()
122 .zip(output.labels.iter())
123 .map(|(&prob, label)| ADProbability {
124 probability: prob.max(self.epsilon).min(1.0),
125 label: label.clone(),
126 })
127 .collect()
128 }
129
130 pub fn batch_to_circuit_leaves(&self, outputs: &[NeuralOutput]) -> Vec<Vec<CircuitLeaf>> {
134 outputs
135 .iter()
136 .map(|output| {
137 output
138 .values
139 .iter()
140 .enumerate()
141 .map(|(i, &weight)| CircuitLeaf {
142 variable_id: i,
143 weight: weight.max(self.epsilon).min(1.0),
144 })
145 .collect()
146 })
147 .collect()
148 }
149
150 pub fn to_log_probabilities(&self, output: &NeuralOutput) -> Vec<f64> {
156 output
157 .values
158 .iter()
159 .map(|&p| (p.max(self.epsilon)).ln())
160 .collect()
161 }
162
163 pub fn normalize(&self, output: &NeuralOutput) -> NeuralOutput {
167 let sum: f64 = output.values.iter().sum();
168 if sum.abs() < self.epsilon {
169 let uniform = 1.0 / output.values.len() as f64;
171 NeuralOutput {
172 values: vec![uniform; output.values.len()],
173 labels: output.labels.clone(),
174 }
175 } else {
176 NeuralOutput {
177 values: output.values.iter().map(|&v| v / sum).collect(),
178 labels: output.labels.clone(),
179 }
180 }
181 }
182
183 pub fn extract_gradient_weights(&self, output: &NeuralOutput) -> Vec<f64> {
188 output.values.clone()
189 }
190
191 pub fn probability_of(&self, output: &NeuralOutput, label: &str) -> Option<f64> {
193 output
194 .labels
195 .iter()
196 .position(|l| l == label)
197 .map(|idx| output.values[idx].max(self.epsilon))
198 }
199
200 pub fn log_probability_of(&self, output: &NeuralOutput, label: &str) -> Option<f64> {
202 self.probability_of(output, label).map(|p| p.ln())
203 }
204
205 pub fn to_circuit_leaves_with_offset(
210 &self,
211 output: &NeuralOutput,
212 variable_offset: usize,
213 ) -> Vec<CircuitLeaf> {
214 output
215 .values
216 .iter()
217 .enumerate()
218 .map(|(i, &weight)| CircuitLeaf {
219 variable_id: variable_offset + i,
220 weight: weight.max(self.epsilon).min(1.0),
221 })
222 .collect()
223 }
224}
225
226impl Default for NeuralBridge {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_neural_output_argmax() {
238 let output = NeuralOutput {
239 values: vec![0.1, 0.7, 0.2],
240 labels: vec!["a".to_string(), "b".to_string(), "c".to_string()],
241 };
242
243 let (idx, label) = output.argmax().unwrap();
244 assert_eq!(idx, 1);
245 assert_eq!(label, "b");
246 }
247
248 #[test]
249 fn test_probability_of_label() {
250 let output = NeuralOutput {
251 values: vec![0.3, 0.5, 0.2],
252 labels: vec!["cat".to_string(), "dog".to_string(), "bird".to_string()],
253 };
254
255 let bridge = NeuralBridge::new();
256
257 assert!((bridge.probability_of(&output, "cat").unwrap() - 0.3).abs() < 1e-6);
258 assert!((bridge.probability_of(&output, "dog").unwrap() - 0.5).abs() < 1e-6);
259 assert!(bridge.probability_of(&output, "fish").is_none());
260 }
261
262 #[test]
263 fn test_circuit_leaves_with_offset() {
264 let output = NeuralOutput {
265 values: vec![0.4, 0.6],
266 labels: vec!["x".to_string(), "y".to_string()],
267 };
268
269 let bridge = NeuralBridge::new();
270 let leaves = bridge.to_circuit_leaves_with_offset(&output, 100);
271
272 assert_eq!(leaves[0].variable_id, 100);
273 assert_eq!(leaves[1].variable_id, 101);
274 }
275}