xlog_neural/
tensor_source.rs1use std::collections::HashMap;
40use thiserror::Error;
41
42#[cfg(feature = "python")]
43use pyo3::PyObject;
44
45#[derive(Error, Debug)]
47#[non_exhaustive]
48pub enum TensorSourceError {
49 #[error("Tensor source '{0}' not found")]
51 NotFound(String),
52
53 #[error("No active tensor source set")]
55 NoActive,
56
57 #[error("Index {0} out of bounds for source with {1} entries")]
59 IndexOutOfBounds(usize, usize),
60}
61
62#[derive(Debug, Clone)]
66pub struct TensorMetadata {
67 pub size: usize,
69 pub shape: Vec<usize>,
71 pub dtype: String,
73}
74
75impl TensorMetadata {
76 pub fn new(size: usize, shape: Vec<usize>) -> Self {
78 Self {
79 size,
80 shape,
81 dtype: "float32".to_string(),
82 }
83 }
84
85 pub fn with_dtype(size: usize, shape: Vec<usize>, dtype: &str) -> Self {
87 Self {
88 size,
89 shape,
90 dtype: dtype.to_string(),
91 }
92 }
93
94 pub fn sample_numel(&self) -> usize {
96 self.shape.iter().product()
97 }
98
99 pub fn full_shape(&self) -> Vec<usize> {
101 let mut shape = vec![self.size];
102 shape.extend(&self.shape);
103 shape
104 }
105}
106
107#[cfg(feature = "python")]
109struct TensorSource {
110 tensor: PyObject,
112}
113
114pub struct TensorSourceRegistry {
119 #[cfg(feature = "python")]
121 sources: HashMap<String, TensorSource>,
122 #[cfg(not(feature = "python"))]
123 sources: HashMap<String, TensorMetadata>,
124
125 active: Option<String>,
127
128 #[cfg(feature = "python")]
130 metadata: HashMap<String, TensorMetadata>,
131}
132
133impl TensorSourceRegistry {
134 pub fn new() -> Self {
136 Self {
137 sources: HashMap::new(),
138 active: None,
139 #[cfg(feature = "python")]
140 metadata: HashMap::new(),
141 }
142 }
143
144 pub fn add_with_metadata(&mut self, name: &str, metadata: TensorMetadata) {
146 #[cfg(feature = "python")]
147 {
148 self.metadata.insert(name.to_string(), metadata);
149 }
150 #[cfg(not(feature = "python"))]
151 {
152 self.sources.insert(name.to_string(), metadata);
153 }
154
155 if self.active.is_none() {
157 self.active = Some(name.to_string());
158 }
159 }
160
161 #[cfg(feature = "python")]
163 pub fn add(&mut self, name: &str, tensor: PyObject, metadata: TensorMetadata) {
164 let source = TensorSource { tensor };
165 self.sources.insert(name.to_string(), source);
166 self.metadata.insert(name.to_string(), metadata);
167
168 if self.active.is_none() {
170 self.active = Some(name.to_string());
171 }
172 }
173
174 pub fn set_active(&mut self, name: &str) -> Result<(), TensorSourceError> {
176 #[cfg(feature = "python")]
177 let exists = self.metadata.contains_key(name);
178 #[cfg(not(feature = "python"))]
179 let exists = self.sources.contains_key(name);
180
181 if exists {
182 self.active = Some(name.to_string());
183 Ok(())
184 } else {
185 Err(TensorSourceError::NotFound(name.to_string()))
186 }
187 }
188
189 pub fn active_name(&self) -> Option<&str> {
191 self.active.as_deref()
192 }
193
194 pub fn active_size(&self) -> Result<usize, TensorSourceError> {
196 match &self.active {
197 Some(name) => {
198 #[cfg(feature = "python")]
199 let meta = self.metadata.get(name);
200 #[cfg(not(feature = "python"))]
201 let meta = self.sources.get(name);
202
203 meta.map(|m| m.size)
204 .ok_or_else(|| TensorSourceError::NotFound(name.clone()))
205 }
206 None => Err(TensorSourceError::NoActive),
207 }
208 }
209
210 #[cfg(feature = "python")]
212 pub fn get_active(&self) -> Result<&PyObject, TensorSourceError> {
213 match &self.active {
214 Some(name) => self
215 .sources
216 .get(name)
217 .map(|s| &s.tensor)
218 .ok_or_else(|| TensorSourceError::NotFound(name.clone())),
219 None => Err(TensorSourceError::NoActive),
220 }
221 }
222
223 #[cfg(feature = "python")]
228 pub fn get_named(&self, name: &str) -> Result<&PyObject, TensorSourceError> {
229 self.sources
230 .get(name)
231 .map(|s| &s.tensor)
232 .ok_or_else(|| TensorSourceError::NotFound(name.to_string()))
233 }
234
235 pub fn get_metadata(&self, name: &str) -> Option<&TensorMetadata> {
237 #[cfg(feature = "python")]
238 {
239 self.metadata.get(name)
240 }
241 #[cfg(not(feature = "python"))]
242 {
243 self.sources.get(name)
244 }
245 }
246
247 pub fn contains(&self, name: &str) -> bool {
249 #[cfg(feature = "python")]
250 {
251 self.metadata.contains_key(name)
252 }
253 #[cfg(not(feature = "python"))]
254 {
255 self.sources.contains_key(name)
256 }
257 }
258
259 pub fn check_index(&self, index: usize) -> Result<(), TensorSourceError> {
261 let size = self.active_size()?;
262 if index < size {
263 Ok(())
264 } else {
265 Err(TensorSourceError::IndexOutOfBounds(index, size))
266 }
267 }
268
269 pub fn validate_indices(&self, indices: &[usize]) -> Result<(), TensorSourceError> {
271 let size = self.active_size()?;
272 for &idx in indices {
273 if idx >= size {
274 return Err(TensorSourceError::IndexOutOfBounds(idx, size));
275 }
276 }
277 Ok(())
278 }
279
280 pub fn source_names(&self) -> Vec<String> {
282 #[cfg(feature = "python")]
283 {
284 self.metadata.keys().cloned().collect()
285 }
286 #[cfg(not(feature = "python"))]
287 {
288 self.sources.keys().cloned().collect()
289 }
290 }
291
292 pub fn len(&self) -> usize {
294 #[cfg(feature = "python")]
295 {
296 self.metadata.len()
297 }
298 #[cfg(not(feature = "python"))]
299 {
300 self.sources.len()
301 }
302 }
303
304 pub fn is_empty(&self) -> bool {
306 self.len() == 0
307 }
308
309 pub fn remove(&mut self, name: &str) {
311 #[cfg(feature = "python")]
312 {
313 self.sources.remove(name);
314 self.metadata.remove(name);
315 }
316 #[cfg(not(feature = "python"))]
317 {
318 self.sources.remove(name);
319 }
320
321 if self.active.as_deref() == Some(name) {
323 self.active = None;
324 }
325 }
326
327 pub fn clear(&mut self) {
329 self.sources.clear();
330 #[cfg(feature = "python")]
331 self.metadata.clear();
332 self.active = None;
333 }
334
335 pub fn iter(&self) -> impl Iterator<Item = (&str, &TensorMetadata)> {
337 #[cfg(feature = "python")]
338 {
339 self.metadata.iter().map(|(k, v)| (k.as_str(), v))
340 }
341 #[cfg(not(feature = "python"))]
342 {
343 self.sources.iter().map(|(k, v)| (k.as_str(), v))
344 }
345 }
346}
347
348impl Default for TensorSourceRegistry {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_metadata_sample_numel() {
360 let meta = TensorMetadata::new(100, vec![3, 224, 224]);
361 assert_eq!(meta.sample_numel(), 3 * 224 * 224);
362 }
363
364 #[test]
365 fn test_metadata_full_shape() {
366 let meta = TensorMetadata::new(1000, vec![1, 28, 28]);
367 assert_eq!(meta.full_shape(), vec![1000, 1, 28, 28]);
368 }
369
370 #[test]
371 fn test_registry_default() {
372 let registry = TensorSourceRegistry::default();
373 assert!(registry.is_empty());
374 }
375}