Skip to main content

xlog_cuda/
type_seam.rs

1//! GpuScalar — marker trait for Rust scalar types that round-trip through GPU column storage.
2//!
3//! The trait is `pub` because external crates call turbofish generics bounded by it
4//! (e.g. `provider.download_column::<u32>()`), and Rust's `private_bounds` lint
5//! requires trait bounds on pub functions to be pub. However, the trait is **sealed**:
6//! external crates cannot add new implementations.
7//!
8//! # Bool encoding
9//!
10//! Write encoding (H2D): canonical `0x00` = false, `0x01` = true.
11//! Read decoding (D2H): `0x00` = false, any nonzero byte = true.
12//!
13//! The asymmetry is intentional: we always write canonical values, but tolerate
14//! non-canonical GPU output during reads to match existing provider behavior.
15
16/// Private module prevents external crates from implementing `GpuScalar`.
17mod sealed {
18    pub trait Sealed {}
19    impl Sealed for u8 {}
20    impl Sealed for u32 {}
21    impl Sealed for u64 {}
22    impl Sealed for i32 {}
23    impl Sealed for i64 {}
24    impl Sealed for f32 {}
25    impl Sealed for f64 {}
26    impl Sealed for bool {}
27}
28
29/// Marker trait: a Rust scalar type that can round-trip through GPU column storage.
30///
31/// Requires `cudarc::driver::DeviceRepr` + known byte width + little-endian serialization.
32///
33/// This trait is **sealed** — it cannot be implemented outside `xlog-cuda`.
34/// The fixed set of implementations covers all GPU-compatible scalar types.
35pub trait GpuScalar:
36    sealed::Sealed + crate::cuda_compat::KernelScalar + Copy + Send + 'static
37{
38    /// Size in bytes of this scalar type.
39    const BYTE_WIDTH: usize;
40
41    /// Deserialize from a little-endian byte slice.
42    /// The slice length must equal `BYTE_WIDTH`.
43    fn from_le_bytes(bytes: &[u8]) -> Self;
44
45    /// Serialize into a little-endian byte buffer.
46    /// The buffer length must equal `BYTE_WIDTH`.
47    fn to_le_bytes_into(self, buf: &mut [u8]);
48
49    /// Kernel function name for const-compare mask generation.
50    fn filter_compare_kernel() -> &'static str;
51
52    /// Kernel function name for column-column comparison mask.
53    fn compare_col_kernel() -> &'static str;
54
55    /// ScalarType variants accepted for this type in filter/compare operations.
56    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType];
57
58    /// Optional fused compare+scan kernel (phase 1). Only u32 and f64 have optimized
59    /// fused-scan paths. Returns None for types using the mask+compact path.
60    fn filter_scan_phase1_kernel() -> Option<&'static str> {
61        None
62    }
63}
64
65impl GpuScalar for u8 {
66    const BYTE_WIDTH: usize = 1;
67    fn from_le_bytes(bytes: &[u8]) -> Self {
68        bytes[0]
69    }
70    fn to_le_bytes_into(self, buf: &mut [u8]) {
71        buf[0] = self;
72    }
73    fn filter_compare_kernel() -> &'static str {
74        "filter_compare_u8"
75    }
76    fn compare_col_kernel() -> &'static str {
77        "filter_compare_u8_col"
78    }
79    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
80        &[xlog_core::ScalarType::Bool]
81    }
82}
83
84impl GpuScalar for u32 {
85    const BYTE_WIDTH: usize = 4;
86    fn from_le_bytes(bytes: &[u8]) -> Self {
87        u32::from_le_bytes(bytes.try_into().unwrap())
88    }
89    fn to_le_bytes_into(self, buf: &mut [u8]) {
90        buf.copy_from_slice(&self.to_le_bytes());
91    }
92    fn filter_compare_kernel() -> &'static str {
93        "filter_compare_u32"
94    }
95    fn compare_col_kernel() -> &'static str {
96        "filter_compare_u32_col"
97    }
98    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
99        &[xlog_core::ScalarType::U32, xlog_core::ScalarType::Symbol]
100    }
101    fn filter_scan_phase1_kernel() -> Option<&'static str> {
102        Some("filter_compare_u32_scan_phase1")
103    }
104}
105
106impl GpuScalar for u64 {
107    const BYTE_WIDTH: usize = 8;
108    fn from_le_bytes(bytes: &[u8]) -> Self {
109        u64::from_le_bytes(bytes.try_into().unwrap())
110    }
111    fn to_le_bytes_into(self, buf: &mut [u8]) {
112        buf.copy_from_slice(&self.to_le_bytes());
113    }
114    fn filter_compare_kernel() -> &'static str {
115        "filter_compare_u64"
116    }
117    fn compare_col_kernel() -> &'static str {
118        "filter_compare_u64_col"
119    }
120    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
121        &[xlog_core::ScalarType::U64]
122    }
123}
124
125impl GpuScalar for i32 {
126    const BYTE_WIDTH: usize = 4;
127    fn from_le_bytes(bytes: &[u8]) -> Self {
128        i32::from_le_bytes(bytes.try_into().unwrap())
129    }
130    fn to_le_bytes_into(self, buf: &mut [u8]) {
131        buf.copy_from_slice(&self.to_le_bytes());
132    }
133    fn filter_compare_kernel() -> &'static str {
134        "filter_compare_i32"
135    }
136    fn compare_col_kernel() -> &'static str {
137        "filter_compare_i32_col"
138    }
139    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
140        &[xlog_core::ScalarType::I32]
141    }
142}
143
144impl GpuScalar for i64 {
145    const BYTE_WIDTH: usize = 8;
146    fn from_le_bytes(bytes: &[u8]) -> Self {
147        i64::from_le_bytes(bytes.try_into().unwrap())
148    }
149    fn to_le_bytes_into(self, buf: &mut [u8]) {
150        buf.copy_from_slice(&self.to_le_bytes());
151    }
152    fn filter_compare_kernel() -> &'static str {
153        "filter_compare_i64"
154    }
155    fn compare_col_kernel() -> &'static str {
156        "filter_compare_i64_col"
157    }
158    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
159        &[xlog_core::ScalarType::I64]
160    }
161}
162
163impl GpuScalar for f32 {
164    const BYTE_WIDTH: usize = 4;
165    fn from_le_bytes(bytes: &[u8]) -> Self {
166        f32::from_le_bytes(bytes.try_into().unwrap())
167    }
168    fn to_le_bytes_into(self, buf: &mut [u8]) {
169        buf.copy_from_slice(&self.to_le_bytes());
170    }
171    fn filter_compare_kernel() -> &'static str {
172        "filter_compare_f32"
173    }
174    fn compare_col_kernel() -> &'static str {
175        "filter_compare_f32_col"
176    }
177    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
178        &[xlog_core::ScalarType::F32]
179    }
180}
181
182impl GpuScalar for f64 {
183    const BYTE_WIDTH: usize = 8;
184    fn from_le_bytes(bytes: &[u8]) -> Self {
185        f64::from_le_bytes(bytes.try_into().unwrap())
186    }
187    fn to_le_bytes_into(self, buf: &mut [u8]) {
188        buf.copy_from_slice(&self.to_le_bytes());
189    }
190    fn filter_compare_kernel() -> &'static str {
191        "filter_compare_f64"
192    }
193    fn compare_col_kernel() -> &'static str {
194        "filter_compare_f64_col"
195    }
196    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
197        &[xlog_core::ScalarType::F64]
198    }
199    fn filter_scan_phase1_kernel() -> Option<&'static str> {
200        Some("filter_compare_f64_scan_phase1")
201    }
202}
203
204/// Bool encoding:
205/// - Write (H2D): `0x00` = false, `0x01` = true (canonical).
206/// - Read (D2H): `0x00` = false, nonzero = true (lenient, matches the D2H bool decoding path in provider/transfer.rs).
207impl GpuScalar for bool {
208    const BYTE_WIDTH: usize = 1;
209
210    fn from_le_bytes(bytes: &[u8]) -> Self {
211        // Lenient read: any nonzero byte is true (matches existing D2H behavior).
212        bytes[0] != 0
213    }
214
215    fn to_le_bytes_into(self, buf: &mut [u8]) {
216        // Canonical write: 0x00 or 0x01.
217        buf[0] = if self { 1 } else { 0 };
218    }
219
220    // Bool uses the u8 kernel on the GPU side.
221    fn filter_compare_kernel() -> &'static str {
222        "filter_compare_u8"
223    }
224    fn compare_col_kernel() -> &'static str {
225        "filter_compare_u8_col"
226    }
227    fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
228        &[xlog_core::ScalarType::Bool]
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    /// Helper: roundtrip a value through le-bytes serialization.
237    fn roundtrip<T: GpuScalar + PartialEq + std::fmt::Debug>(val: T) {
238        let mut buf = vec![0u8; T::BYTE_WIDTH];
239        val.to_le_bytes_into(&mut buf);
240        let recovered = T::from_le_bytes(&buf);
241        assert_eq!(recovered, val);
242    }
243
244    #[test]
245    fn test_gpu_scalar_roundtrip_u8() {
246        roundtrip(42u8);
247        roundtrip(0u8);
248        roundtrip(255u8);
249    }
250
251    #[test]
252    fn test_gpu_scalar_roundtrip_u32() {
253        roundtrip(0u32);
254        roundtrip(42u32);
255        roundtrip(u32::MAX);
256    }
257
258    #[test]
259    fn test_gpu_scalar_roundtrip_u64() {
260        roundtrip(0u64);
261        roundtrip(42u64);
262        roundtrip(u64::MAX);
263    }
264
265    #[test]
266    fn test_gpu_scalar_roundtrip_i32() {
267        roundtrip(0i32);
268        roundtrip(-1i32);
269        roundtrip(i32::MAX);
270    }
271
272    #[test]
273    fn test_gpu_scalar_roundtrip_i64() {
274        roundtrip(0i64);
275        roundtrip(-1i64);
276        roundtrip(i64::MAX);
277    }
278
279    #[test]
280    fn test_gpu_scalar_roundtrip_f32() {
281        roundtrip(0.0f32);
282        roundtrip(-1.5f32);
283        roundtrip(f32::INFINITY);
284    }
285
286    #[test]
287    fn test_gpu_scalar_roundtrip_f64() {
288        roundtrip(0.0f64);
289        roundtrip(-1.5f64);
290        roundtrip(f64::INFINITY);
291    }
292
293    #[test]
294    fn test_gpu_scalar_roundtrip_bool() {
295        roundtrip(true);
296        roundtrip(false);
297    }
298
299    #[test]
300    fn test_bool_canonical_write() {
301        let mut buf = [0xFFu8];
302        false.to_le_bytes_into(&mut buf);
303        assert_eq!(buf[0], 0x00, "false must write canonical 0x00");
304
305        true.to_le_bytes_into(&mut buf);
306        assert_eq!(buf[0], 0x01, "true must write canonical 0x01");
307    }
308
309    #[test]
310    fn test_bool_lenient_read() {
311        // Any nonzero byte reads as true (matches the D2H bool decoding path in provider/transfer.rs behavior).
312        assert!(!bool::from_le_bytes(&[0x00]));
313        assert!(bool::from_le_bytes(&[0x01]));
314        assert!(bool::from_le_bytes(&[0x02]));
315        assert!(bool::from_le_bytes(&[0xFF]));
316    }
317
318    #[test]
319    fn test_byte_width_consistency() {
320        assert_eq!(u8::BYTE_WIDTH, std::mem::size_of::<u8>());
321        assert_eq!(u32::BYTE_WIDTH, std::mem::size_of::<u32>());
322        assert_eq!(u64::BYTE_WIDTH, std::mem::size_of::<u64>());
323        assert_eq!(i32::BYTE_WIDTH, std::mem::size_of::<i32>());
324        assert_eq!(i64::BYTE_WIDTH, std::mem::size_of::<i64>());
325        assert_eq!(f32::BYTE_WIDTH, std::mem::size_of::<f32>());
326        assert_eq!(f64::BYTE_WIDTH, std::mem::size_of::<f64>());
327        assert_eq!(bool::BYTE_WIDTH, std::mem::size_of::<bool>());
328    }
329
330    #[test]
331    fn test_filter_kernel_names_non_empty() {
332        // Every GpuScalar impl must return non-empty kernel names.
333        assert!(!u8::filter_compare_kernel().is_empty());
334        assert!(!u8::compare_col_kernel().is_empty());
335        assert!(!u32::filter_compare_kernel().is_empty());
336        assert!(!u32::compare_col_kernel().is_empty());
337        assert!(!u64::filter_compare_kernel().is_empty());
338        assert!(!u64::compare_col_kernel().is_empty());
339        assert!(!i32::filter_compare_kernel().is_empty());
340        assert!(!i32::compare_col_kernel().is_empty());
341        assert!(!i64::filter_compare_kernel().is_empty());
342        assert!(!i64::compare_col_kernel().is_empty());
343        assert!(!f32::filter_compare_kernel().is_empty());
344        assert!(!f32::compare_col_kernel().is_empty());
345        assert!(!f64::filter_compare_kernel().is_empty());
346        assert!(!f64::compare_col_kernel().is_empty());
347        assert!(!bool::filter_compare_kernel().is_empty());
348        assert!(!bool::compare_col_kernel().is_empty());
349    }
350
351    #[test]
352    fn test_allowed_scalar_types_non_empty() {
353        assert!(!u8::allowed_scalar_types().is_empty());
354        assert!(!u32::allowed_scalar_types().is_empty());
355        assert!(!u64::allowed_scalar_types().is_empty());
356        assert!(!i32::allowed_scalar_types().is_empty());
357        assert!(!i64::allowed_scalar_types().is_empty());
358        assert!(!f32::allowed_scalar_types().is_empty());
359        assert!(!f64::allowed_scalar_types().is_empty());
360        assert!(!bool::allowed_scalar_types().is_empty());
361    }
362
363    #[test]
364    fn test_fused_scan_kernel_only_u32_and_f64() {
365        // Only u32 and f64 have fused-scan phase1 kernels.
366        assert!(u32::filter_scan_phase1_kernel().is_some());
367        assert!(f64::filter_scan_phase1_kernel().is_some());
368        // All others return None.
369        assert!(u8::filter_scan_phase1_kernel().is_none());
370        assert!(u64::filter_scan_phase1_kernel().is_none());
371        assert!(i32::filter_scan_phase1_kernel().is_none());
372        assert!(i64::filter_scan_phase1_kernel().is_none());
373        assert!(f32::filter_scan_phase1_kernel().is_none());
374        assert!(bool::filter_scan_phase1_kernel().is_none());
375    }
376
377    #[test]
378    fn test_bool_and_u8_share_gpu_kernels() {
379        // Bool is stored as u8 on the GPU, so both types share the same kernels.
380        assert_eq!(u8::filter_compare_kernel(), bool::filter_compare_kernel());
381        assert_eq!(u8::compare_col_kernel(), bool::compare_col_kernel());
382    }
383
384    #[test]
385    fn test_u32_allowed_includes_symbol() {
386        // u32 filter must accept both U32 and Symbol columns.
387        let allowed = u32::allowed_scalar_types();
388        assert!(allowed.contains(&xlog_core::ScalarType::U32));
389        assert!(allowed.contains(&xlog_core::ScalarType::Symbol));
390    }
391}