Skip to main content

xlog_cuda/provider/
arithmetic.rs

1//! Arithmetic column operations: add, sub, mul, div, mod, abs, min, max, pow, cast, select, combine.
2
3use crate::{DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig};
4use xlog_core::{Result, ScalarType, Schema, XlogError};
5
6use super::{arith_kernels, ARITH_MODULE};
7use crate::memory::TrackedCudaSlice;
8use crate::CudaBuffer;
9
10impl super::CudaKernelProvider {
11    /// Create a column filled with a constant value
12    ///
13    /// # Arguments
14    /// * `value_bytes` - The raw bytes of the constant value (in little-endian format)
15    /// * `col_type` - The ScalarType of the column
16    /// * `num_rows` - Number of rows to create
17    ///
18    /// # Returns
19    /// A new single-column CudaBuffer filled with the constant value
20    pub fn create_constant_column(
21        &self,
22        value_bytes: &[u8],
23        col_type: ScalarType,
24        num_rows: u64,
25    ) -> Result<CudaBuffer> {
26        if num_rows == 0 {
27            let schema = Schema::new(vec![("const".to_string(), col_type)]);
28            return self.create_empty_buffer(schema);
29        }
30
31        let elem_size = col_type.size_bytes();
32        if value_bytes.len() != elem_size {
33            return Err(XlogError::Kernel(format!(
34                "Value bytes length {} doesn't match type size {}",
35                value_bytes.len(),
36                elem_size
37            )));
38        }
39
40        if num_rows > u32::MAX as u64 {
41            return Err(XlogError::Kernel(format!(
42                "Constant column supports at most {} rows, got {}",
43                u32::MAX,
44                num_rows
45            )));
46        }
47
48        let total_bytes = (num_rows as usize)
49            .checked_mul(elem_size)
50            .ok_or_else(|| XlogError::Kernel("Constant column size overflow".to_string()))?;
51
52        let mut dst_col = self.memory.alloc::<u8>(total_bytes)?;
53        let n = num_rows as u32;
54
55        macro_rules! launch_fill_const {
56            ($kernel:expr, $value:expr) => {{
57                let func = self
58                    .device
59                    .inner()
60                    .get_func(ARITH_MODULE, $kernel)
61                    .ok_or_else(|| XlogError::Kernel("arith fill kernel not found".to_string()))?;
62                let config = LaunchConfig::for_num_elems(n);
63                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
64                unsafe { func.clone().launch(config, ($value, n, &mut dst_col)) }
65                    .map_err(|e| XlogError::Kernel(format!("fill const failed: {}", e)))?;
66            }};
67        }
68
69        match col_type {
70            ScalarType::U32 | ScalarType::Symbol => {
71                let value = u32::from_le_bytes(value_bytes.try_into().unwrap());
72                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U32, value);
73            }
74            ScalarType::U64 => {
75                let value = u64::from_le_bytes(value_bytes.try_into().unwrap());
76                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U64, value);
77            }
78            ScalarType::I64 => {
79                let value = i64::from_le_bytes(value_bytes.try_into().unwrap());
80                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I64, value);
81            }
82            ScalarType::I32 => {
83                let value = i32::from_le_bytes(value_bytes.try_into().unwrap());
84                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I32, value);
85            }
86            ScalarType::F64 => {
87                let value = f64::from_le_bytes(value_bytes.try_into().unwrap());
88                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F64, value);
89            }
90            ScalarType::F32 => {
91                let value = f32::from_le_bytes(value_bytes.try_into().unwrap());
92                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F32, value);
93            }
94            ScalarType::Bool => {
95                let value = value_bytes[0];
96                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U8, value);
97            }
98        }
99
100        self.device.synchronize()?;
101
102        let schema = Schema::new(vec![("const".to_string(), col_type)]);
103        self.buffer_from_columns(vec![dst_col.into()], num_rows, schema)
104    }
105
106    /// Create a constant column sized to `row_cap` while preserving device row count from `d_num_rows_src`.
107    pub fn create_constant_column_with_device_count(
108        &self,
109        value_bytes: &[u8],
110        col_type: ScalarType,
111        row_cap: u64,
112        d_num_rows_src: &TrackedCudaSlice<u32>,
113    ) -> Result<CudaBuffer> {
114        if row_cap == 0 {
115            let schema = Schema::new(vec![("const".to_string(), col_type)]);
116            return self.create_empty_buffer(schema);
117        }
118
119        let elem_size = col_type.size_bytes();
120        if value_bytes.len() != elem_size {
121            return Err(XlogError::Kernel(format!(
122                "Value bytes length {} doesn't match type size {}",
123                value_bytes.len(),
124                elem_size
125            )));
126        }
127
128        if row_cap > u32::MAX as u64 {
129            return Err(XlogError::Kernel(format!(
130                "Constant column supports at most {} rows, got {}",
131                u32::MAX,
132                row_cap
133            )));
134        }
135
136        let total_bytes = (row_cap as usize)
137            .checked_mul(elem_size)
138            .ok_or_else(|| XlogError::Kernel("Constant column size overflow".to_string()))?;
139
140        let mut dst_col = self.memory.alloc::<u8>(total_bytes)?;
141        let n = row_cap as u32;
142
143        macro_rules! launch_fill_const {
144            ($kernel:expr, $value:expr) => {{
145                let func = self
146                    .device
147                    .inner()
148                    .get_func(ARITH_MODULE, $kernel)
149                    .ok_or_else(|| XlogError::Kernel("arith fill kernel not found".to_string()))?;
150                let config = LaunchConfig::for_num_elems(n);
151                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
152                unsafe { func.clone().launch(config, ($value, n, &mut dst_col)) }
153                    .map_err(|e| XlogError::Kernel(format!("fill const failed: {}", e)))?;
154            }};
155        }
156
157        match col_type {
158            ScalarType::U32 | ScalarType::Symbol => {
159                let value = u32::from_le_bytes(value_bytes.try_into().unwrap());
160                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U32, value);
161            }
162            ScalarType::U64 => {
163                let value = u64::from_le_bytes(value_bytes.try_into().unwrap());
164                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U64, value);
165            }
166            ScalarType::I64 => {
167                let value = i64::from_le_bytes(value_bytes.try_into().unwrap());
168                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I64, value);
169            }
170            ScalarType::I32 => {
171                let value = i32::from_le_bytes(value_bytes.try_into().unwrap());
172                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_I32, value);
173            }
174            ScalarType::F64 => {
175                let value = f64::from_le_bytes(value_bytes.try_into().unwrap());
176                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F64, value);
177            }
178            ScalarType::F32 => {
179                let value = f32::from_le_bytes(value_bytes.try_into().unwrap());
180                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_F32, value);
181            }
182            ScalarType::Bool => {
183                let value = value_bytes[0];
184                launch_fill_const!(arith_kernels::ARITH_FILL_CONST_U8, value);
185            }
186        }
187
188        self.device.synchronize()?;
189
190        let schema = Schema::new(vec![("const".to_string(), col_type)]);
191        let mut d_num_rows = self.memory.alloc::<u32>(1)?;
192        self.device
193            .inner()
194            .dtod_copy(d_num_rows_src, &mut d_num_rows)
195            .map_err(|e| XlogError::Kernel(format!("Failed to copy row count: {}", e)))?;
196
197        Ok(CudaBuffer::from_columns(
198            vec![dst_col.into()],
199            row_cap,
200            d_num_rows,
201            schema,
202        ))
203    }
204
205    /// Element-wise addition of two single-column buffers
206    ///
207    /// Performs element-wise addition using GPU kernels.
208    /// Uses wrapping arithmetic for integer overflow.
209    ///
210    /// # Arguments
211    /// * `a` - First operand buffer (single column)
212    /// * `b` - Second operand buffer (single column)
213    ///
214    /// # Returns
215    /// A new CudaBuffer containing the element-wise sum
216    ///
217    /// # Errors
218    /// Returns `XlogError::Kernel` if:
219    /// - Row counts don't match
220    /// - Buffers are not single-column
221    /// - Type is not supported for arithmetic
222    pub fn add_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
223        match a.schema().column_type(0) {
224            Some(ScalarType::I64) => {
225                self.binary_arith_op_device::<i64>(a, b, 0, arith_kernels::ARITH_BINARY_I64)
226            }
227            Some(ScalarType::I32) => {
228                self.binary_arith_op_device::<i32>(a, b, 0, arith_kernels::ARITH_BINARY_I32)
229            }
230            Some(ScalarType::U64) => {
231                self.binary_arith_op_device::<u64>(a, b, 0, arith_kernels::ARITH_BINARY_U64)
232            }
233            Some(ScalarType::U32 | ScalarType::Symbol) => {
234                self.binary_arith_op_device::<u32>(a, b, 0, arith_kernels::ARITH_BINARY_U32)
235            }
236            Some(ScalarType::F64) => {
237                self.binary_arith_op_device::<f64>(a, b, 0, arith_kernels::ARITH_BINARY_F64)
238            }
239            Some(ScalarType::F32) => {
240                self.binary_arith_op_device::<f32>(a, b, 0, arith_kernels::ARITH_BINARY_F32)
241            }
242            other => Err(XlogError::Kernel(format!(
243                "Arithmetic not supported for {:?}",
244                other
245            ))),
246        }
247    }
248
249    /// Element-wise subtraction of two single-column buffers
250    ///
251    /// Performs element-wise subtraction using GPU kernels.
252    /// Uses wrapping arithmetic for integer overflow.
253    ///
254    /// # Arguments
255    /// * `a` - First operand buffer (single column)
256    /// * `b` - Second operand buffer (single column)
257    ///
258    /// # Returns
259    /// A new CudaBuffer containing the element-wise difference
260    ///
261    /// # Errors
262    /// Returns `XlogError::Kernel` if:
263    /// - Row counts don't match
264    /// - Buffers are not single-column
265    /// - Type is not supported for arithmetic
266    pub fn sub_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
267        match a.schema().column_type(0) {
268            Some(ScalarType::I64) => {
269                self.binary_arith_op_device::<i64>(a, b, 1, arith_kernels::ARITH_BINARY_I64)
270            }
271            Some(ScalarType::I32) => {
272                self.binary_arith_op_device::<i32>(a, b, 1, arith_kernels::ARITH_BINARY_I32)
273            }
274            Some(ScalarType::U64) => {
275                self.binary_arith_op_device::<u64>(a, b, 1, arith_kernels::ARITH_BINARY_U64)
276            }
277            Some(ScalarType::U32 | ScalarType::Symbol) => {
278                self.binary_arith_op_device::<u32>(a, b, 1, arith_kernels::ARITH_BINARY_U32)
279            }
280            Some(ScalarType::F64) => {
281                self.binary_arith_op_device::<f64>(a, b, 1, arith_kernels::ARITH_BINARY_F64)
282            }
283            Some(ScalarType::F32) => {
284                self.binary_arith_op_device::<f32>(a, b, 1, arith_kernels::ARITH_BINARY_F32)
285            }
286            other => Err(XlogError::Kernel(format!(
287                "Arithmetic not supported for {:?}",
288                other
289            ))),
290        }
291    }
292
293    /// Element-wise multiplication of two single-column buffers
294    ///
295    /// Performs element-wise multiplication using GPU kernels.
296    /// Uses wrapping arithmetic for integer overflow.
297    ///
298    /// # Arguments
299    /// * `a` - First operand buffer (single column)
300    /// * `b` - Second operand buffer (single column)
301    ///
302    /// # Returns
303    /// A new CudaBuffer containing the element-wise product
304    ///
305    /// # Errors
306    /// Returns `XlogError::Kernel` if:
307    /// - Row counts don't match
308    /// - Buffers are not single-column
309    /// - Type is not supported for arithmetic
310    pub fn mul_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
311        match a.schema().column_type(0) {
312            Some(ScalarType::I64) => {
313                self.binary_arith_op_device::<i64>(a, b, 2, arith_kernels::ARITH_BINARY_I64)
314            }
315            Some(ScalarType::I32) => {
316                self.binary_arith_op_device::<i32>(a, b, 2, arith_kernels::ARITH_BINARY_I32)
317            }
318            Some(ScalarType::U64) => {
319                self.binary_arith_op_device::<u64>(a, b, 2, arith_kernels::ARITH_BINARY_U64)
320            }
321            Some(ScalarType::U32 | ScalarType::Symbol) => {
322                self.binary_arith_op_device::<u32>(a, b, 2, arith_kernels::ARITH_BINARY_U32)
323            }
324            Some(ScalarType::F64) => {
325                self.binary_arith_op_device::<f64>(a, b, 2, arith_kernels::ARITH_BINARY_F64)
326            }
327            Some(ScalarType::F32) => {
328                self.binary_arith_op_device::<f32>(a, b, 2, arith_kernels::ARITH_BINARY_F32)
329            }
330            other => Err(XlogError::Kernel(format!(
331                "Arithmetic not supported for {:?}",
332                other
333            ))),
334        }
335    }
336
337    /// Element-wise division of two single-column buffers
338    ///
339    /// Performs element-wise division using GPU kernels.
340    /// For signed integers, division by zero returns i64::MAX/i32::MAX.
341    /// For unsigned integers, division by zero returns u64::MAX/u32::MAX.
342    /// For floats, division by zero produces Inf/NaN as per IEEE 754.
343    ///
344    /// # Arguments
345    /// * `a` - Dividend buffer (single column)
346    /// * `b` - Divisor buffer (single column)
347    ///
348    /// # Returns
349    /// A new CudaBuffer containing the element-wise quotient
350    ///
351    /// # Errors
352    /// Returns `XlogError::Kernel` if:
353    /// - Row counts don't match
354    /// - Buffers are not single-column
355    /// - Type is not supported for arithmetic
356    pub fn div_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
357        match a.schema().column_type(0) {
358            Some(ScalarType::I64) => {
359                self.binary_arith_op_device::<i64>(a, b, 3, arith_kernels::ARITH_BINARY_I64)
360            }
361            Some(ScalarType::I32) => {
362                self.binary_arith_op_device::<i32>(a, b, 3, arith_kernels::ARITH_BINARY_I32)
363            }
364            Some(ScalarType::U64) => {
365                self.binary_arith_op_device::<u64>(a, b, 3, arith_kernels::ARITH_BINARY_U64)
366            }
367            Some(ScalarType::U32 | ScalarType::Symbol) => {
368                self.binary_arith_op_device::<u32>(a, b, 3, arith_kernels::ARITH_BINARY_U32)
369            }
370            Some(ScalarType::F64) => {
371                self.binary_arith_op_device::<f64>(a, b, 3, arith_kernels::ARITH_BINARY_F64)
372            }
373            Some(ScalarType::F32) => {
374                self.binary_arith_op_device::<f32>(a, b, 3, arith_kernels::ARITH_BINARY_F32)
375            }
376            other => Err(XlogError::Kernel(format!(
377                "Arithmetic not supported for {:?}",
378                other
379            ))),
380        }
381    }
382
383    /// Element-wise modulo of two single-column buffers
384    ///
385    /// Performs element-wise modulo using GPU kernels.
386    /// For integers, modulo by zero returns 0.
387    /// For floats, modulo by zero produces NaN as per IEEE 754.
388    ///
389    /// # Arguments
390    /// * `a` - Dividend buffer (single column)
391    /// * `b` - Divisor buffer (single column)
392    ///
393    /// # Returns
394    /// A new CudaBuffer containing the element-wise remainder
395    ///
396    /// # Errors
397    /// Returns `XlogError::Kernel` if:
398    /// - Row counts don't match
399    /// - Buffers are not single-column
400    /// - Type is not supported for arithmetic
401    pub fn mod_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
402        match a.schema().column_type(0) {
403            Some(ScalarType::I64) => {
404                self.binary_arith_op_device::<i64>(a, b, 4, arith_kernels::ARITH_BINARY_I64)
405            }
406            Some(ScalarType::I32) => {
407                self.binary_arith_op_device::<i32>(a, b, 4, arith_kernels::ARITH_BINARY_I32)
408            }
409            Some(ScalarType::U64) => {
410                self.binary_arith_op_device::<u64>(a, b, 4, arith_kernels::ARITH_BINARY_U64)
411            }
412            Some(ScalarType::U32 | ScalarType::Symbol) => {
413                self.binary_arith_op_device::<u32>(a, b, 4, arith_kernels::ARITH_BINARY_U32)
414            }
415            Some(ScalarType::F64) => {
416                self.binary_arith_op_device::<f64>(a, b, 4, arith_kernels::ARITH_BINARY_F64)
417            }
418            Some(ScalarType::F32) => {
419                self.binary_arith_op_device::<f32>(a, b, 4, arith_kernels::ARITH_BINARY_F32)
420            }
421            other => Err(XlogError::Kernel(format!(
422                "Arithmetic not supported for {:?}",
423                other
424            ))),
425        }
426    }
427
428    /// Element-wise absolute value of a single-column buffer
429    ///
430    /// Performs element-wise absolute value using GPU kernels.
431    ///
432    /// # Arguments
433    /// * `a` - Input buffer (single column)
434    ///
435    /// # Returns
436    /// A new CudaBuffer containing the absolute values
437    ///
438    /// # Errors
439    /// Returns `XlogError::Kernel` if:
440    /// - Buffer is not single-column
441    /// - Type is not supported for arithmetic
442    pub fn abs_column(&self, a: &CudaBuffer) -> Result<CudaBuffer> {
443        if a.arity() != 1 {
444            return Err(XlogError::Kernel(
445                "Arithmetic requires single-column buffers".into(),
446            ));
447        }
448
449        if a.num_rows() == 0 {
450            return self.create_empty_buffer(a.schema().clone());
451        }
452
453        let n: u32 = a.num_rows().try_into().map_err(|_| {
454            XlogError::Kernel(format!(
455                "abs_column: row count {} exceeds u32::MAX",
456                a.num_rows()
457            ))
458        })?;
459        let col = a
460            .column(0)
461            .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
462        let config = LaunchConfig::for_num_elems(n);
463
464        match a.schema().column_type(0) {
465            Some(ScalarType::I64) => {
466                let expected_bytes = (n as usize)
467                    .checked_mul(std::mem::size_of::<i64>())
468                    .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
469                if col.num_bytes() != expected_bytes {
470                    return Err(XlogError::Kernel(format!(
471                        "Column 0 has {} bytes but expected {} for {} rows",
472                        col.num_bytes(),
473                        expected_bytes,
474                        a.num_rows()
475                    )));
476                }
477                let mut out = self.memory.alloc::<u8>(expected_bytes)?;
478                let func = self
479                    .device
480                    .inner()
481                    .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_I64)
482                    .ok_or_else(|| XlogError::Kernel("arith_abs_i64 not found".into()))?;
483                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
484                unsafe { func.clone().launch(config, (col, n, &mut out)) }
485                    .map_err(|e| XlogError::Kernel(format!("abs_i64 failed: {}", e)))?;
486                self.device.synchronize()?;
487                self.buffer_from_columns_with_device_count(
488                    vec![out.into()],
489                    a.num_rows(),
490                    a.schema.clone(),
491                    a,
492                )
493            }
494            Some(ScalarType::I32) => {
495                let expected_bytes = (n as usize)
496                    .checked_mul(std::mem::size_of::<i32>())
497                    .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
498                if col.num_bytes() != expected_bytes {
499                    return Err(XlogError::Kernel(format!(
500                        "Column 0 has {} bytes but expected {} for {} rows",
501                        col.num_bytes(),
502                        expected_bytes,
503                        a.num_rows()
504                    )));
505                }
506                let mut out = self.memory.alloc::<u8>(expected_bytes)?;
507                let func = self
508                    .device
509                    .inner()
510                    .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_I32)
511                    .ok_or_else(|| XlogError::Kernel("arith_abs_i32 not found".into()))?;
512                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
513                unsafe { func.clone().launch(config, (col, n, &mut out)) }
514                    .map_err(|e| XlogError::Kernel(format!("abs_i32 failed: {}", e)))?;
515                self.device.synchronize()?;
516                self.buffer_from_columns_with_device_count(
517                    vec![out.into()],
518                    a.num_rows(),
519                    a.schema.clone(),
520                    a,
521                )
522            }
523            Some(ScalarType::F64) => {
524                let expected_bytes = (n as usize)
525                    .checked_mul(std::mem::size_of::<f64>())
526                    .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
527                if col.num_bytes() != expected_bytes {
528                    return Err(XlogError::Kernel(format!(
529                        "Column 0 has {} bytes but expected {} for {} rows",
530                        col.num_bytes(),
531                        expected_bytes,
532                        a.num_rows()
533                    )));
534                }
535                let mut out = self.memory.alloc::<u8>(expected_bytes)?;
536                let func = self
537                    .device
538                    .inner()
539                    .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_F64)
540                    .ok_or_else(|| XlogError::Kernel("arith_abs_f64 not found".into()))?;
541                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
542                unsafe { func.clone().launch(config, (col, n, &mut out)) }
543                    .map_err(|e| XlogError::Kernel(format!("abs_f64 failed: {}", e)))?;
544                self.device.synchronize()?;
545                self.buffer_from_columns_with_device_count(
546                    vec![out.into()],
547                    a.num_rows(),
548                    a.schema.clone(),
549                    a,
550                )
551            }
552            Some(ScalarType::F32) => {
553                let expected_bytes = (n as usize)
554                    .checked_mul(std::mem::size_of::<f32>())
555                    .ok_or_else(|| XlogError::Kernel("abs_column size overflow".into()))?;
556                if col.num_bytes() != expected_bytes {
557                    return Err(XlogError::Kernel(format!(
558                        "Column 0 has {} bytes but expected {} for {} rows",
559                        col.num_bytes(),
560                        expected_bytes,
561                        a.num_rows()
562                    )));
563                }
564                let mut out = self.memory.alloc::<u8>(expected_bytes)?;
565                let func = self
566                    .device
567                    .inner()
568                    .get_func(ARITH_MODULE, arith_kernels::ARITH_ABS_F32)
569                    .ok_or_else(|| XlogError::Kernel("arith_abs_f32 not found".into()))?;
570                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
571                unsafe { func.clone().launch(config, (col, n, &mut out)) }
572                    .map_err(|e| XlogError::Kernel(format!("abs_f32 failed: {}", e)))?;
573                self.device.synchronize()?;
574                self.buffer_from_columns_with_device_count(
575                    vec![out.into()],
576                    a.num_rows(),
577                    a.schema.clone(),
578                    a,
579                )
580            }
581            Some(ScalarType::U32 | ScalarType::U64 | ScalarType::Bool | ScalarType::Symbol) => {
582                self.clone_buffer(a)
583            }
584            other => Err(XlogError::Kernel(format!(
585                "Abs not supported for {:?}",
586                other
587            ))),
588        }
589    }
590
591    /// Element-wise minimum of two single-column buffers
592    ///
593    /// Performs element-wise minimum using GPU kernels.
594    ///
595    /// # Arguments
596    /// * `a` - First operand buffer (single column)
597    /// * `b` - Second operand buffer (single column)
598    ///
599    /// # Returns
600    /// A new CudaBuffer containing the element-wise minimums
601    ///
602    /// # Errors
603    /// Returns `XlogError::Kernel` if:
604    /// - Row counts don't match
605    /// - Buffers are not single-column
606    /// - Type is not supported for arithmetic
607    pub fn min_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
608        match a.schema().column_type(0) {
609            Some(ScalarType::I64) => {
610                self.binary_arith_op_device::<i64>(a, b, 5, arith_kernels::ARITH_BINARY_I64)
611            }
612            Some(ScalarType::I32) => {
613                self.binary_arith_op_device::<i32>(a, b, 5, arith_kernels::ARITH_BINARY_I32)
614            }
615            Some(ScalarType::U64) => {
616                self.binary_arith_op_device::<u64>(a, b, 5, arith_kernels::ARITH_BINARY_U64)
617            }
618            Some(ScalarType::U32 | ScalarType::Symbol) => {
619                self.binary_arith_op_device::<u32>(a, b, 5, arith_kernels::ARITH_BINARY_U32)
620            }
621            Some(ScalarType::F64) => {
622                self.binary_arith_op_device::<f64>(a, b, 5, arith_kernels::ARITH_BINARY_F64)
623            }
624            Some(ScalarType::F32) => {
625                self.binary_arith_op_device::<f32>(a, b, 5, arith_kernels::ARITH_BINARY_F32)
626            }
627            other => Err(XlogError::Kernel(format!(
628                "Arithmetic not supported for {:?}",
629                other
630            ))),
631        }
632    }
633
634    /// Element-wise maximum of two single-column buffers
635    ///
636    /// Performs element-wise maximum using GPU kernels.
637    ///
638    /// # Arguments
639    /// * `a` - First operand buffer (single column)
640    /// * `b` - Second operand buffer (single column)
641    ///
642    /// # Returns
643    /// A new CudaBuffer containing the element-wise maximums
644    ///
645    /// # Errors
646    /// Returns `XlogError::Kernel` if:
647    /// - Row counts don't match
648    /// - Buffers are not single-column
649    /// - Type is not supported for arithmetic
650    pub fn max_columns(&self, a: &CudaBuffer, b: &CudaBuffer) -> Result<CudaBuffer> {
651        match a.schema().column_type(0) {
652            Some(ScalarType::I64) => {
653                self.binary_arith_op_device::<i64>(a, b, 6, arith_kernels::ARITH_BINARY_I64)
654            }
655            Some(ScalarType::I32) => {
656                self.binary_arith_op_device::<i32>(a, b, 6, arith_kernels::ARITH_BINARY_I32)
657            }
658            Some(ScalarType::U64) => {
659                self.binary_arith_op_device::<u64>(a, b, 6, arith_kernels::ARITH_BINARY_U64)
660            }
661            Some(ScalarType::U32 | ScalarType::Symbol) => {
662                self.binary_arith_op_device::<u32>(a, b, 6, arith_kernels::ARITH_BINARY_U32)
663            }
664            Some(ScalarType::F64) => {
665                self.binary_arith_op_device::<f64>(a, b, 6, arith_kernels::ARITH_BINARY_F64)
666            }
667            Some(ScalarType::F32) => {
668                self.binary_arith_op_device::<f32>(a, b, 6, arith_kernels::ARITH_BINARY_F32)
669            }
670            other => Err(XlogError::Kernel(format!(
671                "Arithmetic not supported for {:?}",
672                other
673            ))),
674        }
675    }
676
677    /// Element-wise power of two single-column buffers
678    ///
679    /// Converts both operands to f64, computes x^y on the GPU, and returns f64 result.
680    /// This matches the behavior of most database systems where pow() returns a float.
681    ///
682    /// # Arguments
683    /// * `base` - Base values buffer (single column)
684    /// * `exp` - Exponent values buffer (single column)
685    ///
686    /// # Returns
687    /// A new CudaBuffer containing the element-wise powers as f64
688    ///
689    /// # Errors
690    /// Returns `XlogError::Kernel` if:
691    /// - Row counts don't match
692    /// - Buffers are not single-column
693    /// - Type is not supported for arithmetic
694    pub fn pow_columns(&self, base: &CudaBuffer, exp: &CudaBuffer) -> Result<CudaBuffer> {
695        if base.num_rows() != exp.num_rows() {
696            return Err(XlogError::Kernel("Row count mismatch".into()));
697        }
698        if base.arity() != 1 || exp.arity() != 1 {
699            return Err(XlogError::Kernel(
700                "Arithmetic requires single-column buffers".into(),
701            ));
702        }
703
704        if base.num_rows() == 0 {
705            let schema = Schema::new(vec![("result".to_string(), ScalarType::F64)]);
706            return self.create_empty_buffer(schema);
707        }
708
709        let n: u32 = base.num_rows().try_into().map_err(|_| {
710            XlogError::Kernel(format!(
711                "pow_columns: row count {} exceeds u32::MAX",
712                base.num_rows()
713            ))
714        })?;
715
716        let base_f64_buf = if base.schema().column_type(0) == Some(ScalarType::F64) {
717            None
718        } else {
719            Some(self.cast_column(base, ScalarType::F64)?)
720        };
721        let base_buf = base_f64_buf.as_ref().unwrap_or(base);
722
723        let exp_f64_buf = if exp.schema().column_type(0) == Some(ScalarType::F64) {
724            None
725        } else {
726            Some(self.cast_column(exp, ScalarType::F64)?)
727        };
728        let exp_buf = exp_f64_buf.as_ref().unwrap_or(exp);
729
730        let base_col = base_buf
731            .column(0)
732            .ok_or_else(|| XlogError::Kernel("Missing base column".into()))?;
733        let exp_col = exp_buf
734            .column(0)
735            .ok_or_else(|| XlogError::Kernel("Missing exp column".into()))?;
736
737        let expected_bytes = (n as usize)
738            .checked_mul(std::mem::size_of::<f64>())
739            .ok_or_else(|| XlogError::Kernel("pow_columns size overflow".into()))?;
740        if base_col.num_bytes() != expected_bytes || exp_col.num_bytes() != expected_bytes {
741            return Err(XlogError::Kernel(format!(
742                "pow_columns: expected {} bytes for {} rows",
743                expected_bytes,
744                base.num_rows()
745            )));
746        }
747
748        let mut out = self.memory.alloc::<u8>(expected_bytes)?;
749        let func = self
750            .device
751            .inner()
752            .get_func(ARITH_MODULE, arith_kernels::ARITH_POW_F64)
753            .ok_or_else(|| XlogError::Kernel("arith_pow_f64 not found".into()))?;
754        let config = LaunchConfig::for_num_elems(n);
755
756        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
757        unsafe {
758            func.clone()
759                .launch(config, (base_col, exp_col, n, &mut out))
760        }
761        .map_err(|e| XlogError::Kernel(format!("pow_f64 failed: {}", e)))?;
762
763        self.device.synchronize()?;
764
765        let schema = Schema::new(vec![("result".to_string(), ScalarType::F64)]);
766        self.buffer_from_columns_with_device_count(vec![out.into()], base.num_rows(), schema, base)
767    }
768
769    /// Conditional select between two single-column buffers based on a boolean mask.
770    ///
771    /// For each row: out[i] = mask[i] ? then_vals[i] : else_vals[i]
772    ///
773    /// # Arguments
774    /// * `mask` - Boolean mask buffer (single column, type Bool/u8)
775    /// * `then_vals` - Values to select when mask is true
776    /// * `else_vals` - Values to select when mask is false
777    ///
778    /// # Returns
779    /// A new CudaBuffer with values selected based on the mask
780    ///
781    /// # Errors
782    /// Returns `XlogError::Kernel` if:
783    /// - Row counts don't match
784    /// - Buffers are not single-column
785    /// - Types of then/else values don't match
786    pub fn select_columns(
787        &self,
788        mask: &CudaBuffer,
789        then_vals: &CudaBuffer,
790        else_vals: &CudaBuffer,
791    ) -> Result<CudaBuffer> {
792        if mask.num_rows() != then_vals.num_rows() || mask.num_rows() != else_vals.num_rows() {
793            return Err(XlogError::Kernel("Row count mismatch in select".into()));
794        }
795        if mask.arity() != 1 || then_vals.arity() != 1 || else_vals.arity() != 1 {
796            return Err(XlogError::Kernel(
797                "Select requires single-column buffers".into(),
798            ));
799        }
800
801        let then_type = then_vals.schema().column_type(0);
802        let else_type = else_vals.schema().column_type(0);
803        if then_type != else_type {
804            return Err(XlogError::Kernel(format!(
805                "Type mismatch in select: then={:?}, else={:?}",
806                then_type, else_type
807            )));
808        }
809
810        if mask.num_rows() == 0 {
811            let result_type = then_type.unwrap_or(ScalarType::I64);
812            let schema = Schema::new(vec![("result".to_string(), result_type)]);
813            return self.create_empty_buffer(schema);
814        }
815
816        let n: u32 = mask.num_rows().try_into().map_err(|_| {
817            XlogError::Kernel(format!(
818                "select_columns: row count {} exceeds u32::MAX",
819                mask.num_rows()
820            ))
821        })?;
822
823        let mask_col = mask
824            .column(0)
825            .ok_or_else(|| XlogError::Kernel("Missing mask column".into()))?;
826        let then_col = then_vals
827            .column(0)
828            .ok_or_else(|| XlogError::Kernel("Missing then column".into()))?;
829        let else_col = else_vals
830            .column(0)
831            .ok_or_else(|| XlogError::Kernel("Missing else column".into()))?;
832
833        let result_type = then_type.unwrap_or(ScalarType::I64);
834        let elem_size = result_type.size_bytes();
835        let expected_bytes = (n as usize)
836            .checked_mul(elem_size)
837            .ok_or_else(|| XlogError::Kernel("select_columns size overflow".into()))?;
838
839        let mut out = self.memory.alloc::<u8>(expected_bytes)?;
840
841        let kernel_name = match result_type {
842            ScalarType::I64 => arith_kernels::ARITH_SELECT_I64,
843            ScalarType::I32 => arith_kernels::ARITH_SELECT_I32,
844            ScalarType::U64 => arith_kernels::ARITH_SELECT_U64,
845            ScalarType::U32 | ScalarType::Symbol => arith_kernels::ARITH_SELECT_U32,
846            ScalarType::F64 => arith_kernels::ARITH_SELECT_F64,
847            ScalarType::F32 => arith_kernels::ARITH_SELECT_F32,
848            ScalarType::Bool => {
849                // Bool is stored as u8, treat as u8 select (use fill + mask trick)
850                // For simplicity, cast to u32 and back
851                return self.select_columns_bool(mask, then_vals, else_vals);
852            }
853        };
854
855        let func = self
856            .device
857            .inner()
858            .get_func(ARITH_MODULE, kernel_name)
859            .ok_or_else(|| XlogError::Kernel(format!("{} not found", kernel_name)))?;
860        let config = LaunchConfig::for_num_elems(n);
861
862        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
863        unsafe {
864            func.clone()
865                .launch(config, (mask_col, then_col, else_col, n, &mut out))
866        }
867        .map_err(|e| XlogError::Kernel(format!("select kernel failed: {}", e)))?;
868
869        self.device.synchronize()?;
870
871        let schema = Schema::new(vec![("result".to_string(), result_type)]);
872        self.buffer_from_columns_with_device_count(vec![out.into()], mask.num_rows(), schema, mask)
873    }
874
875    /// Helper for select_columns when result type is Bool
876    fn select_columns_bool(
877        &self,
878        mask: &CudaBuffer,
879        then_vals: &CudaBuffer,
880        else_vals: &CudaBuffer,
881    ) -> Result<CudaBuffer> {
882        // Cast bool columns to u32, select, then cast back
883        let then_u32 = self.cast_column(then_vals, ScalarType::U32)?;
884        let else_u32 = self.cast_column(else_vals, ScalarType::U32)?;
885        let result_u32 = self.select_columns(mask, &then_u32, &else_u32)?;
886        self.cast_column(&result_u32, ScalarType::Bool)
887    }
888
889    /// Cast a single-column buffer to a different type
890    ///
891    /// Casts data on the GPU using the arithmetic cast kernel.
892    ///
893    /// # Arguments
894    /// * `a` - Input buffer (single column)
895    /// * `target` - Target scalar type
896    ///
897    /// # Returns
898    /// A new CudaBuffer with the cast values
899    ///
900    /// # Errors
901    /// Returns `XlogError::Kernel` if:
902    /// - Buffer is not single-column
903    /// - Source or target type is not supported for casting
904    pub fn cast_column(&self, a: &CudaBuffer, target: ScalarType) -> Result<CudaBuffer> {
905        if a.arity() != 1 {
906            return Err(XlogError::Kernel(
907                "Cast requires single-column buffer".into(),
908            ));
909        }
910
911        let source_type = a
912            .schema()
913            .column_type(0)
914            .ok_or_else(|| XlogError::Kernel("Missing column type".into()))?;
915
916        let schema = Schema::new(vec![("result".to_string(), target)]);
917
918        if a.num_rows() == 0 {
919            return self.create_empty_buffer(schema);
920        }
921
922        let n: u32 = a.num_rows().try_into().map_err(|_| {
923            XlogError::Kernel(format!(
924                "cast_column: row count {} exceeds u32::MAX",
925                a.num_rows()
926            ))
927        })?;
928
929        let src_col = a
930            .column(0)
931            .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
932        let src_bytes = (n as usize)
933            .checked_mul(source_type.size_bytes())
934            .ok_or_else(|| XlogError::Kernel("cast_column size overflow".into()))?;
935        if src_col.num_bytes() != src_bytes {
936            return Err(XlogError::Kernel(format!(
937                "Column 0 has {} bytes but expected {} for {} rows",
938                src_col.num_bytes(),
939                src_bytes,
940                a.num_rows()
941            )));
942        }
943
944        let dst_bytes = (n as usize)
945            .checked_mul(target.size_bytes())
946            .ok_or_else(|| XlogError::Kernel("cast_column size overflow".into()))?;
947        let mut out = self.memory.alloc::<u8>(dst_bytes)?;
948
949        let func = self
950            .device
951            .inner()
952            .get_func(ARITH_MODULE, arith_kernels::ARITH_CAST)
953            .ok_or_else(|| XlogError::Kernel("arith_cast not found".into()))?;
954        let config = LaunchConfig::for_num_elems(n);
955
956        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
957        unsafe {
958            func.clone().launch(
959                config,
960                (
961                    src_col,
962                    &mut out,
963                    n,
964                    source_type.to_code(),
965                    target.to_code(),
966                ),
967            )
968        }
969        .map_err(|e| XlogError::Kernel(format!("cast failed: {}", e)))?;
970
971        self.device.synchronize()?;
972
973        self.buffer_from_columns_with_device_count(vec![out.into()], a.num_rows(), schema, a)
974    }
975
976    /// Helper for binary arithmetic operations on device.
977    fn binary_arith_op_device<T: DeviceRepr>(
978        &self,
979        a: &CudaBuffer,
980        b: &CudaBuffer,
981        op: u8,
982        kernel: &str,
983    ) -> Result<CudaBuffer> {
984        if a.num_rows() != b.num_rows() {
985            return Err(XlogError::Kernel("Row count mismatch".into()));
986        }
987        if a.arity() != 1 || b.arity() != 1 {
988            return Err(XlogError::Kernel(
989                "Arithmetic requires single-column buffers".into(),
990            ));
991        }
992        if a.schema().column_type(0) != b.schema().column_type(0) {
993            return Err(XlogError::Kernel(
994                "Arithmetic requires matching column types".into(),
995            ));
996        }
997        if a.num_rows() == 0 {
998            return self.create_empty_buffer(a.schema.clone());
999        }
1000
1001        let n: u32 = a.num_rows().try_into().map_err(|_| {
1002            XlogError::Kernel(format!(
1003                "arith: row count {} exceeds u32::MAX",
1004                a.num_rows()
1005            ))
1006        })?;
1007
1008        let expected_bytes = (n as usize)
1009            .checked_mul(std::mem::size_of::<T>())
1010            .ok_or_else(|| XlogError::Kernel("arith output size overflow".into()))?;
1011
1012        let col_a = a
1013            .column(0)
1014            .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
1015        let col_b = b
1016            .column(0)
1017            .ok_or_else(|| XlogError::Kernel("Missing column 0".into()))?;
1018
1019        if col_a.num_bytes() != expected_bytes || col_b.num_bytes() != expected_bytes {
1020            return Err(XlogError::Kernel(format!(
1021                "Arithmetic expects {} bytes per column for {} rows",
1022                expected_bytes,
1023                a.num_rows()
1024            )));
1025        }
1026
1027        let mut out = self.memory.alloc::<u8>(expected_bytes)?;
1028        let func = self
1029            .device
1030            .inner()
1031            .get_func(ARITH_MODULE, kernel)
1032            .ok_or_else(|| XlogError::Kernel("arith kernel not found".into()))?;
1033        let config = LaunchConfig::for_num_elems(n);
1034
1035        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1036        unsafe { func.clone().launch(config, (col_a, col_b, n, op, &mut out)) }
1037            .map_err(|e| XlogError::Kernel(format!("arith binary failed: {}", e)))?;
1038
1039        self.device.synchronize()?;
1040        self.buffer_from_columns_with_device_count(
1041            vec![out.into()],
1042            a.num_rows(),
1043            a.schema.clone(),
1044            a,
1045        )
1046    }
1047
1048    /// Combine multiple single-column buffers into a multi-column buffer
1049    ///
1050    /// # Arguments
1051    /// * `columns` - Vector of single-column CudaBuffers to combine
1052    /// * `types` - Vector of ScalarTypes for each column
1053    ///
1054    /// # Returns
1055    /// A new CudaBuffer with all columns combined
1056    pub fn combine_columns(
1057        &self,
1058        columns: Vec<CudaBuffer>,
1059        types: Vec<ScalarType>,
1060    ) -> Result<CudaBuffer> {
1061        if columns.is_empty() {
1062            let schema_cols: Vec<(String, ScalarType)> = types
1063                .iter()
1064                .enumerate()
1065                .map(|(i, t)| (format!("col_{}", i), *t))
1066                .collect();
1067            let schema = Schema::new(schema_cols);
1068            return self.create_empty_buffer(schema);
1069        }
1070
1071        let row_cap = columns[0].row_cap;
1072
1073        // Verify all columns have the same row capacity and are single-column
1074        for (i, col) in columns.iter().enumerate() {
1075            if col.row_cap != row_cap {
1076                return Err(XlogError::Kernel(format!(
1077                    "Column {} has row capacity {}, expected {}",
1078                    i, col.row_cap, row_cap
1079                )));
1080            }
1081            if col.arity() != 1 {
1082                return Err(XlogError::Kernel(format!(
1083                    "Column {} buffer has {} columns, expected 1",
1084                    i,
1085                    col.arity()
1086                )));
1087            }
1088        }
1089
1090        let device = self.device.inner();
1091        let mut d_num_rows = self.memory.alloc::<u32>(1)?;
1092        device
1093            .dtod_copy(columns[0].num_rows_device(), &mut d_num_rows)
1094            .map_err(|e| XlogError::Kernel(format!("Failed to copy row count: {}", e)))?;
1095        self.device.synchronize()?;
1096
1097        let mut result_columns = Vec::with_capacity(columns.len());
1098        for (i, col_buf) in columns.into_iter().enumerate() {
1099            let src_col = col_buf
1100                .columns
1101                .into_iter()
1102                .next()
1103                .ok_or_else(|| XlogError::Kernel(format!("Column {} buffer has no data", i)))?;
1104            result_columns.push(src_col);
1105        }
1106
1107        let schema_cols: Vec<(String, ScalarType)> = types
1108            .iter()
1109            .enumerate()
1110            .map(|(i, t)| (format!("col_{}", i), *t))
1111            .collect();
1112        let schema = Schema::new(schema_cols);
1113
1114        Ok(CudaBuffer::from_columns(
1115            result_columns,
1116            row_cap,
1117            d_num_rows,
1118            schema,
1119        ))
1120    }
1121}