1use crate::{DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig};
4use xlog_core::{Result, ScalarType, Schema, XlogError};
5
6use super::{arith_kernels, ARITH_MODULE};
7use crate::memory::TrackedCudaSlice;
8use crate::CudaBuffer;
9
10impl super::CudaKernelProvider {
11 pub fn create_constant_column(
21 &self,
22 value_bytes: &[u8],
23 col_type: ScalarType,
24 num_rows: u64,
25 ) -> Result<CudaBuffer> {
26 if num_rows == 0 {
27 let schema = Schema::new(vec![("const".to_string(), col_type)]);
28 return self.create_empty_buffer(schema);
29 }
30
31 let elem_size = col_type.size_bytes();
32 if value_bytes.len() != elem_size {
33 return Err(XlogError::Kernel(format!(
34 "Value bytes length {} doesn't match type size {}",
35 value_bytes.len(),
36 elem_size
37 )));
38 }
39
40 if num_rows > u32::MAX as u64 {
41 return Err(XlogError::Kernel(format!(
42 "Constant column supports at most {} rows, got {}",
43 u32::MAX,
44 num_rows
45 )));
46 }
47
48 let total_bytes = (num_rows as usize)
49 .checked_mul(elem_size)
50 .ok_or_else(|| XlogError::Kernel("Constant column size overflow".to_string()))?;
51
52 let mut dst_col = self.memory.alloc::<u8>(total_bytes)?;
53 let n = num_rows as u32;
54
55 macro_rules! launch_fill_const {
56 ($kernel:expr, $value:expr) => {{
57 let func = self
58 .device
59 .inner()
60 .get_func(ARITH_MODULE, $kernel)
61 .ok_or_else(|| XlogError::Kernel("arith fill kernel not found".to_string()))?;
62 let config = LaunchConfig::for_num_elems(n);
63 unsafe { func.clone().launch(config, ($value, n, &mut dst_col)) }
65 .map_err(|e| XlogError::Kernel(format!("fill const failed: {}", e)))?;
66 }};
67 }
68
69 match col_type {
70 ScalarType::U32 | ScalarType::Symbol => {
71 let value = u32::from_le_bytes(value_bytes.try_into().unwrap());
72 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U32, value);
73 }
74 ScalarType::U64 => {
75 let value = u64::from_le_bytes(value_bytes.try_into().unwrap());
76 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U64, value);
77 }
78 ScalarType::I64 => {
79 let value = i64::from_le_bytes(value_bytes.try_into().unwrap());
80 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I64, value);
81 }
82 ScalarType::I32 => {
83 let value = i32::from_le_bytes(value_bytes.try_into().unwrap());
84 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I32, value);
85 }
86 ScalarType::F64 => {
87 let value = f64::from_le_bytes(value_bytes.try_into().unwrap());
88 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F64, value);
89 }
90 ScalarType::F32 => {
91 let value = f32::from_le_bytes(value_bytes.try_into().unwrap());
92 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F32, value);
93 }
94 ScalarType::Bool => {
95 let value = value_bytes[0];
96 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U8, value);
97 }
98 }
99
100 self.device.synchronize()?;
101
102 let schema = Schema::new(vec![("const".to_string(), col_type)]);
103 self.buffer_from_columns(vec![dst_col.into()], num_rows, schema)
104 }
105
106 pub fn create_constant_column_with_device_count(
108 &self,
109 value_bytes: &[u8],
110 col_type: ScalarType,
111 row_cap: u64,
112 d_num_rows_src: &TrackedCudaSlice<u32>,
113 ) -> Result<CudaBuffer> {
114 if row_cap == 0 {
115 let schema = Schema::new(vec![("const".to_string(), col_type)]);
116 return self.create_empty_buffer(schema);
117 }
118
119 let elem_size = col_type.size_bytes();
120 if value_bytes.len() != elem_size {
121 return Err(XlogError::Kernel(format!(
122 "Value bytes length {} doesn't match type size {}",
123 value_bytes.len(),
124 elem_size
125 )));
126 }
127
128 if row_cap > u32::MAX as u64 {
129 return Err(XlogError::Kernel(format!(
130 "Constant column supports at most {} rows, got {}",
131 u32::MAX,
132 row_cap
133 )));
134 }
135
136 let total_bytes = (row_cap as usize)
137 .checked_mul(elem_size)
138 .ok_or_else(|| XlogError::Kernel("Constant column size overflow".to_string()))?;
139
140 let mut dst_col = self.memory.alloc::<u8>(total_bytes)?;
141 let n = row_cap as u32;
142
143 macro_rules! launch_fill_const {
144 ($kernel:expr, $value:expr) => {{
145 let func = self
146 .device
147 .inner()
148 .get_func(ARITH_MODULE, $kernel)
149 .ok_or_else(|| XlogError::Kernel("arith fill kernel not found".to_string()))?;
150 let config = LaunchConfig::for_num_elems(n);
151 unsafe { func.clone().launch(config, ($value, n, &mut dst_col)) }
153 .map_err(|e| XlogError::Kernel(format!("fill const failed: {}", e)))?;
154 }};
155 }
156
157 match col_type {
158 ScalarType::U32 | ScalarType::Symbol => {
159 let value = u32::from_le_bytes(value_bytes.try_into().unwrap());
160 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U32, value);
161 }
162 ScalarType::U64 => {
163 let value = u64::from_le_bytes(value_bytes.try_into().unwrap());
164 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U64, value);
165 }
166 ScalarType::I64 => {
167 let value = i64::from_le_bytes(value_bytes.try_into().unwrap());
168 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I64, value);
169 }
170 ScalarType::I32 => {
171 let value = i32::from_le_bytes(value_bytes.try_into().unwrap());
172 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I32, value);
173 }
174 ScalarType::F64 => {
175 let value = f64::from_le_bytes(value_bytes.try_into().unwrap());
176 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F64, value);
177 }
178 ScalarType::F32 => {
179 let value = f32::from_le_bytes(value_bytes.try_into().unwrap());
180 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F32, value);
181 }
182 ScalarType::Bool => {
183 let value = value_bytes[0];
184 launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U8, value);
185 }
186 }
187
188 self.device.synchronize()?;
189
190 let schema = Schema::new(vec![("const".to_string(), col_type)]);
191 let mut d_num_rows = self.memory.alloc::<u32>(1)?;
192 self.device
193 .inner()
194 .dtod_copy(d_num_rows_src, &mut d_num_rows)
195 .map_err(|e| XlogError::Kernel(format!("Failed to copy row count: {}", e)))?;
196
197 Ok(CudaBuffer::from_columns(
198 vec![dst_col.into()],
199 row_cap,
200 d_num_rows,
201 schema,
202 ))
203 }
204
205 pub fn add_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
223 match a.schema().column_type(0) {
224 Some(ScalarType::I64) => {
225 self.binary_arith_op_device::<i64>(a, b, 0, arith_kernels::ARITH_BINARY_I64)
226 }
227 Some(ScalarType::I32) => {
228 self.binary_arith_op_device::<i32>(a, b, 0, arith_kernels::ARITH_BINARY_I32)
229 }
230 Some(ScalarType::U64) => {
231 self.binary_arith_op_device::<u64>(a, b, 0, arith_kernels::ARITH_BINARY_U64)
232 }
233 Some(ScalarType::U32 | ScalarType::Symbol) => {
234 self.binary_arith_op_device::<u32>(a, b, 0, arith_kernels::ARITH_BINARY_U32)
235 }
236 Some(ScalarType::F64) => {
237 self.binary_arith_op_device::<f64>(a, b, 0, arith_kernels::ARITH_BINARY_F64)
238 }
239 Some(ScalarType::F32) => {
240 self.binary_arith_op_device::<f32>(a, b, 0, arith_kernels::ARITH_BINARY_F32)
241 }
242 other => Err(XlogError::Kernel(format!(
243 "Arithmetic not supported for {:?}",
244 other
245 ))),
246 }
247 }
248
249 pub fn sub_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
267 match a.schema().column_type(0) {
268 Some(ScalarType::I64) => {
269 self.binary_arith_op_device::<i64>(a, b, 1, arith_kernels::ARITH_BINARY_I64)
270 }
271 Some(ScalarType::I32) => {
272 self.binary_arith_op_device::<i32>(a, b, 1, arith_kernels::ARITH_BINARY_I32)
273 }
274 Some(ScalarType::U64) => {
275 self.binary_arith_op_device::<u64>(a, b, 1, arith_kernels::ARITH_BINARY_U64)
276 }
277 Some(ScalarType::U32 | ScalarType::Symbol) => {
278 self.binary_arith_op_device::<u32>(a, b, 1, arith_kernels::ARITH_BINARY_U32)
279 }
280 Some(ScalarType::F64) => {
281 self.binary_arith_op_device::<f64>(a, b, 1, arith_kernels::ARITH_BINARY_F64)
282 }
283 Some(ScalarType::F32) => {
284 self.binary_arith_op_device::<f32>(a, b, 1, arith_kernels::ARITH_BINARY_F32)
285 }
286 other => Err(XlogError::Kernel(format!(
287 "Arithmetic not supported for {:?}",
288 other
289 ))),
290 }
291 }
292
293 pub fn mul_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
311 match a.schema().column_type(0) {
312 Some(ScalarType::I64) => {
313 self.binary_arith_op_device::<i64>(a, b, 2, arith_kernels::ARITH_BINARY_I64)
314 }
315 Some(ScalarType::I32) => {
316 self.binary_arith_op_device::<i32>(a, b, 2, arith_kernels::ARITH_BINARY_I32)
317 }
318 Some(ScalarType::U64) => {
319 self.binary_arith_op_device::<u64>(a, b, 2, arith_kernels::ARITH_BINARY_U64)
320 }
321 Some(ScalarType::U32 | ScalarType::Symbol) => {
322 self.binary_arith_op_device::<u32>(a, b, 2, arith_kernels::ARITH_BINARY_U32)
323 }
324 Some(ScalarType::F64) => {
325 self.binary_arith_op_device::<f64>(a, b, 2, arith_kernels::ARITH_BINARY_F64)
326 }
327 Some(ScalarType::F32) => {
328 self.binary_arith_op_device::<f32>(a, b, 2, arith_kernels::ARITH_BINARY_F32)
329 }
330 other => Err(XlogError::Kernel(format!(
331 "Arithmetic not supported for {:?}",
332 other
333 ))),
334 }
335 }
336
337 pub fn div_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
357 match a.schema().column_type(0) {
358 Some(ScalarType::I64) => {
359 self.binary_arith_op_device::<i64>(a, b, 3, arith_kernels::ARITH_BINARY_I64)
360 }
361 Some(ScalarType::I32) => {
362 self.binary_arith_op_device::<i32>(a, b, 3, arith_kernels::ARITH_BINARY_I32)
363 }
364 Some(ScalarType::U64) => {
365 self.binary_arith_op_device::<u64>(a, b, 3, arith_kernels::ARITH_BINARY_U64)
366 }
367 Some(ScalarType::U32 | ScalarType::Symbol) => {
368 self.binary_arith_op_device::<u32>(a, b, 3, arith_kernels::ARITH_BINARY_U32)
369 }
370 Some(ScalarType::F64) => {
371 self.binary_arith_op_device::<f64>(a, b, 3, arith_kernels::ARITH_BINARY_F64)
372 }
373 Some(ScalarType::F32) => {
374 self.binary_arith_op_device::<f32>(a, b, 3, arith_kernels::ARITH_BINARY_F32)
375 }
376 other => Err(XlogError::Kernel(format!(
377 "Arithmetic not supported for {:?}",
378 other
379 ))),
380 }
381 }
382
383 pub fn mod_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
402 match a.schema().column_type(0) {
403 Some(ScalarType::I64) => {
404 self.binary_arith_op_device::<i64>(a, b, 4, arith_kernels::ARITH_BINARY_I64)
405 }
406 Some(ScalarType::I32) => {
407 self.binary_arith_op_device::<i32>(a, b, 4, arith_kernels::ARITH_BINARY_I32)
408 }
409 Some(ScalarType::U64) => {
410 self.binary_arith_op_device::<u64>(a, b, 4, arith_kernels::ARITH_BINARY_U64)
411 }
412 Some(ScalarType::U32 | ScalarType::Symbol) => {
413 self.binary_arith_op_device::<u32>(a, b, 4, arith_kernels::ARITH_BINARY_U32)
414 }
415 Some(ScalarType::F64) => {
416 self.binary_arith_op_device::<f64>(a, b, 4, arith_kernels::ARITH_BINARY_F64)
417 }
418 Some(ScalarType::F32) => {
419 self.binary_arith_op_device::<f32>(a, b, 4, arith_kernels::ARITH_BINARY_F32)
420 }
421 other => Err(XlogError::Kernel(format!(
422 "Arithmetic not supported for {:?}",
423 other
424 ))),
425 }
426 }
427
428 pub fn abs_column(&self, a: &CudaBuffer) -> Result<CudaBuffer> {
443 if a.arity() != 1 {
444 return Err(XlogError::Kernel(
445 "Arithmetic requires single-column buffers".into(),
446 ));
447 }
448
449 if a.num_rows() == 0 {
450 return self.create_empty_buffer(a.schema().clone());
451 }
452
453 let n: u32 = a.num_rows().try_into().map_err(|_| {
454 XlogError::Kernel(format!(
455 "abs_column: row count {} exceeds u32::MAX",
456 a.num_rows()
457 ))
458 })?;
459 let col = a
460 .column(0)
461 .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
462 let config = LaunchConfig::for_num_elems(n);
463
464 match a.schema().column_type(0) {
465 Some(ScalarType::I64) => {
466 let expected_bytes = (n as usize)
467 .checked_mul(std::mem::size_of::<i64>())
468 .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
469 if col.num_bytes() != expected_bytes {
470 return Err(XlogError::Kernel(format!(
471 "Column 0 has {} bytes but expected {} for {} rows",
472 col.num_bytes(),
473 expected_bytes,
474 a.num_rows()
475 )));
476 }
477 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
478 let func = self
479 .device
480 .inner()
481 .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_I64)
482 .ok_or_else(|| XlogError::Kernel("arith_abs_i64 not found".into()))?;
483 unsafe { func.clone().launch(config, (col, n, &mut out)) }
485 .map_err(|e| XlogError::Kernel(format!("abs_i64 failed: {}", e)))?;
486 self.device.synchronize()?;
487 self.buffer_from_columns_with_device_count(
488 vec![out.into()],
489 a.num_rows(),
490 a.schema.clone(),
491 a,
492 )
493 }
494 Some(ScalarType::I32) => {
495 let expected_bytes = (n as usize)
496 .checked_mul(std::mem::size_of::<i32>())
497 .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
498 if col.num_bytes() != expected_bytes {
499 return Err(XlogError::Kernel(format!(
500 "Column 0 has {} bytes but expected {} for {} rows",
501 col.num_bytes(),
502 expected_bytes,
503 a.num_rows()
504 )));
505 }
506 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
507 let func = self
508 .device
509 .inner()
510 .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_I32)
511 .ok_or_else(|| XlogError::Kernel("arith_abs_i32 not found".into()))?;
512 unsafe { func.clone().launch(config, (col, n, &mut out)) }
514 .map_err(|e| XlogError::Kernel(format!("abs_i32 failed: {}", e)))?;
515 self.device.synchronize()?;
516 self.buffer_from_columns_with_device_count(
517 vec![out.into()],
518 a.num_rows(),
519 a.schema.clone(),
520 a,
521 )
522 }
523 Some(ScalarType::F64) => {
524 let expected_bytes = (n as usize)
525 .checked_mul(std::mem::size_of::<f64>())
526 .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
527 if col.num_bytes() != expected_bytes {
528 return Err(XlogError::Kernel(format!(
529 "Column 0 has {} bytes but expected {} for {} rows",
530 col.num_bytes(),
531 expected_bytes,
532 a.num_rows()
533 )));
534 }
535 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
536 let func = self
537 .device
538 .inner()
539 .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_F64)
540 .ok_or_else(|| XlogError::Kernel("arith_abs_f64 not found".into()))?;
541 unsafe { func.clone().launch(config, (col, n, &mut out)) }
543 .map_err(|e| XlogError::Kernel(format!("abs_f64 failed: {}", e)))?;
544 self.device.synchronize()?;
545 self.buffer_from_columns_with_device_count(
546 vec![out.into()],
547 a.num_rows(),
548 a.schema.clone(),
549 a,
550 )
551 }
552 Some(ScalarType::F32) => {
553 let expected_bytes = (n as usize)
554 .checked_mul(std::mem::size_of::<f32>())
555 .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
556 if col.num_bytes() != expected_bytes {
557 return Err(XlogError::Kernel(format!(
558 "Column 0 has {} bytes but expected {} for {} rows",
559 col.num_bytes(),
560 expected_bytes,
561 a.num_rows()
562 )));
563 }
564 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
565 let func = self
566 .device
567 .inner()
568 .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_F32)
569 .ok_or_else(|| XlogError::Kernel("arith_abs_f32 not found".into()))?;
570 unsafe { func.clone().launch(config, (col, n, &mut out)) }
572 .map_err(|e| XlogError::Kernel(format!("abs_f32 failed: {}", e)))?;
573 self.device.synchronize()?;
574 self.buffer_from_columns_with_device_count(
575 vec![out.into()],
576 a.num_rows(),
577 a.schema.clone(),
578 a,
579 )
580 }
581 Some(ScalarType::U32 | ScalarType::U64 | ScalarType::Bool | ScalarType::Symbol) => {
582 self.clone_buffer(a)
583 }
584 other => Err(XlogError::Kernel(format!(
585 "Abs not supported for {:?}",
586 other
587 ))),
588 }
589 }
590
591 pub fn min_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
608 match a.schema().column_type(0) {
609 Some(ScalarType::I64) => {
610 self.binary_arith_op_device::<i64>(a, b, 5, arith_kernels::ARITH_BINARY_I64)
611 }
612 Some(ScalarType::I32) => {
613 self.binary_arith_op_device::<i32>(a, b, 5, arith_kernels::ARITH_BINARY_I32)
614 }
615 Some(ScalarType::U64) => {
616 self.binary_arith_op_device::<u64>(a, b, 5, arith_kernels::ARITH_BINARY_U64)
617 }
618 Some(ScalarType::U32 | ScalarType::Symbol) => {
619 self.binary_arith_op_device::<u32>(a, b, 5, arith_kernels::ARITH_BINARY_U32)
620 }
621 Some(ScalarType::F64) => {
622 self.binary_arith_op_device::<f64>(a, b, 5, arith_kernels::ARITH_BINARY_F64)
623 }
624 Some(ScalarType::F32) => {
625 self.binary_arith_op_device::<f32>(a, b, 5, arith_kernels::ARITH_BINARY_F32)
626 }
627 other => Err(XlogError::Kernel(format!(
628 "Arithmetic not supported for {:?}",
629 other
630 ))),
631 }
632 }
633
634 pub fn max_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
651 match a.schema().column_type(0) {
652 Some(ScalarType::I64) => {
653 self.binary_arith_op_device::<i64>(a, b, 6, arith_kernels::ARITH_BINARY_I64)
654 }
655 Some(ScalarType::I32) => {
656 self.binary_arith_op_device::<i32>(a, b, 6, arith_kernels::ARITH_BINARY_I32)
657 }
658 Some(ScalarType::U64) => {
659 self.binary_arith_op_device::<u64>(a, b, 6, arith_kernels::ARITH_BINARY_U64)
660 }
661 Some(ScalarType::U32 | ScalarType::Symbol) => {
662 self.binary_arith_op_device::<u32>(a, b, 6, arith_kernels::ARITH_BINARY_U32)
663 }
664 Some(ScalarType::F64) => {
665 self.binary_arith_op_device::<f64>(a, b, 6, arith_kernels::ARITH_BINARY_F64)
666 }
667 Some(ScalarType::F32) => {
668 self.binary_arith_op_device::<f32>(a, b, 6, arith_kernels::ARITH_BINARY_F32)
669 }
670 other => Err(XlogError::Kernel(format!(
671 "Arithmetic not supported for {:?}",
672 other
673 ))),
674 }
675 }
676
677 pub fn pow_columns(&self, base: &CudaBuffer, exp: &CudaBuffer) -> Result<CudaBuffer> {
695 if base.num_rows() != exp.num_rows() {
696 return Err(XlogError::Kernel("Row count mismatch".into()));
697 }
698 if base.arity() != 1 || exp.arity() != 1 {
699 return Err(XlogError::Kernel(
700 "Arithmetic requires single-column buffers".into(),
701 ));
702 }
703
704 if base.num_rows() == 0 {
705 let schema = Schema::new(vec![("result".to_string(), ScalarType::F64)]);
706 return self.create_empty_buffer(schema);
707 }
708
709 let n: u32 = base.num_rows().try_into().map_err(|_| {
710 XlogError::Kernel(format!(
711 "pow_columns: row count {} exceeds u32::MAX",
712 base.num_rows()
713 ))
714 })?;
715
716 let base_f64_buf = if base.schema().column_type(0) == Some(ScalarType::F64) {
717 None
718 } else {
719 Some(self.cast_column(base, ScalarType::F64)?)
720 };
721 let base_buf = base_f64_buf.as_ref().unwrap_or(base);
722
723 let exp_f64_buf = if exp.schema().column_type(0) == Some(ScalarType::F64) {
724 None
725 } else {
726 Some(self.cast_column(exp, ScalarType::F64)?)
727 };
728 let exp_buf = exp_f64_buf.as_ref().unwrap_or(exp);
729
730 let base_col = base_buf
731 .column(0)
732 .ok_or_else(|| XlogError::Kernel("Missing base column".into()))?;
733 let exp_col = exp_buf
734 .column(0)
735 .ok_or_else(|| XlogError::Kernel("Missing exp column".into()))?;
736
737 let expected_bytes = (n as usize)
738 .checked_mul(std::mem::size_of::<f64>())
739 .ok_or_else(|| XlogError::Kernel("pow_columns size overflow".into()))?;
740 if base_col.num_bytes() != expected_bytes || exp_col.num_bytes() != expected_bytes {
741 return Err(XlogError::Kernel(format!(
742 "pow_columns: expected {} bytes for {} rows",
743 expected_bytes,
744 base.num_rows()
745 )));
746 }
747
748 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
749 let func = self
750 .device
751 .inner()
752 .get_func(ARITH_MODULE, arith_kernels::ARITH_POW_F64)
753 .ok_or_else(|| XlogError::Kernel("arith_pow_f64 not found".into()))?;
754 let config = LaunchConfig::for_num_elems(n);
755
756 unsafe {
758 func.clone()
759 .launch(config, (base_col, exp_col, n, &mut out))
760 }
761 .map_err(|e| XlogError::Kernel(format!("pow_f64 failed: {}", e)))?;
762
763 self.device.synchronize()?;
764
765 let schema = Schema::new(vec![("result".to_string(), ScalarType::F64)]);
766 self.buffer_from_columns_with_device_count(vec![out.into()], base.num_rows(), schema, base)
767 }
768
769 pub fn select_columns(
787 &self,
788 mask: &CudaBuffer,
789 then_vals: &CudaBuffer,
790 else_vals: &CudaBuffer,
791 ) -> Result<CudaBuffer> {
792 if mask.num_rows() != then_vals.num_rows() || mask.num_rows() != else_vals.num_rows() {
793 return Err(XlogError::Kernel("Row count mismatch in select".into()));
794 }
795 if mask.arity() != 1 || then_vals.arity() != 1 || else_vals.arity() != 1 {
796 return Err(XlogError::Kernel(
797 "Select requires single-column buffers".into(),
798 ));
799 }
800
801 let then_type = then_vals.schema().column_type(0);
802 let else_type = else_vals.schema().column_type(0);
803 if then_type != else_type {
804 return Err(XlogError::Kernel(format!(
805 "Type mismatch in select: then={:?}, else={:?}",
806 then_type, else_type
807 )));
808 }
809
810 if mask.num_rows() == 0 {
811 let result_type = then_type.unwrap_or(ScalarType::I64);
812 let schema = Schema::new(vec![("result".to_string(), result_type)]);
813 return self.create_empty_buffer(schema);
814 }
815
816 let n: u32 = mask.num_rows().try_into().map_err(|_| {
817 XlogError::Kernel(format!(
818 "select_columns: row count {} exceeds u32::MAX",
819 mask.num_rows()
820 ))
821 })?;
822
823 let mask_col = mask
824 .column(0)
825 .ok_or_else(|| XlogError::Kernel("Missing mask column".into()))?;
826 let then_col = then_vals
827 .column(0)
828 .ok_or_else(|| XlogError::Kernel("Missing then column".into()))?;
829 let else_col = else_vals
830 .column(0)
831 .ok_or_else(|| XlogError::Kernel("Missing else column".into()))?;
832
833 let result_type = then_type.unwrap_or(ScalarType::I64);
834 let elem_size = result_type.size_bytes();
835 let expected_bytes = (n as usize)
836 .checked_mul(elem_size)
837 .ok_or_else(|| XlogError::Kernel("select_columns size overflow".into()))?;
838
839 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
840
841 let kernel_name = match result_type {
842 ScalarType::I64 => arith_kernels::ARITH_SELECT_I64,
843 ScalarType::I32 => arith_kernels::ARITH_SELECT_I32,
844 ScalarType::U64 => arith_kernels::ARITH_SELECT_U64,
845 ScalarType::U32 | ScalarType::Symbol => arith_kernels::ARITH_SELECT_U32,
846 ScalarType::F64 => arith_kernels::ARITH_SELECT_F64,
847 ScalarType::F32 => arith_kernels::ARITH_SELECT_F32,
848 ScalarType::Bool => {
849 return self.select_columns_bool(mask, then_vals, else_vals);
852 }
853 };
854
855 let func = self
856 .device
857 .inner()
858 .get_func(ARITH_MODULE, kernel_name)
859 .ok_or_else(|| XlogError::Kernel(format!("{} not found", kernel_name)))?;
860 let config = LaunchConfig::for_num_elems(n);
861
862 unsafe {
864 func.clone()
865 .launch(config, (mask_col, then_col, else_col, n, &mut out))
866 }
867 .map_err(|e| XlogError::Kernel(format!("select kernel failed: {}", e)))?;
868
869 self.device.synchronize()?;
870
871 let schema = Schema::new(vec![("result".to_string(), result_type)]);
872 self.buffer_from_columns_with_device_count(vec![out.into()], mask.num_rows(), schema, mask)
873 }
874
875 fn select_columns_bool(
877 &self,
878 mask: &CudaBuffer,
879 then_vals: &CudaBuffer,
880 else_vals: &CudaBuffer,
881 ) -> Result<CudaBuffer> {
882 let then_u32 = self.cast_column(then_vals, ScalarType::U32)?;
884 let else_u32 = self.cast_column(else_vals, ScalarType::U32)?;
885 let result_u32 = self.select_columns(mask, &then_u32, &else_u32)?;
886 self.cast_column(&result_u32, ScalarType::Bool)
887 }
888
889 pub fn cast_column(&self, a: &CudaBuffer, target: ScalarType) -> Result<CudaBuffer> {
905 if a.arity() != 1 {
906 return Err(XlogError::Kernel(
907 "Cast requires single-column buffer".into(),
908 ));
909 }
910
911 let source_type = a
912 .schema()
913 .column_type(0)
914 .ok_or_else(|| XlogError::Kernel("Missing column type".into()))?;
915
916 let schema = Schema::new(vec![("result".to_string(), target)]);
917
918 if a.num_rows() == 0 {
919 return self.create_empty_buffer(schema);
920 }
921
922 let n: u32 = a.num_rows().try_into().map_err(|_| {
923 XlogError::Kernel(format!(
924 "cast_column: row count {} exceeds u32::MAX",
925 a.num_rows()
926 ))
927 })?;
928
929 let src_col = a
930 .column(0)
931 .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
932 let src_bytes = (n as usize)
933 .checked_mul(source_type.size_bytes())
934 .ok_or_else(|| XlogError::Kernel("cast_column size overflow".into()))?;
935 if src_col.num_bytes() != src_bytes {
936 return Err(XlogError::Kernel(format!(
937 "Column 0 has {} bytes but expected {} for {} rows",
938 src_col.num_bytes(),
939 src_bytes,
940 a.num_rows()
941 )));
942 }
943
944 let dst_bytes = (n as usize)
945 .checked_mul(target.size_bytes())
946 .ok_or_else(|| XlogError::Kernel("cast_column size overflow".into()))?;
947 let mut out = self.memory.alloc::<u8>(dst_bytes)?;
948
949 let func = self
950 .device
951 .inner()
952 .get_func(ARITH_MODULE, arith_kernels::ARITH_CAST)
953 .ok_or_else(|| XlogError::Kernel("arith_cast not found".into()))?;
954 let config = LaunchConfig::for_num_elems(n);
955
956 unsafe {
958 func.clone().launch(
959 config,
960 (
961 src_col,
962 &mut out,
963 n,
964 source_type.to_code(),
965 target.to_code(),
966 ),
967 )
968 }
969 .map_err(|e| XlogError::Kernel(format!("cast failed: {}", e)))?;
970
971 self.device.synchronize()?;
972
973 self.buffer_from_columns_with_device_count(vec![out.into()], a.num_rows(), schema, a)
974 }
975
976 fn binary_arith_op_device<T: DeviceRepr>(
978 &self,
979 a: &CudaBuffer,
980 b: &CudaBuffer,
981 op: u8,
982 kernel: &str,
983 ) -> Result<CudaBuffer> {
984 if a.num_rows() != b.num_rows() {
985 return Err(XlogError::Kernel("Row count mismatch".into()));
986 }
987 if a.arity() != 1 || b.arity() != 1 {
988 return Err(XlogError::Kernel(
989 "Arithmetic requires single-column buffers".into(),
990 ));
991 }
992 if a.schema().column_type(0) != b.schema().column_type(0) {
993 return Err(XlogError::Kernel(
994 "Arithmetic requires matching column types".into(),
995 ));
996 }
997 if a.num_rows() == 0 {
998 return self.create_empty_buffer(a.schema.clone());
999 }
1000
1001 let n: u32 = a.num_rows().try_into().map_err(|_| {
1002 XlogError::Kernel(format!(
1003 "arith: row count {} exceeds u32::MAX",
1004 a.num_rows()
1005 ))
1006 })?;
1007
1008 let expected_bytes = (n as usize)
1009 .checked_mul(std::mem::size_of::<T>())
1010 .ok_or_else(|| XlogError::Kernel("arith output size overflow".into()))?;
1011
1012 let col_a = a
1013 .column(0)
1014 .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
1015 let col_b = b
1016 .column(0)
1017 .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
1018
1019 if col_a.num_bytes() != expected_bytes || col_b.num_bytes() != expected_bytes {
1020 return Err(XlogError::Kernel(format!(
1021 "Arithmetic expects {} bytes per column for {} rows",
1022 expected_bytes,
1023 a.num_rows()
1024 )));
1025 }
1026
1027 let mut out = self.memory.alloc::<u8>(expected_bytes)?;
1028 let func = self
1029 .device
1030 .inner()
1031 .get_func(ARITH_MODULE, kernel)
1032 .ok_or_else(|| XlogError::Kernel("arith kernel not found".into()))?;
1033 let config = LaunchConfig::for_num_elems(n);
1034
1035 unsafe { func.clone().launch(config, (col_a, col_b, n, op, &mut out)) }
1037 .map_err(|e| XlogError::Kernel(format!("arith binary failed: {}", e)))?;
1038
1039 self.device.synchronize()?;
1040 self.buffer_from_columns_with_device_count(
1041 vec![out.into()],
1042 a.num_rows(),
1043 a.schema.clone(),
1044 a,
1045 )
1046 }
1047
1048 pub fn combine_columns(
1057 &self,
1058 columns: Vec<CudaBuffer>,
1059 types: Vec<ScalarType>,
1060 ) -> Result<CudaBuffer> {
1061 if columns.is_empty() {
1062 let schema_cols: Vec<(String, ScalarType)> = types
1063 .iter()
1064 .enumerate()
1065 .map(|(i, t)| (format!("col_{}", i), *t))
1066 .collect();
1067 let schema = Schema::new(schema_cols);
1068 return self.create_empty_buffer(schema);
1069 }
1070
1071 let row_cap = columns[0].row_cap;
1072
1073 for (i, col) in columns.iter().enumerate() {
1075 if col.row_cap != row_cap {
1076 return Err(XlogError::Kernel(format!(
1077 "Column {} has row capacity {}, expected {}",
1078 i, col.row_cap, row_cap
1079 )));
1080 }
1081 if col.arity() != 1 {
1082 return Err(XlogError::Kernel(format!(
1083 "Column {} buffer has {} columns, expected 1",
1084 i,
1085 col.arity()
1086 )));
1087 }
1088 }
1089
1090 let device = self.device.inner();
1091 let mut d_num_rows = self.memory.alloc::<u32>(1)?;
1092 device
1093 .dtod_copy(columns[0].num_rows_device(), &mut d_num_rows)
1094 .map_err(|e| XlogError::Kernel(format!("Failed to copy row count: {}", e)))?;
1095 self.device.synchronize()?;
1096
1097 let mut result_columns = Vec::with_capacity(columns.len());
1098 for (i, col_buf) in columns.into_iter().enumerate() {
1099 let src_col = col_buf
1100 .columns
1101 .into_iter()
1102 .next()
1103 .ok_or_else(|| XlogError::Kernel(format!("Column {} buffer has no data", i)))?;
1104 result_columns.push(src_col);
1105 }
1106
1107 let schema_cols: Vec<(String, ScalarType)> = types
1108 .iter()
1109 .enumerate()
1110 .map(|(i, t)| (format!("col_{}", i), *t))
1111 .collect();
1112 let schema = Schema::new(schema_cols);
1113
1114 Ok(CudaBuffer::from_columns(
1115 result_columns,
1116 row_cap,
1117 d_num_rows,
1118 schema,
1119 ))
1120 }
1121}