Skip to main content

xlog_cuda/
device.rs

1//! CUDA device management
2//!
3//! This module keeps XLOG's historical single-stream device abstraction while
4//! targeting cudarc's newer CUDA 13-capable context/stream APIs.
5
6use std::collections::BTreeMap;
7use std::ffi::{c_void, CString};
8use std::path::Path;
9use std::sync::{Arc, RwLock};
10
11use cudarc::driver::result::{self, DriverError};
12use cudarc::driver::{
13    sys, CudaContext as CudarcContext, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr,
14    HostSlice, LaunchConfig, ValidAsZeroBits,
15};
16use cudarc::nvrtc::Ptx;
17use xlog_core::{Result, XlogError};
18
19#[derive(Debug)]
20struct LoadedModule {
21    cu_module: sys::CUmodule,
22    functions: BTreeMap<String, sys::CUfunction>,
23}
24
25unsafe impl Send for LoadedModule {}
26unsafe impl Sync for LoadedModule {}
27
28/// Kernel handle bound to XLOG's default CUDA stream.
29#[derive(Debug, Clone)]
30pub struct CudaFunction {
31    cu_function: sys::CUfunction,
32    context: Arc<CudarcContext>,
33    stream: Arc<CudaStream>,
34}
35
36impl CudaFunction {
37    pub(crate) unsafe fn launch_raw(
38        &self,
39        cfg: LaunchConfig,
40        params: &mut [*mut c_void],
41    ) -> std::result::Result<(), DriverError> {
42        self.context.bind_to_thread()?;
43        result::launch_kernel(
44            self.cu_function,
45            cfg.grid_dim,
46            cfg.block_dim,
47            cfg.shared_mem_bytes,
48            self.stream.cu_stream(),
49            params,
50        )
51    }
52
53    pub(crate) unsafe fn launch_raw_on_stream(
54        &self,
55        stream: &CudaStream,
56        cfg: LaunchConfig,
57        params: &mut [*mut c_void],
58    ) -> std::result::Result<(), DriverError> {
59        self.context.bind_to_thread()?;
60        result::launch_kernel(
61            self.cu_function,
62            cfg.grid_dim,
63            cfg.block_dim,
64            cfg.shared_mem_bytes,
65            stream.cu_stream(),
66            params,
67        )
68    }
69
70    pub(crate) unsafe fn launch_raw_cooperative(
71        &self,
72        cfg: LaunchConfig,
73        params: &mut [*mut c_void],
74    ) -> std::result::Result<(), DriverError> {
75        self.context.bind_to_thread()?;
76        result::launch_cooperative_kernel(
77            self.cu_function,
78            cfg.grid_dim,
79            cfg.block_dim,
80            cfg.shared_mem_bytes,
81            self.stream.cu_stream(),
82            params,
83        )
84    }
85
86    pub fn occupancy_available_dynamic_smem_per_block(
87        &self,
88        num_blocks: u32,
89        block_size: u32,
90    ) -> std::result::Result<usize, DriverError> {
91        let mut dynamic_smem_size: usize = 0;
92        unsafe {
93            sys::cuOccupancyAvailableDynamicSMemPerBlock(
94                &mut dynamic_smem_size,
95                self.cu_function,
96                num_blocks as std::ffi::c_int,
97                block_size as std::ffi::c_int,
98            )
99            .result()?
100        };
101        Ok(dynamic_smem_size)
102    }
103
104    pub fn occupancy_max_active_blocks_per_multiprocessor(
105        &self,
106        block_size: u32,
107        dynamic_smem_size: usize,
108        flags: Option<sys::CUoccupancy_flags_enum>,
109    ) -> std::result::Result<u32, DriverError> {
110        let mut num_blocks: std::ffi::c_int = 0;
111        let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
112        unsafe {
113            sys::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
114                &mut num_blocks,
115                self.cu_function,
116                block_size as std::ffi::c_int,
117                dynamic_smem_size,
118                flags as std::ffi::c_uint,
119            )
120            .result()?
121        };
122        Ok(num_blocks as u32)
123    }
124
125    pub fn occupancy_max_active_clusters(
126        &self,
127        config: LaunchConfig,
128    ) -> std::result::Result<u32, DriverError> {
129        let mut num_clusters: std::ffi::c_int = 0;
130        let cfg = sys::CUlaunchConfig {
131            gridDimX: config.grid_dim.0,
132            gridDimY: config.grid_dim.1,
133            gridDimZ: config.grid_dim.2,
134            blockDimX: config.block_dim.0,
135            blockDimY: config.block_dim.1,
136            blockDimZ: config.block_dim.2,
137            sharedMemBytes: config.shared_mem_bytes,
138            hStream: self.stream.cu_stream(),
139            attrs: std::ptr::null_mut(),
140            numAttrs: 0,
141        };
142        unsafe {
143            sys::cuOccupancyMaxActiveClusters(&mut num_clusters, self.cu_function, &cfg).result()?
144        };
145        Ok(num_clusters as u32)
146    }
147
148    pub fn occupancy_max_potential_block_size(
149        &self,
150        block_size_to_dynamic_smem_size: extern "C" fn(block_size: std::ffi::c_int) -> usize,
151        dynamic_smem_size: usize,
152        block_size_limit: u32,
153        flags: Option<sys::CUoccupancy_flags_enum>,
154    ) -> std::result::Result<(u32, u32), DriverError> {
155        let mut min_grid_size: std::ffi::c_int = 0;
156        let mut block_size: std::ffi::c_int = 0;
157        let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
158        unsafe {
159            sys::cuOccupancyMaxPotentialBlockSizeWithFlags(
160                &mut min_grid_size,
161                &mut block_size,
162                self.cu_function,
163                Some(block_size_to_dynamic_smem_size),
164                dynamic_smem_size,
165                block_size_limit as std::ffi::c_int,
166                flags as std::ffi::c_uint,
167            )
168            .result()?
169        };
170        Ok((min_grid_size as u32, block_size as u32))
171    }
172
173    pub fn occupancy_max_potential_cluster_size(
174        &self,
175        config: LaunchConfig,
176    ) -> std::result::Result<u32, DriverError> {
177        let mut cluster_size: std::ffi::c_int = 0;
178        let cfg = sys::CUlaunchConfig {
179            gridDimX: config.grid_dim.0,
180            gridDimY: config.grid_dim.1,
181            gridDimZ: config.grid_dim.2,
182            blockDimX: config.block_dim.0,
183            blockDimY: config.block_dim.1,
184            blockDimZ: config.block_dim.2,
185            sharedMemBytes: config.shared_mem_bytes,
186            hStream: self.stream.cu_stream(),
187            attrs: std::ptr::null_mut(),
188            numAttrs: 0,
189        };
190        unsafe {
191            sys::cuOccupancyMaxPotentialClusterSize(&mut cluster_size, self.cu_function, &cfg)
192                .result()?
193        };
194        Ok(cluster_size as u32)
195    }
196
197    pub fn get_attribute(
198        &self,
199        attribute: sys::CUfunction_attribute_enum,
200    ) -> std::result::Result<i32, DriverError> {
201        self.context.bind_to_thread()?;
202        unsafe { result::function::get_function_attribute(self.cu_function, attribute) }
203    }
204
205    pub fn num_regs(&self) -> std::result::Result<i32, DriverError> {
206        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_NUM_REGS)
207    }
208
209    pub fn shared_size_bytes(&self) -> std::result::Result<i32, DriverError> {
210        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
211    }
212
213    pub fn const_size_bytes(&self) -> std::result::Result<i32, DriverError> {
214        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
215    }
216
217    pub fn local_size_bytes(&self) -> std::result::Result<i32, DriverError> {
218        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
219    }
220
221    pub fn max_threads_per_block(&self) -> std::result::Result<i32, DriverError> {
222        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
223    }
224
225    pub fn ptx_version(&self) -> std::result::Result<i32, DriverError> {
226        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_PTX_VERSION)
227    }
228
229    pub fn binary_version(&self) -> std::result::Result<i32, DriverError> {
230        self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_BINARY_VERSION)
231    }
232
233    pub fn set_attribute(
234        &self,
235        attribute: sys::CUfunction_attribute_enum,
236        value: i32,
237    ) -> std::result::Result<(), DriverError> {
238        unsafe { result::function::set_function_attribute(self.cu_function, attribute, value) }
239    }
240
241    pub fn set_function_cache_config(
242        &self,
243        config: sys::CUfunc_cache,
244    ) -> std::result::Result<(), DriverError> {
245        unsafe { result::function::set_function_cache_config(self.cu_function, config) }
246    }
247}
248
249#[derive(Debug)]
250pub struct CudaDeviceInner {
251    context: Arc<CudarcContext>,
252    stream: Arc<CudaStream>,
253    modules: RwLock<BTreeMap<String, LoadedModule>>,
254}
255
256impl Drop for CudaDeviceInner {
257    fn drop(&mut self) {
258        let _ = self.context.bind_to_thread();
259        if let Ok(modules) = self.modules.get_mut() {
260            for module in modules.values() {
261                let _ = unsafe { result::module::unload(module.cu_module) };
262            }
263            modules.clear();
264        }
265    }
266}
267
268impl CudaDeviceInner {
269    fn insert_module(
270        &self,
271        module_name: &str,
272        cu_module: sys::CUmodule,
273        kernels: &[&str],
274    ) -> std::result::Result<(), DriverError> {
275        let mut functions = BTreeMap::new();
276        for &kernel in kernels {
277            let name_c = CString::new(kernel).unwrap();
278            let cu_function = unsafe { result::module::get_function(cu_module, name_c) }?;
279            functions.insert(kernel.to_string(), cu_function);
280        }
281        let module = LoadedModule {
282            cu_module,
283            functions,
284        };
285
286        let mut modules = self.modules.write().unwrap();
287        if let Some(prev) = modules.insert(module_name.to_string(), module) {
288            unsafe { result::module::unload(prev.cu_module) }?;
289        }
290        Ok(())
291    }
292
293    pub fn stream(&self) -> &Arc<CudaStream> {
294        &self.stream
295    }
296
297    pub fn has_func(&self, module_name: &str, func_name: &str) -> bool {
298        let modules = self.modules.read().unwrap();
299        modules
300            .get(module_name)
301            .is_some_and(|module| module.functions.contains_key(func_name))
302    }
303
304    pub fn get_func(&self, module_name: &str, func_name: &str) -> Option<CudaFunction> {
305        let modules = self.modules.read().unwrap();
306        let cu_function = modules
307            .get(module_name)
308            .and_then(|module| module.functions.get(func_name))
309            .copied()?;
310        Some(CudaFunction {
311            cu_function,
312            context: self.context.clone(),
313            stream: self.stream.clone(),
314        })
315    }
316
317    pub fn load_file(
318        &self,
319        path: &Path,
320        module_name: &str,
321        kernels: &[&str],
322    ) -> std::result::Result<(), DriverError> {
323        self.context.bind_to_thread()?;
324        let name_c = CString::new(path.to_string_lossy().as_bytes()).unwrap();
325        let cu_module = result::module::load(name_c)?;
326        self.insert_module(module_name, cu_module, kernels)
327    }
328
329    pub fn load_ptx(
330        &self,
331        ptx: Ptx,
332        module_name: &str,
333        kernels: &[&str],
334    ) -> std::result::Result<(), DriverError> {
335        self.context.bind_to_thread()?;
336        let cu_module = if let Some(bytes) = ptx.as_bytes() {
337            unsafe { result::module::load_data(bytes.as_ptr() as *const _) }?
338        } else {
339            let src = CString::new(ptx.to_src()).unwrap();
340            unsafe { result::module::load_data(src.as_ptr() as *const _) }?
341        };
342        self.insert_module(module_name, cu_module, kernels)
343    }
344
345    /// Allocate an uninitialized device slice on this device stream.
346    ///
347    /// # Safety
348    ///
349    /// The caller must initialize the returned allocation before any device or
350    /// host read observes its contents.
351    pub unsafe fn alloc<T: DeviceRepr>(
352        &self,
353        len: usize,
354    ) -> std::result::Result<CudaSlice<T>, DriverError> {
355        self.stream.alloc(len)
356    }
357
358    pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(
359        &self,
360        len: usize,
361    ) -> std::result::Result<CudaSlice<T>, DriverError> {
362        self.stream.alloc_zeros(len)
363    }
364
365    pub fn memset_zeros<T: DeviceRepr + ValidAsZeroBits, Dst: DevicePtrMut<T>>(
366        &self,
367        dst: &mut Dst,
368    ) -> std::result::Result<(), DriverError> {
369        self.stream.memset_zeros(dst)?;
370        self.stream.synchronize()
371    }
372
373    pub fn htod_sync_copy_into<T: DeviceRepr, Dst: DevicePtrMut<T>, Src: HostSlice<T> + ?Sized>(
374        &self,
375        src: &Src,
376        dst: &mut Dst,
377    ) -> std::result::Result<(), DriverError> {
378        self.stream.memcpy_htod(src, dst)?;
379        self.stream.synchronize()
380    }
381
382    pub fn dtoh_sync_copy_into<T: DeviceRepr, Src: DevicePtr<T>, Dst: HostSlice<T> + ?Sized>(
383        &self,
384        src: &Src,
385        dst: &mut Dst,
386    ) -> std::result::Result<(), DriverError> {
387        self.stream.memcpy_dtoh(src, dst)?;
388        self.stream.synchronize()
389    }
390
391    pub fn htod_sync_copy<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
392        &self,
393        src: &Src,
394    ) -> std::result::Result<CudaSlice<T>, DriverError> {
395        let dst = self.stream.clone_htod(src)?;
396        self.stream.synchronize()?;
397        Ok(dst)
398    }
399
400    pub fn dtoh_sync_copy<T: DeviceRepr, Src: DevicePtr<T>>(
401        &self,
402        src: &Src,
403    ) -> std::result::Result<Vec<T>, DriverError> {
404        let dst = self.stream.clone_dtoh(src)?;
405        self.stream.synchronize()?;
406        Ok(dst)
407    }
408
409    pub fn dtod_copy<T, Src: DevicePtr<T>, Dst: DevicePtrMut<T>>(
410        &self,
411        src: &Src,
412        dst: &mut Dst,
413    ) -> std::result::Result<(), DriverError> {
414        self.stream.memcpy_dtod(src, dst)?;
415        self.stream.synchronize()
416    }
417
418    /// Wrap an existing CUDA device pointer in a typed cudarc slice.
419    ///
420    /// # Safety
421    ///
422    /// `cu_device_ptr` must point to a live allocation containing at least
423    /// `len * size_of::<T>()` bytes, and the resulting wrapper must not outlive
424    /// the allocation or alias another owner that will free it independently.
425    pub unsafe fn upgrade_device_ptr<T>(
426        &self,
427        cu_device_ptr: sys::CUdeviceptr,
428        len: usize,
429    ) -> CudaSlice<T> {
430        self.stream.upgrade_device_ptr(cu_device_ptr, len)
431    }
432
433    pub fn attribute(
434        &self,
435        attrib: sys::CUdevice_attribute,
436    ) -> std::result::Result<i32, DriverError> {
437        self.context.attribute(attrib)
438    }
439
440    pub fn synchronize(&self) -> std::result::Result<(), DriverError> {
441        self.stream.synchronize()
442    }
443
444    pub fn ordinal(&self) -> usize {
445        self.context.ordinal()
446    }
447}
448
449/// CUDA device wrapper for GPU operations.
450///
451/// This keeps XLOG's historical "device with a built-in default stream" API,
452/// but is backed by cudarc's newer `CudaContext` and `CudaStream`.
453pub struct CudaDevice {
454    device: Arc<CudaDeviceInner>,
455}
456
457impl CudaDevice {
458    /// Create a new CUDA device on the specified GPU ordinal.
459    pub fn new(ordinal: usize) -> Result<Self> {
460        let context = std::panic::catch_unwind(|| CudarcContext::new(ordinal))
461            .map_err(|_| {
462                XlogError::Kernel(format!(
463                    "Failed to create CUDA device {}: cudarc panicked during driver initialization",
464                    ordinal
465                ))
466            })?
467            .map_err(|e| {
468                XlogError::Kernel(format!("Failed to create CUDA device {}: {}", ordinal, e))
469            })?;
470
471        let stream = context.default_stream();
472        Ok(Self {
473            device: Arc::new(CudaDeviceInner {
474                context,
475                stream,
476                modules: RwLock::new(BTreeMap::new()),
477            }),
478        })
479    }
480
481    pub fn count() -> Result<i32> {
482        std::panic::catch_unwind(|| {
483            result::init()?;
484            result::device::get_count()
485        })
486        .map_err(|_| {
487            XlogError::Kernel(
488                "Failed to count CUDA devices: cudarc panicked during driver initialization"
489                    .to_string(),
490            )
491        })?
492        .map_err(|e| XlogError::Kernel(format!("Failed to count CUDA devices: {}", e)))
493    }
494
495    pub fn synchronize(&self) -> Result<()> {
496        self.device
497            .synchronize()
498            .map_err(|e| XlogError::Kernel(format!("Failed to synchronize device: {}", e)))
499    }
500
501    pub fn inner(&self) -> &Arc<CudaDeviceInner> {
502        &self.device
503    }
504
505    pub fn ordinal(&self) -> usize {
506        self.device.ordinal()
507    }
508}
509
510// Compile-time assertion: CudaDevice must be Send so pyxlog can use py.allow_threads().
511const _: () = {
512    fn _assert_send<T: Send>() {}
513    fn _check() {
514        _assert_send::<CudaDevice>();
515    }
516};
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_device_creation() {
524        let device = match CudaDevice::new(0) {
525            Ok(d) => d,
526            Err(e) => {
527                eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
528                return;
529            }
530        };
531        drop(device);
532    }
533
534    #[test]
535    fn test_device_synchronize() {
536        let device = match CudaDevice::new(0) {
537            Ok(d) => d,
538            Err(e) => {
539                eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
540                return;
541            }
542        };
543        let result = device.synchronize();
544        assert!(result.is_ok(), "Failed to synchronize: {:?}", result.err());
545    }
546
547    #[test]
548    fn test_device_ordinal() {
549        let device = match CudaDevice::new(0) {
550            Ok(d) => d,
551            Err(e) => {
552                eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
553                return;
554            }
555        };
556        assert_eq!(device.ordinal(), 0);
557    }
558
559    #[test]
560    fn test_device_inner_access() {
561        let device = match CudaDevice::new(0) {
562            Ok(d) => d,
563            Err(e) => {
564                eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
565                return;
566            }
567        };
568        let inner = device.inner();
569        assert_eq!(inner.ordinal(), 0);
570    }
571
572    #[test]
573    fn test_invalid_device_ordinal() {
574        let result = CudaDevice::new(9999);
575        assert!(result.is_err(), "Should fail with invalid ordinal");
576
577        if let Err(XlogError::Kernel(msg)) = result {
578            assert!(msg.contains("9999"), "Error should mention device ordinal");
579        } else {
580            panic!("Expected XlogError::Kernel");
581        }
582    }
583}