1use std::ffi::c_void;
2
3use cudarc::driver::{self, SyncOnDrop};
4
5pub use cudarc::driver::{
6 sys, CudaSlice, CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut, DeviceRepr,
7 DeviceSlice, DriverError, LaunchConfig, ValidAsZeroBits,
8};
9
10pub use crate::device::CudaFunction;
11
12mod sealed {
13 pub trait KernelScalarSealed {}
14
15 macro_rules! impl_kernel_scalar_sealed {
16 ($($ty:ty),* $(,)?) => {
17 $(impl KernelScalarSealed for $ty {})*
18 };
19 }
20
21 impl_kernel_scalar_sealed!(
22 bool, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64
23 );
24}
25
26pub trait KernelParamStorage {
28 fn as_kernel_param(&self) -> *mut c_void;
29}
30
31#[derive(Debug)]
32pub struct ScalarParamStorage<T>(T);
33
34impl<T> KernelParamStorage for ScalarParamStorage<T> {
35 fn as_kernel_param(&self) -> *mut c_void {
36 (&self.0 as *const T).cast_mut().cast()
37 }
38}
39
40#[derive(Debug)]
41pub struct DeviceParamStorage<'a> {
42 ptr: driver::sys::CUdeviceptr,
43 _sync: Option<SyncOnDrop<'a>>,
44}
45
46impl<'a> DeviceParamStorage<'a> {
47 pub fn synced(ptr: driver::sys::CUdeviceptr, sync: SyncOnDrop<'a>) -> Self {
48 Self {
49 ptr,
50 _sync: Some(sync),
51 }
52 }
53
54 pub fn unsynced(ptr: driver::sys::CUdeviceptr) -> Self {
55 Self { ptr, _sync: None }
56 }
57}
58
59impl KernelParamStorage for DeviceParamStorage<'_> {
60 fn as_kernel_param(&self) -> *mut c_void {
61 (&self.ptr as *const driver::sys::CUdeviceptr)
62 .cast_mut()
63 .cast()
64 }
65}
66
67pub trait AsKernelParam {
69 fn as_kernel_param(&self) -> *mut c_void;
70}
71
72pub trait IntoKernelParamStorage {
74 type Storage: KernelParamStorage;
75
76 fn into_kernel_param_storage(self) -> Self::Storage;
77}
78
79pub trait KernelScalar:
81 sealed::KernelScalarSealed
82 + cudarc::driver::DeviceRepr
83 + Copy
84 + 'static
85 + AsKernelParam
86 + IntoKernelParamStorage
87{
88}
89
90macro_rules! impl_kernel_scalar {
91 ($($ty:ty),* $(,)?) => {
92 $(
93 impl KernelScalar for $ty {}
94
95 impl AsKernelParam for $ty {
96 fn as_kernel_param(&self) -> *mut c_void {
97 (self as *const $ty).cast_mut().cast()
98 }
99 }
100
101 impl AsKernelParam for &$ty {
102 fn as_kernel_param(&self) -> *mut c_void {
103 (*self as *const $ty).cast_mut().cast()
104 }
105 }
106
107 impl AsKernelParam for &mut $ty {
108 fn as_kernel_param(&self) -> *mut c_void {
109 (*self as *const $ty).cast_mut().cast()
110 }
111 }
112
113 impl IntoKernelParamStorage for $ty {
114 type Storage = ScalarParamStorage<$ty>;
115
116 fn into_kernel_param_storage(self) -> Self::Storage {
117 ScalarParamStorage(self)
118 }
119 }
120 )*
121 };
122}
123
124impl_kernel_scalar!(bool, i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64);
125
126impl<'a, T> IntoKernelParamStorage for &'a CudaSlice<T> {
127 type Storage = DeviceParamStorage<'a>;
128
129 fn into_kernel_param_storage(self) -> Self::Storage {
130 let stream = self.stream();
131 let (ptr, sync) = cudarc::driver::DevicePtr::device_ptr(self, stream);
132 DeviceParamStorage::synced(ptr, sync)
133 }
134}
135
136impl<T> IntoKernelParamStorage for &mut CudaSlice<T> {
137 type Storage = DeviceParamStorage<'static>;
138
139 fn into_kernel_param_storage(self) -> Self::Storage {
140 let stream = self.stream().clone();
141 let (ptr, sync) = cudarc::driver::DevicePtrMut::device_ptr_mut(self, &stream);
142 std::mem::forget(sync);
143 DeviceParamStorage::unsynced(ptr)
144 }
145}
146
147impl<'a, 'b, T> IntoKernelParamStorage for &'a CudaView<'b, T> {
148 type Storage = DeviceParamStorage<'a>;
149
150 fn into_kernel_param_storage(self) -> Self::Storage {
151 let stream = self.stream();
152 let (ptr, sync) = cudarc::driver::DevicePtr::device_ptr(self, stream);
153 DeviceParamStorage::synced(ptr, sync)
154 }
155}
156
157impl<'a, 'b, T> IntoKernelParamStorage for &'a CudaViewMut<'b, T> {
158 type Storage = DeviceParamStorage<'a>;
159
160 fn into_kernel_param_storage(self) -> Self::Storage {
161 let stream = self.stream();
162 let (ptr, sync) = cudarc::driver::DevicePtr::device_ptr(self, stream);
163 DeviceParamStorage::synced(ptr, sync)
164 }
165}
166
167impl<'a, 'b, T> IntoKernelParamStorage for &'a mut CudaViewMut<'b, T> {
168 type Storage = DeviceParamStorage<'static>;
169
170 fn into_kernel_param_storage(self) -> Self::Storage {
171 let stream = self.stream().clone();
172 let (ptr, sync) = cudarc::driver::DevicePtrMut::device_ptr_mut(self, &stream);
173 std::mem::forget(sync);
174 DeviceParamStorage::unsynced(ptr)
175 }
176}
177
178pub unsafe trait LaunchAsync<Params> {
186 unsafe fn launch(self, cfg: LaunchConfig, params: Params) -> Result<(), DriverError>;
192
193 unsafe fn launch_on_stream(
199 self,
200 stream: &CudaStream,
201 cfg: LaunchConfig,
202 params: Params,
203 ) -> Result<(), DriverError>;
204
205 unsafe fn launch_cooperative(
212 self,
213 cfg: LaunchConfig,
214 params: Params,
215 ) -> Result<(), DriverError>;
216}
217
218unsafe impl LaunchAsync<&mut [*mut c_void]> for CudaFunction {
219 unsafe fn launch(
220 self,
221 cfg: LaunchConfig,
222 params: &mut [*mut c_void],
223 ) -> Result<(), DriverError> {
224 self.launch_raw(cfg, params)
225 }
226
227 unsafe fn launch_on_stream(
228 self,
229 stream: &CudaStream,
230 cfg: LaunchConfig,
231 params: &mut [*mut c_void],
232 ) -> Result<(), DriverError> {
233 self.launch_raw_on_stream(stream, cfg, params)
234 }
235
236 unsafe fn launch_cooperative(
237 self,
238 cfg: LaunchConfig,
239 params: &mut [*mut c_void],
240 ) -> Result<(), DriverError> {
241 self.launch_raw_cooperative(cfg, params)
242 }
243}
244
245unsafe impl LaunchAsync<&mut Vec<*mut c_void>> for CudaFunction {
246 unsafe fn launch(
247 self,
248 cfg: LaunchConfig,
249 params: &mut Vec<*mut c_void>,
250 ) -> Result<(), DriverError> {
251 self.launch_raw(cfg, params)
252 }
253
254 unsafe fn launch_on_stream(
255 self,
256 stream: &CudaStream,
257 cfg: LaunchConfig,
258 params: &mut Vec<*mut c_void>,
259 ) -> Result<(), DriverError> {
260 self.launch_raw_on_stream(stream, cfg, params)
261 }
262
263 unsafe fn launch_cooperative(
264 self,
265 cfg: LaunchConfig,
266 params: &mut Vec<*mut c_void>,
267 ) -> Result<(), DriverError> {
268 self.launch_raw_cooperative(cfg, params)
269 }
270}
271
272macro_rules! impl_launch_tuple {
273 ([$($var:ident),*], [$($idx:tt),*]) => {
274 #[allow(non_snake_case)]
275 unsafe impl<$($var: IntoKernelParamStorage),*> LaunchAsync<($($var,)*)> for CudaFunction {
276 unsafe fn launch(
277 self,
278 cfg: LaunchConfig,
279 params: ($($var,)*),
280 ) -> Result<(), DriverError> {
281 let ($($var,)*) = params;
282 $(let $var = $var.into_kernel_param_storage();)*
283 let mut raw = [$( $var.as_kernel_param(), )*];
284 self.launch_raw(cfg, &mut raw)
285 }
286
287 unsafe fn launch_on_stream(
288 self,
289 stream: &CudaStream,
290 cfg: LaunchConfig,
291 params: ($($var,)*),
292 ) -> Result<(), DriverError> {
293 let ($($var,)*) = params;
294 $(let $var = $var.into_kernel_param_storage();)*
295 let mut raw = [$( $var.as_kernel_param(), )*];
296 self.launch_raw_on_stream(stream, cfg, &mut raw)
297 }
298
299 unsafe fn launch_cooperative(
300 self,
301 cfg: LaunchConfig,
302 params: ($($var,)*),
303 ) -> Result<(), DriverError> {
304 let ($($var,)*) = params;
305 $(let $var = $var.into_kernel_param_storage();)*
306 let mut raw = [$( $var.as_kernel_param(), )*];
307 self.launch_raw_cooperative(cfg, &mut raw)
308 }
309 }
310 };
311}
312
313impl_launch_tuple!([A], [0]);
314impl_launch_tuple!([A, B], [0, 1]);
315impl_launch_tuple!([A, B, C], [0, 1, 2]);
316impl_launch_tuple!([A, B, C, D], [0, 1, 2, 3]);
317impl_launch_tuple!([A, B, C, D, E], [0, 1, 2, 3, 4]);
318impl_launch_tuple!([A, B, C, D, E, F], [0, 1, 2, 3, 4, 5]);
319impl_launch_tuple!([A, B, C, D, E, F, G], [0, 1, 2, 3, 4, 5, 6]);
320impl_launch_tuple!([A, B, C, D, E, F, G, H], [0, 1, 2, 3, 4, 5, 6, 7]);
321impl_launch_tuple!([A, B, C, D, E, F, G, H, I], [0, 1, 2, 3, 4, 5, 6, 7, 8]);
322impl_launch_tuple!(
323 [A, B, C, D, E, F, G, H, I, J],
324 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
325);
326impl_launch_tuple!(
327 [A, B, C, D, E, F, G, H, I, J, K],
328 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
329);
330impl_launch_tuple!(
331 [A, B, C, D, E, F, G, H, I, J, K, L],
332 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
333);
334impl_launch_tuple!(
335 [A, B, C, D, E, F, G, H, I, J, K, L, M],
336 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
337);