1use crate::{LaunchAsync, LaunchConfig};
4use xlog_core::{AggOp, Result, ScalarType, Schema, XlogError};
5
6use super::{
7 arith_kernels, groupby_kernels, pack_kernels, scan_kernels, ARITH_MODULE, GROUPBY_MODULE,
8 PACK_MODULE, SCAN_MODULE,
9};
10use crate::memory::{CudaColumn, TrackedCudaSlice};
11use crate::CudaBuffer;
12
13impl super::CudaKernelProvider {
14 pub fn groupby_agg(
30 &self,
31 input: &CudaBuffer,
32 key_cols: &[usize],
33 agg: AggOp,
34 value_col: usize,
35 ) -> Result<CudaBuffer> {
36 self.groupby_multi_agg(input, key_cols, &[(value_col, agg)])
37 }
38
39 pub fn groupby_multi_agg(
67 &self,
68 buffer: &CudaBuffer,
69 key_cols: &[usize],
70 aggs: &[(usize, AggOp)],
71 ) -> Result<CudaBuffer> {
72 if Self::use_recorded_groupby_env()
78 && !key_cols.is_empty()
79 && !aggs.is_empty()
80 && key_cols.len() <= 4
81 {
82 if let Some(launch_stream) = self.recorded_op_stream_or_init() {
83 let keys_compatible = key_cols.iter().all(|&k| {
84 matches!(
85 buffer.schema.column_type(k),
86 Some(ScalarType::U32) | Some(ScalarType::Symbol)
87 )
88 });
89 let aggs_compatible = aggs.iter().all(|&(_, op)| {
90 matches!(op, AggOp::Count | AggOp::Sum | AggOp::Min | AggOp::Max)
91 });
92 if keys_compatible && aggs_compatible {
93 return self.groupby_multi_agg_recorded(buffer, key_cols, aggs, launch_stream);
94 }
95 }
96 }
97 let num_rows = self.device_row_count(buffer)?;
98 if num_rows == 0 {
99 let result_schema =
100 self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
101 return self.create_empty_buffer(result_schema);
102 }
103 if num_rows > u32::MAX as usize {
104 return Err(XlogError::Kernel(format!(
105 "GroupBy supports at most {} rows, got {}",
106 u32::MAX,
107 num_rows
108 )));
109 }
110
111 if key_cols.is_empty() {
113 return Err(XlogError::Kernel(
114 "GroupBy requires at least one key column".to_string(),
115 ));
116 }
117 if aggs.is_empty() {
118 return Err(XlogError::Kernel(
119 "GroupBy requires at least one aggregation".to_string(),
120 ));
121 }
122
123 for &key_col in key_cols {
125 if key_col >= buffer.arity() {
126 return Err(XlogError::Kernel(format!(
127 "Key column {} out of bounds (arity {})",
128 key_col,
129 buffer.arity()
130 )));
131 }
132 }
133
134 for &(value_col, agg_op) in aggs {
136 if value_col >= buffer.arity() {
137 return Err(XlogError::Kernel(format!(
138 "Value column {} out of bounds (arity {})",
139 value_col,
140 buffer.arity()
141 )));
142 }
143
144 let value_ty = buffer
145 .schema()
146 .column_type(value_col)
147 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
148 match agg_op {
149 AggOp::Count => {}
150 AggOp::Sum | AggOp::Min | AggOp::Max => {
154 if !matches!(value_ty, ScalarType::U32 | ScalarType::U64) {
155 return Err(XlogError::Kernel(format!(
156 "{:?} currently requires U32 or U64 values, got {:?}",
157 agg_op, value_ty
158 )));
159 }
160 }
161 AggOp::LogSumExp => {
162 if value_ty != ScalarType::F64 {
163 return Err(XlogError::Kernel(format!(
164 "LogSumExp requires F64 values, got {:?}",
165 value_ty
166 )));
167 }
168 }
169 }
170 }
171
172 let sorted = self.sort(buffer, key_cols)?;
174 let num_rows = self.device_row_count(&sorted)?;
175 if num_rows > u32::MAX as usize {
176 return Err(XlogError::Kernel(format!(
177 "GroupBy supports at most {} rows, got {}",
178 u32::MAX,
179 num_rows
180 )));
181 }
182 let num_rows = num_rows as u32;
183
184 let boundary_func = self
186 .device
187 .inner()
188 .get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES)
189 .ok_or_else(|| {
190 XlogError::Kernel("detect_group_boundaries kernel not found".to_string())
191 })?;
192
193 let boundaries = self.memory.alloc::<u8>(num_rows as usize)?;
195
196 let packed = self.compute_hashes_and_pack_keys(&sorted, key_cols)?;
197 if packed.key_bytes == 0 || packed.key_bytes % 4 != 0 {
198 return Err(XlogError::Kernel(format!(
199 "GroupBy key packing produced {} bytes per row (expected multiple of 4); Bool keys are not supported",
200 packed.key_bytes
201 )));
202 }
203
204 let segments_per_row = (packed.key_bytes / 4) as usize;
205 let total_segments = (num_rows as usize) * segments_per_row;
206 let packed_u32 = self.bytes_as_u32_view(&packed.packed_keys, total_segments)?;
207
208 let block_size = 256u32;
210 let grid_size = num_rows.div_ceil(block_size);
211 let config = LaunchConfig {
212 grid_dim: (grid_size, 1, 1),
213 block_dim: (block_size, 1, 1),
214 shared_mem_bytes: 0,
215 };
216
217 unsafe {
219 boundary_func.clone().launch(
220 config,
221 (
222 &packed_u32,
223 num_rows,
224 segments_per_row as u32,
225 segments_per_row as u32,
226 &boundaries,
227 ),
228 )
229 }
230 .map_err(|e| XlogError::Kernel(format!("detect_group_boundaries failed: {}", e)))?;
231
232 self.device.synchronize()?;
233
234 let device = self.device.inner();
236 let num_blocks = grid_size;
237 let d_boundary_pos = self.memory.alloc::<u32>(num_rows as usize)?;
238 let mut d_block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
239
240 let phase1_fn = device
241 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE1)
242 .ok_or_else(|| {
243 XlogError::Kernel("Failed to get multiblock_scan_phase1 kernel".to_string())
244 })?;
245
246 unsafe {
248 phase1_fn.clone().launch(
249 LaunchConfig {
250 grid_dim: (num_blocks, 1, 1),
251 block_dim: (block_size, 1, 1),
252 shared_mem_bytes: 0,
253 },
254 (&boundaries, &d_boundary_pos, &d_block_sums, num_rows),
255 )
256 }
257 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase1 failed: {}", e)))?;
258
259 if num_blocks > 1 {
260 self.multiblock_scan_u32_inplace(&mut d_block_sums, num_blocks)?;
261
262 let phase3_fn = device
263 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
264 .ok_or_else(|| {
265 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
266 })?;
267
268 unsafe {
270 phase3_fn.clone().launch(
271 LaunchConfig {
272 grid_dim: (num_blocks, 1, 1),
273 block_dim: (block_size, 1, 1),
274 shared_mem_bytes: 0,
275 },
276 (&d_boundary_pos, &d_block_sums, num_rows),
277 )
278 }
279 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
280 }
281
282 self.device.synchronize()?;
283 let d_num_groups = self.capture_num_groups(&d_boundary_pos, &boundaries, num_rows)?;
284 let row_cap = num_rows as u64;
285 let row_cap_usize = num_rows as usize;
286 let row_cap_u32 = num_rows;
287
288 let mut group_ids = self.memory.alloc::<u32>(num_rows as usize)?;
289 let mut group_first_idx = self.memory.alloc::<u32>(row_cap_usize)?;
290
291 let group_ids_fn = device
292 .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_IDS_FROM_BOUNDARIES)
293 .ok_or_else(|| {
294 XlogError::Kernel("group_ids_from_boundaries kernel not found".to_string())
295 })?;
296 let group_start_fn = device
297 .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_START_INDICES)
298 .ok_or_else(|| XlogError::Kernel("group_start_indices kernel not found".to_string()))?;
299
300 unsafe {
302 group_ids_fn.clone().launch(
303 config,
304 (&boundaries, &d_boundary_pos, num_rows, &mut group_ids),
305 )
306 }
307 .map_err(|e| XlogError::Kernel(format!("group_ids_from_boundaries failed: {}", e)))?;
308
309 unsafe {
311 group_start_fn.clone().launch(
312 config,
313 (&boundaries, &d_boundary_pos, num_rows, &mut group_first_idx),
314 )
315 }
316 .map_err(|e| XlogError::Kernel(format!("group_start_indices failed: {}", e)))?;
317
318 self.device.synchronize()?;
319
320 let mut agg_columns: Vec<CudaColumn> = Vec::with_capacity(aggs.len());
322
323 for &(value_col, agg_op) in aggs {
324 let values = sorted
325 .column(value_col)
326 .ok_or_else(|| XlogError::Kernel("Value column not found".to_string()))?;
327
328 match agg_op {
329 AggOp::Count => {
330 let output_bytes = row_cap_usize
331 .checked_mul(std::mem::size_of::<u64>())
332 .ok_or_else(|| {
333 XlogError::Kernel("Count output size overflow".to_string())
334 })?;
335 let mut output = self.memory.alloc::<u8>(output_bytes)?;
336 device.memset_zeros(&mut output).map_err(|e| {
337 XlogError::Kernel(format!("Failed to zero count output: {}", e))
338 })?;
339
340 let count_func = device
341 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT)
342 .ok_or_else(|| {
343 XlogError::Kernel("groupby_count kernel not found".to_string())
344 })?;
345
346 unsafe {
348 count_func
349 .clone()
350 .launch(config, (&boundaries, &group_ids, num_rows, &output))
351 }
352 .map_err(|e| XlogError::Kernel(format!("groupby_count failed: {}", e)))?;
353
354 self.device.synchronize()?;
355 agg_columns.push(output.into());
356 }
357 AggOp::Sum => {
358 let value_ty = sorted
359 .schema()
360 .column_type(value_col)
361 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
362 let output_bytes = row_cap_usize
363 .checked_mul(std::mem::size_of::<u64>())
364 .ok_or_else(|| XlogError::Kernel("Sum output size overflow".to_string()))?;
365 let mut output = self.memory.alloc::<u8>(output_bytes)?;
366 device.memset_zeros(&mut output).map_err(|e| {
367 XlogError::Kernel(format!("Failed to zero sum output: {}", e))
368 })?;
369
370 if value_ty == ScalarType::U64 {
373 let values_view = self.column_as_u64_view(values, num_rows as usize)?;
374 let sum_func = device
375 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM_U64)
376 .ok_or_else(|| {
377 XlogError::Kernel("groupby_sum_u64 kernel not found".to_string())
378 })?;
379 unsafe {
381 sum_func
382 .clone()
383 .launch(config, (&values_view, &group_ids, num_rows, &output))
384 }
385 .map_err(|e| XlogError::Kernel(format!("groupby_sum_u64 failed: {}", e)))?;
386 } else {
387 let values_view = self.column_as_u32_view(values, num_rows as usize)?;
388 let sum_func = device
389 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM)
390 .ok_or_else(|| {
391 XlogError::Kernel("groupby_sum kernel not found".to_string())
392 })?;
393 unsafe {
395 sum_func
396 .clone()
397 .launch(config, (&values_view, &group_ids, num_rows, &output))
398 }
399 .map_err(|e| XlogError::Kernel(format!("groupby_sum failed: {}", e)))?;
400 }
401
402 self.device.synchronize()?;
403 agg_columns.push(output.into());
404 }
405 AggOp::Min => {
406 let value_ty = sorted
407 .schema()
408 .column_type(value_col)
409 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
410 if value_ty == ScalarType::U64 {
411 let values_view = self.column_as_u64_view(values, num_rows as usize)?;
414 let output_bytes = row_cap_usize
415 .checked_mul(std::mem::size_of::<u64>())
416 .ok_or_else(|| {
417 XlogError::Kernel("Min output size overflow".to_string())
418 })?;
419 let mut output = self.memory.alloc::<u8>(output_bytes)?;
420 let fill_fn = device
421 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U64)
422 .ok_or_else(|| {
423 XlogError::Kernel("arith_fill_const_u64 not found".to_string())
424 })?;
425 let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
426 unsafe {
428 fill_fn
429 .clone()
430 .launch(fill_config, (u64::MAX, row_cap_u32, &mut output))
431 }
432 .map_err(|e| {
433 XlogError::Kernel(format!("Failed to init min output: {}", e))
434 })?;
435
436 let min_func = device
437 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN_U64)
438 .ok_or_else(|| {
439 XlogError::Kernel("groupby_min_u64 kernel not found".to_string())
440 })?;
441 unsafe {
443 min_func
444 .clone()
445 .launch(config, (&values_view, &group_ids, num_rows, &output))
446 }
447 .map_err(|e| XlogError::Kernel(format!("groupby_min_u64 failed: {}", e)))?;
448
449 self.device.synchronize()?;
450 agg_columns.push(output.into());
451 } else {
452 let values_view = self.column_as_u32_view(values, num_rows as usize)?;
453 let output_bytes = row_cap_usize
454 .checked_mul(std::mem::size_of::<u32>())
455 .ok_or_else(|| {
456 XlogError::Kernel("Min output size overflow".to_string())
457 })?;
458 let mut output = self.memory.alloc::<u8>(output_bytes)?;
459 let fill_fn = device
460 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U32)
461 .ok_or_else(|| {
462 XlogError::Kernel("arith_fill_const_u32 not found".to_string())
463 })?;
464 let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
465 unsafe {
467 fill_fn
468 .clone()
469 .launch(fill_config, (u32::MAX, row_cap_u32, &mut output))
470 }
471 .map_err(|e| {
472 XlogError::Kernel(format!("Failed to init min output: {}", e))
473 })?;
474
475 let min_func = device
476 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN)
477 .ok_or_else(|| {
478 XlogError::Kernel("groupby_min kernel not found".to_string())
479 })?;
480
481 unsafe {
483 min_func
484 .clone()
485 .launch(config, (&values_view, &group_ids, num_rows, &output))
486 }
487 .map_err(|e| XlogError::Kernel(format!("groupby_min failed: {}", e)))?;
488
489 self.device.synchronize()?;
490 agg_columns.push(output.into());
491 }
492 }
493 AggOp::Max => {
494 let value_ty = sorted
495 .schema()
496 .column_type(value_col)
497 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
498 if value_ty == ScalarType::U64 {
499 let values_view = self.column_as_u64_view(values, num_rows as usize)?;
502 let output_bytes = row_cap_usize
503 .checked_mul(std::mem::size_of::<u64>())
504 .ok_or_else(|| {
505 XlogError::Kernel("Max output size overflow".to_string())
506 })?;
507 let mut output = self.memory.alloc::<u8>(output_bytes)?;
508 device.memset_zeros(&mut output).map_err(|e| {
509 XlogError::Kernel(format!("Failed to zero max output: {}", e))
510 })?;
511
512 let max_func = device
513 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX_U64)
514 .ok_or_else(|| {
515 XlogError::Kernel("groupby_max_u64 kernel not found".to_string())
516 })?;
517 unsafe {
519 max_func
520 .clone()
521 .launch(config, (&values_view, &group_ids, num_rows, &output))
522 }
523 .map_err(|e| XlogError::Kernel(format!("groupby_max_u64 failed: {}", e)))?;
524
525 self.device.synchronize()?;
526 agg_columns.push(output.into());
527 } else {
528 let values_view = self.column_as_u32_view(values, num_rows as usize)?;
529 let output_bytes = row_cap_usize
530 .checked_mul(std::mem::size_of::<u32>())
531 .ok_or_else(|| {
532 XlogError::Kernel("Max output size overflow".to_string())
533 })?;
534 let mut output = self.memory.alloc::<u8>(output_bytes)?;
535 device.memset_zeros(&mut output).map_err(|e| {
536 XlogError::Kernel(format!("Failed to zero max output: {}", e))
537 })?;
538
539 let max_func = device
540 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX)
541 .ok_or_else(|| {
542 XlogError::Kernel("groupby_max kernel not found".to_string())
543 })?;
544
545 unsafe {
547 max_func
548 .clone()
549 .launch(config, (&values_view, &group_ids, num_rows, &output))
550 }
551 .map_err(|e| XlogError::Kernel(format!("groupby_max failed: {}", e)))?;
552
553 self.device.synchronize()?;
554 agg_columns.push(output.into());
555 }
556 }
557 AggOp::LogSumExp => {
558 let values_f64 = self.column_as_f64_view(values, num_rows as usize)?;
559 let output_bytes = row_cap_usize
560 .checked_mul(std::mem::size_of::<f64>())
561 .ok_or_else(|| {
562 XlogError::Kernel("LogSumExp output size overflow".to_string())
563 })?;
564 let mut maxs = self.memory.alloc::<u8>(output_bytes)?;
565 let mut sumexps = self.memory.alloc::<u8>(output_bytes)?;
566 let results = self.memory.alloc::<u8>(output_bytes)?;
567
568 let fill_f64 = device
569 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_F64)
570 .ok_or_else(|| {
571 XlogError::Kernel("arith_fill_const_f64 not found".to_string())
572 })?;
573 let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
574 unsafe {
576 fill_f64
577 .clone()
578 .launch(fill_config, (f64::NEG_INFINITY, row_cap_u32, &mut maxs))
579 }
580 .map_err(|e| XlogError::Kernel(format!("Failed to init maxs: {}", e)))?;
581 device
582 .memset_zeros(&mut sumexps)
583 .map_err(|e| XlogError::Kernel(format!("Failed to init sumexps: {}", e)))?;
584
585 let max_func = device
586 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_MAX)
587 .ok_or_else(|| {
588 XlogError::Kernel("groupby_logsumexp_max kernel not found".to_string())
589 })?;
590
591 unsafe {
593 max_func
594 .clone()
595 .launch(config, (&values_f64, &group_ids, num_rows, &maxs))
596 }
597 .map_err(|e| {
598 XlogError::Kernel(format!("groupby_logsumexp_max failed: {}", e))
599 })?;
600
601 self.device.synchronize()?;
602
603 let sumexp_func = device
604 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_SUMEXP)
605 .ok_or_else(|| {
606 XlogError::Kernel(
607 "groupby_logsumexp_sumexp kernel not found".to_string(),
608 )
609 })?;
610
611 unsafe {
613 sumexp_func
614 .clone()
615 .launch(config, (&values_f64, &group_ids, &maxs, num_rows, &sumexps))
616 }
617 .map_err(|e| {
618 XlogError::Kernel(format!("groupby_logsumexp_sumexp failed: {}", e))
619 })?;
620
621 self.device.synchronize()?;
622
623 let final_config = LaunchConfig::for_num_elems(row_cap_u32);
624 let final_func = device
625 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_FINAL)
626 .ok_or_else(|| {
627 XlogError::Kernel(
628 "groupby_logsumexp_final kernel not found".to_string(),
629 )
630 })?;
631
632 unsafe {
634 final_func.clone().launch(
635 final_config,
636 (&maxs, &sumexps, &d_num_groups, row_cap_u32, &results),
637 )
638 }
639 .map_err(|e| {
640 XlogError::Kernel(format!("groupby_logsumexp_final failed: {}", e))
641 })?;
642
643 self.device.synchronize()?;
644 agg_columns.push(results.into());
645 }
646 }
647 }
648
649 let mut result_columns: Vec<CudaColumn> = Vec::with_capacity(key_cols.len() + aggs.len());
651
652 let group_packed_bytes = row_cap_usize
653 .checked_mul(packed.key_bytes as usize)
654 .ok_or_else(|| XlogError::Kernel("GroupBy packed size overflow".to_string()))?;
655 let mut group_packed = self.memory.alloc::<u8>(group_packed_bytes)?;
656
657 let gather_fn = device
658 .get_func(PACK_MODULE, pack_kernels::GATHER_PACKED_ROWS_COUNTED)
659 .ok_or_else(|| {
660 XlogError::Kernel("gather_packed_rows_counted kernel not found".to_string())
661 })?;
662 let gather_config = LaunchConfig::for_num_elems(row_cap_u32);
663
664 unsafe {
666 gather_fn.clone().launch(
667 gather_config,
668 (
669 &packed.packed_keys,
670 packed.key_bytes,
671 &group_first_idx,
672 &d_num_groups,
673 row_cap_u32,
674 &mut group_packed,
675 ),
676 )
677 }
678 .map_err(|e| XlogError::Kernel(format!("gather_packed_rows failed: {}", e)))?;
679
680 let mut col_offsets: Vec<u32> = Vec::with_capacity(key_cols.len());
681 let mut col_sizes: Vec<u32> = Vec::with_capacity(key_cols.len());
682 let mut offset = 0u32;
683 for &key_col in key_cols {
684 let size = buffer
685 .schema()
686 .column_type(key_col)
687 .map(|t| t.size_bytes() as u32)
688 .unwrap_or(4);
689 col_offsets.push(offset);
690 col_sizes.push(size);
691 offset = offset
692 .checked_add(size)
693 .ok_or_else(|| XlogError::Kernel("GroupBy key size overflow".to_string()))?;
694 }
695
696 let unpack_fn = device
697 .get_func(PACK_MODULE, pack_kernels::UNPACK_COLUMN_COUNTED)
698 .ok_or_else(|| {
699 XlogError::Kernel("unpack_column_counted kernel not found".to_string())
700 })?;
701 let unpack_config = LaunchConfig::for_num_elems(row_cap_u32);
702
703 for idx in 0..key_cols.len() {
704 let col_size = col_sizes[idx];
705 let col_offset = col_offsets[idx];
706 let out_bytes = row_cap_usize
707 .checked_mul(col_size as usize)
708 .ok_or_else(|| XlogError::Kernel("GroupBy key column overflow".to_string()))?;
709 let mut out_col = self.memory.alloc::<u8>(out_bytes)?;
710
711 unsafe {
713 unpack_fn.clone().launch(
714 unpack_config,
715 (
716 &group_packed,
717 packed.key_bytes,
718 col_offset,
719 col_size,
720 &d_num_groups,
721 row_cap_u32,
722 &mut out_col,
723 ),
724 )
725 }
726 .map_err(|e| XlogError::Kernel(format!("unpack_column failed: {}", e)))?;
727
728 result_columns.push(out_col.into());
729 }
730
731 result_columns.extend(agg_columns);
732
733 let result_schema = self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
734
735 Ok(CudaBuffer::from_columns(
736 result_columns,
737 row_cap,
738 d_num_groups,
739 result_schema,
740 ))
741 }
742
743 fn capture_num_groups(
744 &self,
745 boundary_pos: &TrackedCudaSlice<u32>,
746 boundaries: &TrackedCudaSlice<u8>,
747 num_rows: u32,
748 ) -> Result<TrackedCudaSlice<u32>> {
749 let mut d_num_groups = self.memory.alloc::<u32>(1)?;
750 let capture_fn = self
751 .device
752 .inner()
753 .get_func(GROUPBY_MODULE, groupby_kernels::CAPTURE_NUM_GROUPS)
754 .ok_or_else(|| XlogError::Kernel("capture_num_groups kernel not found".to_string()))?;
755 unsafe {
757 capture_fn.clone().launch(
758 LaunchConfig {
759 grid_dim: (1, 1, 1),
760 block_dim: (1, 1, 1),
761 shared_mem_bytes: 0,
762 },
763 (boundary_pos, boundaries, num_rows, &mut d_num_groups),
764 )
765 }
766 .map_err(|e| XlogError::Kernel(format!("capture_num_groups failed: {}", e)))?;
767 Ok(d_num_groups)
768 }
769
770 pub(crate) fn groupby_multi_agg_result_schema(
772 &self,
773 input: &Schema,
774 key_cols: &[usize],
775 aggs: &[(usize, AggOp)],
776 ) -> Schema {
777 let mut columns: Vec<(String, ScalarType)> = key_cols
778 .iter()
779 .filter_map(|&i| input.columns.get(i).cloned())
780 .collect();
781 let mut sort_labels: Vec<String> = key_cols
782 .iter()
783 .filter_map(|&i| {
784 input
785 .column_sort_label(i)
786 .map(ToString::to_string)
787 .or_else(|| input.columns.get(i).map(|(name, _)| name.clone()))
788 })
789 .collect();
790
791 for (i, &(value_col, agg_op)) in aggs.iter().enumerate() {
792 let agg_name = match agg_op {
793 AggOp::Count => format!("count_{}", i),
794 AggOp::Sum => format!("sum_{}", i),
795 AggOp::Min => format!("min_{}", i),
796 AggOp::Max => format!("max_{}", i),
797 AggOp::LogSumExp => format!("logsumexp_{}", i),
798 };
799 let agg_type = match agg_op {
802 AggOp::Count => ScalarType::U64,
803 AggOp::Sum => ScalarType::U64,
804 AggOp::Min | AggOp::Max => match input.columns.get(value_col).map(|(_, ty)| *ty) {
807 Some(ScalarType::U64) => ScalarType::U64,
808 _ => ScalarType::U32,
809 },
810 AggOp::LogSumExp => ScalarType::F64,
811 };
812 columns.push((agg_name, agg_type));
813 sort_labels.push(format!("aggregate_{}", i));
814 }
815
816 Schema::new(columns)
817 .with_sort_labels(sort_labels)
818 .expect("groupby result sort labels match column arity")
819 }
820
821 pub(super) fn pack_keys_gpu_on_stream(
842 &self,
843 buffer: &CudaBuffer,
844 key_cols: &[usize],
845 cu_stream: &cudarc::driver::CudaStream,
846 launch_stream: crate::device_runtime::StreamId,
847 runtime: &crate::device_runtime::XlogDeviceRuntime,
848 ) -> Result<crate::provider::PackedKeyData> {
849 use crate::launch::LaunchRecorder;
850
851 if key_cols.is_empty() {
852 return Err(XlogError::Kernel(
853 "pack_keys_gpu_on_stream: no key columns specified".to_string(),
854 ));
855 }
856 if key_cols.len() > 4 {
857 return Err(XlogError::Kernel(
858 "pack_keys_gpu_on_stream: max 4 key columns supported".to_string(),
859 ));
860 }
861 let num_rows = self.device_row_count(buffer)?;
862 if num_rows > u32::MAX as usize {
863 return Err(XlogError::Kernel(format!(
864 "pack_keys_gpu_on_stream supports at most {} rows, got {}",
865 u32::MAX,
866 num_rows
867 )));
868 }
869 let num_rows = num_rows as u32;
870
871 let mut col_sizes_host: Vec<u32> = Vec::with_capacity(key_cols.len());
872 let mut row_size: u32 = 0;
873 for &col_idx in key_cols {
874 let ty = buffer
875 .schema()
876 .column_type(col_idx)
877 .ok_or_else(|| XlogError::Kernel(format!("Invalid column index: {}", col_idx)))?;
878 let s = ty.size_bytes() as u32;
879 col_sizes_host.push(s);
880 row_size += s;
881 }
882
883 if num_rows == 0 {
884 return Ok(crate::provider::PackedKeyData {
885 hashes: self.memory.alloc::<u64>(0)?,
886 packed_keys: self.memory.alloc::<u8>(0)?,
887 key_bytes: row_size,
888 });
889 }
890
891 let packed_bytes = (num_rows as u64) * (row_size as u64);
892 let packed_slice = self.memory.alloc::<u8>(packed_bytes as usize)?;
893 let hash_slice = self.memory.alloc::<u64>(num_rows as usize)?;
894
895 let mut col_ptrs: [u64; 4] = [0; 4];
896 for (i, &col_idx) in key_cols.iter().enumerate() {
897 let col = buffer
898 .column(col_idx)
899 .ok_or_else(|| XlogError::Kernel(format!("Key column {} not found", col_idx)))?;
900 col_ptrs[i] = *col.device_ptr();
901 }
902 let mut packed_col_sizes = 0u64;
903 for (i, size) in col_sizes_host.iter().copied().enumerate() {
904 if size > u16::MAX as u32 {
905 return Err(XlogError::Kernel(format!(
906 "pack_keys_gpu_on_stream: column element size {} exceeds 16-bit kernel argument",
907 size
908 )));
909 }
910 packed_col_sizes |= (size as u64) << (i * 16);
911 }
912
913 let mut rec = LaunchRecorder::new_strict(launch_stream);
921 for &col_idx in key_cols {
922 let col = buffer
923 .column(col_idx)
924 .ok_or_else(|| XlogError::Kernel(format!("Key column {} not found", col_idx)))?;
925 rec.read_column(col);
926 }
927 rec.write(&packed_slice);
928 rec.write(&hash_slice);
929 rec.preflight(runtime).map_err(|e| {
930 XlogError::Kernel(format!(
931 "pack_keys_gpu_on_stream: launch recorder preflight failed: {}",
932 e
933 ))
934 })?;
935
936 let func = self
937 .device
938 .inner()
939 .get_func(PACK_MODULE, pack_kernels::PACK_AND_HASH_KEYS)
940 .ok_or_else(|| XlogError::Kernel("pack_and_hash_keys kernel not found".to_string()))?;
941 let block_size = 256u32;
942 let grid_size = num_rows.div_ceil(block_size);
943 let cfg = LaunchConfig {
944 grid_dim: (grid_size, 1, 1),
945 block_dim: (block_size, 1, 1),
946 shared_mem_bytes: 0,
947 };
948 unsafe {
950 func.clone().launch_on_stream(
951 cu_stream,
952 cfg,
953 (
954 col_ptrs[0],
955 col_ptrs[1],
956 col_ptrs[2],
957 col_ptrs[3],
958 packed_col_sizes,
959 key_cols.len() as u32,
960 num_rows,
961 row_size,
962 &packed_slice,
963 &hash_slice,
964 ),
965 )
966 }
967 .map_err(|e| XlogError::Kernel(format!("pack_and_hash_keys (on_stream) failed: {}", e)))?;
968
969 rec.commit(runtime).map_err(|e| {
975 XlogError::Kernel(format!(
976 "pack_keys_gpu_on_stream: launch recorder commit failed: {}",
977 e
978 ))
979 })?;
980
981 Ok(crate::provider::PackedKeyData {
982 hashes: hash_slice,
983 packed_keys: packed_slice,
984 key_bytes: row_size,
985 })
986 }
987
988 fn memset_zeros_u8_on_stream(
992 &self,
993 buf: &mut TrackedCudaSlice<u8>,
994 cu_stream: &cudarc::driver::CudaStream,
995 ) -> Result<()> {
996 if buf.is_empty() {
997 return Ok(());
998 }
999 let ptr = *buf.device_ptr();
1000 let len = <TrackedCudaSlice<u8> as crate::DeviceSlice<u8>>::len(buf);
1001 unsafe {
1006 let res = cudarc::driver::sys::cuMemsetD8Async(ptr, 0, len, cu_stream.cu_stream());
1007 if res != cudarc::driver::sys::cudaError_enum::CUDA_SUCCESS {
1008 return Err(XlogError::Kernel(format!(
1009 "cuMemsetD8Async (groupby init) failed: {:?}",
1010 res
1011 )));
1012 }
1013 }
1014 Ok(())
1015 }
1016
1017 pub fn groupby_multi_agg_recorded(
1050 &self,
1051 buffer: &CudaBuffer,
1052 key_cols: &[usize],
1053 aggs: &[(usize, AggOp)],
1054 launch_stream: crate::device_runtime::StreamId,
1055 ) -> Result<CudaBuffer> {
1056 use crate::launch::LaunchRecorder;
1057
1058 let runtime = self.memory.runtime().ok_or_else(|| {
1059 XlogError::Kernel(
1060 "groupby_multi_agg_recorded requires a runtime-backed GpuMemoryManager".to_string(),
1061 )
1062 })?;
1063 let cu_stream = runtime
1064 .stream_pool()
1065 .resolve(launch_stream)
1066 .ok_or_else(|| {
1067 XlogError::Kernel(format!(
1068 "groupby_multi_agg_recorded: launch_stream StreamId({}) does not resolve",
1069 launch_stream.0
1070 ))
1071 })?;
1072
1073 let num_rows = self.device_row_count(buffer)?;
1074 if num_rows == 0 {
1075 let result_schema =
1076 self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
1077 return self.create_empty_buffer(result_schema);
1078 }
1079 if num_rows > u32::MAX as usize {
1080 return Err(XlogError::Kernel(format!(
1081 "GroupBy supports at most {} rows, got {}",
1082 u32::MAX,
1083 num_rows
1084 )));
1085 }
1086 if key_cols.is_empty() {
1087 return Err(XlogError::Kernel(
1088 "GroupBy requires at least one key column".to_string(),
1089 ));
1090 }
1091 if aggs.is_empty() {
1092 return Err(XlogError::Kernel(
1093 "GroupBy requires at least one aggregation".to_string(),
1094 ));
1095 }
1096 if key_cols.len() > 4 {
1097 return Err(XlogError::Kernel(
1098 "groupby_multi_agg_recorded: max 4 key columns supported (pack_keys constraint)"
1099 .to_string(),
1100 ));
1101 }
1102 for &k in key_cols {
1103 if k >= buffer.arity() {
1104 return Err(XlogError::Kernel(format!(
1105 "Key column {} out of bounds (arity {})",
1106 k,
1107 buffer.arity()
1108 )));
1109 }
1110 let ty = buffer
1111 .schema()
1112 .column_type(k)
1113 .ok_or_else(|| XlogError::Kernel("Key column has no type".to_string()))?;
1114 if !matches!(ty, ScalarType::U32 | ScalarType::Symbol) {
1115 return Err(XlogError::Kernel(format!(
1116 "groupby_multi_agg_recorded: key column type {:?} unsupported (U32 / Symbol \
1117 only); multi-type sort_recorded is deferred",
1118 ty
1119 )));
1120 }
1121 }
1122 for &(value_col, agg_op) in aggs {
1123 if value_col >= buffer.arity() {
1124 return Err(XlogError::Kernel(format!(
1125 "Value column {} out of bounds (arity {})",
1126 value_col,
1127 buffer.arity()
1128 )));
1129 }
1130 let value_ty = buffer
1131 .schema()
1132 .column_type(value_col)
1133 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1134 match agg_op {
1135 AggOp::Count => {}
1136 AggOp::Sum => {
1137 if !matches!(value_ty, ScalarType::U32 | ScalarType::U64) {
1140 return Err(XlogError::Kernel(format!(
1141 "Sum currently requires U32 or U64 values, got {:?}",
1142 value_ty
1143 )));
1144 }
1145 }
1146 AggOp::Min | AggOp::Max => {
1147 if !matches!(value_ty, ScalarType::U32 | ScalarType::U64) {
1151 return Err(XlogError::Kernel(format!(
1152 "{:?} currently requires U32 or U64 values, got {:?}",
1153 agg_op, value_ty
1154 )));
1155 }
1156 }
1157 AggOp::LogSumExp => {
1158 return Err(XlogError::Kernel(
1159 "groupby_multi_agg_recorded: LogSumExp not yet supported in the \
1160 recorded path (multi-kernel chain deferred to a future implementation)"
1161 .to_string(),
1162 ));
1163 }
1164 }
1165 }
1166
1167 let sorted = self.sort_recorded(buffer, key_cols, launch_stream)?;
1169 let num_rows = self.device_row_count(&sorted)?;
1170 if num_rows > u32::MAX as usize {
1171 return Err(XlogError::Kernel(format!(
1172 "GroupBy supports at most {} rows, got {}",
1173 u32::MAX,
1174 num_rows
1175 )));
1176 }
1177 let num_rows = num_rows as u32;
1178 let row_cap_usize = num_rows as usize;
1179 let row_cap_u32 = num_rows;
1180 let row_cap_u64 = num_rows as u64;
1181
1182 let packed =
1184 self.pack_keys_gpu_on_stream(&sorted, key_cols, &cu_stream, launch_stream, runtime)?;
1185 if packed.key_bytes == 0 || packed.key_bytes % 4 != 0 {
1186 return Err(XlogError::Kernel(format!(
1187 "GroupBy key packing produced {} bytes per row (expected multiple of 4); \
1188 Bool keys are not supported",
1189 packed.key_bytes
1190 )));
1191 }
1192 let segments_per_row = (packed.key_bytes / 4) as usize;
1193 let total_segments = row_cap_usize * segments_per_row;
1194 let packed_u32 = self.bytes_as_u32_view(&packed.packed_keys, total_segments)?;
1195
1196 let boundaries = self.memory.alloc::<u8>(row_cap_usize)?;
1201 let block_size = 256u32;
1202 let num_blocks = num_rows.div_ceil(block_size);
1203 let cfg = LaunchConfig {
1204 grid_dim: (num_blocks, 1, 1),
1205 block_dim: (block_size, 1, 1),
1206 shared_mem_bytes: 0,
1207 };
1208 let d_boundary_pos = self.memory.alloc::<u32>(row_cap_usize)?;
1209 let mut d_block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
1210 let mut d_num_groups = self.memory.alloc::<u32>(1)?;
1211 let mut group_ids = self.memory.alloc::<u32>(row_cap_usize)?;
1212 let mut group_first_idx = self.memory.alloc::<u32>(row_cap_usize)?;
1213
1214 let mut agg_outputs: Vec<TrackedCudaSlice<u8>> = Vec::with_capacity(aggs.len());
1216 for &(value_col, agg_op) in aggs {
1217 let elem_size = match agg_op {
1218 AggOp::Count | AggOp::Sum => std::mem::size_of::<u64>(),
1219 AggOp::Min | AggOp::Max => match sorted.schema().column_type(value_col) {
1221 Some(ScalarType::U64) => std::mem::size_of::<u64>(),
1222 _ => std::mem::size_of::<u32>(),
1223 },
1224 AggOp::LogSumExp => unreachable!("rejected above"),
1225 };
1226 let bytes = row_cap_usize
1227 .checked_mul(elem_size)
1228 .ok_or_else(|| XlogError::Kernel("groupby agg output size overflow".to_string()))?;
1229 agg_outputs.push(self.memory.alloc::<u8>(bytes)?);
1230 }
1231
1232 let group_packed_bytes = row_cap_usize
1234 .checked_mul(packed.key_bytes as usize)
1235 .ok_or_else(|| XlogError::Kernel("GroupBy packed size overflow".to_string()))?;
1236 let mut group_packed = self.memory.alloc::<u8>(group_packed_bytes)?;
1237
1238 let mut col_offsets: Vec<u32> = Vec::with_capacity(key_cols.len());
1239 let mut col_sizes: Vec<u32> = Vec::with_capacity(key_cols.len());
1240 let mut offset = 0u32;
1241 for &key_col in key_cols {
1242 let s = buffer
1243 .schema()
1244 .column_type(key_col)
1245 .map(|t| t.size_bytes() as u32)
1246 .unwrap_or(4);
1247 col_offsets.push(offset);
1248 col_sizes.push(s);
1249 offset = offset
1250 .checked_add(s)
1251 .ok_or_else(|| XlogError::Kernel("GroupBy key size overflow".to_string()))?;
1252 }
1253 let mut key_unpacked: Vec<TrackedCudaSlice<u8>> = Vec::with_capacity(key_cols.len());
1254 for &col_size in &col_sizes {
1255 let bytes = row_cap_usize
1256 .checked_mul(col_size as usize)
1257 .ok_or_else(|| XlogError::Kernel("GroupBy key column overflow".to_string()))?;
1258 key_unpacked.push(self.memory.alloc::<u8>(bytes)?);
1259 }
1260
1261 let mut rec = LaunchRecorder::new_strict(launch_stream);
1268 rec.read(sorted.num_rows_device());
1269 rec.read(&packed.packed_keys);
1273 for &(value_col, _) in aggs {
1274 let c = sorted.column(value_col).ok_or_else(|| {
1275 XlogError::Kernel(format!("Value column {} not found", value_col))
1276 })?;
1277 rec.read_column(c);
1278 }
1279 rec.write(&boundaries);
1280 rec.write(&d_boundary_pos);
1281 rec.write(&d_block_sums);
1282 rec.write(&d_num_groups);
1283 rec.write(&group_ids);
1284 rec.write(&group_first_idx);
1285 rec.write(&group_packed);
1286 for o in &agg_outputs {
1287 rec.write(o);
1288 }
1289 for k in &key_unpacked {
1290 rec.write(k);
1291 }
1292 rec.preflight(runtime).map_err(|e| {
1293 XlogError::Kernel(format!(
1294 "groupby_multi_agg_recorded: preflight failed: {}",
1295 e
1296 ))
1297 })?;
1298
1299 let device = self.device.inner();
1300
1301 let boundary_func = device
1303 .get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES)
1304 .ok_or_else(|| {
1305 XlogError::Kernel("detect_group_boundaries kernel not found".to_string())
1306 })?;
1307 unsafe {
1309 boundary_func.clone().launch_on_stream(
1310 &cu_stream,
1311 cfg,
1312 (
1313 &packed_u32,
1314 num_rows,
1315 segments_per_row as u32,
1316 segments_per_row as u32,
1317 &boundaries,
1318 ),
1319 )
1320 }
1321 .map_err(|e| {
1322 XlogError::Kernel(format!("detect_group_boundaries (on_stream) failed: {}", e))
1323 })?;
1324
1325 let phase1_fn = device
1327 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE1)
1328 .ok_or_else(|| {
1329 XlogError::Kernel("Failed to get multiblock_scan_phase1 kernel".to_string())
1330 })?;
1331 unsafe {
1333 phase1_fn.clone().launch_on_stream(
1334 &cu_stream,
1335 LaunchConfig {
1336 grid_dim: (num_blocks, 1, 1),
1337 block_dim: (block_size, 1, 1),
1338 shared_mem_bytes: 0,
1339 },
1340 (&boundaries, &d_boundary_pos, &d_block_sums, num_rows),
1341 )
1342 }
1343 .map_err(|e| {
1344 XlogError::Kernel(format!("multiblock_scan_phase1 (on_stream) failed: {}", e))
1345 })?;
1346
1347 if num_blocks > 1 {
1348 self.multiblock_scan_u32_inplace_on_stream(
1349 &mut d_block_sums,
1350 num_blocks,
1351 &cu_stream,
1352 launch_stream,
1353 runtime,
1354 )?;
1355 let phase3_fn = device
1356 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
1357 .ok_or_else(|| {
1358 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
1359 })?;
1360 unsafe {
1362 phase3_fn.clone().launch_on_stream(
1363 &cu_stream,
1364 LaunchConfig {
1365 grid_dim: (num_blocks, 1, 1),
1366 block_dim: (block_size, 1, 1),
1367 shared_mem_bytes: 0,
1368 },
1369 (&d_boundary_pos, &d_block_sums, num_rows),
1370 )
1371 }
1372 .map_err(|e| {
1373 XlogError::Kernel(format!("multiblock_scan_phase3 (on_stream) failed: {}", e))
1374 })?;
1375 }
1376
1377 let capture_fn = device
1379 .get_func(GROUPBY_MODULE, groupby_kernels::CAPTURE_NUM_GROUPS)
1380 .ok_or_else(|| XlogError::Kernel("capture_num_groups kernel not found".to_string()))?;
1381 unsafe {
1383 capture_fn.clone().launch_on_stream(
1384 &cu_stream,
1385 LaunchConfig {
1386 grid_dim: (1, 1, 1),
1387 block_dim: (1, 1, 1),
1388 shared_mem_bytes: 0,
1389 },
1390 (&d_boundary_pos, &boundaries, num_rows, &mut d_num_groups),
1391 )
1392 }
1393 .map_err(|e| XlogError::Kernel(format!("capture_num_groups (on_stream) failed: {}", e)))?;
1394
1395 let group_ids_fn = device
1397 .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_IDS_FROM_BOUNDARIES)
1398 .ok_or_else(|| {
1399 XlogError::Kernel("group_ids_from_boundaries kernel not found".to_string())
1400 })?;
1401 let group_start_fn = device
1402 .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_START_INDICES)
1403 .ok_or_else(|| XlogError::Kernel("group_start_indices kernel not found".to_string()))?;
1404 unsafe {
1406 group_ids_fn.clone().launch_on_stream(
1407 &cu_stream,
1408 cfg,
1409 (&boundaries, &d_boundary_pos, num_rows, &mut group_ids),
1410 )
1411 }
1412 .map_err(|e| {
1413 XlogError::Kernel(format!(
1414 "group_ids_from_boundaries (on_stream) failed: {}",
1415 e
1416 ))
1417 })?;
1418 unsafe {
1419 group_start_fn.clone().launch_on_stream(
1420 &cu_stream,
1421 cfg,
1422 (&boundaries, &d_boundary_pos, num_rows, &mut group_first_idx),
1423 )
1424 }
1425 .map_err(|e| XlogError::Kernel(format!("group_start_indices (on_stream) failed: {}", e)))?;
1426
1427 for ((value_col, agg_op), output) in aggs.iter().zip(agg_outputs.iter_mut()) {
1429 let values = sorted.column(*value_col).ok_or_else(|| {
1430 XlogError::Kernel(format!("Value column {} not found", value_col))
1431 })?;
1432 match agg_op {
1433 AggOp::Count => {
1434 self.memset_zeros_u8_on_stream(output, &cu_stream)?;
1435 let count_func = device
1436 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT)
1437 .ok_or_else(|| {
1438 XlogError::Kernel("groupby_count kernel not found".to_string())
1439 })?;
1440 unsafe {
1442 count_func.clone().launch_on_stream(
1443 &cu_stream,
1444 cfg,
1445 (&boundaries, &group_ids, num_rows, &*output),
1446 )
1447 }
1448 .map_err(|e| {
1449 XlogError::Kernel(format!("groupby_count (on_stream) failed: {}", e))
1450 })?;
1451 }
1452 AggOp::Sum => {
1453 self.memset_zeros_u8_on_stream(output, &cu_stream)?;
1454 let value_ty = sorted
1455 .schema()
1456 .column_type(*value_col)
1457 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1458 if value_ty == ScalarType::U64 {
1459 let values_view = self.column_as_u64_view(values, row_cap_usize)?;
1460 let sum_func = device
1461 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM_U64)
1462 .ok_or_else(|| {
1463 XlogError::Kernel("groupby_sum_u64 kernel not found".to_string())
1464 })?;
1465 unsafe {
1467 sum_func.clone().launch_on_stream(
1468 &cu_stream,
1469 cfg,
1470 (&values_view, &group_ids, num_rows, &*output),
1471 )
1472 }
1473 .map_err(|e| {
1474 XlogError::Kernel(format!("groupby_sum_u64 (on_stream) failed: {}", e))
1475 })?;
1476 } else {
1477 let values_view = self.column_as_u32_view(values, row_cap_usize)?;
1478 let sum_func = device
1479 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM)
1480 .ok_or_else(|| {
1481 XlogError::Kernel("groupby_sum kernel not found".to_string())
1482 })?;
1483 unsafe {
1485 sum_func.clone().launch_on_stream(
1486 &cu_stream,
1487 cfg,
1488 (&values_view, &group_ids, num_rows, &*output),
1489 )
1490 }
1491 .map_err(|e| {
1492 XlogError::Kernel(format!("groupby_sum (on_stream) failed: {}", e))
1493 })?;
1494 }
1495 }
1496 AggOp::Min => {
1497 let value_ty = sorted
1498 .schema()
1499 .column_type(*value_col)
1500 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1501 let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
1502 if value_ty == ScalarType::U64 {
1503 let fill_fn = device
1506 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U64)
1507 .ok_or_else(|| {
1508 XlogError::Kernel("arith_fill_const_u64 not found".to_string())
1509 })?;
1510 unsafe {
1512 fill_fn.clone().launch_on_stream(
1513 &cu_stream,
1514 fill_config,
1515 (u64::MAX, row_cap_u32, &mut *output),
1516 )
1517 }
1518 .map_err(|e| {
1519 XlogError::Kernel(format!(
1520 "arith_fill_const_u64 (on_stream) failed: {}",
1521 e
1522 ))
1523 })?;
1524 let values_view = self.column_as_u64_view(values, row_cap_usize)?;
1525 let min_func = device
1526 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN_U64)
1527 .ok_or_else(|| {
1528 XlogError::Kernel("groupby_min_u64 kernel not found".to_string())
1529 })?;
1530 unsafe {
1532 min_func.clone().launch_on_stream(
1533 &cu_stream,
1534 cfg,
1535 (&values_view, &group_ids, num_rows, &*output),
1536 )
1537 }
1538 .map_err(|e| {
1539 XlogError::Kernel(format!("groupby_min_u64 (on_stream) failed: {}", e))
1540 })?;
1541 } else {
1542 let fill_fn = device
1543 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U32)
1544 .ok_or_else(|| {
1545 XlogError::Kernel("arith_fill_const_u32 not found".to_string())
1546 })?;
1547 unsafe {
1549 fill_fn.clone().launch_on_stream(
1550 &cu_stream,
1551 fill_config,
1552 (u32::MAX, row_cap_u32, &mut *output),
1553 )
1554 }
1555 .map_err(|e| {
1556 XlogError::Kernel(format!(
1557 "arith_fill_const_u32 (on_stream) failed: {}",
1558 e
1559 ))
1560 })?;
1561 let values_view = self.column_as_u32_view(values, row_cap_usize)?;
1562 let min_func = device
1563 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN)
1564 .ok_or_else(|| {
1565 XlogError::Kernel("groupby_min kernel not found".to_string())
1566 })?;
1567 unsafe {
1569 min_func.clone().launch_on_stream(
1570 &cu_stream,
1571 cfg,
1572 (&values_view, &group_ids, num_rows, &*output),
1573 )
1574 }
1575 .map_err(|e| {
1576 XlogError::Kernel(format!("groupby_min (on_stream) failed: {}", e))
1577 })?;
1578 }
1579 }
1580 AggOp::Max => {
1581 self.memset_zeros_u8_on_stream(output, &cu_stream)?;
1582 let value_ty = sorted
1583 .schema()
1584 .column_type(*value_col)
1585 .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1586 if value_ty == ScalarType::U64 {
1587 let values_view = self.column_as_u64_view(values, row_cap_usize)?;
1590 let max_func = device
1591 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX_U64)
1592 .ok_or_else(|| {
1593 XlogError::Kernel("groupby_max_u64 kernel not found".to_string())
1594 })?;
1595 unsafe {
1597 max_func.clone().launch_on_stream(
1598 &cu_stream,
1599 cfg,
1600 (&values_view, &group_ids, num_rows, &*output),
1601 )
1602 }
1603 .map_err(|e| {
1604 XlogError::Kernel(format!("groupby_max_u64 (on_stream) failed: {}", e))
1605 })?;
1606 } else {
1607 let values_view = self.column_as_u32_view(values, row_cap_usize)?;
1608 let max_func = device
1609 .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX)
1610 .ok_or_else(|| {
1611 XlogError::Kernel("groupby_max kernel not found".to_string())
1612 })?;
1613 unsafe {
1615 max_func.clone().launch_on_stream(
1616 &cu_stream,
1617 cfg,
1618 (&values_view, &group_ids, num_rows, &*output),
1619 )
1620 }
1621 .map_err(|e| {
1622 XlogError::Kernel(format!("groupby_max (on_stream) failed: {}", e))
1623 })?;
1624 }
1625 }
1626 AggOp::LogSumExp => unreachable!("rejected above"),
1627 }
1628 }
1629
1630 let gather_fn = device
1632 .get_func(PACK_MODULE, pack_kernels::GATHER_PACKED_ROWS_COUNTED)
1633 .ok_or_else(|| {
1634 XlogError::Kernel("gather_packed_rows_counted kernel not found".to_string())
1635 })?;
1636 let gather_config = LaunchConfig::for_num_elems(row_cap_u32);
1637 unsafe {
1639 gather_fn.clone().launch_on_stream(
1640 &cu_stream,
1641 gather_config,
1642 (
1643 &packed.packed_keys,
1644 packed.key_bytes,
1645 &group_first_idx,
1646 &d_num_groups,
1647 row_cap_u32,
1648 &mut group_packed,
1649 ),
1650 )
1651 }
1652 .map_err(|e| {
1653 XlogError::Kernel(format!(
1654 "gather_packed_rows_counted (on_stream) failed: {}",
1655 e
1656 ))
1657 })?;
1658
1659 let unpack_fn = device
1661 .get_func(PACK_MODULE, pack_kernels::UNPACK_COLUMN_COUNTED)
1662 .ok_or_else(|| {
1663 XlogError::Kernel("unpack_column_counted kernel not found".to_string())
1664 })?;
1665 let unpack_config = LaunchConfig::for_num_elems(row_cap_u32);
1666 for idx in 0..key_cols.len() {
1667 let col_size = col_sizes[idx];
1668 let col_offset = col_offsets[idx];
1669 unsafe {
1672 unpack_fn.clone().launch_on_stream(
1673 &cu_stream,
1674 unpack_config,
1675 (
1676 &group_packed,
1677 packed.key_bytes,
1678 col_offset,
1679 col_size,
1680 &d_num_groups,
1681 row_cap_u32,
1682 &mut key_unpacked[idx],
1683 ),
1684 )
1685 }
1686 .map_err(|e| {
1687 XlogError::Kernel(format!("unpack_column_counted (on_stream) failed: {}", e))
1688 })?;
1689 }
1690
1691 rec.commit(runtime).map_err(|e| {
1693 XlogError::Kernel(format!("groupby_multi_agg_recorded: commit failed: {}", e))
1694 })?;
1695
1696 let mut result_columns: Vec<CudaColumn> = Vec::with_capacity(key_cols.len() + aggs.len());
1698 for k in key_unpacked {
1699 result_columns.push(k.into());
1700 }
1701 for o in agg_outputs {
1702 result_columns.push(o.into());
1703 }
1704 let result_schema = self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
1705 Ok(CudaBuffer::from_columns(
1706 result_columns,
1707 row_cap_u64,
1708 d_num_groups,
1709 result_schema,
1710 ))
1711 }
1712
1713 pub fn groupby_agg_recorded(
1717 &self,
1718 input: &CudaBuffer,
1719 key_cols: &[usize],
1720 agg: AggOp,
1721 value_col: usize,
1722 launch_stream: crate::device_runtime::StreamId,
1723 ) -> Result<CudaBuffer> {
1724 self.groupby_multi_agg_recorded(input, key_cols, &[(value_col, agg)], launch_stream)
1725 }
1726}