Skip to main content

xlog_cuda/
dlpack.rs

1//! DLPack interop for zero-copy GPU exchange.
2//!
3//! This module provides a minimal DLPack implementation for exporting XLOG GPU
4//! buffers to other ecosystems (e.g., Python cuDF) without device↔host copies.
5
6use std::ffi::c_void;
7use std::sync::Arc;
8
9use xlog_core::{Result, ScalarType, Schema, XlogError};
10
11use crate::memory::{validate_logical_row_count, CudaBuffer, CudaColumn};
12use crate::provider::CudaKernelProvider;
13use crate::CudaDevice;
14
15pub type DLDeviceType = i32;
16
17pub const K_DLCPU: DLDeviceType = 1;
18pub const K_DLCUDA: DLDeviceType = 2;
19
20pub type DLDataTypeCode = u8;
21pub const K_DLINT: DLDataTypeCode = 0;
22pub const K_DLUINT: DLDataTypeCode = 1;
23pub const K_DLFLOAT: DLDataTypeCode = 2;
24pub const K_DLBOOL: DLDataTypeCode = 6;
25
26#[repr(C)]
27#[derive(Debug, Clone, Copy)]
28pub struct DLDevice {
29    pub device_type: DLDeviceType,
30    pub device_id: i32,
31}
32
33#[repr(C)]
34#[derive(Debug, Clone, Copy)]
35pub struct DLDataType {
36    pub code: DLDataTypeCode,
37    pub bits: u8,
38    pub lanes: u16,
39}
40
41#[repr(C)]
42#[derive(Debug)]
43pub struct DLTensor {
44    pub data: *mut c_void,
45    pub device: DLDevice,
46    pub ndim: i32,
47    pub dtype: DLDataType,
48    pub shape: *mut i64,
49    pub strides: *mut i64,
50    pub byte_offset: u64,
51}
52
53pub type DLDeleter = Option<unsafe extern "C" fn(*mut DLManagedTensor)>;
54
55#[repr(C)]
56#[derive(Debug)]
57pub struct DLManagedTensor {
58    pub dl_tensor: DLTensor,
59    pub manager_ctx: *mut c_void,
60    pub deleter: DLDeleter,
61}
62
63#[allow(dead_code)]
64struct DlpackCtx {
65    buffer: Arc<CudaBuffer>,
66    shape: Box<[i64]>,
67}
68
69unsafe extern "C" fn dlpack_deleter(ptr: *mut DLManagedTensor) {
70    if ptr.is_null() {
71        return;
72    }
73    // SAFETY: ptr is non-null (checked above); manager_ctx was set from a Box<DlpackCtx> raw pointer
74    let ctx_ptr = unsafe { (*ptr).manager_ctx as *mut DlpackCtx };
75    if !ctx_ptr.is_null() {
76        // SAFETY: ctx_ptr was originally created via Box::into_raw; we are the sole owner
77        unsafe {
78            drop(Box::from_raw(ctx_ptr));
79        }
80    }
81    // SAFETY: ptr was originally created via Box::into_raw in DlpackManagedTensor::into_raw; we are the sole owner
82    unsafe {
83        drop(Box::from_raw(ptr));
84    }
85}
86
87fn scalar_to_dl_dtype(ty: ScalarType) -> DLDataType {
88    match ty {
89        ScalarType::U32 | ScalarType::Symbol => DLDataType {
90            code: K_DLUINT,
91            bits: 32,
92            lanes: 1,
93        },
94        ScalarType::U64 => DLDataType {
95            code: K_DLUINT,
96            bits: 64,
97            lanes: 1,
98        },
99        ScalarType::I32 => DLDataType {
100            code: K_DLINT,
101            bits: 32,
102            lanes: 1,
103        },
104        ScalarType::I64 => DLDataType {
105            code: K_DLINT,
106            bits: 64,
107            lanes: 1,
108        },
109        ScalarType::F32 => DLDataType {
110            code: K_DLFLOAT,
111            bits: 32,
112            lanes: 1,
113        },
114        ScalarType::F64 => DLDataType {
115            code: K_DLFLOAT,
116            bits: 64,
117            lanes: 1,
118        },
119        ScalarType::Bool => DLDataType {
120            code: K_DLBOOL,
121            bits: 8,
122            lanes: 1,
123        },
124    }
125}
126
127fn dl_dtype_to_scalar(dtype: DLDataType) -> Result<ScalarType> {
128    if dtype.lanes != 1 {
129        return Err(XlogError::Kernel(format!(
130            "Unsupported DLPack dtype lanes {} (expected 1)",
131            dtype.lanes
132        )));
133    }
134    match (dtype.code, dtype.bits) {
135        (K_DLUINT, 32) => Ok(ScalarType::U32),
136        (K_DLUINT, 64) => Ok(ScalarType::U64),
137        (K_DLINT, 32) => Ok(ScalarType::I32),
138        (K_DLINT, 64) => Ok(ScalarType::I64),
139        (K_DLFLOAT, 32) => Ok(ScalarType::F32),
140        (K_DLFLOAT, 64) => Ok(ScalarType::F64),
141        // XLOG represents bool as one byte per row today (not bitpacked).
142        (K_DLBOOL, 8) => Ok(ScalarType::Bool),
143        _ => Err(XlogError::Kernel(format!(
144            "Unsupported DLPack dtype code={} bits={} lanes={}",
145            dtype.code, dtype.bits, dtype.lanes
146        ))),
147    }
148}
149
150/// Owned DLPack tensor handle.
151///
152/// Dropping this value will call the DLPack deleter and free the underlying GPU memory.
153pub struct DlpackManagedTensor {
154    ptr: *mut DLManagedTensor,
155}
156
157// SAFETY: DLPack tensors are GPU device pointers with a deleter callback.
158// GPU memory is accessible from any CPU thread, and the deleter is a plain
159// function pointer with no thread affinity. The pointer is never dereferenced
160// concurrently — it is only read during column access and freed on drop.
161unsafe impl Send for DlpackManagedTensor {}
162unsafe impl Sync for DlpackManagedTensor {}
163
164impl DlpackManagedTensor {
165    /// Construct an owned DLPack tensor from a raw pointer.
166    ///
167    /// # Safety
168    /// `ptr` must be a valid `DLManagedTensor*` obtained from a DLPack producer, and ownership
169    /// must be transferred to the caller (the returned value will call the DLPack deleter on drop).
170    pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Self {
171        Self { ptr }
172    }
173
174    pub fn as_ptr(&self) -> *mut DLManagedTensor {
175        self.ptr
176    }
177
178    pub fn into_raw(self) -> *mut DLManagedTensor {
179        let ptr = self.ptr;
180        std::mem::forget(self);
181        ptr
182    }
183}
184
185impl Drop for DlpackManagedTensor {
186    fn drop(&mut self) {
187        // SAFETY: ptr is non-null (checked), was created by the DLPack producer; deleter is the registered cleanup function
188        unsafe {
189            if !self.ptr.is_null() {
190                if let Some(deleter) = (*self.ptr).deleter {
191                    deleter(self.ptr);
192                }
193            }
194        }
195    }
196}
197
198unsafe fn dlpack_tensor_info(
199    provider: &CudaKernelProvider,
200    tensor: &DlpackManagedTensor,
201) -> Result<(u64, ScalarType, cudarc::driver::sys::CUdeviceptr, usize)> {
202    let ptr = tensor.as_ptr();
203    if ptr.is_null() {
204        return Err(XlogError::Kernel(
205            "Null DLManagedTensor pointer".to_string(),
206        ));
207    }
208
209    // SAFETY: ptr is non-null (checked above); DlpackManagedTensor holds a valid DLManagedTensor for its lifetime
210    let dl = unsafe { &(*ptr).dl_tensor };
211
212    if dl.device.device_type != K_DLCUDA {
213        return Err(XlogError::Kernel(format!(
214            "Unsupported DLPack device type {} (expected CUDA)",
215            dl.device.device_type
216        )));
217    }
218    if dl.device.device_id != provider.device().ordinal() as i32 {
219        return Err(XlogError::Kernel(format!(
220            "DLPack tensor device_id {} does not match provider device_id {}",
221            dl.device.device_id,
222            provider.device().ordinal()
223        )));
224    }
225
226    if dl.ndim != 1 {
227        return Err(XlogError::Kernel(format!(
228            "Unsupported DLPack ndim {} (expected 1)",
229            dl.ndim
230        )));
231    }
232    if dl.shape.is_null() {
233        return Err(XlogError::Kernel("DLPack tensor shape is null".to_string()));
234    }
235    if !dl.strides.is_null() {
236        // SAFETY: dl.strides is non-null (checked above); points to a valid C array of ndim elements
237        let stride0 = unsafe { *dl.strides };
238        if stride0 != 1 {
239            return Err(XlogError::Kernel(format!(
240                "Non-contiguous DLPack tensor stride {} (expected 1)",
241                stride0
242            )));
243        }
244    }
245
246    // SAFETY: dl.shape is non-null (checked above); points to a valid C array of ndim elements
247    let shape0 = unsafe { *dl.shape };
248    if shape0 < 0 {
249        return Err(XlogError::Kernel(format!(
250            "Negative DLPack tensor shape {}",
251            shape0
252        )));
253    }
254    let num_rows = shape0 as u64;
255
256    let scalar = dl_dtype_to_scalar(dl.dtype)?;
257    let elem_size = scalar.size_bytes();
258    if dl.byte_offset % (elem_size as u64) != 0 {
259        return Err(XlogError::Kernel(format!(
260            "DLPack byte_offset {} is not aligned to element size {}",
261            dl.byte_offset, elem_size
262        )));
263    }
264
265    if dl.data.is_null() && num_rows > 0 {
266        return Err(XlogError::Kernel(
267            "DLPack tensor data pointer is null".to_string(),
268        ));
269    }
270
271    let base = dl.data as usize;
272    let ptr_with_offset = base
273        .checked_add(dl.byte_offset as usize)
274        .ok_or_else(|| XlogError::Kernel("DLPack data pointer overflow".to_string()))?;
275
276    if ptr_with_offset % elem_size != 0 {
277        return Err(XlogError::Kernel(
278            "DLPack tensor data is not properly aligned".to_string(),
279        ));
280    }
281
282    let len_bytes = usize::try_from(num_rows)
283        .ok()
284        .and_then(|n| n.checked_mul(elem_size))
285        .ok_or_else(|| XlogError::Kernel("DLPack tensor length overflow".to_string()))?;
286
287    Ok((num_rows, scalar, ptr_with_offset as u64, len_bytes))
288}
289
290fn dlpack_logical_row_count(device: &Arc<CudaDevice>, buffer: &CudaBuffer) -> Result<usize> {
291    if let Some(cached_rows) = buffer.cached_row_count() {
292        return validate_logical_row_count(buffer.num_rows(), cached_rows as usize);
293    }
294
295    let mut host_rows = [0u32];
296    device
297        .inner()
298        .dtoh_sync_copy_into(buffer.num_rows_device(), &mut host_rows)
299        .map_err(|e| XlogError::Kernel(format!("Failed to read row count: {}", e)))?;
300    buffer.set_cached_row_count_if_unset(host_rows[0]);
301    validate_logical_row_count(buffer.num_rows(), host_rows[0] as usize)
302}
303
304/// A table-like wrapper that can export individual columns as DLPack tensors without copies.
305///
306/// The underlying `CudaBuffer` is reference-counted so multiple DLPack exports can share it.
307pub struct DlpackTable {
308    buffer: Arc<CudaBuffer>,
309    cuda_device: Arc<CudaDevice>,
310    device: DLDevice,
311}
312
313impl DlpackTable {
314    pub fn column(&self, col_idx: usize) -> Result<DlpackManagedTensor> {
315        let logical_rows = dlpack_logical_row_count(&self.cuda_device, &self.buffer)?;
316        let dtype =
317            self.buffer.schema().column_type(col_idx).ok_or_else(|| {
318                XlogError::Kernel(format!("Column index {} out of bounds", col_idx))
319            })?;
320
321        let col = self
322            .buffer
323            .columns
324            .get(col_idx)
325            .ok_or_else(|| XlogError::Kernel(format!("Column {} not found", col_idx)))?;
326
327        let device_ptr = *col.device_ptr() as usize as *mut c_void;
328
329        let mut ctx = Box::new(DlpackCtx {
330            buffer: self.buffer.clone(),
331            shape: vec![logical_rows as i64].into_boxed_slice(),
332        });
333        let shape_ptr = ctx.shape.as_mut_ptr();
334
335        let dl_tensor = DLTensor {
336            data: device_ptr,
337            device: self.device,
338            ndim: 1,
339            dtype: scalar_to_dl_dtype(dtype),
340            shape: shape_ptr,
341            strides: std::ptr::null_mut(),
342            byte_offset: 0,
343        };
344
345        let managed = Box::new(DLManagedTensor {
346            dl_tensor,
347            manager_ctx: Box::into_raw(ctx) as *mut c_void,
348            deleter: Some(dlpack_deleter),
349        });
350
351        Ok(DlpackManagedTensor {
352            ptr: Box::into_raw(managed),
353        })
354    }
355}
356
357impl CudaKernelProvider {
358    /// Convert a `CudaBuffer` into a DLPack-exportable table without device↔host copies.
359    ///
360    /// Export each column with `DlpackTable::column(...)`.
361    pub fn to_dlpack_table(&self, buffer: CudaBuffer) -> DlpackTable {
362        DlpackTable {
363            buffer: Arc::new(buffer),
364            cuda_device: Arc::clone(self.device()),
365            device: DLDevice {
366                device_type: K_DLCUDA,
367                device_id: self.device().ordinal() as i32,
368            },
369        }
370    }
371
372    /// Import one DLPack tensor per column as a zero-copy `CudaBuffer`.
373    ///
374    /// The returned buffer owns the DLPack tensors and will call their deleters on drop.
375    pub fn from_dlpack_tensors(&self, tensors: Vec<DlpackManagedTensor>) -> Result<CudaBuffer> {
376        if tensors.is_empty() {
377            return self.create_empty_buffer(Schema::new(vec![]));
378        }
379
380        let mut columns = Vec::with_capacity(tensors.len());
381        let mut schema_cols = Vec::with_capacity(tensors.len());
382        let mut num_rows: Option<u64> = None;
383
384        for (i, tensor) in tensors.into_iter().enumerate() {
385            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
386            let (rows, ty, ptr, len_bytes) = unsafe { dlpack_tensor_info(self, &tensor)? };
387            if let Some(n) = num_rows {
388                if rows != n {
389                    return Err(XlogError::Kernel(
390                        "DLPack column row counts do not match".to_string(),
391                    ));
392                }
393            } else {
394                num_rows = Some(rows);
395            }
396
397            schema_cols.push((format!("col_{}", i), ty));
398            columns.push(CudaColumn::dlpack(
399                ptr,
400                len_bytes,
401                self.device().inner().stream().clone(),
402                tensor,
403            ));
404        }
405
406        let schema = xlog_core::Schema::new(schema_cols);
407        self.buffer_from_columns(columns, num_rows.unwrap_or(0), schema)
408    }
409
410    /// Import DLPack column tensors with an explicit schema (type-checked).
411    pub fn from_dlpack_tensors_with_schema(
412        &self,
413        schema: xlog_core::Schema,
414        tensors: Vec<DlpackManagedTensor>,
415    ) -> Result<CudaBuffer> {
416        if schema.arity() != tensors.len() {
417            return Err(XlogError::Kernel(format!(
418                "Schema arity {} does not match tensor count {}",
419                schema.arity(),
420                tensors.len()
421            )));
422        }
423
424        if tensors.is_empty() {
425            return self.create_empty_buffer(schema);
426        }
427
428        let mut columns = Vec::with_capacity(tensors.len());
429        let mut num_rows: Option<u64> = None;
430
431        for (i, tensor) in tensors.into_iter().enumerate() {
432            // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
433            let (rows, ty, ptr, len_bytes) = unsafe { dlpack_tensor_info(self, &tensor)? };
434            let expected = schema.column_type(i).ok_or_else(|| {
435                XlogError::Kernel(format!("Missing schema type for column {}", i))
436            })?;
437            if !expected.dlpack_compatible(ty) {
438                return Err(XlogError::Kernel(format!(
439                    "DLPack column {} dtype {:?} does not match schema {:?}",
440                    i, ty, expected
441                )));
442            }
443
444            if let Some(n) = num_rows {
445                if rows != n {
446                    return Err(XlogError::Kernel(
447                        "DLPack column row counts do not match".to_string(),
448                    ));
449                }
450            } else {
451                num_rows = Some(rows);
452            }
453
454            columns.push(CudaColumn::dlpack(
455                ptr,
456                len_bytes,
457                self.device().inner().stream().clone(),
458                tensor,
459            ));
460        }
461
462        self.buffer_from_columns(columns, num_rows.unwrap_or(0), schema)
463    }
464}