Skip to main content

xlog_neural/
registry.rs

1//! Network registry for managing registered neural networks.
2//!
3//! The registry is the central point for managing all neural networks used
4//! in a probabilistic logic program. It handles:
5//!
6//! - Registration of networks with their configurations
7//! - Train/eval mode switching for all networks
8//! - Network lookup by name
9
10use crate::handle::{EmbeddingHandle, NetworkHandle};
11use std::collections::HashMap;
12
13/// Configuration for registering a neural network.
14///
15/// This mirrors the DeepProbLog `register_network` options.
16#[derive(Debug, Clone)]
17#[non_exhaustive]
18pub struct NetworkConfig {
19    /// Unique name identifying this network (must match nn() declarations)
20    pub name: String,
21
22    /// Whether to batch inputs for efficient GPU processing.
23    /// When true, multiple queries are grouped into a single forward pass.
24    pub batching: bool,
25
26    /// Top-k sampling: if Some(k), only consider the top k outputs.
27    /// Useful for large output spaces where most classes have near-zero probability.
28    pub k: Option<usize>,
29
30    /// Deterministic mode: use argmax instead of probabilistic sampling.
31    /// Useful for debugging and when you want reproducible results.
32    pub det: bool,
33
34    /// Whether to cache network outputs.
35    /// Caching avoids redundant forward passes for repeated inputs.
36    pub cache_enabled: bool,
37
38    /// Maximum number of entries in the output cache.
39    pub cache_size: usize,
40}
41
42impl NetworkConfig {
43    /// Create a default configuration for a network with the given name.
44    ///
45    /// Default settings:
46    /// - batching: true
47    /// - k: None (consider all outputs)
48    /// - det: false (probabilistic mode)
49    /// - cache_enabled: true
50    /// - cache_size: 10000
51    pub fn default(name: &str) -> Self {
52        Self {
53            name: name.to_string(),
54            batching: true,
55            k: None,
56            det: false,
57            cache_enabled: true,
58            cache_size: 10000,
59        }
60    }
61
62    /// Create a configuration for a deterministic network.
63    pub fn deterministic(name: &str) -> Self {
64        Self {
65            name: name.to_string(),
66            batching: true,
67            k: None,
68            det: true,
69            cache_enabled: true,
70            cache_size: 10000,
71        }
72    }
73
74    /// Create a configuration with top-k sampling.
75    pub fn with_top_k(name: &str, k: usize) -> Self {
76        Self {
77            name: name.to_string(),
78            batching: true,
79            k: Some(k),
80            det: false,
81            cache_enabled: true,
82            cache_size: 10000,
83        }
84    }
85
86    /// Builder method to set batching.
87    pub fn batching(mut self, enabled: bool) -> Self {
88        self.batching = enabled;
89        self
90    }
91
92    /// Builder method to set top-k.
93    pub fn k(mut self, k: Option<usize>) -> Self {
94        self.k = k;
95        self
96    }
97
98    /// Builder method to set deterministic mode.
99    pub fn det(mut self, det: bool) -> Self {
100        self.det = det;
101        self
102    }
103
104    /// Builder method to set cache.
105    pub fn cache(mut self, enabled: bool, size: usize) -> Self {
106        self.cache_enabled = enabled;
107        self.cache_size = size;
108        self
109    }
110}
111
112/// Registry for managing neural networks.
113///
114/// The registry maintains a collection of `NetworkHandle` instances,
115/// each identified by a unique name. Networks are registered with
116/// configurations and then have their PyTorch modules attached via
117/// the Python API.
118pub struct NetworkRegistry {
119    /// Map from network name to handle
120    networks: HashMap<String, NetworkHandle>,
121    /// Map from embedding name to handle
122    embeddings: HashMap<String, EmbeddingHandle>,
123}
124
125impl NetworkRegistry {
126    /// Create a new empty registry.
127    pub fn new() -> Self {
128        Self {
129            networks: HashMap::new(),
130            embeddings: HashMap::new(),
131        }
132    }
133
134    /// Register a network with the given configuration.
135    ///
136    /// If a network with the same name already exists, it will be replaced.
137    pub fn register(&mut self, config: NetworkConfig) {
138        let handle = NetworkHandle::from_config(&config);
139        self.networks.insert(config.name, handle);
140    }
141
142    /// Get a reference to a network handle by name.
143    pub fn get(&self, name: &str) -> Option<&NetworkHandle> {
144        self.networks.get(name)
145    }
146
147    /// Get a mutable reference to a network handle by name.
148    pub fn get_mut(&mut self, name: &str) -> Option<&mut NetworkHandle> {
149        self.networks.get_mut(name)
150    }
151
152    /// Check if a network is registered.
153    pub fn contains(&self, name: &str) -> bool {
154        self.networks.contains_key(name)
155    }
156
157    /// Remove a network from the registry.
158    pub fn unregister(&mut self, name: &str) -> Option<NetworkHandle> {
159        self.networks.remove(name)
160    }
161
162    /// Set train mode for all registered networks.
163    ///
164    /// This affects both the `train_mode` flag on handles and should
165    /// be used to call `.train()` or `.eval()` on PyTorch modules.
166    pub fn set_train_mode(&mut self, train: bool) {
167        for handle in self.networks.values_mut() {
168            handle.train_mode = train;
169        }
170    }
171
172    /// Get the names of all registered networks.
173    pub fn names(&self) -> Vec<&str> {
174        self.networks.keys().map(|s| s.as_str()).collect()
175    }
176
177    /// Get the number of registered networks.
178    pub fn len(&self) -> usize {
179        self.networks.len()
180    }
181
182    /// Check if the registry is empty.
183    pub fn is_empty(&self) -> bool {
184        self.networks.is_empty()
185    }
186
187    /// Remove all networks from the registry.
188    pub fn clear(&mut self) {
189        self.networks.clear();
190    }
191
192    /// Iterate over all network handles.
193    pub fn iter(&self) -> impl Iterator<Item = (&str, &NetworkHandle)> {
194        self.networks.iter().map(|(k, v)| (k.as_str(), v))
195    }
196
197    /// Iterate mutably over all network handles.
198    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut NetworkHandle)> {
199        self.networks.iter_mut().map(|(k, v)| (k.as_str(), v))
200    }
201
202    /// Register an embedding with the given handle.
203    pub fn register_embedding(&mut self, handle: EmbeddingHandle) {
204        self.embeddings.insert(handle.name.clone(), handle);
205    }
206
207    /// Get a reference to an embedding handle by name.
208    pub fn get_embedding(&self, name: &str) -> Option<&EmbeddingHandle> {
209        self.embeddings.get(name)
210    }
211
212    /// Get a mutable reference to an embedding handle by name.
213    pub fn get_embedding_mut(&mut self, name: &str) -> Option<&mut EmbeddingHandle> {
214        self.embeddings.get_mut(name)
215    }
216
217    /// Check if an embedding is registered.
218    pub fn contains_embedding(&self, name: &str) -> bool {
219        self.embeddings.contains_key(name)
220    }
221}
222
223impl Default for NetworkRegistry {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_config_default() {
235        let config = NetworkConfig::default("test");
236        assert_eq!(config.name, "test");
237        assert!(config.batching);
238        assert!(config.k.is_none());
239        assert!(!config.det);
240        assert!(config.cache_enabled);
241        assert_eq!(config.cache_size, 10000);
242    }
243
244    #[test]
245    fn test_config_deterministic() {
246        let config = NetworkConfig::deterministic("det_test");
247        assert!(config.det);
248    }
249
250    #[test]
251    fn test_config_with_top_k() {
252        let config = NetworkConfig::with_top_k("top_k_test", 5);
253        assert_eq!(config.k, Some(5));
254    }
255
256    #[test]
257    fn test_config_builder() {
258        let config = NetworkConfig::default("builder_test")
259            .batching(false)
260            .k(Some(3))
261            .det(true)
262            .cache(false, 0);
263
264        assert!(!config.batching);
265        assert_eq!(config.k, Some(3));
266        assert!(config.det);
267        assert!(!config.cache_enabled);
268        assert_eq!(config.cache_size, 0);
269    }
270
271    #[test]
272    fn test_registry_new() {
273        let registry = NetworkRegistry::new();
274        assert!(registry.is_empty());
275        assert_eq!(registry.len(), 0);
276    }
277
278    #[test]
279    fn test_registry_register_get() {
280        let mut registry = NetworkRegistry::new();
281        registry.register(NetworkConfig::default("net1"));
282
283        assert!(registry.contains("net1"));
284        assert!(registry.get("net1").is_some());
285        assert!(registry.get("nonexistent").is_none());
286    }
287
288    #[test]
289    fn test_registry_iter() {
290        let mut registry = NetworkRegistry::new();
291        registry.register(NetworkConfig::default("a"));
292        registry.register(NetworkConfig::default("b"));
293
294        let names: Vec<&str> = registry.iter().map(|(name, _)| name).collect();
295        assert_eq!(names.len(), 2);
296    }
297
298    use crate::handle::EmbeddingHandle;
299
300    #[test]
301    fn test_registry_embedding_register_get() {
302        let mut registry = NetworkRegistry::new();
303        let handle = EmbeddingHandle::new("embed1".to_string(), true, 64, 100);
304        registry.register_embedding(handle);
305
306        assert!(registry.contains_embedding("embed1"));
307        assert!(!registry.contains_embedding("nonexistent"));
308
309        let h = registry.get_embedding("embed1").unwrap();
310        assert_eq!(h.dim, 64);
311        assert_eq!(h.vocab_size, 100);
312    }
313
314    #[test]
315    fn test_registry_embedding_get_mut() {
316        let mut registry = NetworkRegistry::new();
317        let handle = EmbeddingHandle::new("embed1".to_string(), true, 64, 100);
318        registry.register_embedding(handle);
319
320        let h = registry.get_embedding_mut("embed1").unwrap();
321        h.trainable = false;
322        assert!(!registry.get_embedding("embed1").unwrap().trainable);
323    }
324}