Skip to main content

xlog_neural/
handle.rs

1//! Network handle for managing PyTorch modules.
2//!
3//! Each registered neural network is represented by a `NetworkHandle` which holds:
4//! - The PyTorch module (nn.Module) - when `python` feature is enabled
5//! - Optional optimizer for training
6//! - Optional learning rate scheduler
7//! - Configuration flags for batching, caching, etc.
8
9#[cfg(feature = "python")]
10use pyo3::PyObject;
11
12/// Handle to a registered neural network.
13///
14/// This struct holds the PyTorch module and associated training state.
15/// When the `python` feature is enabled, it can hold PyO3 PyObject references.
16#[derive(Debug)]
17pub struct NetworkHandle {
18    /// Unique name identifying this network
19    pub name: String,
20
21    /// The PyTorch nn.Module (set via Python API)
22    /// Only available with the `python` feature
23    #[cfg(feature = "python")]
24    pub module: Option<PyObject>,
25
26    /// The optimizer for training (e.g., Adam, SGD)
27    /// Only available with the `python` feature
28    #[cfg(feature = "python")]
29    pub optimizer: Option<PyObject>,
30
31    /// Learning rate scheduler
32    /// Only available with the `python` feature
33    #[cfg(feature = "python")]
34    pub scheduler: Option<PyObject>,
35
36    /// Whether to batch inputs for efficient GPU processing
37    pub batching: bool,
38
39    /// Top-k sampling: if Some(k), only consider top k outputs
40    pub k: Option<usize>,
41
42    /// Deterministic mode: use argmax instead of sampling
43    pub det: bool,
44
45    /// Whether the network is in training mode
46    pub train_mode: bool,
47
48    /// Whether output caching is enabled
49    pub cache_enabled: bool,
50
51    /// Maximum number of cached outputs
52    pub cache_size: usize,
53}
54
55impl NetworkHandle {
56    /// Create a new network handle with the given name and default settings.
57    pub fn new(name: String) -> Self {
58        Self {
59            name,
60            #[cfg(feature = "python")]
61            module: None,
62            #[cfg(feature = "python")]
63            optimizer: None,
64            #[cfg(feature = "python")]
65            scheduler: None,
66            batching: true,
67            k: None,
68            det: false,
69            train_mode: false,
70            cache_enabled: true,
71            cache_size: 10000,
72        }
73    }
74
75    /// Create a handle from a configuration.
76    pub fn from_config(config: &crate::NetworkConfig) -> Self {
77        Self {
78            name: config.name.clone(),
79            #[cfg(feature = "python")]
80            module: None,
81            #[cfg(feature = "python")]
82            optimizer: None,
83            #[cfg(feature = "python")]
84            scheduler: None,
85            batching: config.batching,
86            k: config.k,
87            det: config.det,
88            train_mode: false,
89            cache_enabled: config.cache_enabled,
90            cache_size: config.cache_size,
91        }
92    }
93
94    /// Check if the PyTorch module has been set.
95    #[cfg(feature = "python")]
96    pub fn has_module(&self) -> bool {
97        self.module.is_some()
98    }
99
100    /// Check if the PyTorch module has been set.
101    /// Without Python feature, always returns false.
102    #[cfg(not(feature = "python"))]
103    /// Report whether a Python module/tensor handle is attached.
104    pub fn has_module(&self) -> bool {
105        false
106    }
107
108    /// Check if an optimizer has been configured.
109    #[cfg(feature = "python")]
110    pub fn has_optimizer(&self) -> bool {
111        self.optimizer.is_some()
112    }
113
114    /// Check if an optimizer has been configured.
115    /// Without Python feature, always returns false.
116    #[cfg(not(feature = "python"))]
117    pub fn has_optimizer(&self) -> bool {
118        false
119    }
120
121    /// Check if a scheduler has been configured.
122    #[cfg(feature = "python")]
123    pub fn has_scheduler(&self) -> bool {
124        self.scheduler.is_some()
125    }
126
127    /// Check if a scheduler has been configured.
128    /// Without Python feature, always returns false.
129    #[cfg(not(feature = "python"))]
130    pub fn has_scheduler(&self) -> bool {
131        false
132    }
133
134    /// Set the PyTorch module.
135    #[cfg(feature = "python")]
136    pub fn set_module(&mut self, module: PyObject) {
137        self.module = Some(module);
138    }
139
140    /// Set the optimizer.
141    #[cfg(feature = "python")]
142    pub fn set_optimizer(&mut self, optimizer: PyObject) {
143        self.optimizer = Some(optimizer);
144    }
145
146    /// Set the learning rate scheduler.
147    #[cfg(feature = "python")]
148    pub fn set_scheduler(&mut self, scheduler: PyObject) {
149        self.scheduler = Some(scheduler);
150    }
151
152    /// Get a reference to the PyTorch module.
153    #[cfg(feature = "python")]
154    pub fn module(&self) -> Option<&PyObject> {
155        self.module.as_ref()
156    }
157
158    /// Get a reference to the optimizer.
159    #[cfg(feature = "python")]
160    pub fn optimizer(&self) -> Option<&PyObject> {
161        self.optimizer.as_ref()
162    }
163
164    /// Get a reference to the scheduler.
165    #[cfg(feature = "python")]
166    pub fn scheduler(&self) -> Option<&PyObject> {
167        self.scheduler.as_ref()
168    }
169
170    /// Clear the module and training state.
171    #[cfg(feature = "python")]
172    pub fn clear(&mut self) {
173        self.module = None;
174        self.optimizer = None;
175        self.scheduler = None;
176    }
177
178    /// Clear the module and training state.
179    /// Without Python feature, this is a no-op.
180    #[cfg(not(feature = "python"))]
181    pub fn clear(&mut self) {
182        // No-op without Python feature
183    }
184}
185
186/// Handle to a registered embedding module.
187///
188/// Wraps either a trainable `nn.Embedding` or a frozen `torch.Tensor`.
189/// Created via `CompiledProgram.register_embedding()` in Python.
190#[derive(Debug)]
191pub struct EmbeddingHandle {
192    /// Unique name matching the nn() declaration
193    pub name: String,
194
195    /// The PyTorch nn.Embedding or tensor
196    #[cfg(feature = "python")]
197    pub module: Option<PyObject>,
198
199    /// Whether gradients flow through this embedding
200    pub trainable: bool,
201
202    /// Embedding vector dimension (second axis of weight matrix)
203    pub dim: usize,
204
205    /// Number of embedding entries (first axis of weight matrix)
206    pub vocab_size: usize,
207}
208
209impl EmbeddingHandle {
210    /// Create a new embedding handle.
211    pub fn new(name: String, trainable: bool, dim: usize, vocab_size: usize) -> Self {
212        Self {
213            name,
214            #[cfg(feature = "python")]
215            module: None,
216            trainable,
217            dim,
218            vocab_size,
219        }
220    }
221
222    /// Check if the PyTorch module/tensor has been set.
223    #[cfg(feature = "python")]
224    pub fn has_module(&self) -> bool {
225        self.module.is_some()
226    }
227
228    /// Check if the PyTorch module/tensor has been set.
229    /// Without Python feature, always returns false.
230    #[cfg(not(feature = "python"))]
231    pub fn has_module(&self) -> bool {
232        false
233    }
234
235    /// Set the PyTorch module/tensor.
236    #[cfg(feature = "python")]
237    pub fn set_module(&mut self, module: PyObject) {
238        self.module = Some(module);
239    }
240
241    /// Get a reference to the PyTorch module/tensor.
242    #[cfg(feature = "python")]
243    pub fn module(&self) -> Option<&PyObject> {
244        self.module.as_ref()
245    }
246}
247
248impl Default for NetworkHandle {
249    fn default() -> Self {
250        Self::new(String::new())
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_handle_new() {
260        let handle = NetworkHandle::new("test".to_string());
261        assert_eq!(handle.name, "test");
262        assert!(!handle.has_module());
263        assert!(!handle.has_optimizer());
264        assert!(handle.batching);
265        assert!(!handle.train_mode);
266    }
267
268    #[test]
269    fn test_handle_from_config() {
270        let config = crate::NetworkConfig {
271            name: "configured".to_string(),
272            batching: false,
273            k: Some(5),
274            det: true,
275            cache_enabled: false,
276            cache_size: 500,
277        };
278
279        let handle = NetworkHandle::from_config(&config);
280        assert_eq!(handle.name, "configured");
281        assert!(!handle.batching);
282        assert_eq!(handle.k, Some(5));
283        assert!(handle.det);
284        assert!(!handle.cache_enabled);
285        assert_eq!(handle.cache_size, 500);
286    }
287
288    #[test]
289    fn test_embedding_handle_new() {
290        let handle = EmbeddingHandle::new("test_embed".to_string(), true, 64, 1000);
291        assert_eq!(handle.name, "test_embed");
292        assert!(handle.trainable);
293        assert_eq!(handle.dim, 64);
294        assert_eq!(handle.vocab_size, 1000);
295        assert!(!handle.has_module());
296    }
297
298    #[test]
299    fn test_embedding_handle_frozen() {
300        let handle = EmbeddingHandle::new("frozen".to_string(), false, 128, 500);
301        assert!(!handle.trainable);
302        assert_eq!(handle.dim, 128);
303        assert_eq!(handle.vocab_size, 500);
304    }
305}