1use crate::handle::{EmbeddingHandle, NetworkHandle};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
17#[non_exhaustive]
18pub struct NetworkConfig {
19 pub name: String,
21
22 pub batching: bool,
25
26 pub k: Option<usize>,
29
30 pub det: bool,
33
34 pub cache_enabled: bool,
37
38 pub cache_size: usize,
40}
41
42impl NetworkConfig {
43 pub fn default(name: &str) -> Self {
52 Self {
53 name: name.to_string(),
54 batching: true,
55 k: None,
56 det: false,
57 cache_enabled: true,
58 cache_size: 10000,
59 }
60 }
61
62 pub fn deterministic(name: &str) -> Self {
64 Self {
65 name: name.to_string(),
66 batching: true,
67 k: None,
68 det: true,
69 cache_enabled: true,
70 cache_size: 10000,
71 }
72 }
73
74 pub fn with_top_k(name: &str, k: usize) -> Self {
76 Self {
77 name: name.to_string(),
78 batching: true,
79 k: Some(k),
80 det: false,
81 cache_enabled: true,
82 cache_size: 10000,
83 }
84 }
85
86 pub fn batching(mut self, enabled: bool) -> Self {
88 self.batching = enabled;
89 self
90 }
91
92 pub fn k(mut self, k: Option<usize>) -> Self {
94 self.k = k;
95 self
96 }
97
98 pub fn det(mut self, det: bool) -> Self {
100 self.det = det;
101 self
102 }
103
104 pub fn cache(mut self, enabled: bool, size: usize) -> Self {
106 self.cache_enabled = enabled;
107 self.cache_size = size;
108 self
109 }
110}
111
112pub struct NetworkRegistry {
119 networks: HashMap<String, NetworkHandle>,
121 embeddings: HashMap<String, EmbeddingHandle>,
123}
124
125impl NetworkRegistry {
126 pub fn new() -> Self {
128 Self {
129 networks: HashMap::new(),
130 embeddings: HashMap::new(),
131 }
132 }
133
134 pub fn register(&mut self, config: NetworkConfig) {
138 let handle = NetworkHandle::from_config(&config);
139 self.networks.insert(config.name, handle);
140 }
141
142 pub fn get(&self, name: &str) -> Option<&NetworkHandle> {
144 self.networks.get(name)
145 }
146
147 pub fn get_mut(&mut self, name: &str) -> Option<&mut NetworkHandle> {
149 self.networks.get_mut(name)
150 }
151
152 pub fn contains(&self, name: &str) -> bool {
154 self.networks.contains_key(name)
155 }
156
157 pub fn unregister(&mut self, name: &str) -> Option<NetworkHandle> {
159 self.networks.remove(name)
160 }
161
162 pub fn set_train_mode(&mut self, train: bool) {
167 for handle in self.networks.values_mut() {
168 handle.train_mode = train;
169 }
170 }
171
172 pub fn names(&self) -> Vec<&str> {
174 self.networks.keys().map(|s| s.as_str()).collect()
175 }
176
177 pub fn len(&self) -> usize {
179 self.networks.len()
180 }
181
182 pub fn is_empty(&self) -> bool {
184 self.networks.is_empty()
185 }
186
187 pub fn clear(&mut self) {
189 self.networks.clear();
190 }
191
192 pub fn iter(&self) -> impl Iterator<Item = (&str, &NetworkHandle)> {
194 self.networks.iter().map(|(k, v)| (k.as_str(), v))
195 }
196
197 pub fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut NetworkHandle)> {
199 self.networks.iter_mut().map(|(k, v)| (k.as_str(), v))
200 }
201
202 pub fn register_embedding(&mut self, handle: EmbeddingHandle) {
204 self.embeddings.insert(handle.name.clone(), handle);
205 }
206
207 pub fn get_embedding(&self, name: &str) -> Option<&EmbeddingHandle> {
209 self.embeddings.get(name)
210 }
211
212 pub fn get_embedding_mut(&mut self, name: &str) -> Option<&mut EmbeddingHandle> {
214 self.embeddings.get_mut(name)
215 }
216
217 pub fn contains_embedding(&self, name: &str) -> bool {
219 self.embeddings.contains_key(name)
220 }
221}
222
223impl Default for NetworkRegistry {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_config_default() {
235 let config = NetworkConfig::default("test");
236 assert_eq!(config.name, "test");
237 assert!(config.batching);
238 assert!(config.k.is_none());
239 assert!(!config.det);
240 assert!(config.cache_enabled);
241 assert_eq!(config.cache_size, 10000);
242 }
243
244 #[test]
245 fn test_config_deterministic() {
246 let config = NetworkConfig::deterministic("det_test");
247 assert!(config.det);
248 }
249
250 #[test]
251 fn test_config_with_top_k() {
252 let config = NetworkConfig::with_top_k("top_k_test", 5);
253 assert_eq!(config.k, Some(5));
254 }
255
256 #[test]
257 fn test_config_builder() {
258 let config = NetworkConfig::default("builder_test")
259 .batching(false)
260 .k(Some(3))
261 .det(true)
262 .cache(false, 0);
263
264 assert!(!config.batching);
265 assert_eq!(config.k, Some(3));
266 assert!(config.det);
267 assert!(!config.cache_enabled);
268 assert_eq!(config.cache_size, 0);
269 }
270
271 #[test]
272 fn test_registry_new() {
273 let registry = NetworkRegistry::new();
274 assert!(registry.is_empty());
275 assert_eq!(registry.len(), 0);
276 }
277
278 #[test]
279 fn test_registry_register_get() {
280 let mut registry = NetworkRegistry::new();
281 registry.register(NetworkConfig::default("net1"));
282
283 assert!(registry.contains("net1"));
284 assert!(registry.get("net1").is_some());
285 assert!(registry.get("nonexistent").is_none());
286 }
287
288 #[test]
289 fn test_registry_iter() {
290 let mut registry = NetworkRegistry::new();
291 registry.register(NetworkConfig::default("a"));
292 registry.register(NetworkConfig::default("b"));
293
294 let names: Vec<&str> = registry.iter().map(|(name, _)| name).collect();
295 assert_eq!(names.len(), 2);
296 }
297
298 use crate::handle::EmbeddingHandle;
299
300 #[test]
301 fn test_registry_embedding_register_get() {
302 let mut registry = NetworkRegistry::new();
303 let handle = EmbeddingHandle::new("embed1".to_string(), true, 64, 100);
304 registry.register_embedding(handle);
305
306 assert!(registry.contains_embedding("embed1"));
307 assert!(!registry.contains_embedding("nonexistent"));
308
309 let h = registry.get_embedding("embed1").unwrap();
310 assert_eq!(h.dim, 64);
311 assert_eq!(h.vocab_size, 100);
312 }
313
314 #[test]
315 fn test_registry_embedding_get_mut() {
316 let mut registry = NetworkRegistry::new();
317 let handle = EmbeddingHandle::new("embed1".to_string(), true, 64, 100);
318 registry.register_embedding(handle);
319
320 let h = registry.get_embedding_mut("embed1").unwrap();
321 h.trainable = false;
322 assert!(!registry.get_embedding("embed1").unwrap().trainable);
323 }
324}