Skip to main content

xlog_cuda/
cuda_compat.rs

1use std::ffi::c_void;
2
3use cudarc::driver::{self, SyncOnDrop};
4
5pub use cudarc::driver::{
6    sys, CudaSlice, CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut, DeviceRepr,
7    DeviceSlice, DriverError, LaunchConfig, ValidAsZeroBits,
8};
9
10pub use crate::device::CudaFunction;
11
12mod sealed {
13    pub trait KernelScalarSealed {}
14
15    macro_rules! impl_kernel_scalar_sealed {
16        ($($ty:ty),* $(,)?) => {
17            $(impl KernelScalarSealed for $ty {})*
18        };
19    }
20
21    impl_kernel_scalar_sealed!(
22        bool, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
23    );
24}
25
26/// Stable host-side storage for a kernel argument pointer.
27pub trait KernelParamStorage {
28    fn as_kernel_param(&self) -> *mut c_void;
29}
30
31#[derive(Debug)]
32pub struct ScalarParamStorage<T>(T);
33
34impl<T> KernelParamStorage for ScalarParamStorage<T> {
35    fn as_kernel_param(&self) -> *mut c_void {
36        (&self.0 as *const T).cast_mut().cast()
37    }
38}
39
40#[derive(Debug)]
41pub struct DeviceParamStorage<'a> {
42    ptr: driver::sys::CUdeviceptr,
43    _sync: Option<SyncOnDrop<'a>>,
44}
45
46impl<'a> DeviceParamStorage<'a> {
47    pub fn synced(ptr: driver::sys::CUdeviceptr, sync: SyncOnDrop<'a>) -> Self {
48        Self {
49            ptr,
50            _sync: Some(sync),
51        }
52    }
53
54    pub fn unsynced(ptr: driver::sys::CUdeviceptr) -> Self {
55        Self { ptr, _sync: None }
56    }
57}
58
59impl KernelParamStorage for DeviceParamStorage<'_> {
60    fn as_kernel_param(&self) -> *mut c_void {
61        (&self.ptr as *const driver::sys::CUdeviceptr)
62            .cast_mut()
63            .cast()
64    }
65}
66
67/// Backwards-compatible `as_kernel_param()` helper for manual raw launch lists.
68pub trait AsKernelParam {
69    fn as_kernel_param(&self) -> *mut c_void;
70}
71
72/// Convert a launch argument into storage that lives until `cuLaunchKernel` runs.
73pub trait IntoKernelParamStorage {
74    type Storage: KernelParamStorage;
75
76    fn into_kernel_param_storage(self) -> Self::Storage;
77}
78
79/// Scalar kernel parameters that can be copied directly into launch storage.
80pub trait KernelScalar:
81    sealed::KernelScalarSealed
82    + cudarc::driver::DeviceRepr
83    + Copy
84    + 'static
85    + AsKernelParam
86    + IntoKernelParamStorage
87{
88}
89
90macro_rules! impl_kernel_scalar {
91    ($($ty:ty),* $(,)?) => {
92        $(
93            impl KernelScalar for $ty {}
94
95            impl AsKernelParam for $ty {
96                fn as_kernel_param(&self) -> *mut c_void {
97                    (self as *const $ty).cast_mut().cast()
98                }
99            }
100
101            impl AsKernelParam for &$ty {
102                fn as_kernel_param(&self) -> *mut c_void {
103                    (*self as *const $ty).cast_mut().cast()
104                }
105            }
106
107            impl AsKernelParam for &mut $ty {
108                fn as_kernel_param(&self) -> *mut c_void {
109                    (*self as *const $ty).cast_mut().cast()
110                }
111            }
112
113            impl IntoKernelParamStorage for $ty {
114                type Storage = ScalarParamStorage<$ty>;
115
116                fn into_kernel_param_storage(self) -> Self::Storage {
117                    ScalarParamStorage(self)
118                }
119            }
120        )*
121    };
122}
123
124impl_kernel_scalar!(bool, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64);
125
126impl<'a, T> IntoKernelParamStorage for &'a CudaSlice<T> {
127    type Storage = DeviceParamStorage<'a>;
128
129    fn into_kernel_param_storage(self) -> Self::Storage {
130        let stream = self.stream();
131        let (ptr, sync) = cudarc::driver::DevicePtr::device_ptr(self, stream);
132        DeviceParamStorage::synced(ptr, sync)
133    }
134}
135
136impl<T> IntoKernelParamStorage for &mut CudaSlice<T> {
137    type Storage = DeviceParamStorage<'static>;
138
139    fn into_kernel_param_storage(self) -> Self::Storage {
140        let stream = self.stream().clone();
141        let (ptr, sync) = cudarc::driver::DevicePtrMut::device_ptr_mut(self, &stream);
142        std::mem::forget(sync);
143        DeviceParamStorage::unsynced(ptr)
144    }
145}
146
147impl<'a, 'b, T> IntoKernelParamStorage for &'a CudaView<'b, T> {
148    type Storage = DeviceParamStorage<'a>;
149
150    fn into_kernel_param_storage(self) -> Self::Storage {
151        let stream = self.stream();
152        let (ptr, sync) = cudarc::driver::DevicePtr::device_ptr(self, stream);
153        DeviceParamStorage::synced(ptr, sync)
154    }
155}
156
157impl<'a, 'b, T> IntoKernelParamStorage for &'a CudaViewMut<'b, T> {
158    type Storage = DeviceParamStorage<'a>;
159
160    fn into_kernel_param_storage(self) -> Self::Storage {
161        let stream = self.stream();
162        let (ptr, sync) = cudarc::driver::DevicePtr::device_ptr(self, stream);
163        DeviceParamStorage::synced(ptr, sync)
164    }
165}
166
167impl<'a, 'b, T> IntoKernelParamStorage for &'a mut CudaViewMut<'b, T> {
168    type Storage = DeviceParamStorage<'static>;
169
170    fn into_kernel_param_storage(self) -> Self::Storage {
171        let stream = self.stream().clone();
172        let (ptr, sync) = cudarc::driver::DevicePtrMut::device_ptr_mut(self, &stream);
173        std::mem::forget(sync);
174        DeviceParamStorage::unsynced(ptr)
175    }
176}
177
178/// Old cudarc-style launch trait reimplemented on top of CUDA 13-compatible
179/// raw kernel launches.
180///
181/// # Safety
182/// Implementors must preserve CUDA's launch semantics and must not let kernel
183/// parameter storage or referenced device memory expire before the launch is
184/// enqueued on the target stream.
185pub unsafe trait LaunchAsync<Params> {
186    /// Launch a kernel on the function's default stream.
187    ///
188    /// # Safety
189    /// `params` must match the underlying CUDA kernel ABI exactly, and all
190    /// referenced device pointers must stay valid until the launch is enqueued.
191    unsafe fn launch(self, cfg: LaunchConfig, params: Params) -> Result<(), DriverError>;
192
193    /// Launch a kernel on an explicit CUDA stream.
194    ///
195    /// # Safety
196    /// The caller must uphold the same ABI and lifetime guarantees as `launch`
197    /// and must ensure `stream` is valid for the target device.
198    unsafe fn launch_on_stream(
199        self,
200        stream: &CudaStream,
201        cfg: LaunchConfig,
202        params: Params,
203    ) -> Result<(), DriverError>;
204
205    /// Launch a cooperative kernel.
206    ///
207    /// # Safety
208    /// The caller must uphold the same ABI and lifetime guarantees as `launch`
209    /// and must also ensure the kernel/configuration satisfies CUDA cooperative
210    /// launch requirements.
211    unsafe fn launch_cooperative(
212        self,
213        cfg: LaunchConfig,
214        params: Params,
215    ) -> Result<(), DriverError>;
216}
217
218unsafe impl LaunchAsync<&mut [*mut c_void]> for CudaFunction {
219    unsafe fn launch(
220        self,
221        cfg: LaunchConfig,
222        params: &mut [*mut c_void],
223    ) -> Result<(), DriverError> {
224        self.launch_raw(cfg, params)
225    }
226
227    unsafe fn launch_on_stream(
228        self,
229        stream: &CudaStream,
230        cfg: LaunchConfig,
231        params: &mut [*mut c_void],
232    ) -> Result<(), DriverError> {
233        self.launch_raw_on_stream(stream, cfg, params)
234    }
235
236    unsafe fn launch_cooperative(
237        self,
238        cfg: LaunchConfig,
239        params: &mut [*mut c_void],
240    ) -> Result<(), DriverError> {
241        self.launch_raw_cooperative(cfg, params)
242    }
243}
244
245unsafe impl LaunchAsync<&mut Vec<*mut c_void>> for CudaFunction {
246    unsafe fn launch(
247        self,
248        cfg: LaunchConfig,
249        params: &mut Vec<*mut c_void>,
250    ) -> Result<(), DriverError> {
251        self.launch_raw(cfg, params)
252    }
253
254    unsafe fn launch_on_stream(
255        self,
256        stream: &CudaStream,
257        cfg: LaunchConfig,
258        params: &mut Vec<*mut c_void>,
259    ) -> Result<(), DriverError> {
260        self.launch_raw_on_stream(stream, cfg, params)
261    }
262
263    unsafe fn launch_cooperative(
264        self,
265        cfg: LaunchConfig,
266        params: &mut Vec<*mut c_void>,
267    ) -> Result<(), DriverError> {
268        self.launch_raw_cooperative(cfg, params)
269    }
270}
271
272macro_rules! impl_launch_tuple {
273    ([$($var:ident),*], [$($idx:tt),*]) => {
274        #[allow(non_snake_case)]
275        unsafe impl<$($var: IntoKernelParamStorage),*> LaunchAsync<($($var,)*)> for CudaFunction {
276            unsafe fn launch(
277                self,
278                cfg: LaunchConfig,
279                params: ($($var,)*),
280            ) -> Result<(), DriverError> {
281                let ($($var,)*) = params;
282                $(let $var = $var.into_kernel_param_storage();)*
283                let mut raw = [$( $var.as_kernel_param(), )*];
284                self.launch_raw(cfg, &mut raw)
285            }
286
287            unsafe fn launch_on_stream(
288                self,
289                stream: &CudaStream,
290                cfg: LaunchConfig,
291                params: ($($var,)*),
292            ) -> Result<(), DriverError> {
293                let ($($var,)*) = params;
294                $(let $var = $var.into_kernel_param_storage();)*
295                let mut raw = [$( $var.as_kernel_param(), )*];
296                self.launch_raw_on_stream(stream, cfg, &mut raw)
297            }
298
299            unsafe fn launch_cooperative(
300                self,
301                cfg: LaunchConfig,
302                params: ($($var,)*),
303            ) -> Result<(), DriverError> {
304                let ($($var,)*) = params;
305                $(let $var = $var.into_kernel_param_storage();)*
306                let mut raw = [$( $var.as_kernel_param(), )*];
307                self.launch_raw_cooperative(cfg, &mut raw)
308            }
309        }
310    };
311}
312
313impl_launch_tuple!([A], [0]);
314impl_launch_tuple!([A, B], [0, 1]);
315impl_launch_tuple!([A, B, C], [0, 1, 2]);
316impl_launch_tuple!([A, B, C, D], [0, 1, 2, 3]);
317impl_launch_tuple!([A, B, C, D, E], [0, 1, 2, 3, 4]);
318impl_launch_tuple!([A, B, C, D, E, F], [0, 1, 2, 3, 4, 5]);
319impl_launch_tuple!([A, B, C, D, E, F, G], [0, 1, 2, 3, 4, 5, 6]);
320impl_launch_tuple!([A, B, C, D, E, F, G, H], [0, 1, 2, 3, 4, 5, 6, 7]);
321impl_launch_tuple!([A, B, C, D, E, F, G, H, I], [0, 1, 2, 3, 4, 5, 6, 7, 8]);
322impl_launch_tuple!(
323    [A, B, C, D, E, F, G, H, I, J],
324    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
325);
326impl_launch_tuple!(
327    [A, B, C, D, E, F, G, H, I, J, K],
328    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
329);
330impl_launch_tuple!(
331    [A, B, C, D, E, F, G, H, I, J, K, L],
332    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
333);
334impl_launch_tuple!(
335    [A, B, C, D, E, F, G, H, I, J, K, L, M],
336    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
337);