Skip to main content

xlog_neural/
tensor_source.rs

1//! Tensor Source Registry
2//!
3//! This module manages external tensor data (images, embeddings, etc.) that
4//! can be indexed by neural predicates during proof search.
5//!
6//! # Architecture
7//!
8//! In DeepProbLog-style programs, neural predicates reference external data:
9//!
10//! ```text
11//! nn(mnist_net, [X], Y, [0..9]) :: digit(X, Y).
12//! ?- digit(42, Y).  // X=42 indexes into tensor source
13//! ```
14//!
15//! The tensor source registry:
16//! - Stores named tensor sources (train, test, etc.)
17//! - Tracks the "active" source for current evaluation
18//! - Validates indices are within bounds
19//! - Holds PyTorch tensors via PyO3 (when `python` feature enabled)
20//!
21//! # Usage
22//!
23//! ```
24//! use xlog_neural::tensor_source::{TensorSourceRegistry, TensorMetadata};
25//!
26//! let mut registry = TensorSourceRegistry::new();
27//!
28//! // Add sources with metadata
29//! registry.add_with_metadata("train", TensorMetadata::new(60000, vec![1, 28, 28]));
30//! registry.add_with_metadata("test", TensorMetadata::new(10000, vec![1, 28, 28]));
31//!
32//! // Set active source
33//! registry.set_active("train").unwrap();
34//!
35//! // Validate index before neural call
36//! registry.check_index(42).unwrap();
37//! ```
38
39use std::collections::HashMap;
40use thiserror::Error;
41
42#[cfg(feature = "python")]
43use pyo3::PyObject;
44
45/// Errors from tensor source operations.
46#[derive(Error, Debug)]
47#[non_exhaustive]
48pub enum TensorSourceError {
49    /// Tensor source not found in registry
50    #[error("Tensor source '{0}' not found")]
51    NotFound(String),
52
53    /// No active tensor source set
54    #[error("No active tensor source set")]
55    NoActive,
56
57    /// Index out of bounds for the active source
58    #[error("Index {0} out of bounds for source with {1} entries")]
59    IndexOutOfBounds(usize, usize),
60}
61
62/// Metadata about a tensor source (without the actual tensor data).
63///
64/// Used for validation and introspection without requiring Python GIL.
65#[derive(Debug, Clone)]
66pub struct TensorMetadata {
67    /// Number of samples in the tensor (first dimension)
68    pub size: usize,
69    /// Shape of each sample (excluding batch dimension)
70    pub shape: Vec<usize>,
71    /// Data type as string (e.g., "float32", "float64")
72    pub dtype: String,
73}
74
75impl TensorMetadata {
76    /// Create metadata with default dtype (float32).
77    pub fn new(size: usize, shape: Vec<usize>) -> Self {
78        Self {
79            size,
80            shape,
81            dtype: "float32".to_string(),
82        }
83    }
84
85    /// Create metadata with explicit dtype.
86    pub fn with_dtype(size: usize, shape: Vec<usize>, dtype: &str) -> Self {
87        Self {
88            size,
89            shape,
90            dtype: dtype.to_string(),
91        }
92    }
93
94    /// Total number of elements per sample.
95    pub fn sample_numel(&self) -> usize {
96        self.shape.iter().product()
97    }
98
99    /// Full shape including batch dimension.
100    pub fn full_shape(&self) -> Vec<usize> {
101        let mut shape = vec![self.size];
102        shape.extend(&self.shape);
103        shape
104    }
105}
106
107/// Internal storage for a tensor source (used only with python feature).
108#[cfg(feature = "python")]
109struct TensorSource {
110    /// The actual PyTorch tensor
111    tensor: PyObject,
112}
113
114/// Registry for managing tensor sources.
115///
116/// Tensor sources are named collections of data (e.g., "train", "test")
117/// that neural predicates can index into.
118pub struct TensorSourceRegistry {
119    /// Map from source name to source data
120    #[cfg(feature = "python")]
121    sources: HashMap<String, TensorSource>,
122    #[cfg(not(feature = "python"))]
123    sources: HashMap<String, TensorMetadata>,
124
125    /// Currently active source name
126    active: Option<String>,
127
128    /// Metadata stored separately for non-python access
129    #[cfg(feature = "python")]
130    metadata: HashMap<String, TensorMetadata>,
131}
132
133impl TensorSourceRegistry {
134    /// Create a new empty registry.
135    pub fn new() -> Self {
136        Self {
137            sources: HashMap::new(),
138            active: None,
139            #[cfg(feature = "python")]
140            metadata: HashMap::new(),
141        }
142    }
143
144    /// Add a tensor source with metadata only (for testing without Python).
145    pub fn add_with_metadata(&mut self, name: &str, metadata: TensorMetadata) {
146        #[cfg(feature = "python")]
147        {
148            self.metadata.insert(name.to_string(), metadata);
149        }
150        #[cfg(not(feature = "python"))]
151        {
152            self.sources.insert(name.to_string(), metadata);
153        }
154
155        // Auto-set first source as active
156        if self.active.is_none() {
157            self.active = Some(name.to_string());
158        }
159    }
160
161    /// Add a tensor source with PyTorch tensor.
162    #[cfg(feature = "python")]
163    pub fn add(&mut self, name: &str, tensor: PyObject, metadata: TensorMetadata) {
164        let source = TensorSource { tensor };
165        self.sources.insert(name.to_string(), source);
166        self.metadata.insert(name.to_string(), metadata);
167
168        // Auto-set first source as active
169        if self.active.is_none() {
170            self.active = Some(name.to_string());
171        }
172    }
173
174    /// Set the active tensor source.
175    pub fn set_active(&mut self, name: &str) -> Result<(), TensorSourceError> {
176        #[cfg(feature = "python")]
177        let exists = self.metadata.contains_key(name);
178        #[cfg(not(feature = "python"))]
179        let exists = self.sources.contains_key(name);
180
181        if exists {
182            self.active = Some(name.to_string());
183            Ok(())
184        } else {
185            Err(TensorSourceError::NotFound(name.to_string()))
186        }
187    }
188
189    /// Get the name of the active source.
190    pub fn active_name(&self) -> Option<&str> {
191        self.active.as_deref()
192    }
193
194    /// Get the size of the active source.
195    pub fn active_size(&self) -> Result<usize, TensorSourceError> {
196        match &self.active {
197            Some(name) => {
198                #[cfg(feature = "python")]
199                let meta = self.metadata.get(name);
200                #[cfg(not(feature = "python"))]
201                let meta = self.sources.get(name);
202
203                meta.map(|m| m.size)
204                    .ok_or_else(|| TensorSourceError::NotFound(name.clone()))
205            }
206            None => Err(TensorSourceError::NoActive),
207        }
208    }
209
210    /// Get the PyTorch tensor for the active source.
211    #[cfg(feature = "python")]
212    pub fn get_active(&self) -> Result<&PyObject, TensorSourceError> {
213        match &self.active {
214            Some(name) => self
215                .sources
216                .get(name)
217                .map(|s| &s.tensor)
218                .ok_or_else(|| TensorSourceError::NotFound(name.clone())),
219            None => Err(TensorSourceError::NoActive),
220        }
221    }
222
223    /// Get the PyTorch tensor for a specific named source (regardless of which
224    /// source is active). Used by the Stage-B existential-join forward, which
225    /// reads the per-event feature batch from a fixed `nsr_domain` source while
226    /// the per-query examples source stays active.
227    #[cfg(feature = "python")]
228    pub fn get_named(&self, name: &str) -> Result<&PyObject, TensorSourceError> {
229        self.sources
230            .get(name)
231            .map(|s| &s.tensor)
232            .ok_or_else(|| TensorSourceError::NotFound(name.to_string()))
233    }
234
235    /// Get metadata for a specific source.
236    pub fn get_metadata(&self, name: &str) -> Option<&TensorMetadata> {
237        #[cfg(feature = "python")]
238        {
239            self.metadata.get(name)
240        }
241        #[cfg(not(feature = "python"))]
242        {
243            self.sources.get(name)
244        }
245    }
246
247    /// Check if a source exists.
248    pub fn contains(&self, name: &str) -> bool {
249        #[cfg(feature = "python")]
250        {
251            self.metadata.contains_key(name)
252        }
253        #[cfg(not(feature = "python"))]
254        {
255            self.sources.contains_key(name)
256        }
257    }
258
259    /// Check if an index is valid for the active source.
260    pub fn check_index(&self, index: usize) -> Result<(), TensorSourceError> {
261        let size = self.active_size()?;
262        if index < size {
263            Ok(())
264        } else {
265            Err(TensorSourceError::IndexOutOfBounds(index, size))
266        }
267    }
268
269    /// Validate multiple indices at once.
270    pub fn validate_indices(&self, indices: &[usize]) -> Result<(), TensorSourceError> {
271        let size = self.active_size()?;
272        for &idx in indices {
273            if idx >= size {
274                return Err(TensorSourceError::IndexOutOfBounds(idx, size));
275            }
276        }
277        Ok(())
278    }
279
280    /// Get names of all sources.
281    pub fn source_names(&self) -> Vec<String> {
282        #[cfg(feature = "python")]
283        {
284            self.metadata.keys().cloned().collect()
285        }
286        #[cfg(not(feature = "python"))]
287        {
288            self.sources.keys().cloned().collect()
289        }
290    }
291
292    /// Number of sources.
293    pub fn len(&self) -> usize {
294        #[cfg(feature = "python")]
295        {
296            self.metadata.len()
297        }
298        #[cfg(not(feature = "python"))]
299        {
300            self.sources.len()
301        }
302    }
303
304    /// Check if registry is empty.
305    pub fn is_empty(&self) -> bool {
306        self.len() == 0
307    }
308
309    /// Remove a source.
310    pub fn remove(&mut self, name: &str) {
311        #[cfg(feature = "python")]
312        {
313            self.sources.remove(name);
314            self.metadata.remove(name);
315        }
316        #[cfg(not(feature = "python"))]
317        {
318            self.sources.remove(name);
319        }
320
321        // Clear active if removed
322        if self.active.as_deref() == Some(name) {
323            self.active = None;
324        }
325    }
326
327    /// Clear all sources.
328    pub fn clear(&mut self) {
329        self.sources.clear();
330        #[cfg(feature = "python")]
331        self.metadata.clear();
332        self.active = None;
333    }
334
335    /// Iterate over source names and metadata.
336    pub fn iter(&self) -> impl Iterator<Item = (&str, &TensorMetadata)> {
337        #[cfg(feature = "python")]
338        {
339            self.metadata.iter().map(|(k, v)| (k.as_str(), v))
340        }
341        #[cfg(not(feature = "python"))]
342        {
343            self.sources.iter().map(|(k, v)| (k.as_str(), v))
344        }
345    }
346}
347
348impl Default for TensorSourceRegistry {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_metadata_sample_numel() {
360        let meta = TensorMetadata::new(100, vec![3, 224, 224]);
361        assert_eq!(meta.sample_numel(), 3 * 224 * 224);
362    }
363
364    #[test]
365    fn test_metadata_full_shape() {
366        let meta = TensorMetadata::new(1000, vec![1, 28, 28]);
367        assert_eq!(meta.full_shape(), vec![1000, 1, 28, 28]);
368    }
369
370    #[test]
371    fn test_registry_default() {
372        let registry = TensorSourceRegistry::default();
373        assert!(registry.is_empty());
374    }
375}