1use std::marker::PhantomData;
10use std::sync::atomic::Ordering;
11
12use crate::{LaunchAsync, LaunchConfig};
13use xlog_core::{Result, ScalarType, XlogError};
14
15use super::{ilp_exact_kernels, RawCudaView, ILP_EXACT_MODULE};
16use crate::memory::{CudaBuffer, TrackedCudaSlice};
17
18const ILP_EXACT_BLOCK_SIZE: u32 = 256;
19const ILP_EXACT_TOPK_FIELDS: usize = 9;
20const ENV_ILP_EXACT_CHAIN_SMEM: &str = "XLOG_ILP_EXACT_CHAIN_SMEM";
21const ENV_ILP_EXACT_CHAIN_SMEM_MIN_ROWS: &str = "XLOG_ILP_EXACT_CHAIN_SMEM_MIN_ROWS";
22const DEFAULT_ILP_EXACT_CHAIN_SMEM_MIN_ROWS: u32 = 256;
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq)]
25pub struct IlpExactTopkCandidate {
26 pub topology_idx: u32,
27 pub left_idx: u32,
28 pub right_idx: u32,
29 pub positives_covered: u32,
30 pub negatives_covered: u32,
31 pub local_rank: u32,
32 pub next_positives_covered: u32,
33 pub next_negatives_covered: u32,
34 pub tie_class_size: u32,
35}
36
37struct IlpExactDeviceScores {
38 candidate_count: usize,
39 #[cfg(test)]
40 slot_count: usize,
41 pos_covered: TrackedCudaSlice<u32>,
42 neg_covered: TrackedCudaSlice<u32>,
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46enum ExactPairLayout {
47 U64,
48 U32,
49 Symbol,
50}
51
52impl ExactPairLayout {
53 fn elem_size(self) -> usize {
54 match self {
55 Self::U64 => std::mem::size_of::<u64>(),
56 Self::U32 | Self::Symbol => std::mem::size_of::<u32>(),
57 }
58 }
59}
60
61fn ilp_exact_chain_smem_enabled() -> bool {
62 match std::env::var(ENV_ILP_EXACT_CHAIN_SMEM) {
63 Ok(value) => !matches!(
64 value.trim().to_ascii_lowercase().as_str(),
65 "0" | "false" | "off" | "no"
66 ),
67 Err(_) => true,
68 }
69}
70
71fn chain_smem_shared_bytes(layout: ExactPairLayout) -> u32 {
72 let block = ILP_EXACT_BLOCK_SIZE as usize;
73 let bytes = (2usize * block * layout.elem_size()) + (block * std::mem::size_of::<u32>());
74 u32::try_from(bytes).expect("chain smem byte count fits in u32")
75}
76
77fn ilp_exact_chain_smem_min_rows() -> u32 {
78 std::env::var(ENV_ILP_EXACT_CHAIN_SMEM_MIN_ROWS)
79 .ok()
80 .and_then(|value| value.trim().parse::<u32>().ok())
81 .unwrap_or(DEFAULT_ILP_EXACT_CHAIN_SMEM_MIN_ROWS)
82}
83
84impl super::CudaKernelProvider {
85 #[cfg(test)]
104 fn ilp_exact_score(
105 &self,
106 candidate_buffers: &[&CudaBuffer],
107 positives: &CudaBuffer,
108 negatives: &CudaBuffer,
109 ) -> Result<(Vec<u32>, Vec<u32>)> {
110 let scores = self.ilp_exact_score_device(candidate_buffers, positives, negatives)?;
111 let device = self.device.inner();
112 self.device.synchronize()?;
113
114 let mut pos_covered = vec![0u32; scores.slot_count];
115 self.d2h_transfer_count.fetch_add(1, Ordering::Relaxed);
116 device
117 .dtoh_sync_copy_into(&scores.pos_covered, &mut pos_covered)
118 .map_err(|e| XlogError::Kernel(format!("ilp_exact_score: dtoh pos_covered: {}", e)))?;
119
120 let mut neg_covered = vec![0u32; scores.slot_count];
121 self.d2h_transfer_count.fetch_add(1, Ordering::Relaxed);
122 device
123 .dtoh_sync_copy_into(&scores.neg_covered, &mut neg_covered)
124 .map_err(|e| XlogError::Kernel(format!("ilp_exact_score: dtoh neg_covered: {}", e)))?;
125
126 Ok((pos_covered, neg_covered))
127 }
128
129 pub fn ilp_exact_score_topk(
132 &self,
133 candidate_buffers: &[&CudaBuffer],
134 positives: &CudaBuffer,
135 negatives: &CudaBuffer,
136 k_per_topology: u32,
137 ) -> Result<Vec<IlpExactTopkCandidate>> {
138 if k_per_topology == 0 {
139 return Ok(Vec::new());
140 }
141
142 let scores = self.ilp_exact_score_device(candidate_buffers, positives, negatives)?;
143 let out_rows = 4usize
144 .checked_mul(k_per_topology as usize)
145 .ok_or_else(|| XlogError::Kernel("ilp_exact_score_topk: output row overflow".into()))?;
146 let out_words = out_rows.checked_mul(ILP_EXACT_TOPK_FIELDS).ok_or_else(|| {
147 XlogError::Kernel("ilp_exact_score_topk: output word overflow".into())
148 })?;
149 let mut selected_buf = self.memory.alloc::<u32>(out_words)?;
150 let device = self.device.inner();
151 let func = device
152 .get_func(ILP_EXACT_MODULE, ilp_exact_kernels::ILP_EXACT_SELECT_TOPK)
153 .ok_or_else(|| {
154 XlogError::Kernel(format!(
155 "{} kernel not loaded",
156 ilp_exact_kernels::ILP_EXACT_SELECT_TOPK
157 ))
158 })?;
159
160 unsafe {
161 func.clone().launch(
162 LaunchConfig {
163 grid_dim: (4, 1, 1),
164 block_dim: (1, 1, 1),
165 shared_mem_bytes: 0,
166 },
167 (
168 &scores.pos_covered,
169 &scores.neg_covered,
170 scores.candidate_count as u32,
171 k_per_topology,
172 &mut selected_buf,
173 ),
174 )
175 }
176 .map_err(|e| XlogError::Kernel(format!("ilp_exact_select_topk launch: {}", e)))?;
177
178 self.device.synchronize()?;
179 let mut words = vec![0u32; out_words];
180 self.d2h_transfer_count.fetch_add(1, Ordering::Relaxed);
181 device
182 .dtoh_sync_copy_into(&selected_buf, &mut words)
183 .map_err(|e| {
184 XlogError::Kernel(format!("ilp_exact_score_topk: dtoh selected: {}", e))
185 })?;
186
187 let mut selected = Vec::new();
188 for chunk in words.chunks_exact(ILP_EXACT_TOPK_FIELDS) {
189 if chunk[3] == 0 {
190 continue;
191 }
192 selected.push(IlpExactTopkCandidate {
193 topology_idx: chunk[0],
194 left_idx: chunk[1],
195 right_idx: chunk[2],
196 positives_covered: chunk[3],
197 negatives_covered: chunk[4],
198 local_rank: chunk[5],
199 next_positives_covered: chunk[6],
200 next_negatives_covered: chunk[7],
201 tie_class_size: chunk[8],
202 });
203 }
204 Ok(selected)
205 }
206
207 fn ilp_exact_score_device(
208 &self,
209 candidate_buffers: &[&CudaBuffer],
210 positives: &CudaBuffer,
211 negatives: &CudaBuffer,
212 ) -> Result<IlpExactDeviceScores> {
213 let c = candidate_buffers.len();
214 if c == 0 {
215 return Err(XlogError::Kernel(
216 "ilp_exact_score: candidate list is empty (filter at the engine)".to_string(),
217 ));
218 }
219 let c_u32 = u32::try_from(c).map_err(|_| {
220 XlogError::Kernel(format!(
221 "ilp_exact_score: candidate count {} exceeds u32::MAX",
222 c
223 ))
224 })?;
225
226 let layout = validate_exact_pair_buffer(positives, "positives")?;
228 require_exact_pair_layout(negatives, "negatives", layout)?;
229 let pos_rows = cached_rows(positives, "positives")?;
230 let neg_rows = cached_rows(negatives, "negatives")?;
231
232 let mut cand_rows: Vec<u32> = Vec::with_capacity(c);
233 for (i, buf) in candidate_buffers.iter().enumerate() {
234 let label = format!("candidate[{}]", i);
235 require_exact_pair_layout(buf, &label, layout)?;
236 cand_rows.push(cached_rows(buf, &label)?);
237 }
238
239 let mut cand_offsets_host: Vec<u32> = Vec::with_capacity(c + 1);
241 let mut running: u32 = 0;
242 cand_offsets_host.push(0);
243 for &r in &cand_rows {
244 running = running.checked_add(r).ok_or_else(|| {
245 XlogError::Kernel("ilp_exact_score: candidate row count overflow u32".to_string())
246 })?;
247 cand_offsets_host.push(running);
248 }
249 let total_rows = running as usize;
250 let elem_size = layout.elem_size();
251 let total_bytes = total_rows * elem_size;
252
253 let device = self.device.inner();
254
255 let mut cand_arg0_buf = self.memory.alloc::<u8>(total_bytes)?;
259 let mut cand_arg1_buf = self.memory.alloc::<u8>(total_bytes)?;
260 if total_bytes > 0 {
261 let mut byte_offset: usize = 0;
262 for (i, buf) in candidate_buffers.iter().enumerate() {
263 let rows = cand_rows[i] as usize;
264 if rows == 0 {
265 continue;
266 }
267 let bytes = rows * elem_size;
268
269 let src0 = buf.column(0).ok_or_else(|| {
270 XlogError::Kernel(format!("candidate[{}] missing column 0", i))
271 })?;
272 let src1 = buf.column(1).ok_or_else(|| {
273 XlogError::Kernel(format!("candidate[{}] missing column 1", i))
274 })?;
275 let src_view0 = self.column_bytes_view(src0, bytes)?;
276 let src_view1 = self.column_bytes_view(src1, bytes)?;
277 let mut dst0 = cand_arg0_buf.slice_mut(byte_offset..byte_offset + bytes);
278 let mut dst1 = cand_arg1_buf.slice_mut(byte_offset..byte_offset + bytes);
279 device.dtod_copy(&src_view0, &mut dst0).map_err(|e| {
280 XlogError::Kernel(format!(
281 "ilp_exact_score: d2d concat arg0 (candidate {}): {}",
282 i, e
283 ))
284 })?;
285 device.dtod_copy(&src_view1, &mut dst1).map_err(|e| {
286 XlogError::Kernel(format!(
287 "ilp_exact_score: d2d concat arg1 (candidate {}): {}",
288 i, e
289 ))
290 })?;
291 byte_offset += bytes;
292 }
293 }
294
295 let mut cand_offsets_buf = self.memory.alloc::<u32>(c + 1)?;
297 self.htod_sync_copy_into_tracked(&cand_offsets_host, &mut cand_offsets_buf)
298 .map_err(|e| XlogError::Kernel(format!("ilp_exact_score: h2d cand_offsets: {}", e)))?;
299
300 let n_slots = 4usize
302 .checked_mul(c)
303 .and_then(|v| v.checked_mul(c))
304 .ok_or_else(|| {
305 XlogError::Kernel("ilp_exact_score: n_slots = 4 * C * C overflow".to_string())
306 })?;
307 let mut pos_covered_buf = self.memory.alloc::<u32>(n_slots)?;
308 let mut neg_covered_buf = self.memory.alloc::<u32>(n_slots)?;
309 let pos_col0 = positives
312 .column(0)
313 .ok_or_else(|| XlogError::Kernel("positives: missing column 0".to_string()))?;
314 let pos_col1 = positives
315 .column(1)
316 .ok_or_else(|| XlogError::Kernel("positives: missing column 1".to_string()))?;
317 let neg_col0 = negatives
318 .column(0)
319 .ok_or_else(|| XlogError::Kernel("negatives: missing column 0".to_string()))?;
320 let neg_col1 = negatives
321 .column(1)
322 .ok_or_else(|| XlogError::Kernel("negatives: missing column 1".to_string()))?;
323
324 let max_candidate_rows = cand_rows.iter().copied().max().unwrap_or(0);
326 let chain_smem_enabled =
327 ilp_exact_chain_smem_enabled() && max_candidate_rows >= ilp_exact_chain_smem_min_rows();
328 let shared_mem_bytes = if chain_smem_enabled {
329 chain_smem_shared_bytes(layout)
330 } else {
331 0
332 };
333 match layout {
334 ExactPairLayout::U64 => {
335 let cand_arg0_view = RawCudaView::<u64> {
336 ptr: *cand_arg0_buf.device_ptr(),
337 len: total_rows,
338 stream: cand_arg0_buf.stream().clone(),
339 source_block: None,
340 _marker: PhantomData,
341 };
342 let cand_arg1_view = RawCudaView::<u64> {
343 ptr: *cand_arg1_buf.device_ptr(),
344 len: total_rows,
345 stream: cand_arg1_buf.stream().clone(),
346 source_block: None,
347 _marker: PhantomData,
348 };
349 let pos_arg0_view = self.column_as_u64_view(pos_col0, pos_rows as usize)?;
350 let pos_arg1_view = self.column_as_u64_view(pos_col1, pos_rows as usize)?;
351 let neg_arg0_view = self.column_as_u64_view(neg_col0, neg_rows as usize)?;
352 let neg_arg1_view = self.column_as_u64_view(neg_col1, neg_rows as usize)?;
353 let kernel_name = if chain_smem_enabled {
354 ilp_exact_kernels::ILP_EXACT_SCORE_CHAIN_SMEM
355 } else {
356 ilp_exact_kernels::ILP_EXACT_SCORE
357 };
358 let func = device
359 .get_func(ILP_EXACT_MODULE, kernel_name)
360 .ok_or_else(|| {
361 XlogError::Kernel(format!("{} kernel not loaded", kernel_name))
362 })?;
363 unsafe {
364 func.clone().launch(
365 LaunchConfig {
366 grid_dim: (c_u32, c_u32, 4),
367 block_dim: (ILP_EXACT_BLOCK_SIZE, 1, 1),
368 shared_mem_bytes,
369 },
370 (
371 &cand_arg0_view,
372 &cand_arg1_view,
373 &cand_offsets_buf,
374 c_u32,
375 &pos_arg0_view,
376 &pos_arg1_view,
377 pos_rows,
378 &neg_arg0_view,
379 &neg_arg1_view,
380 neg_rows,
381 &mut pos_covered_buf,
382 &mut neg_covered_buf,
383 ),
384 )
385 }
386 .map_err(|e| XlogError::Kernel(format!("ilp_exact_score launch: {}", e)))?;
387 }
388 ExactPairLayout::U32 | ExactPairLayout::Symbol => {
389 let cand_arg0_view = RawCudaView::<u32> {
390 ptr: *cand_arg0_buf.device_ptr(),
391 len: total_rows,
392 stream: cand_arg0_buf.stream().clone(),
393 source_block: None,
394 _marker: PhantomData,
395 };
396 let cand_arg1_view = RawCudaView::<u32> {
397 ptr: *cand_arg1_buf.device_ptr(),
398 len: total_rows,
399 stream: cand_arg1_buf.stream().clone(),
400 source_block: None,
401 _marker: PhantomData,
402 };
403 let pos_arg0_view = self.column_as_u32_view(pos_col0, pos_rows as usize)?;
404 let pos_arg1_view = self.column_as_u32_view(pos_col1, pos_rows as usize)?;
405 let neg_arg0_view = self.column_as_u32_view(neg_col0, neg_rows as usize)?;
406 let neg_arg1_view = self.column_as_u32_view(neg_col1, neg_rows as usize)?;
407 let kernel_name = if chain_smem_enabled {
408 ilp_exact_kernels::ILP_EXACT_SCORE_CHAIN_SMEM_U32
409 } else {
410 ilp_exact_kernels::ILP_EXACT_SCORE_U32
411 };
412 let func = device
413 .get_func(ILP_EXACT_MODULE, kernel_name)
414 .ok_or_else(|| {
415 XlogError::Kernel(format!("{} kernel not loaded", kernel_name))
416 })?;
417 unsafe {
418 func.clone().launch(
419 LaunchConfig {
420 grid_dim: (c_u32, c_u32, 4),
421 block_dim: (ILP_EXACT_BLOCK_SIZE, 1, 1),
422 shared_mem_bytes,
423 },
424 (
425 &cand_arg0_view,
426 &cand_arg1_view,
427 &cand_offsets_buf,
428 c_u32,
429 &pos_arg0_view,
430 &pos_arg1_view,
431 pos_rows,
432 &neg_arg0_view,
433 &neg_arg1_view,
434 neg_rows,
435 &mut pos_covered_buf,
436 &mut neg_covered_buf,
437 ),
438 )
439 }
440 .map_err(|e| XlogError::Kernel(format!("ilp_exact_score_u32 launch: {}", e)))?;
441 }
442 }
443
444 Ok(IlpExactDeviceScores {
445 candidate_count: c,
446 #[cfg(test)]
447 slot_count: n_slots,
448 pos_covered: pos_covered_buf,
449 neg_covered: neg_covered_buf,
450 })
451 }
452}
453
454fn validate_exact_pair_buffer(buf: &CudaBuffer, label: &str) -> Result<ExactPairLayout> {
455 if buf.arity() != 2 {
456 return Err(XlogError::Kernel(format!(
457 "ilp_exact_score: {} buffer arity = {}, expected 2",
458 label,
459 buf.arity(),
460 )));
461 }
462 let mut layout: Option<ExactPairLayout> = None;
463 for col_idx in 0..2 {
464 let t = buf.schema().column_type(col_idx).ok_or_else(|| {
465 XlogError::Kernel(format!(
466 "ilp_exact_score: {} buffer missing column {} type",
467 label, col_idx,
468 ))
469 })?;
470 let col_layout = match t {
471 ScalarType::U64 => ExactPairLayout::U64,
472 ScalarType::U32 => ExactPairLayout::U32,
473 ScalarType::Symbol => ExactPairLayout::Symbol,
474 _ => {
475 return Err(XlogError::Kernel(format!(
476 "ilp_exact_score: {} buffer column {} type = {:?}, expected U64, U32, or Symbol",
477 label, col_idx, t,
478 )));
479 }
480 };
481 if let Some(expected) = layout {
482 if expected != col_layout {
483 return Err(XlogError::Kernel(format!(
484 "ilp_exact_score: {} buffer column {} type mismatch: {:?} vs {:?}",
485 label, col_idx, expected, col_layout,
486 )));
487 }
488 } else {
489 layout = Some(col_layout);
490 }
491 }
492 Ok(layout.expect("arity 2 loop sets layout"))
493}
494
495fn require_exact_pair_layout(
496 buf: &CudaBuffer,
497 label: &str,
498 expected: ExactPairLayout,
499) -> Result<()> {
500 let actual = validate_exact_pair_buffer(buf, label)?;
501 if actual != expected {
502 return Err(XlogError::Kernel(format!(
503 "ilp_exact_score: {} buffer type mismatch: expected {:?}, got {:?}",
504 label, expected, actual,
505 )));
506 }
507 Ok(())
508}
509
510fn cached_rows(buf: &CudaBuffer, label: &str) -> Result<u32> {
511 buf.cached_row_count().ok_or_else(|| {
512 XlogError::Kernel(format!(
513 "ilp_exact_score: {} buffer has no cached row count \
514 (DLPack ingest and create_empty_buffer both populate it)",
515 label
516 ))
517 })
518}
519
520#[cfg(test)]
521mod tests {
522 use std::sync::Arc;
530
531 use xlog_core::{MemoryBudget, ScalarType, Schema};
532
533 use crate::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
534
535 fn make_provider() -> Option<CudaKernelProvider> {
536 let device = Arc::new(CudaDevice::new(0).ok()?);
537 let budget = MemoryBudget::with_limit(1024 * 1024 * 1024);
538 let memory = Arc::new(GpuMemoryManager::new(device.clone(), budget));
539 CudaKernelProvider::new(device, memory).ok()
540 }
541
542 fn pair_buffer(provider: &CudaKernelProvider, arg0: &[u64], arg1: &[u64]) -> crate::CudaBuffer {
546 assert_eq!(arg0.len(), arg1.len());
547 let schema = Schema::new(vec![
548 ("arg0".to_string(), ScalarType::U64),
549 ("arg1".to_string(), ScalarType::U64),
550 ]);
551 if arg0.is_empty() {
552 return provider
553 .create_empty_buffer(schema)
554 .expect("empty pair buffer");
555 }
556 let device = provider.device().inner();
560 let arg0_bytes: Vec<u8> = arg0.iter().flat_map(|v| v.to_le_bytes()).collect();
561 let arg1_bytes: Vec<u8> = arg1.iter().flat_map(|v| v.to_le_bytes()).collect();
562 let mut col0 = provider
563 .memory()
564 .alloc::<u8>(arg0_bytes.len())
565 .expect("alloc");
566 let mut col1 = provider
567 .memory()
568 .alloc::<u8>(arg1_bytes.len())
569 .expect("alloc");
570 device
571 .htod_sync_copy_into(&arg0_bytes, &mut col0)
572 .expect("h2d arg0");
573 device
574 .htod_sync_copy_into(&arg1_bytes, &mut col1)
575 .expect("h2d arg1");
576 provider
577 .buffer_from_columns(vec![col0.into(), col1.into()], arg0.len() as u64, schema)
578 .expect("buffer_from_columns")
579 }
580
581 fn pair_buffer_u32(
582 provider: &CudaKernelProvider,
583 arg0: &[u32],
584 arg1: &[u32],
585 typ: ScalarType,
586 ) -> crate::CudaBuffer {
587 assert_eq!(arg0.len(), arg1.len());
588 assert!(matches!(typ, ScalarType::U32 | ScalarType::Symbol));
589 let schema = Schema::new(vec![("arg0".to_string(), typ), ("arg1".to_string(), typ)]);
590 if arg0.is_empty() {
591 return provider
592 .create_empty_buffer(schema)
593 .expect("empty pair buffer");
594 }
595 let device = provider.device().inner();
596 let arg0_bytes: Vec<u8> = arg0.iter().flat_map(|v| v.to_le_bytes()).collect();
597 let arg1_bytes: Vec<u8> = arg1.iter().flat_map(|v| v.to_le_bytes()).collect();
598 let mut col0 = provider
599 .memory()
600 .alloc::<u8>(arg0_bytes.len())
601 .expect("alloc");
602 let mut col1 = provider
603 .memory()
604 .alloc::<u8>(arg1_bytes.len())
605 .expect("alloc");
606 device
607 .htod_sync_copy_into(&arg0_bytes, &mut col0)
608 .expect("h2d arg0");
609 device
610 .htod_sync_copy_into(&arg1_bytes, &mut col1)
611 .expect("h2d arg1");
612 provider
613 .buffer_from_columns(vec![col0.into(), col1.into()], arg0.len() as u64, schema)
614 .expect("buffer_from_columns")
615 }
616
617 fn pair_buffer_i32(
618 provider: &CudaKernelProvider,
619 arg0: &[i32],
620 arg1: &[i32],
621 ) -> crate::CudaBuffer {
622 assert_eq!(arg0.len(), arg1.len());
623 let schema = Schema::new(vec![
624 ("arg0".to_string(), ScalarType::I32),
625 ("arg1".to_string(), ScalarType::I32),
626 ]);
627 if arg0.is_empty() {
628 return provider
629 .create_empty_buffer(schema)
630 .expect("empty pair buffer");
631 }
632 let device = provider.device().inner();
633 let arg0_bytes: Vec<u8> = arg0.iter().flat_map(|v| v.to_le_bytes()).collect();
634 let arg1_bytes: Vec<u8> = arg1.iter().flat_map(|v| v.to_le_bytes()).collect();
635 let mut col0 = provider
636 .memory()
637 .alloc::<u8>(arg0_bytes.len())
638 .expect("alloc");
639 let mut col1 = provider
640 .memory()
641 .alloc::<u8>(arg1_bytes.len())
642 .expect("alloc");
643 device
644 .htod_sync_copy_into(&arg0_bytes, &mut col0)
645 .expect("h2d arg0");
646 device
647 .htod_sync_copy_into(&arg1_bytes, &mut col1)
648 .expect("h2d arg1");
649 provider
650 .buffer_from_columns(vec![col0.into(), col1.into()], arg0.len() as u64, schema)
651 .expect("buffer_from_columns")
652 }
653
654 #[test]
663 fn ilp_exact_score_matches_hand_computed_fixture() {
664 let provider = match make_provider() {
665 Some(p) => p,
666 None => {
667 eprintln!("Skipping test: no CUDA device available");
668 return;
669 }
670 };
671
672 let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
674 let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
675
676 let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
678 let negatives = pair_buffer(&provider, &[7], &[8]);
679
680 let (pos, neg) = provider
681 .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
682 .expect("ilp_exact_score launch");
683
684 let mut expected_pos = vec![0u32; 16];
689 expected_pos[1] = 2;
690 assert_eq!(
691 pos, expected_pos,
692 "positives coverage mismatch: expected {:?}, got {:?}",
693 expected_pos, pos,
694 );
695
696 let expected_neg = vec![0u32; 16];
698 assert_eq!(
699 neg, expected_neg,
700 "negatives coverage mismatch: expected {:?}, got {:?}",
701 expected_neg, neg,
702 );
703 }
704
705 #[test]
706 fn ilp_exact_score_topk_reduces_on_device_to_compact_result() {
707 let provider = match make_provider() {
708 Some(p) => p,
709 None => {
710 eprintln!("Skipping test: no CUDA device available");
711 return;
712 }
713 };
714
715 let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
716 let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
717 let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
718 let negatives = pair_buffer(&provider, &[7], &[8]);
719
720 provider.reset_d2h_transfer_count();
721 let selected = provider
722 .ilp_exact_score_topk(&[&p_b, &p_c], &positives, &negatives, 2)
723 .expect("ilp_exact_score_topk launch");
724
725 assert_eq!(provider.d2h_transfer_count(), 1);
726 assert_eq!(selected.len(), 1);
727 let winner = selected[0];
728 assert_eq!(winner.topology_idx, 0);
729 assert_eq!(winner.left_idx, 0);
730 assert_eq!(winner.right_idx, 1);
731 assert_eq!(winner.positives_covered, 2);
732 assert_eq!(winner.negatives_covered, 0);
733 assert_eq!(winner.local_rank, 0);
734 assert_eq!(winner.next_positives_covered, 0);
735 assert_eq!(winner.next_negatives_covered, 0);
736 assert_eq!(winner.tie_class_size, 1);
737 }
738
739 #[test]
740 fn ilp_exact_score_topk_preserves_rank_next_and_tie_diagnostics() {
741 let provider = match make_provider() {
742 Some(p) => p,
743 None => {
744 eprintln!("Skipping test: no CUDA device available");
745 return;
746 }
747 };
748
749 let p_all = pair_buffer(&provider, &[1, 2], &[1, 2]);
750 let p_one = pair_buffer(&provider, &[1], &[1]);
751 let p_two = pair_buffer(&provider, &[2], &[2]);
752 let positives = pair_buffer(&provider, &[1, 2], &[1, 2]);
753 let negatives = pair_buffer(&provider, &[9], &[9]);
754
755 let selected = provider
756 .ilp_exact_score_topk(&[&p_all, &p_one, &p_two], &positives, &negatives, 2)
757 .expect("ilp_exact_score_topk launch");
758
759 let star_rank0 = selected
760 .iter()
761 .find(|row| row.topology_idx == 1 && row.local_rank == 0)
762 .expect("star rank 0");
763 assert_eq!(star_rank0.left_idx, 0);
764 assert_eq!(star_rank0.right_idx, 0);
765 assert_eq!(star_rank0.positives_covered, 2);
766 assert_eq!(star_rank0.negatives_covered, 0);
767 assert_eq!(star_rank0.next_positives_covered, 1);
768 assert_eq!(star_rank0.next_negatives_covered, 0);
769 assert_eq!(star_rank0.tie_class_size, 1);
770
771 let star_rank1 = selected
772 .iter()
773 .find(|row| row.topology_idx == 1 && row.local_rank == 1)
774 .expect("star rank 1");
775 assert_eq!(star_rank1.left_idx, 0);
776 assert_eq!(star_rank1.right_idx, 1);
777 assert_eq!(star_rank1.positives_covered, 1);
778 assert_eq!(star_rank1.negatives_covered, 0);
779 assert_eq!(star_rank1.next_positives_covered, 1);
780 assert_eq!(star_rank1.next_negatives_covered, 0);
781 assert_eq!(star_rank1.tie_class_size, 6);
782 }
783
784 #[test]
790 fn ilp_exact_score_is_deterministic_across_runs() {
791 let provider = match make_provider() {
792 Some(p) => p,
793 None => {
794 eprintln!("Skipping test: no CUDA device available");
795 return;
796 }
797 };
798
799 let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
800 let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
801 let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
802 let negatives = pair_buffer(&provider, &[7], &[8]);
803
804 let run_a = provider
805 .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
806 .unwrap();
807 let run_b = provider
808 .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
809 .unwrap();
810 assert_eq!(run_a.0, run_b.0, "pos coverage drifted across runs");
811 assert_eq!(run_a.1, run_b.1, "neg coverage drifted across runs");
812 }
813
814 #[test]
819 fn ilp_exact_score_handles_empty_negatives() {
820 let provider = match make_provider() {
821 Some(p) => p,
822 None => {
823 eprintln!("Skipping test: no CUDA device available");
824 return;
825 }
826 };
827
828 let p_b = pair_buffer(&provider, &[1, 2], &[2, 3]);
829 let p_c = pair_buffer(&provider, &[2, 3, 4], &[4, 5, 6]);
830 let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
831 let negatives = pair_buffer(&provider, &[], &[]);
832
833 let (pos, neg) = provider
834 .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
835 .unwrap();
836
837 let mut expected_pos = vec![0u32; 16];
838 expected_pos[1] = 2;
839 assert_eq!(pos, expected_pos);
840 assert_eq!(neg, vec![0u32; 16]);
841 }
842
843 #[test]
844 fn ilp_exact_score_accepts_u32_pair_buffers() {
845 let provider = match make_provider() {
846 Some(p) => p,
847 None => {
848 eprintln!("Skipping test: no CUDA device available");
849 return;
850 }
851 };
852
853 let p_b = pair_buffer_u32(&provider, &[1, 2], &[2, 3], ScalarType::U32);
854 let p_c = pair_buffer_u32(&provider, &[2, 3, 4], &[4, 5, 6], ScalarType::U32);
855 let positives = pair_buffer_u32(&provider, &[1, 2], &[4, 5], ScalarType::U32);
856 let negatives = pair_buffer_u32(&provider, &[7], &[8], ScalarType::U32);
857
858 let (pos, neg) = provider
859 .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
860 .expect("U32 ilp_exact_score launch");
861
862 let mut expected_pos = vec![0u32; 16];
863 expected_pos[1] = 2;
864 assert_eq!(pos, expected_pos);
865 assert_eq!(neg, vec![0u32; 16]);
866 }
867
868 #[test]
869 fn ilp_exact_score_accepts_symbol_pair_buffers() {
870 let provider = match make_provider() {
871 Some(p) => p,
872 None => {
873 eprintln!("Skipping test: no CUDA device available");
874 return;
875 }
876 };
877
878 let p_b = pair_buffer_u32(&provider, &[1, 2], &[2, 3], ScalarType::Symbol);
879 let p_c = pair_buffer_u32(&provider, &[2, 3, 4], &[4, 5, 6], ScalarType::Symbol);
880 let positives = pair_buffer_u32(&provider, &[1, 2], &[4, 5], ScalarType::Symbol);
881 let negatives = pair_buffer_u32(&provider, &[7], &[8], ScalarType::Symbol);
882
883 let (pos, neg) = provider
884 .ilp_exact_score(&[&p_b, &p_c], &positives, &negatives)
885 .expect("Symbol ilp_exact_score launch");
886
887 let mut expected_pos = vec![0u32; 16];
888 expected_pos[1] = 2;
889 assert_eq!(pos, expected_pos);
890 assert_eq!(neg, vec![0u32; 16]);
891 }
892
893 #[test]
894 fn ilp_exact_score_rejects_mixed_pair_types() {
895 let provider = match make_provider() {
896 Some(p) => p,
897 None => {
898 eprintln!("Skipping test: no CUDA device available");
899 return;
900 }
901 };
902
903 let p_b = pair_buffer_u32(&provider, &[1, 2], &[2, 3], ScalarType::U32);
904 let positives = pair_buffer(&provider, &[1, 2], &[4, 5]);
905 let negatives = pair_buffer(&provider, &[7], &[8]);
906
907 let err = provider
908 .ilp_exact_score(&[&p_b], &positives, &negatives)
909 .expect_err("mixed U64/U32 buffers must be rejected");
910 assert!(
911 err.to_string().contains("expected U64") || err.to_string().contains("type mismatch"),
912 "unexpected error: {err}"
913 );
914 }
915
916 #[test]
917 fn ilp_exact_score_rejects_unsupported_pair_types() {
918 let provider = match make_provider() {
919 Some(p) => p,
920 None => {
921 eprintln!("Skipping test: no CUDA device available");
922 return;
923 }
924 };
925
926 let p_b = pair_buffer_i32(&provider, &[1, 2], &[2, 3]);
927 let positives = pair_buffer_i32(&provider, &[1, 2], &[4, 5]);
928 let negatives = pair_buffer_i32(&provider, &[7], &[8]);
929
930 let err = provider
931 .ilp_exact_score(&[&p_b], &positives, &negatives)
932 .expect_err("I32 pair buffers must be rejected");
933 assert!(
934 err.to_string().contains("expected U64, U32, or Symbol"),
935 "unexpected error: {err}"
936 );
937 }
938}