1#[cfg(feature = "python")]
10use pyo3::PyObject;
11
12#[derive(Debug)]
17pub struct NetworkHandle {
18 pub name: String,
20
21 #[cfg(feature = "python")]
24 pub module: Option<PyObject>,
25
26 #[cfg(feature = "python")]
29 pub optimizer: Option<PyObject>,
30
31 #[cfg(feature = "python")]
34 pub scheduler: Option<PyObject>,
35
36 pub batching: bool,
38
39 pub k: Option<usize>,
41
42 pub det: bool,
44
45 pub train_mode: bool,
47
48 pub cache_enabled: bool,
50
51 pub cache_size: usize,
53}
54
55impl NetworkHandle {
56 pub fn new(name: String) -> Self {
58 Self {
59 name,
60 #[cfg(feature = "python")]
61 module: None,
62 #[cfg(feature = "python")]
63 optimizer: None,
64 #[cfg(feature = "python")]
65 scheduler: None,
66 batching: true,
67 k: None,
68 det: false,
69 train_mode: false,
70 cache_enabled: true,
71 cache_size: 10000,
72 }
73 }
74
75 pub fn from_config(config: &crate::NetworkConfig) -> Self {
77 Self {
78 name: config.name.clone(),
79 #[cfg(feature = "python")]
80 module: None,
81 #[cfg(feature = "python")]
82 optimizer: None,
83 #[cfg(feature = "python")]
84 scheduler: None,
85 batching: config.batching,
86 k: config.k,
87 det: config.det,
88 train_mode: false,
89 cache_enabled: config.cache_enabled,
90 cache_size: config.cache_size,
91 }
92 }
93
94 #[cfg(feature = "python")]
96 pub fn has_module(&self) -> bool {
97 self.module.is_some()
98 }
99
100 #[cfg(not(feature = "python"))]
103 pub fn has_module(&self) -> bool {
105 false
106 }
107
108 #[cfg(feature = "python")]
110 pub fn has_optimizer(&self) -> bool {
111 self.optimizer.is_some()
112 }
113
114 #[cfg(not(feature = "python"))]
117 pub fn has_optimizer(&self) -> bool {
118 false
119 }
120
121 #[cfg(feature = "python")]
123 pub fn has_scheduler(&self) -> bool {
124 self.scheduler.is_some()
125 }
126
127 #[cfg(not(feature = "python"))]
130 pub fn has_scheduler(&self) -> bool {
131 false
132 }
133
134 #[cfg(feature = "python")]
136 pub fn set_module(&mut self, module: PyObject) {
137 self.module = Some(module);
138 }
139
140 #[cfg(feature = "python")]
142 pub fn set_optimizer(&mut self, optimizer: PyObject) {
143 self.optimizer = Some(optimizer);
144 }
145
146 #[cfg(feature = "python")]
148 pub fn set_scheduler(&mut self, scheduler: PyObject) {
149 self.scheduler = Some(scheduler);
150 }
151
152 #[cfg(feature = "python")]
154 pub fn module(&self) -> Option<&PyObject> {
155 self.module.as_ref()
156 }
157
158 #[cfg(feature = "python")]
160 pub fn optimizer(&self) -> Option<&PyObject> {
161 self.optimizer.as_ref()
162 }
163
164 #[cfg(feature = "python")]
166 pub fn scheduler(&self) -> Option<&PyObject> {
167 self.scheduler.as_ref()
168 }
169
170 #[cfg(feature = "python")]
172 pub fn clear(&mut self) {
173 self.module = None;
174 self.optimizer = None;
175 self.scheduler = None;
176 }
177
178 #[cfg(not(feature = "python"))]
181 pub fn clear(&mut self) {
182 }
184}
185
186#[derive(Debug)]
191pub struct EmbeddingHandle {
192 pub name: String,
194
195 #[cfg(feature = "python")]
197 pub module: Option<PyObject>,
198
199 pub trainable: bool,
201
202 pub dim: usize,
204
205 pub vocab_size: usize,
207}
208
209impl EmbeddingHandle {
210 pub fn new(name: String, trainable: bool, dim: usize, vocab_size: usize) -> Self {
212 Self {
213 name,
214 #[cfg(feature = "python")]
215 module: None,
216 trainable,
217 dim,
218 vocab_size,
219 }
220 }
221
222 #[cfg(feature = "python")]
224 pub fn has_module(&self) -> bool {
225 self.module.is_some()
226 }
227
228 #[cfg(not(feature = "python"))]
231 pub fn has_module(&self) -> bool {
232 false
233 }
234
235 #[cfg(feature = "python")]
237 pub fn set_module(&mut self, module: PyObject) {
238 self.module = Some(module);
239 }
240
241 #[cfg(feature = "python")]
243 pub fn module(&self) -> Option<&PyObject> {
244 self.module.as_ref()
245 }
246}
247
248impl Default for NetworkHandle {
249 fn default() -> Self {
250 Self::new(String::new())
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_handle_new() {
260 let handle = NetworkHandle::new("test".to_string());
261 assert_eq!(handle.name, "test");
262 assert!(!handle.has_module());
263 assert!(!handle.has_optimizer());
264 assert!(handle.batching);
265 assert!(!handle.train_mode);
266 }
267
268 #[test]
269 fn test_handle_from_config() {
270 let config = crate::NetworkConfig {
271 name: "configured".to_string(),
272 batching: false,
273 k: Some(5),
274 det: true,
275 cache_enabled: false,
276 cache_size: 500,
277 };
278
279 let handle = NetworkHandle::from_config(&config);
280 assert_eq!(handle.name, "configured");
281 assert!(!handle.batching);
282 assert_eq!(handle.k, Some(5));
283 assert!(handle.det);
284 assert!(!handle.cache_enabled);
285 assert_eq!(handle.cache_size, 500);
286 }
287
288 #[test]
289 fn test_embedding_handle_new() {
290 let handle = EmbeddingHandle::new("test_embed".to_string(), true, 64, 1000);
291 assert_eq!(handle.name, "test_embed");
292 assert!(handle.trainable);
293 assert_eq!(handle.dim, 64);
294 assert_eq!(handle.vocab_size, 1000);
295 assert!(!handle.has_module());
296 }
297
298 #[test]
299 fn test_embedding_handle_frozen() {
300 let handle = EmbeddingHandle::new("frozen".to_string(), false, 128, 500);
301 assert!(!handle.trainable);
302 assert_eq!(handle.dim, 128);
303 assert_eq!(handle.vocab_size, 500);
304 }
305}