Skip to main content

xlog_neural/
batch.rs

1//! Batched Neural Evaluation
2//!
3//! This module provides infrastructure for grouping neural predicate calls
4//! by network name, enabling efficient batched GPU evaluation.
5//!
6//! # Why Batching?
7//!
8//! In DeepProbLog-style programs, the same neural network may be called many times:
9//!
10//! ```text
11//! nn(mnist_net, [X], Y, [0..9]) :: digit(X, Y).
12//! addition(X, Y, Z) :- digit(X, LeftDigit), digit(Y, RightDigit), Z is LeftDigit + RightDigit.
13//! ```
14//!
15//! For a query like `addition(img1, img2, Z)`, we need to evaluate `mnist_net`
16//! twice (once for each digit). Instead of two separate forward passes, we batch
17//! them into a single `mnist_net([img1, img2])` call for GPU efficiency.
18//!
19//! # Usage
20//!
21//! ```
22//! use xlog_neural::batch::{BatchCollector, NeuralCall};
23//!
24//! let mut collector = BatchCollector::new();
25//!
26//! // During proof search, collect neural calls
27//! collector.add(NeuralCall::new("mnist", vec![0])); // digit(img[0], Y)
28//! collector.add(NeuralCall::new("mnist", vec![1])); // digit(img[1], Y)
29//!
30//! // Group by network for batched evaluation
31//! let batches = collector.collect();
32//! let mnist_indices = collector.indices_for_network("mnist");
33//! // mnist_indices = [0, 1] - evaluate both in one forward pass
34//! ```
35
36use std::collections::HashMap;
37
38/// A single neural predicate call.
39///
40/// Records which network to call and which input indices to use
41/// from the active tensor source.
42#[derive(Debug, Clone)]
43pub struct NeuralCall {
44    /// Name of the neural network (must match registered network)
45    pub network: String,
46    /// Indices into the active tensor source
47    pub input_indices: Vec<usize>,
48}
49
50impl NeuralCall {
51    /// Create a new neural call.
52    pub fn new(network: &str, input_indices: Vec<usize>) -> Self {
53        Self {
54            network: network.to_string(),
55            input_indices,
56        }
57    }
58
59    /// Create a call with a single input index.
60    pub fn single(network: &str, index: usize) -> Self {
61        Self::new(network, vec![index])
62    }
63
64    /// Number of inputs in this call.
65    pub fn num_inputs(&self) -> usize {
66        self.input_indices.len()
67    }
68}
69
70/// Collects neural predicate calls for batched evaluation.
71///
72/// During proof search, calls are accumulated. Before neural evaluation,
73/// they are grouped by network name for efficient batched forward passes.
74pub struct BatchCollector {
75    /// All collected calls in order
76    calls: Vec<NeuralCall>,
77}
78
79impl BatchCollector {
80    /// Create a new empty collector.
81    pub fn new() -> Self {
82        Self { calls: Vec::new() }
83    }
84
85    /// Create a collector with pre-allocated capacity.
86    pub fn with_capacity(capacity: usize) -> Self {
87        Self {
88            calls: Vec::with_capacity(capacity),
89        }
90    }
91
92    /// Add a neural call to the collector.
93    pub fn add(&mut self, call: NeuralCall) {
94        self.calls.push(call);
95    }
96
97    /// Group calls by network name for batched evaluation.
98    ///
99    /// Returns a map from network name to list of calls for that network.
100    pub fn collect(&self) -> HashMap<String, Vec<&NeuralCall>> {
101        let mut batches: HashMap<String, Vec<&NeuralCall>> = HashMap::new();
102
103        for call in &self.calls {
104            batches.entry(call.network.clone()).or_default().push(call);
105        }
106
107        batches
108    }
109
110    /// Get all input indices for a specific network.
111    ///
112    /// These indices can be used to gather inputs from the tensor source
113    /// into a batched tensor for the forward pass.
114    pub fn indices_for_network(&self, network: &str) -> Vec<usize> {
115        self.calls
116            .iter()
117            .filter(|c| c.network == network)
118            .flat_map(|c| c.input_indices.iter().copied())
119            .collect()
120    }
121
122    /// Get the names of all networks that have been called.
123    pub fn network_names(&self) -> Vec<String> {
124        let mut names: Vec<String> = self
125            .calls
126            .iter()
127            .map(|c| c.network.clone())
128            .collect::<std::collections::HashSet<_>>()
129            .into_iter()
130            .collect();
131        names.sort();
132        names
133    }
134
135    /// Get the number of calls for a specific network.
136    pub fn call_count_for_network(&self, network: &str) -> usize {
137        self.calls.iter().filter(|c| c.network == network).count()
138    }
139
140    /// Get the total number of input indices across all calls.
141    pub fn total_input_count(&self) -> usize {
142        self.calls.iter().map(|c| c.input_indices.len()).sum()
143    }
144
145    /// Total number of calls.
146    pub fn len(&self) -> usize {
147        self.calls.len()
148    }
149
150    /// Check if the collector is empty.
151    pub fn is_empty(&self) -> bool {
152        self.calls.is_empty()
153    }
154
155    /// Clear all collected calls.
156    pub fn clear(&mut self) {
157        self.calls.clear();
158    }
159
160    /// Iterate over all calls.
161    pub fn iter(&self) -> impl Iterator<Item = &NeuralCall> {
162        self.calls.iter()
163    }
164
165    /// Take ownership of all calls.
166    pub fn take(&mut self) -> Vec<NeuralCall> {
167        std::mem::take(&mut self.calls)
168    }
169}
170
171impl Default for BatchCollector {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177/// Result of a batched neural evaluation.
178///
179/// Contains the outputs for all inputs processed in a single batch.
180#[derive(Debug, Clone)]
181pub struct BatchResult {
182    /// Network that produced these results
183    pub network: String,
184    /// Output probability distributions, one per input
185    /// `outputs[i]` is the softmax output for the i-th input in the batch
186    pub outputs: Vec<Vec<f64>>,
187}
188
189impl BatchResult {
190    /// Create a new batch result.
191    pub fn new(network: &str, outputs: Vec<Vec<f64>>) -> Self {
192        Self {
193            network: network.to_string(),
194            outputs,
195        }
196    }
197
198    /// Get output for a specific index in the batch.
199    pub fn get_output(&self, index: usize) -> Option<&Vec<f64>> {
200        self.outputs.get(index)
201    }
202
203    /// Number of outputs in this batch.
204    pub fn len(&self) -> usize {
205        self.outputs.len()
206    }
207
208    /// Check if the batch is empty.
209    pub fn is_empty(&self) -> bool {
210        self.outputs.is_empty()
211    }
212
213    /// Iterate over outputs.
214    pub fn iter(&self) -> impl Iterator<Item = &Vec<f64>> {
215        self.outputs.iter()
216    }
217}
218
219/// Mapping from call index to batch result index.
220///
221/// When calls are batched, we need to track which result corresponds
222/// to which original call for reconstructing per-call outputs.
223#[derive(Debug, Clone)]
224pub struct BatchMapping {
225    /// For each original call index, the (batch_network, index_in_batch)
226    mappings: Vec<(String, usize)>,
227}
228
229impl BatchMapping {
230    /// Create a new mapping from a collector.
231    pub fn from_collector(collector: &BatchCollector) -> Self {
232        let mut mappings = Vec::with_capacity(collector.len());
233        let mut network_counts: HashMap<String, usize> = HashMap::new();
234
235        for call in collector.iter() {
236            let idx = *network_counts.entry(call.network.clone()).or_insert(0);
237            mappings.push((call.network.clone(), idx));
238            *network_counts.get_mut(&call.network).unwrap() += 1;
239        }
240
241        Self { mappings }
242    }
243
244    /// Look up which batch result to use for a given call index.
245    pub fn get(&self, call_index: usize) -> Option<(&str, usize)> {
246        self.mappings
247            .get(call_index)
248            .map(|(net, idx)| (net.as_str(), *idx))
249    }
250
251    /// Number of mapped calls.
252    pub fn len(&self) -> usize {
253        self.mappings.len()
254    }
255
256    /// Check if the mapping is empty.
257    pub fn is_empty(&self) -> bool {
258        self.mappings.is_empty()
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_batch_mapping() {
268        let mut collector = BatchCollector::new();
269        collector.add(NeuralCall::new("net1", vec![0]));
270        collector.add(NeuralCall::new("net2", vec![1]));
271        collector.add(NeuralCall::new("net1", vec![2]));
272
273        let mapping = BatchMapping::from_collector(&collector);
274
275        assert_eq!(mapping.get(0), Some(("net1", 0)));
276        assert_eq!(mapping.get(1), Some(("net2", 0)));
277        assert_eq!(mapping.get(2), Some(("net1", 1)));
278    }
279
280    #[test]
281    fn test_neural_call_num_inputs() {
282        let call = NeuralCall::new("test", vec![1, 2, 3, 4]);
283        assert_eq!(call.num_inputs(), 4);
284    }
285
286    #[test]
287    fn test_batch_collector_take() {
288        let mut collector = BatchCollector::new();
289        collector.add(NeuralCall::new("net", vec![0]));
290        collector.add(NeuralCall::new("net", vec![1]));
291
292        let calls = collector.take();
293        assert_eq!(calls.len(), 2);
294        assert!(collector.is_empty());
295    }
296}