Skip to main content

xlog_cuda/provider/
groupby.rs

1//! Groupby aggregation operations: groupby_agg, groupby_multi_agg.
2
3use crate::{LaunchAsync, LaunchConfig};
4use xlog_core::{AggOp, Result, ScalarType, Schema, XlogError};
5
6use super::{
7    arith_kernels, groupby_kernels, pack_kernels, scan_kernels, ARITH_MODULE, GROUPBY_MODULE,
8    PACK_MODULE, SCAN_MODULE,
9};
10use crate::memory::{CudaColumn, TrackedCudaSlice};
11use crate::CudaBuffer;
12
13impl super::CudaKernelProvider {
14    /// Perform groupby aggregation
15    ///
16    /// Assumes input is already sorted by key columns.
17    ///
18    /// # Arguments
19    /// * `input` - The input buffer
20    /// * `key_cols` - Column indices for grouping
21    /// * `agg` - Aggregation operation to perform
22    /// * `value_col` - Column index for the value to aggregate
23    ///
24    /// # Returns
25    /// A buffer with one row per group, containing key columns and aggregated value
26    ///
27    /// # Errors
28    /// Returns `XlogError::Kernel` if kernel execution fails
29    pub fn groupby_agg(
30        &self,
31        input: &CudaBuffer,
32        key_cols: &[usize],
33        agg: AggOp,
34        value_col: usize,
35    ) -> Result<CudaBuffer> {
36        self.groupby_multi_agg(input, key_cols, &[(value_col, agg)])
37    }
38
39    /// Multi-aggregation groupby
40    ///
41    /// Performs groupby with multiple aggregation operations at once.
42    /// This is more efficient than running separate groupby operations
43    /// because it only sorts and computes group boundaries once.
44    ///
45    /// # Arguments
46    /// * `buffer` - The input buffer
47    /// * `key_cols` - Column indices for grouping (currently only single-column supported)
48    /// * `aggs` - A slice of (value_col, AggOp) pairs specifying which aggregations to perform
49    ///
50    /// # Returns
51    /// A buffer with one row per group, containing key columns followed by aggregated values
52    /// in the same order as the `aggs` parameter
53    ///
54    /// # Errors
55    /// Returns `XlogError::Kernel` if kernel execution fails
56    ///
57    /// # Example
58    /// ```ignore
59    /// let result = provider.groupby_multi_agg(
60    ///     &buffer,
61    ///     &[0],  // group by column 0
62    ///     &[(1, AggOp::Sum), (1, AggOp::Count), (1, AggOp::Min)],
63    /// )?;
64    /// // result has columns: key, sum, count, min
65    /// ```
66    pub fn groupby_multi_agg(
67        &self,
68        buffer: &CudaBuffer,
69        key_cols: &[usize],
70        aggs: &[(usize, AggOp)],
71    ) -> Result<CudaBuffer> {
72        // Env-gated recorded dispatch. `groupby_multi_agg_recorded`
73        // is narrow to U32 / Symbol keys + Count /
74        // Sum / Min / Max aggs + ≤4 key columns. Mismatch
75        // (any other key type, LogSumExp, or >4 keys) falls
76        // through to the legacy path.
77        if Self::use_recorded_groupby_env()
78            && !key_cols.is_empty()
79            && !aggs.is_empty()
80            && key_cols.len() <= 4
81        {
82            if let Some(launch_stream) = self.recorded_op_stream_or_init() {
83                let keys_compatible = key_cols.iter().all(|&k| {
84                    matches!(
85                        buffer.schema.column_type(k),
86                        Some(ScalarType::U32) | Some(ScalarType::Symbol)
87                    )
88                });
89                let aggs_compatible = aggs.iter().all(|&(_, op)| {
90                    matches!(op, AggOp::Count | AggOp::Sum | AggOp::Min | AggOp::Max)
91                });
92                if keys_compatible && aggs_compatible {
93                    return self.groupby_multi_agg_recorded(buffer, key_cols, aggs, launch_stream);
94                }
95            }
96        }
97        let num_rows = self.device_row_count(buffer)?;
98        if num_rows == 0 {
99            let result_schema =
100                self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
101            return self.create_empty_buffer(result_schema);
102        }
103        if num_rows > u32::MAX as usize {
104            return Err(XlogError::Kernel(format!(
105                "GroupBy supports at most {} rows, got {}",
106                u32::MAX,
107                num_rows
108            )));
109        }
110
111        // Validate inputs
112        if key_cols.is_empty() {
113            return Err(XlogError::Kernel(
114                "GroupBy requires at least one key column".to_string(),
115            ));
116        }
117        if aggs.is_empty() {
118            return Err(XlogError::Kernel(
119                "GroupBy requires at least one aggregation".to_string(),
120            ));
121        }
122
123        // Validate key columns exist
124        for &key_col in key_cols {
125            if key_col >= buffer.arity() {
126                return Err(XlogError::Kernel(format!(
127                    "Key column {} out of bounds (arity {})",
128                    key_col,
129                    buffer.arity()
130                )));
131            }
132        }
133
134        // Validate all value columns exist and basic dtype constraints for current kernels.
135        for &(value_col, agg_op) in aggs {
136            if value_col >= buffer.arity() {
137                return Err(XlogError::Kernel(format!(
138                    "Value column {} out of bounds (arity {})",
139                    value_col,
140                    buffer.arity()
141                )));
142            }
143
144            let value_ty = buffer
145                .schema()
146                .column_type(value_col)
147                .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
148            match agg_op {
149                AggOp::Count => {}
150                // U64 value-column widening: values reduce through the u64-value
151                // kernels (the legacy groupby is the unfused baseline for
152                // u64-key WCOJ relations, whose value columns are U64).
153                AggOp::Sum | AggOp::Min | AggOp::Max => {
154                    if !matches!(value_ty, ScalarType::U32 | ScalarType::U64) {
155                        return Err(XlogError::Kernel(format!(
156                            "{:?} currently requires U32 or U64 values, got {:?}",
157                            agg_op, value_ty
158                        )));
159                    }
160                }
161                AggOp::LogSumExp => {
162                    if value_ty != ScalarType::F64 {
163                        return Err(XlogError::Kernel(format!(
164                            "LogSumExp requires F64 values, got {:?}",
165                            value_ty
166                        )));
167                    }
168                }
169            }
170        }
171
172        // Step 1: Sort buffer by key columns
173        let sorted = self.sort(buffer, key_cols)?;
174        let num_rows = self.device_row_count(&sorted)?;
175        if num_rows > u32::MAX as usize {
176            return Err(XlogError::Kernel(format!(
177                "GroupBy supports at most {} rows, got {}",
178                u32::MAX,
179                num_rows
180            )));
181        }
182        let num_rows = num_rows as u32;
183
184        // Step 2: Detect boundaries using detect_group_boundaries kernel over packed key bytes
185        let boundary_func = self
186            .device
187            .inner()
188            .get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES)
189            .ok_or_else(|| {
190                XlogError::Kernel("detect_group_boundaries kernel not found".to_string())
191            })?;
192
193        // Allocate boundaries mask
194        let boundaries = self.memory.alloc::<u8>(num_rows as usize)?;
195
196        let packed = self.compute_hashes_and_pack_keys(&sorted, key_cols)?;
197        if packed.key_bytes == 0 || packed.key_bytes % 4 != 0 {
198            return Err(XlogError::Kernel(format!(
199                "GroupBy key packing produced {} bytes per row (expected multiple of 4); Bool keys are not supported",
200                packed.key_bytes
201            )));
202        }
203
204        let segments_per_row = (packed.key_bytes / 4) as usize;
205        let total_segments = (num_rows as usize) * segments_per_row;
206        let packed_u32 = self.bytes_as_u32_view(&packed.packed_keys, total_segments)?;
207
208        // Launch boundary detection
209        let block_size = 256u32;
210        let grid_size = num_rows.div_ceil(block_size);
211        let config = LaunchConfig {
212            grid_dim: (grid_size, 1, 1),
213            block_dim: (block_size, 1, 1),
214            shared_mem_bytes: 0,
215        };
216
217        // SAFETY: Kernel parameters match expected signature
218        unsafe {
219            boundary_func.clone().launch(
220                config,
221                (
222                    &packed_u32,
223                    num_rows,
224                    segments_per_row as u32,
225                    segments_per_row as u32,
226                    &boundaries,
227                ),
228            )
229        }
230        .map_err(|e| XlogError::Kernel(format!("detect_group_boundaries failed: {}", e)))?;
231
232        self.device.synchronize()?;
233
234        // Step 3: Compute group IDs on-device using prefix sum over boundaries.
235        let device = self.device.inner();
236        let num_blocks = grid_size;
237        let d_boundary_pos = self.memory.alloc::<u32>(num_rows as usize)?;
238        let mut d_block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
239
240        let phase1_fn = device
241            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE1)
242            .ok_or_else(|| {
243                XlogError::Kernel("Failed to get multiblock_scan_phase1 kernel".to_string())
244            })?;
245
246        // SAFETY: multiblock_scan_phase1(const uint8_t* mask, uint32_t* prefix_sum, uint32_t* block_sums, uint32_t n)
247        unsafe {
248            phase1_fn.clone().launch(
249                LaunchConfig {
250                    grid_dim: (num_blocks, 1, 1),
251                    block_dim: (block_size, 1, 1),
252                    shared_mem_bytes: 0,
253                },
254                (&boundaries, &d_boundary_pos, &d_block_sums, num_rows),
255            )
256        }
257        .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase1 failed: {}", e)))?;
258
259        if num_blocks > 1 {
260            self.multiblock_scan_u32_inplace(&mut d_block_sums, num_blocks)?;
261
262            let phase3_fn = device
263                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
264                .ok_or_else(|| {
265                    XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
266                })?;
267
268            // SAFETY: multiblock_scan_phase3(uint32_t* prefix_sum, const uint32_t* block_offsets, uint32_t n)
269            unsafe {
270                phase3_fn.clone().launch(
271                    LaunchConfig {
272                        grid_dim: (num_blocks, 1, 1),
273                        block_dim: (block_size, 1, 1),
274                        shared_mem_bytes: 0,
275                    },
276                    (&d_boundary_pos, &d_block_sums, num_rows),
277                )
278            }
279            .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
280        }
281
282        self.device.synchronize()?;
283        let d_num_groups = self.capture_num_groups(&d_boundary_pos, &boundaries, num_rows)?;
284        let row_cap = num_rows as u64;
285        let row_cap_usize = num_rows as usize;
286        let row_cap_u32 = num_rows;
287
288        let mut group_ids = self.memory.alloc::<u32>(num_rows as usize)?;
289        let mut group_first_idx = self.memory.alloc::<u32>(row_cap_usize)?;
290
291        let group_ids_fn = device
292            .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_IDS_FROM_BOUNDARIES)
293            .ok_or_else(|| {
294                XlogError::Kernel("group_ids_from_boundaries kernel not found".to_string())
295            })?;
296        let group_start_fn = device
297            .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_START_INDICES)
298            .ok_or_else(|| XlogError::Kernel("group_start_indices kernel not found".to_string()))?;
299
300        // SAFETY: group_ids_from_boundaries(boundaries, boundary_pos, num_rows, group_ids)
301        unsafe {
302            group_ids_fn.clone().launch(
303                config,
304                (&boundaries, &d_boundary_pos, num_rows, &mut group_ids),
305            )
306        }
307        .map_err(|e| XlogError::Kernel(format!("group_ids_from_boundaries failed: {}", e)))?;
308
309        // SAFETY: group_start_indices(boundaries, boundary_pos, num_rows, group_first_idx)
310        unsafe {
311            group_start_fn.clone().launch(
312                config,
313                (&boundaries, &d_boundary_pos, num_rows, &mut group_first_idx),
314            )
315        }
316        .map_err(|e| XlogError::Kernel(format!("group_start_indices failed: {}", e)))?;
317
318        self.device.synchronize()?;
319
320        // Step 4: For each (value_col, op) pair, run the appropriate kernel on-device.
321        let mut agg_columns: Vec<CudaColumn> = Vec::with_capacity(aggs.len());
322
323        for &(value_col, agg_op) in aggs {
324            let values = sorted
325                .column(value_col)
326                .ok_or_else(|| XlogError::Kernel("Value column not found".to_string()))?;
327
328            match agg_op {
329                AggOp::Count => {
330                    let output_bytes = row_cap_usize
331                        .checked_mul(std::mem::size_of::<u64>())
332                        .ok_or_else(|| {
333                            XlogError::Kernel("Count output size overflow".to_string())
334                        })?;
335                    let mut output = self.memory.alloc::<u8>(output_bytes)?;
336                    device.memset_zeros(&mut output).map_err(|e| {
337                        XlogError::Kernel(format!("Failed to zero count output: {}", e))
338                    })?;
339
340                    let count_func = device
341                        .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT)
342                        .ok_or_else(|| {
343                            XlogError::Kernel("groupby_count kernel not found".to_string())
344                        })?;
345
346                    // SAFETY: groupby_count(boundaries, group_ids, num_rows, counts)
347                    unsafe {
348                        count_func
349                            .clone()
350                            .launch(config, (&boundaries, &group_ids, num_rows, &output))
351                    }
352                    .map_err(|e| XlogError::Kernel(format!("groupby_count failed: {}", e)))?;
353
354                    self.device.synchronize()?;
355                    agg_columns.push(output.into());
356                }
357                AggOp::Sum => {
358                    let value_ty = sorted
359                        .schema()
360                        .column_type(value_col)
361                        .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
362                    let output_bytes = row_cap_usize
363                        .checked_mul(std::mem::size_of::<u64>())
364                        .ok_or_else(|| XlogError::Kernel("Sum output size overflow".to_string()))?;
365                    let mut output = self.memory.alloc::<u8>(output_bytes)?;
366                    device.memset_zeros(&mut output).map_err(|e| {
367                        XlogError::Kernel(format!("Failed to zero sum output: {}", e))
368                    })?;
369
370                    // U64 value-column widening: value columns reduce through the
371                    // u64-value sum kernel (same u64 accumulator).
372                    if value_ty == ScalarType::U64 {
373                        let values_view = self.column_as_u64_view(values, num_rows as usize)?;
374                        let sum_func = device
375                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM_U64)
376                            .ok_or_else(|| {
377                                XlogError::Kernel("groupby_sum_u64 kernel not found".to_string())
378                            })?;
379                        // SAFETY: groupby_sum_u64(values, group_ids, num_rows, sums)
380                        unsafe {
381                            sum_func
382                                .clone()
383                                .launch(config, (&values_view, &group_ids, num_rows, &output))
384                        }
385                        .map_err(|e| XlogError::Kernel(format!("groupby_sum_u64 failed: {}", e)))?;
386                    } else {
387                        let values_view = self.column_as_u32_view(values, num_rows as usize)?;
388                        let sum_func = device
389                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM)
390                            .ok_or_else(|| {
391                                XlogError::Kernel("groupby_sum kernel not found".to_string())
392                            })?;
393                        // SAFETY: groupby_sum(values, group_ids, num_rows, sums)
394                        unsafe {
395                            sum_func
396                                .clone()
397                                .launch(config, (&values_view, &group_ids, num_rows, &output))
398                        }
399                        .map_err(|e| XlogError::Kernel(format!("groupby_sum failed: {}", e)))?;
400                    }
401
402                    self.device.synchronize()?;
403                    agg_columns.push(output.into());
404                }
405                AggOp::Min => {
406                    let value_ty = sorted
407                        .schema()
408                        .column_type(value_col)
409                        .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
410                    if value_ty == ScalarType::U64 {
411                        // U64 value-column min path (output U64,
412                        // identity u64::MAX).
413                        let values_view = self.column_as_u64_view(values, num_rows as usize)?;
414                        let output_bytes = row_cap_usize
415                            .checked_mul(std::mem::size_of::<u64>())
416                            .ok_or_else(|| {
417                                XlogError::Kernel("Min output size overflow".to_string())
418                            })?;
419                        let mut output = self.memory.alloc::<u8>(output_bytes)?;
420                        let fill_fn = device
421                            .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U64)
422                            .ok_or_else(|| {
423                                XlogError::Kernel("arith_fill_const_u64 not found".to_string())
424                            })?;
425                        let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
426                        // SAFETY: arith_fill_const_u64(value, n, output)
427                        unsafe {
428                            fill_fn
429                                .clone()
430                                .launch(fill_config, (u64::MAX, row_cap_u32, &mut output))
431                        }
432                        .map_err(|e| {
433                            XlogError::Kernel(format!("Failed to init min output: {}", e))
434                        })?;
435
436                        let min_func = device
437                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN_U64)
438                            .ok_or_else(|| {
439                                XlogError::Kernel("groupby_min_u64 kernel not found".to_string())
440                            })?;
441                        // SAFETY: groupby_min_u64(values, group_ids, num_rows, mins)
442                        unsafe {
443                            min_func
444                                .clone()
445                                .launch(config, (&values_view, &group_ids, num_rows, &output))
446                        }
447                        .map_err(|e| XlogError::Kernel(format!("groupby_min_u64 failed: {}", e)))?;
448
449                        self.device.synchronize()?;
450                        agg_columns.push(output.into());
451                    } else {
452                        let values_view = self.column_as_u32_view(values, num_rows as usize)?;
453                        let output_bytes = row_cap_usize
454                            .checked_mul(std::mem::size_of::<u32>())
455                            .ok_or_else(|| {
456                                XlogError::Kernel("Min output size overflow".to_string())
457                            })?;
458                        let mut output = self.memory.alloc::<u8>(output_bytes)?;
459                        let fill_fn = device
460                            .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U32)
461                            .ok_or_else(|| {
462                                XlogError::Kernel("arith_fill_const_u32 not found".to_string())
463                            })?;
464                        let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
465                        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
466                        unsafe {
467                            fill_fn
468                                .clone()
469                                .launch(fill_config, (u32::MAX, row_cap_u32, &mut output))
470                        }
471                        .map_err(|e| {
472                            XlogError::Kernel(format!("Failed to init min output: {}", e))
473                        })?;
474
475                        let min_func = device
476                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN)
477                            .ok_or_else(|| {
478                                XlogError::Kernel("groupby_min kernel not found".to_string())
479                            })?;
480
481                        // SAFETY: groupby_min(values, group_ids, num_rows, mins)
482                        unsafe {
483                            min_func
484                                .clone()
485                                .launch(config, (&values_view, &group_ids, num_rows, &output))
486                        }
487                        .map_err(|e| XlogError::Kernel(format!("groupby_min failed: {}", e)))?;
488
489                        self.device.synchronize()?;
490                        agg_columns.push(output.into());
491                    }
492                }
493                AggOp::Max => {
494                    let value_ty = sorted
495                        .schema()
496                        .column_type(value_col)
497                        .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
498                    if value_ty == ScalarType::U64 {
499                        // U64 value-column max path (output U64,
500                        // identity 0).
501                        let values_view = self.column_as_u64_view(values, num_rows as usize)?;
502                        let output_bytes = row_cap_usize
503                            .checked_mul(std::mem::size_of::<u64>())
504                            .ok_or_else(|| {
505                                XlogError::Kernel("Max output size overflow".to_string())
506                            })?;
507                        let mut output = self.memory.alloc::<u8>(output_bytes)?;
508                        device.memset_zeros(&mut output).map_err(|e| {
509                            XlogError::Kernel(format!("Failed to zero max output: {}", e))
510                        })?;
511
512                        let max_func = device
513                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX_U64)
514                            .ok_or_else(|| {
515                                XlogError::Kernel("groupby_max_u64 kernel not found".to_string())
516                            })?;
517                        // SAFETY: groupby_max_u64(values, group_ids, num_rows, maxs)
518                        unsafe {
519                            max_func
520                                .clone()
521                                .launch(config, (&values_view, &group_ids, num_rows, &output))
522                        }
523                        .map_err(|e| XlogError::Kernel(format!("groupby_max_u64 failed: {}", e)))?;
524
525                        self.device.synchronize()?;
526                        agg_columns.push(output.into());
527                    } else {
528                        let values_view = self.column_as_u32_view(values, num_rows as usize)?;
529                        let output_bytes = row_cap_usize
530                            .checked_mul(std::mem::size_of::<u32>())
531                            .ok_or_else(|| {
532                                XlogError::Kernel("Max output size overflow".to_string())
533                            })?;
534                        let mut output = self.memory.alloc::<u8>(output_bytes)?;
535                        device.memset_zeros(&mut output).map_err(|e| {
536                            XlogError::Kernel(format!("Failed to zero max output: {}", e))
537                        })?;
538
539                        let max_func = device
540                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX)
541                            .ok_or_else(|| {
542                                XlogError::Kernel("groupby_max kernel not found".to_string())
543                            })?;
544
545                        // SAFETY: groupby_max(values, group_ids, num_rows, maxs)
546                        unsafe {
547                            max_func
548                                .clone()
549                                .launch(config, (&values_view, &group_ids, num_rows, &output))
550                        }
551                        .map_err(|e| XlogError::Kernel(format!("groupby_max failed: {}", e)))?;
552
553                        self.device.synchronize()?;
554                        agg_columns.push(output.into());
555                    }
556                }
557                AggOp::LogSumExp => {
558                    let values_f64 = self.column_as_f64_view(values, num_rows as usize)?;
559                    let output_bytes = row_cap_usize
560                        .checked_mul(std::mem::size_of::<f64>())
561                        .ok_or_else(|| {
562                            XlogError::Kernel("LogSumExp output size overflow".to_string())
563                        })?;
564                    let mut maxs = self.memory.alloc::<u8>(output_bytes)?;
565                    let mut sumexps = self.memory.alloc::<u8>(output_bytes)?;
566                    let results = self.memory.alloc::<u8>(output_bytes)?;
567
568                    let fill_f64 = device
569                        .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_F64)
570                        .ok_or_else(|| {
571                            XlogError::Kernel("arith_fill_const_f64 not found".to_string())
572                        })?;
573                    let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
574                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
575                    unsafe {
576                        fill_f64
577                            .clone()
578                            .launch(fill_config, (f64::NEG_INFINITY, row_cap_u32, &mut maxs))
579                    }
580                    .map_err(|e| XlogError::Kernel(format!("Failed to init maxs: {}", e)))?;
581                    device
582                        .memset_zeros(&mut sumexps)
583                        .map_err(|e| XlogError::Kernel(format!("Failed to init sumexps: {}", e)))?;
584
585                    let max_func = device
586                        .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_MAX)
587                        .ok_or_else(|| {
588                            XlogError::Kernel("groupby_logsumexp_max kernel not found".to_string())
589                        })?;
590
591                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
592                    unsafe {
593                        max_func
594                            .clone()
595                            .launch(config, (&values_f64, &group_ids, num_rows, &maxs))
596                    }
597                    .map_err(|e| {
598                        XlogError::Kernel(format!("groupby_logsumexp_max failed: {}", e))
599                    })?;
600
601                    self.device.synchronize()?;
602
603                    let sumexp_func = device
604                        .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_SUMEXP)
605                        .ok_or_else(|| {
606                            XlogError::Kernel(
607                                "groupby_logsumexp_sumexp kernel not found".to_string(),
608                            )
609                        })?;
610
611                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
612                    unsafe {
613                        sumexp_func
614                            .clone()
615                            .launch(config, (&values_f64, &group_ids, &maxs, num_rows, &sumexps))
616                    }
617                    .map_err(|e| {
618                        XlogError::Kernel(format!("groupby_logsumexp_sumexp failed: {}", e))
619                    })?;
620
621                    self.device.synchronize()?;
622
623                    let final_config = LaunchConfig::for_num_elems(row_cap_u32);
624                    let final_func = device
625                        .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_LOGSUMEXP_FINAL)
626                        .ok_or_else(|| {
627                            XlogError::Kernel(
628                                "groupby_logsumexp_final kernel not found".to_string(),
629                            )
630                        })?;
631
632                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
633                    unsafe {
634                        final_func.clone().launch(
635                            final_config,
636                            (&maxs, &sumexps, &d_num_groups, row_cap_u32, &results),
637                        )
638                    }
639                    .map_err(|e| {
640                        XlogError::Kernel(format!("groupby_logsumexp_final failed: {}", e))
641                    })?;
642
643                    self.device.synchronize()?;
644                    agg_columns.push(results.into());
645                }
646            }
647        }
648
649        // Step 5: Build output buffer with keys and aggregated values.
650        let mut result_columns: Vec<CudaColumn> = Vec::with_capacity(key_cols.len() + aggs.len());
651
652        let group_packed_bytes = row_cap_usize
653            .checked_mul(packed.key_bytes as usize)
654            .ok_or_else(|| XlogError::Kernel("GroupBy packed size overflow".to_string()))?;
655        let mut group_packed = self.memory.alloc::<u8>(group_packed_bytes)?;
656
657        let gather_fn = device
658            .get_func(PACK_MODULE, pack_kernels::GATHER_PACKED_ROWS_COUNTED)
659            .ok_or_else(|| {
660                XlogError::Kernel("gather_packed_rows_counted kernel not found".to_string())
661            })?;
662        let gather_config = LaunchConfig::for_num_elems(row_cap_u32);
663
664        // SAFETY: gather_packed_rows_counted(src_packed, row_size, indices, num_rows, capacity_rows, dst_packed)
665        unsafe {
666            gather_fn.clone().launch(
667                gather_config,
668                (
669                    &packed.packed_keys,
670                    packed.key_bytes,
671                    &group_first_idx,
672                    &d_num_groups,
673                    row_cap_u32,
674                    &mut group_packed,
675                ),
676            )
677        }
678        .map_err(|e| XlogError::Kernel(format!("gather_packed_rows failed: {}", e)))?;
679
680        let mut col_offsets: Vec<u32> = Vec::with_capacity(key_cols.len());
681        let mut col_sizes: Vec<u32> = Vec::with_capacity(key_cols.len());
682        let mut offset = 0u32;
683        for &key_col in key_cols {
684            let size = buffer
685                .schema()
686                .column_type(key_col)
687                .map(|t| t.size_bytes() as u32)
688                .unwrap_or(4);
689            col_offsets.push(offset);
690            col_sizes.push(size);
691            offset = offset
692                .checked_add(size)
693                .ok_or_else(|| XlogError::Kernel("GroupBy key size overflow".to_string()))?;
694        }
695
696        let unpack_fn = device
697            .get_func(PACK_MODULE, pack_kernels::UNPACK_COLUMN_COUNTED)
698            .ok_or_else(|| {
699                XlogError::Kernel("unpack_column_counted kernel not found".to_string())
700            })?;
701        let unpack_config = LaunchConfig::for_num_elems(row_cap_u32);
702
703        for idx in 0..key_cols.len() {
704            let col_size = col_sizes[idx];
705            let col_offset = col_offsets[idx];
706            let out_bytes = row_cap_usize
707                .checked_mul(col_size as usize)
708                .ok_or_else(|| XlogError::Kernel("GroupBy key column overflow".to_string()))?;
709            let mut out_col = self.memory.alloc::<u8>(out_bytes)?;
710
711            // SAFETY: unpack_column_counted(packed_input, row_size, col_offset, col_size, num_rows, capacity_rows, col_output)
712            unsafe {
713                unpack_fn.clone().launch(
714                    unpack_config,
715                    (
716                        &group_packed,
717                        packed.key_bytes,
718                        col_offset,
719                        col_size,
720                        &d_num_groups,
721                        row_cap_u32,
722                        &mut out_col,
723                    ),
724                )
725            }
726            .map_err(|e| XlogError::Kernel(format!("unpack_column failed: {}", e)))?;
727
728            result_columns.push(out_col.into());
729        }
730
731        result_columns.extend(agg_columns);
732
733        let result_schema = self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
734
735        Ok(CudaBuffer::from_columns(
736            result_columns,
737            row_cap,
738            d_num_groups,
739            result_schema,
740        ))
741    }
742
743    fn capture_num_groups(
744        &self,
745        boundary_pos: &TrackedCudaSlice<u32>,
746        boundaries: &TrackedCudaSlice<u8>,
747        num_rows: u32,
748    ) -> Result<TrackedCudaSlice<u32>> {
749        let mut d_num_groups = self.memory.alloc::<u32>(1)?;
750        let capture_fn = self
751            .device
752            .inner()
753            .get_func(GROUPBY_MODULE, groupby_kernels::CAPTURE_NUM_GROUPS)
754            .ok_or_else(|| XlogError::Kernel("capture_num_groups kernel not found".to_string()))?;
755        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
756        unsafe {
757            capture_fn.clone().launch(
758                LaunchConfig {
759                    grid_dim: (1, 1, 1),
760                    block_dim: (1, 1, 1),
761                    shared_mem_bytes: 0,
762                },
763                (boundary_pos, boundaries, num_rows, &mut d_num_groups),
764            )
765        }
766        .map_err(|e| XlogError::Kernel(format!("capture_num_groups failed: {}", e)))?;
767        Ok(d_num_groups)
768    }
769
770    /// Create result schema for multi-aggregation groupby
771    pub(crate) fn groupby_multi_agg_result_schema(
772        &self,
773        input: &Schema,
774        key_cols: &[usize],
775        aggs: &[(usize, AggOp)],
776    ) -> Schema {
777        let mut columns: Vec<(String, ScalarType)> = key_cols
778            .iter()
779            .filter_map(|&i| input.columns.get(i).cloned())
780            .collect();
781        let mut sort_labels: Vec<String> = key_cols
782            .iter()
783            .filter_map(|&i| {
784                input
785                    .column_sort_label(i)
786                    .map(ToString::to_string)
787                    .or_else(|| input.columns.get(i).map(|(name, _)| name.clone()))
788            })
789            .collect();
790
791        for (i, &(value_col, agg_op)) in aggs.iter().enumerate() {
792            let agg_name = match agg_op {
793                AggOp::Count => format!("count_{}", i),
794                AggOp::Sum => format!("sum_{}", i),
795                AggOp::Min => format!("min_{}", i),
796                AggOp::Max => format!("max_{}", i),
797                AggOp::LogSumExp => format!("logsumexp_{}", i),
798            };
799            // Return correct types for each aggregation
800            // Count and Sum use u64 to match predicate declarations and prevent overflow
801            let agg_type = match agg_op {
802                AggOp::Count => ScalarType::U64,
803                AggOp::Sum => ScalarType::U64,
804                // Value-width preserving min/max: preserve the value column's width
805                // (U64 values reduce to U64; everything else stays U32).
806                AggOp::Min | AggOp::Max => match input.columns.get(value_col).map(|(_, ty)| *ty) {
807                    Some(ScalarType::U64) => ScalarType::U64,
808                    _ => ScalarType::U32,
809                },
810                AggOp::LogSumExp => ScalarType::F64,
811            };
812            columns.push((agg_name, agg_type));
813            sort_labels.push(format!("aggregate_{}", i));
814        }
815
816        Schema::new(columns)
817            .with_sort_labels(sort_labels)
818            .expect("groupby result sort labels match column arity")
819    }
820
821    // ======================================================================
822    // Recorded GroupBy provider-level path
823    //
824    // Strict-recorder, launch_stream-routed sibling of `groupby_multi_agg`.
825    // Scope-narrow per the slice directive:
826    //   * U32 / Symbol key columns only (delegates to sort_recorded which has
827    //     the same constraint).
828    //   * Aggs: Count, Sum, Min, Max only. LogSumExp is a multi-kernel
829    //     chain (max → sumexp → final) and is outside this API surface.
830    //   * No legacy default-routed code is touched. The legacy
831    //     `groupby_multi_agg` and `groupby_agg` keep their semantics
832    //     bit-for-bit; runtime/planner wiring is NOT included.
833    // ======================================================================
834
835    /// Stream-aware variant of `pack_keys_gpu` (≤4 columns).
836    /// Mirrors the legacy fused `pack_and_hash_keys` launch on
837    /// `launch_stream`, then records the kernel's intermediate
838    /// returned outputs (`packed_keys`, `hashes`) against the runtime so that
839    /// downstream consumers / drops are correctly serialized
840    /// against `launch_stream`.
841    pub(super) fn pack_keys_gpu_on_stream(
842        &self,
843        buffer: &CudaBuffer,
844        key_cols: &[usize],
845        cu_stream: &cudarc::driver::CudaStream,
846        launch_stream: crate::device_runtime::StreamId,
847        runtime: &crate::device_runtime::XlogDeviceRuntime,
848    ) -> Result<crate::provider::PackedKeyData> {
849        use crate::launch::LaunchRecorder;
850
851        if key_cols.is_empty() {
852            return Err(XlogError::Kernel(
853                "pack_keys_gpu_on_stream: no key columns specified".to_string(),
854            ));
855        }
856        if key_cols.len() > 4 {
857            return Err(XlogError::Kernel(
858                "pack_keys_gpu_on_stream: max 4 key columns supported".to_string(),
859            ));
860        }
861        let num_rows = self.device_row_count(buffer)?;
862        if num_rows > u32::MAX as usize {
863            return Err(XlogError::Kernel(format!(
864                "pack_keys_gpu_on_stream supports at most {} rows, got {}",
865                u32::MAX,
866                num_rows
867            )));
868        }
869        let num_rows = num_rows as u32;
870
871        let mut col_sizes_host: Vec<u32> = Vec::with_capacity(key_cols.len());
872        let mut row_size: u32 = 0;
873        for &col_idx in key_cols {
874            let ty = buffer
875                .schema()
876                .column_type(col_idx)
877                .ok_or_else(|| XlogError::Kernel(format!("Invalid column index: {}", col_idx)))?;
878            let s = ty.size_bytes() as u32;
879            col_sizes_host.push(s);
880            row_size += s;
881        }
882
883        if num_rows == 0 {
884            return Ok(crate::provider::PackedKeyData {
885                hashes: self.memory.alloc::<u64>(0)?,
886                packed_keys: self.memory.alloc::<u8>(0)?,
887                key_bytes: row_size,
888            });
889        }
890
891        let packed_bytes = (num_rows as u64) * (row_size as u64);
892        let packed_slice = self.memory.alloc::<u8>(packed_bytes as usize)?;
893        let hash_slice = self.memory.alloc::<u64>(num_rows as usize)?;
894
895        let mut col_ptrs: [u64; 4] = [0; 4];
896        for (i, &col_idx) in key_cols.iter().enumerate() {
897            let col = buffer
898                .column(col_idx)
899                .ok_or_else(|| XlogError::Kernel(format!("Key column {} not found", col_idx)))?;
900            col_ptrs[i] = *col.device_ptr();
901        }
902        let mut packed_col_sizes = 0u64;
903        for (i, size) in col_sizes_host.iter().copied().enumerate() {
904            if size > u16::MAX as u32 {
905                return Err(XlogError::Kernel(format!(
906                    "pack_keys_gpu_on_stream: column element size {} exceeds 16-bit kernel argument",
907                    size
908                )));
909            }
910            packed_col_sizes |= (size as u64) << (i * 16);
911        }
912
913        // The pack kernel takes raw column pointers (`u64`)
914        // rather than typed `CudaColumn` kernel params, so the
915        // generic launch recorder cannot infer source-column
916        // lifetimes from the argument list. Record those reads
917        // explicitly before queueing the launch; this also
918        // enforces the strict external-memory policy for
919        // recorded paths.
920        let mut rec = LaunchRecorder::new_strict(launch_stream);
921        for &col_idx in key_cols {
922            let col = buffer
923                .column(col_idx)
924                .ok_or_else(|| XlogError::Kernel(format!("Key column {} not found", col_idx)))?;
925            rec.read_column(col);
926        }
927        rec.write(&packed_slice);
928        rec.write(&hash_slice);
929        rec.preflight(runtime).map_err(|e| {
930            XlogError::Kernel(format!(
931                "pack_keys_gpu_on_stream: launch recorder preflight failed: {}",
932                e
933            ))
934        })?;
935
936        let func = self
937            .device
938            .inner()
939            .get_func(PACK_MODULE, pack_kernels::PACK_AND_HASH_KEYS)
940            .ok_or_else(|| XlogError::Kernel("pack_and_hash_keys kernel not found".to_string()))?;
941        let block_size = 256u32;
942        let grid_size = num_rows.div_ceil(block_size);
943        let cfg = LaunchConfig {
944            grid_dim: (grid_size, 1, 1),
945            block_dim: (block_size, 1, 1),
946            shared_mem_bytes: 0,
947        };
948        // SAFETY: pack_and_hash_keys signature.
949        unsafe {
950            func.clone().launch_on_stream(
951                cu_stream,
952                cfg,
953                (
954                    col_ptrs[0],
955                    col_ptrs[1],
956                    col_ptrs[2],
957                    col_ptrs[3],
958                    packed_col_sizes,
959                    key_cols.len() as u32,
960                    num_rows,
961                    row_size,
962                    &packed_slice,
963                    &hash_slice,
964                ),
965            )
966        }
967        .map_err(|e| XlogError::Kernel(format!("pack_and_hash_keys (on_stream) failed: {}", e)))?;
968
969        // Record uses for buffers touched on launch_stream.
970        // `packed_slice` / `hash_slice` are fresh outputs that
971        // escape to the caller. The post-preflight-fresh path is
972        // valid because they were allocated by this helper before
973        // preflight and first used by the queued pack launch.
974        rec.commit(runtime).map_err(|e| {
975            XlogError::Kernel(format!(
976                "pack_keys_gpu_on_stream: launch recorder commit failed: {}",
977                e
978            ))
979        })?;
980
981        Ok(crate::provider::PackedKeyData {
982            hashes: hash_slice,
983            packed_keys: packed_slice,
984            key_bytes: row_size,
985        })
986    }
987
988    /// Async u8 zero-fill on `cu_stream` via `cuMemsetD8Async`.
989    /// Used by recorded GroupBy aggregations that need a
990    /// freshly zeroed output buffer (Count, Sum, Max).
991    fn memset_zeros_u8_on_stream(
992        &self,
993        buf: &mut TrackedCudaSlice<u8>,
994        cu_stream: &cudarc::driver::CudaStream,
995    ) -> Result<()> {
996        if buf.is_empty() {
997            return Ok(());
998        }
999        let ptr = *buf.device_ptr();
1000        let len = <TrackedCudaSlice<u8> as crate::DeviceSlice<u8>>::len(buf);
1001        // SAFETY: ptr is a live runtime-backed device pointer
1002        // for `len` bytes, cu_stream is a valid CUDA stream
1003        // owned by the runtime's pool. cuMemsetD8Async queues
1004        // and returns immediately.
1005        unsafe {
1006            let res = cudarc::driver::sys::cuMemsetD8Async(ptr, 0, len, cu_stream.cu_stream());
1007            if res != cudarc::driver::sys::cudaError_enum::CUDA_SUCCESS {
1008                return Err(XlogError::Kernel(format!(
1009                    "cuMemsetD8Async (groupby init) failed: {:?}",
1010                    res
1011                )));
1012            }
1013        }
1014        Ok(())
1015    }
1016
1017    /// Strict-recorder variant of [`Self::groupby_multi_agg`].
1018    ///
1019    /// Sort + pack + boundary detect + scan + capture-num-groups
1020    /// + group-id derivation + per-aggregation kernels + key
1021    ///   gather/unpack — every kernel runs on the caller-supplied
1022    ///   `launch_stream` via `launch_on_stream`. Composition with
1023    ///   existing recorded primitives:
1024    ///   * `sort_recorded` does the typed multi-column
1025    ///     sort and commits its own LaunchRecorder.
1026    ///   * `pack_keys_gpu_on_stream` runs the fused
1027    ///     pack+hash kernel on launch_stream and records its
1028    ///     buffers directly via `record_block_use`.
1029    ///   * `multiblock_scan_u32_inplace_on_stream`
1030    ///     drives the boundary-position scan tail.
1031    ///   * The groupby-specific chain has its own LaunchRecorder
1032    ///     for the boundary mask, group ids, group_first
1033    ///     indices, num_groups scalar, per-aggregation outputs,
1034    ///     and key gather/unpack outputs.
1035    ///
1036    /// Composition correctness: each recorder commits
1037    /// independently; the runtime's record-all + wait-all
1038    /// `last_use_events: Vec<CudaEvent>` semantics chain the
1039    /// deallocate safety end-to-end across the four primitive
1040    /// commits.
1041    ///
1042    /// # Scope (narrow)
1043    /// * U32 / Symbol key columns only (sort_recorded
1044    ///   constraint).
1045    /// * Aggs: Count, Sum, Min, Max. LogSumExp is rejected with
1046    ///   a structured error — its multi-kernel chain is
1047    ///   outside this recorded provider surface.
1048    /// * Manager must be runtime-backed.
1049    pub fn groupby_multi_agg_recorded(
1050        &self,
1051        buffer: &CudaBuffer,
1052        key_cols: &[usize],
1053        aggs: &[(usize, AggOp)],
1054        launch_stream: crate::device_runtime::StreamId,
1055    ) -> Result<CudaBuffer> {
1056        use crate::launch::LaunchRecorder;
1057
1058        let runtime = self.memory.runtime().ok_or_else(|| {
1059            XlogError::Kernel(
1060                "groupby_multi_agg_recorded requires a runtime-backed GpuMemoryManager".to_string(),
1061            )
1062        })?;
1063        let cu_stream = runtime
1064            .stream_pool()
1065            .resolve(launch_stream)
1066            .ok_or_else(|| {
1067                XlogError::Kernel(format!(
1068                    "groupby_multi_agg_recorded: launch_stream StreamId({}) does not resolve",
1069                    launch_stream.0
1070                ))
1071            })?;
1072
1073        let num_rows = self.device_row_count(buffer)?;
1074        if num_rows == 0 {
1075            let result_schema =
1076                self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
1077            return self.create_empty_buffer(result_schema);
1078        }
1079        if num_rows > u32::MAX as usize {
1080            return Err(XlogError::Kernel(format!(
1081                "GroupBy supports at most {} rows, got {}",
1082                u32::MAX,
1083                num_rows
1084            )));
1085        }
1086        if key_cols.is_empty() {
1087            return Err(XlogError::Kernel(
1088                "GroupBy requires at least one key column".to_string(),
1089            ));
1090        }
1091        if aggs.is_empty() {
1092            return Err(XlogError::Kernel(
1093                "GroupBy requires at least one aggregation".to_string(),
1094            ));
1095        }
1096        if key_cols.len() > 4 {
1097            return Err(XlogError::Kernel(
1098                "groupby_multi_agg_recorded: max 4 key columns supported (pack_keys constraint)"
1099                    .to_string(),
1100            ));
1101        }
1102        for &k in key_cols {
1103            if k >= buffer.arity() {
1104                return Err(XlogError::Kernel(format!(
1105                    "Key column {} out of bounds (arity {})",
1106                    k,
1107                    buffer.arity()
1108                )));
1109            }
1110            let ty = buffer
1111                .schema()
1112                .column_type(k)
1113                .ok_or_else(|| XlogError::Kernel("Key column has no type".to_string()))?;
1114            if !matches!(ty, ScalarType::U32 | ScalarType::Symbol) {
1115                return Err(XlogError::Kernel(format!(
1116                    "groupby_multi_agg_recorded: key column type {:?} unsupported (U32 / Symbol \
1117                     only); multi-type sort_recorded is deferred",
1118                    ty
1119                )));
1120            }
1121        }
1122        for &(value_col, agg_op) in aggs {
1123            if value_col >= buffer.arity() {
1124                return Err(XlogError::Kernel(format!(
1125                    "Value column {} out of bounds (arity {})",
1126                    value_col,
1127                    buffer.arity()
1128                )));
1129            }
1130            let value_ty = buffer
1131                .schema()
1132                .column_type(value_col)
1133                .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1134            match agg_op {
1135                AggOp::Count => {}
1136                AggOp::Sum => {
1137                    // Recorded groupby U64 value-column sum path: values reduce through the u64-value
1138                    // sum kernel (same u64 accumulator as the U32 path).
1139                    if !matches!(value_ty, ScalarType::U32 | ScalarType::U64) {
1140                        return Err(XlogError::Kernel(format!(
1141                            "Sum currently requires U32 or U64 values, got {:?}",
1142                            value_ty
1143                        )));
1144                    }
1145                }
1146                AggOp::Min | AggOp::Max => {
1147                    // Recorded groupby U64 value-column min/max path: values reduce through the
1148                    // u64-value min/max kernels (result preserves the
1149                    // value width, mirroring the legacy path).
1150                    if !matches!(value_ty, ScalarType::U32 | ScalarType::U64) {
1151                        return Err(XlogError::Kernel(format!(
1152                            "{:?} currently requires U32 or U64 values, got {:?}",
1153                            agg_op, value_ty
1154                        )));
1155                    }
1156                }
1157                AggOp::LogSumExp => {
1158                    return Err(XlogError::Kernel(
1159                        "groupby_multi_agg_recorded: LogSumExp not yet supported in the \
1160                        recorded path (multi-kernel chain deferred to a future implementation)"
1161                            .to_string(),
1162                    ));
1163                }
1164            }
1165        }
1166
1167        // Step 1: sort by key columns (recorded sort, U32/Symbol only).
1168        let sorted = self.sort_recorded(buffer, key_cols, launch_stream)?;
1169        let num_rows = self.device_row_count(&sorted)?;
1170        if num_rows > u32::MAX as usize {
1171            return Err(XlogError::Kernel(format!(
1172                "GroupBy supports at most {} rows, got {}",
1173                u32::MAX,
1174                num_rows
1175            )));
1176        }
1177        let num_rows = num_rows as u32;
1178        let row_cap_usize = num_rows as usize;
1179        let row_cap_u32 = num_rows;
1180        let row_cap_u64 = num_rows as u64;
1181
1182        // Step 2: pack keys on launch_stream.
1183        let packed =
1184            self.pack_keys_gpu_on_stream(&sorted, key_cols, &cu_stream, launch_stream, runtime)?;
1185        if packed.key_bytes == 0 || packed.key_bytes % 4 != 0 {
1186            return Err(XlogError::Kernel(format!(
1187                "GroupBy key packing produced {} bytes per row (expected multiple of 4); \
1188                 Bool keys are not supported",
1189                packed.key_bytes
1190            )));
1191        }
1192        let segments_per_row = (packed.key_bytes / 4) as usize;
1193        let total_segments = row_cap_usize * segments_per_row;
1194        let packed_u32 = self.bytes_as_u32_view(&packed.packed_keys, total_segments)?;
1195
1196        // Step 3: allocate ALL fresh runtime-backed buffers
1197        // BEFORE the GroupBy recorder (Rust drop order — the
1198        // recorder's 'b lifetime must outlive every borrow it
1199        // holds via post_preflight_fresh).
1200        let boundaries = self.memory.alloc::<u8>(row_cap_usize)?;
1201        let block_size = 256u32;
1202        let num_blocks = num_rows.div_ceil(block_size);
1203        let cfg = LaunchConfig {
1204            grid_dim: (num_blocks, 1, 1),
1205            block_dim: (block_size, 1, 1),
1206            shared_mem_bytes: 0,
1207        };
1208        let d_boundary_pos = self.memory.alloc::<u32>(row_cap_usize)?;
1209        let mut d_block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
1210        let mut d_num_groups = self.memory.alloc::<u32>(1)?;
1211        let mut group_ids = self.memory.alloc::<u32>(row_cap_usize)?;
1212        let mut group_first_idx = self.memory.alloc::<u32>(row_cap_usize)?;
1213
1214        // Per-aggregation outputs (allocated up front).
1215        let mut agg_outputs: Vec<TrackedCudaSlice<u8>> = Vec::with_capacity(aggs.len());
1216        for &(value_col, agg_op) in aggs {
1217            let elem_size = match agg_op {
1218                AggOp::Count | AggOp::Sum => std::mem::size_of::<u64>(),
1219                // Value-width preserving min/max: preserve the value column's width.
1220                AggOp::Min | AggOp::Max => match sorted.schema().column_type(value_col) {
1221                    Some(ScalarType::U64) => std::mem::size_of::<u64>(),
1222                    _ => std::mem::size_of::<u32>(),
1223                },
1224                AggOp::LogSumExp => unreachable!("rejected above"),
1225            };
1226            let bytes = row_cap_usize
1227                .checked_mul(elem_size)
1228                .ok_or_else(|| XlogError::Kernel("groupby agg output size overflow".to_string()))?;
1229            agg_outputs.push(self.memory.alloc::<u8>(bytes)?);
1230        }
1231
1232        // Key gather + unpack outputs.
1233        let group_packed_bytes = row_cap_usize
1234            .checked_mul(packed.key_bytes as usize)
1235            .ok_or_else(|| XlogError::Kernel("GroupBy packed size overflow".to_string()))?;
1236        let mut group_packed = self.memory.alloc::<u8>(group_packed_bytes)?;
1237
1238        let mut col_offsets: Vec<u32> = Vec::with_capacity(key_cols.len());
1239        let mut col_sizes: Vec<u32> = Vec::with_capacity(key_cols.len());
1240        let mut offset = 0u32;
1241        for &key_col in key_cols {
1242            let s = buffer
1243                .schema()
1244                .column_type(key_col)
1245                .map(|t| t.size_bytes() as u32)
1246                .unwrap_or(4);
1247            col_offsets.push(offset);
1248            col_sizes.push(s);
1249            offset = offset
1250                .checked_add(s)
1251                .ok_or_else(|| XlogError::Kernel("GroupBy key size overflow".to_string()))?;
1252        }
1253        let mut key_unpacked: Vec<TrackedCudaSlice<u8>> = Vec::with_capacity(key_cols.len());
1254        for &col_size in &col_sizes {
1255            let bytes = row_cap_usize
1256                .checked_mul(col_size as usize)
1257                .ok_or_else(|| XlogError::Kernel("GroupBy key column overflow".to_string()))?;
1258            key_unpacked.push(self.memory.alloc::<u8>(bytes)?);
1259        }
1260
1261        // Build the recorder. Reads BEFORE preflight: the
1262        // sorted buffer's value columns + num_rows_device, plus
1263        // the packed_keys produced by pack_keys_on_stream
1264        // (which already recorded its own writes against
1265        // launch_stream — we record reads here so the chain
1266        // ordering is explicit).
1267        let mut rec = LaunchRecorder::new_strict(launch_stream);
1268        rec.read(sorted.num_rows_device());
1269        // sort_recorded already recorded reads on every input
1270        // column on launch_stream; packed_keys is the new
1271        // launch_stream-resident input to the boundary chain.
1272        rec.read(&packed.packed_keys);
1273        for &(value_col, _) in aggs {
1274            let c = sorted.column(value_col).ok_or_else(|| {
1275                XlogError::Kernel(format!("Value column {} not found", value_col))
1276            })?;
1277            rec.read_column(c);
1278        }
1279        rec.write(&boundaries);
1280        rec.write(&d_boundary_pos);
1281        rec.write(&d_block_sums);
1282        rec.write(&d_num_groups);
1283        rec.write(&group_ids);
1284        rec.write(&group_first_idx);
1285        rec.write(&group_packed);
1286        for o in &agg_outputs {
1287            rec.write(o);
1288        }
1289        for k in &key_unpacked {
1290            rec.write(k);
1291        }
1292        rec.preflight(runtime).map_err(|e| {
1293            XlogError::Kernel(format!(
1294                "groupby_multi_agg_recorded: preflight failed: {}",
1295                e
1296            ))
1297        })?;
1298
1299        let device = self.device.inner();
1300
1301        // Step 4: detect_group_boundaries on launch_stream.
1302        let boundary_func = device
1303            .get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES)
1304            .ok_or_else(|| {
1305                XlogError::Kernel("detect_group_boundaries kernel not found".to_string())
1306            })?;
1307        // SAFETY: detect_group_boundaries(packed_u32, num_rows, segments_per_row, segments_per_row, boundaries)
1308        unsafe {
1309            boundary_func.clone().launch_on_stream(
1310                &cu_stream,
1311                cfg,
1312                (
1313                    &packed_u32,
1314                    num_rows,
1315                    segments_per_row as u32,
1316                    segments_per_row as u32,
1317                    &boundaries,
1318                ),
1319            )
1320        }
1321        .map_err(|e| {
1322            XlogError::Kernel(format!("detect_group_boundaries (on_stream) failed: {}", e))
1323        })?;
1324
1325        // Step 5: multi-block scan over boundary mask (yielding boundary positions).
1326        let phase1_fn = device
1327            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE1)
1328            .ok_or_else(|| {
1329                XlogError::Kernel("Failed to get multiblock_scan_phase1 kernel".to_string())
1330            })?;
1331        // SAFETY: multiblock_scan_phase1(mask, prefix_sum, block_sums, n)
1332        unsafe {
1333            phase1_fn.clone().launch_on_stream(
1334                &cu_stream,
1335                LaunchConfig {
1336                    grid_dim: (num_blocks, 1, 1),
1337                    block_dim: (block_size, 1, 1),
1338                    shared_mem_bytes: 0,
1339                },
1340                (&boundaries, &d_boundary_pos, &d_block_sums, num_rows),
1341            )
1342        }
1343        .map_err(|e| {
1344            XlogError::Kernel(format!("multiblock_scan_phase1 (on_stream) failed: {}", e))
1345        })?;
1346
1347        if num_blocks > 1 {
1348            self.multiblock_scan_u32_inplace_on_stream(
1349                &mut d_block_sums,
1350                num_blocks,
1351                &cu_stream,
1352                launch_stream,
1353                runtime,
1354            )?;
1355            let phase3_fn = device
1356                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
1357                .ok_or_else(|| {
1358                    XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
1359                })?;
1360            // SAFETY: multiblock_scan_phase3(prefix_sum, block_offsets, n)
1361            unsafe {
1362                phase3_fn.clone().launch_on_stream(
1363                    &cu_stream,
1364                    LaunchConfig {
1365                        grid_dim: (num_blocks, 1, 1),
1366                        block_dim: (block_size, 1, 1),
1367                        shared_mem_bytes: 0,
1368                    },
1369                    (&d_boundary_pos, &d_block_sums, num_rows),
1370                )
1371            }
1372            .map_err(|e| {
1373                XlogError::Kernel(format!("multiblock_scan_phase3 (on_stream) failed: {}", e))
1374            })?;
1375        }
1376
1377        // Step 6: capture_num_groups on launch_stream.
1378        let capture_fn = device
1379            .get_func(GROUPBY_MODULE, groupby_kernels::CAPTURE_NUM_GROUPS)
1380            .ok_or_else(|| XlogError::Kernel("capture_num_groups kernel not found".to_string()))?;
1381        // SAFETY: capture_num_groups(boundary_pos, boundaries, num_rows, num_groups)
1382        unsafe {
1383            capture_fn.clone().launch_on_stream(
1384                &cu_stream,
1385                LaunchConfig {
1386                    grid_dim: (1, 1, 1),
1387                    block_dim: (1, 1, 1),
1388                    shared_mem_bytes: 0,
1389                },
1390                (&d_boundary_pos, &boundaries, num_rows, &mut d_num_groups),
1391            )
1392        }
1393        .map_err(|e| XlogError::Kernel(format!("capture_num_groups (on_stream) failed: {}", e)))?;
1394
1395        // Step 7: derive group_ids + group_first_idx on launch_stream.
1396        let group_ids_fn = device
1397            .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_IDS_FROM_BOUNDARIES)
1398            .ok_or_else(|| {
1399                XlogError::Kernel("group_ids_from_boundaries kernel not found".to_string())
1400            })?;
1401        let group_start_fn = device
1402            .get_func(GROUPBY_MODULE, groupby_kernels::GROUP_START_INDICES)
1403            .ok_or_else(|| XlogError::Kernel("group_start_indices kernel not found".to_string()))?;
1404        // SAFETY: matches kernel signatures.
1405        unsafe {
1406            group_ids_fn.clone().launch_on_stream(
1407                &cu_stream,
1408                cfg,
1409                (&boundaries, &d_boundary_pos, num_rows, &mut group_ids),
1410            )
1411        }
1412        .map_err(|e| {
1413            XlogError::Kernel(format!(
1414                "group_ids_from_boundaries (on_stream) failed: {}",
1415                e
1416            ))
1417        })?;
1418        unsafe {
1419            group_start_fn.clone().launch_on_stream(
1420                &cu_stream,
1421                cfg,
1422                (&boundaries, &d_boundary_pos, num_rows, &mut group_first_idx),
1423            )
1424        }
1425        .map_err(|e| XlogError::Kernel(format!("group_start_indices (on_stream) failed: {}", e)))?;
1426
1427        // Step 8: per-aggregation kernels.
1428        for ((value_col, agg_op), output) in aggs.iter().zip(agg_outputs.iter_mut()) {
1429            let values = sorted.column(*value_col).ok_or_else(|| {
1430                XlogError::Kernel(format!("Value column {} not found", value_col))
1431            })?;
1432            match agg_op {
1433                AggOp::Count => {
1434                    self.memset_zeros_u8_on_stream(output, &cu_stream)?;
1435                    let count_func = device
1436                        .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT)
1437                        .ok_or_else(|| {
1438                            XlogError::Kernel("groupby_count kernel not found".to_string())
1439                        })?;
1440                    // SAFETY: groupby_count(boundaries, group_ids, num_rows, counts)
1441                    unsafe {
1442                        count_func.clone().launch_on_stream(
1443                            &cu_stream,
1444                            cfg,
1445                            (&boundaries, &group_ids, num_rows, &*output),
1446                        )
1447                    }
1448                    .map_err(|e| {
1449                        XlogError::Kernel(format!("groupby_count (on_stream) failed: {}", e))
1450                    })?;
1451                }
1452                AggOp::Sum => {
1453                    self.memset_zeros_u8_on_stream(output, &cu_stream)?;
1454                    let value_ty = sorted
1455                        .schema()
1456                        .column_type(*value_col)
1457                        .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1458                    if value_ty == ScalarType::U64 {
1459                        let values_view = self.column_as_u64_view(values, row_cap_usize)?;
1460                        let sum_func = device
1461                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM_U64)
1462                            .ok_or_else(|| {
1463                                XlogError::Kernel("groupby_sum_u64 kernel not found".to_string())
1464                            })?;
1465                        // SAFETY: groupby_sum_u64(values, group_ids, num_rows, sums)
1466                        unsafe {
1467                            sum_func.clone().launch_on_stream(
1468                                &cu_stream,
1469                                cfg,
1470                                (&values_view, &group_ids, num_rows, &*output),
1471                            )
1472                        }
1473                        .map_err(|e| {
1474                            XlogError::Kernel(format!("groupby_sum_u64 (on_stream) failed: {}", e))
1475                        })?;
1476                    } else {
1477                        let values_view = self.column_as_u32_view(values, row_cap_usize)?;
1478                        let sum_func = device
1479                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM)
1480                            .ok_or_else(|| {
1481                                XlogError::Kernel("groupby_sum kernel not found".to_string())
1482                            })?;
1483                        // SAFETY: groupby_sum(values, group_ids, num_rows, sums)
1484                        unsafe {
1485                            sum_func.clone().launch_on_stream(
1486                                &cu_stream,
1487                                cfg,
1488                                (&values_view, &group_ids, num_rows, &*output),
1489                            )
1490                        }
1491                        .map_err(|e| {
1492                            XlogError::Kernel(format!("groupby_sum (on_stream) failed: {}", e))
1493                        })?;
1494                    }
1495                }
1496                AggOp::Min => {
1497                    let value_ty = sorted
1498                        .schema()
1499                        .column_type(*value_col)
1500                        .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1501                    let fill_config = LaunchConfig::for_num_elems(row_cap_u32);
1502                    if value_ty == ScalarType::U64 {
1503                        // U64 value-column min path (output U64,
1504                        // identity u64::MAX).
1505                        let fill_fn = device
1506                            .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U64)
1507                            .ok_or_else(|| {
1508                                XlogError::Kernel("arith_fill_const_u64 not found".to_string())
1509                            })?;
1510                        // SAFETY: arith_fill_const_u64(value, n, output)
1511                        unsafe {
1512                            fill_fn.clone().launch_on_stream(
1513                                &cu_stream,
1514                                fill_config,
1515                                (u64::MAX, row_cap_u32, &mut *output),
1516                            )
1517                        }
1518                        .map_err(|e| {
1519                            XlogError::Kernel(format!(
1520                                "arith_fill_const_u64 (on_stream) failed: {}",
1521                                e
1522                            ))
1523                        })?;
1524                        let values_view = self.column_as_u64_view(values, row_cap_usize)?;
1525                        let min_func = device
1526                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN_U64)
1527                            .ok_or_else(|| {
1528                                XlogError::Kernel("groupby_min_u64 kernel not found".to_string())
1529                            })?;
1530                        // SAFETY: groupby_min_u64(values, group_ids, num_rows, mins)
1531                        unsafe {
1532                            min_func.clone().launch_on_stream(
1533                                &cu_stream,
1534                                cfg,
1535                                (&values_view, &group_ids, num_rows, &*output),
1536                            )
1537                        }
1538                        .map_err(|e| {
1539                            XlogError::Kernel(format!("groupby_min_u64 (on_stream) failed: {}", e))
1540                        })?;
1541                    } else {
1542                        let fill_fn = device
1543                            .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_U32)
1544                            .ok_or_else(|| {
1545                                XlogError::Kernel("arith_fill_const_u32 not found".to_string())
1546                            })?;
1547                        // SAFETY: arith_fill_const_u32(value, n, output)
1548                        unsafe {
1549                            fill_fn.clone().launch_on_stream(
1550                                &cu_stream,
1551                                fill_config,
1552                                (u32::MAX, row_cap_u32, &mut *output),
1553                            )
1554                        }
1555                        .map_err(|e| {
1556                            XlogError::Kernel(format!(
1557                                "arith_fill_const_u32 (on_stream) failed: {}",
1558                                e
1559                            ))
1560                        })?;
1561                        let values_view = self.column_as_u32_view(values, row_cap_usize)?;
1562                        let min_func = device
1563                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN)
1564                            .ok_or_else(|| {
1565                                XlogError::Kernel("groupby_min kernel not found".to_string())
1566                            })?;
1567                        // SAFETY: groupby_min(values, group_ids, num_rows, mins)
1568                        unsafe {
1569                            min_func.clone().launch_on_stream(
1570                                &cu_stream,
1571                                cfg,
1572                                (&values_view, &group_ids, num_rows, &*output),
1573                            )
1574                        }
1575                        .map_err(|e| {
1576                            XlogError::Kernel(format!("groupby_min (on_stream) failed: {}", e))
1577                        })?;
1578                    }
1579                }
1580                AggOp::Max => {
1581                    self.memset_zeros_u8_on_stream(output, &cu_stream)?;
1582                    let value_ty = sorted
1583                        .schema()
1584                        .column_type(*value_col)
1585                        .ok_or_else(|| XlogError::Kernel("Value column has no type".to_string()))?;
1586                    if value_ty == ScalarType::U64 {
1587                        // U64 value-column max path (output U64,
1588                        // identity 0).
1589                        let values_view = self.column_as_u64_view(values, row_cap_usize)?;
1590                        let max_func = device
1591                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX_U64)
1592                            .ok_or_else(|| {
1593                                XlogError::Kernel("groupby_max_u64 kernel not found".to_string())
1594                            })?;
1595                        // SAFETY: groupby_max_u64(values, group_ids, num_rows, maxs)
1596                        unsafe {
1597                            max_func.clone().launch_on_stream(
1598                                &cu_stream,
1599                                cfg,
1600                                (&values_view, &group_ids, num_rows, &*output),
1601                            )
1602                        }
1603                        .map_err(|e| {
1604                            XlogError::Kernel(format!("groupby_max_u64 (on_stream) failed: {}", e))
1605                        })?;
1606                    } else {
1607                        let values_view = self.column_as_u32_view(values, row_cap_usize)?;
1608                        let max_func = device
1609                            .get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX)
1610                            .ok_or_else(|| {
1611                                XlogError::Kernel("groupby_max kernel not found".to_string())
1612                            })?;
1613                        // SAFETY: groupby_max(values, group_ids, num_rows, maxs)
1614                        unsafe {
1615                            max_func.clone().launch_on_stream(
1616                                &cu_stream,
1617                                cfg,
1618                                (&values_view, &group_ids, num_rows, &*output),
1619                            )
1620                        }
1621                        .map_err(|e| {
1622                            XlogError::Kernel(format!("groupby_max (on_stream) failed: {}", e))
1623                        })?;
1624                    }
1625                }
1626                AggOp::LogSumExp => unreachable!("rejected above"),
1627            }
1628        }
1629
1630        // Step 9: gather packed key rows by group_first_idx.
1631        let gather_fn = device
1632            .get_func(PACK_MODULE, pack_kernels::GATHER_PACKED_ROWS_COUNTED)
1633            .ok_or_else(|| {
1634                XlogError::Kernel("gather_packed_rows_counted kernel not found".to_string())
1635            })?;
1636        let gather_config = LaunchConfig::for_num_elems(row_cap_u32);
1637        // SAFETY: gather_packed_rows_counted(src_packed, row_size, indices, num_rows, capacity_rows, dst_packed)
1638        unsafe {
1639            gather_fn.clone().launch_on_stream(
1640                &cu_stream,
1641                gather_config,
1642                (
1643                    &packed.packed_keys,
1644                    packed.key_bytes,
1645                    &group_first_idx,
1646                    &d_num_groups,
1647                    row_cap_u32,
1648                    &mut group_packed,
1649                ),
1650            )
1651        }
1652        .map_err(|e| {
1653            XlogError::Kernel(format!(
1654                "gather_packed_rows_counted (on_stream) failed: {}",
1655                e
1656            ))
1657        })?;
1658
1659        // Step 10: unpack each key column from the gathered packed rows.
1660        let unpack_fn = device
1661            .get_func(PACK_MODULE, pack_kernels::UNPACK_COLUMN_COUNTED)
1662            .ok_or_else(|| {
1663                XlogError::Kernel("unpack_column_counted kernel not found".to_string())
1664            })?;
1665        let unpack_config = LaunchConfig::for_num_elems(row_cap_u32);
1666        for idx in 0..key_cols.len() {
1667            let col_size = col_sizes[idx];
1668            let col_offset = col_offsets[idx];
1669            // SAFETY: unpack_column_counted(packed, row_size, col_offset, col_size,
1670            // num_rows, capacity_rows, col_output)
1671            unsafe {
1672                unpack_fn.clone().launch_on_stream(
1673                    &cu_stream,
1674                    unpack_config,
1675                    (
1676                        &group_packed,
1677                        packed.key_bytes,
1678                        col_offset,
1679                        col_size,
1680                        &d_num_groups,
1681                        row_cap_u32,
1682                        &mut key_unpacked[idx],
1683                    ),
1684                )
1685            }
1686            .map_err(|e| {
1687                XlogError::Kernel(format!("unpack_column_counted (on_stream) failed: {}", e))
1688            })?;
1689        }
1690
1691        // Record fresh writes via post-preflight escape hatch.
1692        rec.commit(runtime).map_err(|e| {
1693            XlogError::Kernel(format!("groupby_multi_agg_recorded: commit failed: {}", e))
1694        })?;
1695
1696        // Step 11: build the result CudaBuffer (keys then aggs).
1697        let mut result_columns: Vec<CudaColumn> = Vec::with_capacity(key_cols.len() + aggs.len());
1698        for k in key_unpacked {
1699            result_columns.push(k.into());
1700        }
1701        for o in agg_outputs {
1702            result_columns.push(o.into());
1703        }
1704        let result_schema = self.groupby_multi_agg_result_schema(buffer.schema(), key_cols, aggs);
1705        Ok(CudaBuffer::from_columns(
1706            result_columns,
1707            row_cap_u64,
1708            d_num_groups,
1709            result_schema,
1710        ))
1711    }
1712
1713    /// Convenience single-aggregation entry, mirrors
1714    /// [`Self::groupby_agg`]. Forwards to
1715    /// [`Self::groupby_multi_agg_recorded`].
1716    pub fn groupby_agg_recorded(
1717        &self,
1718        input: &CudaBuffer,
1719        key_cols: &[usize],
1720        agg: AggOp,
1721        value_col: usize,
1722        launch_stream: crate::device_runtime::StreamId,
1723    ) -> Result<CudaBuffer> {
1724        self.groupby_multi_agg_recorded(input, key_cols, &[(value_col, agg)], launch_stream)
1725    }
1726}