1use std::sync::Arc;
4
5use crate::{LaunchAsync, LaunchConfig};
6use xlog_core::{Result, ScalarType, Schema, XlogError};
7
8use super::{d4_kernels, pack_kernels, D4_MODULE, PACK_MODULE};
9use crate::CudaBuffer;
10
11impl super::CudaKernelProvider {
12 pub fn create_buffer_from_u32_columns(
26 &self,
27 columns: &[&[u32]],
28 schema: Schema,
29 ) -> Result<CudaBuffer> {
30 if columns.is_empty() {
31 return self.create_empty_buffer(schema);
32 }
33
34 let num_rows = columns[0].len();
35 for (i, col) in columns.iter().enumerate() {
36 if col.len() != num_rows {
37 return Err(XlogError::Kernel(format!(
38 "Column {} has {} rows but expected {}",
39 i,
40 col.len(),
41 num_rows
42 )));
43 }
44 }
45
46 let mut cuda_columns = Vec::with_capacity(columns.len());
47 for col_data in columns {
48 let bytes: Vec<u8> = col_data.iter().flat_map(|v| v.to_le_bytes()).collect();
49 let mut col = self.memory.alloc::<u8>(bytes.len())?;
50 self.htod_sync_copy_into_tracked(&bytes, &mut col)
51 .map_err(|e| XlogError::Kernel(format!("Failed to upload column: {}", e)))?;
52 cuda_columns.push(col.into());
53 }
54
55 self.buffer_from_columns(cuda_columns, num_rows as u64, schema)
56 }
57
58 pub fn create_buffer_from_slices(
76 &self,
77 slices: &[&[u8]],
78 schema: Schema,
79 ) -> Result<CudaBuffer> {
80 if slices.len() != schema.arity() {
81 return Err(XlogError::Kernel(format!(
82 "Slice count {} doesn't match schema arity {}",
83 slices.len(),
84 schema.arity()
85 )));
86 }
87
88 if slices.is_empty() {
89 return self.create_empty_buffer(schema);
90 }
91
92 let first_col_size = schema.column_type(0).map(|t| t.size_bytes()).unwrap_or(4);
93 let num_rows = slices[0].len() / first_col_size;
94
95 for (i, slice) in slices.iter().enumerate() {
97 let col_size = schema.column_type(i).map(|t| t.size_bytes()).unwrap_or(4);
98 let col_rows = slice.len() / col_size;
99 if col_rows != num_rows {
100 return Err(XlogError::Kernel(format!(
101 "Column {} has {} rows but expected {} rows (based on first column)",
102 i, col_rows, num_rows
103 )));
104 }
105 if slice.len() % col_size != 0 {
107 return Err(XlogError::Kernel(format!(
108 "Column {} slice length {} is not divisible by type size {}",
109 i,
110 slice.len(),
111 col_size
112 )));
113 }
114 }
115
116 let mut columns = Vec::with_capacity(slices.len());
117
118 for (i, slice) in slices.iter().enumerate() {
119 let mut col = self.memory.alloc::<u8>(slice.len())?;
120 self.htod_sync_copy_into_tracked(slice, &mut col)
121 .map_err(|e| XlogError::Kernel(format!("Failed to upload column {}: {}", i, e)))?;
122 columns.push(col.into());
123 }
124
125 self.buffer_from_columns(columns, num_rows as u64, schema)
126 }
127
128 pub fn to_arrow_device_record_batch(
135 &self,
136 buffer: CudaBuffer,
137 ) -> Result<crate::arrow_device::ArrowDeviceArrayOwned> {
138 use arrow::array::ArrayData;
139 use arrow::datatypes::{DataType, Field};
140 use arrow::ffi::to_ffi;
141
142 use crate::arrow_device::{ArrowDeviceArray, ARROW_DEVICE_CUDA};
143
144 let buffer = Arc::new(buffer);
145 let row_cap = buffer.num_rows();
146 let num_rows_u32 = u32::try_from(row_cap).map_err(|_| {
147 XlogError::Kernel(format!(
148 "Arrow device export supports at most {} rows, got {}",
149 u32::MAX,
150 row_cap
151 ))
152 })?;
153
154 let assert_fn = self
156 .device
157 .inner()
158 .get_func(D4_MODULE, d4_kernels::D4_ASSERT_U32_EQ)
159 .ok_or_else(|| XlogError::Kernel("d4_assert_u32_eq kernel not found".to_string()))?;
160 unsafe {
162 assert_fn.clone().launch(
163 LaunchConfig {
164 grid_dim: (1, 1, 1),
165 block_dim: (1, 1, 1),
166 shared_mem_bytes: 0,
167 },
168 (buffer.num_rows_device(), num_rows_u32),
169 )
170 }
171 .map_err(|e| XlogError::Kernel(format!("d4_assert_u32_eq failed: {}", e)))?;
172 self.device.synchronize()?;
173
174 let num_rows = usize::try_from(num_rows_u32)
175 .map_err(|_| XlogError::Kernel("Arrow device export row count overflow".to_string()))?;
176
177 let mut fields: Vec<Field> = Vec::with_capacity(buffer.arity());
178 let mut children: Vec<ArrayData> = Vec::with_capacity(buffer.arity());
179
180 for (col_idx, (name, scalar_type)) in buffer.schema().columns.iter().enumerate() {
181 let (field, child) =
182 self.build_arrow_device_child(&buffer, col_idx, name, *scalar_type, num_rows)?;
183 fields.push(field);
184 children.push(child);
185 }
186
187 let struct_type = DataType::Struct(fields.into());
188 let struct_data = ArrayData::builder(struct_type)
189 .len(num_rows)
190 .child_data(children)
191 .build()
192 .map_err(|e| XlogError::Kernel(format!("Arrow device export failed: {}", e)))?;
193
194 let (ffi_array, ffi_schema) =
195 to_ffi(&struct_data).map_err(|e| XlogError::Kernel(format!("{}", e)))?;
196 let array_ptr = Box::into_raw(Box::new(ffi_array));
197 let schema_ptr = Box::into_raw(Box::new(ffi_schema));
198
199 Ok(ArrowDeviceArray::new(
200 ARROW_DEVICE_CUDA,
201 self.device.ordinal() as i32,
202 array_ptr,
203 schema_ptr,
204 ))
205 }
206
207 #[cfg(feature = "arrow-device-import")]
211 pub fn from_arrow_device_record_batch(
212 &self,
213 device_array: crate::arrow_device::ArrowDeviceArrayOwned,
214 ) -> Result<CudaBuffer> {
215 use arrow::array::ArrayData;
216 use arrow::datatypes::DataType;
217 use arrow::ffi::from_ffi;
218 use std::sync::Arc;
219
220 use crate::arrow_device::{ArrowDeviceImport, ARROW_DEVICE_CUDA};
221 use crate::memory::CudaColumn;
222
223 let (device_type, device_id, ffi_array, ffi_schema) =
224 unsafe { device_array.into_ffi_parts() };
226
227 if device_type != ARROW_DEVICE_CUDA {
228 return Err(XlogError::Kernel(format!(
229 "Arrow device import requires CUDA device type={}, got {}",
230 ARROW_DEVICE_CUDA, device_type
231 )));
232 }
233 if device_id != self.device.ordinal() as i32 {
234 return Err(XlogError::Kernel(format!(
235 "Arrow device import device id mismatch: expected {}, got {}",
236 self.device.ordinal(),
237 device_id
238 )));
239 }
240
241 let data: ArrayData = unsafe { from_ffi(ffi_array, &ffi_schema) }
243 .map_err(|e| XlogError::Kernel(format!("Arrow device import failed: {}", e)))?;
244
245 let (fields, children) = match data.data_type() {
246 DataType::Struct(fields) => (fields.clone(), data.child_data().to_vec()),
247 other => {
248 return Err(XlogError::Kernel(format!(
249 "Arrow device import expects Struct, got {:?}",
250 other
251 )))
252 }
253 };
254
255 if data.offset() != 0 {
256 return Err(XlogError::Kernel(
257 "Arrow device import does not support non-zero offsets".to_string(),
258 ));
259 }
260 if data.null_count() > 0 || data.nulls().is_some() {
261 return Err(XlogError::Kernel(
262 "Arrow device import does not support nulls".to_string(),
263 ));
264 }
265
266 let num_rows = data.len();
267 if fields.len() != children.len() {
268 return Err(XlogError::Kernel(
269 "Arrow device import field/child length mismatch".to_string(),
270 ));
271 }
272
273 let keepalive = Arc::new(ArrowDeviceImport::new(data));
274 let mut columns = Vec::with_capacity(children.len());
275 let mut schema_cols = Vec::with_capacity(children.len());
276
277 for (field, child) in fields.iter().zip(children.iter()) {
278 if child.len() != num_rows {
279 return Err(XlogError::Kernel(
280 "Arrow device import child length mismatch".to_string(),
281 ));
282 }
283 if child.offset() != 0 {
284 return Err(XlogError::Kernel(
285 "Arrow device import does not support child offsets".to_string(),
286 ));
287 }
288 if child.null_count() > 0 || child.nulls().is_some() {
289 return Err(XlogError::Kernel(
290 "Arrow device import does not support child nulls".to_string(),
291 ));
292 }
293
294 let (scalar_type, elem_size) = Self::scalar_type_from_arrow_field(field)?;
295 let buffers = child.buffers();
296 let buf = buffers.first().ok_or_else(|| {
297 XlogError::Kernel("Arrow device import missing value buffer".to_string())
298 })?;
299 let len_bytes = buf.len();
300 let expected_bytes = num_rows.checked_mul(elem_size).ok_or_else(|| {
301 XlogError::Kernel("Arrow device import size overflow".to_string())
302 })?;
303 if len_bytes != expected_bytes {
304 return Err(XlogError::Kernel(format!(
305 "Arrow device import buffer size mismatch: expected {}, got {}",
306 expected_bytes, len_bytes
307 )));
308 }
309
310 let ptr = buf.as_ptr();
311 if ptr.is_null() && len_bytes > 0 {
312 return Err(XlogError::Kernel(
313 "Arrow device import got null buffer pointer".to_string(),
314 ));
315 }
316 let device_ptr = ptr as usize as cudarc::driver::sys::CUdeviceptr;
317 columns.push(CudaColumn::arrow_device(
318 device_ptr,
319 len_bytes,
320 self.device().inner().stream().clone(),
321 keepalive.clone(),
322 ));
323 schema_cols.push((field.name().to_string(), scalar_type));
324 }
325
326 let schema = Schema::new(schema_cols);
327 self.buffer_from_columns(columns, num_rows as u64, schema)
328 }
329
330 pub fn to_arrow_record_batch(
346 &self,
347 buffer: &CudaBuffer,
348 ) -> Result<arrow::record_batch::RecordBatch> {
349 use arrow::array::*;
350 use arrow::datatypes::{Field, Schema as ArrowSchema};
351 use arrow::record_batch::{RecordBatch, RecordBatchOptions};
352
353 let num_rows = self.device_row_count(buffer)?;
354
355 let fields: Vec<Field> = buffer
356 .schema
357 .columns
358 .iter()
359 .map(|(name, scalar_type)| Field::new(name, scalar_type.to_arrow_type(), false))
360 .collect();
361 let arrow_schema = Arc::new(ArrowSchema::new(fields));
362
363 let mut arrays: Vec<Arc<dyn Array>> = Vec::with_capacity(buffer.arity());
364
365 for (col_idx, (_, scalar_type)) in buffer.schema.columns.iter().enumerate() {
366 let col = buffer
367 .column(col_idx)
368 .ok_or_else(|| XlogError::Kernel(format!("Column {} not found", col_idx)))?;
369
370 if num_rows == 0 {
372 let array: Arc<dyn Array> = match scalar_type {
373 ScalarType::Bool => Arc::new(BooleanArray::from(Vec::<bool>::new())),
374 ScalarType::U32 => Arc::new(UInt32Array::from(Vec::<u32>::new())),
375 ScalarType::Symbol => Arc::new(xlog_core::symbol::to_arrow(&[])),
376 ScalarType::I32 => Arc::new(Int32Array::from(Vec::<i32>::new())),
377 ScalarType::U64 => Arc::new(UInt64Array::from(Vec::<u64>::new())),
378 ScalarType::I64 => Arc::new(Int64Array::from(Vec::<i64>::new())),
379 ScalarType::F32 => Arc::new(Float32Array::from(Vec::<f32>::new())),
380 ScalarType::F64 => Arc::new(Float64Array::from(Vec::<f64>::new())),
381 };
382 arrays.push(array);
383 continue;
384 }
385
386 let elem_size = scalar_type.size_bytes();
387 let num_bytes = num_rows
388 .checked_mul(elem_size)
389 .ok_or_else(|| XlogError::Kernel("Row byte size overflow".to_string()))?;
390 let mut bytes = vec![0u8; num_bytes];
391 let col_view = self.column_bytes_view(col, num_bytes)?;
392 self.device
393 .inner()
394 .dtoh_sync_copy_into(&col_view, &mut bytes)
395 .map_err(|e| XlogError::Kernel(format!("Failed to download column: {}", e)))?;
396
397 let array: Arc<dyn Array> = match scalar_type {
398 ScalarType::Bool => Arc::new(BooleanArray::from(
399 bytes.iter().map(|&b| b != 0).collect::<Vec<_>>(),
400 )),
401 ScalarType::U32 => {
402 let values: Vec<u32> = bytes
403 .chunks_exact(4)
404 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
405 .collect();
406 Arc::new(UInt32Array::from(values))
407 }
408 ScalarType::Symbol => {
409 let values: Vec<u32> = bytes
410 .chunks_exact(4)
411 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
412 .collect();
413 Arc::new(xlog_core::symbol::to_arrow(&values))
414 }
415 ScalarType::I32 => {
416 let values: Vec<i32> = bytes
417 .chunks_exact(4)
418 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
419 .collect();
420 Arc::new(Int32Array::from(values))
421 }
422 ScalarType::U64 => {
423 let values: Vec<u64> = bytes
424 .chunks_exact(8)
425 .map(|c| {
426 u64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
427 })
428 .collect();
429 Arc::new(UInt64Array::from(values))
430 }
431 ScalarType::I64 => {
432 let values: Vec<i64> = bytes
433 .chunks_exact(8)
434 .map(|c| {
435 i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
436 })
437 .collect();
438 Arc::new(Int64Array::from(values))
439 }
440 ScalarType::F32 => {
441 let values: Vec<f32> = bytes
442 .chunks_exact(4)
443 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
444 .collect();
445 Arc::new(Float32Array::from(values))
446 }
447 ScalarType::F64 => {
448 let values: Vec<f64> = bytes
449 .chunks_exact(8)
450 .map(|c| {
451 f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]])
452 })
453 .collect();
454 Arc::new(Float64Array::from(values))
455 }
456 };
457
458 arrays.push(array);
459 }
460
461 let options = RecordBatchOptions::new().with_row_count(Some(num_rows));
462 RecordBatch::try_new_with_options(arrow_schema, arrays, &options)
463 .map_err(|e| XlogError::Kernel(format!("Failed to create RecordBatch: {}", e)))
464 }
465
466 fn build_arrow_device_child(
467 &self,
468 buffer: &Arc<CudaBuffer>,
469 col_idx: usize,
470 name: &str,
471 scalar_type: ScalarType,
472 num_rows: usize,
473 ) -> Result<(arrow::datatypes::Field, arrow::array::ArrayData)> {
474 use arrow::array::ArrayData;
475 use arrow::buffer::Buffer;
476 use arrow::datatypes::{DataType, Field};
477 use std::collections::HashMap;
478 use std::ptr::NonNull;
479
480 use crate::arrow_device::ArrowCudaAllocation;
481
482 let col = buffer
483 .column(col_idx)
484 .ok_or_else(|| XlogError::Kernel(format!("Column {} not found", col_idx)))?;
485
486 let (dtype, metadata) = if scalar_type == ScalarType::Symbol {
487 let mut meta = HashMap::new();
488 meta.insert("xlog.symbol".to_string(), "true".to_string());
489 meta.insert("xlog.symbol_encoding".to_string(), "u32".to_string());
490 (DataType::UInt32, Some(meta))
491 } else {
492 (scalar_type.to_arrow_type(), None)
493 };
494
495 let field = match metadata {
496 Some(meta) => Field::new(name, dtype.clone(), false).with_metadata(meta),
497 None => Field::new(name, dtype.clone(), false),
498 };
499
500 let elem_size = match scalar_type {
501 ScalarType::Symbol => 4usize,
502 _ => scalar_type.size_bytes(),
503 };
504
505 let len_bytes = num_rows
506 .checked_mul(elem_size)
507 .ok_or_else(|| XlogError::Kernel("Arrow device export size overflow".to_string()))?;
508
509 let mut extra = Vec::new();
510 let (ptr, len) = match scalar_type {
511 ScalarType::Bool => {
512 let packed_len = num_rows.div_ceil(8);
513 let mut packed = self.memory.alloc::<u8>(packed_len)?;
514 let pack_fn = self
515 .device
516 .inner()
517 .get_func(PACK_MODULE, pack_kernels::PACK_BOOLS_TO_BITMAP)
518 .ok_or_else(|| {
519 XlogError::Kernel("pack_bools_to_bitmap kernel not found".to_string())
520 })?;
521 let block_size = 256u32;
522 let grid_size = (packed_len as u32).div_ceil(block_size);
523 unsafe {
525 pack_fn.clone().launch(
526 LaunchConfig {
527 grid_dim: (grid_size, 1, 1),
528 block_dim: (block_size, 1, 1),
529 shared_mem_bytes: 0,
530 },
531 (col, num_rows as u32, &mut packed),
532 )
533 }
534 .map_err(|e| XlogError::Kernel(format!("pack_bools_to_bitmap failed: {}", e)))?;
535 self.device.synchronize()?;
536 let ptr = *packed.device_ptr() as usize as *mut u8;
537 extra.push(packed);
538 (ptr, packed_len)
539 }
540 _ => {
541 let ptr = *col.device_ptr() as usize as *mut u8;
542 (ptr, len_bytes)
543 }
544 };
545
546 let alloc = Arc::new(ArrowCudaAllocation::new(Arc::clone(buffer), extra));
547 let nn = if len == 0 {
548 NonNull::dangling()
549 } else {
550 NonNull::new(ptr).ok_or_else(|| {
551 XlogError::Kernel("Arrow device export got null device pointer".to_string())
552 })?
553 };
554 let buf = unsafe { Buffer::from_custom_allocation(nn, len, alloc) };
556
557 let data = ArrayData::builder(dtype)
558 .len(num_rows)
559 .add_buffer(buf)
560 .build()
561 .map_err(|e| XlogError::Kernel(format!("Arrow device export failed: {}", e)))?;
562
563 Ok((field, data))
564 }
565
566 #[cfg(feature = "arrow-device-import")]
567 fn scalar_type_from_arrow_field(
568 field: &arrow::datatypes::Field,
569 ) -> Result<(ScalarType, usize)> {
570 use arrow::datatypes::DataType;
571
572 let is_symbol = field
574 .metadata()
575 .get("xlog.symbol")
576 .map(|v| v == "true")
577 .unwrap_or(false);
578
579 let scalar = match field.data_type() {
580 DataType::Boolean => {
581 return Err(XlogError::Kernel(
582 "Arrow device import does not support bit-packed bool yet".to_string(),
583 ))
584 }
585 DataType::UInt32 if is_symbol => ScalarType::Symbol,
586 dt => ScalarType::from_arrow_type(dt).ok_or_else(|| {
587 XlogError::Kernel(format!("Arrow device import unsupported type {:?}", dt))
588 })?,
589 };
590
591 let elem_size = match scalar {
592 ScalarType::Symbol => 4usize,
593 _ => scalar.size_bytes(),
594 };
595
596 Ok((scalar, elem_size))
597 }
598
599 pub fn from_arrow_record_batch(
612 &self,
613 record_batch: &arrow::record_batch::RecordBatch,
614 ) -> Result<CudaBuffer> {
615 use arrow::array::*;
616 use arrow::datatypes::DataType;
617
618 let num_rows = record_batch.num_rows() as u64;
619
620 if num_rows == 0 {
621 let columns: Vec<(String, ScalarType)> = record_batch
622 .schema()
623 .fields()
624 .iter()
625 .filter_map(|f| {
626 ScalarType::from_arrow_type(f.data_type()).map(|st| (f.name().clone(), st))
627 })
628 .collect();
629 return self.create_empty_buffer(Schema::new(columns));
630 }
631
632 let mut columns = Vec::with_capacity(record_batch.num_columns());
633 let mut schema_cols = Vec::with_capacity(record_batch.num_columns());
634
635 for (col_idx, field) in record_batch.schema().fields().iter().enumerate() {
636 let array = record_batch.column(col_idx);
637 let scalar_type = ScalarType::from_arrow_type(field.data_type()).ok_or_else(|| {
638 XlogError::Kernel(format!("Unsupported Arrow type: {:?}", field.data_type()))
639 })?;
640
641 let bytes: Vec<u8> = match field.data_type() {
642 DataType::Boolean => {
643 let arr = array.as_any().downcast_ref::<BooleanArray>().unwrap();
644 arr.iter()
645 .map(|v| if v.unwrap_or(false) { 1u8 } else { 0u8 })
646 .collect()
647 }
648 DataType::UInt32 => {
649 let arr = array.as_any().downcast_ref::<UInt32Array>().unwrap();
650 arr.values().iter().flat_map(|v| v.to_le_bytes()).collect()
651 }
652 DataType::Int32 => {
653 let arr = array.as_any().downcast_ref::<Int32Array>().unwrap();
654 arr.values().iter().flat_map(|v| v.to_le_bytes()).collect()
655 }
656 DataType::UInt64 => {
657 let arr = array.as_any().downcast_ref::<UInt64Array>().unwrap();
658 arr.values().iter().flat_map(|v| v.to_le_bytes()).collect()
659 }
660 DataType::Int64 => {
661 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
662 arr.values().iter().flat_map(|v| v.to_le_bytes()).collect()
663 }
664 DataType::Float32 => {
665 let arr = array.as_any().downcast_ref::<Float32Array>().unwrap();
666 arr.values().iter().flat_map(|v| v.to_le_bytes()).collect()
667 }
668 DataType::Float64 => {
669 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
670 arr.values().iter().flat_map(|v| v.to_le_bytes()).collect()
671 }
672 _ => {
673 return Err(XlogError::Kernel(format!(
674 "Unsupported Arrow type: {:?}",
675 field.data_type()
676 )))
677 }
678 };
679
680 let mut d_col = self.memory.alloc::<u8>(bytes.len())?;
681 self.htod_sync_copy_into_tracked(&bytes, &mut d_col)
682 .map_err(|e| XlogError::Kernel(format!("Failed to upload column: {}", e)))?;
683
684 columns.push(d_col.into());
685 schema_cols.push((field.name().clone(), scalar_type));
686 }
687
688 self.buffer_from_columns(columns, num_rows, Schema::new(schema_cols))
689 }
690
691 pub fn to_arrow_ipc_stream(&self, buffer: &CudaBuffer) -> Result<Vec<u8>> {
698 use arrow::ipc::writer::StreamWriter;
699
700 let batch = self.to_arrow_record_batch(buffer)?;
701 let mut out = Vec::new();
702 let mut writer = StreamWriter::try_new(&mut out, &batch.schema())
703 .map_err(|e| XlogError::Kernel(format!("Failed to create Arrow IPC writer: {}", e)))?;
704 writer
705 .write(&batch)
706 .map_err(|e| XlogError::Kernel(format!("Failed to write Arrow RecordBatch: {}", e)))?;
707 writer
708 .finish()
709 .map_err(|e| XlogError::Kernel(format!("Failed to finish Arrow IPC stream: {}", e)))?;
710 Ok(out)
711 }
712
713 pub fn from_arrow_ipc_stream(&self, ipc: &[u8]) -> Result<CudaBuffer> {
717 use arrow::ipc::reader::StreamReader;
718 use std::io::Cursor;
719
720 let cursor = Cursor::new(ipc);
721 let mut reader = StreamReader::try_new(cursor, None)
722 .map_err(|e| XlogError::Kernel(format!("Failed to create Arrow IPC reader: {}", e)))?;
723
724 let batches = reader.by_ref();
725 let first = batches
726 .next()
727 .ok_or_else(|| {
728 XlogError::Kernel("Arrow IPC stream contained no record batches".to_string())
729 })?
730 .map_err(|e| XlogError::Kernel(format!("Failed to read Arrow RecordBatch: {}", e)))?;
731
732 if batches.next().is_some() {
733 return Err(XlogError::Kernel(
734 "Arrow IPC stream contains multiple record batches; this API expects exactly one"
735 .to_string(),
736 ));
737 }
738
739 self.from_arrow_record_batch(&first)
740 }
741
742 pub fn write_arrow_ipc_stream_file<P: AsRef<std::path::Path>>(
744 &self,
745 buffer: &CudaBuffer,
746 path: P,
747 ) -> Result<()> {
748 let bytes = self.to_arrow_ipc_stream(buffer)?;
749 std::fs::write(&path, bytes).map_err(|e| {
750 XlogError::Kernel(format!(
751 "Failed to write Arrow IPC stream to {}: {}",
752 path.as_ref().display(),
753 e
754 ))
755 })?;
756 Ok(())
757 }
758
759 pub fn read_arrow_ipc_stream_file<P: AsRef<std::path::Path>>(
761 &self,
762 path: P,
763 ) -> Result<CudaBuffer> {
764 let bytes = std::fs::read(&path).map_err(|e| {
765 XlogError::Kernel(format!(
766 "Failed to read Arrow IPC stream from {}: {}",
767 path.as_ref().display(),
768 e
769 ))
770 })?;
771 self.from_arrow_ipc_stream(&bytes)
772 }
773}