Skip to main content

xlog_cuda/provider/
ilp_exact.rs

1//! Launcher for the native bounded exact-induction scoring kernel.
2//!
3//! Drives `kernels/ilp_exact.cu`'s `ilp_exact_score` kernel: scores all
4//! `(topology, L, R)` triples for a single `induce_exact` call in one
5//! launch and returns the positive/negative coverage count arrays to host.
6//!
7//! Design: `docs/plans/2026-04-17-m8-ilp-exact-kernel-design.md`.
8
9use std::marker::PhantomData;
10use std::sync::atomic::Ordering;
11
12use crate::{LaunchAsync, LaunchConfig};
13use xlog_core::{Result, ScalarType, XlogError};
14
15use super::{ilp_exact_kernels, RawCudaView, ILP_EXACT_MODULE};
16use crate::memory::{CudaBuffer, TrackedCudaSlice};
17
18const ILP_EXACT_BLOCK_SIZE: u32 = 256;
19const ILP_EXACT_TOPK_FIELDS: usize = 9;
20const ENV_ILP_EXACT_CHAIN_SMEM: &str = "XLOG_ILP_EXACT_CHAIN_SMEM";
21const ENV_ILP_EXACT_CHAIN_SMEM_MIN_ROWS: &str = "XLOG_ILP_EXACT_CHAIN_SMEM_MIN_ROWS";
22const DEFAULT_ILP_EXACT_CHAIN_SMEM_MIN_ROWS: u32 = 256;
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub struct IlpExactTopkCandidate {
26    pub topology_idx: u32,
27    pub left_idx: u32,
28    pub right_idx: u32,
29    pub positives_covered: u32,
30    pub negatives_covered: u32,
31    pub local_rank: u32,
32    pub next_positives_covered: u32,
33    pub next_negatives_covered: u32,
34    pub tie_class_size: u32,
35}
36
37struct IlpExactDeviceScores {
38    candidate_count: usize,
39    #[cfg(test)]
40    slot_count: usize,
41    pos_covered: TrackedCudaSlice<u32>,
42    neg_covered: TrackedCudaSlice<u32>,
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46enum ExactPairLayout {
47    U64,
48    U32,
49    Symbol,
50}
51
52impl ExactPairLayout {
53    fn elem_size(self) -> usize {
54        match self {
55            Self::U64 => std::mem::size_of::<u64>(),
56            Self::U32 | Self::Symbol => std::mem::size_of::<u32>(),
57        }
58    }
59}
60
61fn ilp_exact_chain_smem_enabled() -> bool {
62    match std::env::var(ENV_ILP_EXACT_CHAIN_SMEM) {
63        Ok(value) => !matches!(
64            value.trim().to_ascii_lowercase().as_str(),
65            "0" | "false" | "off" | "no"
66        ),
67        Err(_) => true,
68    }
69}
70
71fn chain_smem_shared_bytes(layout: ExactPairLayout) -> u32 {
72    let block = ILP_EXACT_BLOCK_SIZE as usize;
73    let bytes = (2usize * block * layout.elem_size()) + (block * std::mem::size_of::<u32>());
74    u32::try_from(bytes).expect("chain smem byte count fits in u32")
75}
76
77fn ilp_exact_chain_smem_min_rows() -> u32 {
78    std::env::var(ENV_ILP_EXACT_CHAIN_SMEM_MIN_ROWS)
79        .ok()
80        .and_then(|value| value.trim().parse::<u32>().ok())
81        .unwrap_or(DEFAULT_ILP_EXACT_CHAIN_SMEM_MIN_ROWS)
82}
83
84impl super::CudaKernelProvider {
85    /// Test-only full-score export for validating the scoring kernels.
86    ///
87    /// Returns `(pos_covered, neg_covered)`, each of length `4 * C * C`
88    /// where `C = candidate_buffers.len()`. Slot ordering:
89    /// `slot = topology * (C * C) + L * C + R`, with topology indices
90    /// `chain=0, star=1, fanout=2, fanin=3`.
91    ///
92    /// Host-side contract:
93    ///   * All buffers must be arity 2 with one matching pair type: `U64`,
94    ///     `U32`, or `Symbol`.
95    ///   * `cached_row_count()` must be populated on every buffer (DLPack
96    ///     ingest and `create_empty_buffer` both guarantee this).
97    ///   * `negatives` is always a valid buffer — the caller constructs
98    ///     an empty pair buffer matching the positive pair type when there are
99    ///     no negatives.
100    ///
101    /// D2H budget: **2** counter-tracked transfers (one per count array).
102    /// Setup H2D / D2D copies are not D2H-counted.
103    #[cfg(test)]
104    fn ilp_exact_score(
105        &self,
106        candidate_buffers: &[&CudaBuffer],
107        positives: &CudaBuffer,
108        negatives: &CudaBuffer,
109    ) -> Result<(Vec<u32>, Vec<u32>)> {
110        let scores = self.ilp_exact_score_device(candidate_buffers, positives, negatives)?;
111        let device = self.device.inner();
112        self.device.synchronize()?;
113
114        let mut pos_covered = vec![0u32; scores.slot_count];
115        self.d2h_transfer_count.fetch_add(1, Ordering::Relaxed);
116        device
117            .dtoh_sync_copy_into(&scores.pos_covered, &mut pos_covered)
118            .map_err(|e| XlogError::Kernel(format!("ilp_exact_score: dtoh pos_covered: {}", e)))?;
119
120        let mut neg_covered = vec![0u32; scores.slot_count];
121        self.d2h_transfer_count.fetch_add(1, Ordering::Relaxed);
122        device
123            .dtoh_sync_copy_into(&scores.neg_covered, &mut neg_covered)
124            .map_err(|e| XlogError::Kernel(format!("ilp_exact_score: dtoh neg_covered: {}", e)))?;
125
126        Ok((pos_covered, neg_covered))
127    }
128
129    /// Score on GPU, reduce per-topology top-K on GPU, and transfer only the
130    /// compact selected rows back to host.
131    pub fn ilp_exact_score_topk(
132        &self,
133        candidate_buffers: &[&CudaBuffer],
134        positives: &CudaBuffer,
135        negatives: &CudaBuffer,
136        k_per_topology: u32,
137    ) -> Result<Vec<IlpExactTopkCandidate>> {
138        if k_per_topology == 0 {
139            return Ok(Vec::new());
140        }
141
142        let scores = self.ilp_exact_score_device(candidate_buffers, positives, negatives)?;
143        let out_rows = 4usize
144            .checked_mul(k_per_topology as usize)
145            .ok_or_else(|| XlogError::Kernel("ilp_exact_score_topk: output row overflow".into()))?;
146        let out_words = out_rows.checked_mul(ILP_EXACT_TOPK_FIELDS).ok_or_else(|| {
147            XlogError::Kernel("ilp_exact_score_topk: output word overflow".into())
148        })?;
149        let mut selected_buf = self.memory.alloc::<u32>(out_words)?;
150        let device = self.device.inner();
151        let func = device
152            .get_func(ILP_EXACT_MODULE, ilp_exact_kernels::ILP_EXACT_SELECT_TOPK)
153            .ok_or_else(|| {
154                XlogError::Kernel(format!(
155                    "{} kernel not loaded",
156                    ilp_exact_kernels::ILP_EXACT_SELECT_TOPK
157                ))
158            })?;
159
160        unsafe {
161            func.clone().launch(
162                LaunchConfig {
163                    grid_dim: (4, 1, 1),
164                    block_dim: (1, 1, 1),
165                    shared_mem_bytes: 0,
166                },
167                (
168                    &scores.pos_covered,
169                    &scores.neg_covered,
170                    scores.candidate_count as u32,
171                    k_per_topology,
172                    &mut selected_buf,
173                ),
174            )
175        }
176        .map_err(|e| XlogError::Kernel(format!("ilp_exact_select_topk launch: {}", e)))?;
177
178        self.device.synchronize()?;
179        let mut words = vec![0u32; out_words];
180        self.d2h_transfer_count.fetch_add(1, Ordering::Relaxed);
181        device
182            .dtoh_sync_copy_into(&selected_buf, &mut words)
183            .map_err(|e| {
184                XlogError::Kernel(format!("ilp_exact_score_topk: dtoh selected: {}", e))
185            })?;
186
187        let mut selected = Vec::new();
188        for chunk in words.chunks_exact(ILP_EXACT_TOPK_FIELDS) {
189            if chunk[3] == 0 {
190                continue;
191            }
192            selected.push(IlpExactTopkCandidate {
193                topology_idx: chunk[0],
194                left_idx: chunk[1],
195                right_idx: chunk[2],
196                positives_covered: chunk[3],
197                negatives_covered: chunk[4],
198                local_rank: chunk[5],
199                next_positives_covered: chunk[6],
200                next_negatives_covered: chunk[7],
201                tie_class_size: chunk[8],
202            });
203        }
204        Ok(selected)
205    }
206
207    fn ilp_exact_score_device(
208        &self,
209        candidate_buffers: &[&CudaBuffer],
210        positives: &CudaBuffer,
211        negatives: &CudaBuffer,
212    ) -> Result<IlpExactDeviceScores> {
213        let c = candidate_buffers.len();
214        if c == 0 {
215            return Err(XlogError::Kernel(
216                "ilp_exact_score: candidate list is empty (filter at the engine)".to_string(),
217            ));
218        }
219        let c_u32 = u32::try_from(c).map_err(|_| {
220            XlogError::Kernel(format!(
221                "ilp_exact_score: candidate count {} exceeds u32::MAX",
222                c
223            ))
224        })?;
225
226        // ── Validate shapes and gather host-side row counts ────────────────
227        let layout = validate_exact_pair_buffer(positives, "positives")?;
228        require_exact_pair_layout(negatives, "negatives", layout)?;
229        let pos_rows = cached_rows(positives, "positives")?;
230        let neg_rows = cached_rows(negatives, "negatives")?;
231
232        let mut cand_rows: Vec<u32> = Vec::with_capacity(c);
233        for (i, buf) in candidate_buffers.iter().enumerate() {
234            let label = format!("candidate[{}]", i);
235            require_exact_pair_layout(buf, &label, layout)?;
236            cand_rows.push(cached_rows(buf, &label)?);
237        }
238
239        // ── Exclusive prefix sum of row counts (cand_offsets, length C+1) ─
240        let mut cand_offsets_host: Vec<u32> = Vec::with_capacity(c + 1);
241        let mut running: u32 = 0;
242        cand_offsets_host.push(0);
243        for &r in &cand_rows {
244            running = running.checked_add(r).ok_or_else(|| {
245                XlogError::Kernel("ilp_exact_score: candidate row count overflow u32".to_string())
246            })?;
247            cand_offsets_host.push(running);
248        }
249        let total_rows = running as usize;
250        let elem_size = layout.elem_size();
251        let total_bytes = total_rows * elem_size;
252
253        let device = self.device.inner();
254
255        // ── Concatenate candidate columns via D2D copies ──────────────────
256        // Setup-phase D→D; neither counted by the D2H gate nor by the
257        // transfer tracker as a host-to-device round trip.
258        let mut cand_arg0_buf = self.memory.alloc::<u8>(total_bytes)?;
259        let mut cand_arg1_buf = self.memory.alloc::<u8>(total_bytes)?;
260        if total_bytes > 0 {
261            let mut byte_offset: usize = 0;
262            for (i, buf) in candidate_buffers.iter().enumerate() {
263                let rows = cand_rows[i] as usize;
264                if rows == 0 {
265                    continue;
266                }
267                let bytes = rows * elem_size;
268
269                let src0 = buf.column(0).ok_or_else(|| {
270                    XlogError::Kernel(format!("candidate[{}] missing column 0", i))
271                })?;
272                let src1 = buf.column(1).ok_or_else(|| {
273                    XlogError::Kernel(format!("candidate[{}] missing column 1", i))
274                })?;
275                let src_view0 = self.column_bytes_view(src0, bytes)?;
276                let src_view1 = self.column_bytes_view(src1, bytes)?;
277                let mut dst0 = cand_arg0_buf.slice_mut(byte_offset..byte_offset + bytes);
278                let mut dst1 = cand_arg1_buf.slice_mut(byte_offset..byte_offset + bytes);
279                device.dtod_copy(&src_view0, &mut dst0).map_err(|e| {
280                    XlogError::Kernel(format!(
281                        "ilp_exact_score: d2d concat arg0 (candidate {}): {}",
282                        i, e
283                    ))
284                })?;
285                device.dtod_copy(&src_view1, &mut dst1).map_err(|e| {
286                    XlogError::Kernel(format!(
287                        "ilp_exact_score: d2d concat arg1 (candidate {}): {}",
288                        i, e
289                    ))
290                })?;
291                byte_offset += bytes;
292            }
293        }
294
295        // ── Upload cand_offsets (H→D, not D2H-counted) ────────────────────
296        let mut cand_offsets_buf = self.memory.alloc::<u32>(c + 1)?;
297        self.htod_sync_copy_into_tracked(&cand_offsets_host, &mut cand_offsets_buf)
298            .map_err(|e| XlogError::Kernel(format!("ilp_exact_score: h2d cand_offsets: {}", e)))?;
299
300        // ── Alloc output count arrays ─────────────────────────────────────
301        let n_slots = 4usize
302            .checked_mul(c)
303            .and_then(|v| v.checked_mul(c))
304            .ok_or_else(|| {
305                XlogError::Kernel("ilp_exact_score: n_slots = 4 * C * C overflow".to_string())
306            })?;
307        let mut pos_covered_buf = self.memory.alloc::<u32>(n_slots)?;
308        let mut neg_covered_buf = self.memory.alloc::<u32>(n_slots)?;
309        // Kernel writes every slot exactly once — no zero-init required.
310
311        let pos_col0 = positives
312            .column(0)
313            .ok_or_else(|| XlogError::Kernel("positives: missing column 0".to_string()))?;
314        let pos_col1 = positives
315            .column(1)
316            .ok_or_else(|| XlogError::Kernel("positives: missing column 1".to_string()))?;
317        let neg_col0 = negatives
318            .column(0)
319            .ok_or_else(|| XlogError::Kernel("negatives: missing column 0".to_string()))?;
320        let neg_col1 = negatives
321            .column(1)
322            .ok_or_else(|| XlogError::Kernel("negatives: missing column 1".to_string()))?;
323
324        // ── Launch ────────────────────────────────────────────────────────
325        let max_candidate_rows = cand_rows.iter().copied().max().unwrap_or(0);
326        let chain_smem_enabled =
327            ilp_exact_chain_smem_enabled() && max_candidate_rows >= ilp_exact_chain_smem_min_rows();
328        let shared_mem_bytes = if chain_smem_enabled {
329            chain_smem_shared_bytes(layout)
330        } else {
331            0
332        };
333        match layout {
334            ExactPairLayout::U64 => {
335                let cand_arg0_view = RawCudaView::<u64> {
336                    ptr: *cand_arg0_buf.device_ptr(),
337                    len: total_rows,
338                    stream: cand_arg0_buf.stream().clone(),
339                    source_block: None,
340                    _marker: PhantomData,
341                };
342                let cand_arg1_view = RawCudaView::<u64> {
343                    ptr: *cand_arg1_buf.device_ptr(),
344                    len: total_rows,
345                    stream: cand_arg1_buf.stream().clone(),
346                    source_block: None,
347                    _marker: PhantomData,
348                };
349                let pos_arg0_view = self.column_as_u64_view(pos_col0, pos_rows as usize)?;
350                let pos_arg1_view = self.column_as_u64_view(pos_col1, pos_rows as usize)?;
351                let neg_arg0_view = self.column_as_u64_view(neg_col0, neg_rows as usize)?;
352                let neg_arg1_view = self.column_as_u64_view(neg_col1, neg_rows as usize)?;
353                let kernel_name = if chain_smem_enabled {
354                    ilp_exact_kernels::ILP_EXACT_SCORE_CHAIN_SMEM
355                } else {
356                    ilp_exact_kernels::ILP_EXACT_SCORE
357                };
358                let func = device
359                    .get_func(ILP_EXACT_MODULE, kernel_name)
360                    .ok_or_else(|| {
361                        XlogError::Kernel(format!("{} kernel not loaded", kernel_name))
362                    })?;
363                unsafe {
364                    func.clone().launch(
365                        LaunchConfig {
366                            grid_dim: (c_u32, c_u32, 4),
367                            block_dim: (ILP_EXACT_BLOCK_SIZE, 1, 1),
368                            shared_mem_bytes,
369                        },
370                        (
371                            &cand_arg0_view,
372                            &cand_arg1_view,
373                            &cand_offsets_buf,
374                            c_u32,
375                            &pos_arg0_view,
376                            &pos_arg1_view,
377                            pos_rows,
378                            &neg_arg0_view,
379                            &neg_arg1_view,
380                            neg_rows,
381                            &mut pos_covered_buf,
382                            &mut neg_covered_buf,
383                        ),
384                    )
385                }
386                .map_err(|e| XlogError::Kernel(format!("ilp_exact_score launch: {}", e)))?;
387            }
388            ExactPairLayout::U32 | ExactPairLayout::Symbol => {
389                let cand_arg0_view = RawCudaView::<u32> {
390                    ptr: *cand_arg0_buf.device_ptr(),
391                    len: total_rows,
392                    stream: cand_arg0_buf.stream().clone(),
393                    source_block: None,
394                    _marker: PhantomData,
395                };
396                let cand_arg1_view = RawCudaView::<u32> {
397                    ptr: *cand_arg1_buf.device_ptr(),
398                    len: total_rows,
399                    stream: cand_arg1_buf.stream().clone(),
400                    source_block: None,
401                    _marker: PhantomData,
402                };
403                let pos_arg0_view = self.column_as_u32_view(pos_col0, pos_rows as usize)?;
404                let pos_arg1_view = self.column_as_u32_view(pos_col1, pos_rows as usize)?;
405                let neg_arg0_view = self.column_as_u32_view(neg_col0, neg_rows as usize)?;
406                let neg_arg1_view = self.column_as_u32_view(neg_col1, neg_rows as usize)?;
407                let kernel_name = if chain_smem_enabled {
408                    ilp_exact_kernels::ILP_EXACT_SCORE_CHAIN_SMEM_U32
409                } else {
410                    ilp_exact_kernels::ILP_EXACT_SCORE_U32
411                };
412                let func = device
413                    .get_func(ILP_EXACT_MODULE, kernel_name)
414                    .ok_or_else(|| {
415                        XlogError::Kernel(format!("{} kernel not loaded", kernel_name))
416                    })?;
417                unsafe {
418                    func.clone().launch(
419                        LaunchConfig {
420                            grid_dim: (c_u32, c_u32, 4),
421                            block_dim: (ILP_EXACT_BLOCK_SIZE, 1, 1),
422                            shared_mem_bytes,
423                        },
424                        (
425                            &cand_arg0_view,
426                            &cand_arg1_view,
427                            &cand_offsets_buf,
428                            c_u32,
429                            &pos_arg0_view,
430                            &pos_arg1_view,
431                            pos_rows,
432                            &neg_arg0_view,
433                            &neg_arg1_view,
434                            neg_rows,
435                            &mut pos_covered_buf,
436                            &mut neg_covered_buf,
437                        ),
438                    )
439                }
440                .map_err(|e| XlogError::Kernel(format!("ilp_exact_score_u32 launch: {}", e)))?;
441            }
442        }
443
444        Ok(IlpExactDeviceScores {
445            candidate_count: c,
446            #[cfg(test)]
447            slot_count: n_slots,
448            pos_covered: pos_covered_buf,
449            neg_covered: neg_covered_buf,
450        })
451    }
452}
453
454fn validate_exact_pair_buffer(buf: &CudaBuffer, label: &str) -> Result<ExactPairLayout> {
455    if buf.arity() != 2 {
456        return Err(XlogError::Kernel(format!(
457            "ilp_exact_score: {} buffer arity = {}, expected 2",
458            label,
459            buf.arity(),
460        )));
461    }
462    let mut layout: Option<ExactPairLayout> = None;
463    for col_idx in 0..2 {
464        let t = buf.schema().column_type(col_idx).ok_or_else(|| {
465            XlogError::Kernel(format!(
466                "ilp_exact_score: {} buffer missing column {} type",
467                label, col_idx,
468            ))
469        })?;
470        let col_layout = match t {
471            ScalarType::U64 => ExactPairLayout::U64,
472            ScalarType::U32 => ExactPairLayout::U32,
473            ScalarType::Symbol => ExactPairLayout::Symbol,
474            _ => {
475                return Err(XlogError::Kernel(format!(
476                    "ilp_exact_score: {} buffer column {} type = {:?}, expected U64, U32, or Symbol",
477                    label, col_idx, t,
478                )));
479            }
480        };
481        if let Some(expected) = layout {
482            if expected != col_layout {
483                return Err(XlogError::Kernel(format!(
484                    "ilp_exact_score: {} buffer column {} type mismatch: {:?} vs {:?}",
485                    label, col_idx, expected, col_layout,
486                )));
487            }
488        } else {
489            layout = Some(col_layout);
490        }
491    }
492    Ok(layout.expect("arity 2 loop sets layout"))
493}
494
495fn require_exact_pair_layout(
496    buf: &CudaBuffer,
497    label: &str,
498    expected: ExactPairLayout,
499) -> Result<()> {
500    let actual = validate_exact_pair_buffer(buf, label)?;
501    if actual != expected {
502        return Err(XlogError::Kernel(format!(
503            "ilp_exact_score: {} buffer type mismatch: expected {:?}, got {:?}",
504            label, expected, actual,
505        )));
506    }
507    Ok(())
508}
509
510fn cached_rows(buf: &CudaBuffer, label: &str) -> Result<u32> {
511    buf.cached_row_count().ok_or_else(|| {
512        XlogError::Kernel(format!(
513            "ilp_exact_score: {} buffer has no cached row count \
514             (DLPack ingest and create_empty_buffer both populate it)",
515            label
516        ))
517    })
518}
519
520#[cfg(test)]
521mod tests {
522    //! CUDA-gated correctness tests for the ilp_exact launcher.
523    //!
524    //! Pinned to a hand-computed fixture so the kernel's coverage arithmetic
525    //! can be verified without relying on the Python backend as oracle. The
526    //! fixture uses C=2 candidate relations so the expected flat output
527    //! (4 × C × C = 16 slots per count array) is tractable to enumerate.
528
529    use std::sync::Arc;
530
531    use xlog_core::{MemoryBudget, ScalarType, Schema};
532
533    use crate::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
534
535    fn make_provider() -> Option<CudaKernelProvider> {
536        let device = Arc::new(CudaDevice::new(0).ok()?);
537        let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
538        let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
539        CudaKernelProvider::new(device, memory).ok()
540    }
541
542    /// Build a `(u64, u64)` pair buffer from parallel host-side column arrays.
543    /// Uses `create_buffer_from_slice` per column then recombines, relying on
544    /// the provider's buffer-from-columns path to set the cached row count.
545    fn pair_buffer(provider: &CudaKernelProvider, arg0: &[u64], arg1: &[u64]) -> crate::CudaBuffer {
546        assert_eq!(arg0.len(), arg1.len());
547        let schema = Schema::new(vec![
548            ("arg0".to_string(), ScalarType::U64),
549            ("arg1".to_string(), ScalarType::U64),
550        ]);
551        if arg0.is_empty() {
552            return provider
553                .create_empty_buffer(schema)
554                .expect("empty pair buffer");
555        }
556        // Pack both columns as a single 2-column buffer by constructing
557        // byte-columns manually — mirrors what `from_dlpack_tensors_with_schema`
558        // does for the in-process launcher tests.
559        let device = provider.device().inner();
560        let arg0_bytes: Vec<u8> = arg0.iter().flat_map(|v| v.to_le_bytes()).collect();
561        let arg1_bytes: Vec<u8> = arg1.iter().flat_map(|v| v.to_le_bytes()).collect();
562        let mut col0 = provider
563            .memory()
564            .alloc::<u8>(arg0_bytes.len())
565            .expect("alloc");
566        let mut col1 = provider
567            .memory()
568            .alloc::<u8>(arg1_bytes.len())
569            .expect("alloc");
570        device
571            .htod_sync_copy_into(&arg0_bytes, &mut col0)
572            .expect("h2d arg0");
573        device
574            .htod_sync_copy_into(&arg1_bytes, &mut col1)
575            .expect("h2d arg1");
576        provider
577            .buffer_from_columns(vec![col0.into(), col1.into()], arg0.len() as u64, schema)
578            .expect("buffer_from_columns")
579    }
580
581    fn pair_buffer_u32(
582        provider: &CudaKernelProvider,
583        arg0: &[u32],
584        arg1: &[u32],
585        typ: ScalarType,
586    ) -> crate::CudaBuffer {
587        assert_eq!(arg0.len(), arg1.len());
588        assert!(matches!(typ, ScalarType::U32 | ScalarType::Symbol));
589        let schema = Schema::new(vec![("arg0".to_string(), typ), ("arg1".to_string(), typ)]);
590        if arg0.is_empty() {
591            return provider
592                .create_empty_buffer(schema)
593                .expect("empty pair buffer");
594        }
595        let device = provider.device().inner();
596        let arg0_bytes: Vec<u8> = arg0.iter().flat_map(|v| v.to_le_bytes()).collect();
597        let arg1_bytes: Vec<u8> = arg1.iter().flat_map(|v| v.to_le_bytes()).collect();
598        let mut col0 = provider
599            .memory()
600            .alloc::<u8>(arg0_bytes.len())
601            .expect("alloc");
602        let mut col1 = provider
603            .memory()
604            .alloc::<u8>(arg1_bytes.len())
605            .expect("alloc");
606        device
607            .htod_sync_copy_into(&arg0_bytes, &mut col0)
608            .expect("h2d arg0");
609        device
610            .htod_sync_copy_into(&arg1_bytes, &mut col1)
611            .expect("h2d arg1");
612        provider
613            .buffer_from_columns(vec![col0.into(), col1.into()], arg0.len() as u64, schema)
614            .expect("buffer_from_columns")
615    }
616
617    fn pair_buffer_i32(
618        provider: &CudaKernelProvider,
619        arg0: &[i32],
620        arg1: &[i32],
621    ) -> crate::CudaBuffer {
622        assert_eq!(arg0.len(), arg1.len());
623        let schema = Schema::new(vec![
624            ("arg0".to_string(), ScalarType::I32),
625            ("arg1".to_string(), ScalarType::I32),
626        ]);
627        if arg0.is_empty() {
628            return provider
629                .create_empty_buffer(schema)
630                .expect("empty pair buffer");
631        }
632        let device = provider.device().inner();
633        let arg0_bytes: Vec<u8> = arg0.iter().flat_map(|v| v.to_le_bytes()).collect();
634        let arg1_bytes: Vec<u8> = arg1.iter().flat_map(|v| v.to_le_bytes()).collect();
635        let mut col0 = provider
636            .memory()
637            .alloc::<u8>(arg0_bytes.len())
638            .expect("alloc");
639        let mut col1 = provider
640            .memory()
641            .alloc::<u8>(arg1_bytes.len())
642            .expect("alloc");
643        device
644            .htod_sync_copy_into(&arg0_bytes, &mut col0)
645            .expect("h2d arg0");
646        device
647            .htod_sync_copy_into(&arg1_bytes, &mut col1)
648            .expect("h2d arg1");
649        provider
650            .buffer_from_columns(vec![col0.into(), col1.into()], arg0.len() as u64, schema)
651            .expect("buffer_from_columns")
652    }
653
654    /// Hand-computed coverage for C=2 candidates {p_B, p_C} against positives
655    /// `{(1,4), (2,5)}` and negatives `{(7,8)}`. The only non-zero coverage
656    /// is `chain(p_B, p_C) = 2` (both positives covered via chain joins
657    /// z=2 and z=3). Everything else is zero by direct enumeration of the
658    /// four topology templates — see
659    /// `docs/plans/2026-04-17-m8-ilp-exact-kernel-design.md` for the
660    /// templates. Also exercises the negative-scoring path with one negative
661    /// that no topology-L-R combination covers.
662    #[test]
663    fn ilp_exact_score_matches_hand_computed_fixture() {
664        let provider = match make_provider() {
665            Some(p) => p,
666            None => {
667                eprintln!("Skipping test: no CUDA device available");
668                return;
669            }
670        };
671
672        // Candidate relations.
673        let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
674        let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
675
676        // Positives: {(1,4), (2,5)}. Negatives: {(7,8)}.
677        let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
678        let negatives = pair_buffer(&provider, &[7], &[8]);
679
680        let (pos, neg) = provider
681            .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
682            .expect("ilp_exact_score launch");
683
684        // Slot layout: topology * C² + L * C + R, with C=2.
685        //   topology: chain=0, star=1, fanout=2, fanin=3.
686        //   L/R: p_B=0, p_C=1.
687        // Only chain(p_B=0, p_C=1) → slot 0*4 + 0*2 + 1 = 1 is non-zero.
688        let mut expected_pos = vec![0u32; 16];
689        expected_pos[1] = 2;
690        assert_eq!(
691            pos, expected_pos,
692            "positives coverage mismatch: expected {:?}, got {:?}",
693            expected_pos, pos,
694        );
695
696        // All negatives coverage slots are zero: no (L, R, topology) covers (7, 8).
697        let expected_neg = vec![0u32; 16];
698        assert_eq!(
699            neg, expected_neg,
700            "negatives coverage mismatch: expected {:?}, got {:?}",
701            expected_neg, neg,
702        );
703    }
704
705    #[test]
706    fn ilp_exact_score_topk_reduces_on_device_to_compact_result() {
707        let provider = match make_provider() {
708            Some(p) => p,
709            None => {
710                eprintln!("Skipping test: no CUDA device available");
711                return;
712            }
713        };
714
715        let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
716        let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
717        let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
718        let negatives = pair_buffer(&provider, &[7], &[8]);
719
720        provider.reset_d2h_transfer_count();
721        let selected = provider
722            .ilp_exact_score_topk(&[&p_b, &p_c], &positives, &negatives, 2)
723            .expect("ilp_exact_score_topk launch");
724
725        assert_eq!(provider.d2h_transfer_count(), 1);
726        assert_eq!(selected.len(), 1);
727        let winner = selected[0];
728        assert_eq!(winner.topology_idx, 0);
729        assert_eq!(winner.left_idx, 0);
730        assert_eq!(winner.right_idx, 1);
731        assert_eq!(winner.positives_covered, 2);
732        assert_eq!(winner.negatives_covered, 0);
733        assert_eq!(winner.local_rank, 0);
734        assert_eq!(winner.next_positives_covered, 0);
735        assert_eq!(winner.next_negatives_covered, 0);
736        assert_eq!(winner.tie_class_size, 1);
737    }
738
739    #[test]
740    fn ilp_exact_score_topk_preserves_rank_next_and_tie_diagnostics() {
741        let provider = match make_provider() {
742            Some(p) => p,
743            None => {
744                eprintln!("Skipping test: no CUDA device available");
745                return;
746            }
747        };
748
749        let p_all = pair_buffer(&provider, &[1, 2], &[1, 2]);
750        let p_one = pair_buffer(&provider, &[1], &[1]);
751        let p_two = pair_buffer(&provider, &[2], &[2]);
752        let positives = pair_buffer(&provider, &[1, 2], &[1, 2]);
753        let negatives = pair_buffer(&provider, &[9], &[9]);
754
755        let selected = provider
756            .ilp_exact_score_topk(&[&p_all, &p_one, &p_two], &positives, &negatives, 2)
757            .expect("ilp_exact_score_topk launch");
758
759        let star_rank0 = selected
760            .iter()
761            .find(|row| row.topology_idx == 1 && row.local_rank == 0)
762            .expect("star rank 0");
763        assert_eq!(star_rank0.left_idx, 0);
764        assert_eq!(star_rank0.right_idx, 0);
765        assert_eq!(star_rank0.positives_covered, 2);
766        assert_eq!(star_rank0.negatives_covered, 0);
767        assert_eq!(star_rank0.next_positives_covered, 1);
768        assert_eq!(star_rank0.next_negatives_covered, 0);
769        assert_eq!(star_rank0.tie_class_size, 1);
770
771        let star_rank1 = selected
772            .iter()
773            .find(|row| row.topology_idx == 1 && row.local_rank == 1)
774            .expect("star rank 1");
775        assert_eq!(star_rank1.left_idx, 0);
776        assert_eq!(star_rank1.right_idx, 1);
777        assert_eq!(star_rank1.positives_covered, 1);
778        assert_eq!(star_rank1.negatives_covered, 0);
779        assert_eq!(star_rank1.next_positives_covered, 1);
780        assert_eq!(star_rank1.next_negatives_covered, 0);
781        assert_eq!(star_rank1.tie_class_size, 6);
782    }
783
784    /// Determinism: the same inputs produce identical outputs on repeat runs.
785    /// The kernel relies on integer counts + each block owning one unique
786    /// output slot, so determinism is structural — no associativity or
787    /// floating-point ordering concerns. Still worth pinning as a regression
788    /// guard in case a future change swaps in atomics or shared state.
789    #[test]
790    fn ilp_exact_score_is_deterministic_across_runs() {
791        let provider = match make_provider() {
792            Some(p) => p,
793            None => {
794                eprintln!("Skipping test: no CUDA device available");
795                return;
796            }
797        };
798
799        let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
800        let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
801        let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
802        let negatives = pair_buffer(&provider, &[7], &[8]);
803
804        let run_a = provider
805            .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
806            .unwrap();
807        let run_b = provider
808            .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
809            .unwrap();
810        assert_eq!(run_a.0, run_b.0, "pos coverage drifted across runs");
811        assert_eq!(run_a.1, run_b.1, "neg coverage drifted across runs");
812    }
813
814    /// Empty negatives: when the caller supplies a zero-row negatives buffer
815    /// (the engine's normal treatment of `None`), the kernel must not
816    /// dereference the negative pointers and must leave all `neg_covered`
817    /// slots at zero.
818    #[test]
819    fn ilp_exact_score_handles_empty_negatives() {
820        let provider = match make_provider() {
821            Some(p) => p,
822            None => {
823                eprintln!("Skipping test: no CUDA device available");
824                return;
825            }
826        };
827
828        let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
829        let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
830        let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
831        let negatives = pair_buffer(&provider, &[], &[]);
832
833        let (pos, neg) = provider
834            .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
835            .unwrap();
836
837        let mut expected_pos = vec![0u32; 16];
838        expected_pos[1] = 2;
839        assert_eq!(pos, expected_pos);
840        assert_eq!(neg, vec![0u32; 16]);
841    }
842
843    #[test]
844    fn ilp_exact_score_accepts_u32_pair_buffers() {
845        let provider = match make_provider() {
846            Some(p) => p,
847            None => {
848                eprintln!("Skipping test: no CUDA device available");
849                return;
850            }
851        };
852
853        let p_b = pair_buffer_u32(&provider, &[1, 2], &[2, 3], ScalarType::U32);
854        let p_c = pair_buffer_u32(&provider, &[2, 3, 4], &[4, 5, 6], ScalarType::U32);
855        let positives = pair_buffer_u32(&provider, &[1, 2], &[4, 5], ScalarType::U32);
856        let negatives = pair_buffer_u32(&provider, &[7], &[8], ScalarType::U32);
857
858        let (pos, neg) = provider
859            .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
860            .expect("U32 ilp_exact_score launch");
861
862        let mut expected_pos = vec![0u32; 16];
863        expected_pos[1] = 2;
864        assert_eq!(pos, expected_pos);
865        assert_eq!(neg, vec![0u32; 16]);
866    }
867
868    #[test]
869    fn ilp_exact_score_accepts_symbol_pair_buffers() {
870        let provider = match make_provider() {
871            Some(p) => p,
872            None => {
873                eprintln!("Skipping test: no CUDA device available");
874                return;
875            }
876        };
877
878        let p_b = pair_buffer_u32(&provider, &[1, 2], &[2, 3], ScalarType::Symbol);
879        let p_c = pair_buffer_u32(&provider, &[2, 3, 4], &[4, 5, 6], ScalarType::Symbol);
880        let positives = pair_buffer_u32(&provider, &[1, 2], &[4, 5], ScalarType::Symbol);
881        let negatives = pair_buffer_u32(&provider, &[7], &[8], ScalarType::Symbol);
882
883        let (pos, neg) = provider
884            .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
885            .expect("Symbol ilp_exact_score launch");
886
887        let mut expected_pos = vec![0u32; 16];
888        expected_pos[1] = 2;
889        assert_eq!(pos, expected_pos);
890        assert_eq!(neg, vec![0u32; 16]);
891    }
892
893    #[test]
894    fn ilp_exact_score_rejects_mixed_pair_types() {
895        let provider = match make_provider() {
896            Some(p) => p,
897            None => {
898                eprintln!("Skipping test: no CUDA device available");
899                return;
900            }
901        };
902
903        let p_b = pair_buffer_u32(&provider, &[1, 2], &[2, 3], ScalarType::U32);
904        let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
905        let negatives = pair_buffer(&provider, &[7], &[8]);
906
907        let err = provider
908            .ilp_exact_score(&[&p_b], &positives, &negatives)
909            .expect_err("mixed U64/U32 buffers must be rejected");
910        assert!(
911            err.to_string().contains("expected U64") || err.to_string().contains("type mismatch"),
912            "unexpected error: {err}"
913        );
914    }
915
916    #[test]
917    fn ilp_exact_score_rejects_unsupported_pair_types() {
918        let provider = match make_provider() {
919            Some(p) => p,
920            None => {
921                eprintln!("Skipping test: no CUDA device available");
922                return;
923            }
924        };
925
926        let p_b = pair_buffer_i32(&provider, &[1, 2], &[2, 3]);
927        let positives = pair_buffer_i32(&provider, &[1, 2], &[4, 5]);
928        let negatives = pair_buffer_i32(&provider, &[7], &[8]);
929
930        let err = provider
931            .ilp_exact_score(&[&p_b], &positives, &negatives)
932            .expect_err("I32 pair buffers must be rejected");
933        assert!(
934            err.to_string().contains("expected U64, U32, or Symbol"),
935            "unexpected error: {err}"
936        );
937    }
938}