Skip to main content

xlog_cuda/provider/
mod.rs

1//! CUDA kernel provider implementation
2//!
3//! This module provides the `CudaKernelProvider` which manages pre-compiled
4//! PTX kernels for GPU execution of relational operations (join, dedup, groupby).
5
6use std::collections::HashMap;
7use std::marker::PhantomData;
8use std::path::PathBuf;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::{Arc, Mutex, OnceLock};
11
12use std::ffi::c_void;
13use xlog_core::{Result, Schema, XlogError};
14
15use crate::{
16    cuda_compat::{
17        AsKernelParam, DeviceParamStorage, DevicePtr, DeviceRepr, DeviceSlice,
18        IntoKernelParamStorage, LaunchAsync, LaunchConfig,
19    },
20    cuda_graph::{CapturedCudaGraph, CsmCudaGraphKey, CudaGraphNode},
21    memory::{validate_logical_row_count, CudaColumn, TrackedCudaSlice},
22    CudaBuffer, CudaDevice, CudaStream, CudaViewMut, GpuMemoryManager,
23};
24
25mod arithmetic;
26mod filter;
27mod fj;
28mod fj_delta;
29mod fj_delta_sparse;
30mod groupby;
31mod ilp;
32mod ilp_exact;
33mod io;
34mod kernel_loading;
35pub mod kernel_paths;
36mod launch_safe;
37mod probabilistic;
38mod relational;
39mod transfer;
40mod wcoj;
41mod wcoj_metadata;
42mod wcoj_project;
43
44pub use fj::{FjNode, FjPlan, FjSubAtom};
45pub use fj_delta::{FjDeltaCols, FJ_DELTA_MAX_DOMAIN};
46
47/// Per-module PTX load timing (populated only when XLOG_WARMUP_PROFILE=1).
48#[derive(Debug, Clone, Default)]
49pub struct PtxLoadProfile {
50    pub total_sec: f64,
51    pub per_module_sec: Vec<(String, f64)>,
52    pub cubin_loaded: u32,
53    pub ptx_fallback: u32,
54}
55
56fn warmup_profiling_enabled() -> bool {
57    std::env::var("XLOG_WARMUP_PROFILE")
58        .map(|v| v == "1")
59        .unwrap_or(false)
60}
61
62/// Detect device compute capability as a two-digit number (e.g. 75, 80, 120).
63fn detect_compute_capability(device: &Arc<CudaDevice>) -> Result<u32> {
64    let major = device
65        .inner()
66        .attribute(
67            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
68        )
69        .map_err(|e| XlogError::Kernel(format!("Failed to query SM major: {}", e)))?;
70    let minor = device
71        .inner()
72        .attribute(
73            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
74        )
75        .map_err(|e| XlogError::Kernel(format!("Failed to query SM minor: {}", e)))?;
76    Ok((major as u32) * 10 + (minor as u32))
77}
78
79#[cfg(test)]
80fn resolve_module_path(name: &str, cc: u32) -> Option<(std::path::PathBuf, bool)> {
81    kernel_paths::KernelArtifactLocator::from_env().resolve_module_path(name, cc)
82}
83
84#[derive(Debug)]
85pub(crate) enum KernelModuleSource {
86    File { path: PathBuf, is_cubin: bool },
87    EmbeddedPortablePtx { ptx: &'static str },
88}
89
90pub(crate) fn resolve_module_sources_with_locator(
91    name: &str,
92    cc: u32,
93    locator: &kernel_paths::KernelArtifactLocator,
94) -> Vec<KernelModuleSource> {
95    let mut sources: Vec<KernelModuleSource> = locator
96        .resolve_module_paths(name, cc)
97        .into_iter()
98        // Skip any staged cubin/PTX whose bytes diverge from what this binary
99        // was built against. A stale staged artifact (kernel signature changed
100        // but the staged copy was never refreshed) otherwise loads "fine" and
101        // then launches a mismatched kernel into an illegal address.
102        .filter(|(path, _)| !staged_artifact_is_stale(path))
103        .map(|(path, is_cubin)| KernelModuleSource::File { path, is_cubin })
104        .collect();
105
106    // ALWAYS append the embedded portable PTX as the final fallback. It is
107    // compiled into this binary, so it can never be stale relative to the launch
108    // sites — it guarantees a signature-correct kernel even when every staged
109    // File artifact was skipped as stale or fails to load. (Previously this was
110    // suppressed whenever any portable-PTX *file* existed, which let a stale
111    // staged PTX shadow the fresh embedded one.)
112    if let Some(ptx) = crate::embedded_kernel_data::portable_ptx(name) {
113        sources.push(KernelModuleSource::EmbeddedPortablePtx { ptx });
114    }
115    sources
116}
117
118/// A staged cubin/PTX is "stale" when this binary embeds a canonical integrity
119/// hash for that artifact file name and the on-disk bytes do not match it — the
120/// staged artifact diverges from what this build produced. Loading such an
121/// artifact can launch a mismatched kernel into an illegal address, so it is
122/// skipped in favor of a fresh source. Artifacts with no embedded canonical
123/// hash (e.g. an arch this build did not produce) are NOT treated as stale — we
124/// can only validate what we built — nor are unreadable files (the loader
125/// surfaces the IO error).
126fn staged_artifact_is_stale(path: &std::path::Path) -> bool {
127    let Some(file_name) = path.file_name().and_then(|n| n.to_str()) else {
128        return false;
129    };
130    let Some(expected) = crate::embedded_kernel_data::canonical_artifact_hash(file_name) else {
131        return false;
132    };
133    match std::fs::read(path) {
134        Ok(bytes) => fnv1a_64(&bytes) != expected,
135        Err(_) => false,
136    }
137}
138
139/// FNV-1a 64-bit, matching the build-time hash in `crates/xlog-cuda/build.rs`.
140fn fnv1a_64(bytes: &[u8]) -> u64 {
141    let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
142    for &byte in bytes {
143        hash ^= byte as u64;
144        hash = hash.wrapping_mul(0x0000_0100_0000_01b3);
145    }
146    hash
147}
148
149#[cfg(test)]
150mod kernel_source_resolution_tests {
151    use super::{
152        kernel_paths::KernelArtifactLocator, resolve_module_sources_with_locator,
153        KernelModuleSource,
154    };
155    use std::fs;
156
157    #[test]
158    fn keeps_portable_ptx_fallback_when_cubin_exists() {
159        let root = std::env::temp_dir().join(format!(
160            "xlog-kernel-fallback-{}-{}",
161            std::process::id(),
162            std::time::SystemTime::now()
163                .duration_since(std::time::UNIX_EPOCH)
164                .expect("system clock before UNIX_EPOCH")
165                .as_nanos()
166        ));
167        let kernels = root.join("kernels");
168        fs::create_dir_all(&kernels).expect("create kernels dir");
169        // Use a name this build does NOT produce, so neither file carries a
170        // canonical integrity hash (the staleness skip is exercised separately).
171        // This isolates the file-resolution precedence: cubin first, then the
172        // portable-PTX file as fallback.
173        fs::write(kernels.join("fakekernel.sm_86.cubin"), b"cubin").expect("write cubin");
174        fs::write(kernels.join("fakekernel.portable.ptx"), b"ptx").expect("write ptx");
175        let expected_cubin = kernels.join("fakekernel.sm_86.cubin");
176        let expected_ptx = kernels.join("fakekernel.portable.ptx");
177
178        let locator = KernelArtifactLocator::new(None, Some(kernels.clone()), None);
179        let sources = resolve_module_sources_with_locator("fakekernel", 86, &locator);
180
181        assert_eq!(sources.len(), 2);
182        assert!(matches!(
183            &sources[0],
184            KernelModuleSource::File {
185                path,
186                is_cubin: true
187            } if path == &expected_cubin
188        ));
189        assert!(matches!(
190            &sources[1],
191            KernelModuleSource::File {
192                path,
193                is_cubin: false
194            } if path == &expected_ptx
195        ));
196
197        fs::remove_dir_all(root).expect("remove temp kernels");
198    }
199
200    // Locks the FNV-1a contract between build.rs (which embeds canonical
201    // artifact hashes) and the runtime (which re-hashes staged artifacts). If
202    // these two implementations ever diverge, every staged artifact would read
203    // as "stale" — so these canonical FNV-1a 64-bit vectors must hold.
204    #[test]
205    fn fnv1a_64_matches_known_vectors() {
206        assert_eq!(super::fnv1a_64(b""), 0xcbf2_9ce4_8422_2325);
207        assert_eq!(super::fnv1a_64(b"a"), 0xaf63_dc4c_8601_ec8c);
208        assert_eq!(super::fnv1a_64(b"foobar"), 0x85944171_f73967e8);
209    }
210
211    // A file whose name this build did not produce has no canonical hash, so it
212    // is conservatively NOT treated as stale (we only validate what we built).
213    // A nonexistent path is likewise not "stale" — the loader surfaces IO.
214    #[test]
215    fn staged_artifact_not_stale_without_canonical_hash() {
216        let root = std::env::temp_dir().join(format!(
217            "xlog-kernel-stale-{}-{}",
218            std::process::id(),
219            std::time::SystemTime::now()
220                .duration_since(std::time::UNIX_EPOCH)
221                .expect("system clock before UNIX_EPOCH")
222                .as_nanos()
223        ));
224        fs::create_dir_all(&root).expect("create dir");
225        let unknown = root.join("definitely_not_a_real_kernel.sm_86.cubin");
226        fs::write(&unknown, b"bytes").expect("write");
227        assert!(!super::staged_artifact_is_stale(&unknown));
228        assert!(!super::staged_artifact_is_stale(
229            &root.join("missing.portable.ptx")
230        ));
231        fs::remove_dir_all(root).expect("remove temp dir");
232    }
233}
234
235/// Resolve a kernel module from sidecar artifacts or embedded portable PTX.
236///
237/// Asserts (in debug builds) that `name` is present in the kernel manifest,
238/// catching name/order drift between the manifest and provider load blocks.
239pub(crate) fn load_module_sources(name: &str, cc: u32) -> Result<Vec<KernelModuleSource>> {
240    debug_assert!(
241        crate::kernel_manifest_data::KERNEL_CU_NAMES.contains(&name),
242        "kernel module '{name}' is not in KERNEL_CU_NAMES manifest — update kernel_manifest_data.rs"
243    );
244    let locator = kernel_paths::KernelArtifactLocator::from_env();
245    let sources = resolve_module_sources_with_locator(name, cc, &locator);
246    if sources.is_empty() {
247        Err(XlogError::Kernel(format!(
248            "{name}: no cubin, sidecar portable PTX, or embedded portable PTX found"
249        )))
250    } else {
251        Ok(sources)
252    }
253}
254
255#[derive(Clone)]
256pub(crate) struct RawCudaView<'a, T> {
257    ptr: cudarc::driver::sys::CUdeviceptr,
258    len: usize,
259    stream: Arc<CudaStream>,
260    /// Optional back-reference to the source [`DeviceBlock`]
261    /// when this view borrows a region of a runtime-backed
262    /// allocation. The launch recorder uses this to attach
263    /// cross-stream uses without losing identity through view
264    /// construction. `None` for views built from external
265    /// memory or legacy paths; strict-mode launch recorders
266    /// reject `None` views.
267    ///
268    /// Read by [`RawCudaView::runtime_block`]; the field
269    /// itself is intentionally not directly exposed because
270    /// the lifetime of the back-reference is bound to the
271    /// view's `'a`.
272    #[allow(dead_code)]
273    source_block: Option<&'a crate::device_runtime::DeviceBlock>,
274    _marker: PhantomData<&'a [T]>,
275}
276
277/// Preallocated scratch layout for graph-capturable u32 multi-block scans.
278///
279/// The legacy stream-aware scan helper allocates recursive `block_sums`
280/// buffers inside the helper. CUDA Graph capture records concrete allocation
281/// addresses, so bounded CSM CUDA Graph replay needs the scan topology and
282/// scratch buffers to be fixed before capture begins.
283pub(crate) struct MultiblockScanScratchU32 {
284    levels: Vec<TrackedCudaSlice<u32>>,
285}
286
287impl MultiblockScanScratchU32 {
288    pub(crate) fn levels(&self) -> &[TrackedCudaSlice<u32>] {
289        &self.levels
290    }
291}
292
293pub(crate) struct CsmCudaGraphNodes {
294    pub(crate) count: CudaGraphNode,
295    pub(crate) total: CudaGraphNode,
296    pub(crate) materialize: CudaGraphNode,
297    pub(crate) node_count: usize,
298}
299
300pub(crate) struct CsmCudaGraphEntry {
301    pub(crate) graph: CapturedCudaGraph,
302    pub(crate) nodes: CsmCudaGraphNodes,
303    pub(crate) per_probe_count: TrackedCudaSlice<u32>,
304    pub(crate) per_probe_offsets: TrackedCudaSlice<u32>,
305    pub(crate) d_logical_count: TrackedCudaSlice<u32>,
306    pub(crate) d_overflow: TrackedCudaSlice<u8>,
307    pub(crate) d_output_left: TrackedCudaSlice<u32>,
308    pub(crate) d_output_right: TrackedCudaSlice<u32>,
309    pub(crate) scan_scratch: MultiblockScanScratchU32,
310    pub(crate) probe_capacity: u32,
311    pub(crate) output_capacity: u32,
312}
313
314impl<'a, T> DeviceSlice<T> for RawCudaView<'a, T> {
315    fn len(&self) -> usize {
316        self.len
317    }
318
319    fn stream(&self) -> &Arc<CudaStream> {
320        &self.stream
321    }
322}
323
324impl<'a, T> DevicePtr<T> for RawCudaView<'a, T> {
325    fn device_ptr<'b>(
326        &'b self,
327        _stream: &'b CudaStream,
328    ) -> (
329        cudarc::driver::sys::CUdeviceptr,
330        cudarc::driver::SyncOnDrop<'b>,
331    ) {
332        (self.ptr, cudarc::driver::SyncOnDrop::Sync(None))
333    }
334}
335
336impl<'a, T> RawCudaView<'a, T> {
337    pub fn device_ptr(&self) -> &cudarc::driver::sys::CUdeviceptr {
338        &self.ptr
339    }
340
341    /// Borrow the back-reference to the source
342    /// [`crate::device_runtime::DeviceBlock`], if this view was
343    /// constructed from a runtime-backed allocation. Returns
344    /// `None` for views built from external memory or legacy
345    /// paths.
346    ///
347    /// Public API reserved for the filter-class migration; no
348    /// production caller exists yet.
349    #[allow(dead_code)]
350    pub fn runtime_block(&self) -> Option<&'a crate::device_runtime::DeviceBlock> {
351        self.source_block
352    }
353}
354
355impl<'a, T: DeviceRepr> AsKernelParam for &RawCudaView<'a, T> {
356    fn as_kernel_param(&self) -> *mut c_void {
357        ((*self).device_ptr() as *const cudarc::driver::sys::CUdeviceptr)
358            .cast_mut()
359            .cast()
360    }
361}
362
363impl<'a, T: DeviceRepr> IntoKernelParamStorage for &'a RawCudaView<'a, T> {
364    type Storage = DeviceParamStorage<'a>;
365
366    fn into_kernel_param_storage(self) -> Self::Storage {
367        DeviceParamStorage::unsynced(self.ptr)
368    }
369}
370
371/// Scratch buffers for stable radix sorting of u32 key/value pairs.
372pub struct RadixSortScratch {
373    keys_b: TrackedCudaSlice<u32>,
374    values_b: TrackedCudaSlice<u32>,
375    hist: TrackedCudaSlice<u32>,
376    prefix: TrackedCudaSlice<u32>,
377    ranks: TrackedCudaSlice<u32>,
378    len: u32,
379}
380
381impl RadixSortScratch {
382    pub fn new(provider: &CudaKernelProvider, n: u32) -> Result<Self> {
383        let memory = provider.memory();
384        let len = n.max(1);
385        let keys_b = memory.alloc::<u32>(len as usize)?;
386        let values_b = memory.alloc::<u32>(len as usize)?;
387        let ranks = memory.alloc::<u32>(len as usize)?;
388        let block_size = CudaKernelProvider::SORT_BLOCK_SIZE;
389        let grid_size = len.div_ceil(block_size).max(1);
390        let hist = memory.alloc::<u32>((grid_size as usize) * 16)?;
391        let prefix = memory.alloc::<u32>(16)?;
392        Ok(Self {
393            keys_b,
394            values_b,
395            hist,
396            prefix,
397            ranks,
398            len,
399        })
400    }
401
402    pub fn ensure_capacity(&mut self, provider: &CudaKernelProvider, n: u32) -> Result<()> {
403        if n <= self.len {
404            return Ok(());
405        }
406        *self = Self::new(provider, n)?;
407        Ok(())
408    }
409}
410
411/// Module names for loaded PTX modules
412pub const JOIN_MODULE: &str = "xlog_join";
413pub const DEDUP_MODULE: &str = "xlog_dedup";
414pub const GROUPBY_MODULE: &str = "xlog_groupby";
415pub const SCAN_MODULE: &str = "xlog_scan";
416pub const SORT_MODULE: &str = "xlog_sort";
417pub const FILTER_MODULE: &str = "xlog_filter";
418pub const SET_OPS_MODULE: &str = "xlog_set_ops";
419pub const PACK_MODULE: &str = "xlog_pack";
420pub const CIRCUIT_MODULE: &str = "xlog_circuit";
421pub const MC_SAMPLE_MODULE: &str = "xlog_mc_sample";
422pub const MC_EVAL_MODULE: &str = "xlog_mc_eval";
423pub const MC_RESIDENT_MODULE: &str = "xlog_mc_resident";
424pub const ARITH_MODULE: &str = "xlog_arith";
425pub const SAT_MODULE: &str = "xlog_sat";
426pub const D4_MODULE: &str = "xlog_d4";
427pub const NEURAL_MODULE: &str = "xlog_neural";
428pub const PIR_MODULE: &str = "xlog_pir";
429pub const CNF_MODULE: &str = "xlog_cnf";
430pub const CACHE_MODULE: &str = "xlog_cache";
431pub const WEIGHTS_MODULE: &str = "xlog_weights";
432pub const ILP_MODULE: &str = "xlog_ilp";
433pub const ILP_CREDIT_MODULE: &str = "xlog_ilp_credit";
434pub const ILP_EXACT_MODULE: &str = "xlog_ilp_exact";
435pub const EPISTEMIC_MODULE: &str = "xlog_epistemic";
436pub const WCOJ_MODULE: &str = "xlog_wcoj";
437
438// Compile-time check: kernel manifest lists exactly 25 modules.
439const _: () = assert!(crate::kernel_manifest_data::KERNEL_CU_NAMES.len() == 25);
440
441/// Kernel function names in the GPU WCOJ module.
442pub mod wcoj_kernels {
443    pub const WCOJ_BUILD_METADATA_MARK_BOUNDARIES_U32: &str =
444        "wcoj_build_metadata_mark_boundaries_u32";
445    pub const WCOJ_BUILD_METADATA_MARK_BOUNDARIES_U64: &str =
446        "wcoj_build_metadata_mark_boundaries_u64";
447    pub const WCOJ_BUILD_METADATA_SCATTER_U32: &str = "wcoj_build_metadata_scatter_u32";
448    pub const WCOJ_BUILD_METADATA_SCATTER_U64: &str = "wcoj_build_metadata_scatter_u64";
449    pub const WCOJ_TRIANGLE_BUILD_HG_WORK_PLAN_U32: &str = "wcoj_triangle_build_hg_work_plan_u32";
450    pub const WCOJ_TRIANGLE_COUNT_HG_U32: &str = "wcoj_triangle_count_hg_u32";
451    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_COUNT_HG_U32: &str =
452        "wcoj_triangle_groupby_root_count_hg_u32";
453    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_SUM_HG_U32: &str = "wcoj_triangle_groupby_root_sum_hg_u32";
454    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MIN_HG_U32: &str = "wcoj_triangle_groupby_root_min_hg_u32";
455    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MAX_HG_U32: &str = "wcoj_triangle_groupby_root_max_hg_u32";
456    pub const WCOJ_TRIANGLE_MATERIALIZE_HG_U32: &str = "wcoj_triangle_materialize_hg_u32";
457    pub const WCOJ_TRIANGLE_BUILD_HG_WORK_PLAN_U64: &str = "wcoj_triangle_build_hg_work_plan_u64";
458    pub const WCOJ_TRIANGLE_COUNT_HG_U64: &str = "wcoj_triangle_count_hg_u64";
459    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_COUNT_HG_U64: &str =
460        "wcoj_triangle_groupby_root_count_hg_u64";
461    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_SUM_HG_U64: &str = "wcoj_triangle_groupby_root_sum_hg_u64";
462    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MIN_HG_U64: &str = "wcoj_triangle_groupby_root_min_hg_u64";
463    pub const WCOJ_TRIANGLE_GROUPBY_ROOT_MAX_HG_U64: &str = "wcoj_triangle_groupby_root_max_hg_u64";
464    pub const WCOJ_GROUPBY_ROOT_SEGMENT_SUM_COUNTS_U32: &str =
465        "wcoj_groupby_root_segment_sum_counts_u32";
466    pub const WCOJ_GROUPBY_ROOT_SEGMENT_SUM_VALUES_U64: &str =
467        "wcoj_groupby_root_segment_sum_values_u64";
468    pub const WCOJ_GROUPBY_ROOT_SEGMENT_MIN_VALUES_U64: &str =
469        "wcoj_groupby_root_segment_min_values_u64";
470    pub const WCOJ_GROUPBY_ROOT_SEGMENT_MAX_VALUES_U64: &str =
471        "wcoj_groupby_root_segment_max_values_u64";
472    pub const WCOJ_TRIANGLE_MATERIALIZE_HG_U64: &str = "wcoj_triangle_materialize_hg_u64";
473    pub const WCOJ_TRIANGLE_COUNT_HG_CACHED_U32: &str = "wcoj_triangle_count_hg_cached_u32";
474    pub const WCOJ_TRIANGLE_MATERIALIZE_HG_CACHED_U32: &str =
475        "wcoj_triangle_materialize_hg_cached_u32";
476    pub const WCOJ_SCAN_HG_BLOCK_COUNTS_U32: &str = "wcoj_scan_hg_block_counts_u32";
477    pub const WCOJ_COMPUTE_TOTAL: &str = "wcoj_compute_total";
478    pub const WCOJ_LAYOUT_CHECK_SORTED_UNIQUE_U32: &str = "wcoj_layout_check_sorted_unique_u32";
479    pub const WCOJ_LAYOUT_CHECK_SORTED_UNIQUE_U64: &str = "wcoj_layout_check_sorted_unique_u64";
480    pub const WCOJ_4CYCLE_BUILD_E2_WORK_PREFIX_U32: &str = "wcoj_4cycle_build_e2_work_prefix_u32";
481    pub const WCOJ_4CYCLE_BUILD_HG_WORK_PLAN_U32: &str = "wcoj_4cycle_build_hg_work_plan_u32";
482    pub const WCOJ_4CYCLE_COUNT_HG_U32: &str = "wcoj_4cycle_count_hg_u32";
483    pub const WCOJ_4CYCLE_GROUPBY_ROOT_COUNT_HG_U32: &str = "wcoj_4cycle_groupby_root_count_hg_u32";
484    pub const WCOJ_4CYCLE_GROUPBY_ROOT_SUM_HG_U32: &str = "wcoj_4cycle_groupby_root_sum_hg_u32";
485    pub const WCOJ_4CYCLE_GROUPBY_ROOT_MIN_HG_U32: &str = "wcoj_4cycle_groupby_root_min_hg_u32";
486    pub const WCOJ_4CYCLE_GROUPBY_ROOT_MAX_HG_U32: &str = "wcoj_4cycle_groupby_root_max_hg_u32";
487    pub const WCOJ_4CYCLE_MATERIALIZE_HG_U32: &str = "wcoj_4cycle_materialize_hg_u32";
488    pub const WCOJ_4CYCLE_BUILD_E2_WORK_PREFIX_U64: &str = "wcoj_4cycle_build_e2_work_prefix_u64";
489    pub const WCOJ_4CYCLE_BUILD_HG_WORK_PLAN_U64: &str = "wcoj_4cycle_build_hg_work_plan_u64";
490    pub const WCOJ_4CYCLE_COUNT_HG_U64: &str = "wcoj_4cycle_count_hg_u64";
491    pub const WCOJ_4CYCLE_GROUPBY_ROOT_COUNT_HG_U64: &str = "wcoj_4cycle_groupby_root_count_hg_u64";
492    pub const WCOJ_4CYCLE_MATERIALIZE_HG_U64: &str = "wcoj_4cycle_materialize_hg_u64";
493    // General-arity clique kernels (k=5..8 from a single template).
494    pub const WCOJ_CLIQUE5_COUNT_HG_U32: &str = "wcoj_clique5_count_hg_u32";
495    pub const WCOJ_CLIQUE5_MATERIALIZE_HG_U32: &str = "wcoj_clique5_materialize_hg_u32";
496    pub const WCOJ_CLIQUE5_COUNT_HG_U64: &str = "wcoj_clique5_count_hg_u64";
497    pub const WCOJ_CLIQUE5_MATERIALIZE_HG_U64: &str = "wcoj_clique5_materialize_hg_u64";
498    pub const WCOJ_CLIQUE6_COUNT_HG_U32: &str = "wcoj_clique6_count_hg_u32";
499    pub const WCOJ_CLIQUE6_MATERIALIZE_HG_U32: &str = "wcoj_clique6_materialize_hg_u32";
500    pub const WCOJ_CLIQUE6_COUNT_HG_U64: &str = "wcoj_clique6_count_hg_u64";
501    pub const WCOJ_CLIQUE6_MATERIALIZE_HG_U64: &str = "wcoj_clique6_materialize_hg_u64";
502    pub const WCOJ_CLIQUE7_COUNT_HG_U32: &str = "wcoj_clique7_count_hg_u32";
503    pub const WCOJ_CLIQUE7_MATERIALIZE_HG_U32: &str = "wcoj_clique7_materialize_hg_u32";
504    pub const WCOJ_CLIQUE7_COUNT_HG_U64: &str = "wcoj_clique7_count_hg_u64";
505    pub const WCOJ_CLIQUE7_MATERIALIZE_HG_U64: &str = "wcoj_clique7_materialize_hg_u64";
506    pub const WCOJ_CLIQUE8_COUNT_HG_U32: &str = "wcoj_clique8_count_hg_u32";
507    pub const WCOJ_CLIQUE8_MATERIALIZE_HG_U32: &str = "wcoj_clique8_materialize_hg_u32";
508    pub const WCOJ_CLIQUE8_COUNT_HG_U64: &str = "wcoj_clique8_count_hg_u64";
509    pub const WCOJ_CLIQUE8_MATERIALIZE_HG_U64: &str = "wcoj_clique8_materialize_hg_u64";
510    pub const WCOJ_CLIQUE5_GROUPBY_ROOT_COUNT_HG_U32: &str =
511        "wcoj_clique5_groupby_root_count_hg_u32";
512    pub const WCOJ_CLIQUE6_GROUPBY_ROOT_COUNT_HG_U32: &str =
513        "wcoj_clique6_groupby_root_count_hg_u32";
514    // Free Join frontier engine primitives. The work
515    // prefix kernel is width-agnostic (ranges are u32 row indices in
516    // every width class); count/emit/probe have u64 data twins.
517    pub const FJ_EXPAND_WORK_PREFIX_U32: &str = "fj_expand_work_prefix_u32";
518    pub const FJ_EXPAND_COUNT_U32: &str = "fj_expand_count_u32";
519    pub const FJ_EXPAND_EMIT_U32: &str = "fj_expand_emit_u32";
520    pub const FJ_PROBE_REFINE_U32: &str = "fj_probe_refine_u32";
521    pub const FJ_EXPAND_COUNT_U64: &str = "fj_expand_count_u64";
522    pub const FJ_EXPAND_EMIT_U64: &str = "fj_expand_emit_u64";
523    pub const FJ_PROBE_REFINE_U64: &str = "fj_probe_refine_u64";
524    pub const FJ_COUNT_MULTIPLICITY: &str = "fj_count_multiplicity";
525    // D3 S3 spike — factorized recursive delta novel-set pipeline.
526    pub const FJ_DELTA_RANGE_U32: &str = "fj_delta_range_u32";
527    pub const FJ_DELTA_MARK_U32: &str = "fj_delta_mark_u32";
528    pub const FJ_DELTA_SUBTRACT_U32: &str = "fj_delta_subtract_u32";
529    pub const FJ_DELTA_POPCOUNT: &str = "fj_delta_popcount";
530    pub const FJ_DELTA_EMIT_U32: &str = "fj_delta_emit_u32";
531    pub const FJ_DELTA_MAX_U32: &str = "fj_delta_max_u32";
532    pub const FJ_DELTA_SPARSE_ESTIMATE: &str = "fj_delta_sparse_estimate";
533    pub const FJ_DELTA_SPARSE_LOAD_R: &str = "fj_delta_sparse_load_r";
534    pub const FJ_DELTA_SPARSE_INSERT_CANDIDATES: &str = "fj_delta_sparse_insert_candidates";
535    pub const FJ_DELTA_SPARSE_MARK: &str = "fj_delta_sparse_mark";
536    pub const FJ_DELTA_SPARSE_EMIT: &str = "fj_delta_sparse_emit";
537}
538
539/// Kernel function names in the Monte Carlo sampling module
540pub mod mc_sample_kernels {
541    pub const MC_SAMPLE_BERNOULLI: &str = "mc_sample_bernoulli";
542}
543
544/// Kernel function names in the Monte Carlo evaluation module
545pub mod mc_eval_kernels {
546    pub const MC_EVAL_MASK_VAR: &str = "mc_eval_mask_var";
547    pub const MC_EVAL_MASK_AD: &str = "mc_eval_mask_ad_choice";
548    pub const MC_EVAL_QUERY_EVIDENCE_TRUTH: &str = "mc_eval_query_evidence_truth";
549    pub const MC_EVAL_ACCUMULATE_COUNTS: &str = "mc_accumulate_counts";
550}
551
552/// Kernel function names in the GPU-resident Datalog/MC engine module.
553pub mod mc_resident_kernels {
554    /// Single megakernel: evaluates all MC worlds to fixpoint and counts
555    /// query/evidence satisfaction with zero host interaction in-region.
556    pub const MC_RESIDENT_ENGINE: &str = "mc_resident_engine";
557}
558
559/// Kernel function names in the arithmetic module
560pub mod arith_kernels {
561    pub const ARITH_BINARY_I64: &str = "arith_binary_i64";
562    pub const ARITH_BINARY_I32: &str = "arith_binary_i32";
563    pub const ARITH_BINARY_U64: &str = "arith_binary_u64";
564    pub const ARITH_BINARY_U32: &str = "arith_binary_u32";
565    pub const ARITH_BINARY_F64: &str = "arith_binary_f64";
566    pub const ARITH_BINARY_F32: &str = "arith_binary_f32";
567    pub const ARITH_ABS_I64: &str = "arith_abs_i64";
568    pub const ARITH_ABS_I32: &str = "arith_abs_i32";
569    pub const ARITH_ABS_F64: &str = "arith_abs_f64";
570    pub const ARITH_ABS_F32: &str = "arith_abs_f32";
571    pub const ARITH_POW_F64: &str = "arith_pow_f64";
572    pub const ARITH_CAST: &str = "arith_cast";
573    pub const ARITH_FILL_CONST_U32: &str = "arith_fill_const_u32";
574    pub const ARITH_FILL_CONST_U64: &str = "arith_fill_const_u64";
575    pub const ARITH_FILL_CONST_I64: &str = "arith_fill_const_i64";
576    pub const ARITH_FILL_CONST_I32: &str = "arith_fill_const_i32";
577    pub const ARITH_FILL_CONST_F64: &str = "arith_fill_const_f64";
578    pub const ARITH_FILL_CONST_F32: &str = "arith_fill_const_f32";
579    pub const ARITH_FILL_CONST_U8: &str = "arith_fill_const_u8";
580    // Conditional select kernels
581    pub const ARITH_SELECT_I64: &str = "arith_select_i64";
582    pub const ARITH_SELECT_I32: &str = "arith_select_i32";
583    pub const ARITH_SELECT_U64: &str = "arith_select_u64";
584    pub const ARITH_SELECT_U32: &str = "arith_select_u32";
585    pub const ARITH_SELECT_F64: &str = "arith_select_f64";
586    pub const ARITH_SELECT_F32: &str = "arith_select_f32";
587}
588
589/// Kernel function names in the epistemic module.
590pub mod epistemic_kernels {
591    /// Device-side epistemic candidate-assumption generator.
592    pub const EPISTEMIC_GENERATE_CANDIDATE_ASSUMPTIONS_U8: &str =
593        "epistemic_generate_candidate_assumptions_u8";
594    /// Device-side epistemic candidate propagation staging kernel.
595    pub const EPISTEMIC_PROPAGATE_CANDIDATES_U8: &str = "epistemic_propagate_candidates_u8";
596    /// Device-side epistemic candidate bit validation kernel.
597    pub const EPISTEMIC_VALIDATE_CANDIDATE_BITS_U8: &str = "epistemic_validate_candidate_bits_u8";
598    /// Device-side model-membership staging kernel.
599    pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_U8: &str =
600        "epistemic_populate_model_membership_u8";
601    /// Device-side tuple-source-backed model-membership kernel.
602    pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_U8: &str =
603        "epistemic_populate_model_membership_from_tuple_source_u8";
604    /// Device-side arity-one tuple-key-backed model-membership kernel.
605    pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY1_U8: &str =
606        "epistemic_populate_model_membership_from_tuple_source_arity1_u8";
607    /// Device-side arity-two tuple-key-backed model-membership kernel.
608    pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY2_U8: &str =
609        "epistemic_populate_model_membership_from_tuple_source_arity2_u8";
610    /// Device-side arity-three tuple-key-backed model-membership kernel.
611    pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY3_U8: &str =
612        "epistemic_populate_model_membership_from_tuple_source_arity3_u8";
613    /// Device-side generic-arity tuple-key-backed model-membership kernel.
614    pub const EPISTEMIC_POPULATE_MODEL_MEMBERSHIP_FROM_TUPLE_SOURCE_ARITY_N_U8: &str =
615        "epistemic_populate_model_membership_from_tuple_source_arity_n_u8";
616    /// Device-side world-view validation kernel.
617    pub const EPISTEMIC_VALIDATE_WORLD_VIEWS_U8: &str = "epistemic_validate_world_views_u8";
618    /// Device-side world-view integrity-constraint validation kernel.
619    pub const EPISTEMIC_VALIDATE_CONSTRAINTS_U8: &str = "epistemic_validate_constraints_u8";
620    /// Device-side accepted-candidate materialization staging kernel.
621    pub const EPISTEMIC_MATERIALIZE_ACCEPTED_CANDIDATES_U8: &str =
622        "epistemic_materialize_accepted_candidates_u8";
623
624    /// Device-side final-result flag materialization staging kernel.
625    pub const EPISTEMIC_MATERIALIZE_FINAL_RESULT_FLAGS_U8: &str =
626        "epistemic_materialize_final_result_flags_u8";
627    /// Device-side final tuple materialization kernel.
628    pub const EPISTEMIC_MATERIALIZE_FINAL_TUPLE_COLUMN_U8: &str =
629        "epistemic_materialize_final_tuple_column_u8";
630    /// Device-side final tuple row-map kernel.
631    pub const EPISTEMIC_BUILD_FINAL_TUPLE_ROW_MAP_U8: &str =
632        "epistemic_build_final_tuple_row_map_u8";
633    /// Device-side final tuple rejection-close kernel.
634    pub const EPISTEMIC_CLOSE_FINAL_TUPLE_REJECTIONS_U8: &str =
635        "epistemic_close_final_tuple_rejections_u8";
636}
637
638/// Kernel function names in the neural fast-path module.
639pub mod neural_kernels {
640    pub const NEURAL_FILL_AD_CHAIN_F32: &str = "neural_fill_ad_chain_f32";
641    pub const NEURAL_SCATTER_AD_CHAIN_GRADS_F32: &str = "neural_scatter_ad_chain_grads_f32";
642}
643
644/// Kernel function names in the ILP module.
645pub mod ilp_kernels {
646    pub const EXTRACT_NONZERO_INDICES: &str = "extract_nonzero_indices";
647    pub const ILP_MARK_SELECTED_IDS_U32: &str = "ilp_mark_selected_ids_u32";
648    pub const ILP_MARK_SELECTED_IDS_I32: &str = "ilp_mark_selected_ids_i32";
649    pub const ILP_MARK_SELECTED_IDS_I64: &str = "ilp_mark_selected_ids_i64";
650    pub const ILP_MARK_SELECTED_IDS_U64: &str = "ilp_mark_selected_ids_u64";
651    pub const ILP_VALIDATE_SELECTED_IDS_U32: &str = "ilp_validate_selected_ids_u32";
652    pub const ILP_VALIDATE_SELECTED_IDS_I32: &str = "ilp_validate_selected_ids_i32";
653    pub const ILP_VALIDATE_SELECTED_IDS_I64: &str = "ilp_validate_selected_ids_i64";
654    pub const ILP_VALIDATE_SELECTED_IDS_U64: &str = "ilp_validate_selected_ids_u64";
655    pub const ILP_BROADCAST_CANDIDATE_FLAG: &str = "ilp_broadcast_candidate_flag";
656    pub const ILP_COO_FILL_FROM_MASK: &str = "ilp_coo_fill_from_mask";
657    pub const ILP_CSR_HISTOGRAM: &str = "ilp_csr_histogram";
658    pub const ILP_REDUCE_SUM_F32: &str = "ilp_reduce_sum_f32";
659    pub const ILP_REDUCE_SUM_F64: &str = "ilp_reduce_sum_f64";
660}
661
662/// Kernel function names in the ILP credit module.
663pub mod ilp_credit_kernels {
664    pub const ILP_COO_FILL: &str = "ilp_coo_fill";
665    pub const ILP_CREDIT_FORWARD_F32: &str = "ilp_credit_forward_f32";
666    pub const ILP_CREDIT_FORWARD_F64: &str = "ilp_credit_forward_f64";
667    pub const ILP_CREDIT_BACKWARD_F32: &str = "ilp_credit_backward_f32";
668    pub const ILP_CREDIT_BACKWARD_F64: &str = "ilp_credit_backward_f64";
669}
670
671/// Kernel function names in the native bounded exact-induction module.
672pub mod ilp_exact_kernels {
673    pub const ILP_EXACT_SCORE: &str = "ilp_exact_score";
674    pub const ILP_EXACT_SCORE_U32: &str = "ilp_exact_score_u32";
675    pub const ILP_EXACT_SCORE_CHAIN_SMEM: &str = "ilp_exact_score_chain_smem";
676    pub const ILP_EXACT_SCORE_CHAIN_SMEM_U32: &str = "ilp_exact_score_chain_smem_u32";
677    pub const ILP_EXACT_SELECT_TOPK: &str = "ilp_exact_select_topk";
678}
679
680/// Kernel function names in the PIR interning module.
681pub mod pir_kernels {
682    pub const PIR_PACK_KEYS: &str = "pir_pack_keys";
683    pub const PIR_HASH_KEYS: &str = "pir_hash_keys";
684    pub const PIR_MARK_UNIQUE: &str = "pir_mark_unique";
685    pub const PIR_FIND_EXISTING: &str = "pir_find_existing";
686    pub const PIR_MARK_NEW_GROUPS: &str = "pir_mark_new_groups";
687    pub const PIR_BUILD_GROUP_IDS: &str = "pir_build_group_ids";
688    pub const PIR_FILL_CHILD_PARENTS: &str = "pir_fill_child_parents";
689    pub const PIR_MARK_UNIQUE_PAIRS: &str = "pir_mark_unique_pairs";
690    pub const PIR_COMPACT_PAIRS: &str = "pir_compact_pairs";
691    pub const PIR_COUNT_CHILDREN: &str = "pir_count_children";
692    pub const PIR_WRITE_CHILD_OFFSETS: &str = "pir_write_child_offsets";
693    pub const PIR_GATHER_CHILDREN: &str = "pir_gather_children";
694    pub const PIR_BUILD_GRAPH_CHILD_COUNTS: &str = "pir_build_graph_child_counts";
695    pub const PIR_SUM_COUNTS: &str = "pir_sum_counts";
696    pub const PIR_EMIT_NODES_AND_IDS: &str = "pir_emit_nodes_and_ids";
697    pub const PIR_UPDATE_COUNTS: &str = "pir_update_counts";
698}
699
700/// Kernel function names in the GPU CNF encoder module.
701pub mod cnf_kernels {
702    pub const CNF_REACHABILITY_INIT: &str = "cnf_reachability_init";
703    pub const CNF_REACHABILITY_BFS: &str = "cnf_reachability_bfs";
704    pub const CNF_MARK_LEAF_CHOICE: &str = "cnf_mark_leaf_choice";
705    pub const CNF_ASSIGN_LEAF_VAR: &str = "cnf_assign_leaf_var";
706    pub const CNF_ASSIGN_CHOICE_VAR: &str = "cnf_assign_choice_var";
707    pub const CNF_MARK_NODE_VARS: &str = "cnf_mark_node_vars";
708    pub const CNF_COUNT_CLAUSES: &str = "cnf_count_clauses";
709    pub const CNF_CAPTURE_LAST_COUNTS: &str = "cnf_capture_last_counts";
710    pub const CNF_COMPUTE_LEAF_CHOICE_TOTALS: &str = "cnf_compute_leaf_choice_totals";
711    pub const CNF_COMPUTE_TOTALS: &str = "cnf_compute_totals";
712    pub const CNF_ASSIGN_NODE_VAR: &str = "cnf_assign_node_var";
713    pub const CNF_EMIT_CLAUSES: &str = "cnf_emit_clauses";
714    pub const CNF_SET_CLAUSE_END: &str = "cnf_set_clause_end";
715}
716
717/// Kernel function names in the weights module.
718pub mod weights_kernels {
719    pub const WEIGHTS_FILL_LEAF: &str = "weights_fill_leaf";
720    pub const WEIGHTS_FILL_CHOICE: &str = "weights_fill_choice";
721    pub const WEIGHTS_COUNT_LIFT_EXACT: &str = "weights_count_lift_exact";
722    pub const WEIGHTS_SET_EVIDENCE_FROM_NODES: &str = "weights_set_evidence_from_nodes";
723    pub const WEIGHTS_APPLY_EVIDENCE: &str = "weights_apply_evidence";
724    pub const WEIGHTS_MAP_NODES_TO_VARS: &str = "weights_map_nodes_to_vars";
725    pub const WEIGHTS_FORCE_VAR_FALSE: &str = "weights_force_var_false";
726    pub const WEIGHTS_RESTORE_VAR_FALSE: &str = "weights_restore_var_false";
727    pub const WEIGHTS_FORCE_VAR_TRUE: &str = "weights_force_var_true";
728    pub const WEIGHTS_RESTORE_VAR_TRUE: &str = "weights_restore_var_true";
729    pub const WEIGHTS_COPY_SLOT_TO_BATCH: &str = "weights_copy_slot_to_batch";
730    pub const WEIGHTS_APPLY_QUERY_VARS: &str = "weights_apply_query_vars";
731    pub const WEIGHTS_RESTORE_QUERY_VARS: &str = "weights_restore_query_vars";
732    pub const WEIGHTS_APPLY_QUERY_VARS_FALSE_BATCHED: &str =
733        "weights_apply_query_vars_false_batched";
734    pub const WEIGHTS_RESTORE_QUERY_VARS_FALSE_BATCHED: &str =
735        "weights_restore_query_vars_false_batched";
736    pub const WEIGHTS_APPLY_QUERY_VARS_TRUE_BATCHED: &str = "weights_apply_query_vars_true_batched";
737    pub const WEIGHTS_RESTORE_QUERY_VARS_TRUE_BATCHED: &str =
738        "weights_restore_query_vars_true_batched";
739}
740
741/// Kernel function names in the GPU Decision-DNNF compiler module
742/// (CNF validation + circuit levelization).
743pub mod d4_kernels {
744    pub const D4_VALIDATE_CNF: &str = "d4_validate_cnf";
745    pub const D4_LEVELIZE_COUNTS: &str = "d4_levelize_counts";
746    pub const D4_LEVELIZE_EMIT: &str = "d4_levelize_emit";
747    // BFS frontier expansion and unit propagation.
748    pub const D4_FRONTIER_PREPARE: &str = "d4_frontier_prepare";
749    pub const D4_FRONTIER_EXPAND: &str = "d4_frontier_expand";
750    pub const D4_FRONTIER_PREPARE_DENSE: &str = "d4_frontier_prepare_dense";
751    pub const D4_FRONTIER_EXPAND_DENSE: &str = "d4_frontier_expand_dense";
752    // Per-frontier Decision-DNNF DFS worker (count+emit).
753    pub const D4_COMPILE_COUNT: &str = "d4_compile_count";
754    pub const D4_COMPILE_EMIT: &str = "d4_compile_emit";
755    pub const D4_CAPTURE_EMIT_META: &str = "d4_capture_emit_meta";
756    // GPU smoothing with random-variable support and wrapper emission.
757    pub const D4_SUPPORT_LEVEL: &str = "d4_support_level";
758    pub const D4_SUPPORT_SET_ROOT_BITS: &str = "d4_support_set_root_bits";
759    pub const D4_SMOOTH_COUNT: &str = "d4_smooth_count";
760    pub const D4_SMOOTH_WRAPPER_COUNTS: &str = "d4_smooth_wrapper_counts";
761    pub const D4_SMOOTH_WRAPPER_EDGE_COUNTS_OR: &str = "d4_smooth_wrapper_edge_counts_or";
762    pub const D4_SMOOTH_WRAPPER_EDGE_COUNTS_DEC: &str = "d4_smooth_wrapper_edge_counts_dec";
763    pub const D4_SMOOTH_INIT_NODES: &str = "d4_smooth_init_nodes";
764    pub const D4_SMOOTH_EMIT_LEVEL: &str = "d4_smooth_emit_level";
765    pub const D4_SMOOTH_CHECK_EDGE_CAP: &str = "d4_smooth_check_edge_cap";
766    // GPU free-variable mask for variables in clauses versus the circuit.
767    pub const D4_MARK_VARS_IN_CLAUSES: &str = "d4_mark_vars_in_clauses";
768    pub const D4_MARK_VARS_IN_CIRCUIT: &str = "d4_mark_vars_in_circuit";
769    pub const D4_BUILD_FREE_VAR_MASK: &str = "d4_build_free_var_mask";
770    // GPU-only assertions (tests + invariant enforcement without host reads).
771    pub const D4_ASSERT_U32_EQ: &str = "d4_assert_u32_eq";
772    pub const D4_ASSERT_BITSET_VAR: &str = "d4_assert_bitset_var";
773    pub const D4_ASSERT_DENSE_VAR: &str = "d4_assert_dense_var";
774    pub const D4_ASSERT_LEAF_ROOT_AND_DEGREE: &str = "d4_assert_leaf_root_and_degree";
775}
776
777/// Kernel function names in the join module
778pub mod join_kernels {
779    pub const HASH_JOIN_BUILD: &str = "hash_join_build";
780    pub const HASH_JOIN_PROBE: &str = "hash_join_probe";
781    // V2 kernels for multi-column joins
782    pub const COMPUTE_COMPOSITE_HASH: &str = "compute_composite_hash";
783    pub const HASH_JOIN_BUCKET_COUNT_V2: &str = "hash_join_bucket_count_v2";
784    pub const HASH_JOIN_SCATTER_V2: &str = "hash_join_scatter_v2";
785    pub const HASH_JOIN_PROBE_V2: &str = "hash_join_probe_v2";
786    pub const HASH_JOIN_PROBE_V2_COUNT_PER_ROW: &str = "hash_join_probe_v2_count_per_row";
787    pub const HASH_JOIN_PROBE_V2_MATERIALIZE: &str = "hash_join_probe_v2_materialize";
788    pub const HASH_JOIN_TOTAL_FROM_SCAN: &str = "hash_join_total_from_scan";
789    pub const HASH_JOIN_CSM_UNMATCHED_MASK: &str = "hash_join_csm_unmatched_mask";
790    pub const HASH_JOIN_SEMI: &str = "hash_join_semi";
791    pub const HASH_JOIN_ANTI: &str = "hash_join_anti";
792    pub const INIT_HASH_TABLE: &str = "init_hash_table";
793    /// Nested-loop inner join (emit-pairs design). Reads
794    /// the single key column from each side; emits matched
795    /// `(left_idx, right_idx)` pairs as two parallel u32 arrays.
796    /// Payload columns are materialized after the kernel via
797    /// `gather_buffer_by_indices` in the provider fn.
798    pub const NESTED_LOOP_JOIN_INNER_U32_1KEY_PAIRS: &str = "nested_loop_join_inner_u32_1key_pairs";
799    /// Sort-merge inner join (emit-pairs design,
800    /// caller-asserted pre-sorted inputs). Reads the single
801    /// key column from each side, performs per-thread binary
802    /// search on the right side to find matched-key runs,
803    /// emits `(left_idx, right_idx)` pairs as two parallel
804    /// u32 arrays. Payload columns materialize after the
805    /// kernel via `gather_buffer_by_indices`.
806    pub const SORT_MERGE_JOIN_INNER_U32_1KEY_PAIRS: &str = "sort_merge_join_inner_u32_1key_pairs";
807}
808
809/// Kernel function names in the dedup module
810pub mod dedup_kernels {
811    pub const MARK_DUPLICATES: &str = "mark_duplicates";
812    pub const MARK_UNIQUE_COLUMNAR: &str = "mark_unique_columnar";
813    pub const MARK_UNIQUE_AND_SCAN_COLUMNAR: &str = "mark_unique_and_scan_columnar";
814    pub const COMPACT_ROWS: &str = "compact_rows";
815    pub const MARK_UNIQUE_FULL_ROW_BYTEWISE: &str = "mark_unique_full_row_bytewise";
816    pub const MARK_DIFF_FULL_ROW_TYPED_SORTED: &str = "mark_diff_full_row_typed_sorted";
817    pub const SMALL_SORT_FULL_ROW_INDICES_TYPED: &str = "small_sort_full_row_indices_typed";
818}
819
820/// Kernel function names in the groupby module
821pub mod groupby_kernels {
822    pub const DETECT_GROUP_BOUNDARIES: &str = "detect_group_boundaries";
823    pub const DETECT_BOUNDARIES: &str = "detect_boundaries";
824    pub const EXTRACT_GROUP_KEYS: &str = "extract_group_keys";
825    pub const GROUP_IDS_FROM_BOUNDARIES: &str = "group_ids_from_boundaries";
826    pub const GROUP_START_INDICES: &str = "group_start_indices";
827    pub const CAPTURE_NUM_GROUPS: &str = "capture_num_groups";
828    pub const GROUPBY_COUNT: &str = "groupby_count";
829    pub const GROUPBY_SUM: &str = "groupby_sum";
830    pub const GROUPBY_SUM_U64: &str = "groupby_sum_u64";
831    pub const GROUPBY_MIN: &str = "groupby_min";
832    pub const GROUPBY_MIN_U64: &str = "groupby_min_u64";
833    pub const GROUPBY_MAX: &str = "groupby_max";
834    pub const GROUPBY_MAX_U64: &str = "groupby_max_u64";
835    pub const GROUPBY_LOGSUMEXP_MAX: &str = "groupby_logsumexp_max";
836    pub const GROUPBY_LOGSUMEXP_SUMEXP: &str = "groupby_logsumexp_sumexp";
837    pub const GROUPBY_LOGSUMEXP_FINAL: &str = "groupby_logsumexp_final";
838}
839
840/// Kernel function names in the scan module
841pub mod scan_kernels {
842    pub const BLOCK_INCLUSIVE_SCAN: &str = "block_inclusive_scan";
843    pub const ADD_BLOCK_OFFSETS: &str = "add_block_offsets";
844    pub const EXCLUSIVE_SCAN_MASK: &str = "exclusive_scan_mask";
845    pub const COUNT_MASK: &str = "count_mask";
846    // Multi-block scan kernels for large prefix sums
847    pub const MULTIBLOCK_SCAN_PHASE1: &str = "multiblock_scan_phase1";
848    pub const MULTIBLOCK_SCAN_U32_PHASE1: &str = "multiblock_scan_u32_phase1";
849    pub const MULTIBLOCK_SCAN_PHASE2: &str = "multiblock_scan_phase2";
850    pub const MULTIBLOCK_SCAN_PHASE3: &str = "multiblock_scan_phase3";
851}
852
853/// Kernel function names in the sort module
854pub mod sort_kernels {
855    pub const RADIX_HISTOGRAM: &str = "radix_histogram";
856    pub const RADIX_SCATTER: &str = "radix_scatter";
857    pub const COMPUTE_RANKS: &str = "compute_ranks";
858    pub const RADIX_SCATTER_STABLE: &str = "radix_scatter_stable";
859    pub const COMPUTE_DIGIT_PREFIX_SUMS: &str = "compute_digit_prefix_sums";
860    pub const INIT_INDICES: &str = "init_indices";
861    pub const APPLY_PERMUTATION_U32: &str = "apply_permutation_u32";
862    pub const APPLY_PERMUTATION_BYTES: &str = "apply_permutation_bytes";
863
864    pub const GATHER_KEYS_I32_ORDERED_U32: &str = "gather_keys_i32_ordered_u32";
865    pub const GATHER_KEYS_F32_ORDERED_U32: &str = "gather_keys_f32_ordered_u32";
866    pub const GATHER_KEYS_BOOL_ORDERED_U32: &str = "gather_keys_bool_ordered_u32";
867
868    pub const GATHER_KEYS_U64_LO_U32: &str = "gather_keys_u64_lo_u32";
869    pub const GATHER_KEYS_U64_HI_U32: &str = "gather_keys_u64_hi_u32";
870
871    pub const GATHER_KEYS_I64_LO_U32: &str = "gather_keys_i64_lo_u32";
872    pub const GATHER_KEYS_I64_HI_U32: &str = "gather_keys_i64_hi_u32";
873
874    pub const GATHER_KEYS_F64_LO_U32: &str = "gather_keys_f64_lo_u32";
875    pub const GATHER_KEYS_F64_HI_U32: &str = "gather_keys_f64_hi_u32";
876    /// Sort-merge sortedness-detection kernel — single-pass adjacent-
877    /// pair check; atomically writes 0 to a u32 flag on
878    /// `keys[i] > keys[i+1]`. Caller initializes flag to 1
879    /// before launch, reads result post-launch. Used by the
880    /// dispatch-site eligibility check at `execute_join` to
881    /// validate caller-asserted sortedness before invoking
882    /// `sort_merge_join_v2_inner_u32_1key`.
883    pub const CHECK_ASCENDING_SORTED_U32: &str = "check_ascending_sorted_u32";
884}
885
886/// Kernel function names in the filter module
887pub mod filter_kernels {
888    pub const FILTER_COMPARE_U32: &str = "filter_compare_u32";
889    pub const FILTER_COMPARE_I64: &str = "filter_compare_i64";
890    pub const FILTER_COMPARE_F64: &str = "filter_compare_f64";
891    pub const FILTER_COMPARE_I32: &str = "filter_compare_i32";
892    pub const FILTER_COMPARE_U64: &str = "filter_compare_u64";
893    pub const FILTER_COMPARE_F32: &str = "filter_compare_f32";
894    pub const FILTER_COMPARE_U8: &str = "filter_compare_u8";
895    pub const FILTER_COMPARE_U32_SCAN_PHASE1: &str = "filter_compare_u32_scan_phase1";
896    pub const FILTER_COMPARE_F64_SCAN_PHASE1: &str = "filter_compare_f64_scan_phase1";
897    pub const FILTER_COMPARE_F32_SCAN_PHASE1: &str = "filter_compare_f32_scan_phase1";
898    pub const FILTER_COMPARE_U32_COL: &str = "filter_compare_u32_col";
899    pub const FILTER_COMPARE_I32_COL: &str = "filter_compare_i32_col";
900    pub const FILTER_COMPARE_I64_COL: &str = "filter_compare_i64_col";
901    pub const FILTER_COMPARE_U64_COL: &str = "filter_compare_u64_col";
902    pub const FILTER_COMPARE_F32_COL: &str = "filter_compare_f32_col";
903    pub const FILTER_COMPARE_F64_COL: &str = "filter_compare_f64_col";
904    pub const FILTER_COMPARE_U8_COL: &str = "filter_compare_u8_col";
905    pub const FILL_U32_IOTA: &str = "fill_u32_iota";
906    pub const FILL_U32_CONST: &str = "fill_u32_const";
907    pub const MARK_RANDOM_VARS: &str = "mark_random_vars";
908    pub const RANDOM_VAR_TO_BIT_FROM_LIST: &str = "random_var_to_bit_from_list";
909    pub const CHECK_RANDOM_VAR_COUNT: &str = "check_random_var_count";
910    pub const COMPACT_U32_BY_MASK: &str = "compact_u32_by_mask";
911    pub const COMPACT_I64_BY_MASK: &str = "compact_i64_by_mask";
912    pub const COMPACT_F64_BY_MASK: &str = "compact_f64_by_mask";
913    pub const COMPACT_BYTES_BY_MASK: &str = "compact_bytes_by_mask";
914    pub const CAPTURE_COMPACT_COUNT: &str = "capture_compact_count";
915    pub const MASK_CLAMP_ROWS: &str = "mask_clamp_rows";
916    pub const MASK_AND: &str = "mask_and";
917    pub const MASK_OR: &str = "mask_or";
918    pub const MASK_NOT: &str = "mask_not";
919}
920
921/// Kernel function names in the set_ops module
922pub mod set_ops_kernels {
923    pub const CONCAT_U32: &str = "concat_u32";
924    pub const CONCAT_BYTES: &str = "concat_bytes";
925    pub const SORTED_DIFF_MARK: &str = "sorted_diff_mark";
926}
927
928/// Kernel function names in the pack module (GPU-side key packing)
929pub mod pack_kernels {
930    /// Pack multiple columns into row-major byte array
931    pub const PACK_KEYS: &str = "pack_keys";
932    /// Compute FNV-1a hash from packed keys
933    pub const HASH_PACKED_KEYS: &str = "hash_packed_keys";
934    /// Fused pack + hash in single pass (optimal for join key preparation)
935    pub const PACK_AND_HASH_KEYS: &str = "pack_and_hash_keys";
936    /// Fused pack + hash for arbitrary key column counts
937    pub const PACK_AND_HASH_KEYS_GENERIC: &str = "pack_and_hash_keys_generic";
938    /// Vectorized pack for 8-byte aligned columns
939    pub const PACK_KEYS_ALIGNED: &str = "pack_keys_aligned";
940    /// Unpack single column from packed row data
941    pub const UNPACK_COLUMN: &str = "unpack_column";
942    /// Unpack single column with device-resident row count
943    pub const UNPACK_COLUMN_COUNTED: &str = "unpack_column_counted";
944    /// Gather rows from packed data based on index array
945    pub const GATHER_PACKED_ROWS: &str = "gather_packed_rows";
946    /// Gather rows with device-resident row count
947    pub const GATHER_PACKED_ROWS_COUNTED: &str = "gather_packed_rows_counted";
948    /// Scatter write: distribute packed rows to non-contiguous output positions
949    pub const SCATTER_PACKED_ROWS: &str = "scatter_packed_rows";
950    /// Compare packed keys for equality
951    pub const COMPARE_PACKED_KEYS: &str = "compare_packed_keys";
952    /// Pack u8 bools into Arrow bitmap bytes
953    pub const PACK_BOOLS_TO_BITMAP: &str = "pack_bools_to_bitmap";
954}
955
956/// Kernel function names in the circuit module
957pub mod circuit_kernels {
958    pub const XGCF_FORWARD_LEVEL: &str = "xgcf_forward_level";
959    pub const XGCF_BACKWARD_LEVEL_PROPAGATE: &str = "xgcf_backward_level_propagate";
960    pub const XGCF_BACKWARD_LEVEL_DECISION_GRAD: &str = "xgcf_backward_level_decision_grad";
961    pub const XGCF_BACKWARD_LEVEL_LIT_GRAD: &str = "xgcf_backward_level_lit_grad";
962    pub const XGCF_FREE_VAR_APPLY_GRAD: &str = "xgcf_free_var_apply_grad";
963    pub const XGCF_FREE_VAR_REDUCE_STAGE: &str = "xgcf_free_var_reduce_stage";
964    pub const XGCF_ADD_SCALAR: &str = "xgcf_add_scalar";
965    pub const XGCF_FORWARD_LEVEL_CACHED: &str = "xgcf_forward_level_cached";
966    pub const XGCF_EVAL_ALL_LEVELS_CACHED: &str = "xgcf_eval_all_levels_cached";
967    pub const XGCF_EVAL_ALL_LEVELS_CACHED_BATCHED: &str = "xgcf_eval_all_levels_cached_batched";
968    pub const XGCF_BACKWARD_LEVEL_PROPAGATE_CACHED: &str = "xgcf_backward_level_propagate_cached";
969    pub const XGCF_BACKWARD_LEVEL_DECISION_GRAD_CACHED: &str =
970        "xgcf_backward_level_decision_grad_cached";
971    pub const XGCF_BACKWARD_LEVEL_LIT_GRAD_CACHED: &str = "xgcf_backward_level_lit_grad_cached";
972    pub const XGCF_BACKWARD_ALL_LEVELS_CACHED: &str = "xgcf_backward_all_levels_cached";
973    pub const XGCF_BACKWARD_ALL_LEVELS_CACHED_BATCHED: &str =
974        "xgcf_backward_all_levels_cached_batched";
975    pub const XGCF_FREE_VAR_APPLY_GRAD_CACHED: &str = "xgcf_free_var_apply_grad_cached";
976    pub const XGCF_FREE_VAR_REDUCE_STAGE_CACHED: &str = "xgcf_free_var_reduce_stage_cached";
977    pub const XGCF_ADD_SCALAR_CACHED: &str = "xgcf_add_scalar_cached";
978    pub const XGCF_SET_ROOT_ADJ_CACHED_BATCHED: &str = "xgcf_set_root_adj_cached_batched";
979    pub const XGCF_COPY_ROOT_CACHED: &str = "xgcf_copy_root_cached";
980    pub const XGCF_COPY_ROOT_CACHED_META: &str = "xgcf_copy_root_cached_meta";
981    pub const XGCF_COPY_ROOT_CACHED_META_BATCHED: &str = "xgcf_copy_root_cached_meta_batched";
982}
983
984/// Kernel function names in the cache module
985pub mod cache_kernels {
986    pub const CACHE_CNF_HASH: &str = "cache_cnf_hash";
987    pub const CACHE_LOOKUP_OR_INSERT: &str = "cache_lookup_or_insert";
988    pub const CACHE_EVICT_LRU: &str = "cache_evict_lru";
989    pub const CACHE_STORE_U8: &str = "cache_store_u8";
990    pub const CACHE_STORE_U32: &str = "cache_store_u32";
991    pub const CACHE_STORE_I32: &str = "cache_store_i32";
992    pub const CACHE_STORE_F64: &str = "cache_store_f64";
993    pub const CACHE_STORE_META: &str = "cache_store_meta";
994}
995
996/// Kernel function names in the SAT module
997pub mod sat_kernels {
998    pub const SAT_CDCL_SOLVE: &str = "sat_cdcl_solve";
999    pub const SAT_CHECK_MODEL: &str = "sat_check_model";
1000    pub const SAT_PROOF_MARK_NEEDED: &str = "sat_proof_mark_needed";
1001    pub const SAT_PROOF_CHECK: &str = "sat_proof_check";
1002    pub const SAT_ASSERT_STATUS: &str = "sat_assert_status";
1003    pub const SAT_ASSERT_OK: &str = "sat_assert_ok";
1004    pub const SAT_XGCF_CNF_COUNTS: &str = "sat_xgcf_cnf_counts";
1005    pub const SAT_XGCF_CNF_EMIT: &str = "sat_xgcf_cnf_emit";
1006    pub const SAT_XGCF_CNF_CAPTURE_LAST_COUNTS: &str = "sat_xgcf_cnf_capture_last_counts";
1007    pub const SAT_XGCF_CNF_COMPUTE_TOTALS: &str = "sat_xgcf_cnf_compute_totals";
1008    pub const SAT_CNF_WRITE_TERMINATOR: &str = "sat_cnf_write_terminator";
1009    pub const SAT_CNF_COPY_INTO: &str = "sat_cnf_copy_into";
1010    pub const SAT_SHIFT_OFFSETS: &str = "sat_shift_offsets";
1011    pub const SAT_XGCF_WRITE_ROOT_UNIT_CLAUSE: &str = "sat_xgcf_write_root_unit_clause";
1012    pub const SAT_NOT_PHI_COUNTS: &str = "sat_not_phi_counts";
1013    pub const SAT_EMIT_NOT_PHI: &str = "sat_emit_not_phi";
1014}
1015
1016/// Default maximum output size for join operations.
1017/// This prevents memory overflow when joining large tables with high cardinality matches.
1018pub const DEFAULT_JOIN_MAX_OUTPUT: usize = 1_000_000;
1019
1020/// Nested-loop join eligibility threshold (Cartesian product
1021/// upper bound). The dispatcher routes to nested-loop iff
1022/// `num_left * num_right <= NESTED_LOOP_TOTAL_THRESHOLD`; the
1023/// provider validates the same invariant fail-closed before any
1024/// allocation.
1025///
1026/// This is the **single source of truth** for the threshold.
1027/// `xlog-runtime`'s dispatch site imports this constant; do NOT
1028/// redeclare in xlog-runtime (would create either drift risk or
1029/// a reverse `xlog-cuda → xlog-runtime` dep cycle).
1030///
1031/// Value (`4_000_000`) is grounded in the bench-spike at
1032/// `bench-spike/w42-nested-loop` HEAD `9c0cefc6` (see
1033/// `docs/evidence/2026-05-07-w42-bench-spike/README.md`):
1034/// largest symmetric tested cell `L=R=2000` → 4M total wins by
1035/// 5.41× over hash; the algorithmic crossover is extrapolated to
1036/// ~10000×10000 = 100M; 4M leaves 6× margin to absorb
1037/// production-kernel cost asymmetry. The threshold also caps the
1038/// index-array allocation at 32 MB total (4M × 4 bytes × 2
1039/// arrays).
1040pub const NESTED_LOOP_TOTAL_THRESHOLD: u64 = 4_000_000;
1041
1042/// Comparison operators for filtering
1043#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1044#[repr(u8)]
1045pub enum CompareOp {
1046    Eq = 0,
1047    Ne = 1,
1048    Lt = 2,
1049    Le = 3,
1050    Gt = 4,
1051    Ge = 5,
1052}
1053
1054/// Join types for hash_join_v2
1055#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1056pub enum JoinType {
1057    /// Inner join: return rows where keys match on both sides
1058    Inner,
1059    /// Semi join: return left rows that have any match in right (no right columns)
1060    Semi,
1061    /// Anti join: return left rows that have NO match in right
1062    Anti,
1063    /// Left outer join: return all left rows, with nulls for non-matching right
1064    LeftOuter,
1065}
1066
1067/// Result of packing key columns and computing hashes for join operations
1068struct PackedKeyData {
1069    /// Computed hash values (one per row)
1070    hashes: crate::memory::TrackedCudaSlice<u64>,
1071    /// Packed key data in row-major format
1072    packed_keys: crate::memory::TrackedCudaSlice<u8>,
1073    /// Total bytes per row (key stride)
1074    key_bytes: u32,
1075}
1076
1077struct JoinHashTableV2 {
1078    bucket_counts: crate::memory::TrackedCudaSlice<u32>,
1079    bucket_offsets: crate::memory::TrackedCudaSlice<u32>,
1080    bucket_entries: crate::memory::TrackedCudaSlice<u32>,
1081    bucket_entry_hashes: crate::memory::TrackedCudaSlice<u64>,
1082    bucket_mask: u32,
1083}
1084
1085/// Bucketed hash table for u64 hashes.
1086pub struct HashTableU64 {
1087    pub bucket_counts: crate::memory::TrackedCudaSlice<u32>,
1088    pub bucket_offsets: crate::memory::TrackedCudaSlice<u32>,
1089    pub bucket_entries: crate::memory::TrackedCudaSlice<u32>,
1090    pub bucket_entry_hashes: crate::memory::TrackedCudaSlice<u64>,
1091    pub bucket_mask: u32,
1092}
1093
1094/// Cached build-side join index for v2 hash join.
1095///
1096/// This captures the packed key bytes and bucketed hash table layout for the build (right) side,
1097/// enabling reuse across repeated joins on the same relation + key columns.
1098pub struct JoinIndexV2 {
1099    right_num_rows: u32,
1100    right_keys: Vec<usize>,
1101    key_bytes: u32,
1102    packed_keys: crate::memory::TrackedCudaSlice<u8>,
1103    table: JoinHashTableV2,
1104}
1105
1106impl JoinIndexV2 {
1107    /// Key columns (indices) this index was built for.
1108    pub fn right_keys(&self) -> &[usize] {
1109        &self.right_keys
1110    }
1111
1112    /// Row count of the build-side buffer at index build time.
1113    pub fn right_num_rows(&self) -> u32 {
1114        self.right_num_rows
1115    }
1116
1117    /// Approximate device memory used by this cached index.
1118    pub fn estimated_bytes(&self) -> u64 {
1119        let mut bytes = 0u64;
1120        bytes = bytes.saturating_add(self.packed_keys.len() as u64);
1121        bytes = bytes.saturating_add(self.table.bucket_counts.len() as u64 * 4);
1122        bytes = bytes.saturating_add(self.table.bucket_offsets.len() as u64 * 4);
1123        bytes = bytes.saturating_add(self.table.bucket_entries.len() as u64 * 4);
1124        bytes = bytes.saturating_add(self.table.bucket_entry_hashes.len() as u64 * 8);
1125        bytes
1126    }
1127}
1128
1129/// CUDA kernel provider for xlog GPU operations
1130///
1131/// Manages pre-compiled PTX modules for relational operations:
1132/// - **Join**: Hash join with build/probe phases
1133/// - **Dedup**: Sort-based deduplication with prefix-sum compaction
1134/// - **GroupBy**: Sorted-input group aggregation (count, sum, min, max)
1135///
1136/// PTX modules are loaded at construction time and stored in the CUDA device.
1137/// Kernel functions can be retrieved using `device.get_func()`.
1138///
1139/// # Example
1140/// ```ignore
1141/// use std::sync::Arc;
1142/// use xlog_cuda::{CudaDevice, GpuMemoryManager, CudaKernelProvider};
1143/// use xlog_core::MemoryBudget;
1144///
1145/// let device = Arc::new(CudaDevice::new(0)?);
1146/// let memory = Arc::new(GpuMemoryManager::new(device.clone(), MemoryBudget::default()));
1147/// let provider = CudaKernelProvider::new(device, memory)?;
1148/// ```
1149pub struct CudaKernelProvider {
1150    /// The CUDA device with loaded PTX modules
1151    device: Arc<CudaDevice>,
1152    /// GPU memory manager for kernel allocations
1153    memory: Arc<GpuMemoryManager>,
1154    /// Tracked host transfers for diagnostics
1155    transfer_tracker: HostTransferTracker,
1156    /// PTX load profiling data (populated only when XLOG_WARMUP_PROFILE=1)
1157    ptx_load_profile: Option<PtxLoadProfile>,
1158    /// Column-level D2H transfer counter (incremented by each download_column_* call)
1159    d2h_transfer_count: AtomicU64,
1160    /// Untracked control-plane metadata D2H read counter. Incremented by every
1161    /// `dtoh_scalar_untracked` / `dtoh_small_metadata_untracked` call. These are
1162    /// bounded metadata reads (row counts, scan totals) exempt from the
1163    /// data-plane transfer contract, but the GPU-resident MC engine's no-host
1164    /// gate must prove they are *also* zero inside the measured region — hence an
1165    /// explicit, resettable counter.
1166    untracked_metadata_dtoh_count: AtomicU64,
1167    /// Strict deterministic-Datalog D2H gate. When `true`, any data-plane D2H
1168    /// transfer (column downloads or `dtoh_sync_copy_into_tracked`) increments
1169    /// the violation counter and returns `XlogError::Execution` from the
1170    /// originating call. Metadata reads via `dtoh_scalar_untracked` are NOT
1171    /// gated. See [`CudaKernelProvider::enable_strict_deterministic_d2h`].
1172    strict_deterministic_d2h: AtomicBool,
1173    /// Cumulative count of deterministic-D2H gate violations observed since
1174    /// the last reset. Increments even on the failing path (the originating
1175    /// call still returns `Err`); kept for telemetry and tests.
1176    deterministic_d2h_violations: AtomicU64,
1177    /// Lazy-initialized non-default launch stream used by
1178    /// env-gated recorded-operator dispatch (filter, sort,
1179    /// dedup, GroupBy, hash-join). Cached for the provider's
1180    /// lifetime — the [`crate::device_runtime::StreamPool`]
1181    /// never returns streams to a free-list, so per-call
1182    /// acquire would saturate it. One stream per provider is
1183    /// sufficient because the recorder serializes work on it;
1184    /// multiple operations chain through commit-order events.
1185    recorded_op_stream: OnceLock<crate::device_runtime::StreamId>,
1186    /// Test/diagnostic-only counter for CSM (count-scan-materialize)
1187    /// invocations selected by the recorded hash-join dispatch.
1188    /// **Not part of any public stability guarantee** — its existence,
1189    /// shape, exposure, and increment semantics may change in any
1190    /// release. Used by the env-dispatch test suite to prove that CSM
1191    /// was actually selected for eligible Inner / LeftOuter cases (and
1192    /// not selected for Semi / Anti or when the env gate is off).
1193    csm_invocations: AtomicU64,
1194    /// Diagnostic counter for bounded CSM CUDA Graph captures.
1195    csm_cuda_graph_captures: AtomicU64,
1196    /// Diagnostic counter for bounded CSM CUDA Graph launches.
1197    csm_cuda_graph_launches: AtomicU64,
1198    /// Diagnostic counter for bounded CSM CUDA Graph ineligibility fallbacks.
1199    csm_cuda_graph_fallbacks: AtomicU64,
1200    /// Diagnostic counter for bounded CSM CUDA Graph cache replays.
1201    csm_cuda_graph_cache_hits: AtomicU64,
1202    /// Diagnostic counter for graph-mode small full-row set-maintenance
1203    /// sorts. This is test telemetry only; production correctness must not
1204    /// depend on the value.
1205    small_full_row_sort_invocations: AtomicU64,
1206    /// Bounded CSM CUDA Graph replay cache.
1207    csm_cuda_graph_cache: Mutex<HashMap<CsmCudaGraphKey, CsmCudaGraphEntry>>,
1208    /// Per-process counter of WCOJ layout fast-path hits. The
1209    /// fast-path skips `dedup_full_row_recorded` when the input
1210    /// is already strictly lex-sorted and full-row unique.
1211    /// Tests + the phase report binary read this counter to
1212    /// confirm the fast-path actually fired vs. silently fell
1213    /// through to the existing dedup pipeline.
1214    wcoj_layout_fast_path_hit_count: AtomicU64,
1215    /// Diagnostic counter for generic WCOJ layout-sort helper
1216    /// invocations. Used by K-clique dispatch-plan certifications to
1217    /// prove K-clique runtime dispatch no longer routes every edge
1218    /// through the old all-edge `wcoj_layout_sort_*_recorded` path.
1219    wcoj_layout_sort_invocation_count: AtomicU64,
1220    /// Diagnostic counter for K-clique leader-edge metadata builds.
1221    kclique_metadata_build_count: AtomicU64,
1222    /// Diagnostic counter for cumulative nanoseconds spent building K-clique
1223    /// leader-edge metadata.
1224    kclique_metadata_build_nanos: AtomicU64,
1225    /// Histogram-guided triangle WCOJ routing counter: successful dispatches
1226    /// accepted through the block-slice provider entry.
1227    wcoj_triangle_hg_dispatch_count: AtomicU64,
1228    /// Diagnostic-only: last WCOJ triangle dispatch's per-phase
1229    /// CUDA-event timings, populated by `wcoj_triangle_*_recorded`
1230    /// when the `wcoj-phase-timing` Cargo feature is on. Read by
1231    /// the `wcoj_phase_report` binary in xlog-integration. Field
1232    /// is absent when the feature is off, so production builds
1233    /// have zero overhead.
1234    #[cfg(feature = "wcoj-phase-timing")]
1235    last_triangle_phase_timing:
1236        std::sync::Mutex<Option<crate::wcoj_phase_timing::WcojTrianglePhaseTiming>>,
1237}
1238
1239#[derive(Default)]
1240struct HostTransferTracker {
1241    dtoh_bytes: AtomicU64,
1242    htod_bytes: AtomicU64,
1243    dtoh_calls: AtomicU64,
1244    htod_calls: AtomicU64,
1245    launch_metadata_htod_bytes: AtomicU64,
1246    launch_metadata_htod_calls: AtomicU64,
1247}
1248
1249#[derive(Debug, Clone, Copy)]
1250pub struct HostTransferStats {
1251    pub dtoh_bytes: u64,
1252    pub htod_bytes: u64,
1253    pub dtoh_calls: u64,
1254    pub htod_calls: u64,
1255}
1256
1257#[derive(Debug, Clone, Copy, Default)]
1258pub struct HostLaunchMetadataTransferStats {
1259    pub htod_bytes: u64,
1260    pub htod_calls: u64,
1261}
1262
1263impl HostTransferTracker {
1264    fn record_dtoh(&self, bytes: u64) {
1265        self.dtoh_calls.fetch_add(1, Ordering::Relaxed);
1266        self.dtoh_bytes.fetch_add(bytes, Ordering::Relaxed);
1267    }
1268
1269    fn record_htod(&self, bytes: u64) {
1270        self.htod_calls.fetch_add(1, Ordering::Relaxed);
1271        self.htod_bytes.fetch_add(bytes, Ordering::Relaxed);
1272    }
1273
1274    fn record_htod_launch_metadata(&self, bytes: u64) {
1275        self.launch_metadata_htod_calls
1276            .fetch_add(1, Ordering::Relaxed);
1277        self.launch_metadata_htod_bytes
1278            .fetch_add(bytes, Ordering::Relaxed);
1279    }
1280
1281    fn snapshot(&self) -> HostTransferStats {
1282        HostTransferStats {
1283            dtoh_bytes: self.dtoh_bytes.load(Ordering::Relaxed),
1284            htod_bytes: self.htod_bytes.load(Ordering::Relaxed),
1285            dtoh_calls: self.dtoh_calls.load(Ordering::Relaxed),
1286            htod_calls: self.htod_calls.load(Ordering::Relaxed),
1287        }
1288    }
1289
1290    fn launch_metadata_snapshot(&self) -> HostLaunchMetadataTransferStats {
1291        HostLaunchMetadataTransferStats {
1292            htod_bytes: self.launch_metadata_htod_bytes.load(Ordering::Relaxed),
1293            htod_calls: self.launch_metadata_htod_calls.load(Ordering::Relaxed),
1294        }
1295    }
1296
1297    fn reset(&self) {
1298        self.dtoh_bytes.store(0, Ordering::Relaxed);
1299        self.htod_bytes.store(0, Ordering::Relaxed);
1300        self.dtoh_calls.store(0, Ordering::Relaxed);
1301        self.htod_calls.store(0, Ordering::Relaxed);
1302        self.launch_metadata_htod_bytes.store(0, Ordering::Relaxed);
1303        self.launch_metadata_htod_calls.store(0, Ordering::Relaxed);
1304    }
1305}
1306
1307impl CudaKernelProvider {
1308    /// Create a new CUDA kernel provider
1309    ///
1310    /// Loads all kernel modules into the CUDA device.
1311    /// Prefers cubin for the detected SM arch, falls back to portable PTX (sm_75+).
1312    ///
1313    /// # Arguments
1314    /// * `device` - The CUDA device to load modules into
1315    /// * `memory` - The GPU memory manager for kernel allocations
1316    ///
1317    /// # Errors
1318    /// Returns `XlogError::Kernel` if PTX loading fails
1319    ///
1320    /// # Example
1321    /// ```ignore
1322    /// let device = Arc::new(CudaDevice::new(0)?);
1323    /// let memory = Arc::new(GpuMemoryManager::new(device.clone(), MemoryBudget::default()));
1324    /// let provider = CudaKernelProvider::new(device, memory)?;
1325    /// ```
1326    pub fn new(device: Arc<CudaDevice>, memory: Arc<GpuMemoryManager>) -> Result<Self> {
1327        let profiling = warmup_profiling_enabled();
1328        let ptx_load_profile = Self::load_all_kernel_modules(&device, profiling)?;
1329
1330        Ok(Self {
1331            device,
1332            memory,
1333            transfer_tracker: HostTransferTracker::default(),
1334            ptx_load_profile,
1335            d2h_transfer_count: AtomicU64::new(0),
1336            untracked_metadata_dtoh_count: AtomicU64::new(0),
1337            strict_deterministic_d2h: AtomicBool::new(false),
1338            deterministic_d2h_violations: AtomicU64::new(0),
1339            recorded_op_stream: OnceLock::new(),
1340            csm_invocations: AtomicU64::new(0),
1341            csm_cuda_graph_captures: AtomicU64::new(0),
1342            csm_cuda_graph_launches: AtomicU64::new(0),
1343            csm_cuda_graph_fallbacks: AtomicU64::new(0),
1344            csm_cuda_graph_cache_hits: AtomicU64::new(0),
1345            small_full_row_sort_invocations: AtomicU64::new(0),
1346            csm_cuda_graph_cache: Mutex::new(HashMap::new()),
1347            wcoj_layout_fast_path_hit_count: AtomicU64::new(0),
1348            wcoj_layout_sort_invocation_count: AtomicU64::new(0),
1349            kclique_metadata_build_count: AtomicU64::new(0),
1350            kclique_metadata_build_nanos: AtomicU64::new(0),
1351            wcoj_triangle_hg_dispatch_count: AtomicU64::new(0),
1352            #[cfg(feature = "wcoj-phase-timing")]
1353            last_triangle_phase_timing: std::sync::Mutex::new(None),
1354        })
1355    }
1356
1357    /// Construct a provider whose `GpuMemoryManager` must already
1358    /// have a v0.6 [`crate::device_runtime::XlogDeviceRuntime`]
1359    /// attached via [`GpuMemoryManager::with_runtime`].
1360    ///
1361    /// Equivalent to [`Self::new`] in every respect — same kernel
1362    /// loading, same field initialization — but **rejects** managers
1363    /// that lack a runtime. This guards against the misconfiguration
1364    /// in which a caller asks for runtime-routed provider semantics
1365    /// (by calling `with_runtime`) but supplies a legacy manager
1366    /// built via [`GpuMemoryManager::new`]; without the check, the
1367    /// resulting provider would silently keep using the cudarc
1368    /// default allocator and the runtime budget/logging stack would
1369    /// never observe the allocations the caller expected to be
1370    /// routed through it.
1371    ///
1372    /// Note: a runtime-routed manager passed to [`Self::new`] still
1373    /// routes correctly — `alloc::<T>` and `alloc_raw` consult
1374    /// `memory.runtime()` regardless of which provider constructor
1375    /// was used. `with_runtime` exists for callers that want the
1376    /// requirement enforced at construction time, not for
1377    /// correctness of the routing itself.
1378    ///
1379    /// This is the **opt-in** runtime entry point for providers.
1380    /// `Self::new` continues to accept managers without a runtime
1381    /// (the legacy default) and remains the production constructor
1382    /// until the runtime stack is certified end-to-end.
1383    ///
1384    /// # Errors
1385    /// Returns `XlogError::Kernel` if `memory.runtime()` is `None`,
1386    /// or anything `Self::new` would return.
1387    ///
1388    /// # Example
1389    /// ```ignore
1390    /// let device = Arc::new(CudaDevice::new(0)?);
1391    /// let runtime = Arc::new(XlogDeviceRuntime::with_resource(
1392    ///     Arc::clone(&device),
1393    ///     0,
1394    ///     Arc::new(StreamPool::with_defaults(Arc::clone(&device))),
1395    ///     Box::new(AsyncCudaResource::new(/* ... */)),
1396    /// ));
1397    /// let memory = Arc::new(GpuMemoryManager::with_runtime(
1398    ///     Arc::clone(&device),
1399    ///     MemoryBudget::default(),
1400    ///     runtime,
1401    /// ));
1402    /// let provider = CudaKernelProvider::with_runtime(device, memory)?;
1403    /// ```
1404    pub fn with_runtime(device: Arc<CudaDevice>, memory: Arc<GpuMemoryManager>) -> Result<Self> {
1405        if memory.runtime().is_none() {
1406            return Err(XlogError::Kernel(
1407                "CudaKernelProvider::with_runtime requires a GpuMemoryManager built via \
1408                 GpuMemoryManager::with_runtime; got a manager with no runtime attached"
1409                    .to_string(),
1410            ));
1411        }
1412        Self::new(device, memory)
1413    }
1414
1415    /// Internal: parse a "boolean" env var. Empty / unset / `"0"`
1416    /// → false; any other value → true.
1417    fn env_flag(name: &str) -> bool {
1418        std::env::var(name)
1419            .map(|v| !v.is_empty() && v != "0")
1420            .unwrap_or(false)
1421    }
1422
1423    /// Whether the recorded filter dispatch is enabled via env.
1424    ///
1425    /// Returns `true` when either `XLOG_USE_RECORDED_FILTERS` or
1426    /// the umbrella `XLOG_USE_RECORDED_OPS` env var is set.
1427    /// Combined with a runtime-backed manager, this routes
1428    /// `filter::<T>` through the recorded launch path.
1429    ///
1430    /// Env-gated rather than default-on so the migration is
1431    /// opt-in for real callers; the existing legacy paths remain
1432    /// the production default until the runtime stack is
1433    /// certified end-to-end.
1434    pub(crate) fn use_recorded_filters_env() -> bool {
1435        Self::env_flag("XLOG_USE_RECORDED_FILTERS") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1436    }
1437
1438    /// Whether the recorded sort dispatch is enabled via env.
1439    /// Reads `XLOG_USE_RECORDED_SORT` or the umbrella
1440    /// `XLOG_USE_RECORDED_OPS`. The recorded-sort path is narrowed
1441    /// to U32 / Symbol keys only — the public
1442    /// `sort()` dispatcher checks both this env flag AND key
1443    /// type compatibility before routing.
1444    pub(crate) fn use_recorded_sort_env() -> bool {
1445        Self::env_flag("XLOG_USE_RECORDED_SORT") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1446    }
1447
1448    /// Whether the recorded full-row dedup dispatch is enabled
1449    /// via env. Reads `XLOG_USE_RECORDED_DEDUP` or the umbrella
1450    /// `XLOG_USE_RECORDED_OPS`. `dedup_full_row_recorded` is
1451    /// narrow to all-U32 / Symbol columns.
1452    pub(crate) fn use_recorded_dedup_env() -> bool {
1453        Self::env_flag("XLOG_USE_RECORDED_DEDUP") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1454    }
1455
1456    /// Whether the recorded GroupBy dispatch is enabled via
1457    /// env. Reads `XLOG_USE_RECORDED_GROUPBY` or
1458    /// `XLOG_USE_RECORDED_OPS`. `groupby_multi_agg_recorded`
1459    /// supports U32 / Symbol keys + Count / Sum / Min / Max
1460    /// aggs only.
1461    pub(crate) fn use_recorded_groupby_env() -> bool {
1462        Self::env_flag("XLOG_USE_RECORDED_GROUPBY") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1463    }
1464
1465    /// Whether the recorded hash-join dispatch is enabled via
1466    /// env. Reads `XLOG_USE_RECORDED_HASH_JOIN` or
1467    /// `XLOG_USE_RECORDED_OPS`. `hash_join_v2_recorded` and
1468    /// `hash_join_v2_with_index_recorded` cover all four join
1469    /// types (Inner / Semi / Anti / LeftOuter); the only
1470    /// hard constraint inherited from `pack_keys` is `≤4`
1471    /// key columns.
1472    pub(crate) fn use_recorded_hash_join_env() -> bool {
1473        Self::env_flag("XLOG_USE_RECORDED_HASH_JOIN") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1474    }
1475
1476    /// Whether the recorded CSM (count-scan-materialize)
1477    /// dispatch is enabled via env. Reads `XLOG_USE_RECORDED_CSM`
1478    /// or `XLOG_USE_RECORDED_OPS`. CSM is a sub-strategy of the
1479    /// recorded hash-join: it is consulted only after the
1480    /// recorded path has already been selected, and only for
1481    /// `JoinType::Inner` / `JoinType::LeftOuter` where a CSM
1482    /// implementation exists. `Semi` / `Anti` are not affected.
1483    pub(crate) fn use_recorded_csm_env() -> bool {
1484        Self::env_flag("XLOG_USE_RECORDED_CSM") || Self::env_flag("XLOG_USE_RECORDED_OPS")
1485    }
1486
1487    /// Whether the bounded CSM CUDA Graph path is enabled.
1488    ///
1489    /// This is narrower than `XLOG_USE_RECORDED_CSM`: callers must first select
1490    /// the recorded CSM hash-join path, then opt into graph capture/replay with
1491    /// `XLOG_USE_CSM_CUDA_GRAPH=1` (or the broader `XLOG_USE_CUDA_GRAPHS=1`).
1492    pub(crate) fn use_csm_cuda_graph_env() -> bool {
1493        Self::env_flag("XLOG_USE_CSM_CUDA_GRAPH") || Self::env_flag("XLOG_USE_CUDA_GRAPHS")
1494    }
1495
1496    /// Test/diagnostic-only telemetry: number of times the recorded
1497    /// hash-join dispatch routed through a CSM (count-scan-materialize)
1498    /// method since this provider was created. Increments once per
1499    /// dispatched call across all four CSM methods (Inner / LeftOuter,
1500    /// non-indexed / indexed). Used by `test_csm_env_dispatch` to
1501    /// prove dispatch selection.
1502    ///
1503    /// **Not part of any public stability guarantee.** Hidden from
1504    /// rustdoc with `#[doc(hidden)]` so it does not appear in
1505    /// generated API docs; the symbol remains callable from
1506    /// integration tests within this crate but production callers
1507    /// must not depend on it. May be renamed, gated behind a cargo
1508    /// feature, or withdrawn in any release without notice.
1509    #[doc(hidden)]
1510    pub fn csm_invocations(&self) -> u64 {
1511        self.csm_invocations.load(Ordering::Relaxed)
1512    }
1513
1514    #[doc(hidden)]
1515    pub fn csm_cuda_graph_captures(&self) -> u64 {
1516        self.csm_cuda_graph_captures.load(Ordering::Relaxed)
1517    }
1518
1519    #[doc(hidden)]
1520    pub fn csm_cuda_graph_launches(&self) -> u64 {
1521        self.csm_cuda_graph_launches.load(Ordering::Relaxed)
1522    }
1523
1524    #[doc(hidden)]
1525    pub fn csm_cuda_graph_fallbacks(&self) -> u64 {
1526        self.csm_cuda_graph_fallbacks.load(Ordering::Relaxed)
1527    }
1528
1529    #[doc(hidden)]
1530    pub fn csm_cuda_graph_cache_hits(&self) -> u64 {
1531        self.csm_cuda_graph_cache_hits.load(Ordering::Relaxed)
1532    }
1533
1534    #[doc(hidden)]
1535    pub fn small_full_row_sort_invocations(&self) -> u64 {
1536        self.small_full_row_sort_invocations.load(Ordering::Relaxed)
1537    }
1538
1539    /// Lazily acquire one non-default launch stream from the
1540    /// runtime's [`crate::device_runtime::StreamPool`] for
1541    /// recorded-operator dispatch, and cache it for this
1542    /// provider's lifetime. Shared across all env-gated
1543    /// recorded paths (filter, sort, dedup, GroupBy,
1544    /// hash-join) — a single stream is sufficient because the
1545    /// recorder serializes work on it; multiple operations
1546    /// chain naturally through commit-order events.
1547    ///
1548    /// Returns `None` when:
1549    ///   * the manager has no runtime attached
1550    ///     (`memory.runtime() == None`), or
1551    ///   * the stream pool is at capacity and `acquire` fails.
1552    ///
1553    /// On a lost race during first init the loser leaks one
1554    /// stream (the pool keeps it alive); both winners cache
1555    /// the same `StreamId`. Acceptable cost — practical pool
1556    /// sizes are large compared to the number of providers
1557    /// per process.
1558    pub(crate) fn recorded_op_stream_or_init(&self) -> Option<crate::device_runtime::StreamId> {
1559        if let Some(s) = self.recorded_op_stream.get() {
1560            return Some(*s);
1561        }
1562        let runtime = self.memory.runtime()?;
1563        let stream = runtime.stream_pool().acquire().ok()?;
1564        let _ = self.recorded_op_stream.set(stream);
1565        self.recorded_op_stream.get().copied()
1566    }
1567
1568    /// Take the per-phase WCOJ triangle dispatch timings recorded
1569    /// by the most recent `wcoj_triangle_*_recorded` call. Reading
1570    /// clears the slot — designed for one-shot consumption by the
1571    /// `wcoj_phase_report` binary in xlog-integration. Returns
1572    /// `None` if no triangle dispatch has fired since the last
1573    /// read (or since construction).
1574    ///
1575    /// Compiled in only with the `wcoj-phase-timing` Cargo
1576    /// feature; production builds have no such method.
1577    #[cfg(feature = "wcoj-phase-timing")]
1578    pub fn take_wcoj_triangle_phase_timing(
1579        &self,
1580    ) -> Option<crate::wcoj_phase_timing::WcojTrianglePhaseTiming> {
1581        self.last_triangle_phase_timing
1582            .lock()
1583            .ok()
1584            .and_then(|mut g| g.take())
1585    }
1586
1587    /// Internal: store the phase timings produced by a triangle
1588    /// dispatch. Overwrites any prior unread slot — the report
1589    /// binary is expected to read after every `execute_plan`.
1590    #[cfg(feature = "wcoj-phase-timing")]
1591    #[allow(dead_code)]
1592    pub(crate) fn put_wcoj_triangle_phase_timing(
1593        &self,
1594        timing: crate::wcoj_phase_timing::WcojTrianglePhaseTiming,
1595    ) {
1596        if let Ok(mut g) = self.last_triangle_phase_timing.lock() {
1597            *g = Some(timing);
1598        }
1599    }
1600
1601    /// Number of times `wcoj_layout_*_recorded` short-circuited
1602    /// to the fast-path (recorded clone) instead of running
1603    /// `dedup_full_row_recorded`. Increments by 1 per
1604    /// fast-path hit (3 hits per dispatch when all inputs are
1605    /// already sorted+unique). Used by tests + the phase
1606    /// report to confirm the fast-path fired.
1607    pub fn wcoj_layout_fast_path_hit_count(&self) -> u64 {
1608        self.wcoj_layout_fast_path_hit_count.load(Ordering::Relaxed)
1609    }
1610
1611    /// Histogram-guided block-slice triangle WCOJ test/diagnostic counter:
1612    /// successful dispatches that routed through the provider entry.
1613    pub fn wcoj_triangle_hg_dispatch_count(&self) -> u64 {
1614        self.wcoj_triangle_hg_dispatch_count.load(Ordering::Relaxed)
1615    }
1616
1617    /// Reset the fast-path hit counter to 0. Tests use this to
1618    /// scope counter assertions to a single dispatch.
1619    pub fn reset_wcoj_layout_fast_path_hit_count(&self) {
1620        self.wcoj_layout_fast_path_hit_count
1621            .store(0, Ordering::Relaxed);
1622    }
1623
1624    /// Number of calls to `wcoj_layout_sort_*_recorded` since the
1625    /// last reset. Diagnostic-only; used by dispatch-plan certification.
1626    pub fn wcoj_layout_sort_invocation_count(&self) -> u64 {
1627        self.wcoj_layout_sort_invocation_count
1628            .load(Ordering::Relaxed)
1629    }
1630
1631    /// Reset the WCOJ layout-sort invocation counter to 0.
1632    pub fn reset_wcoj_layout_sort_invocation_count(&self) {
1633        self.wcoj_layout_sort_invocation_count
1634            .store(0, Ordering::Relaxed);
1635    }
1636
1637    /// Number of K-clique leader-edge metadata builds since the
1638    /// last reset.
1639    pub fn kclique_metadata_build_count(&self) -> u64 {
1640        self.kclique_metadata_build_count.load(Ordering::Relaxed)
1641    }
1642
1643    /// Cumulative nanoseconds spent building K-clique leader-edge
1644    /// metadata since the last reset.
1645    pub fn kclique_metadata_build_nanos(&self) -> u64 {
1646        self.kclique_metadata_build_nanos.load(Ordering::Relaxed)
1647    }
1648
1649    /// Reset K-clique metadata build diagnostics.
1650    pub fn reset_kclique_metadata_build_metrics(&self) {
1651        self.kclique_metadata_build_count
1652            .store(0, Ordering::Relaxed);
1653        self.kclique_metadata_build_nanos
1654            .store(0, Ordering::Relaxed);
1655    }
1656
1657    /// Internal: increment the fast-path counter. Called by
1658    /// `wcoj_layout_*_recorded` after a successful fast-path
1659    /// branch. Not part of any public stability guarantee.
1660    pub(crate) fn record_wcoj_layout_fast_path_hit(&self) {
1661        self.wcoj_layout_fast_path_hit_count
1662            .fetch_add(1, Ordering::Relaxed);
1663    }
1664
1665    /// Internal: increment the generic WCOJ layout-sort counter.
1666    pub(crate) fn record_wcoj_layout_sort_invocation(&self) {
1667        self.wcoj_layout_sort_invocation_count
1668            .fetch_add(1, Ordering::Relaxed);
1669    }
1670
1671    /// Internal: record a K-clique leader-edge metadata build.
1672    pub(crate) fn record_kclique_metadata_build_nanos(&self, nanos: u128) {
1673        self.kclique_metadata_build_count
1674            .fetch_add(1, Ordering::Relaxed);
1675        let nanos = u64::try_from(nanos).unwrap_or(u64::MAX);
1676        self.kclique_metadata_build_nanos
1677            .fetch_add(nanos, Ordering::Relaxed);
1678    }
1679
1680    /// Runtime hook: record a successful histogram-guided block-slice triangle
1681    /// dispatch.
1682    #[doc(hidden)]
1683    pub fn record_wcoj_triangle_hg_dispatch(&self) {
1684        self.wcoj_triangle_hg_dispatch_count
1685            .fetch_add(1, Ordering::Relaxed);
1686    }
1687
1688    /// Get the CUDA device
1689    pub fn device(&self) -> &Arc<CudaDevice> {
1690        &self.device
1691    }
1692
1693    /// Get the GPU memory manager
1694    pub fn memory(&self) -> &Arc<GpuMemoryManager> {
1695        &self.memory
1696    }
1697
1698    /// Get PTX load profiling data (only populated when XLOG_WARMUP_PROFILE=1).
1699    pub fn ptx_load_profile(&self) -> Option<&PtxLoadProfile> {
1700        self.ptx_load_profile.as_ref()
1701    }
1702
1703    /// Reset tracked host transfer statistics.
1704    pub fn reset_host_transfer_stats(&self) {
1705        self.transfer_tracker.reset();
1706    }
1707
1708    /// Snapshot tracked host transfer statistics.
1709    pub fn host_transfer_stats(&self) -> HostTransferStats {
1710        self.transfer_tracker.snapshot()
1711    }
1712
1713    /// Snapshot launch-parameter H2D uploads tracked separately from
1714    /// `host_transfer_stats`.
1715    pub fn host_launch_metadata_transfer_stats(&self) -> HostLaunchMetadataTransferStats {
1716        self.transfer_tracker.launch_metadata_snapshot()
1717    }
1718
1719    /// Read the column-level D2H transfer counter.
1720    ///
1721    /// This counter increments once per `download_column_*` call, enabling
1722    /// callers (e.g. the ILP trainer) to assert that no column downloads
1723    /// occurred during a performance-critical section.
1724    pub fn d2h_transfer_count(&self) -> u64 {
1725        self.d2h_transfer_count.load(Ordering::Relaxed)
1726    }
1727
1728    /// Reset the column-level D2H transfer counter to zero.
1729    pub fn reset_d2h_transfer_count(&self) {
1730        self.d2h_transfer_count.store(0, Ordering::Relaxed);
1731    }
1732
1733    /// Count of untracked control-plane metadata D2H reads
1734    /// (`dtoh_scalar_untracked` + `dtoh_small_metadata_untracked`).
1735    pub fn untracked_metadata_dtoh_count(&self) -> u64 {
1736        self.untracked_metadata_dtoh_count.load(Ordering::Relaxed)
1737    }
1738
1739    /// Reset the untracked metadata D2H read counter to zero.
1740    pub fn reset_untracked_metadata_dtoh_count(&self) {
1741        self.untracked_metadata_dtoh_count
1742            .store(0, Ordering::Relaxed);
1743    }
1744
1745    /// Enable the strict deterministic-Datalog D2H gate.
1746    ///
1747    /// While enabled, any data-plane device-to-host transfer (column downloads
1748    /// via `download_column` / `download_column_untracked`, and any internal
1749    /// transfer routed through `dtoh_sync_copy_into_tracked`) increments
1750    /// [`CudaKernelProvider::deterministic_d2h_violation_count`] and returns
1751    /// `XlogError::Execution` from the originating call.
1752    ///
1753    /// Metadata reads via [`CudaKernelProvider::dtoh_scalar_untracked`] are
1754    /// allowed and never trip the gate.
1755    ///
1756    /// Default is `false`; the runtime opts in via
1757    /// `RuntimeConfig::strict_deterministic_d2h`. v0.5.5 ships the gate
1758    /// opt-in only — known-violating relational paths (set difference,
1759    /// join count/materialize) are scheduled for replacement before the
1760    /// default flips.
1761    pub fn enable_strict_deterministic_d2h(&self) {
1762        self.strict_deterministic_d2h.store(true, Ordering::Relaxed);
1763    }
1764
1765    /// Disable the strict deterministic-Datalog D2H gate.
1766    pub fn disable_strict_deterministic_d2h(&self) {
1767        self.strict_deterministic_d2h
1768            .store(false, Ordering::Relaxed);
1769    }
1770
1771    /// Returns whether the strict deterministic-Datalog D2H gate is enabled.
1772    pub fn strict_deterministic_d2h_enabled(&self) -> bool {
1773        self.strict_deterministic_d2h.load(Ordering::Relaxed)
1774    }
1775
1776    /// Cumulative deterministic-D2H gate violations since the last reset.
1777    pub fn deterministic_d2h_violation_count(&self) -> u64 {
1778        self.deterministic_d2h_violations.load(Ordering::Relaxed)
1779    }
1780
1781    /// Reset the deterministic-D2H violation counter to zero.
1782    pub fn reset_deterministic_d2h_violations(&self) {
1783        self.deterministic_d2h_violations
1784            .store(0, Ordering::Relaxed);
1785    }
1786
1787    /// Chokepoint for the deterministic-D2H gate.
1788    ///
1789    /// If the gate is enabled, increments the violation counter and returns
1790    /// `XlogError::Execution` naming the offending operation and byte count.
1791    /// If the gate is disabled, returns `Ok(())` cheaply.
1792    pub(crate) fn check_deterministic_d2h(&self, op: &'static str, bytes: u64) -> Result<()> {
1793        if self.strict_deterministic_d2h.load(Ordering::Relaxed) {
1794            self.deterministic_d2h_violations
1795                .fetch_add(1, Ordering::Relaxed);
1796            return Err(XlogError::Execution(format!(
1797                "deterministic D2H gate: {} attempted to copy {} bytes from device to host",
1798                op, bytes
1799            )));
1800        }
1801        Ok(())
1802    }
1803
1804    fn dtoh_sync_copy_into_tracked<T: DeviceRepr, Src: DevicePtr<T>>(
1805        &self,
1806        src: &Src,
1807        dst: &mut [T],
1808    ) -> Result<()> {
1809        let bytes = std::mem::size_of::<T>()
1810            .checked_mul(dst.len())
1811            .ok_or_else(|| XlogError::Kernel("dtoh size overflow".to_string()))?;
1812        self.check_deterministic_d2h("dtoh_sync_copy_into_tracked", bytes as u64)?;
1813        self.transfer_tracker.record_dtoh(bytes as u64);
1814        self.device
1815            .inner()
1816            .dtoh_sync_copy_into(src, dst)
1817            .map_err(|e| XlogError::Kernel(format!("Failed to copy from device: {}", e)))
1818    }
1819
1820    /// Hard cap (in bytes) for [`Self::dtoh_small_metadata_untracked`].
1821    /// Set deliberately small (4 KB) so the helper cannot become a
1822    /// general-purpose vector D2H escape hatch — it's strictly for
1823    /// classifier histograms and similar small metadata round-trips.
1824    pub const DTOH_SMALL_METADATA_MAX_BYTES: usize = 4096;
1825
1826    /// Read a small metadata vector (≤ [`Self::DTOH_SMALL_METADATA_MAX_BYTES`])
1827    /// from device to host WITHOUT updating the D2H transfer tracker.
1828    ///
1829    /// Sibling of [`Self::dtoh_scalar_untracked`] for callers that need
1830    /// a few bucket counts (the WCOJ skew classifier reads a 3 × 64 ×
1831    /// `u32` = 768-byte histogram in one go) instead of `count` separate
1832    /// scalar reads. Like `dtoh_scalar_untracked`, this method is
1833    /// whitelisted by the strict deterministic-D2H gate
1834    /// ([`Self::enable_strict_deterministic_d2h`]) — it does NOT trip
1835    /// the gate, on purpose, because metadata reads are part of the
1836    /// determinism contract (just like a scalar `total` after a scan).
1837    ///
1838    /// # Hard contract — DO NOT WIDEN THE CAP
1839    /// The 4 KB cap is the contract. If a caller wants a larger D2H,
1840    /// it's a data-plane transfer and must go through the tracked
1841    /// `download_column*` path. Widening this cap turns the helper
1842    /// into a backdoor for tracked-bypass column reads, which would
1843    /// silently invalidate the strict deterministic-D2H gate.
1844    ///
1845    /// # Errors
1846    ///   * `XlogError::Kernel` if `count * size_of::<T>()` exceeds
1847    ///     `DTOH_SMALL_METADATA_MAX_BYTES`.
1848    ///   * `XlogError::Kernel` if `count` exceeds the device slice's
1849    ///     length, or if the inner sync copy fails.
1850    pub fn dtoh_small_metadata_untracked<T: DeviceRepr + Default + Copy>(
1851        &self,
1852        src: &crate::memory::TrackedCudaSlice<T>,
1853        count: usize,
1854    ) -> Result<Vec<T>> {
1855        let bytes = count.checked_mul(std::mem::size_of::<T>()).ok_or_else(|| {
1856            XlogError::Kernel("dtoh_small_metadata_untracked: byte size overflow".to_string())
1857        })?;
1858        if bytes > Self::DTOH_SMALL_METADATA_MAX_BYTES {
1859            return Err(XlogError::Kernel(format!(
1860                "dtoh_small_metadata_untracked: requested {} bytes exceeds metadata cap of {} bytes \
1861                 (this is metadata-only; use download_column* for data-plane transfers)",
1862                bytes,
1863                Self::DTOH_SMALL_METADATA_MAX_BYTES
1864            )));
1865        }
1866        if count > src.len() {
1867            return Err(XlogError::Kernel(format!(
1868                "dtoh_small_metadata_untracked: count={count} > src.len={}",
1869                src.len()
1870            )));
1871        }
1872        if count == 0 {
1873            return Ok(Vec::new());
1874        }
1875        let slice = src.try_slice(0..count).ok_or_else(|| {
1876            XlogError::Kernel(format!(
1877                "dtoh_small_metadata_untracked: try_slice(0..{count}) failed"
1878            ))
1879        })?;
1880        let mut buf: Vec<T> = vec![T::default(); count];
1881        self.untracked_metadata_dtoh_count
1882            .fetch_add(1, Ordering::Relaxed);
1883        self.device
1884            .inner()
1885            .dtoh_sync_copy_into(&slice, &mut buf)
1886            .map_err(|e| {
1887                XlogError::Kernel(format!("dtoh_small_metadata_untracked: copy failed: {}", e))
1888            })?;
1889        Ok(buf)
1890    }
1891
1892    /// Read a single scalar from device to host WITHOUT updating the
1893    /// D2H transfer tracker. Use ONLY for metadata reads (e.g. total_nnz
1894    /// after an exclusive scan), never for data-plane transfers.
1895    ///
1896    /// This makes the "metadata != data-plane" contract explicit and
1897    /// auditable: callers that bypass tracking must call this method
1898    /// (which is grep-able) rather than reaching for device().inner().
1899    pub fn dtoh_scalar_untracked<T: DeviceRepr + Default + Copy>(
1900        &self,
1901        src: &crate::memory::TrackedCudaSlice<T>,
1902        index: usize,
1903    ) -> Result<T> {
1904        if index >= src.len() {
1905            return Err(XlogError::Kernel(format!(
1906                "dtoh_scalar_untracked: index={} >= len={}",
1907                index,
1908                src.len()
1909            )));
1910        }
1911        let slice = src.try_slice(index..index + 1).ok_or_else(|| {
1912            XlogError::Kernel(format!(
1913                "dtoh_scalar_untracked: slice failed at index={}",
1914                index
1915            ))
1916        })?;
1917        let mut buf = [T::default()];
1918        self.untracked_metadata_dtoh_count
1919            .fetch_add(1, Ordering::Relaxed);
1920        self.device
1921            .inner()
1922            .dtoh_sync_copy_into(&slice, &mut buf)
1923            .map_err(|e| XlogError::Kernel(format!("dtoh_scalar_untracked: copy failed: {}", e)))?;
1924        Ok(buf[0])
1925    }
1926
1927    /// Upload host data to device while recording data-plane H2D transfer stats.
1928    pub fn htod_sync_copy_into_tracked<T: DeviceRepr, Dst: cudarc::driver::DevicePtrMut<T>>(
1929        &self,
1930        src: &[T],
1931        dst: &mut Dst,
1932    ) -> Result<()> {
1933        let bytes = std::mem::size_of::<T>()
1934            .checked_mul(src.len())
1935            .ok_or_else(|| XlogError::Kernel("htod size overflow".to_string()))?;
1936        self.transfer_tracker.record_htod(bytes as u64);
1937        self.device
1938            .inner()
1939            .htod_sync_copy_into(src, dst)
1940            .map_err(|e| XlogError::Kernel(format!("Failed to copy to device: {}", e)))
1941    }
1942
1943    /// Allocate a CUDA slice from host data while recording data-plane H2D
1944    /// transfer stats.
1945    pub fn htod_sync_copy_tracked<T: DeviceRepr>(
1946        &self,
1947        src: &[T],
1948    ) -> Result<cudarc::driver::CudaSlice<T>> {
1949        let bytes = std::mem::size_of::<T>()
1950            .checked_mul(src.len())
1951            .ok_or_else(|| XlogError::Kernel("htod size overflow".to_string()))?;
1952        self.transfer_tracker.record_htod(bytes as u64);
1953        self.device
1954            .inner()
1955            .htod_sync_copy(src)
1956            .map_err(|e| XlogError::Kernel(format!("Failed to copy to device: {}", e)))
1957    }
1958
1959    /// Upload bounded launch metadata from host to device while recording it in
1960    /// the launch-metadata subcounter.
1961    pub fn htod_launch_metadata_sync_copy_into<
1962        T: DeviceRepr,
1963        Dst: cudarc::driver::DevicePtrMut<T>,
1964    >(
1965        &self,
1966        src: &[T],
1967        dst: &mut Dst,
1968    ) -> Result<()> {
1969        let bytes = std::mem::size_of::<T>()
1970            .checked_mul(src.len())
1971            .ok_or_else(|| XlogError::Kernel("launch metadata htod size overflow".to_string()))?;
1972        self.transfer_tracker
1973            .record_htod_launch_metadata(bytes as u64);
1974        self.device
1975            .inner()
1976            .htod_sync_copy_into(src, dst)
1977            .map_err(|e| {
1978                XlogError::Kernel(format!("Failed to copy launch metadata to device: {}", e))
1979            })
1980    }
1981
1982    /// Upload one launch-metadata scalar to device on a caller-owned stream
1983    /// while recording the transfer in the launch-metadata H2D counters.
1984    pub(crate) fn htod_launch_metadata_async_copy_one<T: DeviceRepr>(
1985        &self,
1986        src: &T,
1987        dst: &TrackedCudaSlice<T>,
1988        stream: &CudaStream,
1989        context: &str,
1990    ) -> Result<()> {
1991        let bytes = std::mem::size_of::<T>();
1992        self.transfer_tracker
1993            .record_htod_launch_metadata(bytes as u64);
1994        unsafe {
1995            let res = cudarc::driver::sys::cuMemcpyHtoDAsync_v2(
1996                *dst.device_ptr(),
1997                src as *const T as *const c_void,
1998                bytes,
1999                stream.cu_stream(),
2000            );
2001            if res != cudarc::driver::sys::cudaError_enum::CUDA_SUCCESS {
2002                return Err(XlogError::Kernel(format!(
2003                    "{context}: launch metadata H2D failed: {res:?}"
2004                )));
2005            }
2006        }
2007        Ok(())
2008    }
2009
2010    /// Compute exclusive prefix sum of u8 mask, returns (prefix_sum_vec, total_count)
2011    ///
2012    /// This is useful for compaction operations where we need to know:
2013    /// 1. The output position for each input element (prefix sum)
2014    /// 2. The total number of elements that pass the mask (count)
2015    ///
2016    /// # Arguments
2017    /// * `mask` - A slice of u8 values (0 or non-zero)
2018    ///
2019    /// # Returns
2020    /// A tuple of:
2021    /// - `Vec<u32>` containing the exclusive prefix sum
2022    /// - `u32` containing the total count of non-zero mask elements
2023    ///
2024    /// # Example
2025    /// ```ignore
2026    /// let mask = vec![1u8, 0, 1, 1, 0, 1];
2027    /// let (prefix_sum, count) = provider.prefix_sum_mask(&mask)?;
2028    /// // prefix_sum = [0, 1, 1, 2, 3, 3]
2029    /// // count = 4
2030    /// ```
2031    ///
2032    /// # Note
2033    /// For small inputs (<=256 elements), a CPU scan is used for efficiency.
2034    /// For larger inputs, a three-phase multi-block GPU scan is used.
2035    ///
2036    /// # Errors
2037    /// Returns `XlogError::Kernel` if kernel execution fails
2038    pub fn exclusive_scan_u32_inplace(
2039        &self,
2040        data: &mut crate::memory::TrackedCudaSlice<u32>,
2041        n: u32,
2042    ) -> Result<()> {
2043        if n as usize > data.len() {
2044            return Err(XlogError::Kernel(format!(
2045                "exclusive_scan_u32_inplace: n={} exceeds slice len={}",
2046                n,
2047                data.len()
2048            )));
2049        }
2050        self.multiblock_scan_u32_inplace(data, n)
2051    }
2052
2053    fn multiblock_scan_u32_inplace(
2054        &self,
2055        data: &mut crate::memory::TrackedCudaSlice<u32>,
2056        n: u32,
2057    ) -> Result<()> {
2058        if n == 0 {
2059            return Ok(());
2060        }
2061
2062        let device = self.device.inner();
2063        let block_size = 256u32;
2064
2065        if n <= block_size {
2066            let phase2_fn = device
2067                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2068                .ok_or_else(|| {
2069                    XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2070                })?;
2071
2072            // SAFETY: multiblock_scan_phase2(uint32_t* block_sums, uint32_t num_blocks)
2073            unsafe {
2074                phase2_fn.clone().launch(
2075                    LaunchConfig {
2076                        grid_dim: (1, 1, 1),
2077                        block_dim: (block_size, 1, 1),
2078                        shared_mem_bytes: 0,
2079                    },
2080                    (&mut *data, n),
2081                )
2082            }
2083            .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase2 failed: {}", e)))?;
2084
2085            self.device.synchronize()?;
2086            return Ok(());
2087        }
2088
2089        let num_blocks = n.div_ceil(block_size);
2090        let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2091
2092        let phase1_u32_fn = device
2093            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2094            .ok_or_else(|| {
2095                XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2096            })?;
2097
2098        // SAFETY: multiblock_scan_u32_phase1(uint32_t* data, uint32_t* block_sums, uint32_t n)
2099        unsafe {
2100            phase1_u32_fn.clone().launch(
2101                LaunchConfig {
2102                    grid_dim: (num_blocks, 1, 1),
2103                    block_dim: (block_size, 1, 1),
2104                    shared_mem_bytes: 0,
2105                },
2106                (&mut *data, &mut block_sums, n),
2107            )
2108        }
2109        .map_err(|e| XlogError::Kernel(format!("multiblock_scan_u32_phase1 failed: {}", e)))?;
2110        self.device.synchronize()?;
2111
2112        if num_blocks > 1 {
2113            self.multiblock_scan_u32_inplace(&mut block_sums, num_blocks)?;
2114        }
2115
2116        let phase3_fn = device
2117            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2118            .ok_or_else(|| {
2119                XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2120            })?;
2121
2122        // SAFETY: multiblock_scan_phase3(uint32_t* prefix_sum, const uint32_t* block_offsets, uint32_t n)
2123        unsafe {
2124            phase3_fn.clone().launch(
2125                LaunchConfig {
2126                    grid_dim: (num_blocks, 1, 1),
2127                    block_dim: (block_size, 1, 1),
2128                    shared_mem_bytes: 0,
2129                },
2130                (&mut *data, &block_sums, n),
2131            )
2132        }
2133        .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
2134
2135        self.device.synchronize()?;
2136        Ok(())
2137    }
2138
2139    /// Stream-aware variant of [`Self::multiblock_scan_u32_inplace`].
2140    ///
2141    /// Runs every kernel of the recursive scan on `cu_stream`
2142    /// (no `device.synchronize()`), and records each intermediate
2143    /// `block_sums` allocation against the runtime so that when
2144    /// the helper returns and the local drops, the runtime's
2145    /// deallocate can queue `cuStreamWaitEvent(alloc_stream,
2146    /// recorded_event)` BEFORE `cuMemFreeAsync` — the same
2147    /// cross-stream lifetime safety the LaunchRecorder gives
2148    /// caller-provided buffers.
2149    ///
2150    /// `data` is not recorded here: the caller already records
2151    /// its own write of `data` against the same launch_stream
2152    /// (typically via `LaunchRecorder::write` BEFORE preflight).
2153    pub(crate) fn multiblock_scan_u32_inplace_on_stream(
2154        &self,
2155        data: &mut crate::memory::TrackedCudaSlice<u32>,
2156        n: u32,
2157        cu_stream: &cudarc::driver::CudaStream,
2158        launch_stream: crate::device_runtime::StreamId,
2159        runtime: &crate::device_runtime::XlogDeviceRuntime,
2160    ) -> Result<()> {
2161        if n == 0 {
2162            return Ok(());
2163        }
2164        let device = self.device.inner();
2165        let block_size = 256u32;
2166
2167        if n <= block_size {
2168            let phase2_fn = device
2169                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2170                .ok_or_else(|| {
2171                    XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2172                })?;
2173            // SAFETY: kernel signature matches; data is mutated in place.
2174            unsafe {
2175                phase2_fn.clone().launch_on_stream(
2176                    cu_stream,
2177                    LaunchConfig {
2178                        grid_dim: (1, 1, 1),
2179                        block_dim: (block_size, 1, 1),
2180                        shared_mem_bytes: 0,
2181                    },
2182                    (&mut *data, n),
2183                )
2184            }
2185            .map_err(|e| {
2186                XlogError::Kernel(format!("multiblock_scan_phase2 (on_stream) failed: {}", e))
2187            })?;
2188            return Ok(());
2189        }
2190
2191        let num_blocks = n.div_ceil(block_size);
2192        let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2193        // Fence alloc-ready → launch_stream for block_sums
2194        // before phase1 kernel writes it. The alloc was queued
2195        // on the manager's default stream; without this wait,
2196        // a launch_stream-queued kernel can begin before
2197        // cuMemAllocAsync completes and read pool-recycled
2198        // bytes when the streams differ.
2199        runtime
2200            .prepare_first_use(
2201                &block_sums,
2202                launch_stream,
2203                crate::device_runtime::Access::Write,
2204            )
2205            .map_err(|e| {
2206                XlogError::Kernel(format!(
2207                    "multiblock_scan_u32_inplace_on_stream: prepare block_sums failed: {}",
2208                    e
2209                ))
2210            })?;
2211
2212        let phase1_u32_fn = device
2213            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2214            .ok_or_else(|| {
2215                XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2216            })?;
2217        // SAFETY: kernel signature matches.
2218        unsafe {
2219            phase1_u32_fn.clone().launch_on_stream(
2220                cu_stream,
2221                LaunchConfig {
2222                    grid_dim: (num_blocks, 1, 1),
2223                    block_dim: (block_size, 1, 1),
2224                    shared_mem_bytes: 0,
2225                },
2226                (&mut *data, &mut block_sums, n),
2227            )
2228        }
2229        .map_err(|e| {
2230            XlogError::Kernel(format!(
2231                "multiblock_scan_u32_phase1 (on_stream) failed: {}",
2232                e
2233            ))
2234        })?;
2235
2236        if num_blocks > 1 {
2237            self.multiblock_scan_u32_inplace_on_stream(
2238                &mut block_sums,
2239                num_blocks,
2240                cu_stream,
2241                launch_stream,
2242                runtime,
2243            )?;
2244        }
2245
2246        let phase3_fn = device
2247            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2248            .ok_or_else(|| {
2249                XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2250            })?;
2251        // SAFETY: kernel signature matches.
2252        unsafe {
2253            phase3_fn.clone().launch_on_stream(
2254                cu_stream,
2255                LaunchConfig {
2256                    grid_dim: (num_blocks, 1, 1),
2257                    block_dim: (block_size, 1, 1),
2258                    shared_mem_bytes: 0,
2259                },
2260                (&mut *data, &block_sums, n),
2261            )
2262        }
2263        .map_err(|e| {
2264            XlogError::Kernel(format!("multiblock_scan_phase3 (on_stream) failed: {}", e))
2265        })?;
2266
2267        // Record `block_sums` use on `launch_stream` BEFORE it
2268        // drops at end-of-scope. Without this, the runtime's
2269        // deallocate would queue `cuMemFreeAsync` on alloc_stream
2270        // without waiting for the launch_stream chain that's
2271        // still reading/writing block_sums to complete.
2272        if let Some(b) = block_sums.runtime_block() {
2273            runtime
2274                .finish_block_use(
2275                    crate::device_runtime::BlockId::from_block(b),
2276                    launch_stream,
2277                    crate::device_runtime::Access::Write,
2278                )
2279                .map_err(|e| {
2280                    XlogError::Kernel(format!(
2281                        "multiblock_scan_u32_inplace_on_stream: finish_block_use \
2282                         for intermediate block_sums failed: {}",
2283                        e
2284                    ))
2285                })?;
2286        } else {
2287            return Err(XlogError::Kernel(
2288                "multiblock_scan_u32_inplace_on_stream: intermediate block_sums has no \
2289                 runtime block — caller must use a runtime-backed manager"
2290                    .to_string(),
2291            ));
2292        }
2293        Ok(())
2294    }
2295
2296    /// Allocate every recursive `block_sums` buffer needed by
2297    /// [`Self::multiblock_scan_u32_inplace_on_stream_with_scratch`].
2298    pub(crate) fn multiblock_scan_u32_scratch_for_len(
2299        &self,
2300        mut n: u32,
2301    ) -> Result<MultiblockScanScratchU32> {
2302        let block_size = 256u32;
2303        let mut levels = Vec::new();
2304        while n > block_size {
2305            let num_blocks = n.div_ceil(block_size);
2306            levels.push(self.memory.alloc::<u32>(num_blocks as usize)?);
2307            n = num_blocks;
2308        }
2309        Ok(MultiblockScanScratchU32 { levels })
2310    }
2311
2312    /// Stream-aware u32 scan with caller-owned scratch.
2313    ///
2314    /// This is the CUDA Graph compatible counterpart to
2315    /// [`Self::multiblock_scan_u32_inplace_on_stream`]: all scratch buffers are
2316    /// supplied by the caller, so graph capture sees a stable scan topology and
2317    /// stable intermediate addresses.
2318    pub(crate) fn multiblock_scan_u32_inplace_on_stream_with_scratch(
2319        &self,
2320        data: &mut crate::memory::TrackedCudaSlice<u32>,
2321        n: u32,
2322        cu_stream: &cudarc::driver::CudaStream,
2323        scratch: &mut MultiblockScanScratchU32,
2324    ) -> Result<()> {
2325        self.multiblock_scan_u32_inplace_on_stream_with_scratch_levels(
2326            data,
2327            n,
2328            cu_stream,
2329            &mut scratch.levels,
2330        )
2331    }
2332
2333    fn multiblock_scan_u32_inplace_on_stream_with_scratch_levels(
2334        &self,
2335        data: &mut crate::memory::TrackedCudaSlice<u32>,
2336        n: u32,
2337        cu_stream: &cudarc::driver::CudaStream,
2338        scratch_levels: &mut [TrackedCudaSlice<u32>],
2339    ) -> Result<()> {
2340        if n == 0 {
2341            return Ok(());
2342        }
2343        let device = self.device.inner();
2344        let block_size = 256u32;
2345
2346        if n <= block_size {
2347            let phase2_fn = device
2348                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2349                .ok_or_else(|| {
2350                    XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2351                })?;
2352            // SAFETY: kernel signature matches; data is mutated in place.
2353            unsafe {
2354                phase2_fn.clone().launch_on_stream(
2355                    cu_stream,
2356                    LaunchConfig {
2357                        grid_dim: (1, 1, 1),
2358                        block_dim: (block_size, 1, 1),
2359                        shared_mem_bytes: 0,
2360                    },
2361                    (&mut *data, n),
2362                )
2363            }
2364            .map_err(|e| {
2365                XlogError::Kernel(format!(
2366                    "multiblock_scan_phase2 (graph scratch) failed: {}",
2367                    e
2368                ))
2369            })?;
2370            return Ok(());
2371        }
2372
2373        let num_blocks = n.div_ceil(block_size);
2374        let (block_sums, rest) = scratch_levels.split_first_mut().ok_or_else(|| {
2375            XlogError::Kernel(format!(
2376                "multiblock_scan_u32_inplace_on_stream_with_scratch: missing scratch level \
2377                 for n={n}, num_blocks={num_blocks}"
2378            ))
2379        })?;
2380        if block_sums.len() < num_blocks as usize {
2381            return Err(XlogError::Kernel(format!(
2382                "multiblock_scan_u32_inplace_on_stream_with_scratch: scratch level too small \
2383                 (have {}, need {})",
2384                block_sums.len(),
2385                num_blocks
2386            )));
2387        }
2388
2389        let phase1_u32_fn = device
2390            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2391            .ok_or_else(|| {
2392                XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2393            })?;
2394        // SAFETY: kernel signature matches.
2395        unsafe {
2396            phase1_u32_fn.clone().launch_on_stream(
2397                cu_stream,
2398                LaunchConfig {
2399                    grid_dim: (num_blocks, 1, 1),
2400                    block_dim: (block_size, 1, 1),
2401                    shared_mem_bytes: 0,
2402                },
2403                (&mut *data, &mut *block_sums, n),
2404            )
2405        }
2406        .map_err(|e| {
2407            XlogError::Kernel(format!(
2408                "multiblock_scan_u32_phase1 (graph scratch) failed: {}",
2409                e
2410            ))
2411        })?;
2412
2413        if num_blocks > 1 {
2414            self.multiblock_scan_u32_inplace_on_stream_with_scratch_levels(
2415                block_sums, num_blocks, cu_stream, rest,
2416            )?;
2417        }
2418
2419        let phase3_fn = device
2420            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2421            .ok_or_else(|| {
2422                XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2423            })?;
2424        // SAFETY: kernel signature matches.
2425        unsafe {
2426            phase3_fn.clone().launch_on_stream(
2427                cu_stream,
2428                LaunchConfig {
2429                    grid_dim: (num_blocks, 1, 1),
2430                    block_dim: (block_size, 1, 1),
2431                    shared_mem_bytes: 0,
2432                },
2433                (&mut *data, &*block_sums, n),
2434            )
2435        }
2436        .map_err(|e| {
2437            XlogError::Kernel(format!(
2438                "multiblock_scan_phase3 (graph scratch) failed: {}",
2439                e
2440            ))
2441        })?;
2442        Ok(())
2443    }
2444
2445    /// Stream-aware view-inplace variant of
2446    /// [`Self::multiblock_scan_u32_view_inplace`]. Same shape
2447    /// as [`Self::multiblock_scan_u32_inplace_on_stream`] but
2448    /// over a `CudaViewMut` (used by recorded radix sort
2449    /// digit loops that scan per-digit slices of the histogram
2450    /// in place). Records intermediate `block_sums` against
2451    /// the runtime before they drop at end-of-scope.
2452    pub(crate) fn multiblock_scan_u32_view_inplace_on_stream(
2453        &self,
2454        data: &mut CudaViewMut<'_, u32>,
2455        n: u32,
2456        cu_stream: &cudarc::driver::CudaStream,
2457        launch_stream: crate::device_runtime::StreamId,
2458        runtime: &crate::device_runtime::XlogDeviceRuntime,
2459    ) -> Result<()> {
2460        if n == 0 {
2461            return Ok(());
2462        }
2463        let device = self.device.inner();
2464        let block_size = 256u32;
2465
2466        if n <= block_size {
2467            let phase2_fn = device
2468                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2469                .ok_or_else(|| {
2470                    XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2471                })?;
2472            // SAFETY: phase2 kernel signature.
2473            unsafe {
2474                phase2_fn.clone().launch_on_stream(
2475                    cu_stream,
2476                    LaunchConfig {
2477                        grid_dim: (1, 1, 1),
2478                        block_dim: (block_size, 1, 1),
2479                        shared_mem_bytes: 0,
2480                    },
2481                    (data, n),
2482                )
2483            }
2484            .map_err(|e| {
2485                XlogError::Kernel(format!(
2486                    "multiblock_scan_phase2 (view on_stream) failed: {}",
2487                    e
2488                ))
2489            })?;
2490            return Ok(());
2491        }
2492
2493        let num_blocks = n.div_ceil(block_size);
2494        let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2495        // Fence alloc-ready → launch_stream for block_sums
2496        // before phase1 kernel writes it. See the inplace
2497        // variant for the full rationale.
2498        runtime
2499            .prepare_first_use(
2500                &block_sums,
2501                launch_stream,
2502                crate::device_runtime::Access::Write,
2503            )
2504            .map_err(|e| {
2505                XlogError::Kernel(format!(
2506                    "multiblock_scan_u32_view_inplace_on_stream: prepare block_sums failed: {}",
2507                    e
2508                ))
2509            })?;
2510
2511        let phase1_u32_fn = device
2512            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2513            .ok_or_else(|| {
2514                XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2515            })?;
2516        // SAFETY: phase1 kernel signature.
2517        unsafe {
2518            phase1_u32_fn.clone().launch_on_stream(
2519                cu_stream,
2520                LaunchConfig {
2521                    grid_dim: (num_blocks, 1, 1),
2522                    block_dim: (block_size, 1, 1),
2523                    shared_mem_bytes: 0,
2524                },
2525                (&mut *data, &mut block_sums, n),
2526            )
2527        }
2528        .map_err(|e| {
2529            XlogError::Kernel(format!(
2530                "multiblock_scan_u32_phase1 (view on_stream) failed: {}",
2531                e
2532            ))
2533        })?;
2534
2535        if num_blocks > 1 {
2536            self.multiblock_scan_u32_inplace_on_stream(
2537                &mut block_sums,
2538                num_blocks,
2539                cu_stream,
2540                launch_stream,
2541                runtime,
2542            )?;
2543        }
2544
2545        let phase3_fn = device
2546            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2547            .ok_or_else(|| {
2548                XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2549            })?;
2550        // SAFETY: phase3 kernel signature.
2551        unsafe {
2552            phase3_fn.clone().launch_on_stream(
2553                cu_stream,
2554                LaunchConfig {
2555                    grid_dim: (num_blocks, 1, 1),
2556                    block_dim: (block_size, 1, 1),
2557                    shared_mem_bytes: 0,
2558                },
2559                (&mut *data, &block_sums, n),
2560            )
2561        }
2562        .map_err(|e| {
2563            XlogError::Kernel(format!(
2564                "multiblock_scan_phase3 (view on_stream) failed: {}",
2565                e
2566            ))
2567        })?;
2568
2569        // Record block_sums use before end-of-scope drop.
2570        if let Some(b) = block_sums.runtime_block() {
2571            runtime
2572                .finish_block_use(
2573                    crate::device_runtime::BlockId::from_block(b),
2574                    launch_stream,
2575                    crate::device_runtime::Access::Write,
2576                )
2577                .map_err(|e| {
2578                    XlogError::Kernel(format!(
2579                        "multiblock_scan_u32_view_inplace_on_stream: finish_block_use \
2580                     for intermediate block_sums failed: {}",
2581                        e
2582                    ))
2583                })?;
2584        } else {
2585            return Err(XlogError::Kernel(
2586                "multiblock_scan_u32_view_inplace_on_stream: intermediate block_sums has no \
2587                 runtime block — caller must use a runtime-backed manager"
2588                    .to_string(),
2589            ));
2590        }
2591        Ok(())
2592    }
2593
2594    fn multiblock_scan_u32_view_inplace(
2595        &self,
2596        data: &mut CudaViewMut<'_, u32>,
2597        n: u32,
2598    ) -> Result<()> {
2599        if n == 0 {
2600            return Ok(());
2601        }
2602
2603        let device = self.device.inner();
2604        let block_size = 256u32;
2605
2606        if n <= block_size {
2607            let phase2_fn = device
2608                .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE2)
2609                .ok_or_else(|| {
2610                    XlogError::Kernel("Failed to get multiblock_scan_phase2 kernel".to_string())
2611                })?;
2612
2613            // SAFETY: multiblock_scan_phase2(uint32_t* block_sums, uint32_t num_blocks)
2614            unsafe {
2615                phase2_fn.clone().launch(
2616                    LaunchConfig {
2617                        grid_dim: (1, 1, 1),
2618                        block_dim: (block_size, 1, 1),
2619                        shared_mem_bytes: 0,
2620                    },
2621                    (data, n),
2622                )
2623            }
2624            .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase2 failed: {}", e)))?;
2625
2626            self.device.synchronize()?;
2627            return Ok(());
2628        }
2629
2630        let num_blocks = n.div_ceil(block_size);
2631        let mut block_sums = self.memory.alloc::<u32>(num_blocks as usize)?;
2632
2633        let phase1_u32_fn = device
2634            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_U32_PHASE1)
2635            .ok_or_else(|| {
2636                XlogError::Kernel("Failed to get multiblock_scan_u32_phase1 kernel".to_string())
2637            })?;
2638
2639        // SAFETY: multiblock_scan_u32_phase1(uint32_t* data, uint32_t* block_sums, uint32_t n)
2640        unsafe {
2641            phase1_u32_fn.clone().launch(
2642                LaunchConfig {
2643                    grid_dim: (num_blocks, 1, 1),
2644                    block_dim: (block_size, 1, 1),
2645                    shared_mem_bytes: 0,
2646                },
2647                (&mut *data, &mut block_sums, n),
2648            )
2649        }
2650        .map_err(|e| XlogError::Kernel(format!("multiblock_scan_u32_phase1 failed: {}", e)))?;
2651        self.device.synchronize()?;
2652
2653        if num_blocks > 1 {
2654            self.multiblock_scan_u32_inplace(&mut block_sums, num_blocks)?;
2655        }
2656
2657        let phase3_fn = device
2658            .get_func(SCAN_MODULE, scan_kernels::MULTIBLOCK_SCAN_PHASE3)
2659            .ok_or_else(|| {
2660                XlogError::Kernel("Failed to get multiblock_scan_phase3 kernel".to_string())
2661            })?;
2662
2663        // SAFETY: multiblock_scan_phase3(uint32_t* prefix_sum, const uint32_t* block_offsets, uint32_t n)
2664        unsafe {
2665            phase3_fn.clone().launch(
2666                LaunchConfig {
2667                    grid_dim: (num_blocks, 1, 1),
2668                    block_dim: (block_size, 1, 1),
2669                    shared_mem_bytes: 0,
2670                },
2671                (&mut *data, &block_sums, n),
2672            )
2673        }
2674        .map_err(|e| XlogError::Kernel(format!("multiblock_scan_phase3 failed: {}", e)))?;
2675
2676        self.device.synchronize()?;
2677        Ok(())
2678    }
2679
2680    // ============== Internal Helper Methods ==============
2681
2682    /// Read a buffer's logical row count, using the host cache when available
2683    /// and falling back to a metadata-only device-to-host read when needed.
2684    pub fn device_row_count(&self, buffer: &CudaBuffer) -> Result<usize> {
2685        if let Some(n) = buffer.cached_row_count() {
2686            return Ok(n as usize);
2687        }
2688        let mut host_rows = [0u32];
2689        self.device
2690            .inner()
2691            .dtoh_sync_copy_into(buffer.num_rows_device(), &mut host_rows)
2692            .map_err(|e| XlogError::Kernel(format!("Failed to read row count: {}", e)))?;
2693        buffer.set_cached_row_count_if_unset(host_rows[0]);
2694        Ok(host_rows[0] as usize)
2695    }
2696
2697    /// Read and validate a buffer's logical row count for outward-facing APIs.
2698    ///
2699    /// This keeps exported/query-visible lengths tied to the device logical row
2700    /// count while still rejecting impossible metadata (`logical_rows > row_cap`).
2701    pub fn validated_logical_row_count(&self, buffer: &CudaBuffer) -> Result<usize> {
2702        let logical_rows = self.device_row_count(buffer)?;
2703        validate_logical_row_count(buffer.num_rows(), logical_rows)
2704    }
2705
2706    fn clone_device_row_count(&self, buffer: &CudaBuffer) -> Result<TrackedCudaSlice<u32>> {
2707        let mut d_num_rows = self.memory.alloc::<u32>(1)?;
2708        self.device
2709            .inner()
2710            .dtod_copy(buffer.num_rows_device(), &mut d_num_rows)
2711            .map_err(|e| XlogError::Kernel(format!("Failed to copy row count: {}", e)))?;
2712        Ok(d_num_rows)
2713    }
2714
2715    fn upload_device_row_count(&self, row_count: u32) -> Result<TrackedCudaSlice<u32>> {
2716        let mut d_num_rows = self.memory.alloc::<u32>(1)?;
2717        self.htod_launch_metadata_sync_copy_into(&[row_count], &mut d_num_rows)
2718            .map_err(|e| XlogError::Kernel(format!("Failed to upload row count: {}", e)))?;
2719        Ok(d_num_rows)
2720    }
2721
2722    fn buffer_from_columns_with_device_count(
2723        &self,
2724        columns: Vec<CudaColumn>,
2725        row_cap: u64,
2726        schema: Schema,
2727        src: &CudaBuffer,
2728    ) -> Result<CudaBuffer> {
2729        let d_num_rows = self.clone_device_row_count(src)?;
2730        Ok(CudaBuffer::from_columns(
2731            columns, row_cap, d_num_rows, schema,
2732        ))
2733    }
2734
2735    fn column_bytes_view<'a>(
2736        &self,
2737        col: &'a CudaColumn,
2738        num_bytes: usize,
2739    ) -> Result<RawCudaView<'a, u8>> {
2740        if col.num_bytes() < num_bytes {
2741            return Err(XlogError::Kernel(format!(
2742                "Column has {} bytes but {} required",
2743                col.num_bytes(),
2744                num_bytes
2745            )));
2746        }
2747        let ptr = *col.device_ptr();
2748        Ok(RawCudaView {
2749            ptr,
2750            len: num_bytes,
2751            stream: col.stream().clone(),
2752            source_block: col.runtime_block(),
2753            _marker: PhantomData,
2754        })
2755    }
2756
2757    fn bytes_as_u32_view<'a>(
2758        &self,
2759        bytes: &'a TrackedCudaSlice<u8>,
2760        num_elements: usize,
2761    ) -> Result<RawCudaView<'a, u32>> {
2762        let required_bytes = num_elements * std::mem::size_of::<u32>();
2763        if bytes.len() < required_bytes {
2764            return Err(XlogError::Kernel(format!(
2765                "Packed keys have {} bytes but {} required for {} u32 elements",
2766                bytes.len(),
2767                required_bytes,
2768                num_elements
2769            )));
2770        }
2771        let ptr = *bytes.device_ptr();
2772        if !(ptr as usize).is_multiple_of(std::mem::align_of::<u32>()) {
2773            return Err(XlogError::Kernel(
2774                "Packed keys device pointer is not u32-aligned".to_string(),
2775            ));
2776        }
2777        Ok(RawCudaView {
2778            ptr,
2779            len: num_elements,
2780            stream: bytes.stream().clone(),
2781            source_block: bytes.runtime_block(),
2782            _marker: PhantomData,
2783        })
2784    }
2785
2786    /// Reinterpret a `CudaBuffer` column as a `u32` slice for kernel access.
2787    fn column_as_u32_view<'a>(
2788        &self,
2789        col: &'a CudaColumn,
2790        num_elements: usize,
2791    ) -> Result<RawCudaView<'a, u32>> {
2792        let required_bytes = num_elements * std::mem::size_of::<u32>();
2793        if col.num_bytes() < required_bytes {
2794            return Err(XlogError::Kernel(format!(
2795                "Column has {} bytes but {} required for {} u32 elements",
2796                col.num_bytes(),
2797                required_bytes,
2798                num_elements
2799            )));
2800        }
2801        let ptr = *col.device_ptr();
2802        if !(ptr as usize).is_multiple_of(std::mem::align_of::<u32>()) {
2803            return Err(XlogError::Kernel(
2804                "Column device pointer is not u32-aligned".to_string(),
2805            ));
2806        }
2807        Ok(RawCudaView {
2808            ptr,
2809            len: num_elements,
2810            stream: col.stream().clone(),
2811            source_block: col.runtime_block(),
2812            _marker: PhantomData,
2813        })
2814    }
2815
2816    fn column_as_u64_view<'a>(
2817        &self,
2818        col: &'a CudaColumn,
2819        num_elements: usize,
2820    ) -> Result<RawCudaView<'a, u64>> {
2821        let required_bytes = num_elements * std::mem::size_of::<u64>();
2822        if col.num_bytes() < required_bytes {
2823            return Err(XlogError::Kernel(format!(
2824                "Column has {} bytes but {} required for {} u64 elements",
2825                col.num_bytes(),
2826                required_bytes,
2827                num_elements
2828            )));
2829        }
2830        let ptr = *col.device_ptr();
2831        if !(ptr as usize).is_multiple_of(std::mem::align_of::<u64>()) {
2832            return Err(XlogError::Kernel(
2833                "Column device pointer is not u64-aligned".to_string(),
2834            ));
2835        }
2836        Ok(RawCudaView {
2837            ptr,
2838            len: num_elements,
2839            stream: col.stream().clone(),
2840            source_block: col.runtime_block(),
2841            _marker: PhantomData,
2842        })
2843    }
2844
2845    /// Reinterpret a `CudaBuffer` column as an `f64` slice for kernel access.
2846    fn column_as_f64_view<'a>(
2847        &self,
2848        col: &'a CudaColumn,
2849        num_elements: usize,
2850    ) -> Result<RawCudaView<'a, f64>> {
2851        let required_bytes = num_elements * std::mem::size_of::<f64>();
2852        if col.num_bytes() < required_bytes {
2853            return Err(XlogError::Kernel(format!(
2854                "Column has {} bytes but {} required for {} f64 elements",
2855                col.num_bytes(),
2856                required_bytes,
2857                num_elements
2858            )));
2859        }
2860        let ptr = *col.device_ptr();
2861        if !(ptr as usize).is_multiple_of(std::mem::align_of::<f64>()) {
2862            return Err(XlogError::Kernel(
2863                "Column device pointer is not f64-aligned".to_string(),
2864            ));
2865        }
2866        Ok(RawCudaView {
2867            ptr,
2868            len: num_elements,
2869            stream: col.stream().clone(),
2870            source_block: col.runtime_block(),
2871            _marker: PhantomData,
2872        })
2873    }
2874
2875    /// Create an empty buffer with the given schema (all columns are empty slices)
2876    ///
2877    /// # Arguments
2878    /// * `schema` - The schema for the empty buffer
2879    ///
2880    /// # Returns
2881    /// A new CudaBuffer with zero rows
2882    ///
2883    /// # Errors
2884    /// Returns `XlogError::Kernel` if allocation fails
2885    pub fn create_empty_buffer(&self, schema: Schema) -> Result<CudaBuffer> {
2886        let mut columns = Vec::with_capacity(schema.arity());
2887        for _ in 0..schema.arity() {
2888            // Allocate zero-length column
2889            columns.push(self.memory.alloc::<u8>(0)?.into());
2890        }
2891        self.buffer_from_columns(columns, 0, schema)
2892    }
2893
2894    /// Create a zero-arity (nullary) relation buffer carrying `rows` unit tuples.
2895    ///
2896    /// A nullary relation holds exactly when it has at least one row; its single
2897    /// possible tuple is the empty tuple `()`. `create_buffer_from_slices` with no
2898    /// column slices routes to `create_empty_buffer` (0 rows), which represents the
2899    /// relation as *absent* — wrong for an asserted nullary fact. Nullary facts must
2900    /// use this path so presence is materialized as one row.
2901    pub fn create_zero_arity_buffer(&self, schema: Schema, rows: u32) -> Result<CudaBuffer> {
2902        debug_assert_eq!(
2903            schema.arity(),
2904            0,
2905            "create_zero_arity_buffer requires arity 0"
2906        );
2907        self.buffer_from_columns(Vec::new(), u64::from(rows), schema)
2908    }
2909
2910    pub(crate) fn buffer_from_columns(
2911        &self,
2912        columns: Vec<CudaColumn>,
2913        row_cap: u64,
2914        schema: Schema,
2915    ) -> Result<CudaBuffer> {
2916        let row_u32 = u32::try_from(row_cap)
2917            .map_err(|_| XlogError::Kernel(format!("Row capacity {} exceeds u32::MAX", row_cap)))?;
2918        let mut d_num_rows = self.memory.alloc::<u32>(1)?;
2919        self.htod_launch_metadata_sync_copy_into(&[row_u32], &mut d_num_rows)
2920            .map_err(|e| XlogError::Kernel(format!("Failed to set row count: {}", e)))?;
2921        Ok(CudaBuffer::from_columns_with_host_count(
2922            columns, row_cap, d_num_rows, schema, row_u32,
2923        ))
2924    }
2925
2926    /// Combine schemas from left and right buffers for join result
2927    fn combine_schemas(&self, left: &Schema, right: &Schema) -> Schema {
2928        let mut columns = left.columns.clone();
2929        columns.extend(right.columns.iter().cloned());
2930        let mut sort_labels = left.sort_labels().to_vec();
2931        sort_labels.extend(right.sort_labels().iter().cloned());
2932        Schema::new(columns)
2933            .with_sort_labels(sort_labels)
2934            .expect("combined schema sort labels match column arity")
2935    }
2936
2937    /// Check if two schemas have compatible types (same arity and column types)
2938    ///
2939    /// This ignores column names, which is useful for Datalog operations where
2940    /// projected relations may have different column names but the same types.
2941    fn schemas_type_compatible(&self, a: &Schema, b: &Schema) -> bool {
2942        if a.arity() != b.arity() {
2943            return false;
2944        }
2945        for i in 0..a.arity() {
2946            if a.column_type(i) != b.column_type(i) {
2947                return false;
2948            }
2949        }
2950        true
2951    }
2952}
2953
2954#[cfg(test)]
2955mod tests {
2956    use super::*;
2957    use crate::device_runtime::{
2958        AsyncCudaResource, DeviceMemoryResource, GlobalDeviceBudget, LoggingResource, NullSink,
2959        StreamPool, XlogDeviceRuntime,
2960    };
2961    use xlog_core::{AggOp, MemoryBudget, ScalarType};
2962
2963    fn has_cuda_device() -> bool {
2964        CudaDevice::new(0).is_ok()
2965    }
2966
2967    #[test]
2968    fn test_kernel_artifact_locator_precedence_order() {
2969        use super::kernel_paths::KernelArtifactLocator;
2970        use std::fs;
2971        use std::path::PathBuf;
2972
2973        let root = std::env::temp_dir().join(format!(
2974            "xlog-kernel-paths-{}-{}",
2975            std::process::id(),
2976            std::time::SystemTime::now()
2977                .duration_since(std::time::UNIX_EPOCH)
2978                .expect("system clock before UNIX_EPOCH")
2979                .as_nanos()
2980        ));
2981        let cubin_dir = root.join("cubin");
2982        let package_dir = root.join("bin").join("kernels");
2983        let out_dir = root.join("out");
2984        fs::create_dir_all(&cubin_dir).expect("create cubin dir");
2985        fs::create_dir_all(&package_dir).expect("create package kernels dir");
2986        fs::create_dir_all(&out_dir).expect("create out dir");
2987
2988        let name = "xlog_join";
2989        let cc = 75;
2990        let cubin_path = cubin_dir.join(format!("{name}.sm_{cc}.cubin"));
2991        let package_path = package_dir.join(format!("{name}.sm_{cc}.cubin"));
2992        let out_path = out_dir.join(format!("{name}.sm_{cc}.cubin"));
2993        fs::write(&cubin_path, b"cubin").expect("write cubin file");
2994        fs::write(&package_path, b"package").expect("write package file");
2995        fs::write(&out_path, b"out").expect("write out file");
2996
2997        let locator = KernelArtifactLocator::new(
2998            Some(cubin_dir.clone()),
2999            Some(package_dir.clone()),
3000            Some(out_dir.clone()),
3001        );
3002
3003        let (path, is_cubin) = locator
3004            .resolve_module_path(name, cc)
3005            .expect("expected a kernel artifact");
3006        assert_eq!(path, cubin_path);
3007        assert!(is_cubin);
3008
3009        fs::remove_file(&cubin_path).expect("remove cubin file");
3010        let (path, is_cubin) = locator
3011            .resolve_module_path(name, cc)
3012            .expect("expected package kernel artifact");
3013        assert_eq!(path, package_path);
3014        assert!(is_cubin);
3015
3016        fs::remove_file(&package_path).expect("remove package file");
3017        let (path, is_cubin) = locator
3018            .resolve_module_path(name, cc)
3019            .expect("expected out dir kernel artifact");
3020        assert_eq!(path, out_path);
3021        assert!(is_cubin);
3022
3023        let _ = fs::remove_dir_all(PathBuf::from(&root));
3024    }
3025
3026    #[test]
3027    fn test_module_resolution_finds_portable_ptx() {
3028        // Verify resolve_module_path finds portable PTX for all modules.
3029        // Uses a dummy cc (999) so cubin won't match — only portable PTX.
3030        for name in crate::kernel_manifest_data::KERNEL_CU_NAMES {
3031            let result = resolve_module_path(name, 999);
3032            assert!(
3033                result.is_some(),
3034                "resolve_module_path({name}, 999) should find portable PTX"
3035            );
3036            let (path, is_cubin) = result.unwrap();
3037            assert!(
3038                !is_cubin,
3039                "{name}: expected portable PTX fallback, got cubin"
3040            );
3041            assert!(
3042                path.to_str().unwrap().ends_with(".portable.ptx"),
3043                "{name}: path should end with .portable.ptx, got {:?}",
3044                path
3045            );
3046        }
3047    }
3048
3049    #[test]
3050    fn test_module_resolution_falls_back_to_embedded_portable_ptx() {
3051        use super::kernel_paths::KernelArtifactLocator;
3052
3053        let locator = KernelArtifactLocator::new(None, None, None);
3054        for name in crate::kernel_manifest_data::KERNEL_CU_NAMES {
3055            let sources = resolve_module_sources_with_locator(name, 999, &locator);
3056            assert_eq!(
3057                sources.len(),
3058                1,
3059                "{name}: expected only embedded portable PTX fallback"
3060            );
3061
3062            match &sources[0] {
3063                KernelModuleSource::EmbeddedPortablePtx { ptx } => {
3064                    assert!(
3065                        ptx.contains(".entry"),
3066                        "{name}: embedded PTX should contain CUDA entry points"
3067                    );
3068                }
3069                KernelModuleSource::File { path, .. } => {
3070                    panic!(
3071                        "{name}: expected embedded portable PTX fallback, got file {}",
3072                        path.display()
3073                    );
3074                }
3075            }
3076        }
3077    }
3078
3079    #[test]
3080    fn test_embedded_portable_ptx_manifest_matches_kernel_manifest() {
3081        let embedded_names: std::collections::BTreeSet<_> =
3082            crate::embedded_kernel_data::EMBEDDED_PORTABLE_PTX
3083                .iter()
3084                .map(|artifact| artifact.name)
3085                .collect();
3086        let manifest_names: std::collections::BTreeSet<_> =
3087            crate::kernel_manifest_data::KERNEL_CU_NAMES
3088                .iter()
3089                .copied()
3090                .collect();
3091
3092        assert_eq!(
3093            embedded_names, manifest_names,
3094            "embedded portable PTX table should cover every runtime kernel module"
3095        );
3096    }
3097
3098    #[test]
3099    fn test_kernel_provider_creation() {
3100        if !has_cuda_device() {
3101            eprintln!("Skipping test: no CUDA device available");
3102            return;
3103        }
3104
3105        let device = Arc::new(CudaDevice::new(0).expect("Failed to create device"));
3106        let budget = MemoryBudget::with_limit(1024 * 1024 * 1024); // 1 GB
3107        let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
3108
3109        let provider = CudaKernelProvider::new(device.clone(), memory.clone());
3110        assert!(
3111            provider.is_ok(),
3112            "Failed to create kernel provider: {:?}",
3113            provider.err()
3114        );
3115
3116        let provider = provider.unwrap();
3117        assert!(Arc::ptr_eq(provider.device(), &device));
3118        assert!(Arc::ptr_eq(provider.memory(), &memory));
3119    }
3120
3121    #[test]
3122    fn test_kernel_functions_accessible() {
3123        if !has_cuda_device() {
3124            eprintln!("Skipping test: no CUDA device available");
3125            return;
3126        }
3127
3128        let device = Arc::new(CudaDevice::new(0).expect("Failed to create device"));
3129        let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
3130        let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
3131
3132        let _provider =
3133            CudaKernelProvider::new(device.clone(), memory).expect("Failed to create provider");
3134
3135        // Verify all kernel functions can be retrieved
3136        let inner = device.inner();
3137
3138        // Join kernels
3139        let build_fn = inner.get_func(JOIN_MODULE, join_kernels::HASH_JOIN_BUILD);
3140        assert!(
3141            build_fn.is_some(),
3142            "hash_join_build function should be accessible"
3143        );
3144
3145        let probe_fn = inner.get_func(JOIN_MODULE, join_kernels::HASH_JOIN_PROBE);
3146        assert!(
3147            probe_fn.is_some(),
3148            "hash_join_probe function should be accessible"
3149        );
3150
3151        // Dedup kernels
3152        let mark_fn = inner.get_func(DEDUP_MODULE, dedup_kernels::MARK_DUPLICATES);
3153        assert!(
3154            mark_fn.is_some(),
3155            "mark_duplicates function should be accessible"
3156        );
3157
3158        let compact_fn = inner.get_func(DEDUP_MODULE, dedup_kernels::COMPACT_ROWS);
3159        assert!(
3160            compact_fn.is_some(),
3161            "compact_rows function should be accessible"
3162        );
3163
3164        // GroupBy kernels
3165        let boundaries_fn =
3166            inner.get_func(GROUPBY_MODULE, groupby_kernels::DETECT_GROUP_BOUNDARIES);
3167        assert!(
3168            boundaries_fn.is_some(),
3169            "detect_group_boundaries function should be accessible"
3170        );
3171
3172        let count_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_COUNT);
3173        assert!(
3174            count_fn.is_some(),
3175            "groupby_count function should be accessible"
3176        );
3177
3178        let sum_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_SUM);
3179        assert!(
3180            sum_fn.is_some(),
3181            "groupby_sum function should be accessible"
3182        );
3183
3184        let min_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MIN);
3185        assert!(
3186            min_fn.is_some(),
3187            "groupby_min function should be accessible"
3188        );
3189
3190        let max_fn = inner.get_func(GROUPBY_MODULE, groupby_kernels::GROUPBY_MAX);
3191        assert!(
3192            max_fn.is_some(),
3193            "groupby_max function should be accessible"
3194        );
3195
3196        // Circuit kernels (XGCF forward/backward)
3197        let xgcf_forward = inner.get_func(CIRCUIT_MODULE, "xgcf_forward_level");
3198        assert!(
3199            xgcf_forward.is_some(),
3200            "xgcf_forward_level function should be accessible"
3201        );
3202
3203        let xgcf_backward_propagate =
3204            inner.get_func(CIRCUIT_MODULE, "xgcf_backward_level_propagate");
3205        assert!(
3206            xgcf_backward_propagate.is_some(),
3207            "xgcf_backward_level_propagate function should be accessible"
3208        );
3209
3210        let xgcf_backward_decision_grad =
3211            inner.get_func(CIRCUIT_MODULE, "xgcf_backward_level_decision_grad");
3212        assert!(
3213            xgcf_backward_decision_grad.is_some(),
3214            "xgcf_backward_level_decision_grad function should be accessible"
3215        );
3216
3217        let xgcf_backward_lit_grad = inner.get_func(CIRCUIT_MODULE, "xgcf_backward_level_lit_grad");
3218        assert!(
3219            xgcf_backward_lit_grad.is_some(),
3220            "xgcf_backward_level_lit_grad function should be accessible"
3221        );
3222
3223        // Neural fast-path kernels (AD chain weight fill + gradient scatter)
3224        let neural_fill = inner.get_func("xlog_neural", "neural_fill_ad_chain_f32");
3225        assert!(
3226            neural_fill.is_some(),
3227            "neural_fill_ad_chain_f32 function should be accessible"
3228        );
3229        let neural_scatter = inner.get_func("xlog_neural", "neural_scatter_ad_chain_grads_f32");
3230        assert!(
3231            neural_scatter.is_some(),
3232            "neural_scatter_ad_chain_grads_f32 function should be accessible"
3233        );
3234    }
3235
3236    #[test]
3237    fn test_module_names_unique() {
3238        // Ensure module names don't collide
3239        assert_ne!(JOIN_MODULE, DEDUP_MODULE);
3240        assert_ne!(JOIN_MODULE, GROUPBY_MODULE);
3241        assert_ne!(DEDUP_MODULE, GROUPBY_MODULE);
3242    }
3243
3244    // Helper function to create test provider
3245    fn create_test_provider() -> Option<CudaKernelProvider> {
3246        if !has_cuda_device() {
3247            return None;
3248        }
3249        let device = Arc::new(CudaDevice::new(0).ok()?);
3250        let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
3251        let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
3252        CudaKernelProvider::new(device, memory).ok()
3253    }
3254
3255    fn create_test_provider_with_runtime() -> Option<(CudaKernelProvider, Arc<XlogDeviceRuntime>)> {
3256        if !has_cuda_device() {
3257            return None;
3258        }
3259        let device = Arc::new(CudaDevice::new(0).ok()?);
3260        let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
3261        let sink = Arc::new(NullSink::new());
3262        let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> = Box::new(
3263            AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool)),
3264        );
3265        let logging: Box<dyn DeviceMemoryResource + Send + Sync> =
3266            Box::new(LoggingResource::new(async_resource, sink));
3267        let budget: Box<dyn DeviceMemoryResource + Send + Sync> =
3268            Box::new(GlobalDeviceBudget::new(logging, 1024 * 1024 * 1024));
3269        let runtime = Arc::new(XlogDeviceRuntime::with_resource(
3270            Arc::clone(&device),
3271            0,
3272            pool,
3273            budget,
3274        ));
3275        let memory = Arc::new(GpuMemoryManager::with_runtime(
3276            Arc::clone(&device),
3277            MemoryBudget::with_limit(1024 * 1024 * 1024),
3278            Arc::clone(&runtime),
3279        ));
3280        let provider = CudaKernelProvider::with_runtime(device, memory).ok()?;
3281        Some((provider, runtime))
3282    }
3283
3284    #[test]
3285    fn test_recorded_join_index_build_runs_on_runtime_stream() {
3286        let (provider, runtime) = match create_test_provider_with_runtime() {
3287            Some(fixture) => fixture,
3288            None => {
3289                eprintln!("Skipping test: no CUDA device available");
3290                return;
3291            }
3292        };
3293        let stream = runtime.stream_pool().acquire().expect("recorded stream");
3294        let left = create_test_buffer(&provider, &[1, 2, 3, 4], "key");
3295        let right = create_test_buffer(&provider, &[1, 2, 3, 4], "key");
3296
3297        let index = provider
3298            .build_join_index_v2_recorded(&right, &[0], stream)
3299            .expect("recorded join-index build");
3300        let joined = provider
3301            .hash_join_v2_with_index_recorded(
3302                &left,
3303                &right,
3304                &[0],
3305                &[0],
3306                JoinType::Inner,
3307                &index,
3308                None,
3309                stream,
3310            )
3311            .expect("recorded indexed join consumes recorded build");
3312        runtime
3313            .stream_pool()
3314            .resolve(stream)
3315            .expect("stream resolves")
3316            .synchronize()
3317            .expect("recorded stream synchronized");
3318
3319        assert_eq!(index.right_num_rows(), 4);
3320        assert_eq!(index.right_keys(), &[0]);
3321        assert_eq!(provider.device_row_count(&joined).expect("joined rows"), 4);
3322    }
3323
3324    // Helper function to create a CudaBuffer with U32 data
3325    fn create_test_buffer(
3326        provider: &CudaKernelProvider,
3327        data: &[u32],
3328        col_name: &str,
3329    ) -> CudaBuffer {
3330        let schema = Schema::new(vec![(col_name.to_string(), ScalarType::U32)]);
3331        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
3332
3333        let mut col = provider.memory().alloc::<u8>(bytes.len()).expect("alloc");
3334        provider
3335            .device()
3336            .inner()
3337            .htod_sync_copy_into(&bytes, &mut col)
3338            .expect("htod");
3339
3340        provider
3341            .buffer_from_columns(vec![col.into()], data.len() as u64, schema)
3342            .expect("buffer")
3343    }
3344
3345    // Helper function to create an empty buffer with correct column count
3346    fn create_empty_test_buffer(provider: &CudaKernelProvider, schema: Schema) -> CudaBuffer {
3347        let mut columns = Vec::with_capacity(schema.arity());
3348        for _ in 0..schema.arity() {
3349            columns.push(provider.memory().alloc::<u8>(0).expect("alloc").into());
3350        }
3351        provider
3352            .buffer_from_columns(columns, 0, schema)
3353            .expect("buffer")
3354    }
3355
3356    // Helper function to read U32 data from CudaBuffer
3357    fn read_buffer_u32(provider: &CudaKernelProvider, buffer: &CudaBuffer, col: usize) -> Vec<u32> {
3358        if buffer.is_empty() || buffer.column(col).is_none() {
3359            return vec![];
3360        }
3361        let num_rows = buffer.num_rows() as usize;
3362        let mut bytes = vec![0u8; num_rows * 4];
3363        provider
3364            .device()
3365            .inner()
3366            .dtoh_sync_copy_into(buffer.column(col).unwrap(), &mut bytes)
3367            .expect("dtoh");
3368        bytes
3369            .chunks_exact(4)
3370            .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
3371            .collect()
3372    }
3373
3374    #[test]
3375    fn test_compact_device_mask_respects_mask_len_smaller_than_row_cap() {
3376        let provider = match create_test_provider() {
3377            Some(p) => p,
3378            None => {
3379                eprintln!("Skipping test: no CUDA device available");
3380                return;
3381            }
3382        };
3383
3384        let schema = Schema::new(vec![("id".to_string(), ScalarType::U32)]);
3385        let base = create_test_buffer(&provider, &[1, 2, 3, 4, 5, 6, 7, 8], "id");
3386
3387        let row_cap = 16u64;
3388        let data: Vec<u32> = (0..row_cap as u32).collect();
3389        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
3390        let mut col = provider.memory().alloc::<u8>(bytes.len()).expect("alloc");
3391        provider
3392            .device()
3393            .inner()
3394            .htod_sync_copy_into(&bytes, &mut col)
3395            .expect("htod");
3396        let expanded = provider
3397            .buffer_from_columns_with_device_count(vec![col.into()], row_cap, schema, &base)
3398            .expect("buffer");
3399
3400        let mask: Vec<u8> = vec![1, 0, 1, 0, 1, 0, 1, 0];
3401        let (prefix_sum, count) = provider.prefix_sum_mask(&mask).expect("prefix sum");
3402
3403        let mut d_mask = provider.memory().alloc::<u8>(mask.len()).expect("alloc");
3404        provider
3405            .device()
3406            .inner()
3407            .htod_sync_copy_into(&mask, &mut d_mask)
3408            .expect("mask htod");
3409
3410        let mut d_prefix = provider
3411            .memory()
3412            .alloc::<u32>(prefix_sum.len())
3413            .expect("alloc");
3414        provider
3415            .device()
3416            .inner()
3417            .htod_sync_copy_into(&prefix_sum, &mut d_prefix)
3418            .expect("prefix htod");
3419
3420        let mut d_out_count = provider.memory().alloc::<u32>(1).expect("alloc");
3421        provider
3422            .device()
3423            .inner()
3424            .htod_sync_copy_into(&[count], &mut d_out_count)
3425            .expect("count htod");
3426
3427        let compacted = provider
3428            .compact_buffer_by_device_mask_device_count(&expanded, &d_mask, &d_prefix, d_out_count)
3429            .expect("compact");
3430
3431        assert_eq!(compacted.num_rows(), mask.len() as u64);
3432        let device_rows = provider.device_row_count(&compacted).expect("row count");
3433        assert_eq!(device_rows as u32, count);
3434    }
3435
3436    #[test]
3437    fn test_clone_buffer_preserves_device_count() {
3438        let provider = match create_test_provider() {
3439            Some(p) => p,
3440            None => {
3441                eprintln!("Skipping test: no CUDA device available");
3442                return;
3443            }
3444        };
3445
3446        let schema = Schema::new(vec![("id".to_string(), ScalarType::U32)]);
3447        let ids: Vec<u32> = vec![10, 20, 30];
3448        let buffer = provider
3449            .create_buffer_from_slices(&[bytemuck::cast_slice(&ids)], schema)
3450            .unwrap();
3451
3452        let cloned = provider.clone_buffer(&buffer).unwrap();
3453
3454        let mut host_count = [0u32];
3455        provider
3456            .device()
3457            .inner()
3458            .dtoh_sync_copy_into(cloned.num_rows_device(), &mut host_count)
3459            .unwrap();
3460        assert_eq!(host_count[0], 3);
3461    }
3462
3463    /// `clone_buffer` must propagate the host-side `cached_row_count` so
3464    /// downstream code can read the row count without a D2H round-trip.
3465    /// Without this propagation, buffers flowed through the relation store
3466    /// (`CompiledIlpProgram::put_relation` calls `clone_buffer` before
3467    /// storing) lose their host-visible count, forcing consumers to choose
3468    /// between an extra D2H (violating the native bounded exact-induction
3469    /// transfer-budget gates) and a hard error. This test pins the cache-propagation
3470    /// contract directly.
3471    #[test]
3472    fn test_clone_buffer_preserves_cached_row_count() {
3473        let provider = match create_test_provider() {
3474            Some(p) => p,
3475            None => {
3476                eprintln!("Skipping test: no CUDA device available");
3477                return;
3478            }
3479        };
3480
3481        let schema = Schema::new(vec![("id".to_string(), ScalarType::U32)]);
3482        let ids: Vec<u32> = vec![7, 11, 13, 17];
3483        let source = provider
3484            .create_buffer_from_slices(&[bytemuck::cast_slice(&ids)], schema)
3485            .unwrap();
3486        // Source's cache is populated by the `create_buffer_from_*` path;
3487        // verify the precondition so a regression in that path shows up here
3488        // rather than silently passing the real assertion below.
3489        assert_eq!(
3490            source.cached_row_count(),
3491            Some(4),
3492            "source buffer should have its cached row count populated by \
3493             create_buffer_from_slices"
3494        );
3495
3496        let cloned = provider.clone_buffer(&source).unwrap();
3497
3498        assert_eq!(
3499            cloned.cached_row_count(),
3500            Some(4),
3501            "clone_buffer must propagate cached_row_count from source to clone",
3502        );
3503    }
3504
3505    // ============== Hash Join Tests ==============
3506
3507    #[test]
3508    fn test_hash_join_empty_inputs() {
3509        let provider = match create_test_provider() {
3510            Some(p) => p,
3511            None => {
3512                eprintln!("Skipping test: no CUDA device available");
3513                return;
3514            }
3515        };
3516
3517        let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3518        let empty = create_empty_test_buffer(&provider, schema.clone());
3519
3520        // Join empty with empty
3521        let result = provider.hash_join(&empty, &empty, &[0], &[0]);
3522        assert!(result.is_ok());
3523        assert!(result.unwrap().is_empty());
3524    }
3525
3526    #[test]
3527    fn test_hash_join_validation() {
3528        let provider = match create_test_provider() {
3529            Some(p) => p,
3530            None => {
3531                eprintln!("Skipping test: no CUDA device available");
3532                return;
3533            }
3534        };
3535
3536        let left = create_test_buffer(&provider, &[1, 2, 3], "left_key");
3537        let right = create_test_buffer(&provider, &[2, 3, 4], "right_key");
3538
3539        // Empty key columns
3540        let result = provider.hash_join(&left, &right, &[], &[0]);
3541        assert!(result.is_err());
3542
3543        // Mismatched key lengths
3544        let result = provider.hash_join(&left, &right, &[0], &[0, 0]);
3545        assert!(result.is_err());
3546    }
3547
3548    // ============== Dedup Tests ==============
3549
3550    #[test]
3551    fn test_dedup_empty_input() {
3552        let provider = match create_test_provider() {
3553            Some(p) => p,
3554            None => {
3555                eprintln!("Skipping test: no CUDA device available");
3556                return;
3557            }
3558        };
3559
3560        let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3561        let empty = create_empty_test_buffer(&provider, schema);
3562
3563        let result = provider.dedup(&empty, &[0]);
3564        assert!(result.is_ok());
3565        assert!(result.unwrap().is_empty());
3566    }
3567
3568    #[test]
3569    fn test_dedup_validation() {
3570        let provider = match create_test_provider() {
3571            Some(p) => p,
3572            None => {
3573                eprintln!("Skipping test: no CUDA device available");
3574                return;
3575            }
3576        };
3577
3578        let buffer = create_test_buffer(&provider, &[1, 1, 2, 2, 3], "key");
3579
3580        // Empty key columns
3581        let result = provider.dedup(&buffer, &[]);
3582        assert!(result.is_err());
3583    }
3584
3585    #[test]
3586    fn test_dedup_with_duplicates() {
3587        let provider = match create_test_provider() {
3588            Some(p) => p,
3589            None => {
3590                eprintln!("Skipping test: no CUDA device available");
3591                return;
3592            }
3593        };
3594
3595        // Test dedup with duplicates: [3, 1, 2, 1, 3, 2]
3596        let buffer = create_test_buffer(&provider, &[3, 1, 2, 1, 3, 2], "key");
3597        let deduped = provider.dedup(&buffer, &[0]).unwrap();
3598
3599        let dedup_count = provider
3600            .device_row_count(&deduped)
3601            .expect("read dedup row count");
3602        assert_eq!(dedup_count, 3, "Should have 3 unique values");
3603
3604        let result = provider.download_column::<u32>(&deduped, 0).unwrap();
3605        // Result should be sorted and deduped
3606        assert_eq!(result, vec![1, 2, 3]);
3607    }
3608
3609    #[test]
3610    fn test_dedup_larger_input() {
3611        let provider = match create_test_provider() {
3612            Some(p) => p,
3613            None => {
3614                eprintln!("Skipping test: no CUDA device available");
3615                return;
3616            }
3617        };
3618
3619        // Create input with duplicates: 0..500 ++ 250..750 = 1000 elements, 750 unique
3620        let a: Vec<u32> = (0..500).collect();
3621        let b: Vec<u32> = (250..750).collect();
3622        let input: Vec<u32> = a.iter().chain(b.iter()).copied().collect();
3623
3624        let buffer = create_test_buffer(&provider, &input, "key");
3625        let deduped = provider.dedup(&buffer, &[0]).unwrap();
3626
3627        let dedup_count = provider
3628            .device_row_count(&deduped)
3629            .expect("read dedup row count");
3630        assert_eq!(dedup_count, 750, "Should have 750 unique values (0..750)");
3631
3632        // Verify output is sorted
3633        let result = provider.download_column::<u32>(&deduped, 0).unwrap();
3634        let is_sorted = result.windows(2).all(|w| w[0] <= w[1]);
3635        assert!(is_sorted, "Output should be sorted");
3636
3637        // Verify expected values
3638        let expected: Vec<u32> = (0..750).collect();
3639        assert_eq!(result, expected);
3640    }
3641
3642    // ============== Union Tests ==============
3643
3644    #[test]
3645    fn test_union_empty_inputs() {
3646        let provider = match create_test_provider() {
3647            Some(p) => p,
3648            None => {
3649                eprintln!("Skipping test: no CUDA device available");
3650                return;
3651            }
3652        };
3653
3654        let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3655        let empty = create_empty_test_buffer(&provider, schema.clone());
3656
3657        // Empty union empty
3658        let result = provider.union(&empty, &empty);
3659        assert!(result.is_ok());
3660        assert!(result.unwrap().is_empty());
3661
3662        // Non-empty union empty
3663        let a = create_test_buffer(&provider, &[1, 2, 3], "key");
3664        let empty2 = create_empty_test_buffer(&provider, schema);
3665        let result = provider.union(&a, &empty2);
3666        assert!(result.is_ok());
3667        let result = result.unwrap();
3668        assert_eq!(result.num_rows(), 3);
3669    }
3670
3671    #[test]
3672    fn test_union_schema_type_mismatch() {
3673        let provider = match create_test_provider() {
3674            Some(p) => p,
3675            None => {
3676                eprintln!("Skipping test: no CUDA device available");
3677                return;
3678            }
3679        };
3680
3681        let a = create_test_buffer(&provider, &[1, 2], "col_a");
3682        let b = create_test_buffer(&provider, &[3, 4], "col_b");
3683
3684        // Different column names but same types should succeed (Datalog union semantics)
3685        let result = provider.union(&a, &b);
3686        assert!(result.is_ok());
3687
3688        // Different arity should fail - create a 2-column buffer
3689        let two_col_schema = Schema::new(vec![
3690            ("x".to_string(), ScalarType::U32),
3691            ("y".to_string(), ScalarType::U32),
3692        ]);
3693        let c = provider
3694            .create_buffer_from_u32_columns(&[&[1, 2], &[3, 4]], two_col_schema)
3695            .unwrap();
3696        let result = provider.union(&a, &c);
3697        assert!(result.is_err());
3698    }
3699
3700    // ============== Diff Tests ==============
3701
3702    #[test]
3703    fn test_diff_empty_inputs() {
3704        let provider = match create_test_provider() {
3705            Some(p) => p,
3706            None => {
3707                eprintln!("Skipping test: no CUDA device available");
3708                return;
3709            }
3710        };
3711
3712        let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3713        let empty = create_empty_test_buffer(&provider, schema.clone());
3714
3715        // Empty diff empty
3716        let result = provider.diff(&empty, &empty);
3717        assert!(result.is_ok());
3718        assert!(result.unwrap().is_empty());
3719
3720        // Non-empty diff empty should return all of a
3721        let a = create_test_buffer(&provider, &[1, 2, 3], "key");
3722        let empty2 = create_empty_test_buffer(&provider, schema);
3723        let result = provider.diff(&a, &empty2);
3724        assert!(result.is_ok());
3725        let result = result.unwrap();
3726        assert_eq!(result.num_rows(), 3);
3727    }
3728
3729    #[test]
3730    fn test_diff_basic() {
3731        let provider = match create_test_provider() {
3732            Some(p) => p,
3733            None => {
3734                eprintln!("Skipping test: no CUDA device available");
3735                return;
3736            }
3737        };
3738
3739        let a = create_test_buffer(&provider, &[1, 2, 3, 4, 5], "key");
3740        let b = create_test_buffer(&provider, &[2, 4], "key");
3741
3742        let result = provider.diff(&a, &b);
3743        assert!(result.is_ok());
3744        let result = result.unwrap();
3745        assert_eq!(result.num_rows(), 3); // 1, 3, 5
3746
3747        let values = read_buffer_u32(&provider, &result, 0);
3748        assert_eq!(values, vec![1, 3, 5]);
3749    }
3750
3751    #[test]
3752    fn test_diff_all_filtered_out() {
3753        let provider = match create_test_provider() {
3754            Some(p) => p,
3755            None => {
3756                eprintln!("Skipping test: no CUDA device available");
3757                return;
3758            }
3759        };
3760
3761        let a = create_test_buffer(&provider, &[1, 2, 3], "key");
3762        let b = create_test_buffer(&provider, &[1, 2, 3, 4, 5], "key");
3763
3764        let result = provider.diff(&a, &b);
3765        assert!(result.is_ok());
3766        assert!(result.unwrap().is_empty());
3767    }
3768
3769    #[test]
3770    fn test_diff_schema_mismatch() {
3771        let provider = match create_test_provider() {
3772            Some(p) => p,
3773            None => {
3774                eprintln!("Skipping test: no CUDA device available");
3775                return;
3776            }
3777        };
3778
3779        // Different column names with same types should work (Datalog semantics)
3780        let a = create_test_buffer(&provider, &[1, 2], "col_a");
3781        let b = create_test_buffer(&provider, &[1, 2], "col_b");
3782        let result = provider.diff(&a, &b);
3783        assert!(
3784            result.is_ok(),
3785            "Same types with different names should succeed"
3786        );
3787
3788        // Create buffers with different arities (this should fail)
3789        let schema_2col = Schema::new(vec![
3790            ("c0".to_string(), ScalarType::U32),
3791            ("c1".to_string(), ScalarType::U32),
3792        ]);
3793
3794        let bytes_2col: Vec<u8> = [1u32, 2, 3, 4]
3795            .iter()
3796            .flat_map(|v| v.to_le_bytes())
3797            .collect();
3798        let mut col0 = provider
3799            .memory()
3800            .alloc::<u8>(bytes_2col.len() / 2)
3801            .expect("alloc");
3802        let mut col1 = provider
3803            .memory()
3804            .alloc::<u8>(bytes_2col.len() / 2)
3805            .expect("alloc");
3806        provider
3807            .device()
3808            .inner()
3809            .htod_sync_copy_into(&bytes_2col[..8], &mut col0)
3810            .expect("htod");
3811        provider
3812            .device()
3813            .inner()
3814            .htod_sync_copy_into(&bytes_2col[8..], &mut col1)
3815            .expect("htod");
3816        let buffer_2col = provider
3817            .buffer_from_columns(vec![col0.into(), col1.into()], 2, schema_2col)
3818            .expect("buffer");
3819
3820        let buffer_1col = create_test_buffer(&provider, &[1, 2], "c0");
3821
3822        let result = provider.diff(&buffer_2col, &buffer_1col);
3823        assert!(result.is_err(), "Different arities should fail");
3824    }
3825
3826    // ============== GroupBy Aggregation Tests ==============
3827
3828    #[test]
3829    fn test_groupby_empty_input() {
3830        let provider = match create_test_provider() {
3831            Some(p) => p,
3832            None => {
3833                eprintln!("Skipping test: no CUDA device available");
3834                return;
3835            }
3836        };
3837
3838        let schema = Schema::new(vec![("key".to_string(), ScalarType::U32)]);
3839        let empty = create_empty_test_buffer(&provider, schema);
3840
3841        let result = provider.groupby_agg(&empty, &[0], AggOp::Count, 0);
3842        assert!(result.is_ok());
3843        assert!(result.unwrap().is_empty());
3844    }
3845
3846    #[test]
3847    fn test_groupby_validation() {
3848        let provider = match create_test_provider() {
3849            Some(p) => p,
3850            None => {
3851                eprintln!("Skipping test: no CUDA device available");
3852                return;
3853            }
3854        };
3855
3856        let buffer = create_test_buffer(&provider, &[1, 1, 2, 2, 3], "key");
3857
3858        // Empty key columns
3859        let result = provider.groupby_agg(&buffer, &[], AggOp::Count, 0);
3860        assert!(result.is_err());
3861
3862        // Value column out of bounds
3863        let result = provider.groupby_agg(&buffer, &[0], AggOp::Count, 5);
3864        assert!(result.is_err());
3865    }
3866
3867    #[test]
3868    fn test_groupby_logsumexp() {
3869        let provider = match create_test_provider() {
3870            Some(p) => p,
3871            None => {
3872                eprintln!("Skipping test: no CUDA device available");
3873                return;
3874            }
3875        };
3876
3877        // Create buffer with U32 keys and F64 values
3878        // Group 0 (key=1): values 1.0, 2.0 -> logsumexp = log(e^1 + e^2) ≈ 2.31326
3879        // Group 1 (key=2): values 3.0, 4.0 -> logsumexp = log(e^3 + e^4) ≈ 4.31326
3880        let keys: Vec<u32> = vec![1, 1, 2, 2];
3881        let values: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
3882
3883        let schema = Schema::new(vec![
3884            ("key".to_string(), ScalarType::U32),
3885            ("value".to_string(), ScalarType::F64),
3886        ]);
3887
3888        // Create key column
3889        let key_bytes: Vec<u8> = keys.iter().flat_map(|v| v.to_le_bytes()).collect();
3890        let mut key_col = provider
3891            .memory()
3892            .alloc::<u8>(key_bytes.len())
3893            .expect("alloc key");
3894        provider
3895            .device()
3896            .inner()
3897            .htod_sync_copy_into(&key_bytes, &mut key_col)
3898            .expect("upload key");
3899
3900        // Create value column
3901        let val_bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
3902        let mut val_col = provider
3903            .memory()
3904            .alloc::<u8>(val_bytes.len())
3905            .expect("alloc val");
3906        provider
3907            .device()
3908            .inner()
3909            .htod_sync_copy_into(&val_bytes, &mut val_col)
3910            .expect("upload val");
3911
3912        let buffer = provider
3913            .buffer_from_columns(vec![key_col.into(), val_col.into()], 4, schema)
3914            .expect("buffer");
3915
3916        // Run LogSumExp aggregation grouped by key column (0), aggregating value column (1)
3917        let result = provider.groupby_agg(&buffer, &[0], AggOp::LogSumExp, 1);
3918        assert!(
3919            result.is_ok(),
3920            "groupby_agg with LogSumExp should succeed: {:?}",
3921            result.err()
3922        );
3923
3924        let result = result.unwrap();
3925        let group_count = provider
3926            .device_row_count(&result)
3927            .expect("read group count");
3928        assert_eq!(group_count, 2, "Should have 2 groups");
3929
3930        // Download results
3931        let result_values = provider
3932            .download_column::<f64>(&result, 1)
3933            .expect("download result");
3934
3935        // Expected values:
3936        // logsumexp(1.0, 2.0) = 2.0 + log(exp(1.0-2.0) + exp(2.0-2.0)) = 2.0 + log(e^-1 + 1) ≈ 2.31326
3937        // logsumexp(3.0, 4.0) = 4.0 + log(exp(3.0-4.0) + exp(4.0-4.0)) = 4.0 + log(e^-1 + 1) ≈ 4.31326
3938        let expected_0 = 2.0_f64 + ((-1.0_f64).exp() + 1.0_f64).ln(); // ≈ 2.31326
3939        let expected_1 = 4.0_f64 + ((-1.0_f64).exp() + 1.0_f64).ln(); // ≈ 4.31326
3940
3941        let tolerance = 1e-5;
3942        assert!(
3943            (result_values[0] - expected_0).abs() < tolerance,
3944            "Group 0 logsumexp mismatch: got {}, expected {}",
3945            result_values[0],
3946            expected_0
3947        );
3948        assert!(
3949            (result_values[1] - expected_1).abs() < tolerance,
3950            "Group 1 logsumexp mismatch: got {}, expected {}",
3951            result_values[1],
3952            expected_1
3953        );
3954    }
3955
3956    // ============== Schema Helper Tests ==============
3957
3958    #[test]
3959    fn test_combine_schemas() {
3960        let provider = match create_test_provider() {
3961            Some(p) => p,
3962            None => {
3963                eprintln!("Skipping test: no CUDA device available");
3964                return;
3965            }
3966        };
3967
3968        let left = Schema::new(vec![("a".to_string(), ScalarType::U32)]);
3969        let right = Schema::new(vec![("b".to_string(), ScalarType::U64)]);
3970
3971        let combined = provider.combine_schemas(&left, &right);
3972        assert_eq!(combined.arity(), 2);
3973        assert_eq!(combined.column_type(0), Some(ScalarType::U32));
3974        assert_eq!(combined.column_type(1), Some(ScalarType::U64));
3975    }
3976
3977    #[test]
3978    fn test_groupby_result_schema() {
3979        let provider = match create_test_provider() {
3980            Some(p) => p,
3981            None => {
3982                eprintln!("Skipping test: no CUDA device available");
3983                return;
3984            }
3985        };
3986
3987        let input = Schema::new(vec![
3988            ("key".to_string(), ScalarType::U32),
3989            ("value".to_string(), ScalarType::U32),
3990        ]);
3991
3992        // Count result schema (u64 to match predicate declarations)
3993        let count_schema =
3994            provider.groupby_multi_agg_result_schema(&input, &[0], &[(1, AggOp::Count)]);
3995        assert_eq!(count_schema.arity(), 2);
3996        assert_eq!(count_schema.column_type(1), Some(ScalarType::U64));
3997
3998        // Sum result schema
3999        let sum_schema = provider.groupby_multi_agg_result_schema(&input, &[0], &[(1, AggOp::Sum)]);
4000        assert_eq!(sum_schema.arity(), 2);
4001        assert_eq!(sum_schema.column_type(1), Some(ScalarType::U64));
4002
4003        // Min/Max result schema
4004        let min_schema = provider.groupby_multi_agg_result_schema(&input, &[0], &[(1, AggOp::Min)]);
4005        assert_eq!(min_schema.arity(), 2);
4006        assert_eq!(min_schema.column_type(1), Some(ScalarType::U32));
4007    }
4008
4009    #[test]
4010    fn test_groupby_multi_agg_sum_returns_u64_schema() {
4011        let provider = match create_test_provider() {
4012            Some(p) => p,
4013            None => {
4014                eprintln!("Skipping test: no CUDA device");
4015                return;
4016            }
4017        };
4018
4019        let schema = Schema::new(vec![
4020            ("key".to_string(), ScalarType::U32),
4021            ("val".to_string(), ScalarType::U32),
4022        ]);
4023
4024        let result_schema =
4025            provider.groupby_multi_agg_result_schema(&schema, &[0], &[(1, AggOp::Sum)]);
4026
4027        // Sum should return U64 to prevent overflow
4028        assert_eq!(
4029            result_schema.column_type(1),
4030            Some(ScalarType::U64),
4031            "Sum aggregation should return U64 type, not U32"
4032        );
4033    }
4034
4035    #[test]
4036    fn test_join_custom_max_output() {
4037        let provider = match create_test_provider() {
4038            Some(p) => p,
4039            None => {
4040                eprintln!("Skipping test: no CUDA device available");
4041                return;
4042            }
4043        };
4044
4045        // Create buffers that produce more than 10 results when joined
4046        // Left: [1, 1, 1, 1, 2, 2, 2, 2] - 4 copies of 1, 4 copies of 2
4047        // Right: [1, 1, 1, 2, 2, 2] - 3 copies of 1, 3 copies of 2
4048        // Join produces: 4*3 + 4*3 = 24 results
4049        let left = create_test_buffer(&provider, &[1, 1, 1, 1, 2, 2, 2, 2], "left_key");
4050        let right = create_test_buffer(&provider, &[1, 1, 1, 2, 2, 2], "right_key");
4051
4052        // Test with limit of 10 - should get at most 10
4053        let result_limited = provider
4054            .hash_join_v2_with_limit(&left, &right, &[0], &[0], JoinType::Inner, Some(10))
4055            .expect("join with limit should succeed");
4056        assert!(
4057            result_limited.num_rows() <= 10,
4058            "With limit 10, got {} rows but expected at most 10",
4059            result_limited.num_rows()
4060        );
4061
4062        // Test with None (default) - should get all 24 results
4063        let result_unlimited = provider
4064            .hash_join_v2_with_limit(&left, &right, &[0], &[0], JoinType::Inner, None)
4065            .expect("join without limit should succeed");
4066        assert_eq!(
4067            result_unlimited.num_rows(),
4068            24,
4069            "Without limit, expected 24 rows but got {}",
4070            result_unlimited.num_rows()
4071        );
4072
4073        // Test legacy API still works (backward compatibility)
4074        let result_legacy = provider
4075            .hash_join_v2(&left, &right, &[0], &[0], JoinType::Inner)
4076            .expect("legacy hash_join_v2 should succeed");
4077        assert_eq!(
4078            result_legacy.num_rows(),
4079            24,
4080            "Legacy API without limit, expected 24 rows but got {}",
4081            result_legacy.num_rows()
4082        );
4083    }
4084
4085    // ============== Arithmetic Operation Tests ==============
4086
4087    /// Helper to create a test provider for arithmetic tests
4088    fn create_arith_test_provider() -> Option<CudaKernelProvider> {
4089        if !has_cuda_device() {
4090            return None;
4091        }
4092        let device = Arc::new(CudaDevice::new(0).ok()?);
4093        let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
4094        let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
4095        CudaKernelProvider::new(device, memory).ok()
4096    }
4097
4098    /// Helper to create an i64 buffer for arithmetic tests
4099    fn create_i64_buffer(provider: &CudaKernelProvider, data: &[i64]) -> CudaBuffer {
4100        let schema = Schema::new(vec![("col".to_string(), ScalarType::I64)]);
4101        provider
4102            .create_buffer_from_slice::<i64>(data, schema)
4103            .unwrap()
4104    }
4105
4106    /// Helper to create an f64 buffer for arithmetic tests
4107    fn create_f64_buffer(provider: &CudaKernelProvider, data: &[f64]) -> CudaBuffer {
4108        let schema = Schema::new(vec![("col".to_string(), ScalarType::F64)]);
4109        provider
4110            .create_buffer_from_slice::<f64>(data, schema)
4111            .unwrap()
4112    }
4113
4114    #[test]
4115    fn test_add_columns_i64() {
4116        let Some(provider) = create_arith_test_provider() else {
4117            eprintln!("Skipping test: no CUDA device available");
4118            return;
4119        };
4120
4121        let a = create_i64_buffer(&provider, &[1, 2, 3, 4, 5]);
4122        let b = create_i64_buffer(&provider, &[10, 20, 30, 40, 50]);
4123
4124        let result = provider.add_columns(&a, &b).unwrap();
4125        let values = provider.download_column::<i64>(&result, 0).unwrap();
4126
4127        assert_eq!(values, vec![11, 22, 33, 44, 55]);
4128    }
4129
4130    #[test]
4131    fn test_sub_columns_i64() {
4132        let Some(provider) = create_arith_test_provider() else {
4133            eprintln!("Skipping test: no CUDA device available");
4134            return;
4135        };
4136
4137        let a = create_i64_buffer(&provider, &[10, 20, 30, 40, 50]);
4138        let b = create_i64_buffer(&provider, &[1, 2, 3, 4, 5]);
4139
4140        let result = provider.sub_columns(&a, &b).unwrap();
4141        let values = provider.download_column::<i64>(&result, 0).unwrap();
4142
4143        assert_eq!(values, vec![9, 18, 27, 36, 45]);
4144    }
4145
4146    #[test]
4147    fn test_mul_columns_i64() {
4148        let Some(provider) = create_arith_test_provider() else {
4149            eprintln!("Skipping test: no CUDA device available");
4150            return;
4151        };
4152
4153        let a = create_i64_buffer(&provider, &[2, 3, 4, 5, 6]);
4154        let b = create_i64_buffer(&provider, &[3, 4, 5, 6, 7]);
4155
4156        let result = provider.mul_columns(&a, &b).unwrap();
4157        let values = provider.download_column::<i64>(&result, 0).unwrap();
4158
4159        assert_eq!(values, vec![6, 12, 20, 30, 42]);
4160    }
4161
4162    #[test]
4163    fn test_div_columns_i64() {
4164        let Some(provider) = create_arith_test_provider() else {
4165            eprintln!("Skipping test: no CUDA device available");
4166            return;
4167        };
4168
4169        let a = create_i64_buffer(&provider, &[100, 200, 300, 400]);
4170        let b = create_i64_buffer(&provider, &[10, 20, 30, 40]);
4171
4172        let result = provider.div_columns(&a, &b).unwrap();
4173        let values = provider.download_column::<i64>(&result, 0).unwrap();
4174
4175        assert_eq!(values, vec![10, 10, 10, 10]);
4176    }
4177
4178    #[test]
4179    fn test_div_columns_by_zero() {
4180        let Some(provider) = create_arith_test_provider() else {
4181            eprintln!("Skipping test: no CUDA device available");
4182            return;
4183        };
4184
4185        let a = create_i64_buffer(&provider, &[10, 20, 30]);
4186        let b = create_i64_buffer(&provider, &[2, 0, 3]); // Note: division by zero
4187
4188        let result = provider.div_columns(&a, &b).unwrap();
4189        let values = provider.download_column::<i64>(&result, 0).unwrap();
4190
4191        // Division by zero returns i64::MAX
4192        assert_eq!(values, vec![5, i64::MAX, 10]);
4193    }
4194
4195    #[test]
4196    fn test_mod_columns_i64() {
4197        let Some(provider) = create_arith_test_provider() else {
4198            eprintln!("Skipping test: no CUDA device available");
4199            return;
4200        };
4201
4202        let a = create_i64_buffer(&provider, &[17, 23, 100, 7]);
4203        let b = create_i64_buffer(&provider, &[5, 7, 30, 3]);
4204
4205        let result = provider.mod_columns(&a, &b).unwrap();
4206        let values = provider.download_column::<i64>(&result, 0).unwrap();
4207
4208        assert_eq!(values, vec![2, 2, 10, 1]);
4209    }
4210
4211    #[test]
4212    fn test_mod_columns_by_zero() {
4213        let Some(provider) = create_arith_test_provider() else {
4214            eprintln!("Skipping test: no CUDA device available");
4215            return;
4216        };
4217
4218        let a = create_i64_buffer(&provider, &[10, 20]);
4219        let b = create_i64_buffer(&provider, &[3, 0]); // Note: mod by zero
4220
4221        let result = provider.mod_columns(&a, &b).unwrap();
4222        let values = provider.download_column::<i64>(&result, 0).unwrap();
4223
4224        // Mod by zero returns 0
4225        assert_eq!(values, vec![1, 0]);
4226    }
4227
4228    #[test]
4229    fn test_abs_column_i64() {
4230        let Some(provider) = create_arith_test_provider() else {
4231            eprintln!("Skipping test: no CUDA device available");
4232            return;
4233        };
4234
4235        let a = create_i64_buffer(&provider, &[-5, 10, -15, 20, 0]);
4236
4237        let result = provider.abs_column(&a).unwrap();
4238        let values = provider.download_column::<i64>(&result, 0).unwrap();
4239
4240        assert_eq!(values, vec![5, 10, 15, 20, 0]);
4241    }
4242
4243    #[test]
4244    fn test_min_columns_i64() {
4245        let Some(provider) = create_arith_test_provider() else {
4246            eprintln!("Skipping test: no CUDA device available");
4247            return;
4248        };
4249
4250        let a = create_i64_buffer(&provider, &[5, 10, 15, 20]);
4251        let b = create_i64_buffer(&provider, &[3, 12, 10, 25]);
4252
4253        let result = provider.min_columns(&a, &b).unwrap();
4254        let values = provider.download_column::<i64>(&result, 0).unwrap();
4255
4256        assert_eq!(values, vec![3, 10, 10, 20]);
4257    }
4258
4259    #[test]
4260    fn test_max_columns_i64() {
4261        let Some(provider) = create_arith_test_provider() else {
4262            eprintln!("Skipping test: no CUDA device available");
4263            return;
4264        };
4265
4266        let a = create_i64_buffer(&provider, &[5, 10, 15, 20]);
4267        let b = create_i64_buffer(&provider, &[3, 12, 10, 25]);
4268
4269        let result = provider.max_columns(&a, &b).unwrap();
4270        let values = provider.download_column::<i64>(&result, 0).unwrap();
4271
4272        assert_eq!(values, vec![5, 12, 15, 25]);
4273    }
4274
4275    #[test]
4276    fn test_add_columns_f64() {
4277        let Some(provider) = create_arith_test_provider() else {
4278            eprintln!("Skipping test: no CUDA device available");
4279            return;
4280        };
4281
4282        let a = create_f64_buffer(&provider, &[1.5, 2.5, 3.5]);
4283        let b = create_f64_buffer(&provider, &[0.5, 1.5, 2.5]);
4284
4285        let result = provider.add_columns(&a, &b).unwrap();
4286        let values = provider.download_column::<f64>(&result, 0).unwrap();
4287
4288        assert_eq!(values, vec![2.0, 4.0, 6.0]);
4289    }
4290
4291    #[test]
4292    fn test_mul_columns_f64() {
4293        let Some(provider) = create_arith_test_provider() else {
4294            eprintln!("Skipping test: no CUDA device available");
4295            return;
4296        };
4297
4298        let a = create_f64_buffer(&provider, &[2.0, 3.0, 4.0]);
4299        let b = create_f64_buffer(&provider, &[1.5, 2.0, 2.5]);
4300
4301        let result = provider.mul_columns(&a, &b).unwrap();
4302        let values = provider.download_column::<f64>(&result, 0).unwrap();
4303
4304        assert_eq!(values, vec![3.0, 6.0, 10.0]);
4305    }
4306
4307    #[test]
4308    fn test_div_columns_f64_by_zero() {
4309        let Some(provider) = create_arith_test_provider() else {
4310            eprintln!("Skipping test: no CUDA device available");
4311            return;
4312        };
4313
4314        let a = create_f64_buffer(&provider, &[1.0, -1.0, 0.0]);
4315        let b = create_f64_buffer(&provider, &[0.0, 0.0, 0.0]);
4316
4317        let result = provider.div_columns(&a, &b).unwrap();
4318        let values = provider.download_column::<f64>(&result, 0).unwrap();
4319
4320        // IEEE 754: 1.0/0.0 = Inf, -1.0/0.0 = -Inf, 0.0/0.0 = NaN
4321        assert!(values[0].is_infinite() && values[0].is_sign_positive());
4322        assert!(values[1].is_infinite() && values[1].is_sign_negative());
4323        assert!(values[2].is_nan());
4324    }
4325
4326    #[test]
4327    fn test_pow_columns() {
4328        let Some(provider) = create_arith_test_provider() else {
4329            eprintln!("Skipping test: no CUDA device available");
4330            return;
4331        };
4332
4333        let base = create_i64_buffer(&provider, &[2, 3, 4, 5]);
4334        let exp = create_i64_buffer(&provider, &[3, 2, 2, 1]);
4335
4336        let result = provider.pow_columns(&base, &exp).unwrap();
4337        let values = provider.download_column::<f64>(&result, 0).unwrap();
4338
4339        // pow always returns f64
4340        assert_eq!(values, vec![8.0, 9.0, 16.0, 5.0]);
4341    }
4342
4343    #[test]
4344    fn test_pow_columns_fractional_exp() {
4345        let Some(provider) = create_arith_test_provider() else {
4346            eprintln!("Skipping test: no CUDA device available");
4347            return;
4348        };
4349
4350        let base = create_f64_buffer(&provider, &[4.0, 9.0, 27.0]);
4351        let exp = create_f64_buffer(&provider, &[0.5, 0.5, 1.0 / 3.0]);
4352
4353        let result = provider.pow_columns(&base, &exp).unwrap();
4354        let values = provider.download_column::<f64>(&result, 0).unwrap();
4355
4356        // sqrt(4) = 2, sqrt(9) = 3, cbrt(27) = 3
4357        assert!((values[0] - 2.0).abs() < 1e-10);
4358        assert!((values[1] - 3.0).abs() < 1e-10);
4359        assert!((values[2] - 3.0).abs() < 1e-10);
4360    }
4361
4362    #[test]
4363    fn test_cast_i64_to_f64() {
4364        let Some(provider) = create_arith_test_provider() else {
4365            eprintln!("Skipping test: no CUDA device available");
4366            return;
4367        };
4368
4369        let a = create_i64_buffer(&provider, &[1, 2, 3, 4, 5]);
4370
4371        let result = provider.cast_column(&a, ScalarType::F64).unwrap();
4372        let values = provider.download_column::<f64>(&result, 0).unwrap();
4373
4374        assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
4375    }
4376
4377    #[test]
4378    fn test_cast_f64_to_i64() {
4379        let Some(provider) = create_arith_test_provider() else {
4380            eprintln!("Skipping test: no CUDA device available");
4381            return;
4382        };
4383
4384        let a = create_f64_buffer(&provider, &[1.9, 2.1, 3.5, 4.0, 5.7]);
4385
4386        let result = provider.cast_column(&a, ScalarType::I64).unwrap();
4387        let values = provider.download_column::<i64>(&result, 0).unwrap();
4388
4389        // Truncation towards zero
4390        assert_eq!(values, vec![1, 2, 3, 4, 5]);
4391    }
4392
4393    #[test]
4394    fn test_cast_i64_to_i32() {
4395        let Some(provider) = create_arith_test_provider() else {
4396            eprintln!("Skipping test: no CUDA device available");
4397            return;
4398        };
4399
4400        let a = create_i64_buffer(&provider, &[1, 2, 3, 100, 200]);
4401
4402        let result = provider.cast_column(&a, ScalarType::I32).unwrap();
4403        let values = provider.download_column::<i32>(&result, 0).unwrap();
4404
4405        assert_eq!(values, vec![1, 2, 3, 100, 200]);
4406    }
4407
4408    #[test]
4409    fn test_arithmetic_row_count_mismatch() {
4410        let Some(provider) = create_arith_test_provider() else {
4411            eprintln!("Skipping test: no CUDA device available");
4412            return;
4413        };
4414
4415        let a = create_i64_buffer(&provider, &[1, 2, 3]);
4416        let b = create_i64_buffer(&provider, &[1, 2]); // Different size
4417
4418        let result = provider.add_columns(&a, &b);
4419        assert!(result.is_err());
4420        let err = result.err().unwrap();
4421        assert!(err.to_string().contains("Row count mismatch"));
4422    }
4423
4424    #[test]
4425    fn test_arithmetic_empty_buffers() {
4426        let Some(provider) = create_arith_test_provider() else {
4427            eprintln!("Skipping test: no CUDA device available");
4428            return;
4429        };
4430
4431        let a = create_i64_buffer(&provider, &[]);
4432        let b = create_i64_buffer(&provider, &[]);
4433
4434        let result = provider.add_columns(&a, &b).unwrap();
4435        let values = provider.download_column::<i64>(&result, 0).unwrap();
4436
4437        assert_eq!(values, Vec::<i64>::new());
4438    }
4439
4440    #[test]
4441    fn test_wrapping_arithmetic_overflow() {
4442        let Some(provider) = create_arith_test_provider() else {
4443            eprintln!("Skipping test: no CUDA device available");
4444            return;
4445        };
4446
4447        let a = create_i64_buffer(&provider, &[i64::MAX, i64::MIN]);
4448        let b = create_i64_buffer(&provider, &[1, -1]);
4449
4450        // Addition should wrap
4451        let add_result = provider.add_columns(&a, &b).unwrap();
4452        let add_values = provider.download_column::<i64>(&add_result, 0).unwrap();
4453        assert_eq!(add_values[0], i64::MIN); // MAX + 1 wraps to MIN
4454        assert_eq!(add_values[1], i64::MAX); // MIN - 1 wraps to MAX
4455    }
4456
4457    #[test]
4458    fn test_abs_column_f64() {
4459        let Some(provider) = create_arith_test_provider() else {
4460            eprintln!("Skipping test: no CUDA device available");
4461            return;
4462        };
4463
4464        let a = create_f64_buffer(&provider, &[-1.5, 2.5, -3.5, 0.0]);
4465
4466        let result = provider.abs_column(&a).unwrap();
4467        let values = provider.download_column::<f64>(&result, 0).unwrap();
4468
4469        assert_eq!(values, vec![1.5, 2.5, 3.5, 0.0]);
4470    }
4471
4472    #[test]
4473    fn test_min_max_columns_f64() {
4474        let Some(provider) = create_arith_test_provider() else {
4475            eprintln!("Skipping test: no CUDA device available");
4476            return;
4477        };
4478
4479        let a = create_f64_buffer(&provider, &[1.5, 5.0, 3.0]);
4480        let b = create_f64_buffer(&provider, &[2.0, 3.0, 4.0]);
4481
4482        let min_result = provider.min_columns(&a, &b).unwrap();
4483        let min_values = provider.download_column::<f64>(&min_result, 0).unwrap();
4484        assert_eq!(min_values, vec![1.5, 3.0, 3.0]);
4485
4486        let max_result = provider.max_columns(&a, &b).unwrap();
4487        let max_values = provider.download_column::<f64>(&max_result, 0).unwrap();
4488        assert_eq!(max_values, vec![2.0, 5.0, 4.0]);
4489    }
4490}