1use std::ffi::c_void;
43
44use xlog_core::{Result, ScalarType, Schema, XlogError};
45
46use super::{wcoj_kernels, CudaKernelProvider, WCOJ_MODULE};
47use crate::device_runtime::StreamId;
48use crate::launch::LaunchRecorder;
49use crate::memory::{CudaColumn, TrackedCudaSlice};
50use crate::CudaBuffer;
51use crate::{AsKernelParam, LaunchAsync, LaunchConfig};
52
53const BLOCK_SIZE: u32 = 256;
54
55#[derive(Debug, Clone)]
60pub struct FjSubAtom {
61 pub input_idx: usize,
64 pub var_positions: Vec<usize>,
67}
68
69#[derive(Debug, Clone)]
73pub struct FjNode {
74 pub cover: FjSubAtom,
75 pub probes: Vec<FjSubAtom>,
76}
77
78#[derive(Debug, Clone)]
82pub struct FjPlan {
83 pub num_vars: usize,
85 pub nodes: Vec<FjNode>,
87 pub output_vars: Vec<usize>,
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98enum ColTag {
99 Var(usize),
101 RangeLo(usize),
103 RangeHi(usize),
105}
106
107type FrontierCol = (ColTag, TrackedCudaSlice<u8>);
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113enum FjWidth {
114 U32,
115 U64,
116}
117
118impl FjWidth {
119 fn var_bytes(self) -> usize {
120 match self {
121 Self::U32 => std::mem::size_of::<u32>(),
122 Self::U64 => std::mem::size_of::<u64>(),
123 }
124 }
125
126 fn var_type(self) -> ScalarType {
127 match self {
128 Self::U32 => ScalarType::U32,
129 Self::U64 => ScalarType::U64,
130 }
131 }
132
133 fn count_kernel(self) -> &'static str {
134 match self {
135 Self::U32 => wcoj_kernels::FJ_EXPAND_COUNT_U32,
136 Self::U64 => wcoj_kernels::FJ_EXPAND_COUNT_U64,
137 }
138 }
139
140 fn emit_kernel(self) -> &'static str {
141 match self {
142 Self::U32 => wcoj_kernels::FJ_EXPAND_EMIT_U32,
143 Self::U64 => wcoj_kernels::FJ_EXPAND_EMIT_U64,
144 }
145 }
146
147 fn probe_kernel(self) -> &'static str {
148 match self {
149 Self::U32 => wcoj_kernels::FJ_PROBE_REFINE_U32,
150 Self::U64 => wcoj_kernels::FJ_PROBE_REFINE_U64,
151 }
152 }
153}
154
155fn tag_type(tag: ColTag, width: FjWidth) -> ScalarType {
158 match tag {
159 ColTag::Var(_) => width.var_type(),
160 ColTag::RangeLo(_) | ColTag::RangeHi(_) => ScalarType::U32,
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166enum FjMode {
167 Materialize,
169 CountByRoot,
175}
176
177fn owned_col_ptr(buf: &CudaBuffer, idx: usize, ctx: &str) -> Result<u64> {
178 match buf.column(idx) {
179 Some(CudaColumn::Owned(s)) => Ok(*s.device_ptr()),
180 Some(_) => Err(XlogError::Kernel(format!(
181 "{ctx}: input column {idx} must be owned"
182 ))),
183 None => Err(XlogError::Kernel(format!(
184 "{ctx}: input column {idx} not found"
185 ))),
186 }
187}
188
189fn find_col<'a>(
190 cols: &'a [FrontierCol],
191 tag: ColTag,
192 ctx: &str,
193) -> Result<&'a TrackedCudaSlice<u8>> {
194 cols.iter()
195 .find(|(t, _)| *t == tag)
196 .map(|(_, s)| s)
197 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: frontier column {tag:?} missing")))
198}
199
200fn validate_plan(plan: &FjPlan, arities: &[usize], mode: FjMode, ctx: &str) -> Result<Vec<usize>> {
203 if plan.nodes.is_empty() {
204 return Err(XlogError::Kernel(format!("{ctx}: plan has no nodes")));
205 }
206 let mut bound = vec![false; plan.num_vars];
207 let mut bind_order: Vec<usize> = Vec::new();
208 let mut consumed = vec![0usize; arities.len()];
209 let check_sub = |sub: &FjSubAtom, consumed: &[usize], what: &str| -> Result<()> {
210 if sub.input_idx >= arities.len() {
211 return Err(XlogError::Kernel(format!(
212 "{ctx}: {what} input_idx {} out of bounds ({} inputs)",
213 sub.input_idx,
214 arities.len()
215 )));
216 }
217 if sub.var_positions.is_empty() {
218 return Err(XlogError::Kernel(format!(
219 "{ctx}: {what} on input {} has no variables",
220 sub.input_idx
221 )));
222 }
223 if consumed[sub.input_idx] + sub.var_positions.len() > arities[sub.input_idx] {
224 return Err(XlogError::Kernel(format!(
225 "{ctx}: {what} over-consumes input {} (arity {}, consumed {}, +{})",
226 sub.input_idx,
227 arities[sub.input_idx],
228 consumed[sub.input_idx],
229 sub.var_positions.len()
230 )));
231 }
232 for &v in &sub.var_positions {
233 if v >= plan.num_vars {
234 return Err(XlogError::Kernel(format!(
235 "{ctx}: {what} variable {v} out of bounds (num_vars {})",
236 plan.num_vars
237 )));
238 }
239 }
240 Ok(())
241 };
242 for (k, node) in plan.nodes.iter().enumerate() {
243 check_sub(&node.cover, &consumed, "cover")?;
244 for &v in &node.cover.var_positions {
245 if bound[v] {
246 return Err(XlogError::Kernel(format!(
247 "{ctx}: node {k} cover rebinds variable {v}"
248 )));
249 }
250 bound[v] = true;
251 bind_order.push(v);
252 }
253 consumed[node.cover.input_idx] += node.cover.var_positions.len();
254 let mut seen_atoms = vec![node.cover.input_idx];
255 for probe in &node.probes {
256 check_sub(probe, &consumed, "probe")?;
257 if seen_atoms.contains(&probe.input_idx) {
258 return Err(XlogError::Kernel(format!(
259 "{ctx}: node {k} touches input {} more than once",
260 probe.input_idx
261 )));
262 }
263 seen_atoms.push(probe.input_idx);
264 for &v in &probe.var_positions {
265 if !bound[v] {
266 return Err(XlogError::Kernel(format!(
267 "{ctx}: node {k} probes unbound variable {v}"
268 )));
269 }
270 }
271 consumed[probe.input_idx] += probe.var_positions.len();
272 }
273 }
274 for (i, (&used, &arity)) in consumed.iter().zip(arities.iter()).enumerate() {
275 match mode {
276 FjMode::Materialize => {
277 if used != arity {
278 return Err(XlogError::Kernel(format!(
279 "{ctx}: plan consumes {used}/{arity} columns of input {i} \
280 (materialization requires full consumption)"
281 )));
282 }
283 }
284 FjMode::CountByRoot => {
288 if used == 0 {
289 return Err(XlogError::Kernel(format!(
290 "{ctx}: count plan never touches input {i} \
291 (untouched atoms have no live range)"
292 )));
293 }
294 }
295 }
296 }
297 if plan.output_vars.is_empty() {
298 return Err(XlogError::Kernel(format!("{ctx}: empty output_vars")));
299 }
300 if mode == FjMode::CountByRoot && plan.output_vars.len() != 1 {
301 return Err(XlogError::Kernel(format!(
302 "{ctx}: count plans take exactly one output (group) variable, got {}",
303 plan.output_vars.len()
304 )));
305 }
306 for &v in &plan.output_vars {
307 if v >= plan.num_vars || !bound[v] {
308 return Err(XlogError::Kernel(format!(
309 "{ctx}: output variable {v} is never bound"
310 )));
311 }
312 }
313 Ok(bind_order)
314}
315
316impl CudaKernelProvider {
317 pub fn free_join_execute_u32_recorded(
336 &self,
337 inputs: &[&CudaBuffer],
338 plan: &FjPlan,
339 launch_stream: StreamId,
340 ) -> Result<CudaBuffer> {
341 self.free_join_execute_recorded_impl(
342 inputs,
343 plan,
344 launch_stream,
345 FjWidth::U32,
346 FjMode::Materialize,
347 "free_join_execute_u32_recorded",
348 )
349 }
350
351 pub fn free_join_execute_u64_recorded(
355 &self,
356 inputs: &[&CudaBuffer],
357 plan: &FjPlan,
358 launch_stream: StreamId,
359 ) -> Result<CudaBuffer> {
360 self.free_join_execute_recorded_impl(
361 inputs,
362 plan,
363 launch_stream,
364 FjWidth::U64,
365 FjMode::Materialize,
366 "free_join_execute_u64_recorded",
367 )
368 }
369
370 pub fn free_join_count_by_root_u32_recorded(
384 &self,
385 inputs: &[&CudaBuffer],
386 plan: &FjPlan,
387 launch_stream: StreamId,
388 ) -> Result<CudaBuffer> {
389 self.free_join_execute_recorded_impl(
390 inputs,
391 plan,
392 launch_stream,
393 FjWidth::U32,
394 FjMode::CountByRoot,
395 "free_join_count_by_root_u32_recorded",
396 )
397 }
398
399 #[allow(clippy::too_many_lines)]
400 fn free_join_execute_recorded_impl(
401 &self,
402 inputs: &[&CudaBuffer],
403 plan: &FjPlan,
404 launch_stream: StreamId,
405 width: FjWidth,
406 mode: FjMode,
407 ctx: &str,
408 ) -> Result<CudaBuffer> {
409 if self.memory().runtime().is_none() {
410 return Err(XlogError::Kernel(format!(
411 "{ctx} requires a runtime-backed GpuMemoryManager \
412 (constructed via with_runtime)"
413 )));
414 }
415 if inputs.is_empty() {
416 return Err(XlogError::Kernel(format!("{ctx}: no inputs")));
417 }
418 let arities: Vec<usize> = inputs.iter().map(|b| b.arity()).collect();
419 let bind_order = validate_plan(plan, &arities, mode, ctx)?;
420
421 let mut norm: Vec<CudaBuffer> = Vec::with_capacity(inputs.len());
426 for input in inputs {
427 let normalized = match (width, input.arity()) {
428 (FjWidth::U32, 2) => self.wcoj_layout_u32_recorded(input, launch_stream)?,
429 (FjWidth::U32, _) => self.wcoj_layout_sort_u32_recorded(input, launch_stream)?,
430 (FjWidth::U64, 2) => self.wcoj_layout_u64_recorded(input, launch_stream)?,
431 (FjWidth::U64, _) => self.wcoj_layout_sort_u64_recorded(input, launch_stream)?,
432 };
433 norm.push(normalized);
434 }
435 let mut n_rows: Vec<u32> = Vec::with_capacity(norm.len());
436 for buf in &norm {
437 let n = match buf.cached_row_count() {
438 Some(c) => c,
439 None => self.dtoh_scalar_untracked::<u32>(buf.num_rows_device(), 0)?,
440 };
441 n_rows.push(n);
442 }
443
444 let out_schema = match mode {
445 FjMode::Materialize => Schema::new(
446 plan.output_vars
447 .iter()
448 .map(|v| (format!("v{v}"), width.var_type()))
449 .collect(),
450 ),
451 FjMode::CountByRoot => Schema::new(vec![
452 (format!("v{}", plan.output_vars[0]), width.var_type()),
453 ("count".to_string(), ScalarType::U64),
454 ]),
455 };
456 if n_rows.iter().any(|&n| n == 0) {
458 return self.create_empty_buffer(out_schema);
459 }
460
461 let runtime = self.memory().runtime().ok_or_else(|| {
462 XlogError::Kernel(format!("{ctx} requires a runtime-backed GpuMemoryManager"))
463 })?;
464 let cu_stream = runtime
465 .stream_pool()
466 .resolve(launch_stream)
467 .ok_or_else(|| {
468 XlogError::Kernel(format!(
469 "{ctx}: launch_stream StreamId({}) does not resolve",
470 launch_stream.0
471 ))
472 })?;
473
474 let mut frontier: Vec<FrontierCol> = Vec::new();
477 let mut count: u32 = 1;
478 let mut frontier_cap: u32 = 1;
485 let mut consumed = vec![0usize; inputs.len()];
486
487 for node in &plan.nodes {
488 let a = node.cover.input_idx;
489 let c = node.cover.var_positions.len();
490 let depth = consumed[a];
491 let cover_live = frontier.iter().any(|(t, _)| *t == ColTag::RangeLo(a));
492
493 let (total_work, work_prefix) = if cover_live {
499 let mut wp = self.memory().alloc::<u32>(count as usize + 1)?;
500 let lo_col = find_col(&frontier, ColTag::RangeLo(a), ctx)?;
501 let hi_col = find_col(&frontier, ColTag::RangeHi(a), ctx)?;
502 let mut rec = LaunchRecorder::new_strict(launch_stream);
503 rec.read(lo_col);
504 rec.read(hi_col);
505 rec.write(&wp);
506 rec.preflight(runtime)
507 .map_err(|e| XlogError::Kernel(format!("{ctx}: wp preflight failed: {e}")))?;
508 let kernel = self
509 .device()
510 .inner()
511 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_EXPAND_WORK_PREFIX_U32)
512 .ok_or_else(|| {
513 XlogError::Kernel("fj_expand_work_prefix_u32 kernel not found".to_string())
514 })?;
515 let grid = count.div_ceil(BLOCK_SIZE);
516 unsafe {
520 kernel
521 .clone()
522 .launch_on_stream(
523 &cu_stream,
524 LaunchConfig {
525 grid_dim: (grid, 1, 1),
526 block_dim: (BLOCK_SIZE, 1, 1),
527 shared_mem_bytes: 0,
528 },
529 (lo_col, hi_col, count, &mut wp),
530 )
531 .map_err(|e| {
532 XlogError::Kernel(format!(
533 "fj_expand_work_prefix_u32 launch failed: {e}"
534 ))
535 })?;
536 }
537 self.multiblock_scan_u32_inplace_on_stream(
538 &mut wp,
539 count + 1,
540 &cu_stream,
541 launch_stream,
542 runtime,
543 )?;
544 rec.commit(runtime)
545 .map_err(|e| XlogError::Kernel(format!("{ctx}: wp commit failed: {e}")))?;
546 cu_stream
547 .synchronize()
548 .map_err(|e| XlogError::Kernel(format!("{ctx}: wp sync failed: {e}")))?;
549 let total = self.dtoh_scalar_untracked::<u32>(&wp, count as usize)?;
550 (u64::from(total), Some(wp))
551 } else {
552 ((count as u64) * (n_rows[a] as u64), None)
553 };
554 if total_work == 0 {
555 return self.create_empty_buffer(out_schema);
556 }
557 if total_work > u64::from(u32::MAX - 1) {
558 return Err(XlogError::Kernel(format!(
559 "{ctx}: expansion work {total_work} exceeds the u32 work-index \
560 space (frontier budget)"
561 )));
562 }
563 let total_work = total_work as u32;
564
565 let identity = depth + c >= arities[a];
577 let is_fusable = |pr: &FjSubAtom| {
588 consumed[pr.input_idx] + pr.var_positions.len() >= arities[pr.input_idx]
589 && pr
590 .var_positions
591 .iter()
592 .all(|v| node.cover.var_positions.contains(v))
593 };
594 let fused: Vec<&FjSubAtom> = node.probes.iter().filter(|p| is_fusable(p)).collect();
595 let count_ran = !identity || !fused.is_empty();
598 let mut fused_desc: Vec<u64> = Vec::new();
602 for pr in &fused {
603 let p = pr.input_idx;
604 let live = frontier.iter().any(|(t, _)| *t == ColTag::RangeLo(p));
605 fused_desc.push(pr.var_positions.len() as u64);
606 fused_desc.push(u64::from(live));
607 if live {
608 fused_desc.push(*find_col(&frontier, ColTag::RangeLo(p), ctx)?.device_ptr());
609 fused_desc.push(*find_col(&frontier, ColTag::RangeHi(p), ctx)?.device_ptr());
610 } else {
611 fused_desc.push(0);
612 fused_desc.push(0);
613 }
614 fused_desc.push(u64::from(n_rows[p]));
615 for i in consumed[p]..consumed[p] + pr.var_positions.len() {
616 fused_desc.push(owned_col_ptr(&norm[p], i, ctx)?);
617 }
618 for v in &pr.var_positions {
619 fused_desc.push(
620 node.cover
621 .var_positions
622 .iter()
623 .position(|cv| cv == v)
624 .expect("fusable probe keys are cover variables")
625 as u64,
626 );
627 }
628 }
629 let d_fused_desc: Option<TrackedCudaSlice<u64>> = if fused_desc.is_empty() {
630 None
631 } else {
632 let mut tbl = self.memory().alloc::<u64>(fused_desc.len())?;
633 self.htod_launch_metadata_sync_copy_into(&fused_desc, &mut tbl)
634 .map_err(|e| {
635 XlogError::Kernel(format!("{ctx}: htod fused-probe table failed: {e}"))
636 })?;
637 Some(tbl)
638 };
639 let cover_ptrs: Vec<u64> = (depth..depth + c)
640 .map(|i| owned_col_ptr(&norm[a], i, ctx))
641 .collect::<Result<_>>()?;
642 let mut d_cover_tbl = self.memory().alloc::<u64>(c)?;
643 self.htod_launch_metadata_sync_copy_into(&cover_ptrs, &mut d_cover_tbl)
644 .map_err(|e| XlogError::Kernel(format!("{ctx}: htod cover table failed: {e}")))?;
645 let mut marks = self.memory().alloc::<u32>(total_work as usize + 1)?;
646 let has_parent_range: u32 = u32::from(cover_live);
647 let const_lo: u32 = 0;
648 let const_hi: u32 = n_rows[a];
649 let null_ptr: u64 = 0;
650 if count_ran {
651 let mut rec = LaunchRecorder::new_strict(launch_stream);
652 rec.read(norm[a].num_rows_device());
653 for i in depth..depth + c {
654 rec.read_column(norm[a].column(i).expect("validated cover column"));
655 }
656 rec.read(&d_cover_tbl);
657 if let Some(wp) = work_prefix.as_ref() {
658 rec.read(wp);
659 rec.read(find_col(&frontier, ColTag::RangeLo(a), ctx)?);
660 }
661 if let Some(d) = d_fused_desc.as_ref() {
662 rec.read(d);
663 for pr in &fused {
664 let p = pr.input_idx;
665 rec.read(norm[p].num_rows_device());
666 for i in consumed[p]..consumed[p] + pr.var_positions.len() {
667 rec.read_column(norm[p].column(i).expect("validated probe column"));
668 }
669 if frontier.iter().any(|(t, _)| *t == ColTag::RangeLo(p)) {
670 rec.read(find_col(&frontier, ColTag::RangeLo(p), ctx)?);
671 rec.read(find_col(&frontier, ColTag::RangeHi(p), ctx)?);
672 }
673 }
674 }
675 rec.write(&marks);
676 rec.preflight(runtime).map_err(|e| {
677 XlogError::Kernel(format!("{ctx}: count preflight failed: {e}"))
678 })?;
679 let kernel = self
680 .device()
681 .inner()
682 .get_func(WCOJ_MODULE, width.count_kernel())
683 .ok_or_else(|| {
684 XlogError::Kernel(format!("{} kernel not found", width.count_kernel()))
685 })?;
686 let grid = total_work.div_ceil(BLOCK_SIZE);
687 let c_u32 = c as u32;
688 let identity_u32: u32 = u32::from(identity);
689 let n_fused_u32: u32 = fused.len() as u32;
690 unsafe {
697 let parent_lo_param = match work_prefix.as_ref() {
698 Some(_) => find_col(&frontier, ColTag::RangeLo(a), ctx)?.as_kernel_param(),
699 None => null_ptr.as_kernel_param(),
700 };
701 let wp_param = match work_prefix.as_ref() {
702 Some(wp) => wp.as_kernel_param(),
703 None => null_ptr.as_kernel_param(),
704 };
705 let mut params: Vec<*mut c_void> = vec![
706 (&d_cover_tbl).as_kernel_param(),
707 c_u32.as_kernel_param(),
708 parent_lo_param,
709 wp_param,
710 has_parent_range.as_kernel_param(),
711 const_lo.as_kernel_param(),
712 const_hi.as_kernel_param(),
713 count.as_kernel_param(),
714 total_work.as_kernel_param(),
715 identity_u32.as_kernel_param(),
716 match d_fused_desc.as_ref() {
717 Some(d) => d.as_kernel_param(),
718 None => null_ptr.as_kernel_param(),
719 },
720 n_fused_u32.as_kernel_param(),
721 (&marks).as_kernel_param(),
722 ];
723 kernel
724 .clone()
725 .launch_on_stream(
726 &cu_stream,
727 LaunchConfig {
728 grid_dim: (grid, 1, 1),
729 block_dim: (BLOCK_SIZE, 1, 1),
730 shared_mem_bytes: 0,
731 },
732 &mut params,
733 )
734 .map_err(|e| {
735 XlogError::Kernel(format!(
736 "{} launch failed: {e}",
737 width.count_kernel()
738 ))
739 })?;
740 }
741 self.multiblock_scan_u32_inplace_on_stream(
742 &mut marks,
743 total_work + 1,
744 &cu_stream,
745 launch_stream,
746 runtime,
747 )?;
748 rec.commit(runtime)
749 .map_err(|e| XlogError::Kernel(format!("{ctx}: count commit failed: {e}")))?;
750 }
751 let n_children = if count_ran {
752 cu_stream
753 .synchronize()
754 .map_err(|e| XlogError::Kernel(format!("{ctx}: count sync failed: {e}")))?;
755 self.dtoh_scalar_untracked::<u32>(&marks, total_work as usize)?
756 } else {
757 total_work
758 };
759 if n_children == 0 {
760 return self.create_empty_buffer(out_schema);
761 }
762
763 let var_bytes = (n_children as usize) * width.var_bytes();
774 let range_bytes = (n_children as usize) * std::mem::size_of::<u32>();
775 let mut parent_copy_var_ptrs: Vec<u64> = Vec::new();
776 let mut child_copy_var_ptrs: Vec<u64> = Vec::new();
777 let mut parent_copy_range_ptrs: Vec<u64> = Vec::new();
778 let mut child_copy_range_ptrs: Vec<u64> = Vec::new();
779 let mut child_cols: Vec<FrontierCol> = Vec::new();
780 for (tag, slice) in &frontier {
781 if matches!(tag, ColTag::RangeLo(x) | ColTag::RangeHi(x) if *x == a) {
782 continue; }
784 let is_var = matches!(tag, ColTag::Var(_));
785 let dst =
786 self.memory()
787 .alloc::<u8>(if is_var { var_bytes } else { range_bytes })?;
788 if is_var {
789 parent_copy_var_ptrs.push(*slice.device_ptr());
790 child_copy_var_ptrs.push(*dst.device_ptr());
791 } else {
792 parent_copy_range_ptrs.push(*slice.device_ptr());
793 child_copy_range_ptrs.push(*dst.device_ptr());
794 }
795 child_cols.push((*tag, dst));
796 }
797 let n_copy_var = parent_copy_var_ptrs.len();
798 let n_copy_range = parent_copy_range_ptrs.len();
799 let mut child_var_ptrs: Vec<u64> = Vec::with_capacity(c);
800 for &v in &node.cover.var_positions {
801 let dst = self.memory().alloc::<u8>(var_bytes)?;
802 child_var_ptrs.push(*dst.device_ptr());
803 child_cols.push((ColTag::Var(v), dst));
804 }
805 let keep_cover = depth + c < arities[a];
806 if keep_cover {
807 let lo = self.memory().alloc::<u8>(range_bytes)?;
808 let hi = self.memory().alloc::<u8>(range_bytes)?;
809 child_cols.push((ColTag::RangeLo(a), lo));
810 child_cols.push((ColTag::RangeHi(a), hi));
811 }
812 let upload_tbl = |ptrs: &[u64]| -> Result<TrackedCudaSlice<u64>> {
814 let mut tbl = self.memory().alloc::<u64>(ptrs.len().max(1))?;
815 if !ptrs.is_empty() {
816 self.htod_launch_metadata_sync_copy_into(ptrs, &mut tbl)
817 .map_err(|e| {
818 XlogError::Kernel(format!("{ctx}: htod pointer table failed: {e}"))
819 })?;
820 }
821 Ok(tbl)
822 };
823 let d_parent_copy_var_tbl = upload_tbl(&parent_copy_var_ptrs)?;
824 let d_child_copy_var_tbl = upload_tbl(&child_copy_var_ptrs)?;
825 let d_parent_copy_range_tbl = upload_tbl(&parent_copy_range_ptrs)?;
826 let d_child_copy_range_tbl = upload_tbl(&child_copy_range_ptrs)?;
827 let d_child_var_tbl = upload_tbl(&child_var_ptrs)?;
828
829 struct ProbePlan {
832 input_idx: usize,
833 n_cols: u32,
834 data_tbl: TrackedCudaSlice<u64>,
835 key_tbl: TrackedCudaSlice<u64>,
836 live: bool,
837 keep: bool,
838 out_lo: Option<TrackedCudaSlice<u8>>,
839 out_hi: Option<TrackedCudaSlice<u8>>,
840 }
841 let mut probe_plans: Vec<ProbePlan> = Vec::with_capacity(node.probes.len());
842 for probe in node.probes.iter().filter(|pr| !is_fusable(pr)) {
843 let p = probe.input_idx;
844 let p_len = probe.var_positions.len();
845 let p_depth = consumed[p];
846 let data_ptrs: Vec<u64> = (p_depth..p_depth + p_len)
847 .map(|i| owned_col_ptr(&norm[p], i, ctx))
848 .collect::<Result<_>>()?;
849 let key_ptrs: Vec<u64> = probe
850 .var_positions
851 .iter()
852 .map(|&v| Ok(*find_col(&child_cols, ColTag::Var(v), ctx)?.device_ptr()))
853 .collect::<Result<_>>()?;
854 let live = child_cols.iter().any(|(t, _)| *t == ColTag::RangeLo(p));
855 let keep = p_depth + p_len < arities[p];
856 let (out_lo, out_hi) = if keep {
857 (
858 Some(self.memory().alloc::<u8>(range_bytes)?),
859 Some(self.memory().alloc::<u8>(range_bytes)?),
860 )
861 } else {
862 (None, None)
863 };
864 probe_plans.push(ProbePlan {
865 input_idx: p,
866 n_cols: p_len as u32,
867 data_tbl: upload_tbl(&data_ptrs)?,
868 key_tbl: upload_tbl(&key_ptrs)?,
869 live,
870 keep,
871 out_lo,
872 out_hi,
873 });
874 }
875 let mask: Option<TrackedCudaSlice<u8>> = if probe_plans.is_empty() {
876 None
877 } else {
878 Some(self.memory().alloc::<u8>(n_children as usize)?)
879 };
880
881 {
882 let mut rec = LaunchRecorder::new_strict(launch_stream);
883 for i in depth..depth + c {
884 rec.read_column(norm[a].column(i).expect("validated cover column"));
885 }
886 rec.read(&d_cover_tbl);
887 if count_ran {
888 rec.read(&marks);
889 }
890 if let Some(wp) = work_prefix.as_ref() {
891 rec.read(wp);
892 rec.read(find_col(&frontier, ColTag::RangeLo(a), ctx)?);
893 rec.read(find_col(&frontier, ColTag::RangeHi(a), ctx)?);
894 }
895 for (_, slice) in &frontier {
896 rec.read(slice);
897 }
898 rec.read(&d_parent_copy_var_tbl);
899 rec.read(&d_child_copy_var_tbl);
900 rec.read(&d_parent_copy_range_tbl);
901 rec.read(&d_child_copy_range_tbl);
902 rec.read(&d_child_var_tbl);
903 for (_, slice) in &child_cols {
904 rec.write(slice);
905 }
906 for pp in probe_plans.iter() {
907 let p = pp.input_idx;
908 for i in consumed[p]..consumed[p] + pp.n_cols as usize {
909 rec.read_column(norm[p].column(i).expect("validated probe column"));
910 }
911 rec.read(norm[p].num_rows_device());
912 rec.read(&pp.data_tbl);
913 rec.read(&pp.key_tbl);
914 if let Some(lo) = pp.out_lo.as_ref() {
915 rec.write(lo);
916 }
917 if let Some(hi) = pp.out_hi.as_ref() {
918 rec.write(hi);
919 }
920 }
921 if let Some(m) = mask.as_ref() {
922 rec.write(m);
923 }
924 rec.preflight(runtime)
925 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit preflight failed: {e}")))?;
926
927 let emit_kernel = self
928 .device()
929 .inner()
930 .get_func(WCOJ_MODULE, width.emit_kernel())
931 .ok_or_else(|| {
932 XlogError::Kernel(format!("{} kernel not found", width.emit_kernel()))
933 })?;
934 let grid = total_work.div_ceil(BLOCK_SIZE);
935 let c_u32 = c as u32;
936 let n_copy_var_u32 = n_copy_var as u32;
937 let n_copy_range_u32 = n_copy_range as u32;
938 let keep_cover_u32 = u32::from(keep_cover);
939 unsafe {
949 let parent_lo_param = match work_prefix.as_ref() {
950 Some(_) => find_col(&frontier, ColTag::RangeLo(a), ctx)?.as_kernel_param(),
951 None => null_ptr.as_kernel_param(),
952 };
953 let parent_hi_param = match work_prefix.as_ref() {
954 Some(_) => find_col(&frontier, ColTag::RangeHi(a), ctx)?.as_kernel_param(),
955 None => null_ptr.as_kernel_param(),
956 };
957 let wp_param = match work_prefix.as_ref() {
958 Some(wp) => wp.as_kernel_param(),
959 None => null_ptr.as_kernel_param(),
960 };
961 let cover_lo_param = if keep_cover {
962 find_col(&child_cols, ColTag::RangeLo(a), ctx)?.as_kernel_param()
963 } else {
964 null_ptr.as_kernel_param()
965 };
966 let cover_hi_param = if keep_cover {
967 find_col(&child_cols, ColTag::RangeHi(a), ctx)?.as_kernel_param()
968 } else {
969 null_ptr.as_kernel_param()
970 };
971 let mut params: Vec<*mut c_void> = vec![
972 (&d_cover_tbl).as_kernel_param(),
973 c_u32.as_kernel_param(),
974 parent_lo_param,
975 parent_hi_param,
976 wp_param,
977 has_parent_range.as_kernel_param(),
978 const_lo.as_kernel_param(),
979 const_hi.as_kernel_param(),
980 count.as_kernel_param(),
981 total_work.as_kernel_param(),
982 if count_ran {
985 (&marks).as_kernel_param()
986 } else {
987 null_ptr.as_kernel_param()
988 },
989 (&d_parent_copy_var_tbl).as_kernel_param(),
990 (&d_child_copy_var_tbl).as_kernel_param(),
991 n_copy_var_u32.as_kernel_param(),
992 (&d_parent_copy_range_tbl).as_kernel_param(),
993 (&d_child_copy_range_tbl).as_kernel_param(),
994 n_copy_range_u32.as_kernel_param(),
995 (&d_child_var_tbl).as_kernel_param(),
996 keep_cover_u32.as_kernel_param(),
997 cover_lo_param,
998 cover_hi_param,
999 ];
1000 emit_kernel
1001 .clone()
1002 .launch_on_stream(
1003 &cu_stream,
1004 LaunchConfig {
1005 grid_dim: (grid, 1, 1),
1006 block_dim: (BLOCK_SIZE, 1, 1),
1007 shared_mem_bytes: 0,
1008 },
1009 &mut params,
1010 )
1011 .map_err(|e| {
1012 XlogError::Kernel(format!("{} launch failed: {e}", width.emit_kernel()))
1013 })?;
1014 }
1015
1016 let probe_kernel = self
1018 .device()
1019 .inner()
1020 .get_func(WCOJ_MODULE, width.probe_kernel())
1021 .ok_or_else(|| {
1022 XlogError::Kernel(format!("{} kernel not found", width.probe_kernel()))
1023 })?;
1024 let probe_grid = n_children.div_ceil(BLOCK_SIZE);
1025 for (probe_idx, pp) in probe_plans.iter().enumerate() {
1026 let p = pp.input_idx;
1027 let has_range = u32::from(pp.live);
1028 let p_const_lo: u32 = 0;
1029 let p_const_hi: u32 = n_rows[p];
1030 let keep_u32 = u32::from(pp.keep);
1031 let combine: u32 = u32::from(probe_idx > 0);
1032 unsafe {
1038 let in_lo_param = if pp.live {
1039 find_col(&child_cols, ColTag::RangeLo(p), ctx)?.as_kernel_param()
1040 } else {
1041 null_ptr.as_kernel_param()
1042 };
1043 let in_hi_param = if pp.live {
1044 find_col(&child_cols, ColTag::RangeHi(p), ctx)?.as_kernel_param()
1045 } else {
1046 null_ptr.as_kernel_param()
1047 };
1048 let out_lo_param = match pp.out_lo.as_ref() {
1049 Some(lo) => lo.as_kernel_param(),
1050 None => null_ptr.as_kernel_param(),
1051 };
1052 let out_hi_param = match pp.out_hi.as_ref() {
1053 Some(hi) => hi.as_kernel_param(),
1054 None => null_ptr.as_kernel_param(),
1055 };
1056 let mask_ref = mask.as_ref().expect("mask exists when probes exist");
1057 let mut params: Vec<*mut c_void> = vec![
1058 (&pp.data_tbl).as_kernel_param(),
1059 pp.n_cols.as_kernel_param(),
1060 (&pp.key_tbl).as_kernel_param(),
1061 in_lo_param,
1062 in_hi_param,
1063 has_range.as_kernel_param(),
1064 p_const_lo.as_kernel_param(),
1065 p_const_hi.as_kernel_param(),
1066 n_children.as_kernel_param(),
1067 keep_u32.as_kernel_param(),
1068 out_lo_param,
1069 out_hi_param,
1070 mask_ref.as_kernel_param(),
1071 combine.as_kernel_param(),
1072 ];
1073 probe_kernel
1074 .clone()
1075 .launch_on_stream(
1076 &cu_stream,
1077 LaunchConfig {
1078 grid_dim: (probe_grid, 1, 1),
1079 block_dim: (BLOCK_SIZE, 1, 1),
1080 shared_mem_bytes: 0,
1081 },
1082 &mut params,
1083 )
1084 .map_err(|e| {
1085 XlogError::Kernel(format!(
1086 "{} launch failed: {e}",
1087 width.probe_kernel()
1088 ))
1089 })?;
1090 }
1091 }
1092 rec.commit(runtime)
1093 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit commit failed: {e}")))?;
1094 }
1095
1096 consumed[a] += c;
1098 for probe in &node.probes {
1099 consumed[probe.input_idx] += probe.var_positions.len();
1100 }
1101 for pr in &fused {
1104 let p = pr.input_idx;
1105 child_cols.retain(
1106 |(t, _)| !matches!(t, ColTag::RangeLo(x) | ColTag::RangeHi(x) if *x == p),
1107 );
1108 }
1109 for pp in &mut probe_plans {
1112 let p = pp.input_idx;
1113 child_cols.retain(
1114 |(t, _)| !matches!(t, ColTag::RangeLo(x) | ColTag::RangeHi(x) if *x == p),
1115 );
1116 if pp.keep {
1117 child_cols.push((
1118 ColTag::RangeLo(p),
1119 pp.out_lo.take().expect("keep implies out_lo"),
1120 ));
1121 child_cols.push((
1122 ColTag::RangeHi(p),
1123 pp.out_hi.take().expect("keep implies out_hi"),
1124 ));
1125 }
1126 }
1127
1128 if let Some(mask) = mask {
1130 let tags: Vec<ColTag> = child_cols.iter().map(|(t, _)| *t).collect();
1131 let schema = Schema::new(
1135 tags.iter()
1136 .enumerate()
1137 .map(|(i, t)| (format!("f{i}"), tag_type(*t, width)))
1138 .collect(),
1139 );
1140 let d_nr = self.memory().alloc::<u32>(1)?;
1141 self.htod_launch_metadata_async_copy_one(
1142 &n_children,
1143 &d_nr,
1144 &cu_stream,
1145 &format!("{ctx}: frontier num_rows"),
1146 )?;
1147 let columns: Vec<CudaColumn> =
1148 child_cols.drain(..).map(|(_, s)| s.into()).collect();
1149 let staging = CudaBuffer::from_columns_with_host_count(
1150 columns,
1151 u64::from(n_children),
1152 d_nr,
1153 schema,
1154 n_children,
1155 );
1156 let compacted = self.compact_buffer_by_device_mask_counted_recorded(
1157 &staging,
1158 &mask,
1159 launch_stream,
1160 )?;
1161 let new_count = compacted.cached_row_count().ok_or_else(|| {
1162 XlogError::Kernel(format!("{ctx}: compaction lost its row count"))
1163 })?;
1164 if new_count == 0 {
1165 return self.create_empty_buffer(out_schema);
1166 }
1167 let mut new_frontier: Vec<FrontierCol> = Vec::with_capacity(tags.len());
1168 for (tag, col) in tags.into_iter().zip(compacted.columns.into_iter()) {
1169 let CudaColumn::Owned(slice) = col else {
1170 return Err(XlogError::Kernel(format!(
1171 "{ctx}: compaction produced a non-owned column"
1172 )));
1173 };
1174 new_frontier.push((tag, slice));
1175 }
1176 frontier = new_frontier;
1177 count = new_count;
1178 } else {
1179 frontier = child_cols;
1180 count = n_children;
1181 }
1182 frontier_cap = n_children;
1183 }
1184
1185 if mode == FjMode::CountByRoot {
1191 let group_var = plan.output_vars[0];
1192 let mut lo_ptrs: Vec<u64> = Vec::new();
1193 let mut hi_ptrs: Vec<u64> = Vec::new();
1194 for (t, s) in &frontier {
1195 if let ColTag::RangeLo(x) = t {
1196 lo_ptrs.push(*s.device_ptr());
1197 hi_ptrs.push(*find_col(&frontier, ColTag::RangeHi(*x), ctx)?.device_ptr());
1198 }
1199 }
1200 let upload_tbl = |ptrs: &[u64]| -> Result<TrackedCudaSlice<u64>> {
1201 let mut tbl = self.memory().alloc::<u64>(ptrs.len().max(1))?;
1202 if !ptrs.is_empty() {
1203 self.htod_launch_metadata_sync_copy_into(ptrs, &mut tbl)
1204 .map_err(|e| {
1205 XlogError::Kernel(format!("{ctx}: htod range table failed: {e}"))
1206 })?;
1207 }
1208 Ok(tbl)
1209 };
1210 let d_lo_tbl = upload_tbl(&lo_ptrs)?;
1211 let d_hi_tbl = upload_tbl(&hi_ptrs)?;
1212 let mut mult = self
1216 .memory()
1217 .alloc::<u8>(frontier_cap as usize * std::mem::size_of::<u64>())?;
1218 {
1219 let mut rec = LaunchRecorder::new_strict(launch_stream);
1220 for (t, s) in &frontier {
1221 if matches!(t, ColTag::RangeLo(_) | ColTag::RangeHi(_)) {
1222 rec.read(s);
1223 }
1224 }
1225 rec.read(&d_lo_tbl);
1226 rec.read(&d_hi_tbl);
1227 rec.write(&mult);
1228 rec.preflight(runtime).map_err(|e| {
1229 XlogError::Kernel(format!("{ctx}: multiplicity preflight failed: {e}"))
1230 })?;
1231 let kernel = self
1232 .device()
1233 .inner()
1234 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_COUNT_MULTIPLICITY)
1235 .ok_or_else(|| {
1236 XlogError::Kernel("fj_count_multiplicity kernel not found".to_string())
1237 })?;
1238 let grid = count.div_ceil(BLOCK_SIZE);
1239 let n_ranges = lo_ptrs.len() as u32;
1240 unsafe {
1244 kernel
1245 .clone()
1246 .launch_on_stream(
1247 &cu_stream,
1248 LaunchConfig {
1249 grid_dim: (grid, 1, 1),
1250 block_dim: (BLOCK_SIZE, 1, 1),
1251 shared_mem_bytes: 0,
1252 },
1253 (&d_lo_tbl, &d_hi_tbl, n_ranges, count, &mut mult),
1254 )
1255 .map_err(|e| {
1256 XlogError::Kernel(format!("fj_count_multiplicity launch failed: {e}"))
1257 })?;
1258 }
1259 rec.commit(runtime).map_err(|e| {
1260 XlogError::Kernel(format!("{ctx}: multiplicity commit failed: {e}"))
1261 })?;
1262 }
1263 let key_idx = frontier
1264 .iter()
1265 .position(|(t, _)| *t == ColTag::Var(group_var))
1266 .ok_or_else(|| {
1267 XlogError::Kernel(format!("{ctx}: group variable {group_var} missing"))
1268 })?;
1269 let (_, key_col) = frontier.swap_remove(key_idx);
1270 let d_nr = self.memory().alloc::<u32>(1)?;
1271 self.htod_launch_metadata_async_copy_one(
1272 &count,
1273 &d_nr,
1274 &cu_stream,
1275 &format!("{ctx}: staging num_rows"),
1276 )?;
1277 let staging_schema = Schema::new(vec![
1278 (format!("v{group_var}"), width.var_type()),
1279 ("count".to_string(), ScalarType::U64),
1280 ]);
1281 let staging = CudaBuffer::from_columns_with_host_count(
1282 vec![key_col.into(), mult.into()],
1283 u64::from(frontier_cap),
1284 d_nr,
1285 staging_schema,
1286 count,
1287 );
1288 return self.groupby_multi_agg_recorded(
1289 &staging,
1290 &[0],
1291 &[(1, xlog_core::AggOp::Sum)],
1292 launch_stream,
1293 );
1294 }
1295
1296 let perm: Vec<usize> = plan
1301 .output_vars
1302 .iter()
1303 .map(|&v| {
1304 frontier
1305 .iter()
1306 .position(|(t, _)| *t == ColTag::Var(v))
1307 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: output variable {v} missing")))
1308 })
1309 .collect::<Result<_>>()?;
1310 let schema = Schema::new(
1311 frontier
1312 .iter()
1313 .enumerate()
1314 .map(|(i, (t, _))| (format!("f{i}"), tag_type(*t, width)))
1315 .collect(),
1316 );
1317 let d_nr = self.memory().alloc::<u32>(1)?;
1318 self.htod_launch_metadata_async_copy_one(
1319 &count,
1320 &d_nr,
1321 &cu_stream,
1322 &format!("{ctx}: result num_rows"),
1323 )?;
1324 let columns: Vec<CudaColumn> = frontier.into_iter().map(|(_, s)| s.into()).collect();
1325 let src = CudaBuffer::from_columns_with_host_count(
1329 columns,
1330 u64::from(frontier_cap),
1331 d_nr,
1332 schema,
1333 count,
1334 );
1335 let projected =
1336 self.wcoj_project_output_columns_recorded(&src, &perm, out_schema, launch_stream)?;
1337 let distinct_outputs: std::collections::BTreeSet<usize> =
1341 plan.output_vars.iter().copied().collect();
1342 if distinct_outputs.len() < bind_order.len() {
1343 return self.dedup_full_row_recorded(&projected, launch_stream);
1344 }
1345 Ok(projected)
1346 }
1347}