Skip to main content

xlog_cuda_tests/harness/
validators.rs

1//! CPU reference implementations for validating GPU results.
2
3/// CPU reference implementations for all kernel operations.
4pub mod reference {
5    use std::collections::HashMap;
6
7    /// Hash join returning (left_idx, right_idx) pairs.
8    pub fn hash_join_u32(left: &[u32], right: &[u32]) -> Vec<(usize, usize)> {
9        let mut result = Vec::new();
10        let right_map: HashMap<u32, Vec<usize>> =
11            right
12                .iter()
13                .enumerate()
14                .fold(HashMap::new(), |mut acc, (idx, val)| {
15                    acc.entry(*val).or_default().push(idx);
16                    acc
17                });
18
19        for (left_idx, left_val) in left.iter().enumerate() {
20            if let Some(right_indices) = right_map.get(left_val) {
21                for &right_idx in right_indices {
22                    result.push((left_idx, right_idx));
23                }
24            }
25        }
26        result
27    }
28
29    /// Multi-column hash join.
30    pub fn hash_join_multi(left_cols: &[&[u32]], right_cols: &[&[u32]]) -> Vec<(usize, usize)> {
31        if left_cols.is_empty() || right_cols.is_empty() {
32            return Vec::new();
33        }
34
35        let left_len = left_cols[0].len();
36        let right_len = right_cols[0].len();
37
38        // Build hash map on right side
39        let mut right_map: HashMap<Vec<u32>, Vec<usize>> = HashMap::new();
40        for i in 0..right_len {
41            let key: Vec<u32> = right_cols.iter().map(|col| col[i]).collect();
42            right_map.entry(key).or_default().push(i);
43        }
44
45        // Probe with left side
46        let mut result = Vec::new();
47        for i in 0..left_len {
48            let key: Vec<u32> = left_cols.iter().map(|col| col[i]).collect();
49            if let Some(right_indices) = right_map.get(&key) {
50                for &right_idx in right_indices {
51                    result.push((i, right_idx));
52                }
53            }
54        }
55        result
56    }
57
58    /// Semi join - returns left indices that have a match in right.
59    pub fn semi_join_u32(left: &[u32], right: &[u32]) -> Vec<usize> {
60        let right_set: std::collections::HashSet<u32> = right.iter().copied().collect();
61        left.iter()
62            .enumerate()
63            .filter(|(_, val)| right_set.contains(val))
64            .map(|(idx, _)| idx)
65            .collect()
66    }
67
68    /// Anti join - returns left indices that have NO match in right.
69    pub fn anti_join_u32(left: &[u32], right: &[u32]) -> Vec<usize> {
70        let right_set: std::collections::HashSet<u32> = right.iter().copied().collect();
71        left.iter()
72            .enumerate()
73            .filter(|(_, val)| !right_set.contains(val))
74            .map(|(idx, _)| idx)
75            .collect()
76    }
77
78    /// Filter by comparison operation.
79    pub fn filter_compare_u32(data: &[u32], op: CompareOp, val: u32) -> Vec<usize> {
80        data.iter()
81            .enumerate()
82            .filter(|(_, d)| op.apply(**d, val))
83            .map(|(idx, _)| idx)
84            .collect()
85    }
86
87    /// Filter by comparison operation for i64.
88    pub fn filter_compare_i64(data: &[i64], op: CompareOp, val: i64) -> Vec<usize> {
89        data.iter()
90            .enumerate()
91            .filter(|(_, d)| op.apply(**d, val))
92            .map(|(idx, _)| idx)
93            .collect()
94    }
95
96    /// Filter by comparison operation for f64.
97    pub fn filter_compare_f64(data: &[f64], op: CompareOp, val: f64) -> Vec<f64> {
98        data.iter()
99            .filter(|d| op.apply_f64(**d, val))
100            .copied()
101            .collect()
102    }
103
104    /// Compact by mask.
105    pub fn compact_by_mask<T: Copy>(data: &[T], mask: &[u8]) -> Vec<T> {
106        data.iter()
107            .zip(mask.iter())
108            .filter(|(_, m)| **m != 0)
109            .map(|(d, _)| *d)
110            .collect()
111    }
112
113    /// Stable radix sort returning (sorted_data, permutation).
114    pub fn radix_sort_u32(keys: &[u32]) -> (Vec<u32>, Vec<u32>) {
115        let mut indexed: Vec<(u32, u32)> = keys
116            .iter()
117            .enumerate()
118            .map(|(i, &k)| (k, i as u32))
119            .collect();
120
121        // Stable sort
122        indexed.sort_by_key(|(k, _)| *k);
123
124        let sorted: Vec<u32> = indexed.iter().map(|(k, _)| *k).collect();
125        let perm: Vec<u32> = indexed.iter().map(|(_, i)| *i).collect();
126
127        (sorted, perm)
128    }
129
130    /// Apply permutation to data.
131    pub fn apply_permutation<T: Copy>(data: &[T], perm: &[u32]) -> Vec<T> {
132        perm.iter().map(|&i| data[i as usize]).collect()
133    }
134
135    /// Inclusive prefix sum.
136    pub fn inclusive_scan(data: &[u32]) -> Vec<u32> {
137        let mut result = Vec::with_capacity(data.len());
138        let mut sum = 0u32;
139        for &val in data {
140            sum = sum.wrapping_add(val);
141            result.push(sum);
142        }
143        result
144    }
145
146    /// Exclusive prefix sum.
147    pub fn exclusive_scan(data: &[u32]) -> Vec<u32> {
148        let mut result = Vec::with_capacity(data.len());
149        let mut sum = 0u32;
150        for &val in data {
151            result.push(sum);
152            sum = sum.wrapping_add(val);
153        }
154        result
155    }
156
157    /// Group by count (assumes sorted input).
158    pub fn groupby_count_sorted(keys: &[u32]) -> Vec<(u32, u64)> {
159        if keys.is_empty() {
160            return Vec::new();
161        }
162
163        let mut result = Vec::new();
164        let mut current_key = keys[0];
165        let mut count = 1u64;
166
167        for &key in &keys[1..] {
168            if key == current_key {
169                count += 1;
170            } else {
171                result.push((current_key, count));
172                current_key = key;
173                count = 1;
174            }
175        }
176        result.push((current_key, count));
177        result
178    }
179
180    /// Group by sum (assumes sorted input).
181    pub fn groupby_sum_sorted(keys: &[u32], vals: &[u32]) -> Vec<(u32, u64)> {
182        if keys.is_empty() {
183            return Vec::new();
184        }
185
186        let mut result = Vec::new();
187        let mut current_key = keys[0];
188        let mut sum = vals[0] as u64;
189
190        for i in 1..keys.len() {
191            if keys[i] == current_key {
192                sum += vals[i] as u64;
193            } else {
194                result.push((current_key, sum));
195                current_key = keys[i];
196                sum = vals[i] as u64;
197            }
198        }
199        result.push((current_key, sum));
200        result
201    }
202
203    /// Group by min (assumes sorted input).
204    pub fn groupby_min_sorted(keys: &[u32], vals: &[u32]) -> Vec<(u32, u32)> {
205        if keys.is_empty() {
206            return Vec::new();
207        }
208
209        let mut result = Vec::new();
210        let mut current_key = keys[0];
211        let mut min_val = vals[0];
212
213        for i in 1..keys.len() {
214            if keys[i] == current_key {
215                min_val = min_val.min(vals[i]);
216            } else {
217                result.push((current_key, min_val));
218                current_key = keys[i];
219                min_val = vals[i];
220            }
221        }
222        result.push((current_key, min_val));
223        result
224    }
225
226    /// Group by max (assumes sorted input).
227    pub fn groupby_max_sorted(keys: &[u32], vals: &[u32]) -> Vec<(u32, u32)> {
228        if keys.is_empty() {
229            return Vec::new();
230        }
231
232        let mut result = Vec::new();
233        let mut current_key = keys[0];
234        let mut max_val = vals[0];
235
236        for i in 1..keys.len() {
237            if keys[i] == current_key {
238                max_val = max_val.max(vals[i]);
239            } else {
240                result.push((current_key, max_val));
241                current_key = keys[i];
242                max_val = vals[i];
243            }
244        }
245        result.push((current_key, max_val));
246        result
247    }
248
249    /// Dedup sorted data.
250    pub fn dedup_sorted<T: Eq + Copy>(data: &[T]) -> Vec<T> {
251        if data.is_empty() {
252            return Vec::new();
253        }
254
255        let mut result = vec![data[0]];
256        for &val in &data[1..] {
257            if val != *result.last().unwrap() {
258                result.push(val);
259            }
260        }
261        result
262    }
263
264    /// Mark duplicates in sorted data (true = unique, false = duplicate).
265    pub fn mark_duplicates<T: Eq>(sorted: &[T]) -> Vec<bool> {
266        if sorted.is_empty() {
267            return Vec::new();
268        }
269
270        let mut result = vec![true]; // First element is always unique
271        for i in 1..sorted.len() {
272            result.push(sorted[i] != sorted[i - 1]);
273        }
274        result
275    }
276
277    /// Sorted set union.
278    pub fn sorted_union<T: Ord + Copy>(a: &[T], b: &[T]) -> Vec<T> {
279        let mut result = Vec::new();
280        let mut i = 0;
281        let mut j = 0;
282
283        while i < a.len() && j < b.len() {
284            if a[i] < b[j] {
285                if result.last() != Some(&a[i]) {
286                    result.push(a[i]);
287                }
288                i += 1;
289            } else if a[i] > b[j] {
290                if result.last() != Some(&b[j]) {
291                    result.push(b[j]);
292                }
293                j += 1;
294            } else {
295                if result.last() != Some(&a[i]) {
296                    result.push(a[i]);
297                }
298                i += 1;
299                j += 1;
300            }
301        }
302
303        while i < a.len() {
304            if result.last() != Some(&a[i]) {
305                result.push(a[i]);
306            }
307            i += 1;
308        }
309
310        while j < b.len() {
311            if result.last() != Some(&b[j]) {
312                result.push(b[j]);
313            }
314            j += 1;
315        }
316
317        result
318    }
319
320    /// Sorted set difference (a - b).
321    pub fn sorted_diff<T: Ord + Copy>(a: &[T], b: &[T]) -> Vec<T> {
322        let mut result = Vec::new();
323        let mut i = 0;
324        let mut j = 0;
325
326        while i < a.len() && j < b.len() {
327            if a[i] < b[j] {
328                result.push(a[i]);
329                i += 1;
330            } else if a[i] > b[j] {
331                j += 1;
332            } else {
333                i += 1;
334                j += 1;
335            }
336        }
337
338        while i < a.len() {
339            result.push(a[i]);
340            i += 1;
341        }
342
343        result
344    }
345
346    /// Pack columns into row-major byte array.
347    pub fn pack_keys(cols: &[&[u8]], col_sizes: &[usize], num_rows: usize) -> Vec<u8> {
348        let row_size: usize = col_sizes.iter().sum();
349        let mut result = vec![0u8; row_size * num_rows];
350
351        for row in 0..num_rows {
352            let mut offset = 0;
353            for (col_idx, &col_size) in col_sizes.iter().enumerate() {
354                let src_start = row * col_size;
355                let src_end = src_start + col_size;
356                let dst_start = row * row_size + offset;
357                let dst_end = dst_start + col_size;
358                result[dst_start..dst_end].copy_from_slice(&cols[col_idx][src_start..src_end]);
359                offset += col_size;
360            }
361        }
362        result
363    }
364
365    /// FNV-1a hash.
366    pub fn hash_fnv1a(data: &[u8]) -> u32 {
367        const FNV_PRIME: u32 = 16777619;
368        const FNV_OFFSET: u32 = 2166136261;
369
370        let mut hash = FNV_OFFSET;
371        for &byte in data {
372            hash ^= byte as u32;
373            hash = hash.wrapping_mul(FNV_PRIME);
374        }
375        hash
376    }
377
378    /// Unpack a column from packed rows.
379    pub fn unpack_column(
380        packed: &[u8],
381        row_size: usize,
382        col_offset: usize,
383        col_size: usize,
384        num_rows: usize,
385    ) -> Vec<u8> {
386        let mut result = vec![0u8; col_size * num_rows];
387
388        for row in 0..num_rows {
389            let src_start = row * row_size + col_offset;
390            let src_end = src_start + col_size;
391            let dst_start = row * col_size;
392            let dst_end = dst_start + col_size;
393            result[dst_start..dst_end].copy_from_slice(&packed[src_start..src_end]);
394        }
395        result
396    }
397
398    /// Comparison operation enum.
399    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
400    pub enum CompareOp {
401        Eq,
402        Ne,
403        Lt,
404        Le,
405        Gt,
406        Ge,
407    }
408
409    impl CompareOp {
410        pub fn apply<T: PartialOrd>(&self, a: T, b: T) -> bool {
411            match self {
412                CompareOp::Eq => a == b,
413                CompareOp::Ne => a != b,
414                CompareOp::Lt => a < b,
415                CompareOp::Le => a <= b,
416                CompareOp::Gt => a > b,
417                CompareOp::Ge => a >= b,
418            }
419        }
420
421        pub fn apply_f64(&self, a: f64, b: f64) -> bool {
422            match self {
423                CompareOp::Eq => a == b,
424                CompareOp::Ne => a != b,
425                CompareOp::Lt => a < b,
426                CompareOp::Le => a <= b,
427                CompareOp::Gt => a > b,
428                CompareOp::Ge => a >= b,
429            }
430        }
431    }
432}
433
434/// Comparison utilities for GPU vs CPU result validation.
435pub mod compare {
436    /// Assert u32 slices are equal with detailed diff on failure.
437    pub fn assert_eq_u32(gpu: &[u32], cpu: &[u32], context: &str) {
438        if gpu.len() != cpu.len() {
439            panic!(
440                "{}: length mismatch: GPU={}, CPU={}",
441                context,
442                gpu.len(),
443                cpu.len()
444            );
445        }
446
447        let mut first_diff = None;
448        let mut diff_count = 0;
449
450        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
451            if g != c {
452                if first_diff.is_none() {
453                    first_diff = Some((i, *g, *c));
454                }
455                diff_count += 1;
456            }
457        }
458
459        if let Some((idx, gpu_val, cpu_val)) = first_diff {
460            panic!(
461                "{}: {} differences found. First at index {}: GPU={}, CPU={}",
462                context, diff_count, idx, gpu_val, cpu_val
463            );
464        }
465    }
466
467    /// Assert i64 slices are equal with detailed diff on failure.
468    pub fn assert_eq_i64(gpu: &[i64], cpu: &[i64], context: &str) {
469        if gpu.len() != cpu.len() {
470            panic!(
471                "{}: length mismatch: GPU={}, CPU={}",
472                context,
473                gpu.len(),
474                cpu.len()
475            );
476        }
477
478        let mut first_diff = None;
479        let mut diff_count = 0;
480
481        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
482            if g != c {
483                if first_diff.is_none() {
484                    first_diff = Some((i, *g, *c));
485                }
486                diff_count += 1;
487            }
488        }
489
490        if let Some((idx, gpu_val, cpu_val)) = first_diff {
491            panic!(
492                "{}: {} differences found. First at index {}: GPU={}, CPU={}",
493                context, diff_count, idx, gpu_val, cpu_val
494            );
495        }
496    }
497
498    /// Assert u64 slices are equal with detailed diff on failure.
499    pub fn assert_eq_u64(gpu: &[u64], cpu: &[u64], context: &str) {
500        if gpu.len() != cpu.len() {
501            panic!(
502                "{}: length mismatch: GPU={}, CPU={}",
503                context,
504                gpu.len(),
505                cpu.len()
506            );
507        }
508
509        let mut first_diff = None;
510        let mut diff_count = 0;
511
512        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
513            if g != c {
514                if first_diff.is_none() {
515                    first_diff = Some((i, *g, *c));
516                }
517                diff_count += 1;
518            }
519        }
520
521        if let Some((idx, gpu_val, cpu_val)) = first_diff {
522            panic!(
523                "{}: {} differences found. First at index {}: GPU={}, CPU={}",
524                context, diff_count, idx, gpu_val, cpu_val
525            );
526        }
527    }
528
529    /// Assert f64 slices are equal within ULP tolerance.
530    pub fn assert_eq_f64_ulp(gpu: &[f64], cpu: &[f64], max_ulp: u64, context: &str) {
531        if gpu.len() != cpu.len() {
532            panic!(
533                "{}: length mismatch: GPU={}, CPU={}",
534                context,
535                gpu.len(),
536                cpu.len()
537            );
538        }
539
540        let mut first_diff = None;
541        let mut diff_count = 0;
542
543        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
544            let ulp_diff = ulp_distance(*g, *c);
545            if ulp_diff > max_ulp {
546                if first_diff.is_none() {
547                    first_diff = Some((i, *g, *c, ulp_diff));
548                }
549                diff_count += 1;
550            }
551        }
552
553        if let Some((idx, gpu_val, cpu_val, ulp)) = first_diff {
554            panic!(
555                "{}: {} ULP violations (max={}). First at index {}: GPU={}, CPU={}, ULP={}",
556                context, diff_count, max_ulp, idx, gpu_val, cpu_val, ulp
557            );
558        }
559    }
560
561    /// Assert f64 slices are equal within relative tolerance.
562    pub fn assert_eq_f64_rel(gpu: &[f64], cpu: &[f64], rel_tol: f64, context: &str) {
563        if gpu.len() != cpu.len() {
564            panic!(
565                "{}: length mismatch: GPU={}, CPU={}",
566                context,
567                gpu.len(),
568                cpu.len()
569            );
570        }
571
572        let mut first_diff = None;
573        let mut diff_count = 0;
574
575        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
576            let rel_diff = if *c == 0.0 {
577                g.abs()
578            } else {
579                ((g - c) / c).abs()
580            };
581
582            if rel_diff > rel_tol && !g.is_nan() && !c.is_nan() {
583                if first_diff.is_none() {
584                    first_diff = Some((i, *g, *c, rel_diff));
585                }
586                diff_count += 1;
587            }
588        }
589
590        if let Some((idx, gpu_val, cpu_val, rel)) = first_diff {
591            panic!(
592                "{}: {} relative tolerance violations (max={}). First at index {}: GPU={}, CPU={}, rel_diff={}",
593                context, diff_count, rel_tol, idx, gpu_val, cpu_val, rel
594            );
595        }
596    }
597
598    /// Compute ULP distance between two f64 values.
599    fn ulp_distance(a: f64, b: f64) -> u64 {
600        if a.is_nan() || b.is_nan() {
601            return u64::MAX;
602        }
603        if a == b {
604            return 0;
605        }
606        if a.is_infinite() || b.is_infinite() {
607            return u64::MAX;
608        }
609
610        let a_bits = a.to_bits() as i64;
611        let b_bits = b.to_bits() as i64;
612
613        (a_bits - b_bits).unsigned_abs()
614    }
615
616    /// Assert sets are equal (order-independent).
617    pub fn assert_set_eq_u32(gpu: &[u32], cpu: &[u32], context: &str) {
618        let mut gpu_sorted = gpu.to_vec();
619        let mut cpu_sorted = cpu.to_vec();
620        gpu_sorted.sort();
621        cpu_sorted.sort();
622
623        if gpu_sorted != cpu_sorted {
624            panic!(
625                "{}: set mismatch. GPU (sorted): {:?}, CPU (sorted): {:?}",
626                context,
627                &gpu_sorted[..gpu_sorted.len().min(10)],
628                &cpu_sorted[..cpu_sorted.len().min(10)]
629            );
630        }
631    }
632
633    /// Assert permutation is valid.
634    pub fn assert_valid_permutation(perm: &[u32], len: usize, context: &str) {
635        if perm.len() != len {
636            panic!(
637                "{}: permutation length {} != expected {}",
638                context,
639                perm.len(),
640                len
641            );
642        }
643
644        let mut seen = vec![false; len];
645        for (i, &idx) in perm.iter().enumerate() {
646            if idx as usize >= len {
647                panic!(
648                    "{}: permutation index {} out of bounds at position {}",
649                    context, idx, i
650                );
651            }
652            if seen[idx as usize] {
653                panic!(
654                    "{}: duplicate permutation index {} at position {}",
655                    context, idx, i
656                );
657            }
658            seen[idx as usize] = true;
659        }
660    }
661
662    /// Assert sort is stable (equal keys maintain relative order).
663    pub fn assert_stable_sort(
664        original_keys: &[u32],
665        original_vals: &[u32],
666        sorted_keys: &[u32],
667        sorted_vals: &[u32],
668        context: &str,
669    ) {
670        // Group by key in original order
671        let mut key_to_vals: std::collections::HashMap<u32, Vec<u32>> =
672            std::collections::HashMap::new();
673        for (&k, &v) in original_keys.iter().zip(original_vals.iter()) {
674            key_to_vals.entry(k).or_default().push(v);
675        }
676
677        // Check sorted result maintains relative order within each key group
678        let mut key_to_idx: std::collections::HashMap<u32, usize> =
679            std::collections::HashMap::new();
680        for (&k, &v) in sorted_keys.iter().zip(sorted_vals.iter()) {
681            let idx = key_to_idx.entry(k).or_insert(0);
682            let expected_vals = key_to_vals.get(&k).unwrap();
683            if *idx >= expected_vals.len() {
684                panic!("{}: too many values for key {}", context, k);
685            }
686            if v != expected_vals[*idx] {
687                panic!(
688                    "{}: stability violation for key {}. Expected value {} but got {}",
689                    context, k, expected_vals[*idx], v
690                );
691            }
692            *idx += 1;
693        }
694    }
695
696    /// Check if result is sorted.
697    pub fn is_sorted_u32(data: &[u32]) -> bool {
698        data.windows(2).all(|w| w[0] <= w[1])
699    }
700
701    /// Check if result is sorted descending.
702    pub fn is_sorted_desc_u32(data: &[u32]) -> bool {
703        data.windows(2).all(|w| w[0] >= w[1])
704    }
705}