Skip to main content

xlog_runtime/executor/
expression.rs

1//! Expression evaluation methods for the Executor.
2//!
3//! Production GPU-accelerated filter, predicate mask, arithmetic expression,
4//! and mask operation methods.
5
6use cudarc::driver::LaunchConfig;
7use xlog_core::{Result, ScalarType, Schema, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::provider::{arith_kernels, filter_kernels, ARITH_MODULE, FILTER_MODULE};
10use xlog_cuda::{CudaBuffer, LaunchAsync};
11use xlog_ir::{CompareOp, ConstValue, Expr, ProjectExpr};
12
13use super::Executor;
14
15impl Executor {
16    /// Check if an expression may produce a floating-point result.
17    pub(crate) fn expr_may_be_float(expr: &Expr, schema: &Schema) -> bool {
18        match expr {
19            Expr::Column(col_idx) => matches!(
20                schema.column_type(*col_idx),
21                Some(ScalarType::F32 | ScalarType::F64)
22            ),
23            Expr::Const(ConstValue::F32(_) | ConstValue::F64(_)) => true,
24            Expr::Cast(_, ScalarType::F32 | ScalarType::F64) => true,
25            Expr::Add(l, r)
26            | Expr::Sub(l, r)
27            | Expr::Mul(l, r)
28            | Expr::Div(l, r)
29            | Expr::Mod(l, r)
30            | Expr::Min(l, r)
31            | Expr::Max(l, r)
32            | Expr::Pow(l, r) => {
33                Self::expr_may_be_float(l, schema) || Self::expr_may_be_float(r, schema)
34            }
35            Expr::Abs(inner) | Expr::Cast(inner, _) => Self::expr_may_be_float(inner, schema),
36            _ => false,
37        }
38    }
39
40    /// Execute a Filter node using GPU predicate evaluation.
41    pub fn execute_filter(&self, input: &CudaBuffer, predicate: &Expr) -> Result<CudaBuffer> {
42        if input.is_empty() {
43            return self.create_empty_buffer(input.schema().clone());
44        }
45
46        let mask = self.eval_predicate_mask_gpu(predicate, input)?;
47        self.provider.filter_by_device_mask(input, &mask)
48    }
49
50    pub(crate) fn eval_predicate_mask_gpu(
51        &self,
52        expr: &Expr,
53        input: &CudaBuffer,
54    ) -> Result<TrackedCudaSlice<u8>> {
55        if input.num_rows() > u32::MAX as u64 {
56            return Err(XlogError::Execution(format!(
57                "Predicate evaluation supports at most {} rows, got {}",
58                u32::MAX,
59                input.num_rows()
60            )));
61        }
62        let n = input.num_rows() as u32;
63
64        match expr {
65            Expr::Column(col_idx) => {
66                let col_type = input
67                    .schema()
68                    .column_type(*col_idx)
69                    .ok_or_else(|| XlogError::Execution(format!("Column {} not found", col_idx)))?;
70                if col_type == ScalarType::Bool {
71                    let col_buf = self.wrap_single_column(input, *col_idx)?;
72                    let zero = self.provider.create_constant_column_with_device_count(
73                        &[0u8],
74                        ScalarType::Bool,
75                        input.num_rows(),
76                        input.num_rows_device(),
77                    )?;
78                    return self.compare_buffers_mask(&col_buf, &zero, CompareOp::Ne);
79                }
80                self.mask_filled(n, 1)
81            }
82            Expr::Const(ConstValue::Bool(b)) => self.mask_filled(n, if *b { 1 } else { 0 }),
83            Expr::Const(_) => self.mask_filled(n, 1),
84            Expr::Compare { left, op, right } => {
85                let use_float = Self::expr_may_be_float(left, input.schema())
86                    || Self::expr_may_be_float(right, input.schema());
87
88                let mut left_buf = self.evaluate_arith_expr(left, input)?;
89                let mut right_buf = self.evaluate_arith_expr(right, input)?;
90
91                if use_float {
92                    left_buf = self.provider.cast_column(&left_buf, ScalarType::F64)?;
93                    right_buf = self.provider.cast_column(&right_buf, ScalarType::F64)?;
94                }
95
96                self.compare_buffers_mask(&left_buf, &right_buf, *op)
97            }
98            Expr::And(exprs) => {
99                if exprs.is_empty() {
100                    return self.mask_filled(n, 1);
101                }
102                let mut mask = self.eval_predicate_mask_gpu(&exprs[0], input)?;
103                for expr in &exprs[1..] {
104                    let next = self.eval_predicate_mask_gpu(expr, input)?;
105                    mask = self.mask_and(&mask, &next, n)?;
106                }
107                Ok(mask)
108            }
109            Expr::Or(exprs) => {
110                if exprs.is_empty() {
111                    return self.mask_filled(n, 0);
112                }
113                let mut mask = self.eval_predicate_mask_gpu(&exprs[0], input)?;
114                for expr in &exprs[1..] {
115                    let next = self.eval_predicate_mask_gpu(expr, input)?;
116                    mask = self.mask_or(&mask, &next, n)?;
117                }
118                Ok(mask)
119            }
120            Expr::Not(inner) => {
121                let mask = self.eval_predicate_mask_gpu(inner, input)?;
122                self.mask_not(&mask, n)
123            }
124            Expr::Add(_, _)
125            | Expr::Sub(_, _)
126            | Expr::Mul(_, _)
127            | Expr::Div(_, _)
128            | Expr::Mod(_, _)
129            | Expr::Abs(_)
130            | Expr::Min(_, _)
131            | Expr::Max(_, _)
132            | Expr::Pow(_, _)
133            | Expr::Cast(_, _)
134            | Expr::Conditional { .. } => Err(XlogError::Execution(
135                "Arithmetic expression cannot be evaluated as boolean predicate".into(),
136            )),
137        }
138    }
139
140    fn compare_buffers_mask(
141        &self,
142        left: &CudaBuffer,
143        right: &CudaBuffer,
144        op: CompareOp,
145    ) -> Result<TrackedCudaSlice<u8>> {
146        if left.arity() != 1 || right.arity() != 1 {
147            return Err(XlogError::Execution(
148                "Compare requires single-column buffers".into(),
149            ));
150        }
151        if left.num_rows() != right.num_rows() {
152            return Err(XlogError::Execution(
153                "Compare requires matching row counts".into(),
154            ));
155        }
156        if left.num_rows() > u32::MAX as u64 {
157            return Err(XlogError::Execution(format!(
158                "Compare supports at most {} rows, got {}",
159                u32::MAX,
160                left.num_rows()
161            )));
162        }
163        if left.is_empty() {
164            return self.provider.memory().alloc::<u8>(0).map_err(|e| {
165                XlogError::execution_ctx("compare_buffers_mask", "allocate empty mask", &e)
166            });
167        }
168
169        let left_type = left
170            .schema()
171            .column_type(0)
172            .ok_or_else(|| XlogError::Execution("Missing left column type".into()))?;
173        let right_type = right
174            .schema()
175            .column_type(0)
176            .ok_or_else(|| XlogError::Execution("Missing right column type".into()))?;
177
178        if left_type != right_type {
179            return Err(XlogError::Execution(
180                "Compare requires matching column types".into(),
181            ));
182        }
183
184        let kernel = match left_type {
185            ScalarType::U32 | ScalarType::Symbol => filter_kernels::FILTER_COMPARE_U32_COL,
186            ScalarType::U64 => filter_kernels::FILTER_COMPARE_U64_COL,
187            ScalarType::I32 => filter_kernels::FILTER_COMPARE_I32_COL,
188            ScalarType::I64 => filter_kernels::FILTER_COMPARE_I64_COL,
189            ScalarType::F32 => filter_kernels::FILTER_COMPARE_F32_COL,
190            ScalarType::F64 => filter_kernels::FILTER_COMPARE_F64_COL,
191            ScalarType::Bool => filter_kernels::FILTER_COMPARE_U8_COL,
192        };
193
194        let left_col = left
195            .column(0)
196            .ok_or_else(|| XlogError::Execution("Missing left column".into()))?;
197        let right_col = right
198            .column(0)
199            .ok_or_else(|| XlogError::Execution("Missing right column".into()))?;
200
201        let num_rows = left.num_rows() as u32;
202        let mut d_mask = self.provider.memory().alloc::<u8>(num_rows as usize)?;
203
204        let func = self
205            .provider
206            .device()
207            .inner()
208            .get_func(FILTER_MODULE, kernel)
209            .ok_or_else(|| XlogError::Execution("filter compare kernel not found".into()))?;
210        let config = LaunchConfig::for_num_elems(num_rows);
211
212        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
213        unsafe {
214            func.clone().launch(
215                config,
216                (left_col, right_col, num_rows, op as u8, &mut d_mask),
217            )
218        }
219        .map_err(|e| XlogError::execution_ctx("compare_buffers_mask", "filter compare", &e))?;
220
221        Ok(d_mask)
222    }
223
224    fn mask_and(
225        &self,
226        left: &TrackedCudaSlice<u8>,
227        right: &TrackedCudaSlice<u8>,
228        n: u32,
229    ) -> Result<TrackedCudaSlice<u8>> {
230        let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
231        if n == 0 {
232            return Ok(out);
233        }
234
235        let func = self
236            .provider
237            .device()
238            .inner()
239            .get_func(FILTER_MODULE, filter_kernels::MASK_AND)
240            .ok_or_else(|| XlogError::Execution("mask_and kernel not found".into()))?;
241        let config = LaunchConfig::for_num_elems(n);
242
243        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
244        unsafe { func.clone().launch(config, (left, right, &mut out, n)) }
245            .map_err(|e| XlogError::execution_ctx("mask_and", "launch kernel", &e))?;
246
247        Ok(out)
248    }
249
250    fn mask_or(
251        &self,
252        left: &TrackedCudaSlice<u8>,
253        right: &TrackedCudaSlice<u8>,
254        n: u32,
255    ) -> Result<TrackedCudaSlice<u8>> {
256        let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
257        if n == 0 {
258            return Ok(out);
259        }
260
261        let func = self
262            .provider
263            .device()
264            .inner()
265            .get_func(FILTER_MODULE, filter_kernels::MASK_OR)
266            .ok_or_else(|| XlogError::Execution("mask_or kernel not found".into()))?;
267        let config = LaunchConfig::for_num_elems(n);
268
269        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
270        unsafe { func.clone().launch(config, (left, right, &mut out, n)) }
271            .map_err(|e| XlogError::execution_ctx("mask_or", "launch kernel", &e))?;
272
273        Ok(out)
274    }
275
276    fn mask_not(&self, input: &TrackedCudaSlice<u8>, n: u32) -> Result<TrackedCudaSlice<u8>> {
277        let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
278        if n == 0 {
279            return Ok(out);
280        }
281
282        let func = self
283            .provider
284            .device()
285            .inner()
286            .get_func(FILTER_MODULE, filter_kernels::MASK_NOT)
287            .ok_or_else(|| XlogError::Execution("mask_not kernel not found".into()))?;
288        let config = LaunchConfig::for_num_elems(n);
289
290        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
291        unsafe { func.clone().launch(config, (input, &mut out, n)) }
292            .map_err(|e| XlogError::execution_ctx("mask_not", "launch kernel", &e))?;
293
294        Ok(out)
295    }
296
297    fn mask_filled(&self, n: u32, value: u8) -> Result<TrackedCudaSlice<u8>> {
298        let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
299        if n == 0 {
300            return Ok(out);
301        }
302
303        if value == 0 {
304            self.provider
305                .device()
306                .inner()
307                .memset_zeros(&mut out)
308                .map_err(|e| XlogError::execution_ctx("mask_filled", "mask memset", &e))?;
309            return Ok(out);
310        }
311
312        let func = self
313            .provider
314            .device()
315            .inner()
316            .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U8)
317            .ok_or_else(|| XlogError::Execution("arith fill kernel not found".into()))?;
318        let config = LaunchConfig::for_num_elems(n);
319
320        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
321        unsafe { func.clone().launch(config, (value, n, &mut out)) }
322            .map_err(|e| XlogError::execution_ctx("mask_filled", "mask fill", &e))?;
323
324        Ok(out)
325    }
326
327    pub(crate) fn wrap_single_column(
328        &self,
329        buffer: &CudaBuffer,
330        col_idx: usize,
331    ) -> Result<CudaBuffer> {
332        let col_type = buffer
333            .schema()
334            .column_type(col_idx)
335            .ok_or_else(|| XlogError::Execution(format!("Column {} not found", col_idx)))?;
336        let schema = Schema::new(vec![("expr".to_string(), col_type)]);
337
338        if buffer.is_empty() {
339            return self.create_empty_buffer(schema);
340        }
341
342        let num_rows = buffer.num_rows();
343        let bytes = (num_rows as usize)
344            .checked_mul(col_type.size_bytes())
345            .ok_or_else(|| XlogError::Execution("Column size overflow".into()))?;
346
347        let src_col = buffer
348            .column(col_idx)
349            .ok_or_else(|| XlogError::Execution(format!("Column {} not found", col_idx)))?;
350        let mut dst_col = self.provider.memory().alloc::<u8>(bytes)?;
351        if bytes > 0 {
352            self.provider
353                .device()
354                .inner()
355                .dtod_copy(src_col, &mut dst_col)
356                .map_err(|e| XlogError::execution_ctx("wrap_single_column", "copy column", &e))?;
357        }
358
359        let d_num_rows = self.clone_device_row_count(buffer)?;
360        self.provider.device().synchronize()?;
361        Ok(CudaBuffer::from_columns(
362            vec![dst_col.into()],
363            num_rows,
364            d_num_rows,
365            schema,
366        ))
367    }
368
369    /// Evaluate an arithmetic expression on a buffer, producing a single-column result
370    ///
371    /// This method recursively evaluates arithmetic expressions (Add, Sub, Mul, Div, etc.)
372    /// by delegating to the CUDA kernel provider for GPU-accelerated operations.
373    pub(crate) fn evaluate_arith_expr(
374        &self,
375        expr: &Expr,
376        input: &CudaBuffer,
377    ) -> Result<CudaBuffer> {
378        match expr {
379            Expr::Column(idx) => {
380                // Extract the column as a single-column buffer without host round-trip
381                self.wrap_single_column(input, *idx)
382            }
383            Expr::Const(val) => {
384                // Create a column filled with the constant value
385                let (bytes, col_type) = self.const_to_bytes_and_type(val);
386                self.provider.create_constant_column_with_device_count(
387                    &bytes,
388                    col_type,
389                    input.num_rows(),
390                    input.num_rows_device(),
391                )
392            }
393            Expr::Add(l, r) => {
394                let left = self.evaluate_arith_expr(l, input)?;
395                let right = self.evaluate_arith_expr(r, input)?;
396                self.provider.add_columns(&left, &right)
397            }
398            Expr::Sub(l, r) => {
399                let left = self.evaluate_arith_expr(l, input)?;
400                let right = self.evaluate_arith_expr(r, input)?;
401                self.provider.sub_columns(&left, &right)
402            }
403            Expr::Mul(l, r) => {
404                let left = self.evaluate_arith_expr(l, input)?;
405                let right = self.evaluate_arith_expr(r, input)?;
406                self.provider.mul_columns(&left, &right)
407            }
408            Expr::Div(l, r) => {
409                let left = self.evaluate_arith_expr(l, input)?;
410                let right = self.evaluate_arith_expr(r, input)?;
411                self.provider.div_columns(&left, &right)
412            }
413            Expr::Mod(l, r) => {
414                let left = self.evaluate_arith_expr(l, input)?;
415                let right = self.evaluate_arith_expr(r, input)?;
416                self.provider.mod_columns(&left, &right)
417            }
418            Expr::Abs(inner) => {
419                let val = self.evaluate_arith_expr(inner, input)?;
420                self.provider.abs_column(&val)
421            }
422            Expr::Min(l, r) => {
423                let left = self.evaluate_arith_expr(l, input)?;
424                let right = self.evaluate_arith_expr(r, input)?;
425                self.provider.min_columns(&left, &right)
426            }
427            Expr::Max(l, r) => {
428                let left = self.evaluate_arith_expr(l, input)?;
429                let right = self.evaluate_arith_expr(r, input)?;
430                self.provider.max_columns(&left, &right)
431            }
432            Expr::Pow(base, exp) => {
433                let base_buf = self.evaluate_arith_expr(base, input)?;
434                let exp_buf = self.evaluate_arith_expr(exp, input)?;
435                self.provider.pow_columns(&base_buf, &exp_buf)
436            }
437            Expr::Cast(inner, target_type) => {
438                let val = self.evaluate_arith_expr(inner, input)?;
439                self.provider.cast_column(&val, *target_type)
440            }
441            Expr::Conditional {
442                condition,
443                then_expr,
444                else_expr,
445            } => {
446                // Evaluate condition to get boolean mask
447                let mask_slice = self.eval_predicate_mask_gpu(condition, input)?;
448
449                // Convert mask slice to a CudaBuffer for select_columns
450                let d_num_rows = self.clone_device_row_count(input)?;
451                let mask_buffer = CudaBuffer::from_columns(
452                    vec![mask_slice.into()],
453                    input.num_rows(),
454                    d_num_rows,
455                    Schema::new(vec![("mask".to_string(), ScalarType::Bool)]),
456                );
457
458                // Evaluate both branches
459                let then_buf = self.evaluate_arith_expr(then_expr, input)?;
460                let else_buf = self.evaluate_arith_expr(else_expr, input)?;
461
462                // Select based on mask
463                self.provider
464                    .select_columns(&mask_buffer, &then_buf, &else_buf)
465            }
466            _ => Err(XlogError::Execution(format!(
467                "Unsupported expression in arithmetic evaluation: {:?}",
468                expr
469            ))),
470        }
471    }
472
473    /// Convert a ConstValue to raw bytes and ScalarType
474    pub(crate) fn const_to_bytes_and_type(&self, val: &ConstValue) -> (Vec<u8>, ScalarType) {
475        match val {
476            ConstValue::U32(v) => (v.to_le_bytes().to_vec(), ScalarType::U32),
477            ConstValue::U64(v) => (v.to_le_bytes().to_vec(), ScalarType::U64),
478            ConstValue::I32(v) => (v.to_le_bytes().to_vec(), ScalarType::I32),
479            ConstValue::I64(v) => (v.to_le_bytes().to_vec(), ScalarType::I64),
480            ConstValue::F32(v) => (v.to_le_bytes().to_vec(), ScalarType::F32),
481            ConstValue::F64(v) => (v.to_le_bytes().to_vec(), ScalarType::F64),
482            ConstValue::Bool(v) => (vec![if *v { 1u8 } else { 0u8 }], ScalarType::Bool),
483            ConstValue::Symbol(s) => (
484                xlog_core::symbol::intern(s).to_le_bytes().to_vec(),
485                ScalarType::Symbol,
486            ),
487        }
488    }
489
490    /// Execute a Project node
491    ///
492    /// Selects and reorders columns according to the projection list.
493    /// Supports both column pass-through and computed expressions.
494    pub(crate) fn execute_project(
495        &self,
496        input: &CudaBuffer,
497        columns: &[ProjectExpr],
498    ) -> Result<CudaBuffer> {
499        if input.is_empty() {
500            // Build projected schema
501            let projected_schema = self.project_schema(input.schema(), columns)?;
502            return self.create_empty_buffer(projected_schema);
503        }
504
505        // Build result columns as single-column CudaBuffers
506        let mut result_buffers: Vec<CudaBuffer> = Vec::with_capacity(columns.len());
507        let mut result_types: Vec<ScalarType> = Vec::with_capacity(columns.len());
508
509        for proj_expr in columns {
510            match proj_expr {
511                ProjectExpr::Column(col_idx) => {
512                    // Use extract_column to get a single-column buffer
513                    let col_buffer = self.provider.extract_column(input, *col_idx)?;
514                    let col_type = input
515                        .schema()
516                        .column_type(*col_idx)
517                        .unwrap_or(ScalarType::U64);
518                    result_types.push(col_type);
519                    result_buffers.push(col_buffer);
520                }
521                ProjectExpr::Computed(expr, result_type) => {
522                    // Evaluate the arithmetic expression to get a single-column buffer
523                    let computed_buffer = self.evaluate_arith_expr(expr, input)?;
524                    result_types.push(*result_type);
525                    result_buffers.push(computed_buffer);
526                }
527            }
528        }
529
530        let projected_schema = self.project_schema(input.schema(), columns)?;
531        let mut output = self
532            .provider
533            .combine_columns(result_buffers, result_types)?;
534        output.schema = projected_schema;
535        Ok(output)
536    }
537
538    /// Build a projected schema from ProjectExpr list
539    pub(crate) fn project_schema(&self, input: &Schema, columns: &[ProjectExpr]) -> Result<Schema> {
540        let mut projected_columns: Vec<(String, ScalarType)> = Vec::with_capacity(columns.len());
541        let mut projected_sort_labels: Vec<String> = Vec::with_capacity(columns.len());
542        for proj_expr in columns {
543            match proj_expr {
544                ProjectExpr::Column(col_idx) => {
545                    if let Some((name, ty)) = input.columns.get(*col_idx) {
546                        projected_columns.push((name.clone(), *ty));
547                        projected_sort_labels.push(
548                            input
549                                .column_sort_label(*col_idx)
550                                .unwrap_or(name)
551                                .to_string(),
552                        );
553                    } else {
554                        return Err(XlogError::Execution(format!(
555                            "Column index {} out of bounds",
556                            col_idx
557                        )));
558                    }
559                }
560                ProjectExpr::Computed(_expr, result_type) => {
561                    // Computed columns get a generated name
562                    let col_name = format!("computed_{}", projected_columns.len());
563                    projected_columns.push((col_name, *result_type));
564                    projected_sort_labels.push(format!("computed_{}", projected_sort_labels.len()));
565                }
566            }
567        }
568        Schema::new(projected_columns)
569            .with_sort_labels(projected_sort_labels)
570            .map_err(XlogError::Execution)
571    }
572}