1use cudarc::driver::LaunchConfig;
7use xlog_core::{Result, ScalarType, Schema, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::provider::{arith_kernels, filter_kernels, ARITH_MODULE, FILTER_MODULE};
10use xlog_cuda::{CudaBuffer, LaunchAsync};
11use xlog_ir::{CompareOp, ConstValue, Expr, ProjectExpr};
12
13use super::Executor;
14
15impl Executor {
16 pub(crate) fn expr_may_be_float(expr: &Expr, schema: &Schema) -> bool {
18 match expr {
19 Expr::Column(col_idx) => matches!(
20 schema.column_type(*col_idx),
21 Some(ScalarType::F32 | ScalarType::F64)
22 ),
23 Expr::Const(ConstValue::F32(_) | ConstValue::F64(_)) => true,
24 Expr::Cast(_, ScalarType::F32 | ScalarType::F64) => true,
25 Expr::Add(l, r)
26 | Expr::Sub(l, r)
27 | Expr::Mul(l, r)
28 | Expr::Div(l, r)
29 | Expr::Mod(l, r)
30 | Expr::Min(l, r)
31 | Expr::Max(l, r)
32 | Expr::Pow(l, r) => {
33 Self::expr_may_be_float(l, schema) || Self::expr_may_be_float(r, schema)
34 }
35 Expr::Abs(inner) | Expr::Cast(inner, _) => Self::expr_may_be_float(inner, schema),
36 _ => false,
37 }
38 }
39
40 pub fn execute_filter(&self, input: &CudaBuffer, predicate: &Expr) -> Result<CudaBuffer> {
42 if input.is_empty() {
43 return self.create_empty_buffer(input.schema().clone());
44 }
45
46 let mask = self.eval_predicate_mask_gpu(predicate, input)?;
47 self.provider.filter_by_device_mask(input, &mask)
48 }
49
50 pub(crate) fn eval_predicate_mask_gpu(
51 &self,
52 expr: &Expr,
53 input: &CudaBuffer,
54 ) -> Result<TrackedCudaSlice<u8>> {
55 if input.num_rows() > u32::MAX as u64 {
56 return Err(XlogError::Execution(format!(
57 "Predicate evaluation supports at most {} rows, got {}",
58 u32::MAX,
59 input.num_rows()
60 )));
61 }
62 let n = input.num_rows() as u32;
63
64 match expr {
65 Expr::Column(col_idx) => {
66 let col_type = input
67 .schema()
68 .column_type(*col_idx)
69 .ok_or_else(|| XlogError::Execution(format!("Column {} not found", col_idx)))?;
70 if col_type == ScalarType::Bool {
71 let col_buf = self.wrap_single_column(input, *col_idx)?;
72 let zero = self.provider.create_constant_column_with_device_count(
73 &[0u8],
74 ScalarType::Bool,
75 input.num_rows(),
76 input.num_rows_device(),
77 )?;
78 return self.compare_buffers_mask(&col_buf, &zero, CompareOp::Ne);
79 }
80 self.mask_filled(n, 1)
81 }
82 Expr::Const(ConstValue::Bool(b)) => self.mask_filled(n, if *b { 1 } else { 0 }),
83 Expr::Const(_) => self.mask_filled(n, 1),
84 Expr::Compare { left, op, right } => {
85 let use_float = Self::expr_may_be_float(left, input.schema())
86 || Self::expr_may_be_float(right, input.schema());
87
88 let mut left_buf = self.evaluate_arith_expr(left, input)?;
89 let mut right_buf = self.evaluate_arith_expr(right, input)?;
90
91 if use_float {
92 left_buf = self.provider.cast_column(&left_buf, ScalarType::F64)?;
93 right_buf = self.provider.cast_column(&right_buf, ScalarType::F64)?;
94 }
95
96 self.compare_buffers_mask(&left_buf, &right_buf, *op)
97 }
98 Expr::And(exprs) => {
99 if exprs.is_empty() {
100 return self.mask_filled(n, 1);
101 }
102 let mut mask = self.eval_predicate_mask_gpu(&exprs[0], input)?;
103 for expr in &exprs[1..] {
104 let next = self.eval_predicate_mask_gpu(expr, input)?;
105 mask = self.mask_and(&mask, &next, n)?;
106 }
107 Ok(mask)
108 }
109 Expr::Or(exprs) => {
110 if exprs.is_empty() {
111 return self.mask_filled(n, 0);
112 }
113 let mut mask = self.eval_predicate_mask_gpu(&exprs[0], input)?;
114 for expr in &exprs[1..] {
115 let next = self.eval_predicate_mask_gpu(expr, input)?;
116 mask = self.mask_or(&mask, &next, n)?;
117 }
118 Ok(mask)
119 }
120 Expr::Not(inner) => {
121 let mask = self.eval_predicate_mask_gpu(inner, input)?;
122 self.mask_not(&mask, n)
123 }
124 Expr::Add(_, _)
125 | Expr::Sub(_, _)
126 | Expr::Mul(_, _)
127 | Expr::Div(_, _)
128 | Expr::Mod(_, _)
129 | Expr::Abs(_)
130 | Expr::Min(_, _)
131 | Expr::Max(_, _)
132 | Expr::Pow(_, _)
133 | Expr::Cast(_, _)
134 | Expr::Conditional { .. } => Err(XlogError::Execution(
135 "Arithmetic expression cannot be evaluated as boolean predicate".into(),
136 )),
137 }
138 }
139
140 fn compare_buffers_mask(
141 &self,
142 left: &CudaBuffer,
143 right: &CudaBuffer,
144 op: CompareOp,
145 ) -> Result<TrackedCudaSlice<u8>> {
146 if left.arity() != 1 || right.arity() != 1 {
147 return Err(XlogError::Execution(
148 "Compare requires single-column buffers".into(),
149 ));
150 }
151 if left.num_rows() != right.num_rows() {
152 return Err(XlogError::Execution(
153 "Compare requires matching row counts".into(),
154 ));
155 }
156 if left.num_rows() > u32::MAX as u64 {
157 return Err(XlogError::Execution(format!(
158 "Compare supports at most {} rows, got {}",
159 u32::MAX,
160 left.num_rows()
161 )));
162 }
163 if left.is_empty() {
164 return self.provider.memory().alloc::<u8>(0).map_err(|e| {
165 XlogError::execution_ctx("compare_buffers_mask", "allocate empty mask", &e)
166 });
167 }
168
169 let left_type = left
170 .schema()
171 .column_type(0)
172 .ok_or_else(|| XlogError::Execution("Missing left column type".into()))?;
173 let right_type = right
174 .schema()
175 .column_type(0)
176 .ok_or_else(|| XlogError::Execution("Missing right column type".into()))?;
177
178 if left_type != right_type {
179 return Err(XlogError::Execution(
180 "Compare requires matching column types".into(),
181 ));
182 }
183
184 let kernel = match left_type {
185 ScalarType::U32 | ScalarType::Symbol => filter_kernels::FILTER_COMPARE_U32_COL,
186 ScalarType::U64 => filter_kernels::FILTER_COMPARE_U64_COL,
187 ScalarType::I32 => filter_kernels::FILTER_COMPARE_I32_COL,
188 ScalarType::I64 => filter_kernels::FILTER_COMPARE_I64_COL,
189 ScalarType::F32 => filter_kernels::FILTER_COMPARE_F32_COL,
190 ScalarType::F64 => filter_kernels::FILTER_COMPARE_F64_COL,
191 ScalarType::Bool => filter_kernels::FILTER_COMPARE_U8_COL,
192 };
193
194 let left_col = left
195 .column(0)
196 .ok_or_else(|| XlogError::Execution("Missing left column".into()))?;
197 let right_col = right
198 .column(0)
199 .ok_or_else(|| XlogError::Execution("Missing right column".into()))?;
200
201 let num_rows = left.num_rows() as u32;
202 let mut d_mask = self.provider.memory().alloc::<u8>(num_rows as usize)?;
203
204 let func = self
205 .provider
206 .device()
207 .inner()
208 .get_func(FILTER_MODULE, kernel)
209 .ok_or_else(|| XlogError::Execution("filter compare kernel not found".into()))?;
210 let config = LaunchConfig::for_num_elems(num_rows);
211
212 unsafe {
214 func.clone().launch(
215 config,
216 (left_col, right_col, num_rows, op as u8, &mut d_mask),
217 )
218 }
219 .map_err(|e| XlogError::execution_ctx("compare_buffers_mask", "filter compare", &e))?;
220
221 Ok(d_mask)
222 }
223
224 fn mask_and(
225 &self,
226 left: &TrackedCudaSlice<u8>,
227 right: &TrackedCudaSlice<u8>,
228 n: u32,
229 ) -> Result<TrackedCudaSlice<u8>> {
230 let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
231 if n == 0 {
232 return Ok(out);
233 }
234
235 let func = self
236 .provider
237 .device()
238 .inner()
239 .get_func(FILTER_MODULE, filter_kernels::MASK_AND)
240 .ok_or_else(|| XlogError::Execution("mask_and kernel not found".into()))?;
241 let config = LaunchConfig::for_num_elems(n);
242
243 unsafe { func.clone().launch(config, (left, right, &mut out, n)) }
245 .map_err(|e| XlogError::execution_ctx("mask_and", "launch kernel", &e))?;
246
247 Ok(out)
248 }
249
250 fn mask_or(
251 &self,
252 left: &TrackedCudaSlice<u8>,
253 right: &TrackedCudaSlice<u8>,
254 n: u32,
255 ) -> Result<TrackedCudaSlice<u8>> {
256 let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
257 if n == 0 {
258 return Ok(out);
259 }
260
261 let func = self
262 .provider
263 .device()
264 .inner()
265 .get_func(FILTER_MODULE, filter_kernels::MASK_OR)
266 .ok_or_else(|| XlogError::Execution("mask_or kernel not found".into()))?;
267 let config = LaunchConfig::for_num_elems(n);
268
269 unsafe { func.clone().launch(config, (left, right, &mut out, n)) }
271 .map_err(|e| XlogError::execution_ctx("mask_or", "launch kernel", &e))?;
272
273 Ok(out)
274 }
275
276 fn mask_not(&self, input: &TrackedCudaSlice<u8>, n: u32) -> Result<TrackedCudaSlice<u8>> {
277 let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
278 if n == 0 {
279 return Ok(out);
280 }
281
282 let func = self
283 .provider
284 .device()
285 .inner()
286 .get_func(FILTER_MODULE, filter_kernels::MASK_NOT)
287 .ok_or_else(|| XlogError::Execution("mask_not kernel not found".into()))?;
288 let config = LaunchConfig::for_num_elems(n);
289
290 unsafe { func.clone().launch(config, (input, &mut out, n)) }
292 .map_err(|e| XlogError::execution_ctx("mask_not", "launch kernel", &e))?;
293
294 Ok(out)
295 }
296
297 fn mask_filled(&self, n: u32, value: u8) -> Result<TrackedCudaSlice<u8>> {
298 let mut out = self.provider.memory().alloc::<u8>(n as usize)?;
299 if n == 0 {
300 return Ok(out);
301 }
302
303 if value == 0 {
304 self.provider
305 .device()
306 .inner()
307 .memset_zeros(&mut out)
308 .map_err(|e| XlogError::execution_ctx("mask_filled", "mask memset", &e))?;
309 return Ok(out);
310 }
311
312 let func = self
313 .provider
314 .device()
315 .inner()
316 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U8)
317 .ok_or_else(|| XlogError::Execution("arith fill kernel not found".into()))?;
318 let config = LaunchConfig::for_num_elems(n);
319
320 unsafe { func.clone().launch(config, (value, n, &mut out)) }
322 .map_err(|e| XlogError::execution_ctx("mask_filled", "mask fill", &e))?;
323
324 Ok(out)
325 }
326
327 pub(crate) fn wrap_single_column(
328 &self,
329 buffer: &CudaBuffer,
330 col_idx: usize,
331 ) -> Result<CudaBuffer> {
332 let col_type = buffer
333 .schema()
334 .column_type(col_idx)
335 .ok_or_else(|| XlogError::Execution(format!("Column {} not found", col_idx)))?;
336 let schema = Schema::new(vec![("expr".to_string(), col_type)]);
337
338 if buffer.is_empty() {
339 return self.create_empty_buffer(schema);
340 }
341
342 let num_rows = buffer.num_rows();
343 let bytes = (num_rows as usize)
344 .checked_mul(col_type.size_bytes())
345 .ok_or_else(|| XlogError::Execution("Column size overflow".into()))?;
346
347 let src_col = buffer
348 .column(col_idx)
349 .ok_or_else(|| XlogError::Execution(format!("Column {} not found", col_idx)))?;
350 let mut dst_col = self.provider.memory().alloc::<u8>(bytes)?;
351 if bytes > 0 {
352 self.provider
353 .device()
354 .inner()
355 .dtod_copy(src_col, &mut dst_col)
356 .map_err(|e| XlogError::execution_ctx("wrap_single_column", "copy column", &e))?;
357 }
358
359 let d_num_rows = self.clone_device_row_count(buffer)?;
360 self.provider.device().synchronize()?;
361 Ok(CudaBuffer::from_columns(
362 vec![dst_col.into()],
363 num_rows,
364 d_num_rows,
365 schema,
366 ))
367 }
368
369 pub(crate) fn evaluate_arith_expr(
374 &self,
375 expr: &Expr,
376 input: &CudaBuffer,
377 ) -> Result<CudaBuffer> {
378 match expr {
379 Expr::Column(idx) => {
380 self.wrap_single_column(input, *idx)
382 }
383 Expr::Const(val) => {
384 let (bytes, col_type) = self.const_to_bytes_and_type(val);
386 self.provider.create_constant_column_with_device_count(
387 &bytes,
388 col_type,
389 input.num_rows(),
390 input.num_rows_device(),
391 )
392 }
393 Expr::Add(l, r) => {
394 let left = self.evaluate_arith_expr(l, input)?;
395 let right = self.evaluate_arith_expr(r, input)?;
396 self.provider.add_columns(&left, &right)
397 }
398 Expr::Sub(l, r) => {
399 let left = self.evaluate_arith_expr(l, input)?;
400 let right = self.evaluate_arith_expr(r, input)?;
401 self.provider.sub_columns(&left, &right)
402 }
403 Expr::Mul(l, r) => {
404 let left = self.evaluate_arith_expr(l, input)?;
405 let right = self.evaluate_arith_expr(r, input)?;
406 self.provider.mul_columns(&left, &right)
407 }
408 Expr::Div(l, r) => {
409 let left = self.evaluate_arith_expr(l, input)?;
410 let right = self.evaluate_arith_expr(r, input)?;
411 self.provider.div_columns(&left, &right)
412 }
413 Expr::Mod(l, r) => {
414 let left = self.evaluate_arith_expr(l, input)?;
415 let right = self.evaluate_arith_expr(r, input)?;
416 self.provider.mod_columns(&left, &right)
417 }
418 Expr::Abs(inner) => {
419 let val = self.evaluate_arith_expr(inner, input)?;
420 self.provider.abs_column(&val)
421 }
422 Expr::Min(l, r) => {
423 let left = self.evaluate_arith_expr(l, input)?;
424 let right = self.evaluate_arith_expr(r, input)?;
425 self.provider.min_columns(&left, &right)
426 }
427 Expr::Max(l, r) => {
428 let left = self.evaluate_arith_expr(l, input)?;
429 let right = self.evaluate_arith_expr(r, input)?;
430 self.provider.max_columns(&left, &right)
431 }
432 Expr::Pow(base, exp) => {
433 let base_buf = self.evaluate_arith_expr(base, input)?;
434 let exp_buf = self.evaluate_arith_expr(exp, input)?;
435 self.provider.pow_columns(&base_buf, &exp_buf)
436 }
437 Expr::Cast(inner, target_type) => {
438 let val = self.evaluate_arith_expr(inner, input)?;
439 self.provider.cast_column(&val, *target_type)
440 }
441 Expr::Conditional {
442 condition,
443 then_expr,
444 else_expr,
445 } => {
446 let mask_slice = self.eval_predicate_mask_gpu(condition, input)?;
448
449 let d_num_rows = self.clone_device_row_count(input)?;
451 let mask_buffer = CudaBuffer::from_columns(
452 vec![mask_slice.into()],
453 input.num_rows(),
454 d_num_rows,
455 Schema::new(vec![("mask".to_string(), ScalarType::Bool)]),
456 );
457
458 let then_buf = self.evaluate_arith_expr(then_expr, input)?;
460 let else_buf = self.evaluate_arith_expr(else_expr, input)?;
461
462 self.provider
464 .select_columns(&mask_buffer, &then_buf, &else_buf)
465 }
466 _ => Err(XlogError::Execution(format!(
467 "Unsupported expression in arithmetic evaluation: {:?}",
468 expr
469 ))),
470 }
471 }
472
473 pub(crate) fn const_to_bytes_and_type(&self, val: &ConstValue) -> (Vec<u8>, ScalarType) {
475 match val {
476 ConstValue::U32(v) => (v.to_le_bytes().to_vec(), ScalarType::U32),
477 ConstValue::U64(v) => (v.to_le_bytes().to_vec(), ScalarType::U64),
478 ConstValue::I32(v) => (v.to_le_bytes().to_vec(), ScalarType::I32),
479 ConstValue::I64(v) => (v.to_le_bytes().to_vec(), ScalarType::I64),
480 ConstValue::F32(v) => (v.to_le_bytes().to_vec(), ScalarType::F32),
481 ConstValue::F64(v) => (v.to_le_bytes().to_vec(), ScalarType::F64),
482 ConstValue::Bool(v) => (vec![if *v { 1u8 } else { 0u8 }], ScalarType::Bool),
483 ConstValue::Symbol(s) => (
484 xlog_core::symbol::intern(s).to_le_bytes().to_vec(),
485 ScalarType::Symbol,
486 ),
487 }
488 }
489
490 pub(crate) fn execute_project(
495 &self,
496 input: &CudaBuffer,
497 columns: &[ProjectExpr],
498 ) -> Result<CudaBuffer> {
499 if input.is_empty() {
500 let projected_schema = self.project_schema(input.schema(), columns)?;
502 return self.create_empty_buffer(projected_schema);
503 }
504
505 let mut result_buffers: Vec<CudaBuffer> = Vec::with_capacity(columns.len());
507 let mut result_types: Vec<ScalarType> = Vec::with_capacity(columns.len());
508
509 for proj_expr in columns {
510 match proj_expr {
511 ProjectExpr::Column(col_idx) => {
512 let col_buffer = self.provider.extract_column(input, *col_idx)?;
514 let col_type = input
515 .schema()
516 .column_type(*col_idx)
517 .unwrap_or(ScalarType::U64);
518 result_types.push(col_type);
519 result_buffers.push(col_buffer);
520 }
521 ProjectExpr::Computed(expr, result_type) => {
522 let computed_buffer = self.evaluate_arith_expr(expr, input)?;
524 result_types.push(*result_type);
525 result_buffers.push(computed_buffer);
526 }
527 }
528 }
529
530 let projected_schema = self.project_schema(input.schema(), columns)?;
531 let mut output = self
532 .provider
533 .combine_columns(result_buffers, result_types)?;
534 output.schema = projected_schema;
535 Ok(output)
536 }
537
538 pub(crate) fn project_schema(&self, input: &Schema, columns: &[ProjectExpr]) -> Result<Schema> {
540 let mut projected_columns: Vec<(String, ScalarType)> = Vec::with_capacity(columns.len());
541 let mut projected_sort_labels: Vec<String> = Vec::with_capacity(columns.len());
542 for proj_expr in columns {
543 match proj_expr {
544 ProjectExpr::Column(col_idx) => {
545 if let Some((name, ty)) = input.columns.get(*col_idx) {
546 projected_columns.push((name.clone(), *ty));
547 projected_sort_labels.push(
548 input
549 .column_sort_label(*col_idx)
550 .unwrap_or(name)
551 .to_string(),
552 );
553 } else {
554 return Err(XlogError::Execution(format!(
555 "Column index {} out of bounds",
556 col_idx
557 )));
558 }
559 }
560 ProjectExpr::Computed(_expr, result_type) => {
561 let col_name = format!("computed_{}", projected_columns.len());
563 projected_columns.push((col_name, *result_type));
564 projected_sort_labels.push(format!("computed_{}", projected_sort_labels.len()));
565 }
566 }
567 }
568 Schema::new(projected_columns)
569 .with_sort_labels(projected_sort_labels)
570 .map_err(XlogError::Execution)
571 }
572}