1use std::collections::HashMap;
7use std::marker::PhantomData;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::{Arc, Mutex, OnceLock};
11
12use std::ffi::c_void;
13use xlog_core::{Result, Schema, XlogError};
14
15use crate::{
16 cuda_compat::{
17 AsKernelParam, DeviceParamStorage, DevicePtr, DeviceRepr, DeviceSlice,
18 IntoKernelParamStorage, LaunchAsync, LaunchConfig,
19 },
20 cuda_graph::{CapturedCudaGraph, CsmCudaGraphKey, CudaGraphNode},
21 memory::{validate_logical_row_count, CudaColumn, TrackedCudaSlice},
22 CudaBuffer, CudaDevice, CudaStream, CudaViewMut, GpuMemoryManager,
23};
24
25mod arithmetic;
26mod filter;
27mod fj;
28mod fj_delta;
29mod fj_delta_sparse;
30mod groupby;
31mod ilp;
32mod ilp_exact;
33mod io;
34mod kernel_loading;
35pub mod kernel_paths;
36mod launch_safe;
37mod probabilistic;
38mod relational;
39mod transfer;
40mod wcoj;
41mod wcoj_metadata;
42mod wcoj_project;
43
44pub use fj::{FjNode, FjPlan, FjSubAtom};
45pub use fj_delta::{FjDeltaCols, FJ_DELTA_MAX_DOMAIN};
46
47#[derive(Debug, Clone, Default)]
49pub struct PtxLoadProfile {
50 pub total_sec: f64,
51 pub per_module_sec: Vec<(String, f64)>,
52 pub cubin_loaded: u32,
53 pub ptx_fallback: u32,
54}
55
56fn warmup_profiling_enabled() -> bool {
57 std::env::var("XLOG_WARMUP_PROFILE")
58 .map(|v| v == "1")
59 .unwrap_or(false)
60}
61
62fn detect_compute_capability(device: &Arc<CudaDevice>) -> Result<u32> {
64 let major = device
65 .inner()
66 .attribute(
67 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
68 )
69 .map_err(|e| XlogError::Kernel(format!("Failed to query SM major: {}", e)))?;
70 let minor = device
71 .inner()
72 .attribute(
73 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
74 )
75 .map_err(|e| XlogError::Kernel(format!("Failed to query SM minor: {}", e)))?;
76 Ok((major as u32) * 10 + (minor as u32))
77}
78
79#[cfg(test)]
80fn resolve_module_path(name: &str, cc: u32) -> Option<(std::path::PathBuf, bool)> {
81 kernel_paths::KernelArtifactLocator::from_env().resolve_module_path(name, cc)
82}
83
84#[derive(Debug)]
85pub(crate) enum KernelModuleSource {
86 File { path: PathBuf, is_cubin: bool },
87 EmbeddedPortablePtx { ptx: &'static str },
88}
89
90pub(crate) fn resolve_module_sources_with_locator(
91 name: &str,
92 cc: u32,
93 locator: &kernel_paths::KernelArtifactLocator,
94) -> Vec<KernelModuleSource> {
95 let mut sources: Vec<KernelModuleSource> = locator
96 .resolve_module_paths(name, cc)
97 .into_iter()
98 .filter(|(path, _)| !staged_artifact_is_stale(path))
103 .map(|(path, is_cubin)| KernelModuleSource::File { path, is_cubin })
104 .collect();
105
106 if let Some(ptx) = crate::embedded_kernel_data::portable_ptx(name) {
113 sources.push(KernelModuleSource::EmbeddedPortablePtx { ptx });
114 }
115 sources
116}
117
118fn staged_artifact_is_stale(path: &std::path::Path) -> bool {
127 let Some(file_name) = path.file_name().and_then(|n| n.to_str()) else {
128 return false;
129 };
130 let Some(expected) = crate::embedded_kernel_data::canonical_artifact_hash(file_name) else {
131 return false;
132 };
133 match std::fs::read(path) {
134 Ok(bytes) => fnv1a_64(&bytes) != expected,
135 Err(_) => false,
136 }
137}
138
139fn fnv1a_64(bytes: &[u8]) -> u64 {
141 let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
142 for &byte in bytes {
143 hash ^= byte as u64;
144 hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
145 }
146 hash
147}
148
149#[cfg(test)]
150mod kernel_source_resolution_tests {
151 use super::{
152 kernel_paths::KernelArtifactLocator, resolve_module_sources_with_locator,
153 KernelModuleSource,
154 };
155 use std::fs;
156
157 #[test]
158 fn keeps_portable_ptx_fallback_when_cubin_exists() {
159 let root = std::env::temp_dir().join(format!(
160 "xlog-kernel-fallback-{}-{}",
161 std::process::id(),
162 std::time::SystemTime::now()
163 .duration_since(std::time::UNIX_EPOCH)
164 .expect("system clock before UNIX_EPOCH")
165 .as_nanos()
166 ));
167 let kernels = root.join("kernels");
168 fs::create_dir_all(&kernels).expect("create kernels dir");
169 fs::write(kernels.join("fakekernel.sm_86.cubin"), b"cubin").expect("write cubin");
174 fs::write(kernels.join("fakekernel.portable.ptx"), b"ptx").expect("write ptx");
175 let expected_cubin = kernels.join("fakekernel.sm_86.cubin");
176 let expected_ptx = kernels.join("fakekernel.portable.ptx");
177
178 let locator = KernelArtifactLocator::new(None, Some(kernels.clone()), None);
179 let sources = resolve_module_sources_with_locator("fakekernel", 86, &locator);
180
181 assert_eq!(sources.len(), 2);
182 assert!(matches!(
183 &sources[0],
184 KernelModuleSource::File {
185 path,
186 is_cubin: true
187 } if path == &expected_cubin
188 ));
189 assert!(matches!(
190 &sources[1],
191 KernelModuleSource::File {
192 path,
193 is_cubin: false
194 } if path == &expected_ptx
195 ));
196
197 fs::remove_dir_all(root).expect("remove temp kernels");
198 }
199
200 #[test]
205 fn fnv1a_64_matches_known_vectors() {
206 assert_eq!(super::fnv1a_64(b""), 0xcbf2_9ce4_8422_2325);
207 assert_eq!(super::fnv1a_64(b"a"), 0xaf63_dc4c_8601_ec8c);
208 assert_eq!(super::fnv1a_64(b"foobar"), 0x85944171_f73967e8);
209 }
210
211 #[test]
215 fn staged_artifact_not_stale_without_canonical_hash() {
216 let root = std::env::temp_dir().join(format!(
217 "xlog-kernel-stale-{}-{}",
218 std::process::id(),
219 std::time::SystemTime::now()
220 .duration_since(std::time::UNIX_EPOCH)
221 .expect("system clock before UNIX_EPOCH")
222 .as_nanos()
223 ));
224 fs::create_dir_all(&root).expect("create dir");
225 let unknown = root.join("definitely_not_a_real_kernel.sm_86.cubin");
226 fs::write(&unknown, b"bytes").expect("write");
227 assert!(!super::staged_artifact_is_stale(&unknown));
228 assert!(!super::staged_artifact_is_stale(
229 &root.join("missing.portable.ptx")
230 ));
231 fs::remove_dir_all(root).expect("remove temp dir");
232 }
233}
234
235pub(crate) fn load_module_sources(name: &str, cc: u32) -> Result<Vec<KernelModuleSource>> {
240 debug_assert!(
241 crate::kernel_manifest_data::KERNEL_CU_NAMES.contains(&name),
242 "kernel module '{name}' is not in KERNEL_CU_NAMES manifest — update kernel_manifest_data.rs"
243 );
244 let locator = kernel_paths::KernelArtifactLocator::from_env();
245 let sources = resolve_module_sources_with_locator(name, cc, &locator);
246 if sources.is_empty() {
247 Err(XlogError::Kernel(format!(
248 "{name}: no cubin, sidecar portable PTX, or embedded portable PTX found"
249 )))
250 } else {
251 Ok(sources)
252 }
253}
254
255#[derive(Clone)]
256pub(crate) struct RawCudaView<'a, T> {
257 ptr: cudarc::driver::sys::CUdeviceptr,
258 len: usize,
259 stream: Arc<CudaStream>,
260 #[allow(dead_code)]
273 source_block: Option<&'a crate::device_runtime::DeviceBlock>,
274 _marker: PhantomData<&'a [T]>,
275}
276
277pub(crate) struct MultiblockScanScratchU32 {
284 levels: Vec<TrackedCudaSlice<u32>>,
285}
286
287impl MultiblockScanScratchU32 {
288 pub(crate) fn levels(&self) -> &[TrackedCudaSlice<u32>] {
289 &self.levels
290 }
291}
292
293pub(crate) struct CsmCudaGraphNodes {
294 pub(crate) count: CudaGraphNode,
295 pub(crate) total: CudaGraphNode,
296 pub(crate) materialize: CudaGraphNode,
297 pub(crate) node_count: usize,
298}
299
300pub(crate) struct CsmCudaGraphEntry {
301 pub(crate) graph: CapturedCudaGraph,
302 pub(crate) nodes: CsmCudaGraphNodes,
303 pub(crate) per_probe_count: TrackedCudaSlice<u32>,
304 pub(crate) per_probe_offsets: TrackedCudaSlice<u32>,
305 pub(crate) d_logical_count: TrackedCudaSlice<u32>,
306 pub(crate) d_overflow: TrackedCudaSlice<u8>,
307 pub(crate) d_output_left: TrackedCudaSlice<u32>,
308 pub(crate) d_output_right: TrackedCudaSlice<u32>,
309 pub(crate) scan_scratch: MultiblockScanScratchU32,
310 pub(crate) probe_capacity: u32,
311 pub(crate) output_capacity: u32,
312}
313
314impl<'a, T> DeviceSlice<T> for RawCudaView<'a, T> {
315 fn len(&self) -> usize {
316 self.len
317 }
318
319 fn stream(&self) -> &Arc<CudaStream> {
320 &self.stream
321 }
322}
323
324impl<'a, T> DevicePtr<T> for RawCudaView<'a, T> {
325 fn device_ptr<'b>(
326 &'b self,
327 _stream: &'b CudaStream,
328 ) -> (
329 cudarc::driver::sys::CUdeviceptr,
330 cudarc::driver::SyncOnDrop<'b>,
331 ) {
332 (self.ptr, cudarc::driver::SyncOnDrop::Sync(None))
333 }
334}
335
336impl<'a, T> RawCudaView<'a, T> {
337 pub fn device_ptr(&self) -> &cudarc::driver::sys::CUdeviceptr {
338 &self.ptr
339 }
340
341 #[allow(dead_code)]
350 pub fn runtime_block(&self) -> Option<&'a crate::device_runtime::DeviceBlock> {
351 self.source_block
352 }
353}
354
355impl<'a, T: DeviceRepr> AsKernelParam for &RawCudaView<'a, T> {
356 fn as_kernel_param(&self) -> *mut c_void {
357 ((*self).device_ptr() as *const cudarc::driver::sys::CUdeviceptr)
358 .cast_mut()
359 .cast()
360 }
361}
362
363impl<'a, T: DeviceRepr> IntoKernelParamStorage for &'a RawCudaView<'a, T> {
364 type Storage = DeviceParamStorage<'a>;
365
366 fn into_kernel_param_storage(self) -> Self::Storage {
367 DeviceParamStorage::unsynced(self.ptr)
368 }
369}
370
371pub struct RadixSortScratch {
373 keys_b: TrackedCudaSlice<u32>,
374 values_b: TrackedCudaSlice<u32>,
375 hist: TrackedCudaSlice<u32>,
376 prefix: TrackedCudaSlice<u32>,
377 ranks: TrackedCudaSlice<u32>,
378 len: u32,
379}
380
381impl RadixSortScratch {
382 pub fn new(provider: &CudaKernelProvider, n: u32) -> Result<Self> {
383 let memory = provider.memory();
384 let len = n.max(1);
385 let keys_b = memory.alloc::<u32>(len as usize)?;
386 let values_b = memory.alloc::<u32>(len as usize)?;
387 let ranks = memory.alloc::<u32>(len as usize)?;
388 let block_size = CudaKernelProvider::SORT_BLOCK_SIZE;
389 let grid_size = len.div_ceil(block_size).max(1);
390 let hist = memory.alloc::<u32>((grid_size as usize) * 16)?;
391 let prefix = memory.alloc::<u32>(16)?;
392 Ok(Self {
393 keys_b,
394 values_b,
395 hist,
396 prefix,
397 ranks,
398 len,
399 })
400 }
401
402 pub fn ensure_capacity(&mut self, provider: &CudaKernelProvider, n: u32) -> Result<()> {
403 if n <= self.len {
404 return Ok(());
405 }
406 *self = Self::new(provider, n)?;
407 Ok(())
408 }
409}
410
411pub const JOIN_MODULE: &str = "xlog_join";
413pub const DEDUP_MODULE: &str = "xlog_dedup";
414pub const GROUPBY_MODULE: &str = "xlog_groupby";
415pub const SCAN_MODULE: &str = "xlog_scan";
416pub const SORT_MODULE: &str = "xlog_sort";
417pub const FILTER_MODULE: &str = "xlog_filter";
418pub const SET_OPS_MODULE: &str = "xlog_set_ops";
419pub const PACK_MODULE: &str = "xlog_pack";
420pub const CIRCUIT_MODULE: &str = "xlog_circuit";
421pub const MC_SAMPLE_MODULE: &str = "xlog_mc_sample";
422pub const MC_EVAL_MODULE: &str = "xlog_mc_eval";
423pub const MC_RESIDENT_MODULE: &str = "xlog_mc_resident";
424pub const ARITH_MODULE: &str = "xlog_arith";
425pub const SAT_MODULE: &str = "xlog_sat";
426pub const D4_MODULE: &str = "xlog_d4";
427pub const NEURAL_MODULE: &str = "xlog_neural";
428pub const PIR_MODULE: &str = "xlog_pir";
429pub const CNF_MODULE: &str = "xlog_cnf";
430pub const CACHE_MODULE: &str = "xlog_cache";
431pub const WEIGHTS_MODULE: &str = "xlog_weights";
432pub const ILP_MODULE: &str = "xlog_ilp";
433pub const ILP_CREDIT_MODULE: &str = "xlog_ilp_credit";
434pub const ILP_EXACT_MODULE: &str = "xlog_ilp_exact";
435pub const EPISTEMIC_MODULE: &str = "xlog_epistemic";
436pub const WCOJ_MODULE: &str = "xlog_wcoj";
437
438const _: () = assert!(crate::kernel_manifest_data::KERNEL_CU_NAMES.len() == 25);
440
441pub mod wcoj_kernels {
443 pub const WCOJ_BUILD_METADATA_MARK_BOUNDARIES_U32: &str =
444 "wcoj_build_metadata_mark_boundaries_u32";
445 pub const WCOJ_BUILD_METADATA_MARK_BOUNDARIES_U64: &str =
446 "wcoj_build_metadata_mark_boundaries_u64";
447 pub const WCOJ_BUILD_METADATA_SCATTER_U32: &str = "wcoj_build_metadata_scatter_u32";
448 pub const WCOJ_BUILD_METADATA_SCATTER_U64: &str = "wcoj_build_metadata_scatter_u64";
449 pub const WCOJ_TRIANGLE_BUILD_HG_WORK_PLAN_U32: &str = "wcoj_triangle_build_hg_work_plan_u32";
450 pub const WCOJ_TRIANGLE_COUNT_HG_U32: &str = "wcoj_triangle_count_hg_u32";
451 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_COUNT_HG_U32: &str =
452 "wcoj_triangle_groupby_root_count_hg_u32";
453 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_SUM_HG_U32: &str = "wcoj_triangle_groupby_root_sum_hg_u32";
454 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MIN_HG_U32: &str = "wcoj_triangle_groupby_root_min_hg_u32";
455 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MAX_HG_U32: &str = "wcoj_triangle_groupby_root_max_hg_u32";
456 pub const WCOJ_TRIANGLE_MATERIALIZE_HG_U32: &str = "wcoj_triangle_materialize_hg_u32";
457 pub const WCOJ_TRIANGLE_BUILD_HG_WORK_PLAN_U64: &str = "wcoj_triangle_build_hg_work_plan_u64";
458 pub const WCOJ_TRIANGLE_COUNT_HG_U64: &str = "wcoj_triangle_count_hg_u64";
459 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_COUNT_HG_U64: &str =
460 "wcoj_triangle_groupby_root_count_hg_u64";
461 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_SUM_HG_U64: &str = "wcoj_triangle_groupby_root_sum_hg_u64";
462 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MIN_HG_U64: &str = "wcoj_triangle_groupby_root_min_hg_u64";
463 pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MAX_HG_U64: &str = "wcoj_triangle_groupby_root_max_hg_u64";
464 pub const WCOJ_GROUPBY_ROOT_SEGMENT_SUM_COUNTS_U32: &str =
465 "wcoj_groupby_root_segment_sum_counts_u32";
466 pub const WCOJ_GROUPBY_ROOT_SEGMENT_SUM_VALUES_U64: &str =
467 "wcoj_groupby_root_segment_sum_values_u64";
468 pub const WCOJ_GROUPBY_ROOT_SEGMENT_MIN_VALUES_U64: &str =
469 "wcoj_groupby_root_segment_min_values_u64";
470 pub const WCOJ_GROUPBY_ROOT_SEGMENT_MAX_VALUES_U64: &str =
471 "wcoj_groupby_root_segment_max_values_u64";
472 pub const WCOJ_TRIANGLE_MATERIALIZE_HG_U64: &str = "wcoj_triangle_materialize_hg_u64";
473 pub const WCOJ_TRIANGLE_COUNT_HG_CACHED_U32: &str = "wcoj_triangle_count_hg_cached_u32";
474 pub const WCOJ_TRIANGLE_MATERIALIZE_HG_CACHED_U32: &str =
475 "wcoj_triangle_materialize_hg_cached_u32";
476 pub const WCOJ_SCAN_HG_BLOCK_COUNTS_U32: &str = "wcoj_scan_hg_block_counts_u32";
477 pub const WCOJ_COMPUTE_TOTAL: &str = "wcoj_compute_total";
478 pub const WCOJ_LAYOUT_CHECK_SORTED_UNIQUE_U32: &str = "wcoj_layout_check_sorted_unique_u32";
479 pub const WCOJ_LAYOUT_CHECK_SORTED_UNIQUE_U64: &str = "wcoj_layout_check_sorted_unique_u64";
480 pub const WCOJ_4CYCLE_BUILD_E2_WORK_PREFIX_U32: &str = "wcoj_4cycle_build_e2_work_prefix_u32";
481 pub const WCOJ_4CYCLE_BUILD_HG_WORK_PLAN_U32: &str = "wcoj_4cycle_build_hg_work_plan_u32";
482 pub const WCOJ_4CYCLE_COUNT_HG_U32: &str = "wcoj_4cycle_count_hg_u32";
483 pub const WCOJ_4CYCLE_GROUPBY_ROOT_COUNT_HG_U32: &str = "wcoj_4cycle_groupby_root_count_hg_u32";
484 pub const WCOJ_4CYCLE_GROUPBY_ROOT_SUM_HG_U32: &str = "wcoj_4cycle_groupby_root_sum_hg_u32";
485 pub const WCOJ_4CYCLE_GROUPBY_ROOT_MIN_HG_U32: &str = "wcoj_4cycle_groupby_root_min_hg_u32";
486 pub const WCOJ_4CYCLE_GROUPBY_ROOT_MAX_HG_U32: &str = "wcoj_4cycle_groupby_root_max_hg_u32";
487 pub const WCOJ_4CYCLE_MATERIALIZE_HG_U32: &str = "wcoj_4cycle_materialize_hg_u32";
488 pub const WCOJ_4CYCLE_BUILD_E2_WORK_PREFIX_U64: &str = "wcoj_4cycle_build_e2_work_prefix_u64";
489 pub const WCOJ_4CYCLE_BUILD_HG_WORK_PLAN_U64: &str = "wcoj_4cycle_build_hg_work_plan_u64";
490 pub const WCOJ_4CYCLE_COUNT_HG_U64: &str = "wcoj_4cycle_count_hg_u64";
491 pub const WCOJ_4CYCLE_GROUPBY_ROOT_COUNT_HG_U64: &str = "wcoj_4cycle_groupby_root_count_hg_u64";
492 pub const WCOJ_4CYCLE_MATERIALIZE_HG_U64: &str = "wcoj_4cycle_materialize_hg_u64";
493 pub const WCOJ_CLIQUE5_COUNT_HG_U32: &str = "wcoj_clique5_count_hg_u32";
495 pub const WCOJ_CLIQUE5_MATERIALIZE_HG_U32: &str = "wcoj_clique5_materialize_hg_u32";
496 pub const WCOJ_CLIQUE5_COUNT_HG_U64: &str = "wcoj_clique5_count_hg_u64";
497 pub const WCOJ_CLIQUE5_MATERIALIZE_HG_U64: &str = "wcoj_clique5_materialize_hg_u64";
498 pub const WCOJ_CLIQUE6_COUNT_HG_U32: &str = "wcoj_clique6_count_hg_u32";
499 pub const WCOJ_CLIQUE6_MATERIALIZE_HG_U32: &str = "wcoj_clique6_materialize_hg_u32";
500 pub const WCOJ_CLIQUE6_COUNT_HG_U64: &str = "wcoj_clique6_count_hg_u64";
501 pub const WCOJ_CLIQUE6_MATERIALIZE_HG_U64: &str = "wcoj_clique6_materialize_hg_u64";
502 pub const WCOJ_CLIQUE7_COUNT_HG_U32: &str = "wcoj_clique7_count_hg_u32";
503 pub const WCOJ_CLIQUE7_MATERIALIZE_HG_U32: &str = "wcoj_clique7_materialize_hg_u32";
504 pub const WCOJ_CLIQUE7_COUNT_HG_U64: &str = "wcoj_clique7_count_hg_u64";
505 pub const WCOJ_CLIQUE7_MATERIALIZE_HG_U64: &str = "wcoj_clique7_materialize_hg_u64";
506 pub const WCOJ_CLIQUE8_COUNT_HG_U32: &str = "wcoj_clique8_count_hg_u32";
507 pub const WCOJ_CLIQUE8_MATERIALIZE_HG_U32: &str = "wcoj_clique8_materialize_hg_u32";
508 pub const WCOJ_CLIQUE8_COUNT_HG_U64: &str = "wcoj_clique8_count_hg_u64";
509 pub const WCOJ_CLIQUE8_MATERIALIZE_HG_U64: &str = "wcoj_clique8_materialize_hg_u64";
510 pub const WCOJ_CLIQUE5_GROUPBY_ROOT_COUNT_HG_U32: &str =
511 "wcoj_clique5_groupby_root_count_hg_u32";
512 pub const WCOJ_CLIQUE6_GROUPBY_ROOT_COUNT_HG_U32: &str =
513 "wcoj_clique6_groupby_root_count_hg_u32";
514 pub const FJ_EXPAND_WORK_PREFIX_U32: &str = "fj_expand_work_prefix_u32";
518 pub const FJ_EXPAND_COUNT_U32: &str = "fj_expand_count_u32";
519 pub const FJ_EXPAND_EMIT_U32: &str = "fj_expand_emit_u32";
520 pub const FJ_PROBE_REFINE_U32: &str = "fj_probe_refine_u32";
521 pub const FJ_EXPAND_COUNT_U64: &str = "fj_expand_count_u64";
522 pub const FJ_EXPAND_EMIT_U64: &str = "fj_expand_emit_u64";
523 pub const FJ_PROBE_REFINE_U64: &str = "fj_probe_refine_u64";
524 pub const FJ_COUNT_MULTIPLICITY: &str = "fj_count_multiplicity";
525 pub const FJ_DELTA_RANGE_U32: &str = "fj_delta_range_u32";
527 pub const FJ_DELTA_MARK_U32: &str = "fj_delta_mark_u32";
528 pub const FJ_DELTA_SUBTRACT_U32: &str = "fj_delta_subtract_u32";
529 pub const FJ_DELTA_POPCOUNT: &str = "fj_delta_popcount";
530 pub const FJ_DELTA_EMIT_U32: &str = "fj_delta_emit_u32";
531 pub const FJ_DELTA_MAX_U32: &str = "fj_delta_max_u32";
532 pub const FJ_DELTA_SPARSE_ESTIMATE: &str = "fj_delta_sparse_estimate";
533 pub const FJ_DELTA_SPARSE_LOAD_R: &str = "fj_delta_sparse_load_r";
534 pub const FJ_DELTA_SPARSE_INSERT_CANDIDATES: &str = "fj_delta_sparse_insert_candidates";
535 pub const FJ_DELTA_SPARSE_MARK: &str = "fj_delta_sparse_mark";
536 pub const FJ_DELTA_SPARSE_EMIT: &str = "fj_delta_sparse_emit";
537}
538
539pub mod mc_sample_kernels {
541 pub const MC_SAMPLE_BERNOULLI: &str = "mc_sample_bernoulli";
542}
543
544pub mod mc_eval_kernels {
546 pub const MC_EVAL_MASK_VAR: &str = "mc_eval_mask_var";
547 pub const MC_EVAL_MASK_AD: &str = "mc_eval_mask_ad_choice";
548 pub const MC_EVAL_QUERY_EVIDENCE_TRUTH: &str = "mc_eval_query_evidence_truth";
549 pub const MC_EVAL_ACCUMULATE_COUNTS: &str = "mc_accumulate_counts";
550}
551
552pub mod mc_resident_kernels {
554 pub const MC_RESIDENT_ENGINE: &str = "mc_resident_engine";
557}
558
559pub mod arith_kernels {
561 pub const ARITH_BINARY_I64: &str = "arith_binary_i64";
562 pub const ARITH_BINARY_I32: &str = "arith_binary_i32";
563 pub const ARITH_BINARY_U64: &str = "arith_binary_u64";
564 pub const ARITH_BINARY_U32: &str = "arith_binary_u32";
565 pub const ARITH_BINARY_F64: &str = "arith_binary_f64";
566 pub const ARITH_BINARY_F32: &str = "arith_binary_f32";
567 pub const ARITH_ABS_I64: &str = "arith_abs_i64";
568 pub const ARITH_ABS_I32: &str = "arith_abs_i32";
569 pub const ARITH_ABS_F64: &str = "arith_abs_f64";
570 pub const ARITH_ABS_F32: &str = "arith_abs_f32";
571 pub const ARITH_POW_F64: &str = "arith_pow_f64";
572 pub const ARITH_CAST: &str = "arith_cast";
573 pub const ARITH_FILL_CONST_U32: &str = "arith_fill_const_u32";
574 pub const ARITH_FILL_CONST_U64: &str = "arith_fill_const_u64";
575 pub const ARITH_FILL_CONST_I64: &str = "arith_fill_const_i64";
576 pub const ARITH_FILL_CONST_I32: &str = "arith_fill_const_i32";
577 pub const ARITH_FILL_CONST_F64: &str = "arith_fill_const_f64";
578 pub const ARITH_FILL_CONST_F32: &str = "arith_fill_const_f32";
579 pub const ARITH_FILL_CONST_U8: &str = "arith_fill_const_u8";
580 pub const ARITH_SELECT_I64: &str = "arith_select_i64";
582 pub const ARITH_SELECT_I32: &str = "arith_select_i32";
583 pub const ARITH_SELECT_U64: &str = "arith_select_u64";
584 pub const ARITH_SELECT_U32: &str = "arith_select_u32";
585 pub const ARITH_SELECT_F64: &str = "arith_select_f64";
586 pub const ARITH_SELECT_F32: &str = "arith_select_f32";
587}
588
589pub mod epistemic_kernels {
591 pub const EPISTEMIC_GENERATE_CANDIDATE_ASSUMPTIONS_U8: &str =
593 "epistemic_generate_candidate_assumptions_u8";
594 pub const EPISTEMIC_PROPAGATE_CANDIDATES_U8: &str = "epistemic_propagate_candidates_u8";
596 pub const EPISTEMIC_VALIDATE_CANDIDATE_BITS_U8: &str = "epistemic_validate_candidate_bits_u8";
598 pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_U8: &str =
600 "epistemic_populate_model_membership_u8";
601 pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_U8: &str =
603 "epistemic_populate_model_membership_from_tuple_source_u8";
604 pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY1_U8: &str =
606 "epistemic_populate_model_membership_from_tuple_source_arity1_u8";
607 pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY2_U8: &str =
609 "epistemic_populate_model_membership_from_tuple_source_arity2_u8";
610 pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY3_U8: &str =
612 "epistemic_populate_model_membership_from_tuple_source_arity3_u8";
613 pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY_N_U8: &str =
615 "epistemic_populate_model_membership_from_tuple_source_arity_n_u8";
616 pub const EPISTEMIC_VALIDATE_WORLD_VIEWS_U8: &str = "epistemic_validate_world_views_u8";
618 pub const EPISTEMIC_VALIDATE_CONSTRAINTS_U8: &str = "epistemic_validate_constraints_u8";
620 pub const EPISTEMIC_MATERIALIZE_ACCEPTED_CANDIDATES_U8: &str =
622 "epistemic_materialize_accepted_candidates_u8";
623
624 pub const EPISTEMIC_MATERIALIZE_FINAL_RESULT_FLAGS_U8: &str =
626 "epistemic_materialize_final_result_flags_u8";
627 pub const EPISTEMIC_MATERIALIZE_FINAL_TUPLE_COLUMN_U8: &str =
629 "epistemic_materialize_final_tuple_column_u8";
630 pub const EPISTEMIC_BUILD_FINAL_TUPLE_ROW_MAP_U8: &str =
632 "epistemic_build_final_tuple_row_map_u8";
633 pub const EPISTEMIC_CLOSE_FINAL_TUPLE_REJECTIONS_U8: &str =
635 "epistemic_close_final_tuple_rejections_u8";
636}
637
638pub mod neural_kernels {
640 pub const NEURAL_FILL_AD_CHAIN_F32: &str = "neural_fill_ad_chain_f32";
641 pub const NEURAL_SCATTER_AD_CHAIN_GRADS_F32: &str = "neural_scatter_ad_chain_grads_f32";
642}
643
644pub mod ilp_kernels {
646 pub const EXTRACT_NONZERO_INDICES: &str = "extract_nonzero_indices";
647 pub const ILP_MARK_SELECTED_IDS_U32: &str = "ilp_mark_selected_ids_u32";
648 pub const ILP_MARK_SELECTED_IDS_I32: &str = "ilp_mark_selected_ids_i32";
649 pub const ILP_MARK_SELECTED_IDS_I64: &str = "ilp_mark_selected_ids_i64";
650 pub const ILP_MARK_SELECTED_IDS_U64: &str = "ilp_mark_selected_ids_u64";
651 pub const ILP_VALIDATE_SELECTED_IDS_U32: &str = "ilp_validate_selected_ids_u32";
652 pub const ILP_VALIDATE_SELECTED_IDS_I32: &str = "ilp_validate_selected_ids_i32";
653 pub const ILP_VALIDATE_SELECTED_IDS_I64: &str = "ilp_validate_selected_ids_i64";
654 pub const ILP_VALIDATE_SELECTED_IDS_U64: &str = "ilp_validate_selected_ids_u64";
655 pub const ILP_BROADCAST_CANDIDATE_FLAG: &str = "ilp_broadcast_candidate_flag";
656 pub const ILP_COO_FILL_FROM_MASK: &str = "ilp_coo_fill_from_mask";
657 pub const ILP_CSR_HISTOGRAM: &str = "ilp_csr_histogram";
658 pub const ILP_REDUCE_SUM_F32: &str = "ilp_reduce_sum_f32";
659 pub const ILP_REDUCE_SUM_F64: &str = "ilp_reduce_sum_f64";
660}
661
662pub mod ilp_credit_kernels {
664 pub const ILP_COO_FILL: &str = "ilp_coo_fill";
665 pub const ILP_CREDIT_FORWARD_F32: &str = "ilp_credit_forward_f32";
666 pub const ILP_CREDIT_FORWARD_F64: &str = "ilp_credit_forward_f64";
667 pub const ILP_CREDIT_BACKWARD_F32: &str = "ilp_credit_backward_f32";
668 pub const ILP_CREDIT_BACKWARD_F64: &str = "ilp_credit_backward_f64";
669}
670
671pub mod ilp_exact_kernels {
673 pub const ILP_EXACT_SCORE: &str = "ilp_exact_score";
674 pub const ILP_EXACT_SCORE_U32: &str = "ilp_exact_score_u32";
675 pub const ILP_EXACT_SCORE_CHAIN_SMEM: &str = "ilp_exact_score_chain_smem";
676 pub const ILP_EXACT_SCORE_CHAIN_SMEM_U32: &str = "ilp_exact_score_chain_smem_u32";
677 pub const ILP_EXACT_SELECT_TOPK: &str = "ilp_exact_select_topk";
678}
679
680pub mod pir_kernels {
682 pub const PIR_PACK_KEYS: &str = "pir_pack_keys";
683 pub const PIR_HASH_KEYS: &str = "pir_hash_keys";
684 pub const PIR_MARK_UNIQUE: &str = "pir_mark_unique";
685 pub const PIR_FIND_EXISTING: &str = "pir_find_existing";
686 pub const PIR_MARK_NEW_GROUPS: &str = "pir_mark_new_groups";
687 pub const PIR_BUILD_GROUP_IDS: &str = "pir_build_group_ids";
688 pub const PIR_FILL_CHILD_PARENTS: &str = "pir_fill_child_parents";
689 pub const PIR_MARK_UNIQUE_PAIRS: &str = "pir_mark_unique_pairs";
690 pub const PIR_COMPACT_PAIRS: &str = "pir_compact_pairs";
691 pub const PIR_COUNT_CHILDREN: &str = "pir_count_children";
692 pub const PIR_WRITE_CHILD_OFFSETS: &str = "pir_write_child_offsets";
693 pub const PIR_GATHER_CHILDREN: &str = "pir_gather_children";
694 pub const PIR_BUILD_GRAPH_CHILD_COUNTS: &str = "pir_build_graph_child_counts";
695 pub const PIR_SUM_COUNTS: &str = "pir_sum_counts";
696 pub const PIR_EMIT_NODES_AND_IDS: &str = "pir_emit_nodes_and_ids";
697 pub const PIR_UPDATE_COUNTS: &str = "pir_update_counts";
698}
699
700pub mod cnf_kernels {
702 pub const CNF_REACHABILITY_INIT: &str = "cnf_reachability_init";
703 pub const CNF_REACHABILITY_BFS: &str = "cnf_reachability_bfs";
704 pub const CNF_MARK_LEAF_CHOICE: &str = "cnf_mark_leaf_choice";
705 pub const CNF_ASSIGN_LEAF_VAR: &str = "cnf_assign_leaf_var";
706 pub const CNF_ASSIGN_CHOICE_VAR: &str = "cnf_assign_choice_var";
707 pub const CNF_MARK_NODE_VARS: &str = "cnf_mark_node_vars";
708 pub const CNF_COUNT_CLAUSES: &str = "cnf_count_clauses";
709 pub const CNF_CAPTURE_LAST_COUNTS: &str = "cnf_capture_last_counts";
710 pub const CNF_COMPUTE_LEAF_CHOICE_TOTALS: &str = "cnf_compute_leaf_choice_totals";
711 pub const CNF_COMPUTE_TOTALS: &str = "cnf_compute_totals";
712 pub const CNF_ASSIGN_NODE_VAR: &str = "cnf_assign_node_var";
713 pub const CNF_EMIT_CLAUSES: &str = "cnf_emit_clauses";
714 pub const CNF_SET_CLAUSE_END: &str = "cnf_set_clause_end";
715}
716
717pub mod weights_kernels {
719 pub const WEIGHTS_FILL_LEAF: &str = "weights_fill_leaf";
720 pub const WEIGHTS_FILL_CHOICE: &str = "weights_fill_choice";
721 pub const WEIGHTS_COUNT_LIFT_EXACT: &str = "weights_count_lift_exact";
722 pub const WEIGHTS_SET_EVIDENCE_FROM_NODES: &str = "weights_set_evidence_from_nodes";
723 pub const WEIGHTS_APPLY_EVIDENCE: &str = "weights_apply_evidence";
724 pub const WEIGHTS_MAP_NODES_TO_VARS: &str = "weights_map_nodes_to_vars";
725 pub const WEIGHTS_FORCE_VAR_FALSE: &str = "weights_force_var_false";
726 pub const WEIGHTS_RESTORE_VAR_FALSE: &str = "weights_restore_var_false";
727 pub const WEIGHTS_FORCE_VAR_TRUE: &str = "weights_force_var_true";
728 pub const WEIGHTS_RESTORE_VAR_TRUE: &str = "weights_restore_var_true";
729 pub const WEIGHTS_COPY_SLOT_TO_BATCH: &str = "weights_copy_slot_to_batch";
730 pub const WEIGHTS_APPLY_QUERY_VARS: &str = "weights_apply_query_vars";
731 pub const WEIGHTS_RESTORE_QUERY_VARS: &str = "weights_restore_query_vars";
732 pub const WEIGHTS_APPLY_QUERY_VARS_FALSE_BATCHED: &str =
733 "weights_apply_query_vars_false_batched";
734 pub const WEIGHTS_RESTORE_QUERY_VARS_FALSE_BATCHED: &str =
735 "weights_restore_query_vars_false_batched";
736 pub const WEIGHTS_APPLY_QUERY_VARS_TRUE_BATCHED: &str = "weights_apply_query_vars_true_batched";
737 pub const WEIGHTS_RESTORE_QUERY_VARS_TRUE_BATCHED: &str =
738 "weights_restore_query_vars_true_batched";
739}
740
741pub mod d4_kernels {
744 pub const D4_VALIDATE_CNF: &str = "d4_validate_cnf";
745 pub const D4_LEVELIZE_COUNTS: &str = "d4_levelize_counts";
746 pub const D4_LEVELIZE_EMIT: &str = "d4_levelize_emit";
747 pub const D4_FRONTIER_PREPARE: &str = "d4_frontier_prepare";
749 pub const D4_FRONTIER_EXPAND: &str = "d4_frontier_expand";
750 pub const D4_FRONTIER_PREPARE_DENSE: &str = "d4_frontier_prepare_dense";
751 pub const D4_FRONTIER_EXPAND_DENSE: &str = "d4_frontier_expand_dense";
752 pub const D4_COMPILE_COUNT: &str = "d4_compile_count";
754 pub const D4_COMPILE_EMIT: &str = "d4_compile_emit";
755 pub const D4_CAPTURE_EMIT_META: &str = "d4_capture_emit_meta";
756 pub const D4_SUPPORT_LEVEL: &str = "d4_support_level";
758 pub const D4_SUPPORT_SET_ROOT_BITS: &str = "d4_support_set_root_bits";
759 pub const D4_SMOOTH_COUNT: &str = "d4_smooth_count";
760 pub const D4_SMOOTH_WRAPPER_COUNTS: &str = "d4_smooth_wrapper_counts";
761 pub const D4_SMOOTH_WRAPPER_EDGE_COUNTS_OR: &str = "d4_smooth_wrapper_edge_counts_or";
762 pub const D4_SMOOTH_WRAPPER_EDGE_COUNTS_DEC: &str = "d4_smooth_wrapper_edge_counts_dec";
763 pub const D4_SMOOTH_INIT_NODES: &str = "d4_smooth_init_nodes";
764 pub const D4_SMOOTH_EMIT_LEVEL: &str = "d4_smooth_emit_level";
765 pub const D4_SMOOTH_CHECK_EDGE_CAP: &str = "d4_smooth_check_edge_cap";
766 pub const D4_MARK_VARS_IN_CLAUSES: &str = "d4_mark_vars_in_clauses";
768 pub const D4_MARK_VARS_IN_CIRCUIT: &str = "d4_mark_vars_in_circuit";
769 pub const D4_BUILD_FREE_VAR_MASK: &str = "d4_build_free_var_mask";
770 pub const D4_ASSERT_U32_EQ: &str = "d4_assert_u32_eq";
772 pub const D4_ASSERT_BITSET_VAR: &str = "d4_assert_bitset_var";
773 pub const D4_ASSERT_DENSE_VAR: &str = "d4_assert_dense_var";
774 pub const D4_ASSERT_LEAF_ROOT_AND_DEGREE: &str = "d4_assert_leaf_root_and_degree";
775}
776
777pub mod join_kernels {
779 pub const HASH_JOIN_BUILD: &str = "hash_join_build";
780 pub const HASH_JOIN_PROBE: &str = "hash_join_probe";
781 pub const COMPUTE_COMPOSITE_HASH: &str = "compute_composite_hash";
783 pub const HASH_JOIN_BUCKET_COUNT_V2: &str = "hash_join_bucket_count_v2";
784 pub const HASH_JOIN_SCATTER_V2: &str = "hash_join_scatter_v2";
785 pub const HASH_JOIN_PROBE_V2: &str = "hash_join_probe_v2";
786 pub const HASH_JOIN_PROBE_V2_COUNT_PER_ROW: &str = "hash_join_probe_v2_count_per_row";
787 pub const HASH_JOIN_PROBE_V2_MATERIALIZE: &str = "hash_join_probe_v2_materialize";
788 pub const HASH_JOIN_TOTAL_FROM_SCAN: &str = "hash_join_total_from_scan";
789 pub const HASH_JOIN_CSM_UNMATCHED_MASK: &str = "hash_join_csm_unmatched_mask";
790 pub const HASH_JOIN_SEMI: &str = "hash_join_semi";
791 pub const HASH_JOIN_ANTI: &str = "hash_join_anti";
792 pub const INIT_HASH_TABLE: &str = "init_hash_table";
793 pub const NESTED_LOOP_JOIN_INNER_U32_1KEY_PAIRS: &str = "nested_loop_join_inner_u32_1key_pairs";
799 pub const SORT_MERGE_JOIN_INNER_U32_1KEY_PAIRS: &str = "sort_merge_join_inner_u32_1key_pairs";
807}
808
809pub mod dedup_kernels {
811 pub const MARK_DUPLICATES: &str = "mark_duplicates";
812 pub const MARK_UNIQUE_COLUMNAR: &str = "mark_unique_columnar";
813 pub const MARK_UNIQUE_AND_SCAN_COLUMNAR: &str = "mark_unique_and_scan_columnar";
814 pub const COMPACT_ROWS: &str = "compact_rows";
815 pub const MARK_UNIQUE_FULL_ROW_BYTEWISE: &str = "mark_unique_full_row_bytewise";
816 pub const MARK_DIFF_FULL_ROW_TYPED_SORTED: &str = "mark_diff_full_row_typed_sorted";
817 pub const SMALL_SORT_FULL_ROW_INDICES_TYPED: &str = "small_sort_full_row_indices_typed";
818}
819
820pub mod groupby_kernels {
822 pub const DETECT_GROUP_BOUNDARIES: &str = "detect_group_boundaries";
823 pub const DETECT_BOUNDARIES: &str = "detect_boundaries";
824 pub const EXTRACT_GROUP_KEYS: &str = "extract_group_keys";
825 pub const GROUP_IDS_FROM_BOUNDARIES: &str = "group_ids_from_boundaries";
826 pub const GROUP_START_INDICES: &str = "group_start_indices";
827 pub const CAPTURE_NUM_GROUPS: &str = "capture_num_groups";
828 pub const GROUPBY_COUNT: &str = "groupby_count";
829 pub const GROUPBY_SUM: &str = "groupby_sum";
830 pub const GROUPBY_SUM_U64: &str = "groupby_sum_u64";
831 pub const GROUPBY_MIN: &str = "groupby_min";
832 pub const GROUPBY_MIN_U64: &str = "groupby_min_u64";
833 pub const GROUPBY_MAX: &str = "groupby_max";
834 pub const GROUPBY_MAX_U64: &str = "groupby_max_u64";
835 pub const GROUPBY_LOGSUMEXP_MAX: &str = "groupby_logsumexp_max";
836 pub const GROUPBY_LOGSUMEXP_SUMEXP: &str = "groupby_logsumexp_sumexp";
837 pub const GROUPBY_LOGSUMEXP_FINAL: &str = "groupby_logsumexp_final";
838}
839
840pub mod scan_kernels {
842 pub const BLOCK_INCLUSIVE_SCAN: &str = "block_inclusive_scan";
843 pub const ADD_BLOCK_OFFSETS: &str = "add_block_offsets";
844 pub const EXCLUSIVE_SCAN_MASK: &str = "exclusive_scan_mask";
845 pub const COUNT_MASK: &str = "count_mask";
846 pub const MULTIBLOCK_SCAN_PHASE1: &str = "multiblock_scan_phase1";
848 pub const MULTIBLOCK_SCAN_U32_PHASE1: &str = "multiblock_scan_u32_phase1";
849 pub const MULTIBLOCK_SCAN_PHASE2: &str = "multiblock_scan_phase2";
850 pub const MULTIBLOCK_SCAN_PHASE3: &str = "multiblock_scan_phase3";
851}
852
853pub mod sort_kernels {
855 pub const RADIX_HISTOGRAM: &str = "radix_histogram";
856 pub const RADIX_SCATTER: &str = "radix_scatter";
857 pub const COMPUTE_RANKS: &str = "compute_ranks";
858 pub const RADIX_SCATTER_STABLE: &str = "radix_scatter_stable";
859 pub const COMPUTE_DIGIT_PREFIX_SUMS: &str = "compute_digit_prefix_sums";
860 pub const INIT_INDICES: &str = "init_indices";
861 pub const APPLY_PERMUTATION_U32: &str = "apply_permutation_u32";
862 pub const APPLY_PERMUTATION_BYTES: &str = "apply_permutation_bytes";
863
864 pub const GATHER_KEYS_I32_ORDERED_U32: &str = "gather_keys_i32_ordered_u32";
865 pub const GATHER_KEYS_F32_ORDERED_U32: &str = "gather_keys_f32_ordered_u32";
866 pub const GATHER_KEYS_BOOL_ORDERED_U32: &str = "gather_keys_bool_ordered_u32";
867
868 pub const GATHER_KEYS_U64_LO_U32: &str = "gather_keys_u64_lo_u32";
869 pub const GATHER_KEYS_U64_HI_U32: &str = "gather_keys_u64_hi_u32";
870
871 pub const GATHER_KEYS_I64_LO_U32: &str = "gather_keys_i64_lo_u32";
872 pub const GATHER_KEYS_I64_HI_U32: &str = "gather_keys_i64_hi_u32";
873
874 pub const GATHER_KEYS_F64_LO_U32: &str = "gather_keys_f64_lo_u32";
875 pub const GATHER_KEYS_F64_HI_U32: &str = "gather_keys_f64_hi_u32";
876 pub const CHECK_ASCENDING_SORTED_U32: &str = "check_ascending_sorted_u32";
884}
885
886pub mod filter_kernels {
888 pub const FILTER_COMPARE_U32: &str = "filter_compare_u32";
889 pub const FILTER_COMPARE_I64: &str = "filter_compare_i64";
890 pub const FILTER_COMPARE_F64: &str = "filter_compare_f64";
891 pub const FILTER_COMPARE_I32: &str = "filter_compare_i32";
892 pub const FILTER_COMPARE_U64: &str = "filter_compare_u64";
893 pub const FILTER_COMPARE_F32: &str = "filter_compare_f32";
894 pub const FILTER_COMPARE_U8: &str = "filter_compare_u8";
895 pub const FILTER_COMPARE_U32_SCAN_PHASE1: &str = "filter_compare_u32_scan_phase1";
896 pub const FILTER_COMPARE_F64_SCAN_PHASE1: &str = "filter_compare_f64_scan_phase1";
897 pub const FILTER_COMPARE_F32_SCAN_PHASE1: &str = "filter_compare_f32_scan_phase1";
898 pub const FILTER_COMPARE_U32_COL: &str = "filter_compare_u32_col";
899 pub const FILTER_COMPARE_I32_COL: &str = "filter_compare_i32_col";
900 pub const FILTER_COMPARE_I64_COL: &str = "filter_compare_i64_col";
901 pub const FILTER_COMPARE_U64_COL: &str = "filter_compare_u64_col";
902 pub const FILTER_COMPARE_F32_COL: &str = "filter_compare_f32_col";
903 pub const FILTER_COMPARE_F64_COL: &str = "filter_compare_f64_col";
904 pub const FILTER_COMPARE_U8_COL: &str = "filter_compare_u8_col";
905 pub const FILL_U32_IOTA: &str = "fill_u32_iota";
906 pub const FILL_U32_CONST: &str = "fill_u32_const";
907 pub const MARK_RANDOM_VARS: &str = "mark_random_vars";
908 pub const RANDOM_VAR_TO_BIT_FROM_LIST: &str = "random_var_to_bit_from_list";
909 pub const CHECK_RANDOM_VAR_COUNT: &str = "check_random_var_count";
910 pub const COMPACT_U32_BY_MASK: &str = "compact_u32_by_mask";
911 pub const COMPACT_I64_BY_MASK: &str = "compact_i64_by_mask";
912 pub const COMPACT_F64_BY_MASK: &str = "compact_f64_by_mask";
913 pub const COMPACT_BYTES_BY_MASK: &str = "compact_bytes_by_mask";
914 pub const CAPTURE_COMPACT_COUNT: &str = "capture_compact_count";
915 pub const MASK_CLAMP_ROWS: &str = "mask_clamp_rows";
916 pub const MASK_AND: &str = "mask_and";
917 pub const MASK_OR: &str = "mask_or";
918 pub const MASK_NOT: &str = "mask_not";
919}
920
921pub mod set_ops_kernels {
923 pub const CONCAT_U32: &str = "concat_u32";
924 pub const CONCAT_BYTES: &str = "concat_bytes";
925 pub const SORTED_DIFF_MARK: &str = "sorted_diff_mark";
926}
927
928pub mod pack_kernels {
930 pub const PACK_KEYS: &str = "pack_keys";
932 pub const HASH_PACKED_KEYS: &str = "hash_packed_keys";
934 pub const PACK_AND_HASH_KEYS: &str = "pack_and_hash_keys";
936 pub const PACK_AND_HASH_KEYS_GENERIC: &str = "pack_and_hash_keys_generic";
938 pub const PACK_KEYS_ALIGNED: &str = "pack_keys_aligned";
940 pub const UNPACK_COLUMN: &str = "unpack_column";
942 pub const UNPACK_COLUMN_COUNTED: &str = "unpack_column_counted";
944 pub const GATHER_PACKED_ROWS: &str = "gather_packed_rows";
946 pub const GATHER_PACKED_ROWS_COUNTED: &str = "gather_packed_rows_counted";
948 pub const SCATTER_PACKED_ROWS: &str = "scatter_packed_rows";
950 pub const COMPARE_PACKED_KEYS: &str = "compare_packed_keys";
952 pub const PACK_BOOLS_TO_BITMAP: &str = "pack_bools_to_bitmap";
954}
955
956pub mod circuit_kernels {
958 pub const XGCF_FORWARD_LEVEL: &str = "xgcf_forward_level";
959 pub const XGCF_BACKWARD_LEVEL_PROPAGATE: &str = "xgcf_backward_level_propagate";
960 pub const XGCF_BACKWARD_LEVEL_DECISION_GRAD: &str = "xgcf_backward_level_decision_grad";
961 pub const XGCF_BACKWARD_LEVEL_LIT_GRAD: &str = "xgcf_backward_level_lit_grad";
962 pub const XGCF_FREE_VAR_APPLY_GRAD: &str = "xgcf_free_var_apply_grad";
963 pub const XGCF_FREE_VAR_REDUCE_STAGE: &str = "xgcf_free_var_reduce_stage";
964 pub const XGCF_ADD_SCALAR: &str = "xgcf_add_scalar";
965 pub const XGCF_FORWARD_LEVEL_CACHED: &str = "xgcf_forward_level_cached";
966 pub const XGCF_EVAL_ALL_LEVELS_CACHED: &str = "xgcf_eval_all_levels_cached";
967 pub const XGCF_EVAL_ALL_LEVELS_CACHED_BATCHED: &str = "xgcf_eval_all_levels_cached_batched";
968 pub const XGCF_BACKWARD_LEVEL_PROPAGATE_CACHED: &str = "xgcf_backward_level_propagate_cached";
969 pub const XGCF_BACKWARD_LEVEL_DECISION_GRAD_CACHED: &str =
970 "xgcf_backward_level_decision_grad_cached";
971 pub const XGCF_BACKWARD_LEVEL_LIT_GRAD_CACHED: &str = "xgcf_backward_level_lit_grad_cached";
972 pub const XGCF_BACKWARD_ALL_LEVELS_CACHED: &str = "xgcf_backward_all_levels_cached";
973 pub const XGCF_BACKWARD_ALL_LEVELS_CACHED_BATCHED: &str =
974 "xgcf_backward_all_levels_cached_batched";
975 pub const XGCF_FREE_VAR_APPLY_GRAD_CACHED: &str = "xgcf_free_var_apply_grad_cached";
976 pub const XGCF_FREE_VAR_REDUCE_STAGE_CACHED: &str = "xgcf_free_var_reduce_stage_cached";
977 pub const XGCF_ADD_SCALAR_CACHED: &str = "xgcf_add_scalar_cached";
978 pub const XGCF_SET_ROOT_ADJ_CACHED_BATCHED: &str = "xgcf_set_root_adj_cached_batched";
979 pub const XGCF_COPY_ROOT_CACHED: &str = "xgcf_copy_root_cached";
980 pub const XGCF_COPY_ROOT_CACHED_META: &str = "xgcf_copy_root_cached_meta";
981 pub const XGCF_COPY_ROOT_CACHED_META_BATCHED: &str = "xgcf_copy_root_cached_meta_batched";
982}
983
984pub mod cache_kernels {
986 pub const CACHE_CNF_HASH: &str = "cache_cnf_hash";
987 pub const CACHE_LOOKUP_OR_INSERT: &str = "cache_lookup_or_insert";
988 pub const CACHE_EVICT_LRU: &str = "cache_evict_lru";
989 pub const CACHE_STORE_U8: &str = "cache_store_u8";
990 pub const CACHE_STORE_U32: &str = "cache_store_u32";
991 pub const CACHE_STORE_I32: &str = "cache_store_i32";
992 pub const CACHE_STORE_F64: &str = "cache_store_f64";
993 pub const CACHE_STORE_META: &str = "cache_store_meta";
994}
995
996pub mod sat_kernels {
998 pub const SAT_CDCL_SOLVE: &str = "sat_cdcl_solve";
999 pub const SAT_CHECK_MODEL: &str = "sat_check_model";
1000 pub const SAT_PROOF_MARK_NEEDED: &str = "sat_proof_mark_needed";
1001 pub const SAT_PROOF_CHECK: &str = "sat_proof_check";
1002 pub const SAT_ASSERT_STATUS: &str = "sat_assert_status";
1003 pub const SAT_ASSERT_OK: &str = "sat_assert_ok";
1004 pub const SAT_XGCF_CNF_COUNTS: &str = "sat_xgcf_cnf_counts";
1005 pub const SAT_XGCF_CNF_EMIT: &str = "sat_xgcf_cnf_emit";
1006 pub const SAT_XGCF_CNF_CAPTURE_LAST_COUNTS: &str = "sat_xgcf_cnf_capture_last_counts";
1007 pub const SAT_XGCF_CNF_COMPUTE_TOTALS: &str = "sat_xgcf_cnf_compute_totals";
1008 pub const SAT_CNF_WRITE_TERMINATOR: &str = "sat_cnf_write_terminator";
1009 pub const SAT_CNF_COPY_INTO: &str = "sat_cnf_copy_into";
1010 pub const SAT_SHIFT_OFFSETS: &str = "sat_shift_offsets";
1011 pub const SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE: &str = "sat_xgcf_write_root_unit_clause";
1012 pub const SAT_NOT_PHI_COUNTS: &str = "sat_not_phi_counts";
1013 pub const SAT_EMIT_NOT_PHI: &str = "sat_emit_not_phi";
1014}
1015
1016pub const DEFAULT_JOIN_MAX_OUTPUT: usize = 1_000_000;
1019
1020pub const NESTED_LOOP_TOTAL_THRESHOLD: u64 = 4_000_000;
1041
1042#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1044#[repr(u8)]
1045pub enum CompareOp {
1046 Eq = 0,
1047 Ne = 1,
1048 Lt = 2,
1049 Le = 3,
1050 Gt = 4,
1051 Ge = 5,
1052}
1053
1054#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1056pub enum JoinType {
1057 Inner,
1059 Semi,
1061 Anti,
1063 LeftOuter,
1065}
1066
1067struct PackedKeyData {
1069 hashes: crate::memory::TrackedCudaSlice<u64>,
1071 packed_keys: crate::memory::TrackedCudaSlice<u8>,
1073 key_bytes: u32,
1075}
1076
1077struct JoinHashTableV2 {
1078 bucket_counts: crate::memory::TrackedCudaSlice<u32>,
1079 bucket_offsets: crate::memory::TrackedCudaSlice<u32>,
1080 bucket_entries: crate::memory::TrackedCudaSlice<u32>,
1081 bucket_entry_hashes: crate::memory::TrackedCudaSlice<u64>,
1082 bucket_mask: u32,
1083}
1084
1085pub struct HashTableU64 {
1087 pub bucket_counts: crate::memory::TrackedCudaSlice<u32>,
1088 pub bucket_offsets: crate::memory::TrackedCudaSlice<u32>,
1089 pub bucket_entries: crate::memory::TrackedCudaSlice<u32>,
1090 pub bucket_entry_hashes: crate::memory::TrackedCudaSlice<u64>,
1091 pub bucket_mask: u32,
1092}
1093
1094pub struct JoinIndexV2 {
1099 right_num_rows: u32,
1100 right_keys: Vec<usize>,
1101 key_bytes: u32,
1102 packed_keys: crate::memory::TrackedCudaSlice<u8>,
1103 table: JoinHashTableV2,
1104}
1105
1106impl JoinIndexV2 {
1107 pub fn right_keys(&self) -> &[usize] {
1109 &self.right_keys
1110 }
1111
1112 pub fn right_num_rows(&self) -> u32 {
1114 self.right_num_rows
1115 }
1116
1117 pub fn estimated_bytes(&self) -> u64 {
1119 let mut bytes = 0u64;
1120 bytes = bytes.saturating_add(self.packed_keys.len() as u64);
1121 bytes = bytes.saturating_add(self.table.bucket_counts.len() as u64 * 4);
1122 bytes = bytes.saturating_add(self.table.bucket_offsets.len() as u64 * 4);
1123 bytes = bytes.saturating_add(self.table.bucket_entries.len() as u64 * 4);
1124 bytes = bytes.saturating_add(self.table.bucket_entry_hashes.len() as u64 * 8);
1125 bytes
1126 }
1127}
1128
1129pub struct CudaKernelProvider {
1150 device: Arc<CudaDevice>,
1152 memory: Arc<GpuMemoryManager>,
1154 transfer_tracker: HostTransferTracker,
1156 ptx_load_profile: Option<PtxLoadProfile>,
1158 d2h_transfer_count: AtomicU64,
1160 untracked_metadata_dtoh_count: AtomicU64,
1167 strict_deterministic_d2h: AtomicBool,
1173 deterministic_d2h_violations: AtomicU64,
1177 recorded_op_stream: OnceLock<crate::device_runtime::StreamId>,
1186 csm_invocations: AtomicU64,
1194 csm_cuda_graph_captures: AtomicU64,
1196 csm_cuda_graph_launches: AtomicU64,
1198 csm_cuda_graph_fallbacks: AtomicU64,
1200 csm_cuda_graph_cache_hits: AtomicU64,
1202 small_full_row_sort_invocations: AtomicU64,
1206 csm_cuda_graph_cache: Mutex<HashMap<CsmCudaGraphKey, CsmCudaGraphEntry>>,
1208 wcoj_layout_fast_path_hit_count: AtomicU64,
1215 wcoj_layout_sort_invocation_count: AtomicU64,
1220 kclique_metadata_build_count: AtomicU64,
1222 kclique_metadata_build_nanos: AtomicU64,
1225 wcoj_triangle_hg_dispatch_count: AtomicU64,
1228 #[cfg(feature = "wcoj-phase-timing")]
1235 last_triangle_phase_timing:
1236 std::sync::Mutex<Option<crate::wcoj_phase_timing::WcojTrianglePhaseTiming>>,
1237}
1238
1239#[derive(Default)]
1240struct HostTransferTracker {
1241 dtoh_bytes: AtomicU64,
1242 htod_bytes: AtomicU64,
1243 dtoh_calls: AtomicU64,
1244 htod_calls: AtomicU64,
1245 launch_metadata_htod_bytes: AtomicU64,
1246 launch_metadata_htod_calls: AtomicU64,
1247}
1248
1249#[derive(Debug, Clone, Copy)]
1250pub struct HostTransferStats {
1251 pub dtoh_bytes: u64,
1252 pub htod_bytes: u64,
1253 pub dtoh_calls: u64,
1254 pub htod_calls: u64,
1255}
1256
1257#[derive(Debug, Clone, Copy, Default)]
1258pub struct HostLaunchMetadataTransferStats {
1259 pub htod_bytes: u64,
1260 pub htod_calls: u64,
1261}
1262
1263impl HostTransferTracker {
1264 fn record_dtoh(&self, bytes: u64) {
1265 self.dtoh_calls.fetch_add(1, Ordering::Relaxed);
1266 self.dtoh_bytes.fetch_add(bytes, Ordering::Relaxed);
1267 }
1268
1269 fn record_htod(&self, bytes: u64) {
1270 self.htod_calls.fetch_add(1, Ordering::Relaxed);
1271 self.htod_bytes.fetch_add(bytes, Ordering::Relaxed);
1272 }
1273
1274 fn record_htod_launch_metadata(&self, bytes: u64) {
1275 self.launch_metadata_htod_calls
1276 .fetch_add(1, Ordering::Relaxed);
1277 self.launch_metadata_htod_bytes
1278 .fetch_add(bytes, Ordering::Relaxed);
1279 }
1280
1281 fn snapshot(&self) -> HostTransferStats {
1282 HostTransferStats {
1283 dtoh_bytes: self.dtoh_bytes.load(Ordering::Relaxed),
1284 htod_bytes: self.htod_bytes.load(Ordering::Relaxed),
1285 dtoh_calls: self.dtoh_calls.load(Ordering::Relaxed),
1286 htod_calls: self.htod_calls.load(Ordering::Relaxed),
1287 }
1288 }
1289
1290 fn launch_metadata_snapshot(&self) -> HostLaunchMetadataTransferStats {
1291 HostLaunchMetadataTransferStats {
1292 htod_bytes: self.launch_metadata_htod_bytes.load(Ordering::Relaxed),
1293 htod_calls: self.launch_metadata_htod_calls.load(Ordering::Relaxed),
1294 }
1295 }
1296
1297 fn reset(&self) {
1298 self.dtoh_bytes.store(0, Ordering::Relaxed);
1299 self.htod_bytes.store(0, Ordering::Relaxed);
1300 self.dtoh_calls.store(0, Ordering::Relaxed);
1301 self.htod_calls.store(0, Ordering::Relaxed);
1302 self.launch_metadata_htod_bytes.store(0, Ordering::Relaxed);
1303 self.launch_metadata_htod_calls.store(0, Ordering::Relaxed);
1304 }
1305}
1306
1307impl CudaKernelProvider {
1308 pub fn new(device: Arc<CudaDevice>, memory: Arc<GpuMemoryManager>) -> Result<Self> {
1327 let profiling = warmup_profiling_enabled();
1328 let ptx_load_profile = Self::load_all_kernel_modules(&device, profiling)?;
1329
1330 Ok(Self {
1331 device,
1332 memory,
1333 transfer_tracker: HostTransferTracker::default(),
1334 ptx_load_profile,
1335 d2h_transfer_count: AtomicU64::new(0),
1336 untracked_metadata_dtoh_count: AtomicU64::new(0),
1337 strict_deterministic_d2h: AtomicBool::new(false),
1338 deterministic_d2h_violations: AtomicU64::new(0),
1339 recorded_op_stream: OnceLock::new(),
1340 csm_invocations: AtomicU64::new(0),
1341 csm_cuda_graph_captures: AtomicU64::new(0),
1342 csm_cuda_graph_launches: AtomicU64::new(0),
1343 csm_cuda_graph_fallbacks: AtomicU64::new(0),
1344 csm_cuda_graph_cache_hits: AtomicU64::new(0),
1345 small_full_row_sort_invocations: AtomicU64::new(0),
1346 csm_cuda_graph_cache: Mutex::new(HashMap::new()),
1347 wcoj_layout_fast_path_hit_count: AtomicU64::new(0),
1348 wcoj_layout_sort_invocation_count: AtomicU64::new(0),
1349 kclique_metadata_build_count: AtomicU64::new(0),
1350 kclique_metadata_build_nanos: AtomicU64::new(0),
1351 wcoj_triangle_hg_dispatch_count: AtomicU64::new(0),
1352 #[cfg(feature = "wcoj-phase-timing")]
1353 last_triangle_phase_timing: std::sync::Mutex::new(None),
1354 })
1355 }
1356
1357 pub fn with_runtime(device: Arc<CudaDevice>, memory: Arc<GpuMemoryManager>) -> Result<Self> {
1405 if memory.runtime().is_none() {
1406 return Err(XlogError::Kernel(
1407 "CudaKernelProvider::with_runtime requires a GpuMemoryManager built via \
1408 GpuMemoryManager::with_runtime; got a manager with no runtime attached"
1409 .to_string(),
1410 ));
1411 }
1412 Self::new(device, memory)
1413 }
1414
1415 fn env_flag(name: &str) -> bool {
1418 std::env::var(name)
1419 .map(|v| !v.is_empty() && v != "0")
1420 .unwrap_or(false)
1421 }
1422
1423 pub(crate) fn use_recorded_filters_env() -> bool {
1435 Self::env_flag("XLOG_USE_RECORDED_FILTERS") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1436 }
1437
1438 pub(crate) fn use_recorded_sort_env() -> bool {
1445 Self::env_flag("XLOG_USE_RECORDED_SORT") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1446 }
1447
1448 pub(crate) fn use_recorded_dedup_env() -> bool {
1453 Self::env_flag("XLOG_USE_RECORDED_DEDUP") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1454 }
1455
1456 pub(crate) fn use_recorded_groupby_env() -> bool {
1462 Self::env_flag("XLOG_USE_RECORDED_GROUPBY") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1463 }
1464
1465 pub(crate) fn use_recorded_hash_join_env() -> bool {
1473 Self::env_flag("XLOG_USE_RECORDED_HASH_JOIN") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1474 }
1475
1476 pub(crate) fn use_recorded_csm_env() -> bool {
1484 Self::env_flag("XLOG_USE_RECORDED_CSM") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1485 }
1486
1487 pub(crate) fn use_csm_cuda_graph_env() -> bool {
1493 Self::env_flag("XLOG_USE_CSM_CUDA_GRAPH") || Self::env_flag("XLOG_USE_CUDA_GRAPHS")
1494 }
1495
1496 #[doc(hidden)]
1510 pub fn csm_invocations(&self) -> u64 {
1511 self.csm_invocations.load(Ordering::Relaxed)
1512 }
1513
1514 #[doc(hidden)]
1515 pub fn csm_cuda_graph_captures(&self) -> u64 {
1516 self.csm_cuda_graph_captures.load(Ordering::Relaxed)
1517 }
1518
1519 #[doc(hidden)]
1520 pub fn csm_cuda_graph_launches(&self) -> u64 {
1521 self.csm_cuda_graph_launches.load(Ordering::Relaxed)
1522 }
1523
1524 #[doc(hidden)]
1525 pub fn csm_cuda_graph_fallbacks(&self) -> u64 {
1526 self.csm_cuda_graph_fallbacks.load(Ordering::Relaxed)
1527 }
1528
1529 #[doc(hidden)]
1530 pub fn csm_cuda_graph_cache_hits(&self) -> u64 {
1531 self.csm_cuda_graph_cache_hits.load(Ordering::Relaxed)
1532 }
1533
1534 #[doc(hidden)]
1535 pub fn small_full_row_sort_invocations(&self) -> u64 {
1536 self.small_full_row_sort_invocations.load(Ordering::Relaxed)
1537 }
1538
1539 pub(crate) fn recorded_op_stream_or_init(&self) -> Option<crate::device_runtime::StreamId> {
1559 if let Some(s) = self.recorded_op_stream.get() {
1560 return Some(*s);
1561 }
1562 let runtime = self.memory.runtime()?;
1563 let stream = runtime.stream_pool().acquire().ok()?;
1564 let _ = self.recorded_op_stream.set(stream);
1565 self.recorded_op_stream.get().copied()
1566 }
1567
1568 #[cfg(feature = "wcoj-phase-timing")]
1578 pub fn take_wcoj_triangle_phase_timing(
1579 &self,
1580 ) -> Option<crate::wcoj_phase_timing::WcojTrianglePhaseTiming> {
1581 self.last_triangle_phase_timing
1582 .lock()
1583 .ok()
1584 .and_then(|mut g| g.take())
1585 }
1586
1587 #[cfg(feature = "wcoj-phase-timing")]
1591 #[allow(dead_code)]
1592 pub(crate) fn put_wcoj_triangle_phase_timing(
1593 &self,
1594 timing: crate::wcoj_phase_timing::WcojTrianglePhaseTiming,
1595 ) {
1596 if let Ok(mut g) = self.last_triangle_phase_timing.lock() {
1597 *g = Some(timing);
1598 }
1599 }
1600
1601 pub fn wcoj_layout_fast_path_hit_count(&self) -> u64 {
1608 self.wcoj_layout_fast_path_hit_count.load(Ordering::Relaxed)
1609 }
1610
1611 pub fn wcoj_triangle_hg_dispatch_count(&self) -> u64 {
1614 self.wcoj_triangle_hg_dispatch_count.load(Ordering::Relaxed)
1615 }
1616
1617 pub fn reset_wcoj_layout_fast_path_hit_count(&self) {
1620 self.wcoj_layout_fast_path_hit_count
1621 .store(0, Ordering::Relaxed);
1622 }
1623
1624 pub fn wcoj_layout_sort_invocation_count(&self) -> u64 {
1627 self.wcoj_layout_sort_invocation_count
1628 .load(Ordering::Relaxed)
1629 }
1630
1631 pub fn reset_wcoj_layout_sort_invocation_count(&self) {
1633 self.wcoj_layout_sort_invocation_count
1634 .store(0, Ordering::Relaxed);
1635 }
1636
1637 pub fn kclique_metadata_build_count(&self) -> u64 {
1640 self.kclique_metadata_build_count.load(Ordering::Relaxed)
1641 }
1642
1643 pub fn kclique_metadata_build_nanos(&self) -> u64 {
1646 self.kclique_metadata_build_nanos.load(Ordering::Relaxed)
1647 }
1648
1649 pub fn reset_kclique_metadata_build_metrics(&self) {
1651 self.kclique_metadata_build_count
1652 .store(0, Ordering::Relaxed);
1653 self.kclique_metadata_build_nanos
1654 .store(0, Ordering::Relaxed);
1655 }
1656
1657 pub(crate) fn record_wcoj_layout_fast_path_hit(&self) {
1661 self.wcoj_layout_fast_path_hit_count
1662 .fetch_add(1, Ordering::Relaxed);
1663 }
1664
1665 pub(crate) fn record_wcoj_layout_sort_invocation(&self) {
1667 self.wcoj_layout_sort_invocation_count
1668 .fetch_add(1, Ordering::Relaxed);
1669 }
1670
1671 pub(crate) fn record_kclique_metadata_build_nanos(&self, nanos: u128) {
1673 self.kclique_metadata_build_count
1674 .fetch_add(1, Ordering::Relaxed);
1675 let nanos = u64::try_from(nanos).unwrap_or(u64::MAX);
1676 self.kclique_metadata_build_nanos
1677 .fetch_add(nanos, Ordering::Relaxed);
1678 }
1679
1680 #[doc(hidden)]
1683 pub fn record_wcoj_triangle_hg_dispatch(&self) {
1684 self.wcoj_triangle_hg_dispatch_count
1685 .fetch_add(1, Ordering::Relaxed);
1686 }
1687
1688 pub fn device(&self) -> &Arc<CudaDevice> {
1690 &self.device
1691 }
1692
1693 pub fn memory(&self) -> &Arc<GpuMemoryManager> {
1695 &self.memory
1696 }
1697
1698 pub fn ptx_load_profile(&self) -> Option<&PtxLoadProfile> {
1700 self.ptx_load_profile.as_ref()
1701 }
1702
1703 pub fn reset_host_transfer_stats(&self) {
1705 self.transfer_tracker.reset();
1706 }
1707
1708 pub fn host_transfer_stats(&self) -> HostTransferStats {
1710 self.transfer_tracker.snapshot()
1711 }
1712
1713 pub fn host_launch_metadata_transfer_stats(&self) -> HostLaunchMetadataTransferStats {
1716 self.transfer_tracker.launch_metadata_snapshot()
1717 }
1718
1719 pub fn d2h_transfer_count(&self) -> u64 {
1725 self.d2h_transfer_count.load(Ordering::Relaxed)
1726 }
1727
1728 pub fn reset_d2h_transfer_count(&self) {
1730 self.d2h_transfer_count.store(0, Ordering::Relaxed);
1731 }
1732
1733 pub fn untracked_metadata_dtoh_count(&self) -> u64 {
1736 self.untracked_metadata_dtoh_count.load(Ordering::Relaxed)
1737 }
1738
1739 pub fn reset_untracked_metadata_dtoh_count(&self) {
1741 self.untracked_metadata_dtoh_count
1742 .store(0, Ordering::Relaxed);
1743 }
1744
1745 pub fn enable_strict_deterministic_d2h(&self) {
1762 self.strict_deterministic_d2h.store(true, Ordering::Relaxed);
1763 }
1764
1765 pub fn disable_strict_deterministic_d2h(&self) {
1767 self.strict_deterministic_d2h
1768 .store(false, Ordering::Relaxed);
1769 }
1770
1771 pub fn strict_deterministic_d2h_enabled(&self) -> bool {
1773 self.strict_deterministic_d2h.load(Ordering::Relaxed)
1774 }
1775
1776 pub fn deterministic_d2h_violation_count(&self) -> u64 {
1778 self.deterministic_d2h_violations.load(Ordering::Relaxed)
1779 }
1780
1781 pub fn reset_deterministic_d2h_violations(&self) {
1783 self.deterministic_d2h_violations
1784 .store(0, Ordering::Relaxed);
1785 }
1786
1787 pub(crate) fn check_deterministic_d2h(&self, op: &'static str, bytes: u64) -> Result<()> {
1793 if self.strict_deterministic_d2h.load(Ordering::Relaxed) {
1794 self.deterministic_d2h_violations
1795 .fetch_add(1, Ordering::Relaxed);
1796 return Err(XlogError::Execution(format!(
1797 "deterministic D2H gate: {} attempted to copy {} bytes from device to host",
1798 op, bytes
1799 )));
1800 }
1801 Ok(())
1802 }
1803
1804 fn dtoh_sync_copy_into_tracked<T: DeviceRepr, Src: DevicePtr<T>>(
1805 &self,
1806 src: &Src,
1807 dst: &mut [T],
1808 ) -> Result<()> {
1809 let bytes = std::mem::size_of::<T>()
1810 .checked_mul(dst.len())
1811 .ok_or_else(|| XlogError::Kernel("dtoh size overflow".to_string()))?;
1812 self.check_deterministic_d2h("dtoh_sync_copy_into_tracked", bytes as u64)?;
1813 self.transfer_tracker.record_dtoh(bytes as u64);
1814 self.device
1815 .inner()
1816 .dtoh_sync_copy_into(src, dst)
1817 .map_err(|e| XlogError::Kernel(format!("Failed to copy from device: {}", e)))
1818 }
1819
1820 pub const DTOH_SMALL_METADATA_MAX_BYTES: usize = 4096;
1825
1826 pub fn dtoh_small_metadata_untracked<T: DeviceRepr + Default + Copy>(
1851 &self,
1852 src: &crate::memory::TrackedCudaSlice<T>,
1853 count: usize,
1854 ) -> Result<Vec<T>> {
1855 let bytes = count.checked_mul(std::mem::size_of::<T>()).ok_or_else(|| {
1856 XlogError::Kernel("dtoh_small_metadata_untracked: byte size overflow".to_string())
1857 })?;
1858 if bytes > Self::DTOH_SMALL_METADATA_MAX_BYTES {
1859 return Err(XlogError::Kernel(format!(
1860 "dtoh_small_metadata_untracked: requested {} bytes exceeds metadata cap of {} bytes \
1861 (this is metadata-only; use download_column* for data-plane transfers)",
1862 bytes,
1863 Self::DTOH_SMALL_METADATA_MAX_BYTES
1864 )));
1865 }
1866 if count > src.len() {
1867 return Err(XlogError::Kernel(format!(
1868 "dtoh_small_metadata_untracked: count={count} > src.len={}",
1869 src.len()
1870 )));
1871 }
1872 if count == 0 {
1873 return Ok(Vec::new());
1874 }
1875 let slice = src.try_slice(0..count).ok_or_else(|| {
1876 XlogError::Kernel(format!(
1877 "dtoh_small_metadata_untracked: try_slice(0..{count}) failed"
1878 ))
1879 })?;
1880 let mut buf: Vec<T> = vec![T::default(); count];
1881 self.untracked_metadata_dtoh_count
1882 .fetch_add(1, Ordering::Relaxed);
1883 self.device
1884 .inner()
1885 .dtoh_sync_copy_into(&slice, &mut buf)
1886 .map_err(|e| {
1887 XlogError::Kernel(format!("dtoh_small_metadata_untracked: copy failed: {}", e))
1888 })?;
1889 Ok(buf)
1890 }
1891
1892 pub fn dtoh_scalar_untracked<T: DeviceRepr + Default + Copy>(
1900 &self,
1901 src: &crate::memory::TrackedCudaSlice<T>,
1902 index: usize,
1903 ) -> Result<T> {
1904 if index >= src.len() {
1905 return Err(XlogError::Kernel(format!(
1906 "dtoh_scalar_untracked: index={} >= len={}",
1907 index,
1908 src.len()
1909 )));
1910 }
1911 let slice = src.try_slice(index..index + 1).ok_or_else(|| {
1912 XlogError::Kernel(format!(
1913 "dtoh_scalar_untracked: slice failed at index={}",
1914 index
1915 ))
1916 })?;
1917 let mut buf = [T::default()];
1918 self.untracked_metadata_dtoh_count
1919 .fetch_add(1, Ordering::Relaxed);
1920 self.device
1921 .inner()
1922 .dtoh_sync_copy_into(&slice, &mut buf)
1923 .map_err(|e| XlogError::Kernel(format!("dtoh_scalar_untracked: copy failed: {}", e)))?;
1924 Ok(buf[0])
1925 }
1926
1927 pub fn htod_sync_copy_into_tracked<T: DeviceRepr, Dst: cudarc::driver::DevicePtrMut<T>>(
1929 &self,
1930 src: &[T],
1931 dst: &mut Dst,
1932 ) -> Result<()> {
1933 let bytes = std::mem::size_of::<T>()
1934 .checked_mul(src.len())
1935 .ok_or_else(|| XlogError::Kernel("htod size overflow".to_string()))?;
1936 self.transfer_tracker.record_htod(bytes as u64);
1937 self.device
1938 .inner()
1939 .htod_sync_copy_into(src, dst)
1940 .map_err(|e| XlogError::Kernel(format!("Failed to copy to device: {}", e)))
1941 }
1942
1943 pub fn htod_sync_copy_tracked<T: DeviceRepr>(
1946 &self,
1947 src: &[T],
1948 ) -> Result<cudarc::driver::CudaSlice<T>> {
1949 let bytes = std::mem::size_of::<T>()
1950 .checked_mul(src.len())
1951 .ok_or_else(|| XlogError::Kernel("htod size overflow".to_string()))?;
1952 self.transfer_tracker.record_htod(bytes as u64);
1953 self.device
1954 .inner()
1955 .htod_sync_copy(src)
1956 .map_err(|e| XlogError::Kernel(format!("Failed to copy to device: {}", e)))
1957 }
1958
1959 pub fn htod_launch_metadata_sync_copy_into<
1962 T: DeviceRepr,
1963 Dst: cudarc::driver::DevicePtrMut<T>,
1964 >(
1965 &self,
1966 src: &[T],
1967 dst: &mut Dst,
1968 ) -> Result<()> {
1969 let bytes = std::mem::size_of::<T>()
1970 .checked_mul(src.len())
1971 .ok_or_else(|| XlogError::Kernel("launch metadata htod size overflow".to_string()))?;
1972 self.transfer_tracker
1973 .record_htod_launch_metadata(bytes as u64);
1974 self.device
1975 .inner()
1976 .htod_sync_copy_into(src, dst)
1977 .map_err(|e| {
1978 XlogError::Kernel(format!("Failed to copy launch metadata to device: {}", e))
1979 })
1980 }
1981
1982 pub(crate) fn htod_launch_metadata_async_copy_one<T: DeviceRepr>(
1985 &self,
1986 src: &T,
1987 dst: &TrackedCudaSlice<T>,
1988 stream: &CudaStream,
1989 context: &str,
1990 ) -> Result<()> {
1991 let bytes = std::mem::size_of::<T>();
1992 self.transfer_tracker
1993 .record_htod_launch_metadata(bytes as u64);
1994 unsafe {
1995 let res = cudarc::driver::sys::cuMemcpyHtoDAsync_v2(
1996 *dst.device_ptr(),
1997 src as *const T as *const c_void,
1998 bytes,
1999 stream.cu_stream(),
2000 );
2001 if res != cudarc::driver::sys::cudaError_enum::CUDA_SUCCESS {
2002 return Err(XlogError::Kernel(format!(
2003 "{context}: launch metadata H2D failed: {res:?}"
2004 )));
2005 }
2006 }
2007 Ok(())
2008 }
2009
2010 pub fn exclusive_scan_u32_inplace(
2039 &self,
2040 data: &mut crate::memory::TrackedCudaSlice<u32>,
2041 n: u32,
2042 ) -> Result<()> {
2043 if n as usize > data.len() {
2044 return Err(XlogError::Kernel(format!(
2045 "exclusive_scan_u32_inplace: n={} exceeds slice len={}",
2046 n,
2047 data.len()
2048 )));
2049 }
2050 self.multiblock_scan_u32_inplace(data, n)
2051 }
2052
2053 fn multiblock_scan_u32_inplace(
2054 &self,
2055 data: &mut crate::memory::TrackedCudaSlice<u32>,
2056 n: u32,
2057 ) -> Result<()> {
2058 if n == 0 {
2059 return Ok(());
2060 }
2061
2062 let device = self.device.inner();
2063 let block_size = 256u32;
2064
2065 if n <= block_size {
2066 let phase2_fn = device
2067 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2068 .ok_or_else(|| {
2069 XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2070 })?;
2071
2072 unsafe {
2074 phase2_fn.clone().launch(
2075 LaunchConfig {
2076 grid_dim: (1, 1, 1),
2077 block_dim: (block_size, 1, 1),
2078 shared_mem_bytes: 0,
2079 },
2080 (&mut *data, n),
2081 )
2082 }
2083 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase2 failed: {}", e)))?;
2084
2085 self.device.synchronize()?;
2086 return Ok(());
2087 }
2088
2089 let num_blocks = n.div_ceil(block_size);
2090 let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2091
2092 let phase1_u32_fn = device
2093 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2094 .ok_or_else(|| {
2095 XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2096 })?;
2097
2098 unsafe {
2100 phase1_u32_fn.clone().launch(
2101 LaunchConfig {
2102 grid_dim: (num_blocks, 1, 1),
2103 block_dim: (block_size, 1, 1),
2104 shared_mem_bytes: 0,
2105 },
2106 (&mut *data, &mut block_sums, n),
2107 )
2108 }
2109 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_u32_phase1 failed: {}", e)))?;
2110 self.device.synchronize()?;
2111
2112 if num_blocks > 1 {
2113 self.multiblock_scan_u32_inplace(&mut block_sums, num_blocks)?;
2114 }
2115
2116 let phase3_fn = device
2117 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2118 .ok_or_else(|| {
2119 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2120 })?;
2121
2122 unsafe {
2124 phase3_fn.clone().launch(
2125 LaunchConfig {
2126 grid_dim: (num_blocks, 1, 1),
2127 block_dim: (block_size, 1, 1),
2128 shared_mem_bytes: 0,
2129 },
2130 (&mut *data, &block_sums, n),
2131 )
2132 }
2133 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
2134
2135 self.device.synchronize()?;
2136 Ok(())
2137 }
2138
2139 pub(crate) fn multiblock_scan_u32_inplace_on_stream(
2154 &self,
2155 data: &mut crate::memory::TrackedCudaSlice<u32>,
2156 n: u32,
2157 cu_stream: &cudarc::driver::CudaStream,
2158 launch_stream: crate::device_runtime::StreamId,
2159 runtime: &crate::device_runtime::XlogDeviceRuntime,
2160 ) -> Result<()> {
2161 if n == 0 {
2162 return Ok(());
2163 }
2164 let device = self.device.inner();
2165 let block_size = 256u32;
2166
2167 if n <= block_size {
2168 let phase2_fn = device
2169 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2170 .ok_or_else(|| {
2171 XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2172 })?;
2173 unsafe {
2175 phase2_fn.clone().launch_on_stream(
2176 cu_stream,
2177 LaunchConfig {
2178 grid_dim: (1, 1, 1),
2179 block_dim: (block_size, 1, 1),
2180 shared_mem_bytes: 0,
2181 },
2182 (&mut *data, n),
2183 )
2184 }
2185 .map_err(|e| {
2186 XlogError::Kernel(format!("multiblock_scan_phase2 (on_stream) failed: {}", e))
2187 })?;
2188 return Ok(());
2189 }
2190
2191 let num_blocks = n.div_ceil(block_size);
2192 let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2193 runtime
2200 .prepare_first_use(
2201 &block_sums,
2202 launch_stream,
2203 crate::device_runtime::Access::Write,
2204 )
2205 .map_err(|e| {
2206 XlogError::Kernel(format!(
2207 "multiblock_scan_u32_inplace_on_stream: prepare block_sums failed: {}",
2208 e
2209 ))
2210 })?;
2211
2212 let phase1_u32_fn = device
2213 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2214 .ok_or_else(|| {
2215 XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2216 })?;
2217 unsafe {
2219 phase1_u32_fn.clone().launch_on_stream(
2220 cu_stream,
2221 LaunchConfig {
2222 grid_dim: (num_blocks, 1, 1),
2223 block_dim: (block_size, 1, 1),
2224 shared_mem_bytes: 0,
2225 },
2226 (&mut *data, &mut block_sums, n),
2227 )
2228 }
2229 .map_err(|e| {
2230 XlogError::Kernel(format!(
2231 "multiblock_scan_u32_phase1 (on_stream) failed: {}",
2232 e
2233 ))
2234 })?;
2235
2236 if num_blocks > 1 {
2237 self.multiblock_scan_u32_inplace_on_stream(
2238 &mut block_sums,
2239 num_blocks,
2240 cu_stream,
2241 launch_stream,
2242 runtime,
2243 )?;
2244 }
2245
2246 let phase3_fn = device
2247 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2248 .ok_or_else(|| {
2249 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2250 })?;
2251 unsafe {
2253 phase3_fn.clone().launch_on_stream(
2254 cu_stream,
2255 LaunchConfig {
2256 grid_dim: (num_blocks, 1, 1),
2257 block_dim: (block_size, 1, 1),
2258 shared_mem_bytes: 0,
2259 },
2260 (&mut *data, &block_sums, n),
2261 )
2262 }
2263 .map_err(|e| {
2264 XlogError::Kernel(format!("multiblock_scan_phase3 (on_stream) failed: {}", e))
2265 })?;
2266
2267 if let Some(b) = block_sums.runtime_block() {
2273 runtime
2274 .finish_block_use(
2275 crate::device_runtime::BlockId::from_block(b),
2276 launch_stream,
2277 crate::device_runtime::Access::Write,
2278 )
2279 .map_err(|e| {
2280 XlogError::Kernel(format!(
2281 "multiblock_scan_u32_inplace_on_stream: finish_block_use \
2282 for intermediate block_sums failed: {}",
2283 e
2284 ))
2285 })?;
2286 } else {
2287 return Err(XlogError::Kernel(
2288 "multiblock_scan_u32_inplace_on_stream: intermediate block_sums has no \
2289 runtime block — caller must use a runtime-backed manager"
2290 .to_string(),
2291 ));
2292 }
2293 Ok(())
2294 }
2295
2296 pub(crate) fn multiblock_scan_u32_scratch_for_len(
2299 &self,
2300 mut n: u32,
2301 ) -> Result<MultiblockScanScratchU32> {
2302 let block_size = 256u32;
2303 let mut levels = Vec::new();
2304 while n > block_size {
2305 let num_blocks = n.div_ceil(block_size);
2306 levels.push(self.memory.alloc::<u32>(num_blocks as usize)?);
2307 n = num_blocks;
2308 }
2309 Ok(MultiblockScanScratchU32 { levels })
2310 }
2311
2312 pub(crate) fn multiblock_scan_u32_inplace_on_stream_with_scratch(
2319 &self,
2320 data: &mut crate::memory::TrackedCudaSlice<u32>,
2321 n: u32,
2322 cu_stream: &cudarc::driver::CudaStream,
2323 scratch: &mut MultiblockScanScratchU32,
2324 ) -> Result<()> {
2325 self.multiblock_scan_u32_inplace_on_stream_with_scratch_levels(
2326 data,
2327 n,
2328 cu_stream,
2329 &mut scratch.levels,
2330 )
2331 }
2332
2333 fn multiblock_scan_u32_inplace_on_stream_with_scratch_levels(
2334 &self,
2335 data: &mut crate::memory::TrackedCudaSlice<u32>,
2336 n: u32,
2337 cu_stream: &cudarc::driver::CudaStream,
2338 scratch_levels: &mut [TrackedCudaSlice<u32>],
2339 ) -> Result<()> {
2340 if n == 0 {
2341 return Ok(());
2342 }
2343 let device = self.device.inner();
2344 let block_size = 256u32;
2345
2346 if n <= block_size {
2347 let phase2_fn = device
2348 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2349 .ok_or_else(|| {
2350 XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2351 })?;
2352 unsafe {
2354 phase2_fn.clone().launch_on_stream(
2355 cu_stream,
2356 LaunchConfig {
2357 grid_dim: (1, 1, 1),
2358 block_dim: (block_size, 1, 1),
2359 shared_mem_bytes: 0,
2360 },
2361 (&mut *data, n),
2362 )
2363 }
2364 .map_err(|e| {
2365 XlogError::Kernel(format!(
2366 "multiblock_scan_phase2 (graph scratch) failed: {}",
2367 e
2368 ))
2369 })?;
2370 return Ok(());
2371 }
2372
2373 let num_blocks = n.div_ceil(block_size);
2374 let (block_sums, rest) = scratch_levels.split_first_mut().ok_or_else(|| {
2375 XlogError::Kernel(format!(
2376 "multiblock_scan_u32_inplace_on_stream_with_scratch: missing scratch level \
2377 for n={n}, num_blocks={num_blocks}"
2378 ))
2379 })?;
2380 if block_sums.len() < num_blocks as usize {
2381 return Err(XlogError::Kernel(format!(
2382 "multiblock_scan_u32_inplace_on_stream_with_scratch: scratch level too small \
2383 (have {}, need {})",
2384 block_sums.len(),
2385 num_blocks
2386 )));
2387 }
2388
2389 let phase1_u32_fn = device
2390 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2391 .ok_or_else(|| {
2392 XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2393 })?;
2394 unsafe {
2396 phase1_u32_fn.clone().launch_on_stream(
2397 cu_stream,
2398 LaunchConfig {
2399 grid_dim: (num_blocks, 1, 1),
2400 block_dim: (block_size, 1, 1),
2401 shared_mem_bytes: 0,
2402 },
2403 (&mut *data, &mut *block_sums, n),
2404 )
2405 }
2406 .map_err(|e| {
2407 XlogError::Kernel(format!(
2408 "multiblock_scan_u32_phase1 (graph scratch) failed: {}",
2409 e
2410 ))
2411 })?;
2412
2413 if num_blocks > 1 {
2414 self.multiblock_scan_u32_inplace_on_stream_with_scratch_levels(
2415 block_sums, num_blocks, cu_stream, rest,
2416 )?;
2417 }
2418
2419 let phase3_fn = device
2420 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2421 .ok_or_else(|| {
2422 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2423 })?;
2424 unsafe {
2426 phase3_fn.clone().launch_on_stream(
2427 cu_stream,
2428 LaunchConfig {
2429 grid_dim: (num_blocks, 1, 1),
2430 block_dim: (block_size, 1, 1),
2431 shared_mem_bytes: 0,
2432 },
2433 (&mut *data, &*block_sums, n),
2434 )
2435 }
2436 .map_err(|e| {
2437 XlogError::Kernel(format!(
2438 "multiblock_scan_phase3 (graph scratch) failed: {}",
2439 e
2440 ))
2441 })?;
2442 Ok(())
2443 }
2444
2445 pub(crate) fn multiblock_scan_u32_view_inplace_on_stream(
2453 &self,
2454 data: &mut CudaViewMut<'_, u32>,
2455 n: u32,
2456 cu_stream: &cudarc::driver::CudaStream,
2457 launch_stream: crate::device_runtime::StreamId,
2458 runtime: &crate::device_runtime::XlogDeviceRuntime,
2459 ) -> Result<()> {
2460 if n == 0 {
2461 return Ok(());
2462 }
2463 let device = self.device.inner();
2464 let block_size = 256u32;
2465
2466 if n <= block_size {
2467 let phase2_fn = device
2468 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2469 .ok_or_else(|| {
2470 XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2471 })?;
2472 unsafe {
2474 phase2_fn.clone().launch_on_stream(
2475 cu_stream,
2476 LaunchConfig {
2477 grid_dim: (1, 1, 1),
2478 block_dim: (block_size, 1, 1),
2479 shared_mem_bytes: 0,
2480 },
2481 (data, n),
2482 )
2483 }
2484 .map_err(|e| {
2485 XlogError::Kernel(format!(
2486 "multiblock_scan_phase2 (view on_stream) failed: {}",
2487 e
2488 ))
2489 })?;
2490 return Ok(());
2491 }
2492
2493 let num_blocks = n.div_ceil(block_size);
2494 let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2495 runtime
2499 .prepare_first_use(
2500 &block_sums,
2501 launch_stream,
2502 crate::device_runtime::Access::Write,
2503 )
2504 .map_err(|e| {
2505 XlogError::Kernel(format!(
2506 "multiblock_scan_u32_view_inplace_on_stream: prepare block_sums failed: {}",
2507 e
2508 ))
2509 })?;
2510
2511 let phase1_u32_fn = device
2512 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2513 .ok_or_else(|| {
2514 XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2515 })?;
2516 unsafe {
2518 phase1_u32_fn.clone().launch_on_stream(
2519 cu_stream,
2520 LaunchConfig {
2521 grid_dim: (num_blocks, 1, 1),
2522 block_dim: (block_size, 1, 1),
2523 shared_mem_bytes: 0,
2524 },
2525 (&mut *data, &mut block_sums, n),
2526 )
2527 }
2528 .map_err(|e| {
2529 XlogError::Kernel(format!(
2530 "multiblock_scan_u32_phase1 (view on_stream) failed: {}",
2531 e
2532 ))
2533 })?;
2534
2535 if num_blocks > 1 {
2536 self.multiblock_scan_u32_inplace_on_stream(
2537 &mut block_sums,
2538 num_blocks,
2539 cu_stream,
2540 launch_stream,
2541 runtime,
2542 )?;
2543 }
2544
2545 let phase3_fn = device
2546 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2547 .ok_or_else(|| {
2548 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2549 })?;
2550 unsafe {
2552 phase3_fn.clone().launch_on_stream(
2553 cu_stream,
2554 LaunchConfig {
2555 grid_dim: (num_blocks, 1, 1),
2556 block_dim: (block_size, 1, 1),
2557 shared_mem_bytes: 0,
2558 },
2559 (&mut *data, &block_sums, n),
2560 )
2561 }
2562 .map_err(|e| {
2563 XlogError::Kernel(format!(
2564 "multiblock_scan_phase3 (view on_stream) failed: {}",
2565 e
2566 ))
2567 })?;
2568
2569 if let Some(b) = block_sums.runtime_block() {
2571 runtime
2572 .finish_block_use(
2573 crate::device_runtime::BlockId::from_block(b),
2574 launch_stream,
2575 crate::device_runtime::Access::Write,
2576 )
2577 .map_err(|e| {
2578 XlogError::Kernel(format!(
2579 "multiblock_scan_u32_view_inplace_on_stream: finish_block_use \
2580 for intermediate block_sums failed: {}",
2581 e
2582 ))
2583 })?;
2584 } else {
2585 return Err(XlogError::Kernel(
2586 "multiblock_scan_u32_view_inplace_on_stream: intermediate block_sums has no \
2587 runtime block — caller must use a runtime-backed manager"
2588 .to_string(),
2589 ));
2590 }
2591 Ok(())
2592 }
2593
2594 fn multiblock_scan_u32_view_inplace(
2595 &self,
2596 data: &mut CudaViewMut<'_, u32>,
2597 n: u32,
2598 ) -> Result<()> {
2599 if n == 0 {
2600 return Ok(());
2601 }
2602
2603 let device = self.device.inner();
2604 let block_size = 256u32;
2605
2606 if n <= block_size {
2607 let phase2_fn = device
2608 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2609 .ok_or_else(|| {
2610 XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2611 })?;
2612
2613 unsafe {
2615 phase2_fn.clone().launch(
2616 LaunchConfig {
2617 grid_dim: (1, 1, 1),
2618 block_dim: (block_size, 1, 1),
2619 shared_mem_bytes: 0,
2620 },
2621 (data, n),
2622 )
2623 }
2624 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase2 failed: {}", e)))?;
2625
2626 self.device.synchronize()?;
2627 return Ok(());
2628 }
2629
2630 let num_blocks = n.div_ceil(block_size);
2631 let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2632
2633 let phase1_u32_fn = device
2634 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2635 .ok_or_else(|| {
2636 XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2637 })?;
2638
2639 unsafe {
2641 phase1_u32_fn.clone().launch(
2642 LaunchConfig {
2643 grid_dim: (num_blocks, 1, 1),
2644 block_dim: (block_size, 1, 1),
2645 shared_mem_bytes: 0,
2646 },
2647 (&mut *data, &mut block_sums, n),
2648 )
2649 }
2650 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_u32_phase1 failed: {}", e)))?;
2651 self.device.synchronize()?;
2652
2653 if num_blocks > 1 {
2654 self.multiblock_scan_u32_inplace(&mut block_sums, num_blocks)?;
2655 }
2656
2657 let phase3_fn = device
2658 .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2659 .ok_or_else(|| {
2660 XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2661 })?;
2662
2663 unsafe {
2665 phase3_fn.clone().launch(
2666 LaunchConfig {
2667 grid_dim: (num_blocks, 1, 1),
2668 block_dim: (block_size, 1, 1),
2669 shared_mem_bytes: 0,
2670 },
2671 (&mut *data, &block_sums, n),
2672 )
2673 }
2674 .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
2675
2676 self.device.synchronize()?;
2677 Ok(())
2678 }
2679
2680 pub fn device_row_count(&self, buffer: &CudaBuffer) -> Result<usize> {
2685 if let Some(n) = buffer.cached_row_count() {
2686 return Ok(n as usize);
2687 }
2688 let mut host_rows = [0u32];
2689 self.device
2690 .inner()
2691 .dtoh_sync_copy_into(buffer.num_rows_device(), &mut host_rows)
2692 .map_err(|e| XlogError::Kernel(format!("Failed to read row count: {}", e)))?;
2693 buffer.set_cached_row_count_if_unset(host_rows[0]);
2694 Ok(host_rows[0] as usize)
2695 }
2696
2697 pub fn validated_logical_row_count(&self, buffer: &CudaBuffer) -> Result<usize> {
2702 let logical_rows = self.device_row_count(buffer)?;
2703 validate_logical_row_count(buffer.num_rows(), logical_rows)
2704 }
2705
2706 fn clone_device_row_count(&self, buffer: &CudaBuffer) -> Result<TrackedCudaSlice<u32>> {
2707 let mut d_num_rows = self.memory.alloc::<u32>(1)?;
2708 self.device
2709 .inner()
2710 .dtod_copy(buffer.num_rows_device(), &mut d_num_rows)
2711 .map_err(|e| XlogError::Kernel(format!("Failed to copy row count: {}", e)))?;
2712 Ok(d_num_rows)
2713 }
2714
2715 fn upload_device_row_count(&self, row_count: u32) -> Result<TrackedCudaSlice<u32>> {
2716 let mut d_num_rows = self.memory.alloc::<u32>(1)?;
2717 self.htod_launch_metadata_sync_copy_into(&[row_count], &mut d_num_rows)
2718 .map_err(|e| XlogError::Kernel(format!("Failed to upload row count: {}", e)))?;
2719 Ok(d_num_rows)
2720 }
2721
2722 fn buffer_from_columns_with_device_count(
2723 &self,
2724 columns: Vec<CudaColumn>,
2725 row_cap: u64,
2726 schema: Schema,
2727 src: &CudaBuffer,
2728 ) -> Result<CudaBuffer> {
2729 let d_num_rows = self.clone_device_row_count(src)?;
2730 Ok(CudaBuffer::from_columns(
2731 columns, row_cap, d_num_rows, schema,
2732 ))
2733 }
2734
2735 fn column_bytes_view<'a>(
2736 &self,
2737 col: &'a CudaColumn,
2738 num_bytes: usize,
2739 ) -> Result<RawCudaView<'a, u8>> {
2740 if col.num_bytes() < num_bytes {
2741 return Err(XlogError::Kernel(format!(
2742 "Column has {} bytes but {} required",
2743 col.num_bytes(),
2744 num_bytes
2745 )));
2746 }
2747 let ptr = *col.device_ptr();
2748 Ok(RawCudaView {
2749 ptr,
2750 len: num_bytes,
2751 stream: col.stream().clone(),
2752 source_block: col.runtime_block(),
2753 _marker: PhantomData,
2754 })
2755 }
2756
2757 fn bytes_as_u32_view<'a>(
2758 &self,
2759 bytes: &'a TrackedCudaSlice<u8>,
2760 num_elements: usize,
2761 ) -> Result<RawCudaView<'a, u32>> {
2762 let required_bytes = num_elements * std::mem::size_of::<u32>();
2763 if bytes.len() < required_bytes {
2764 return Err(XlogError::Kernel(format!(
2765 "Packed keys have {} bytes but {} required for {} u32 elements",
2766 bytes.len(),
2767 required_bytes,
2768 num_elements
2769 )));
2770 }
2771 let ptr = *bytes.device_ptr();
2772 if !(ptr as usize).is_multiple_of(std::mem::align_of::<u32>()) {
2773 return Err(XlogError::Kernel(
2774 "Packed keys device pointer is not u32-aligned".to_string(),
2775 ));
2776 }
2777 Ok(RawCudaView {
2778 ptr,
2779 len: num_elements,
2780 stream: bytes.stream().clone(),
2781 source_block: bytes.runtime_block(),
2782 _marker: PhantomData,
2783 })
2784 }
2785
2786 fn column_as_u32_view<'a>(
2788 &self,
2789 col: &'a CudaColumn,
2790 num_elements: usize,
2791 ) -> Result<RawCudaView<'a, u32>> {
2792 let required_bytes = num_elements * std::mem::size_of::<u32>();
2793 if col.num_bytes() < required_bytes {
2794 return Err(XlogError::Kernel(format!(
2795 "Column has {} bytes but {} required for {} u32 elements",
2796 col.num_bytes(),
2797 required_bytes,
2798 num_elements
2799 )));
2800 }
2801 let ptr = *col.device_ptr();
2802 if !(ptr as usize).is_multiple_of(std::mem::align_of::<u32>()) {
2803 return Err(XlogError::Kernel(
2804 "Column device pointer is not u32-aligned".to_string(),
2805 ));
2806 }
2807 Ok(RawCudaView {
2808 ptr,
2809 len: num_elements,
2810 stream: col.stream().clone(),
2811 source_block: col.runtime_block(),
2812 _marker: PhantomData,
2813 })
2814 }
2815
2816 fn column_as_u64_view<'a>(
2817 &self,
2818 col: &'a CudaColumn,
2819 num_elements: usize,
2820 ) -> Result<RawCudaView<'a, u64>> {
2821 let required_bytes = num_elements * std::mem::size_of::<u64>();
2822 if col.num_bytes() < required_bytes {
2823 return Err(XlogError::Kernel(format!(
2824 "Column has {} bytes but {} required for {} u64 elements",
2825 col.num_bytes(),
2826 required_bytes,
2827 num_elements
2828 )));
2829 }
2830 let ptr = *col.device_ptr();
2831 if !(ptr as usize).is_multiple_of(std::mem::align_of::<u64>()) {
2832 return Err(XlogError::Kernel(
2833 "Column device pointer is not u64-aligned".to_string(),
2834 ));
2835 }
2836 Ok(RawCudaView {
2837 ptr,
2838 len: num_elements,
2839 stream: col.stream().clone(),
2840 source_block: col.runtime_block(),
2841 _marker: PhantomData,
2842 })
2843 }
2844
2845 fn column_as_f64_view<'a>(
2847 &self,
2848 col: &'a CudaColumn,
2849 num_elements: usize,
2850 ) -> Result<RawCudaView<'a, f64>> {
2851 let required_bytes = num_elements * std::mem::size_of::<f64>();
2852 if col.num_bytes() < required_bytes {
2853 return Err(XlogError::Kernel(format!(
2854 "Column has {} bytes but {} required for {} f64 elements",
2855 col.num_bytes(),
2856 required_bytes,
2857 num_elements
2858 )));
2859 }
2860 let ptr = *col.device_ptr();
2861 if !(ptr as usize).is_multiple_of(std::mem::align_of::<f64>()) {
2862 return Err(XlogError::Kernel(
2863 "Column device pointer is not f64-aligned".to_string(),
2864 ));
2865 }
2866 Ok(RawCudaView {
2867 ptr,
2868 len: num_elements,
2869 stream: col.stream().clone(),
2870 source_block: col.runtime_block(),
2871 _marker: PhantomData,
2872 })
2873 }
2874
2875 pub fn create_empty_buffer(&self, schema: Schema) -> Result<CudaBuffer> {
2886 let mut columns = Vec::with_capacity(schema.arity());
2887 for _ in 0..schema.arity() {
2888 columns.push(self.memory.alloc::<u8>(0)?.into());
2890 }
2891 self.buffer_from_columns(columns, 0, schema)
2892 }
2893
2894 pub fn create_zero_arity_buffer(&self, schema: Schema, rows: u32) -> Result<CudaBuffer> {
2902 debug_assert_eq!(
2903 schema.arity(),
2904 0,
2905 "create_zero_arity_buffer requires arity 0"
2906 );
2907 self.buffer_from_columns(Vec::new(), u64::from(rows), schema)
2908 }
2909
2910 pub(crate) fn buffer_from_columns(
2911 &self,
2912 columns: Vec<CudaColumn>,
2913 row_cap: u64,
2914 schema: Schema,
2915 ) -> Result<CudaBuffer> {
2916 let row_u32 = u32::try_from(row_cap)
2917 .map_err(|_| XlogError::Kernel(format!("Row capacity {} exceeds u32::MAX", row_cap)))?;
2918 let mut d_num_rows = self.memory.alloc::<u32>(1)?;
2919 self.htod_launch_metadata_sync_copy_into(&[row_u32], &mut d_num_rows)
2920 .map_err(|e| XlogError::Kernel(format!("Failed to set row count: {}", e)))?;
2921 Ok(CudaBuffer::from_columns_with_host_count(
2922 columns, row_cap, d_num_rows, schema, row_u32,
2923 ))
2924 }
2925
2926 fn combine_schemas(&self, left: &Schema, right: &Schema) -> Schema {
2928 let mut columns = left.columns.clone();
2929 columns.extend(right.columns.iter().cloned());
2930 let mut sort_labels = left.sort_labels().to_vec();
2931 sort_labels.extend(right.sort_labels().iter().cloned());
2932 Schema::new(columns)
2933 .with_sort_labels(sort_labels)
2934 .expect("combined schema sort labels match column arity")
2935 }
2936
2937 fn schemas_type_compatible(&self, a: &Schema, b: &Schema) -> bool {
2942 if a.arity() != b.arity() {
2943 return false;
2944 }
2945 for i in 0..a.arity() {
2946 if a.column_type(i) != b.column_type(i) {
2947 return false;
2948 }
2949 }
2950 true
2951 }
2952}
2953
2954#[cfg(test)]
2955mod tests {
2956 use super::*;
2957 use crate::device_runtime::{
2958 AsyncCudaResource, DeviceMemoryResource, GlobalDeviceBudget, LoggingResource, NullSink,
2959 StreamPool, XlogDeviceRuntime,
2960 };
2961 use xlog_core::{AggOp, MemoryBudget, ScalarType};
2962
2963 fn has_cuda_device() -> bool {
2964 CudaDevice::new(0).is_ok()
2965 }
2966
2967 #[test]
2968 fn test_kernel_artifact_locator_precedence_order() {
2969 use super::kernel_paths::KernelArtifactLocator;
2970 use std::fs;
2971 use std::path::PathBuf;
2972
2973 let root = std::env::temp_dir().join(format!(
2974 "xlog-kernel-paths-{}-{}",
2975 std::process::id(),
2976 std::time::SystemTime::now()
2977 .duration_since(std::time::UNIX_EPOCH)
2978 .expect("system clock before UNIX_EPOCH")
2979 .as_nanos()
2980 ));
2981 let cubin_dir = root.join("cubin");
2982 let package_dir = root.join("bin").join("kernels");
2983 let out_dir = root.join("out");
2984 fs::create_dir_all(&cubin_dir).expect("create cubin dir");
2985 fs::create_dir_all(&package_dir).expect("create package kernels dir");
2986 fs::create_dir_all(&out_dir).expect("create out dir");
2987
2988 let name = "xlog_join";
2989 let cc = 75;
2990 let cubin_path = cubin_dir.join(format!("{name}.sm_{cc}.cubin"));
2991 let package_path = package_dir.join(format!("{name}.sm_{cc}.cubin"));
2992 let out_path = out_dir.join(format!("{name}.sm_{cc}.cubin"));
2993 fs::write(&cubin_path, b"cubin").expect("write cubin file");
2994 fs::write(&package_path, b"package").expect("write package file");
2995 fs::write(&out_path, b"out").expect("write out file");
2996
2997 let locator = KernelArtifactLocator::new(
2998 Some(cubin_dir.clone()),
2999 Some(package_dir.clone()),
3000 Some(out_dir.clone()),
3001 );
3002
3003 let (path, is_cubin) = locator
3004 .resolve_module_path(name, cc)
3005 .expect("expected a kernel artifact");
3006 assert_eq!(path, cubin_path);
3007 assert!(is_cubin);
3008
3009 fs::remove_file(&cubin_path).expect("remove cubin file");
3010 let (path, is_cubin) = locator
3011 .resolve_module_path(name, cc)
3012 .expect("expected package kernel artifact");
3013 assert_eq!(path, package_path);
3014 assert!(is_cubin);
3015
3016 fs::remove_file(&package_path).expect("remove package file");
3017 let (path, is_cubin) = locator
3018 .resolve_module_path(name, cc)
3019 .expect("expected out dir kernel artifact");
3020 assert_eq!(path, out_path);
3021 assert!(is_cubin);
3022
3023 let _ = fs::remove_dir_all(PathBuf::from(&root));
3024 }
3025
3026 #[test]
3027 fn test_module_resolution_finds_portable_ptx() {
3028 for name in crate::kernel_manifest_data::KERNEL_CU_NAMES {
3031 let result = resolve_module_path(name, 999);
3032 assert!(
3033 result.is_some(),
3034 "resolve_module_path({name}, 999) should find portable PTX"
3035 );
3036 let (path, is_cubin) = result.unwrap();
3037 assert!(
3038 !is_cubin,
3039 "{name}: expected portable PTX fallback, got cubin"
3040 );
3041 assert!(
3042 path.to_str().unwrap().ends_with(".portable.ptx"),
3043 "{name}: path should end with .portable.ptx, got {:?}",
3044 path
3045 );
3046 }
3047 }
3048
3049 #[test]
3050 fn test_module_resolution_falls_back_to_embedded_portable_ptx() {
3051 use super::kernel_paths::KernelArtifactLocator;
3052
3053 let locator = KernelArtifactLocator::new(None, None, None);
3054 for name in crate::kernel_manifest_data::KERNEL_CU_NAMES {
3055 let sources = resolve_module_sources_with_locator(name, 999, &locator);
3056 assert_eq!(
3057 sources.len(),
3058 1,
3059 "{name}: expected only embedded portable PTX fallback"
3060 );
3061
3062 match &sources[0] {
3063 KernelModuleSource::EmbeddedPortablePtx { ptx } => {
3064 assert!(
3065 ptx.contains(".entry"),
3066 "{name}: embedded PTX should contain CUDA entry points"
3067 );
3068 }
3069 KernelModuleSource::File { path, .. } => {
3070 panic!(
3071 "{name}: expected embedded portable PTX fallback, got file {}",
3072 path.display()
3073 );
3074 }
3075 }
3076 }
3077 }
3078
3079 #[test]
3080 fn test_embedded_portable_ptx_manifest_matches_kernel_manifest() {
3081 let embedded_names: std::collections::BTreeSet<_> =
3082 crate::embedded_kernel_data::EMBEDDED_PORTABLE_PTX
3083 .iter()
3084 .map(|artifact| artifact.name)
3085 .collect();
3086 let manifest_names: std::collections::BTreeSet<_> =
3087 crate::kernel_manifest_data::KERNEL_CU_NAMES
3088 .iter()
3089 .copied()
3090 .collect();
3091
3092 assert_eq!(
3093 embedded_names, manifest_names,
3094 "embedded portable PTX table should cover every runtime kernel module"
3095 );
3096 }
3097
3098 #[test]
3099 fn test_kernel_provider_creation() {
3100 if !has_cuda_device() {
3101 eprintln!("Skipping test: no CUDA device available");
3102 return;
3103 }
3104
3105 let device = Arc::new(CudaDevice::new(0).expect("Failed to create device"));
3106 let budget = MemoryBudget::with_limit(1024 * 1024 * 1024); let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
3108
3109 let provider = CudaKernelProvider::new(device.clone(), memory.clone());
3110 assert!(
3111 provider.is_ok(),
3112 "Failed to create kernel provider: {:?}",
3113 provider.err()
3114 );
3115
3116 let provider = provider.unwrap();
3117 assert!(Arc::ptr_eq(provider.device(), &device));
3118 assert!(Arc::ptr_eq(provider.memory(), &memory));
3119 }
3120
3121 #[test]
3122 fn test_kernel_functions_accessible() {
3123 if !has_cuda_device() {
3124 eprintln!("Skipping test: no CUDA device available");
3125 return;
3126 }
3127
3128 let device = Arc::new(CudaDevice::new(0).expect("Failed to create device"));
3129 let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
3130 let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
3131
3132 let _provider =
3133 CudaKernelProvider::new(device.clone(), memory).expect("Failed to create provider");
3134
3135 let inner = device.inner();
3137
3138 let build_fn = inner.get_func(JOIN_MODULE, join_kernels::HASH_JOIN_BUILD);
3140 assert!(
3141 build_fn.is_some(),
3142 "hash_join_build function should be accessible"
3143 );
3144
3145 let probe_fn = inner.get_func(JOIN_MODULE, join_kernels::HASH_JOIN_PROBE);
3146 assert!(
3147 probe_fn.is_some(),
3148 "hash_join_probe function should be accessible"
3149 );
3150
3151 let mark_fn = inner.get_func(DEDUP_MODULE, dedup_kernels::MARK_DUPLICATES);
3153 assert!(
3154 mark_fn.is_some(),
3155 "mark_duplicates function should be accessible"
3156 );
3157
3158 let compact_fn = inner.get_func(DEDUP_MODULE, dedup_kernels::COMPACT_ROWS);
3159 assert!(
3160 compact_fn.is_some(),
3161 "compact_rows function should be accessible"
3162 );
3163
3164 let boundaries_fn =
3166 inner.get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES);
3167 assert!(
3168 boundaries_fn.is_some(),
3169 "detect_group_boundaries function should be accessible"
3170 );
3171
3172 let count_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT);
3173 assert!(
3174 count_fn.is_some(),
3175 "groupby_count function should be accessible"
3176 );
3177
3178 let sum_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM);
3179 assert!(
3180 sum_fn.is_some(),
3181 "groupby_sum function should be accessible"
3182 );
3183
3184 let min_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN);
3185 assert!(
3186 min_fn.is_some(),
3187 "groupby_min function should be accessible"
3188 );
3189
3190 let max_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX);
3191 assert!(
3192 max_fn.is_some(),
3193 "groupby_max function should be accessible"
3194 );
3195
3196 let xgcf_forward = inner.get_func(CIRCUIT_MODULE, "xgcf_forward_level");
3198 assert!(
3199 xgcf_forward.is_some(),
3200 "xgcf_forward_level function should be accessible"
3201 );
3202
3203 let xgcf_backward_propagate =
3204 inner.get_func(CIRCUIT_MODULE, "xgcf_backward_level_propagate");
3205 assert!(
3206 xgcf_backward_propagate.is_some(),
3207 "xgcf_backward_level_propagate function should be accessible"
3208 );
3209
3210 let xgcf_backward_decision_grad =
3211 inner.get_func(CIRCUIT_MODULE, "xgcf_backward_level_decision_grad");
3212 assert!(
3213 xgcf_backward_decision_grad.is_some(),
3214 "xgcf_backward_level_decision_grad function should be accessible"
3215 );
3216
3217 let xgcf_backward_lit_grad = inner.get_func(CIRCUIT_MODULE, "xgcf_backward_level_lit_grad");
3218 assert!(
3219 xgcf_backward_lit_grad.is_some(),
3220 "xgcf_backward_level_lit_grad function should be accessible"
3221 );
3222
3223 let neural_fill = inner.get_func("xlog_neural", "neural_fill_ad_chain_f32");
3225 assert!(
3226 neural_fill.is_some(),
3227 "neural_fill_ad_chain_f32 function should be accessible"
3228 );
3229 let neural_scatter = inner.get_func("xlog_neural", "neural_scatter_ad_chain_grads_f32");
3230 assert!(
3231 neural_scatter.is_some(),
3232 "neural_scatter_ad_chain_grads_f32 function should be accessible"
3233 );
3234 }
3235
3236 #[test]
3237 fn test_module_names_unique() {
3238 assert_ne!(JOIN_MODULE, DEDUP_MODULE);
3240 assert_ne!(JOIN_MODULE, GROUPBY_MODULE);
3241 assert_ne!(DEDUP_MODULE, GROUPBY_MODULE);
3242 }
3243
3244 fn create_test_provider() -> Option<CudaKernelProvider> {
3246 if !has_cuda_device() {
3247 return None;
3248 }
3249 let device = Arc::new(CudaDevice::new(0).ok()?);
3250 let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
3251 let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
3252 CudaKernelProvider::new(device, memory).ok()
3253 }
3254
3255 fn create_test_provider_with_runtime() -> Option<(CudaKernelProvider, Arc<XlogDeviceRuntime>)> {
3256 if !has_cuda_device() {
3257 return None;
3258 }
3259 let device = Arc::new(CudaDevice::new(0).ok()?);
3260 let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
3261 let sink = Arc::new(NullSink::new());
3262 let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> = Box::new(
3263 AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool)),
3264 );
3265 let logging: Box<dyn DeviceMemoryResource + Send + Sync> =
3266 Box::new(LoggingResource::new(async_resource, sink));
3267 let budget: Box<dyn DeviceMemoryResource + Send + Sync> =
3268 Box::new(GlobalDeviceBudget::new(logging, 1024 * 1024 * 1024));
3269 let runtime = Arc::new(XlogDeviceRuntime::with_resource(
3270 Arc::clone(&device),
3271 0,
3272 pool,
3273 budget,
3274 ));
3275 let memory = Arc::new(GpuMemoryManager::with_runtime(
3276 Arc::clone(&device),
3277 MemoryBudget::with_limit(1024 * 1024 * 1024),
3278 Arc::clone(&runtime),
3279 ));
3280 let provider = CudaKernelProvider::with_runtime(device, memory).ok()?;
3281 Some((provider, runtime))
3282 }
3283
3284 #[test]
3285 fn test_recorded_join_index_build_runs_on_runtime_stream() {
3286 let (provider, runtime) = match create_test_provider_with_runtime() {
3287 Some(fixture) => fixture,
3288 None => {
3289 eprintln!("Skipping test: no CUDA device available");
3290 return;
3291 }
3292 };
3293 let stream = runtime.stream_pool().acquire().expect("recorded stream");
3294 let left = create_test_buffer(&provider, &[1, 2, 3, 4], "key");
3295 let right = create_test_buffer(&provider, &[1, 2, 3, 4], "key");
3296
3297 let index = provider
3298 .build_join_index_v2_recorded(&right, &[0], stream)
3299 .expect("recorded join-index build");
3300 let joined = provider
3301 .hash_join_v2_with_index_recorded(
3302 &left,
3303 &right,
3304 &[0],
3305 &[0],
3306 JoinType::Inner,
3307 &index,
3308 None,
3309 stream,
3310 )
3311 .expect("recorded indexed join consumes recorded build");
3312 runtime
3313 .stream_pool()
3314 .resolve(stream)
3315 .expect("stream resolves")
3316 .synchronize()
3317 .expect("recorded stream synchronized");
3318
3319 assert_eq!(index.right_num_rows(), 4);
3320 assert_eq!(index.right_keys(), &[0]);
3321 assert_eq!(provider.device_row_count(&joined).expect("joined rows"), 4);
3322 }
3323
3324 fn create_test_buffer(
3326 provider: &CudaKernelProvider,
3327 data: &[u32],
3328 col_name: &str,
3329 ) -> CudaBuffer {
3330 let schema = Schema::new(vec![(col_name.to_string(), ScalarType::U32)]);
3331 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
3332
3333 let mut col = provider.memory().alloc::<u8>(bytes.len()).expect("alloc");
3334 provider
3335 .device()
3336 .inner()
3337 .htod_sync_copy_into(&bytes, &mut col)
3338 .expect("htod");
3339
3340 provider
3341 .buffer_from_columns(vec![col.into()], data.len() as u64, schema)
3342 .expect("buffer")
3343 }
3344
3345 fn create_empty_test_buffer(provider: &CudaKernelProvider, schema: Schema) -> CudaBuffer {
3347 let mut columns = Vec::with_capacity(schema.arity());
3348 for _ in 0..schema.arity() {
3349 columns.push(provider.memory().alloc::<u8>(0).expect("alloc").into());
3350 }
3351 provider
3352 .buffer_from_columns(columns, 0, schema)
3353 .expect("buffer")
3354 }
3355
3356 fn read_buffer_u32(provider: &CudaKernelProvider, buffer: &CudaBuffer, col: usize) -> Vec<u32> {
3358 if buffer.is_empty() || buffer.column(col).is_none() {
3359 return vec![];
3360 }
3361 let num_rows = buffer.num_rows() as usize;
3362 let mut bytes = vec![0u8; num_rows * 4];
3363 provider
3364 .device()
3365 .inner()
3366 .dtoh_sync_copy_into(buffer.column(col).unwrap(), &mut bytes)
3367 .expect("dtoh");
3368 bytes
3369 .chunks_exact(4)
3370 .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
3371 .collect()
3372 }
3373
3374 #[test]
3375 fn test_compact_device_mask_respects_mask_len_smaller_than_row_cap() {
3376 let provider = match create_test_provider() {
3377 Some(p) => p,
3378 None => {
3379 eprintln!("Skipping test: no CUDA device available");
3380 return;
3381 }
3382 };
3383
3384 let schema = Schema::new(vec![("id".to_string(), ScalarType::U32)]);
3385 let base = create_test_buffer(&provider, &[1, 2, 3, 4, 5, 6, 7, 8], "id");
3386
3387 let row_cap = 16u64;
3388 let data: Vec<u32> = (0..row_cap as u32).collect();
3389 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
3390 let mut col = provider.memory().alloc::<u8>(bytes.len()).expect("alloc");
3391 provider
3392 .device()
3393 .inner()
3394 .htod_sync_copy_into(&bytes, &mut col)
3395 .expect("htod");
3396 let expanded = provider
3397 .buffer_from_columns_with_device_count(vec![col.into()], row_cap, schema, &base)
3398 .expect("buffer");
3399
3400 let mask: Vec<u8> = vec![1, 0, 1, 0, 1, 0, 1, 0];
3401 let (prefix_sum, count) = provider.prefix_sum_mask(&mask).expect("prefix sum");
3402
3403 let mut d_mask = provider.memory().alloc::<u8>(mask.len()).expect("alloc");
3404 provider
3405 .device()
3406 .inner()
3407 .htod_sync_copy_into(&mask, &mut d_mask)
3408 .expect("mask htod");
3409
3410 let mut d_prefix = provider
3411 .memory()
3412 .alloc::<u32>(prefix_sum.len())
3413 .expect("alloc");
3414 provider
3415 .device()
3416 .inner()
3417 .htod_sync_copy_into(&prefix_sum, &mut d_prefix)
3418 .expect("prefix htod");
3419
3420 let mut d_out_count = provider.memory().alloc::<u32>(1).expect("alloc");
3421 provider
3422 .device()
3423 .inner()
3424 .htod_sync_copy_into(&[count], &mut d_out_count)
3425 .expect("count htod");
3426
3427 let compacted = provider
3428 .compact_buffer_by_device_mask_device_count(&expanded, &d_mask, &d_prefix, d_out_count)
3429 .expect("compact");
3430
3431 assert_eq!(compacted.num_rows(), mask.len() as u64);
3432 let device_rows = provider.device_row_count(&compacted).expect("row count");
3433 assert_eq!(device_rows as u32, count);
3434 }
3435
3436 #[test]
3437 fn test_clone_buffer_preserves_device_count() {
3438 let provider = match create_test_provider() {
3439 Some(p) => p,
3440 None => {
3441 eprintln!("Skipping test: no CUDA device available");
3442 return;
3443 }
3444 };
3445
3446 let schema = Schema::new(vec![("id".to_string(), ScalarType::U32)]);
3447 let ids: Vec<u32> = vec![10, 20, 30];
3448 let buffer = provider
3449 .create_buffer_from_slices(&[bytemuck::cast_slice(&ids)], schema)
3450 .unwrap();
3451
3452 let cloned = provider.clone_buffer(&buffer).unwrap();
3453
3454 let mut host_count = [0u32];
3455 provider
3456 .device()
3457 .inner()
3458 .dtoh_sync_copy_into(cloned.num_rows_device(), &mut host_count)
3459 .unwrap();
3460 assert_eq!(host_count[0], 3);
3461 }
3462
3463 #[test]
3472 fn test_clone_buffer_preserves_cached_row_count() {
3473 let provider = match create_test_provider() {
3474 Some(p) => p,
3475 None => {
3476 eprintln!("Skipping test: no CUDA device available");
3477 return;
3478 }
3479 };
3480
3481 let schema = Schema::new(vec![("id".to_string(), ScalarType::U32)]);
3482 let ids: Vec<u32> = vec![7, 11, 13, 17];
3483 let source = provider
3484 .create_buffer_from_slices(&[bytemuck::cast_slice(&ids)], schema)
3485 .unwrap();
3486 assert_eq!(
3490 source.cached_row_count(),
3491 Some(4),
3492 "source buffer should have its cached row count populated by \
3493 create_buffer_from_slices"
3494 );
3495
3496 let cloned = provider.clone_buffer(&source).unwrap();
3497
3498 assert_eq!(
3499 cloned.cached_row_count(),
3500 Some(4),
3501 "clone_buffer must propagate cached_row_count from source to clone",
3502 );
3503 }
3504
3505 #[test]
3508 fn test_hash_join_empty_inputs() {
3509 let provider = match create_test_provider() {
3510 Some(p) => p,
3511 None => {
3512 eprintln!("Skipping test: no CUDA device available");
3513 return;
3514 }
3515 };
3516
3517 let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3518 let empty = create_empty_test_buffer(&provider, schema.clone());
3519
3520 let result = provider.hash_join(&empty, &empty, &[0], &[0]);
3522 assert!(result.is_ok());
3523 assert!(result.unwrap().is_empty());
3524 }
3525
3526 #[test]
3527 fn test_hash_join_validation() {
3528 let provider = match create_test_provider() {
3529 Some(p) => p,
3530 None => {
3531 eprintln!("Skipping test: no CUDA device available");
3532 return;
3533 }
3534 };
3535
3536 let left = create_test_buffer(&provider, &[1, 2, 3], "left_key");
3537 let right = create_test_buffer(&provider, &[2, 3, 4], "right_key");
3538
3539 let result = provider.hash_join(&left, &right, &[], &[0]);
3541 assert!(result.is_err());
3542
3543 let result = provider.hash_join(&left, &right, &[0], &[0, 0]);
3545 assert!(result.is_err());
3546 }
3547
3548 #[test]
3551 fn test_dedup_empty_input() {
3552 let provider = match create_test_provider() {
3553 Some(p) => p,
3554 None => {
3555 eprintln!("Skipping test: no CUDA device available");
3556 return;
3557 }
3558 };
3559
3560 let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3561 let empty = create_empty_test_buffer(&provider, schema);
3562
3563 let result = provider.dedup(&empty, &[0]);
3564 assert!(result.is_ok());
3565 assert!(result.unwrap().is_empty());
3566 }
3567
3568 #[test]
3569 fn test_dedup_validation() {
3570 let provider = match create_test_provider() {
3571 Some(p) => p,
3572 None => {
3573 eprintln!("Skipping test: no CUDA device available");
3574 return;
3575 }
3576 };
3577
3578 let buffer = create_test_buffer(&provider, &[1, 1, 2, 2, 3], "key");
3579
3580 let result = provider.dedup(&buffer, &[]);
3582 assert!(result.is_err());
3583 }
3584
3585 #[test]
3586 fn test_dedup_with_duplicates() {
3587 let provider = match create_test_provider() {
3588 Some(p) => p,
3589 None => {
3590 eprintln!("Skipping test: no CUDA device available");
3591 return;
3592 }
3593 };
3594
3595 let buffer = create_test_buffer(&provider, &[3, 1, 2, 1, 3, 2], "key");
3597 let deduped = provider.dedup(&buffer, &[0]).unwrap();
3598
3599 let dedup_count = provider
3600 .device_row_count(&deduped)
3601 .expect("read dedup row count");
3602 assert_eq!(dedup_count, 3, "Should have 3 unique values");
3603
3604 let result = provider.download_column::<u32>(&deduped, 0).unwrap();
3605 assert_eq!(result, vec![1, 2, 3]);
3607 }
3608
3609 #[test]
3610 fn test_dedup_larger_input() {
3611 let provider = match create_test_provider() {
3612 Some(p) => p,
3613 None => {
3614 eprintln!("Skipping test: no CUDA device available");
3615 return;
3616 }
3617 };
3618
3619 let a: Vec<u32> = (0..500).collect();
3621 let b: Vec<u32> = (250..750).collect();
3622 let input: Vec<u32> = a.iter().chain(b.iter()).copied().collect();
3623
3624 let buffer = create_test_buffer(&provider, &input, "key");
3625 let deduped = provider.dedup(&buffer, &[0]).unwrap();
3626
3627 let dedup_count = provider
3628 .device_row_count(&deduped)
3629 .expect("read dedup row count");
3630 assert_eq!(dedup_count, 750, "Should have 750 unique values (0..750)");
3631
3632 let result = provider.download_column::<u32>(&deduped, 0).unwrap();
3634 let is_sorted = result.windows(2).all(|w| w[0] <= w[1]);
3635 assert!(is_sorted, "Output should be sorted");
3636
3637 let expected: Vec<u32> = (0..750).collect();
3639 assert_eq!(result, expected);
3640 }
3641
3642 #[test]
3645 fn test_union_empty_inputs() {
3646 let provider = match create_test_provider() {
3647 Some(p) => p,
3648 None => {
3649 eprintln!("Skipping test: no CUDA device available");
3650 return;
3651 }
3652 };
3653
3654 let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3655 let empty = create_empty_test_buffer(&provider, schema.clone());
3656
3657 let result = provider.union(&empty, &empty);
3659 assert!(result.is_ok());
3660 assert!(result.unwrap().is_empty());
3661
3662 let a = create_test_buffer(&provider, &[1, 2, 3], "key");
3664 let empty2 = create_empty_test_buffer(&provider, schema);
3665 let result = provider.union(&a, &empty2);
3666 assert!(result.is_ok());
3667 let result = result.unwrap();
3668 assert_eq!(result.num_rows(), 3);
3669 }
3670
3671 #[test]
3672 fn test_union_schema_type_mismatch() {
3673 let provider = match create_test_provider() {
3674 Some(p) => p,
3675 None => {
3676 eprintln!("Skipping test: no CUDA device available");
3677 return;
3678 }
3679 };
3680
3681 let a = create_test_buffer(&provider, &[1, 2], "col_a");
3682 let b = create_test_buffer(&provider, &[3, 4], "col_b");
3683
3684 let result = provider.union(&a, &b);
3686 assert!(result.is_ok());
3687
3688 let two_col_schema = Schema::new(vec![
3690 ("x".to_string(), ScalarType::U32),
3691 ("y".to_string(), ScalarType::U32),
3692 ]);
3693 let c = provider
3694 .create_buffer_from_u32_columns(&[&[1, 2], &[3, 4]], two_col_schema)
3695 .unwrap();
3696 let result = provider.union(&a, &c);
3697 assert!(result.is_err());
3698 }
3699
3700 #[test]
3703 fn test_diff_empty_inputs() {
3704 let provider = match create_test_provider() {
3705 Some(p) => p,
3706 None => {
3707 eprintln!("Skipping test: no CUDA device available");
3708 return;
3709 }
3710 };
3711
3712 let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3713 let empty = create_empty_test_buffer(&provider, schema.clone());
3714
3715 let result = provider.diff(&empty, &empty);
3717 assert!(result.is_ok());
3718 assert!(result.unwrap().is_empty());
3719
3720 let a = create_test_buffer(&provider, &[1, 2, 3], "key");
3722 let empty2 = create_empty_test_buffer(&provider, schema);
3723 let result = provider.diff(&a, &empty2);
3724 assert!(result.is_ok());
3725 let result = result.unwrap();
3726 assert_eq!(result.num_rows(), 3);
3727 }
3728
3729 #[test]
3730 fn test_diff_basic() {
3731 let provider = match create_test_provider() {
3732 Some(p) => p,
3733 None => {
3734 eprintln!("Skipping test: no CUDA device available");
3735 return;
3736 }
3737 };
3738
3739 let a = create_test_buffer(&provider, &[1, 2, 3, 4, 5], "key");
3740 let b = create_test_buffer(&provider, &[2, 4], "key");
3741
3742 let result = provider.diff(&a, &b);
3743 assert!(result.is_ok());
3744 let result = result.unwrap();
3745 assert_eq!(result.num_rows(), 3); let values = read_buffer_u32(&provider, &result, 0);
3748 assert_eq!(values, vec![1, 3, 5]);
3749 }
3750
3751 #[test]
3752 fn test_diff_all_filtered_out() {
3753 let provider = match create_test_provider() {
3754 Some(p) => p,
3755 None => {
3756 eprintln!("Skipping test: no CUDA device available");
3757 return;
3758 }
3759 };
3760
3761 let a = create_test_buffer(&provider, &[1, 2, 3], "key");
3762 let b = create_test_buffer(&provider, &[1, 2, 3, 4, 5], "key");
3763
3764 let result = provider.diff(&a, &b);
3765 assert!(result.is_ok());
3766 assert!(result.unwrap().is_empty());
3767 }
3768
3769 #[test]
3770 fn test_diff_schema_mismatch() {
3771 let provider = match create_test_provider() {
3772 Some(p) => p,
3773 None => {
3774 eprintln!("Skipping test: no CUDA device available");
3775 return;
3776 }
3777 };
3778
3779 let a = create_test_buffer(&provider, &[1, 2], "col_a");
3781 let b = create_test_buffer(&provider, &[1, 2], "col_b");
3782 let result = provider.diff(&a, &b);
3783 assert!(
3784 result.is_ok(),
3785 "Same types with different names should succeed"
3786 );
3787
3788 let schema_2col = Schema::new(vec![
3790 ("c0".to_string(), ScalarType::U32),
3791 ("c1".to_string(), ScalarType::U32),
3792 ]);
3793
3794 let bytes_2col: Vec<u8> = [1u32, 2, 3, 4]
3795 .iter()
3796 .flat_map(|v| v.to_le_bytes())
3797 .collect();
3798 let mut col0 = provider
3799 .memory()
3800 .alloc::<u8>(bytes_2col.len() / 2)
3801 .expect("alloc");
3802 let mut col1 = provider
3803 .memory()
3804 .alloc::<u8>(bytes_2col.len() / 2)
3805 .expect("alloc");
3806 provider
3807 .device()
3808 .inner()
3809 .htod_sync_copy_into(&bytes_2col[..8], &mut col0)
3810 .expect("htod");
3811 provider
3812 .device()
3813 .inner()
3814 .htod_sync_copy_into(&bytes_2col[8..], &mut col1)
3815 .expect("htod");
3816 let buffer_2col = provider
3817 .buffer_from_columns(vec![col0.into(), col1.into()], 2, schema_2col)
3818 .expect("buffer");
3819
3820 let buffer_1col = create_test_buffer(&provider, &[1, 2], "c0");
3821
3822 let result = provider.diff(&buffer_2col, &buffer_1col);
3823 assert!(result.is_err(), "Different arities should fail");
3824 }
3825
3826 #[test]
3829 fn test_groupby_empty_input() {
3830 let provider = match create_test_provider() {
3831 Some(p) => p,
3832 None => {
3833 eprintln!("Skipping test: no CUDA device available");
3834 return;
3835 }
3836 };
3837
3838 let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3839 let empty = create_empty_test_buffer(&provider, schema);
3840
3841 let result = provider.groupby_agg(&empty, &[0], AggOp::Count, 0);
3842 assert!(result.is_ok());
3843 assert!(result.unwrap().is_empty());
3844 }
3845
3846 #[test]
3847 fn test_groupby_validation() {
3848 let provider = match create_test_provider() {
3849 Some(p) => p,
3850 None => {
3851 eprintln!("Skipping test: no CUDA device available");
3852 return;
3853 }
3854 };
3855
3856 let buffer = create_test_buffer(&provider, &[1, 1, 2, 2, 3], "key");
3857
3858 let result = provider.groupby_agg(&buffer, &[], AggOp::Count, 0);
3860 assert!(result.is_err());
3861
3862 let result = provider.groupby_agg(&buffer, &[0], AggOp::Count, 5);
3864 assert!(result.is_err());
3865 }
3866
3867 #[test]
3868 fn test_groupby_logsumexp() {
3869 let provider = match create_test_provider() {
3870 Some(p) => p,
3871 None => {
3872 eprintln!("Skipping test: no CUDA device available");
3873 return;
3874 }
3875 };
3876
3877 let keys: Vec<u32> = vec![1, 1, 2, 2];
3881 let values: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
3882
3883 let schema = Schema::new(vec![
3884 ("key".to_string(), ScalarType::U32),
3885 ("value".to_string(), ScalarType::F64),
3886 ]);
3887
3888 let key_bytes: Vec<u8> = keys.iter().flat_map(|v| v.to_le_bytes()).collect();
3890 let mut key_col = provider
3891 .memory()
3892 .alloc::<u8>(key_bytes.len())
3893 .expect("alloc key");
3894 provider
3895 .device()
3896 .inner()
3897 .htod_sync_copy_into(&key_bytes, &mut key_col)
3898 .expect("upload key");
3899
3900 let val_bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
3902 let mut val_col = provider
3903 .memory()
3904 .alloc::<u8>(val_bytes.len())
3905 .expect("alloc val");
3906 provider
3907 .device()
3908 .inner()
3909 .htod_sync_copy_into(&val_bytes, &mut val_col)
3910 .expect("upload val");
3911
3912 let buffer = provider
3913 .buffer_from_columns(vec![key_col.into(), val_col.into()], 4, schema)
3914 .expect("buffer");
3915
3916 let result = provider.groupby_agg(&buffer, &[0], AggOp::LogSumExp, 1);
3918 assert!(
3919 result.is_ok(),
3920 "groupby_agg with LogSumExp should succeed: {:?}",
3921 result.err()
3922 );
3923
3924 let result = result.unwrap();
3925 let group_count = provider
3926 .device_row_count(&result)
3927 .expect("read group count");
3928 assert_eq!(group_count, 2, "Should have 2 groups");
3929
3930 let result_values = provider
3932 .download_column::<f64>(&result, 1)
3933 .expect("download result");
3934
3935 let expected_0 = 2.0_f64 + ((-1.0_f64).exp() + 1.0_f64).ln(); let expected_1 = 4.0_f64 + ((-1.0_f64).exp() + 1.0_f64).ln(); let tolerance = 1e-5;
3942 assert!(
3943 (result_values[0] - expected_0).abs() < tolerance,
3944 "Group 0 logsumexp mismatch: got {}, expected {}",
3945 result_values[0],
3946 expected_0
3947 );
3948 assert!(
3949 (result_values[1] - expected_1).abs() < tolerance,
3950 "Group 1 logsumexp mismatch: got {}, expected {}",
3951 result_values[1],
3952 expected_1
3953 );
3954 }
3955
3956 #[test]
3959 fn test_combine_schemas() {
3960 let provider = match create_test_provider() {
3961 Some(p) => p,
3962 None => {
3963 eprintln!("Skipping test: no CUDA device available");
3964 return;
3965 }
3966 };
3967
3968 let left = Schema::new(vec![("a".to_string(), ScalarType::U32)]);
3969 let right = Schema::new(vec![("b".to_string(), ScalarType::U64)]);
3970
3971 let combined = provider.combine_schemas(&left, &right);
3972 assert_eq!(combined.arity(), 2);
3973 assert_eq!(combined.column_type(0), Some(ScalarType::U32));
3974 assert_eq!(combined.column_type(1), Some(ScalarType::U64));
3975 }
3976
3977 #[test]
3978 fn test_groupby_result_schema() {
3979 let provider = match create_test_provider() {
3980 Some(p) => p,
3981 None => {
3982 eprintln!("Skipping test: no CUDA device available");
3983 return;
3984 }
3985 };
3986
3987 let input = Schema::new(vec![
3988 ("key".to_string(), ScalarType::U32),
3989 ("value".to_string(), ScalarType::U32),
3990 ]);
3991
3992 let count_schema =
3994 provider.groupby_multi_agg_result_schema(&input, &[0], &[(1, AggOp::Count)]);
3995 assert_eq!(count_schema.arity(), 2);
3996 assert_eq!(count_schema.column_type(1), Some(ScalarType::U64));
3997
3998 let sum_schema = provider.groupby_multi_agg_result_schema(&input, &[0], &[(1, AggOp::Sum)]);
4000 assert_eq!(sum_schema.arity(), 2);
4001 assert_eq!(sum_schema.column_type(1), Some(ScalarType::U64));
4002
4003 let min_schema = provider.groupby_multi_agg_result_schema(&input, &[0], &[(1, AggOp::Min)]);
4005 assert_eq!(min_schema.arity(), 2);
4006 assert_eq!(min_schema.column_type(1), Some(ScalarType::U32));
4007 }
4008
4009 #[test]
4010 fn test_groupby_multi_agg_sum_returns_u64_schema() {
4011 let provider = match create_test_provider() {
4012 Some(p) => p,
4013 None => {
4014 eprintln!("Skipping test: no CUDA device");
4015 return;
4016 }
4017 };
4018
4019 let schema = Schema::new(vec![
4020 ("key".to_string(), ScalarType::U32),
4021 ("val".to_string(), ScalarType::U32),
4022 ]);
4023
4024 let result_schema =
4025 provider.groupby_multi_agg_result_schema(&schema, &[0], &[(1, AggOp::Sum)]);
4026
4027 assert_eq!(
4029 result_schema.column_type(1),
4030 Some(ScalarType::U64),
4031 "Sum aggregation should return U64 type, not U32"
4032 );
4033 }
4034
4035 #[test]
4036 fn test_join_custom_max_output() {
4037 let provider = match create_test_provider() {
4038 Some(p) => p,
4039 None => {
4040 eprintln!("Skipping test: no CUDA device available");
4041 return;
4042 }
4043 };
4044
4045 let left = create_test_buffer(&provider, &[1, 1, 1, 1, 2, 2, 2, 2], "left_key");
4050 let right = create_test_buffer(&provider, &[1, 1, 1, 2, 2, 2], "right_key");
4051
4052 let result_limited = provider
4054 .hash_join_v2_with_limit(&left, &right, &[0], &[0], JoinType::Inner, Some(10))
4055 .expect("join with limit should succeed");
4056 assert!(
4057 result_limited.num_rows() <= 10,
4058 "With limit 10, got {} rows but expected at most 10",
4059 result_limited.num_rows()
4060 );
4061
4062 let result_unlimited = provider
4064 .hash_join_v2_with_limit(&left, &right, &[0], &[0], JoinType::Inner, None)
4065 .expect("join without limit should succeed");
4066 assert_eq!(
4067 result_unlimited.num_rows(),
4068 24,
4069 "Without limit, expected 24 rows but got {}",
4070 result_unlimited.num_rows()
4071 );
4072
4073 let result_legacy = provider
4075 .hash_join_v2(&left, &right, &[0], &[0], JoinType::Inner)
4076 .expect("legacy hash_join_v2 should succeed");
4077 assert_eq!(
4078 result_legacy.num_rows(),
4079 24,
4080 "Legacy API without limit, expected 24 rows but got {}",
4081 result_legacy.num_rows()
4082 );
4083 }
4084
4085 fn create_arith_test_provider() -> Option<CudaKernelProvider> {
4089 if !has_cuda_device() {
4090 return None;
4091 }
4092 let device = Arc::new(CudaDevice::new(0).ok()?);
4093 let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
4094 let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
4095 CudaKernelProvider::new(device, memory).ok()
4096 }
4097
4098 fn create_i64_buffer(provider: &CudaKernelProvider, data: &[i64]) -> CudaBuffer {
4100 let schema = Schema::new(vec![("col".to_string(), ScalarType::I64)]);
4101 provider
4102 .create_buffer_from_slice::<i64>(data, schema)
4103 .unwrap()
4104 }
4105
4106 fn create_f64_buffer(provider: &CudaKernelProvider, data: &[f64]) -> CudaBuffer {
4108 let schema = Schema::new(vec![("col".to_string(), ScalarType::F64)]);
4109 provider
4110 .create_buffer_from_slice::<f64>(data, schema)
4111 .unwrap()
4112 }
4113
4114 #[test]
4115 fn test_add_columns_i64() {
4116 let Some(provider) = create_arith_test_provider() else {
4117 eprintln!("Skipping test: no CUDA device available");
4118 return;
4119 };
4120
4121 let a = create_i64_buffer(&provider, &[1, 2, 3, 4, 5]);
4122 let b = create_i64_buffer(&provider, &[10, 20, 30, 40, 50]);
4123
4124 let result = provider.add_columns(&a, &b).unwrap();
4125 let values = provider.download_column::<i64>(&result, 0).unwrap();
4126
4127 assert_eq!(values, vec![11, 22, 33, 44, 55]);
4128 }
4129
4130 #[test]
4131 fn test_sub_columns_i64() {
4132 let Some(provider) = create_arith_test_provider() else {
4133 eprintln!("Skipping test: no CUDA device available");
4134 return;
4135 };
4136
4137 let a = create_i64_buffer(&provider, &[10, 20, 30, 40, 50]);
4138 let b = create_i64_buffer(&provider, &[1, 2, 3, 4, 5]);
4139
4140 let result = provider.sub_columns(&a, &b).unwrap();
4141 let values = provider.download_column::<i64>(&result, 0).unwrap();
4142
4143 assert_eq!(values, vec![9, 18, 27, 36, 45]);
4144 }
4145
4146 #[test]
4147 fn test_mul_columns_i64() {
4148 let Some(provider) = create_arith_test_provider() else {
4149 eprintln!("Skipping test: no CUDA device available");
4150 return;
4151 };
4152
4153 let a = create_i64_buffer(&provider, &[2, 3, 4, 5, 6]);
4154 let b = create_i64_buffer(&provider, &[3, 4, 5, 6, 7]);
4155
4156 let result = provider.mul_columns(&a, &b).unwrap();
4157 let values = provider.download_column::<i64>(&result, 0).unwrap();
4158
4159 assert_eq!(values, vec![6, 12, 20, 30, 42]);
4160 }
4161
4162 #[test]
4163 fn test_div_columns_i64() {
4164 let Some(provider) = create_arith_test_provider() else {
4165 eprintln!("Skipping test: no CUDA device available");
4166 return;
4167 };
4168
4169 let a = create_i64_buffer(&provider, &[100, 200, 300, 400]);
4170 let b = create_i64_buffer(&provider, &[10, 20, 30, 40]);
4171
4172 let result = provider.div_columns(&a, &b).unwrap();
4173 let values = provider.download_column::<i64>(&result, 0).unwrap();
4174
4175 assert_eq!(values, vec![10, 10, 10, 10]);
4176 }
4177
4178 #[test]
4179 fn test_div_columns_by_zero() {
4180 let Some(provider) = create_arith_test_provider() else {
4181 eprintln!("Skipping test: no CUDA device available");
4182 return;
4183 };
4184
4185 let a = create_i64_buffer(&provider, &[10, 20, 30]);
4186 let b = create_i64_buffer(&provider, &[2, 0, 3]); let result = provider.div_columns(&a, &b).unwrap();
4189 let values = provider.download_column::<i64>(&result, 0).unwrap();
4190
4191 assert_eq!(values, vec![5, i64::MAX, 10]);
4193 }
4194
4195 #[test]
4196 fn test_mod_columns_i64() {
4197 let Some(provider) = create_arith_test_provider() else {
4198 eprintln!("Skipping test: no CUDA device available");
4199 return;
4200 };
4201
4202 let a = create_i64_buffer(&provider, &[17, 23, 100, 7]);
4203 let b = create_i64_buffer(&provider, &[5, 7, 30, 3]);
4204
4205 let result = provider.mod_columns(&a, &b).unwrap();
4206 let values = provider.download_column::<i64>(&result, 0).unwrap();
4207
4208 assert_eq!(values, vec![2, 2, 10, 1]);
4209 }
4210
4211 #[test]
4212 fn test_mod_columns_by_zero() {
4213 let Some(provider) = create_arith_test_provider() else {
4214 eprintln!("Skipping test: no CUDA device available");
4215 return;
4216 };
4217
4218 let a = create_i64_buffer(&provider, &[10, 20]);
4219 let b = create_i64_buffer(&provider, &[3, 0]); let result = provider.mod_columns(&a, &b).unwrap();
4222 let values = provider.download_column::<i64>(&result, 0).unwrap();
4223
4224 assert_eq!(values, vec![1, 0]);
4226 }
4227
4228 #[test]
4229 fn test_abs_column_i64() {
4230 let Some(provider) = create_arith_test_provider() else {
4231 eprintln!("Skipping test: no CUDA device available");
4232 return;
4233 };
4234
4235 let a = create_i64_buffer(&provider, &[-5, 10, -15, 20, 0]);
4236
4237 let result = provider.abs_column(&a).unwrap();
4238 let values = provider.download_column::<i64>(&result, 0).unwrap();
4239
4240 assert_eq!(values, vec![5, 10, 15, 20, 0]);
4241 }
4242
4243 #[test]
4244 fn test_min_columns_i64() {
4245 let Some(provider) = create_arith_test_provider() else {
4246 eprintln!("Skipping test: no CUDA device available");
4247 return;
4248 };
4249
4250 let a = create_i64_buffer(&provider, &[5, 10, 15, 20]);
4251 let b = create_i64_buffer(&provider, &[3, 12, 10, 25]);
4252
4253 let result = provider.min_columns(&a, &b).unwrap();
4254 let values = provider.download_column::<i64>(&result, 0).unwrap();
4255
4256 assert_eq!(values, vec![3, 10, 10, 20]);
4257 }
4258
4259 #[test]
4260 fn test_max_columns_i64() {
4261 let Some(provider) = create_arith_test_provider() else {
4262 eprintln!("Skipping test: no CUDA device available");
4263 return;
4264 };
4265
4266 let a = create_i64_buffer(&provider, &[5, 10, 15, 20]);
4267 let b = create_i64_buffer(&provider, &[3, 12, 10, 25]);
4268
4269 let result = provider.max_columns(&a, &b).unwrap();
4270 let values = provider.download_column::<i64>(&result, 0).unwrap();
4271
4272 assert_eq!(values, vec![5, 12, 15, 25]);
4273 }
4274
4275 #[test]
4276 fn test_add_columns_f64() {
4277 let Some(provider) = create_arith_test_provider() else {
4278 eprintln!("Skipping test: no CUDA device available");
4279 return;
4280 };
4281
4282 let a = create_f64_buffer(&provider, &[1.5, 2.5, 3.5]);
4283 let b = create_f64_buffer(&provider, &[0.5, 1.5, 2.5]);
4284
4285 let result = provider.add_columns(&a, &b).unwrap();
4286 let values = provider.download_column::<f64>(&result, 0).unwrap();
4287
4288 assert_eq!(values, vec![2.0, 4.0, 6.0]);
4289 }
4290
4291 #[test]
4292 fn test_mul_columns_f64() {
4293 let Some(provider) = create_arith_test_provider() else {
4294 eprintln!("Skipping test: no CUDA device available");
4295 return;
4296 };
4297
4298 let a = create_f64_buffer(&provider, &[2.0, 3.0, 4.0]);
4299 let b = create_f64_buffer(&provider, &[1.5, 2.0, 2.5]);
4300
4301 let result = provider.mul_columns(&a, &b).unwrap();
4302 let values = provider.download_column::<f64>(&result, 0).unwrap();
4303
4304 assert_eq!(values, vec![3.0, 6.0, 10.0]);
4305 }
4306
4307 #[test]
4308 fn test_div_columns_f64_by_zero() {
4309 let Some(provider) = create_arith_test_provider() else {
4310 eprintln!("Skipping test: no CUDA device available");
4311 return;
4312 };
4313
4314 let a = create_f64_buffer(&provider, &[1.0, -1.0, 0.0]);
4315 let b = create_f64_buffer(&provider, &[0.0, 0.0, 0.0]);
4316
4317 let result = provider.div_columns(&a, &b).unwrap();
4318 let values = provider.download_column::<f64>(&result, 0).unwrap();
4319
4320 assert!(values[0].is_infinite() && values[0].is_sign_positive());
4322 assert!(values[1].is_infinite() && values[1].is_sign_negative());
4323 assert!(values[2].is_nan());
4324 }
4325
4326 #[test]
4327 fn test_pow_columns() {
4328 let Some(provider) = create_arith_test_provider() else {
4329 eprintln!("Skipping test: no CUDA device available");
4330 return;
4331 };
4332
4333 let base = create_i64_buffer(&provider, &[2, 3, 4, 5]);
4334 let exp = create_i64_buffer(&provider, &[3, 2, 2, 1]);
4335
4336 let result = provider.pow_columns(&base, &exp).unwrap();
4337 let values = provider.download_column::<f64>(&result, 0).unwrap();
4338
4339 assert_eq!(values, vec![8.0, 9.0, 16.0, 5.0]);
4341 }
4342
4343 #[test]
4344 fn test_pow_columns_fractional_exp() {
4345 let Some(provider) = create_arith_test_provider() else {
4346 eprintln!("Skipping test: no CUDA device available");
4347 return;
4348 };
4349
4350 let base = create_f64_buffer(&provider, &[4.0, 9.0, 27.0]);
4351 let exp = create_f64_buffer(&provider, &[0.5, 0.5, 1.0 / 3.0]);
4352
4353 let result = provider.pow_columns(&base, &exp).unwrap();
4354 let values = provider.download_column::<f64>(&result, 0).unwrap();
4355
4356 assert!((values[0] - 2.0).abs() < 1e-10);
4358 assert!((values[1] - 3.0).abs() < 1e-10);
4359 assert!((values[2] - 3.0).abs() < 1e-10);
4360 }
4361
4362 #[test]
4363 fn test_cast_i64_to_f64() {
4364 let Some(provider) = create_arith_test_provider() else {
4365 eprintln!("Skipping test: no CUDA device available");
4366 return;
4367 };
4368
4369 let a = create_i64_buffer(&provider, &[1, 2, 3, 4, 5]);
4370
4371 let result = provider.cast_column(&a, ScalarType::F64).unwrap();
4372 let values = provider.download_column::<f64>(&result, 0).unwrap();
4373
4374 assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
4375 }
4376
4377 #[test]
4378 fn test_cast_f64_to_i64() {
4379 let Some(provider) = create_arith_test_provider() else {
4380 eprintln!("Skipping test: no CUDA device available");
4381 return;
4382 };
4383
4384 let a = create_f64_buffer(&provider, &[1.9, 2.1, 3.5, 4.0, 5.7]);
4385
4386 let result = provider.cast_column(&a, ScalarType::I64).unwrap();
4387 let values = provider.download_column::<i64>(&result, 0).unwrap();
4388
4389 assert_eq!(values, vec![1, 2, 3, 4, 5]);
4391 }
4392
4393 #[test]
4394 fn test_cast_i64_to_i32() {
4395 let Some(provider) = create_arith_test_provider() else {
4396 eprintln!("Skipping test: no CUDA device available");
4397 return;
4398 };
4399
4400 let a = create_i64_buffer(&provider, &[1, 2, 3, 100, 200]);
4401
4402 let result = provider.cast_column(&a, ScalarType::I32).unwrap();
4403 let values = provider.download_column::<i32>(&result, 0).unwrap();
4404
4405 assert_eq!(values, vec![1, 2, 3, 100, 200]);
4406 }
4407
4408 #[test]
4409 fn test_arithmetic_row_count_mismatch() {
4410 let Some(provider) = create_arith_test_provider() else {
4411 eprintln!("Skipping test: no CUDA device available");
4412 return;
4413 };
4414
4415 let a = create_i64_buffer(&provider, &[1, 2, 3]);
4416 let b = create_i64_buffer(&provider, &[1, 2]); let result = provider.add_columns(&a, &b);
4419 assert!(result.is_err());
4420 let err = result.err().unwrap();
4421 assert!(err.to_string().contains("Row count mismatch"));
4422 }
4423
4424 #[test]
4425 fn test_arithmetic_empty_buffers() {
4426 let Some(provider) = create_arith_test_provider() else {
4427 eprintln!("Skipping test: no CUDA device available");
4428 return;
4429 };
4430
4431 let a = create_i64_buffer(&provider, &[]);
4432 let b = create_i64_buffer(&provider, &[]);
4433
4434 let result = provider.add_columns(&a, &b).unwrap();
4435 let values = provider.download_column::<i64>(&result, 0).unwrap();
4436
4437 assert_eq!(values, Vec::<i64>::new());
4438 }
4439
4440 #[test]
4441 fn test_wrapping_arithmetic_overflow() {
4442 let Some(provider) = create_arith_test_provider() else {
4443 eprintln!("Skipping test: no CUDA device available");
4444 return;
4445 };
4446
4447 let a = create_i64_buffer(&provider, &[i64::MAX, i64::MIN]);
4448 let b = create_i64_buffer(&provider, &[1, -1]);
4449
4450 let add_result = provider.add_columns(&a, &b).unwrap();
4452 let add_values = provider.download_column::<i64>(&add_result, 0).unwrap();
4453 assert_eq!(add_values[0], i64::MIN); assert_eq!(add_values[1], i64::MAX); }
4456
4457 #[test]
4458 fn test_abs_column_f64() {
4459 let Some(provider) = create_arith_test_provider() else {
4460 eprintln!("Skipping test: no CUDA device available");
4461 return;
4462 };
4463
4464 let a = create_f64_buffer(&provider, &[-1.5, 2.5, -3.5, 0.0]);
4465
4466 let result = provider.abs_column(&a).unwrap();
4467 let values = provider.download_column::<f64>(&result, 0).unwrap();
4468
4469 assert_eq!(values, vec![1.5, 2.5, 3.5, 0.0]);
4470 }
4471
4472 #[test]
4473 fn test_min_max_columns_f64() {
4474 let Some(provider) = create_arith_test_provider() else {
4475 eprintln!("Skipping test: no CUDA device available");
4476 return;
4477 };
4478
4479 let a = create_f64_buffer(&provider, &[1.5, 5.0, 3.0]);
4480 let b = create_f64_buffer(&provider, &[2.0, 3.0, 4.0]);
4481
4482 let min_result = provider.min_columns(&a, &b).unwrap();
4483 let min_values = provider.download_column::<f64>(&min_result, 0).unwrap();
4484 assert_eq!(min_values, vec![1.5, 3.0, 3.0]);
4485
4486 let max_result = provider.max_columns(&a, &b).unwrap();
4487 let max_values = provider.download_column::<f64>(&max_result, 0).unwrap();
4488 assert_eq!(max_values, vec![2.0, 5.0, 4.0]);
4489 }
4490}