1use std::ffi::c_void;
7use std::sync::Arc;
8
9use xlog_core::{Result, ScalarType, Schema, XlogError};
10
11use crate::memory::{validate_logical_row_count, CudaBuffer, CudaColumn};
12use crate::provider::CudaKernelProvider;
13use crate::CudaDevice;
14
15pub type DLDeviceType = i32;
16
17pub const K_DLCPU: DLDeviceType = 1;
18pub const K_DLCUDA: DLDeviceType = 2;
19
20pub type DLDataTypeCode = u8;
21pub const K_DLINT: DLDataTypeCode = 0;
22pub const K_DLUINT: DLDataTypeCode = 1;
23pub const K_DLFLOAT: DLDataTypeCode = 2;
24pub const K_DLBOOL: DLDataTypeCode = 6;
25
26#[repr(C)]
27#[derive(Debug, Clone, Copy)]
28pub struct DLDevice {
29 pub device_type: DLDeviceType,
30 pub device_id: i32,
31}
32
33#[repr(C)]
34#[derive(Debug, Clone, Copy)]
35pub struct DLDataType {
36 pub code: DLDataTypeCode,
37 pub bits: u8,
38 pub lanes: u16,
39}
40
41#[repr(C)]
42#[derive(Debug)]
43pub struct DLTensor {
44 pub data: *mut c_void,
45 pub device: DLDevice,
46 pub ndim: i32,
47 pub dtype: DLDataType,
48 pub shape: *mut i64,
49 pub strides: *mut i64,
50 pub byte_offset: u64,
51}
52
53pub type DLDeleter = Option<unsafe extern "C" fn(*mut DLManagedTensor)>;
54
55#[repr(C)]
56#[derive(Debug)]
57pub struct DLManagedTensor {
58 pub dl_tensor: DLTensor,
59 pub manager_ctx: *mut c_void,
60 pub deleter: DLDeleter,
61}
62
63#[allow(dead_code)]
64struct DlpackCtx {
65 buffer: Arc<CudaBuffer>,
66 shape: Box<[i64]>,
67}
68
69unsafe extern "C" fn dlpack_deleter(ptr: *mut DLManagedTensor) {
70 if ptr.is_null() {
71 return;
72 }
73 let ctx_ptr = unsafe { (*ptr).manager_ctx as *mut DlpackCtx };
75 if !ctx_ptr.is_null() {
76 unsafe {
78 drop(Box::from_raw(ctx_ptr));
79 }
80 }
81 unsafe {
83 drop(Box::from_raw(ptr));
84 }
85}
86
87fn scalar_to_dl_dtype(ty: ScalarType) -> DLDataType {
88 match ty {
89 ScalarType::U32 | ScalarType::Symbol => DLDataType {
90 code: K_DLUINT,
91 bits: 32,
92 lanes: 1,
93 },
94 ScalarType::U64 => DLDataType {
95 code: K_DLUINT,
96 bits: 64,
97 lanes: 1,
98 },
99 ScalarType::I32 => DLDataType {
100 code: K_DLINT,
101 bits: 32,
102 lanes: 1,
103 },
104 ScalarType::I64 => DLDataType {
105 code: K_DLINT,
106 bits: 64,
107 lanes: 1,
108 },
109 ScalarType::F32 => DLDataType {
110 code: K_DLFLOAT,
111 bits: 32,
112 lanes: 1,
113 },
114 ScalarType::F64 => DLDataType {
115 code: K_DLFLOAT,
116 bits: 64,
117 lanes: 1,
118 },
119 ScalarType::Bool => DLDataType {
120 code: K_DLBOOL,
121 bits: 8,
122 lanes: 1,
123 },
124 }
125}
126
127fn dl_dtype_to_scalar(dtype: DLDataType) -> Result<ScalarType> {
128 if dtype.lanes != 1 {
129 return Err(XlogError::Kernel(format!(
130 "Unsupported DLPack dtype lanes {} (expected 1)",
131 dtype.lanes
132 )));
133 }
134 match (dtype.code, dtype.bits) {
135 (K_DLUINT, 32) => Ok(ScalarType::U32),
136 (K_DLUINT, 64) => Ok(ScalarType::U64),
137 (K_DLINT, 32) => Ok(ScalarType::I32),
138 (K_DLINT, 64) => Ok(ScalarType::I64),
139 (K_DLFLOAT, 32) => Ok(ScalarType::F32),
140 (K_DLFLOAT, 64) => Ok(ScalarType::F64),
141 (K_DLBOOL, 8) => Ok(ScalarType::Bool),
143 _ => Err(XlogError::Kernel(format!(
144 "Unsupported DLPack dtype code={} bits={} lanes={}",
145 dtype.code, dtype.bits, dtype.lanes
146 ))),
147 }
148}
149
150pub struct DlpackManagedTensor {
154 ptr: *mut DLManagedTensor,
155}
156
157unsafe impl Send for DlpackManagedTensor {}
162unsafe impl Sync for DlpackManagedTensor {}
163
164impl DlpackManagedTensor {
165 pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Self {
171 Self { ptr }
172 }
173
174 pub fn as_ptr(&self) -> *mut DLManagedTensor {
175 self.ptr
176 }
177
178 pub fn into_raw(self) -> *mut DLManagedTensor {
179 let ptr = self.ptr;
180 std::mem::forget(self);
181 ptr
182 }
183}
184
185impl Drop for DlpackManagedTensor {
186 fn drop(&mut self) {
187 unsafe {
189 if !self.ptr.is_null() {
190 if let Some(deleter) = (*self.ptr).deleter {
191 deleter(self.ptr);
192 }
193 }
194 }
195 }
196}
197
198unsafe fn dlpack_tensor_info(
199 provider: &CudaKernelProvider,
200 tensor: &DlpackManagedTensor,
201) -> Result<(u64, ScalarType, cudarc::driver::sys::CUdeviceptr, usize)> {
202 let ptr = tensor.as_ptr();
203 if ptr.is_null() {
204 return Err(XlogError::Kernel(
205 "Null DLManagedTensor pointer".to_string(),
206 ));
207 }
208
209 let dl = unsafe { &(*ptr).dl_tensor };
211
212 if dl.device.device_type != K_DLCUDA {
213 return Err(XlogError::Kernel(format!(
214 "Unsupported DLPack device type {} (expected CUDA)",
215 dl.device.device_type
216 )));
217 }
218 if dl.device.device_id != provider.device().ordinal() as i32 {
219 return Err(XlogError::Kernel(format!(
220 "DLPack tensor device_id {} does not match provider device_id {}",
221 dl.device.device_id,
222 provider.device().ordinal()
223 )));
224 }
225
226 if dl.ndim != 1 {
227 return Err(XlogError::Kernel(format!(
228 "Unsupported DLPack ndim {} (expected 1)",
229 dl.ndim
230 )));
231 }
232 if dl.shape.is_null() {
233 return Err(XlogError::Kernel("DLPack tensor shape is null".to_string()));
234 }
235 if !dl.strides.is_null() {
236 let stride0 = unsafe { *dl.strides };
238 if stride0 != 1 {
239 return Err(XlogError::Kernel(format!(
240 "Non-contiguous DLPack tensor stride {} (expected 1)",
241 stride0
242 )));
243 }
244 }
245
246 let shape0 = unsafe { *dl.shape };
248 if shape0 < 0 {
249 return Err(XlogError::Kernel(format!(
250 "Negative DLPack tensor shape {}",
251 shape0
252 )));
253 }
254 let num_rows = shape0 as u64;
255
256 let scalar = dl_dtype_to_scalar(dl.dtype)?;
257 let elem_size = scalar.size_bytes();
258 if dl.byte_offset % (elem_size as u64) != 0 {
259 return Err(XlogError::Kernel(format!(
260 "DLPack byte_offset {} is not aligned to element size {}",
261 dl.byte_offset, elem_size
262 )));
263 }
264
265 if dl.data.is_null() && num_rows > 0 {
266 return Err(XlogError::Kernel(
267 "DLPack tensor data pointer is null".to_string(),
268 ));
269 }
270
271 let base = dl.data as usize;
272 let ptr_with_offset = base
273 .checked_add(dl.byte_offset as usize)
274 .ok_or_else(|| XlogError::Kernel("DLPack data pointer overflow".to_string()))?;
275
276 if ptr_with_offset % elem_size != 0 {
277 return Err(XlogError::Kernel(
278 "DLPack tensor data is not properly aligned".to_string(),
279 ));
280 }
281
282 let len_bytes = usize::try_from(num_rows)
283 .ok()
284 .and_then(|n| n.checked_mul(elem_size))
285 .ok_or_else(|| XlogError::Kernel("DLPack tensor length overflow".to_string()))?;
286
287 Ok((num_rows, scalar, ptr_with_offset as u64, len_bytes))
288}
289
290fn dlpack_logical_row_count(device: &Arc<CudaDevice>, buffer: &CudaBuffer) -> Result<usize> {
291 if let Some(cached_rows) = buffer.cached_row_count() {
292 return validate_logical_row_count(buffer.num_rows(), cached_rows as usize);
293 }
294
295 let mut host_rows = [0u32];
296 device
297 .inner()
298 .dtoh_sync_copy_into(buffer.num_rows_device(), &mut host_rows)
299 .map_err(|e| XlogError::Kernel(format!("Failed to read row count: {}", e)))?;
300 buffer.set_cached_row_count_if_unset(host_rows[0]);
301 validate_logical_row_count(buffer.num_rows(), host_rows[0] as usize)
302}
303
304pub struct DlpackTable {
308 buffer: Arc<CudaBuffer>,
309 cuda_device: Arc<CudaDevice>,
310 device: DLDevice,
311}
312
313impl DlpackTable {
314 pub fn column(&self, col_idx: usize) -> Result<DlpackManagedTensor> {
315 let logical_rows = dlpack_logical_row_count(&self.cuda_device, &self.buffer)?;
316 let dtype =
317 self.buffer.schema().column_type(col_idx).ok_or_else(|| {
318 XlogError::Kernel(format!("Column index {} out of bounds", col_idx))
319 })?;
320
321 let col = self
322 .buffer
323 .columns
324 .get(col_idx)
325 .ok_or_else(|| XlogError::Kernel(format!("Column {} not found", col_idx)))?;
326
327 let device_ptr = *col.device_ptr() as usize as *mut c_void;
328
329 let mut ctx = Box::new(DlpackCtx {
330 buffer: self.buffer.clone(),
331 shape: vec![logical_rows as i64].into_boxed_slice(),
332 });
333 let shape_ptr = ctx.shape.as_mut_ptr();
334
335 let dl_tensor = DLTensor {
336 data: device_ptr,
337 device: self.device,
338 ndim: 1,
339 dtype: scalar_to_dl_dtype(dtype),
340 shape: shape_ptr,
341 strides: std::ptr::null_mut(),
342 byte_offset: 0,
343 };
344
345 let managed = Box::new(DLManagedTensor {
346 dl_tensor,
347 manager_ctx: Box::into_raw(ctx) as *mut c_void,
348 deleter: Some(dlpack_deleter),
349 });
350
351 Ok(DlpackManagedTensor {
352 ptr: Box::into_raw(managed),
353 })
354 }
355}
356
357impl CudaKernelProvider {
358 pub fn to_dlpack_table(&self, buffer: CudaBuffer) -> DlpackTable {
362 DlpackTable {
363 buffer: Arc::new(buffer),
364 cuda_device: Arc::clone(self.device()),
365 device: DLDevice {
366 device_type: K_DLCUDA,
367 device_id: self.device().ordinal() as i32,
368 },
369 }
370 }
371
372 pub fn from_dlpack_tensors(&self, tensors: Vec<DlpackManagedTensor>) -> Result<CudaBuffer> {
376 if tensors.is_empty() {
377 return self.create_empty_buffer(Schema::new(vec![]));
378 }
379
380 let mut columns = Vec::with_capacity(tensors.len());
381 let mut schema_cols = Vec::with_capacity(tensors.len());
382 let mut num_rows: Option<u64> = None;
383
384 for (i, tensor) in tensors.into_iter().enumerate() {
385 let (rows, ty, ptr, len_bytes) = unsafe { dlpack_tensor_info(self, &tensor)? };
387 if let Some(n) = num_rows {
388 if rows != n {
389 return Err(XlogError::Kernel(
390 "DLPack column row counts do not match".to_string(),
391 ));
392 }
393 } else {
394 num_rows = Some(rows);
395 }
396
397 schema_cols.push((format!("col_{}", i), ty));
398 columns.push(CudaColumn::dlpack(
399 ptr,
400 len_bytes,
401 self.device().inner().stream().clone(),
402 tensor,
403 ));
404 }
405
406 let schema = xlog_core::Schema::new(schema_cols);
407 self.buffer_from_columns(columns, num_rows.unwrap_or(0), schema)
408 }
409
410 pub fn from_dlpack_tensors_with_schema(
412 &self,
413 schema: xlog_core::Schema,
414 tensors: Vec<DlpackManagedTensor>,
415 ) -> Result<CudaBuffer> {
416 if schema.arity() != tensors.len() {
417 return Err(XlogError::Kernel(format!(
418 "Schema arity {} does not match tensor count {}",
419 schema.arity(),
420 tensors.len()
421 )));
422 }
423
424 if tensors.is_empty() {
425 return self.create_empty_buffer(schema);
426 }
427
428 let mut columns = Vec::with_capacity(tensors.len());
429 let mut num_rows: Option<u64> = None;
430
431 for (i, tensor) in tensors.into_iter().enumerate() {
432 let (rows, ty, ptr, len_bytes) = unsafe { dlpack_tensor_info(self, &tensor)? };
434 let expected = schema.column_type(i).ok_or_else(|| {
435 XlogError::Kernel(format!("Missing schema type for column {}", i))
436 })?;
437 if !expected.dlpack_compatible(ty) {
438 return Err(XlogError::Kernel(format!(
439 "DLPack column {} dtype {:?} does not match schema {:?}",
440 i, ty, expected
441 )));
442 }
443
444 if let Some(n) = num_rows {
445 if rows != n {
446 return Err(XlogError::Kernel(
447 "DLPack column row counts do not match".to_string(),
448 ));
449 }
450 } else {
451 num_rows = Some(rows);
452 }
453
454 columns.push(CudaColumn::dlpack(
455 ptr,
456 len_bytes,
457 self.device().inner().stream().clone(),
458 tensor,
459 ));
460 }
461
462 self.buffer_from_columns(columns, num_rows.unwrap_or(0), schema)
463 }
464}