1use std::collections::BTreeMap;
7use std::ffi::{c_void, CString};
8use std::path::Path;
9use std::sync::{Arc, RwLock};
10
11use cudarc::driver::result::{self, DriverError};
12use cudarc::driver::{
13 sys, CudaContext as CudarcContext, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr,
14 HostSlice, LaunchConfig, ValidAsZeroBits,
15};
16use cudarc::nvrtc::Ptx;
17use xlog_core::{Result, XlogError};
18
19#[derive(Debug)]
20struct LoadedModule {
21 cu_module: sys::CUmodule,
22 functions: BTreeMap<String, sys::CUfunction>,
23}
24
25unsafe impl Send for LoadedModule {}
26unsafe impl Sync for LoadedModule {}
27
28#[derive(Debug, Clone)]
30pub struct CudaFunction {
31 cu_function: sys::CUfunction,
32 context: Arc<CudarcContext>,
33 stream: Arc<CudaStream>,
34}
35
36impl CudaFunction {
37 pub(crate) unsafe fn launch_raw(
38 &self,
39 cfg: LaunchConfig,
40 params: &mut [*mut c_void],
41 ) -> std::result::Result<(), DriverError> {
42 self.context.bind_to_thread()?;
43 result::launch_kernel(
44 self.cu_function,
45 cfg.grid_dim,
46 cfg.block_dim,
47 cfg.shared_mem_bytes,
48 self.stream.cu_stream(),
49 params,
50 )
51 }
52
53 pub(crate) unsafe fn launch_raw_on_stream(
54 &self,
55 stream: &CudaStream,
56 cfg: LaunchConfig,
57 params: &mut [*mut c_void],
58 ) -> std::result::Result<(), DriverError> {
59 self.context.bind_to_thread()?;
60 result::launch_kernel(
61 self.cu_function,
62 cfg.grid_dim,
63 cfg.block_dim,
64 cfg.shared_mem_bytes,
65 stream.cu_stream(),
66 params,
67 )
68 }
69
70 pub(crate) unsafe fn launch_raw_cooperative(
71 &self,
72 cfg: LaunchConfig,
73 params: &mut [*mut c_void],
74 ) -> std::result::Result<(), DriverError> {
75 self.context.bind_to_thread()?;
76 result::launch_cooperative_kernel(
77 self.cu_function,
78 cfg.grid_dim,
79 cfg.block_dim,
80 cfg.shared_mem_bytes,
81 self.stream.cu_stream(),
82 params,
83 )
84 }
85
86 pub fn occupancy_available_dynamic_smem_per_block(
87 &self,
88 num_blocks: u32,
89 block_size: u32,
90 ) -> std::result::Result<usize, DriverError> {
91 let mut dynamic_smem_size: usize = 0;
92 unsafe {
93 sys::cuOccupancyAvailableDynamicSMemPerBlock(
94 &mut dynamic_smem_size,
95 self.cu_function,
96 num_blocks as std::ffi::c_int,
97 block_size as std::ffi::c_int,
98 )
99 .result()?
100 };
101 Ok(dynamic_smem_size)
102 }
103
104 pub fn occupancy_max_active_blocks_per_multiprocessor(
105 &self,
106 block_size: u32,
107 dynamic_smem_size: usize,
108 flags: Option<sys::CUoccupancy_flags_enum>,
109 ) -> std::result::Result<u32, DriverError> {
110 let mut num_blocks: std::ffi::c_int = 0;
111 let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
112 unsafe {
113 sys::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
114 &mut num_blocks,
115 self.cu_function,
116 block_size as std::ffi::c_int,
117 dynamic_smem_size,
118 flags as std::ffi::c_uint,
119 )
120 .result()?
121 };
122 Ok(num_blocks as u32)
123 }
124
125 pub fn occupancy_max_active_clusters(
126 &self,
127 config: LaunchConfig,
128 ) -> std::result::Result<u32, DriverError> {
129 let mut num_clusters: std::ffi::c_int = 0;
130 let cfg = sys::CUlaunchConfig {
131 gridDimX: config.grid_dim.0,
132 gridDimY: config.grid_dim.1,
133 gridDimZ: config.grid_dim.2,
134 blockDimX: config.block_dim.0,
135 blockDimY: config.block_dim.1,
136 blockDimZ: config.block_dim.2,
137 sharedMemBytes: config.shared_mem_bytes,
138 hStream: self.stream.cu_stream(),
139 attrs: std::ptr::null_mut(),
140 numAttrs: 0,
141 };
142 unsafe {
143 sys::cuOccupancyMaxActiveClusters(&mut num_clusters, self.cu_function, &cfg).result()?
144 };
145 Ok(num_clusters as u32)
146 }
147
148 pub fn occupancy_max_potential_block_size(
149 &self,
150 block_size_to_dynamic_smem_size: extern "C" fn(block_size: std::ffi::c_int) -> usize,
151 dynamic_smem_size: usize,
152 block_size_limit: u32,
153 flags: Option<sys::CUoccupancy_flags_enum>,
154 ) -> std::result::Result<(u32, u32), DriverError> {
155 let mut min_grid_size: std::ffi::c_int = 0;
156 let mut block_size: std::ffi::c_int = 0;
157 let flags = flags.unwrap_or(sys::CUoccupancy_flags_enum::CU_OCCUPANCY_DEFAULT);
158 unsafe {
159 sys::cuOccupancyMaxPotentialBlockSizeWithFlags(
160 &mut min_grid_size,
161 &mut block_size,
162 self.cu_function,
163 Some(block_size_to_dynamic_smem_size),
164 dynamic_smem_size,
165 block_size_limit as std::ffi::c_int,
166 flags as std::ffi::c_uint,
167 )
168 .result()?
169 };
170 Ok((min_grid_size as u32, block_size as u32))
171 }
172
173 pub fn occupancy_max_potential_cluster_size(
174 &self,
175 config: LaunchConfig,
176 ) -> std::result::Result<u32, DriverError> {
177 let mut cluster_size: std::ffi::c_int = 0;
178 let cfg = sys::CUlaunchConfig {
179 gridDimX: config.grid_dim.0,
180 gridDimY: config.grid_dim.1,
181 gridDimZ: config.grid_dim.2,
182 blockDimX: config.block_dim.0,
183 blockDimY: config.block_dim.1,
184 blockDimZ: config.block_dim.2,
185 sharedMemBytes: config.shared_mem_bytes,
186 hStream: self.stream.cu_stream(),
187 attrs: std::ptr::null_mut(),
188 numAttrs: 0,
189 };
190 unsafe {
191 sys::cuOccupancyMaxPotentialClusterSize(&mut cluster_size, self.cu_function, &cfg)
192 .result()?
193 };
194 Ok(cluster_size as u32)
195 }
196
197 pub fn get_attribute(
198 &self,
199 attribute: sys::CUfunction_attribute_enum,
200 ) -> std::result::Result<i32, DriverError> {
201 self.context.bind_to_thread()?;
202 unsafe { result::function::get_function_attribute(self.cu_function, attribute) }
203 }
204
205 pub fn num_regs(&self) -> std::result::Result<i32, DriverError> {
206 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_NUM_REGS)
207 }
208
209 pub fn shared_size_bytes(&self) -> std::result::Result<i32, DriverError> {
210 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)
211 }
212
213 pub fn const_size_bytes(&self) -> std::result::Result<i32, DriverError> {
214 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES)
215 }
216
217 pub fn local_size_bytes(&self) -> std::result::Result<i32, DriverError> {
218 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES)
219 }
220
221 pub fn max_threads_per_block(&self) -> std::result::Result<i32, DriverError> {
222 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK)
223 }
224
225 pub fn ptx_version(&self) -> std::result::Result<i32, DriverError> {
226 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_PTX_VERSION)
227 }
228
229 pub fn binary_version(&self) -> std::result::Result<i32, DriverError> {
230 self.get_attribute(sys::CUfunction_attribute_enum::CU_FUNC_ATTRIBUTE_BINARY_VERSION)
231 }
232
233 pub fn set_attribute(
234 &self,
235 attribute: sys::CUfunction_attribute_enum,
236 value: i32,
237 ) -> std::result::Result<(), DriverError> {
238 unsafe { result::function::set_function_attribute(self.cu_function, attribute, value) }
239 }
240
241 pub fn set_function_cache_config(
242 &self,
243 config: sys::CUfunc_cache,
244 ) -> std::result::Result<(), DriverError> {
245 unsafe { result::function::set_function_cache_config(self.cu_function, config) }
246 }
247}
248
249#[derive(Debug)]
250pub struct CudaDeviceInner {
251 context: Arc<CudarcContext>,
252 stream: Arc<CudaStream>,
253 modules: RwLock<BTreeMap<String, LoadedModule>>,
254}
255
256impl Drop for CudaDeviceInner {
257 fn drop(&mut self) {
258 let _ = self.context.bind_to_thread();
259 if let Ok(modules) = self.modules.get_mut() {
260 for module in modules.values() {
261 let _ = unsafe { result::module::unload(module.cu_module) };
262 }
263 modules.clear();
264 }
265 }
266}
267
268impl CudaDeviceInner {
269 fn insert_module(
270 &self,
271 module_name: &str,
272 cu_module: sys::CUmodule,
273 kernels: &[&str],
274 ) -> std::result::Result<(), DriverError> {
275 let mut functions = BTreeMap::new();
276 for &kernel in kernels {
277 let name_c = CString::new(kernel).unwrap();
278 let cu_function = unsafe { result::module::get_function(cu_module, name_c) }?;
279 functions.insert(kernel.to_string(), cu_function);
280 }
281 let module = LoadedModule {
282 cu_module,
283 functions,
284 };
285
286 let mut modules = self.modules.write().unwrap();
287 if let Some(prev) = modules.insert(module_name.to_string(), module) {
288 unsafe { result::module::unload(prev.cu_module) }?;
289 }
290 Ok(())
291 }
292
293 pub fn stream(&self) -> &Arc<CudaStream> {
294 &self.stream
295 }
296
297 pub fn has_func(&self, module_name: &str, func_name: &str) -> bool {
298 let modules = self.modules.read().unwrap();
299 modules
300 .get(module_name)
301 .is_some_and(|module| module.functions.contains_key(func_name))
302 }
303
304 pub fn get_func(&self, module_name: &str, func_name: &str) -> Option<CudaFunction> {
305 let modules = self.modules.read().unwrap();
306 let cu_function = modules
307 .get(module_name)
308 .and_then(|module| module.functions.get(func_name))
309 .copied()?;
310 Some(CudaFunction {
311 cu_function,
312 context: self.context.clone(),
313 stream: self.stream.clone(),
314 })
315 }
316
317 pub fn load_file(
318 &self,
319 path: &Path,
320 module_name: &str,
321 kernels: &[&str],
322 ) -> std::result::Result<(), DriverError> {
323 self.context.bind_to_thread()?;
324 let name_c = CString::new(path.to_string_lossy().as_bytes()).unwrap();
325 let cu_module = result::module::load(name_c)?;
326 self.insert_module(module_name, cu_module, kernels)
327 }
328
329 pub fn load_ptx(
330 &self,
331 ptx: Ptx,
332 module_name: &str,
333 kernels: &[&str],
334 ) -> std::result::Result<(), DriverError> {
335 self.context.bind_to_thread()?;
336 let cu_module = if let Some(bytes) = ptx.as_bytes() {
337 unsafe { result::module::load_data(bytes.as_ptr() as *const _) }?
338 } else {
339 let src = CString::new(ptx.to_src()).unwrap();
340 unsafe { result::module::load_data(src.as_ptr() as *const _) }?
341 };
342 self.insert_module(module_name, cu_module, kernels)
343 }
344
345 pub unsafe fn alloc<T: DeviceRepr>(
352 &self,
353 len: usize,
354 ) -> std::result::Result<CudaSlice<T>, DriverError> {
355 self.stream.alloc(len)
356 }
357
358 pub fn alloc_zeros<T: DeviceRepr + ValidAsZeroBits>(
359 &self,
360 len: usize,
361 ) -> std::result::Result<CudaSlice<T>, DriverError> {
362 self.stream.alloc_zeros(len)
363 }
364
365 pub fn memset_zeros<T: DeviceRepr + ValidAsZeroBits, Dst: DevicePtrMut<T>>(
366 &self,
367 dst: &mut Dst,
368 ) -> std::result::Result<(), DriverError> {
369 self.stream.memset_zeros(dst)?;
370 self.stream.synchronize()
371 }
372
373 pub fn htod_sync_copy_into<T: DeviceRepr, Dst: DevicePtrMut<T>, Src: HostSlice<T> + ?Sized>(
374 &self,
375 src: &Src,
376 dst: &mut Dst,
377 ) -> std::result::Result<(), DriverError> {
378 self.stream.memcpy_htod(src, dst)?;
379 self.stream.synchronize()
380 }
381
382 pub fn dtoh_sync_copy_into<T: DeviceRepr, Src: DevicePtr<T>, Dst: HostSlice<T> + ?Sized>(
383 &self,
384 src: &Src,
385 dst: &mut Dst,
386 ) -> std::result::Result<(), DriverError> {
387 self.stream.memcpy_dtoh(src, dst)?;
388 self.stream.synchronize()
389 }
390
391 pub fn htod_sync_copy<T: DeviceRepr, Src: HostSlice<T> + ?Sized>(
392 &self,
393 src: &Src,
394 ) -> std::result::Result<CudaSlice<T>, DriverError> {
395 let dst = self.stream.clone_htod(src)?;
396 self.stream.synchronize()?;
397 Ok(dst)
398 }
399
400 pub fn dtoh_sync_copy<T: DeviceRepr, Src: DevicePtr<T>>(
401 &self,
402 src: &Src,
403 ) -> std::result::Result<Vec<T>, DriverError> {
404 let dst = self.stream.clone_dtoh(src)?;
405 self.stream.synchronize()?;
406 Ok(dst)
407 }
408
409 pub fn dtod_copy<T, Src: DevicePtr<T>, Dst: DevicePtrMut<T>>(
410 &self,
411 src: &Src,
412 dst: &mut Dst,
413 ) -> std::result::Result<(), DriverError> {
414 self.stream.memcpy_dtod(src, dst)?;
415 self.stream.synchronize()
416 }
417
418 pub unsafe fn upgrade_device_ptr<T>(
426 &self,
427 cu_device_ptr: sys::CUdeviceptr,
428 len: usize,
429 ) -> CudaSlice<T> {
430 self.stream.upgrade_device_ptr(cu_device_ptr, len)
431 }
432
433 pub fn attribute(
434 &self,
435 attrib: sys::CUdevice_attribute,
436 ) -> std::result::Result<i32, DriverError> {
437 self.context.attribute(attrib)
438 }
439
440 pub fn synchronize(&self) -> std::result::Result<(), DriverError> {
441 self.stream.synchronize()
442 }
443
444 pub fn ordinal(&self) -> usize {
445 self.context.ordinal()
446 }
447}
448
449pub struct CudaDevice {
454 device: Arc<CudaDeviceInner>,
455}
456
457impl CudaDevice {
458 pub fn new(ordinal: usize) -> Result<Self> {
460 let context = std::panic::catch_unwind(|| CudarcContext::new(ordinal))
461 .map_err(|_| {
462 XlogError::Kernel(format!(
463 "Failed to create CUDA device {}: cudarc panicked during driver initialization",
464 ordinal
465 ))
466 })?
467 .map_err(|e| {
468 XlogError::Kernel(format!("Failed to create CUDA device {}: {}", ordinal, e))
469 })?;
470
471 let stream = context.default_stream();
472 Ok(Self {
473 device: Arc::new(CudaDeviceInner {
474 context,
475 stream,
476 modules: RwLock::new(BTreeMap::new()),
477 }),
478 })
479 }
480
481 pub fn count() -> Result<i32> {
482 std::panic::catch_unwind(|| {
483 result::init()?;
484 result::device::get_count()
485 })
486 .map_err(|_| {
487 XlogError::Kernel(
488 "Failed to count CUDA devices: cudarc panicked during driver initialization"
489 .to_string(),
490 )
491 })?
492 .map_err(|e| XlogError::Kernel(format!("Failed to count CUDA devices: {}", e)))
493 }
494
495 pub fn synchronize(&self) -> Result<()> {
496 self.device
497 .synchronize()
498 .map_err(|e| XlogError::Kernel(format!("Failed to synchronize device: {}", e)))
499 }
500
501 pub fn inner(&self) -> &Arc<CudaDeviceInner> {
502 &self.device
503 }
504
505 pub fn ordinal(&self) -> usize {
506 self.device.ordinal()
507 }
508}
509
510const _: () = {
512 fn _assert_send<T: Send>() {}
513 fn _check() {
514 _assert_send::<CudaDevice>();
515 }
516};
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_device_creation() {
524 let device = match CudaDevice::new(0) {
525 Ok(d) => d,
526 Err(e) => {
527 eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
528 return;
529 }
530 };
531 drop(device);
532 }
533
534 #[test]
535 fn test_device_synchronize() {
536 let device = match CudaDevice::new(0) {
537 Ok(d) => d,
538 Err(e) => {
539 eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
540 return;
541 }
542 };
543 let result = device.synchronize();
544 assert!(result.is_ok(), "Failed to synchronize: {:?}", result.err());
545 }
546
547 #[test]
548 fn test_device_ordinal() {
549 let device = match CudaDevice::new(0) {
550 Ok(d) => d,
551 Err(e) => {
552 eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
553 return;
554 }
555 };
556 assert_eq!(device.ordinal(), 0);
557 }
558
559 #[test]
560 fn test_device_inner_access() {
561 let device = match CudaDevice::new(0) {
562 Ok(d) => d,
563 Err(e) => {
564 eprintln!("Skipping test: CUDA runtime unavailable: {}", e);
565 return;
566 }
567 };
568 let inner = device.inner();
569 assert_eq!(inner.ordinal(), 0);
570 }
571
572 #[test]
573 fn test_invalid_device_ordinal() {
574 let result = CudaDevice::new(9999);
575 assert!(result.is_err(), "Should fail with invalid ordinal");
576
577 if let Err(XlogError::Kernel(msg)) = result {
578 assert!(msg.contains("9999"), "Error should mention device ordinal");
579 } else {
580 panic!("Expected XlogError::Kernel");
581 }
582 }
583}