1mod sealed {
18 pub trait Sealed {}
19 impl Sealed for u8 {}
20 impl Sealed for u32 {}
21 impl Sealed for u64 {}
22 impl Sealed for i32 {}
23 impl Sealed for i64 {}
24 impl Sealed for f32 {}
25 impl Sealed for f64 {}
26 impl Sealed for bool {}
27}
28
29pub trait GpuScalar:
36 sealed::Sealed + crate::cuda_compat::KernelScalar + Copy + Send + 'static
37{
38 const BYTE_WIDTH: usize;
40
41 fn from_le_bytes(bytes: &[u8]) -> Self;
44
45 fn to_le_bytes_into(self, buf: &mut [u8]);
48
49 fn filter_compare_kernel() -> &'static str;
51
52 fn compare_col_kernel() -> &'static str;
54
55 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType];
57
58 fn filter_scan_phase1_kernel() -> Option<&'static str> {
61 None
62 }
63}
64
65impl GpuScalar for u8 {
66 const BYTE_WIDTH: usize = 1;
67 fn from_le_bytes(bytes: &[u8]) -> Self {
68 bytes[0]
69 }
70 fn to_le_bytes_into(self, buf: &mut [u8]) {
71 buf[0] = self;
72 }
73 fn filter_compare_kernel() -> &'static str {
74 "filter_compare_u8"
75 }
76 fn compare_col_kernel() -> &'static str {
77 "filter_compare_u8_col"
78 }
79 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
80 &[xlog_core::ScalarType::Bool]
81 }
82}
83
84impl GpuScalar for u32 {
85 const BYTE_WIDTH: usize = 4;
86 fn from_le_bytes(bytes: &[u8]) -> Self {
87 u32::from_le_bytes(bytes.try_into().unwrap())
88 }
89 fn to_le_bytes_into(self, buf: &mut [u8]) {
90 buf.copy_from_slice(&self.to_le_bytes());
91 }
92 fn filter_compare_kernel() -> &'static str {
93 "filter_compare_u32"
94 }
95 fn compare_col_kernel() -> &'static str {
96 "filter_compare_u32_col"
97 }
98 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
99 &[xlog_core::ScalarType::U32, xlog_core::ScalarType::Symbol]
100 }
101 fn filter_scan_phase1_kernel() -> Option<&'static str> {
102 Some("filter_compare_u32_scan_phase1")
103 }
104}
105
106impl GpuScalar for u64 {
107 const BYTE_WIDTH: usize = 8;
108 fn from_le_bytes(bytes: &[u8]) -> Self {
109 u64::from_le_bytes(bytes.try_into().unwrap())
110 }
111 fn to_le_bytes_into(self, buf: &mut [u8]) {
112 buf.copy_from_slice(&self.to_le_bytes());
113 }
114 fn filter_compare_kernel() -> &'static str {
115 "filter_compare_u64"
116 }
117 fn compare_col_kernel() -> &'static str {
118 "filter_compare_u64_col"
119 }
120 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
121 &[xlog_core::ScalarType::U64]
122 }
123}
124
125impl GpuScalar for i32 {
126 const BYTE_WIDTH: usize = 4;
127 fn from_le_bytes(bytes: &[u8]) -> Self {
128 i32::from_le_bytes(bytes.try_into().unwrap())
129 }
130 fn to_le_bytes_into(self, buf: &mut [u8]) {
131 buf.copy_from_slice(&self.to_le_bytes());
132 }
133 fn filter_compare_kernel() -> &'static str {
134 "filter_compare_i32"
135 }
136 fn compare_col_kernel() -> &'static str {
137 "filter_compare_i32_col"
138 }
139 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
140 &[xlog_core::ScalarType::I32]
141 }
142}
143
144impl GpuScalar for i64 {
145 const BYTE_WIDTH: usize = 8;
146 fn from_le_bytes(bytes: &[u8]) -> Self {
147 i64::from_le_bytes(bytes.try_into().unwrap())
148 }
149 fn to_le_bytes_into(self, buf: &mut [u8]) {
150 buf.copy_from_slice(&self.to_le_bytes());
151 }
152 fn filter_compare_kernel() -> &'static str {
153 "filter_compare_i64"
154 }
155 fn compare_col_kernel() -> &'static str {
156 "filter_compare_i64_col"
157 }
158 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
159 &[xlog_core::ScalarType::I64]
160 }
161}
162
163impl GpuScalar for f32 {
164 const BYTE_WIDTH: usize = 4;
165 fn from_le_bytes(bytes: &[u8]) -> Self {
166 f32::from_le_bytes(bytes.try_into().unwrap())
167 }
168 fn to_le_bytes_into(self, buf: &mut [u8]) {
169 buf.copy_from_slice(&self.to_le_bytes());
170 }
171 fn filter_compare_kernel() -> &'static str {
172 "filter_compare_f32"
173 }
174 fn compare_col_kernel() -> &'static str {
175 "filter_compare_f32_col"
176 }
177 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
178 &[xlog_core::ScalarType::F32]
179 }
180}
181
182impl GpuScalar for f64 {
183 const BYTE_WIDTH: usize = 8;
184 fn from_le_bytes(bytes: &[u8]) -> Self {
185 f64::from_le_bytes(bytes.try_into().unwrap())
186 }
187 fn to_le_bytes_into(self, buf: &mut [u8]) {
188 buf.copy_from_slice(&self.to_le_bytes());
189 }
190 fn filter_compare_kernel() -> &'static str {
191 "filter_compare_f64"
192 }
193 fn compare_col_kernel() -> &'static str {
194 "filter_compare_f64_col"
195 }
196 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
197 &[xlog_core::ScalarType::F64]
198 }
199 fn filter_scan_phase1_kernel() -> Option<&'static str> {
200 Some("filter_compare_f64_scan_phase1")
201 }
202}
203
204impl GpuScalar for bool {
208 const BYTE_WIDTH: usize = 1;
209
210 fn from_le_bytes(bytes: &[u8]) -> Self {
211 bytes[0] != 0
213 }
214
215 fn to_le_bytes_into(self, buf: &mut [u8]) {
216 buf[0] = if self { 1 } else { 0 };
218 }
219
220 fn filter_compare_kernel() -> &'static str {
222 "filter_compare_u8"
223 }
224 fn compare_col_kernel() -> &'static str {
225 "filter_compare_u8_col"
226 }
227 fn allowed_scalar_types() -> &'static [xlog_core::ScalarType] {
228 &[xlog_core::ScalarType::Bool]
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 fn roundtrip<T: GpuScalar + PartialEq + std::fmt::Debug>(val: T) {
238 let mut buf = vec![0u8; T::BYTE_WIDTH];
239 val.to_le_bytes_into(&mut buf);
240 let recovered = T::from_le_bytes(&buf);
241 assert_eq!(recovered, val);
242 }
243
244 #[test]
245 fn test_gpu_scalar_roundtrip_u8() {
246 roundtrip(42u8);
247 roundtrip(0u8);
248 roundtrip(255u8);
249 }
250
251 #[test]
252 fn test_gpu_scalar_roundtrip_u32() {
253 roundtrip(0u32);
254 roundtrip(42u32);
255 roundtrip(u32::MAX);
256 }
257
258 #[test]
259 fn test_gpu_scalar_roundtrip_u64() {
260 roundtrip(0u64);
261 roundtrip(42u64);
262 roundtrip(u64::MAX);
263 }
264
265 #[test]
266 fn test_gpu_scalar_roundtrip_i32() {
267 roundtrip(0i32);
268 roundtrip(-1i32);
269 roundtrip(i32::MAX);
270 }
271
272 #[test]
273 fn test_gpu_scalar_roundtrip_i64() {
274 roundtrip(0i64);
275 roundtrip(-1i64);
276 roundtrip(i64::MAX);
277 }
278
279 #[test]
280 fn test_gpu_scalar_roundtrip_f32() {
281 roundtrip(0.0f32);
282 roundtrip(-1.5f32);
283 roundtrip(f32::INFINITY);
284 }
285
286 #[test]
287 fn test_gpu_scalar_roundtrip_f64() {
288 roundtrip(0.0f64);
289 roundtrip(-1.5f64);
290 roundtrip(f64::INFINITY);
291 }
292
293 #[test]
294 fn test_gpu_scalar_roundtrip_bool() {
295 roundtrip(true);
296 roundtrip(false);
297 }
298
299 #[test]
300 fn test_bool_canonical_write() {
301 let mut buf = [0xFFu8];
302 false.to_le_bytes_into(&mut buf);
303 assert_eq!(buf[0], 0x00, "false must write canonical 0x00");
304
305 true.to_le_bytes_into(&mut buf);
306 assert_eq!(buf[0], 0x01, "true must write canonical 0x01");
307 }
308
309 #[test]
310 fn test_bool_lenient_read() {
311 assert!(!bool::from_le_bytes(&[0x00]));
313 assert!(bool::from_le_bytes(&[0x01]));
314 assert!(bool::from_le_bytes(&[0x02]));
315 assert!(bool::from_le_bytes(&[0xFF]));
316 }
317
318 #[test]
319 fn test_byte_width_consistency() {
320 assert_eq!(u8::BYTE_WIDTH, std::mem::size_of::<u8>());
321 assert_eq!(u32::BYTE_WIDTH, std::mem::size_of::<u32>());
322 assert_eq!(u64::BYTE_WIDTH, std::mem::size_of::<u64>());
323 assert_eq!(i32::BYTE_WIDTH, std::mem::size_of::<i32>());
324 assert_eq!(i64::BYTE_WIDTH, std::mem::size_of::<i64>());
325 assert_eq!(f32::BYTE_WIDTH, std::mem::size_of::<f32>());
326 assert_eq!(f64::BYTE_WIDTH, std::mem::size_of::<f64>());
327 assert_eq!(bool::BYTE_WIDTH, std::mem::size_of::<bool>());
328 }
329
330 #[test]
331 fn test_filter_kernel_names_non_empty() {
332 assert!(!u8::filter_compare_kernel().is_empty());
334 assert!(!u8::compare_col_kernel().is_empty());
335 assert!(!u32::filter_compare_kernel().is_empty());
336 assert!(!u32::compare_col_kernel().is_empty());
337 assert!(!u64::filter_compare_kernel().is_empty());
338 assert!(!u64::compare_col_kernel().is_empty());
339 assert!(!i32::filter_compare_kernel().is_empty());
340 assert!(!i32::compare_col_kernel().is_empty());
341 assert!(!i64::filter_compare_kernel().is_empty());
342 assert!(!i64::compare_col_kernel().is_empty());
343 assert!(!f32::filter_compare_kernel().is_empty());
344 assert!(!f32::compare_col_kernel().is_empty());
345 assert!(!f64::filter_compare_kernel().is_empty());
346 assert!(!f64::compare_col_kernel().is_empty());
347 assert!(!bool::filter_compare_kernel().is_empty());
348 assert!(!bool::compare_col_kernel().is_empty());
349 }
350
351 #[test]
352 fn test_allowed_scalar_types_non_empty() {
353 assert!(!u8::allowed_scalar_types().is_empty());
354 assert!(!u32::allowed_scalar_types().is_empty());
355 assert!(!u64::allowed_scalar_types().is_empty());
356 assert!(!i32::allowed_scalar_types().is_empty());
357 assert!(!i64::allowed_scalar_types().is_empty());
358 assert!(!f32::allowed_scalar_types().is_empty());
359 assert!(!f64::allowed_scalar_types().is_empty());
360 assert!(!bool::allowed_scalar_types().is_empty());
361 }
362
363 #[test]
364 fn test_fused_scan_kernel_only_u32_and_f64() {
365 assert!(u32::filter_scan_phase1_kernel().is_some());
367 assert!(f64::filter_scan_phase1_kernel().is_some());
368 assert!(u8::filter_scan_phase1_kernel().is_none());
370 assert!(u64::filter_scan_phase1_kernel().is_none());
371 assert!(i32::filter_scan_phase1_kernel().is_none());
372 assert!(i64::filter_scan_phase1_kernel().is_none());
373 assert!(f32::filter_scan_phase1_kernel().is_none());
374 assert!(bool::filter_scan_phase1_kernel().is_none());
375 }
376
377 #[test]
378 fn test_bool_and_u8_share_gpu_kernels() {
379 assert_eq!(u8::filter_compare_kernel(), bool::filter_compare_kernel());
381 assert_eq!(u8::compare_col_kernel(), bool::compare_col_kernel());
382 }
383
384 #[test]
385 fn test_u32_allowed_includes_symbol() {
386 let allowed = u32::allowed_scalar_types();
388 assert!(allowed.contains(&xlog_core::ScalarType::U32));
389 assert!(allowed.contains(&xlog_core::ScalarType::Symbol));
390 }
391}