1use std::collections::HashMap;
37
38#[derive(Debug, Clone)]
43pub struct NeuralCall {
44 pub network: String,
46 pub input_indices: Vec<usize>,
48}
49
50impl NeuralCall {
51 pub fn new(network: &str, input_indices: Vec<usize>) -> Self {
53 Self {
54 network: network.to_string(),
55 input_indices,
56 }
57 }
58
59 pub fn single(network: &str, index: usize) -> Self {
61 Self::new(network, vec![index])
62 }
63
64 pub fn num_inputs(&self) -> usize {
66 self.input_indices.len()
67 }
68}
69
70pub struct BatchCollector {
75 calls: Vec<NeuralCall>,
77}
78
79impl BatchCollector {
80 pub fn new() -> Self {
82 Self { calls: Vec::new() }
83 }
84
85 pub fn with_capacity(capacity: usize) -> Self {
87 Self {
88 calls: Vec::with_capacity(capacity),
89 }
90 }
91
92 pub fn add(&mut self, call: NeuralCall) {
94 self.calls.push(call);
95 }
96
97 pub fn collect(&self) -> HashMap<String, Vec<&NeuralCall>> {
101 let mut batches: HashMap<String, Vec<&NeuralCall>> = HashMap::new();
102
103 for call in &self.calls {
104 batches.entry(call.network.clone()).or_default().push(call);
105 }
106
107 batches
108 }
109
110 pub fn indices_for_network(&self, network: &str) -> Vec<usize> {
115 self.calls
116 .iter()
117 .filter(|c| c.network == network)
118 .flat_map(|c| c.input_indices.iter().copied())
119 .collect()
120 }
121
122 pub fn network_names(&self) -> Vec<String> {
124 let mut names: Vec<String> = self
125 .calls
126 .iter()
127 .map(|c| c.network.clone())
128 .collect::<std::collections::HashSet<_>>()
129 .into_iter()
130 .collect();
131 names.sort();
132 names
133 }
134
135 pub fn call_count_for_network(&self, network: &str) -> usize {
137 self.calls.iter().filter(|c| c.network == network).count()
138 }
139
140 pub fn total_input_count(&self) -> usize {
142 self.calls.iter().map(|c| c.input_indices.len()).sum()
143 }
144
145 pub fn len(&self) -> usize {
147 self.calls.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.calls.is_empty()
153 }
154
155 pub fn clear(&mut self) {
157 self.calls.clear();
158 }
159
160 pub fn iter(&self) -> impl Iterator<Item = &NeuralCall> {
162 self.calls.iter()
163 }
164
165 pub fn take(&mut self) -> Vec<NeuralCall> {
167 std::mem::take(&mut self.calls)
168 }
169}
170
171impl Default for BatchCollector {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[derive(Debug, Clone)]
181pub struct BatchResult {
182 pub network: String,
184 pub outputs: Vec<Vec<f64>>,
187}
188
189impl BatchResult {
190 pub fn new(network: &str, outputs: Vec<Vec<f64>>) -> Self {
192 Self {
193 network: network.to_string(),
194 outputs,
195 }
196 }
197
198 pub fn get_output(&self, index: usize) -> Option<&Vec<f64>> {
200 self.outputs.get(index)
201 }
202
203 pub fn len(&self) -> usize {
205 self.outputs.len()
206 }
207
208 pub fn is_empty(&self) -> bool {
210 self.outputs.is_empty()
211 }
212
213 pub fn iter(&self) -> impl Iterator<Item = &Vec<f64>> {
215 self.outputs.iter()
216 }
217}
218
219#[derive(Debug, Clone)]
224pub struct BatchMapping {
225 mappings: Vec<(String, usize)>,
227}
228
229impl BatchMapping {
230 pub fn from_collector(collector: &BatchCollector) -> Self {
232 let mut mappings = Vec::with_capacity(collector.len());
233 let mut network_counts: HashMap<String, usize> = HashMap::new();
234
235 for call in collector.iter() {
236 let idx = *network_counts.entry(call.network.clone()).or_insert(0);
237 mappings.push((call.network.clone(), idx));
238 *network_counts.get_mut(&call.network).unwrap() += 1;
239 }
240
241 Self { mappings }
242 }
243
244 pub fn get(&self, call_index: usize) -> Option<(&str, usize)> {
246 self.mappings
247 .get(call_index)
248 .map(|(net, idx)| (net.as_str(), *idx))
249 }
250
251 pub fn len(&self) -> usize {
253 self.mappings.len()
254 }
255
256 pub fn is_empty(&self) -> bool {
258 self.mappings.is_empty()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_batch_mapping() {
268 let mut collector = BatchCollector::new();
269 collector.add(NeuralCall::new("net1", vec![0]));
270 collector.add(NeuralCall::new("net2", vec![1]));
271 collector.add(NeuralCall::new("net1", vec![2]));
272
273 let mapping = BatchMapping::from_collector(&collector);
274
275 assert_eq!(mapping.get(0), Some(("net1", 0)));
276 assert_eq!(mapping.get(1), Some(("net2", 0)));
277 assert_eq!(mapping.get(2), Some(("net1", 1)));
278 }
279
280 #[test]
281 fn test_neural_call_num_inputs() {
282 let call = NeuralCall::new("test", vec![1, 2, 3, 4]);
283 assert_eq!(call.num_inputs(), 4);
284 }
285
286 #[test]
287 fn test_batch_collector_take() {
288 let mut collector = BatchCollector::new();
289 collector.add(NeuralCall::new("net", vec![0]));
290 collector.add(NeuralCall::new("net", vec![1]));
291
292 let calls = collector.take();
293 assert_eq!(calls.len(), 2);
294 assert!(collector.is_empty());
295 }
296}