Skip to main content

xlog_cuda/provider/
ilp.rs

1//! ILP (Inductive Logic Programming) kernel operations: credit/loss, COO fill, CSR histogram, reduce_sum.
2
3use std::marker::PhantomData;
4
5use crate::{DeviceSlice, LaunchAsync, LaunchConfig};
6use xlog_core::{Result, ScalarType, Schema, XlogError};
7
8use super::{ilp_credit_kernels, ilp_kernels, RawCudaView, ILP_CREDIT_MODULE, ILP_MODULE};
9use crate::memory::{CudaBuffer, CudaColumn, TrackedCudaSlice};
10
11impl super::CudaKernelProvider {
12    fn ilp_i32_view<'a>(
13        &self,
14        col: &'a CudaColumn,
15        num_elements: usize,
16    ) -> Result<RawCudaView<'a, i32>> {
17        let required_bytes = num_elements * std::mem::size_of::<i32>();
18        if col.num_bytes() < required_bytes {
19            return Err(XlogError::Kernel(format!(
20                "Column has {} bytes but {} required for {} i32 elements",
21                col.num_bytes(),
22                required_bytes,
23                num_elements
24            )));
25        }
26        let ptr = *col.device_ptr();
27        if !(ptr as usize).is_multiple_of(std::mem::align_of::<i32>()) {
28            return Err(XlogError::Kernel(
29                "Column device pointer is not i32-aligned".to_string(),
30            ));
31        }
32        Ok(RawCudaView {
33            ptr,
34            len: num_elements,
35            stream: col.stream().clone(),
36            source_block: None,
37            _marker: PhantomData,
38        })
39    }
40
41    fn ilp_i64_view<'a>(
42        &self,
43        col: &'a CudaColumn,
44        num_elements: usize,
45    ) -> Result<RawCudaView<'a, i64>> {
46        let required_bytes = num_elements * std::mem::size_of::<i64>();
47        if col.num_bytes() < required_bytes {
48            return Err(XlogError::Kernel(format!(
49                "Column has {} bytes but {} required for {} i64 elements",
50                col.num_bytes(),
51                required_bytes,
52                num_elements
53            )));
54        }
55        let ptr = *col.device_ptr();
56        if !(ptr as usize).is_multiple_of(std::mem::align_of::<i64>()) {
57            return Err(XlogError::Kernel(
58                "Column device pointer is not i64-aligned".to_string(),
59            ));
60        }
61        Ok(RawCudaView {
62            ptr,
63            len: num_elements,
64            stream: col.stream().clone(),
65            source_block: None,
66            _marker: PhantomData,
67        })
68    }
69
70    pub fn build_selected_id_mask(
71        &self,
72        ids_buf: &CudaBuffer,
73        candidate_count: usize,
74    ) -> Result<CudaBuffer> {
75        let selected_len = usize::try_from(ids_buf.num_rows())
76            .map_err(|_| XlogError::Kernel("selected id row count overflow".to_string()))?;
77        let candidate_count_u32 = u32::try_from(candidate_count).map_err(|_| {
78            XlogError::Kernel(format!(
79                "candidate count {} exceeds u32::MAX for strict sparse mask",
80                candidate_count
81            ))
82        })?;
83
84        let mut active_flags = self.memory.alloc::<u32>(candidate_count)?;
85        if candidate_count > 0 {
86            self.device
87                .inner()
88                .memset_zeros(&mut active_flags)
89                .map_err(|e| XlogError::Kernel(format!("zero strict sparse mask: {}", e)))?;
90        }
91
92        if selected_len > 0 {
93            let selected_len_u32 = u32::try_from(selected_len).map_err(|_| {
94                XlogError::Kernel(format!(
95                    "selected id count {} exceeds u32::MAX for strict sparse mask",
96                    selected_len
97                ))
98            })?;
99            let block_size = 256u32;
100            let grid_size = selected_len_u32.div_ceil(block_size);
101            let ids_col = ids_buf
102                .column(0)
103                .ok_or_else(|| XlogError::Kernel("selected id buffer has no column".to_string()))?;
104            match ids_buf.schema().column_type(0).ok_or_else(|| {
105                XlogError::Kernel("selected id buffer has no schema type".to_string())
106            })? {
107                ScalarType::U32 | ScalarType::Symbol => {
108                    let ids_view = self.column_as_u32_view(ids_col, selected_len)?;
109                    let func = self
110                        .device
111                        .inner()
112                        .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_U32)
113                        .ok_or_else(|| {
114                            XlogError::Kernel(
115                                "ilp_mark_selected_ids_u32 kernel not found".to_string(),
116                            )
117                        })?;
118                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
119                    unsafe {
120                        func.clone().launch(
121                            LaunchConfig {
122                                grid_dim: (grid_size, 1, 1),
123                                block_dim: (block_size, 1, 1),
124                                shared_mem_bytes: 0,
125                            },
126                            (
127                                &ids_view,
128                                selected_len_u32,
129                                candidate_count_u32,
130                                &mut active_flags,
131                            ),
132                        )
133                    }
134                    .map_err(|e| {
135                        XlogError::Kernel(format!(
136                            "strict sparse selected-id scatter failed: {}",
137                            e
138                        ))
139                    })?;
140                }
141                ScalarType::I32 => {
142                    let ids_view = self.ilp_i32_view(ids_col, selected_len)?;
143                    let func = self
144                        .device
145                        .inner()
146                        .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_I32)
147                        .ok_or_else(|| {
148                            XlogError::Kernel(
149                                "ilp_mark_selected_ids_i32 kernel not found".to_string(),
150                            )
151                        })?;
152                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
153                    unsafe {
154                        func.clone().launch(
155                            LaunchConfig {
156                                grid_dim: (grid_size, 1, 1),
157                                block_dim: (block_size, 1, 1),
158                                shared_mem_bytes: 0,
159                            },
160                            (
161                                &ids_view,
162                                selected_len_u32,
163                                candidate_count_u32,
164                                &mut active_flags,
165                            ),
166                        )
167                    }
168                    .map_err(|e| {
169                        XlogError::Kernel(format!(
170                            "strict sparse selected-id scatter failed: {}",
171                            e
172                        ))
173                    })?;
174                }
175                ScalarType::I64 => {
176                    let ids_view = self.ilp_i64_view(ids_col, selected_len)?;
177                    let func = self
178                        .device
179                        .inner()
180                        .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_I64)
181                        .ok_or_else(|| {
182                            XlogError::Kernel(
183                                "ilp_mark_selected_ids_i64 kernel not found".to_string(),
184                            )
185                        })?;
186                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
187                    unsafe {
188                        func.clone().launch(
189                            LaunchConfig {
190                                grid_dim: (grid_size, 1, 1),
191                                block_dim: (block_size, 1, 1),
192                                shared_mem_bytes: 0,
193                            },
194                            (
195                                &ids_view,
196                                selected_len_u32,
197                                candidate_count_u32,
198                                &mut active_flags,
199                            ),
200                        )
201                    }
202                    .map_err(|e| {
203                        XlogError::Kernel(format!(
204                            "strict sparse selected-id scatter failed: {}",
205                            e
206                        ))
207                    })?;
208                }
209                ScalarType::U64 => {
210                    let ids_view = self.column_as_u64_view(ids_col, selected_len)?;
211                    let func = self
212                        .device
213                        .inner()
214                        .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_U64)
215                        .ok_or_else(|| {
216                            XlogError::Kernel(
217                                "ilp_mark_selected_ids_u64 kernel not found".to_string(),
218                            )
219                        })?;
220                    // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
221                    unsafe {
222                        func.clone().launch(
223                            LaunchConfig {
224                                grid_dim: (grid_size, 1, 1),
225                                block_dim: (block_size, 1, 1),
226                                shared_mem_bytes: 0,
227                            },
228                            (
229                                &ids_view,
230                                selected_len_u32,
231                                candidate_count_u32,
232                                &mut active_flags,
233                            ),
234                        )
235                    }
236                    .map_err(|e| {
237                        XlogError::Kernel(format!(
238                            "strict sparse selected-id scatter failed: {}",
239                            e
240                        ))
241                    })?;
242                }
243                other => {
244                    return Err(XlogError::Kernel(format!(
245                        "selected candidate ids must be I32/I64/U32/U64, got {:?}",
246                        other
247                    )));
248                }
249            }
250
251            self.device
252                .synchronize()
253                .map_err(|e| XlogError::Kernel(format!("strict sparse scatter sync: {}", e)))?;
254        }
255
256        let d_num_rows = self.upload_device_row_count(candidate_count_u32)?;
257        Ok(CudaBuffer::from_columns_with_host_count(
258            vec![active_flags.into_bytes().into()],
259            candidate_count as u64,
260            d_num_rows,
261            Schema::new(vec![("active".to_string(), ScalarType::U32)]),
262            candidate_count_u32,
263        ))
264    }
265
266    pub fn validate_selected_ids(
267        &self,
268        ids_buf: &CudaBuffer,
269        candidate_count: usize,
270    ) -> Result<()> {
271        let selected_len = usize::try_from(ids_buf.num_rows())
272            .map_err(|_| XlogError::Kernel("selected id row count overflow".to_string()))?;
273        let candidate_count_u32 = u32::try_from(candidate_count).map_err(|_| {
274            XlogError::Kernel(format!(
275                "candidate count {} exceeds u32::MAX for strict sparse mask",
276                candidate_count
277            ))
278        })?;
279
280        if selected_len == 0 {
281            return Ok(());
282        }
283
284        let selected_len_u32 = u32::try_from(selected_len).map_err(|_| {
285            XlogError::Kernel(format!(
286                "selected id count {} exceeds u32::MAX for strict sparse mask",
287                selected_len
288            ))
289        })?;
290        let block_size = 256u32;
291        let grid_size = selected_len_u32.div_ceil(block_size);
292        let ids_col = ids_buf
293            .column(0)
294            .ok_or_else(|| XlogError::Kernel("selected id buffer has no column".to_string()))?;
295
296        let mut seen_flags = self.memory.alloc::<u32>(candidate_count)?;
297        if candidate_count > 0 {
298            self.device
299                .inner()
300                .memset_zeros(&mut seen_flags)
301                .map_err(|e| {
302                    XlogError::Kernel(format!("zero strict sparse validation flags: {}", e))
303                })?;
304        }
305
306        let mut error_code = self.memory.alloc::<u32>(1)?;
307        let mut error_pos = self.memory.alloc::<u32>(1)?;
308        self.device
309            .inner()
310            .memset_zeros(&mut error_code)
311            .map_err(|e| XlogError::Kernel(format!("zero strict sparse error code: {}", e)))?;
312        self.device
313            .inner()
314            .memset_zeros(&mut error_pos)
315            .map_err(|e| XlogError::Kernel(format!("zero strict sparse error pos: {}", e)))?;
316
317        match ids_buf
318            .schema()
319            .column_type(0)
320            .ok_or_else(|| XlogError::Kernel("selected id buffer has no schema type".to_string()))?
321        {
322            ScalarType::U32 | ScalarType::Symbol => {
323                let ids_view = self.column_as_u32_view(ids_col, selected_len)?;
324                let func = self
325                    .device
326                    .inner()
327                    .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_U32)
328                    .ok_or_else(|| {
329                        XlogError::Kernel(
330                            "ilp_validate_selected_ids_u32 kernel not found".to_string(),
331                        )
332                    })?;
333                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
334                unsafe {
335                    func.clone().launch(
336                        LaunchConfig {
337                            grid_dim: (grid_size, 1, 1),
338                            block_dim: (block_size, 1, 1),
339                            shared_mem_bytes: 0,
340                        },
341                        (
342                            &ids_view,
343                            selected_len_u32,
344                            candidate_count_u32,
345                            &mut seen_flags,
346                            &mut error_code,
347                            &mut error_pos,
348                        ),
349                    )
350                }
351                .map_err(|e| {
352                    XlogError::Kernel(format!(
353                        "strict sparse selected-id validation failed: {}",
354                        e
355                    ))
356                })?;
357            }
358            ScalarType::I32 => {
359                let ids_view = self.ilp_i32_view(ids_col, selected_len)?;
360                let func = self
361                    .device
362                    .inner()
363                    .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_I32)
364                    .ok_or_else(|| {
365                        XlogError::Kernel(
366                            "ilp_validate_selected_ids_i32 kernel not found".to_string(),
367                        )
368                    })?;
369                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
370                unsafe {
371                    func.clone().launch(
372                        LaunchConfig {
373                            grid_dim: (grid_size, 1, 1),
374                            block_dim: (block_size, 1, 1),
375                            shared_mem_bytes: 0,
376                        },
377                        (
378                            &ids_view,
379                            selected_len_u32,
380                            candidate_count_u32,
381                            &mut seen_flags,
382                            &mut error_code,
383                            &mut error_pos,
384                        ),
385                    )
386                }
387                .map_err(|e| {
388                    XlogError::Kernel(format!(
389                        "strict sparse selected-id validation failed: {}",
390                        e
391                    ))
392                })?;
393            }
394            ScalarType::I64 => {
395                let ids_view = self.ilp_i64_view(ids_col, selected_len)?;
396                let func = self
397                    .device
398                    .inner()
399                    .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_I64)
400                    .ok_or_else(|| {
401                        XlogError::Kernel(
402                            "ilp_validate_selected_ids_i64 kernel not found".to_string(),
403                        )
404                    })?;
405                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
406                unsafe {
407                    func.clone().launch(
408                        LaunchConfig {
409                            grid_dim: (grid_size, 1, 1),
410                            block_dim: (block_size, 1, 1),
411                            shared_mem_bytes: 0,
412                        },
413                        (
414                            &ids_view,
415                            selected_len_u32,
416                            candidate_count_u32,
417                            &mut seen_flags,
418                            &mut error_code,
419                            &mut error_pos,
420                        ),
421                    )
422                }
423                .map_err(|e| {
424                    XlogError::Kernel(format!(
425                        "strict sparse selected-id validation failed: {}",
426                        e
427                    ))
428                })?;
429            }
430            ScalarType::U64 => {
431                let ids_view = self.column_as_u64_view(ids_col, selected_len)?;
432                let func = self
433                    .device
434                    .inner()
435                    .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_U64)
436                    .ok_or_else(|| {
437                        XlogError::Kernel(
438                            "ilp_validate_selected_ids_u64 kernel not found".to_string(),
439                        )
440                    })?;
441                // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
442                unsafe {
443                    func.clone().launch(
444                        LaunchConfig {
445                            grid_dim: (grid_size, 1, 1),
446                            block_dim: (block_size, 1, 1),
447                            shared_mem_bytes: 0,
448                        },
449                        (
450                            &ids_view,
451                            selected_len_u32,
452                            candidate_count_u32,
453                            &mut seen_flags,
454                            &mut error_code,
455                            &mut error_pos,
456                        ),
457                    )
458                }
459                .map_err(|e| {
460                    XlogError::Kernel(format!(
461                        "strict sparse selected-id validation failed: {}",
462                        e
463                    ))
464                })?;
465            }
466            other => {
467                return Err(XlogError::Kernel(format!(
468                    "selected candidate ids must be I32/I64/U32/U64, got {:?}",
469                    other
470                )));
471            }
472        }
473
474        self.device
475            .synchronize()
476            .map_err(|e| XlogError::Kernel(format!("strict sparse validation sync: {}", e)))?;
477
478        let error_code_host = self.dtoh_scalar_untracked(&error_code, 0)?;
479        if error_code_host == 0 {
480            return Ok(());
481        }
482        let error_pos_host = self.dtoh_scalar_untracked(&error_pos, 0)?;
483        match error_code_host {
484            1 => Err(XlogError::Kernel(format!(
485                "selected candidate id out of range at position {}",
486                error_pos_host
487            ))),
488            2 => Err(XlogError::Kernel(format!(
489                "duplicate selected candidate id at position {}",
490                error_pos_host
491            ))),
492            code => Err(XlogError::Kernel(format!(
493                "strict sparse selected-id validation failed with error code {}",
494                code
495            ))),
496        }
497    }
498
499    pub fn filter_buffer_by_candidate_flag(
500        &self,
501        input: &CudaBuffer,
502        candidate_flags: &CudaBuffer,
503        candidate_idx: usize,
504    ) -> Result<CudaBuffer> {
505        if input.is_empty() {
506            return self.create_empty_buffer(input.schema().clone());
507        }
508        if candidate_idx >= candidate_flags.num_rows() as usize {
509            return Err(XlogError::Kernel(format!(
510                "candidate flag index {} out of range [0, {})",
511                candidate_idx,
512                candidate_flags.num_rows()
513            )));
514        }
515
516        let flag_col = candidate_flags
517            .column(0)
518            .ok_or_else(|| XlogError::Kernel("candidate flag buffer has no column".to_string()))?;
519        let flag_view = self.column_as_u32_view(flag_col, candidate_flags.num_rows() as usize)?;
520        let row_count = u32::try_from(input.num_rows()).map_err(|_| {
521            XlogError::Kernel(format!(
522                "strict sparse row count {} exceeds u32::MAX",
523                input.num_rows()
524            ))
525        })?;
526        let candidate_idx_u32 = u32::try_from(candidate_idx).map_err(|_| {
527            XlogError::Kernel(format!(
528                "candidate flag index {} exceeds u32::MAX",
529                candidate_idx
530            ))
531        })?;
532
533        let mut row_mask = self.memory.alloc::<u8>(row_count as usize)?;
534        let func = self
535            .device
536            .inner()
537            .get_func(ILP_MODULE, ilp_kernels::ILP_BROADCAST_CANDIDATE_FLAG)
538            .ok_or_else(|| {
539                XlogError::Kernel("ilp_broadcast_candidate_flag kernel not found".to_string())
540            })?;
541        let block_size = 256u32;
542        let grid_size = row_count.div_ceil(block_size);
543        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
544        unsafe {
545            func.clone().launch(
546                LaunchConfig {
547                    grid_dim: (grid_size, 1, 1),
548                    block_dim: (block_size, 1, 1),
549                    shared_mem_bytes: 0,
550                },
551                (&flag_view, candidate_idx_u32, row_count, &mut row_mask),
552            )
553        }
554        .map_err(|e| XlogError::Kernel(format!("strict sparse flag broadcast failed: {}", e)))?;
555
556        self.filter_by_device_mask(input, &row_mask)
557    }
558
559    // ─── ILP credit kernel launchers ───────────────────────────────────
560
561    /// Launch `ilp_coo_fill` kernel: writes `(compacted_fact_indices[i], cidx)`
562    /// pairs at `coo_fact[offset..]` and `coo_cand[offset..]`.
563    pub fn ilp_coo_fill_launch(
564        &self,
565        compacted_fact_indices: &TrackedCudaSlice<u32>,
566        cidx: u32,
567        count: u32,
568        offset: u32,
569        coo_fact: &mut TrackedCudaSlice<u32>,
570        coo_cand: &mut TrackedCudaSlice<u32>,
571    ) -> Result<()> {
572        if count == 0 {
573            return Ok(());
574        }
575        let func = self
576            .device
577            .inner()
578            .get_func(ILP_CREDIT_MODULE, ilp_credit_kernels::ILP_COO_FILL)
579            .ok_or_else(|| XlogError::Kernel("ilp_coo_fill kernel not found".to_string()))?;
580        let block_size = 256u32;
581        let grid_size = count.div_ceil(block_size);
582        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
583        unsafe {
584            func.clone().launch(
585                LaunchConfig {
586                    grid_dim: (grid_size, 1, 1),
587                    block_dim: (block_size, 1, 1),
588                    shared_mem_bytes: 0,
589                },
590                (
591                    compacted_fact_indices,
592                    cidx,
593                    count,
594                    offset,
595                    coo_fact,
596                    coo_cand,
597                ),
598            )
599        }
600        .map_err(|e| XlogError::Kernel(format!("ilp_coo_fill failed: {}", e)))?;
601        self.device.synchronize()?;
602        Ok(())
603    }
604
605    /// Launch `ilp_credit_forward_f32`: CSR credit gather + clamp + NLL loss.
606    /// Returns `(credit_out, loss_contrib)` device slices of length `num_facts`.
607    pub fn ilp_credit_forward_f32_launch(
608        &self,
609        row_offsets: &TrackedCudaSlice<u32>,
610        col_indices: &TrackedCudaSlice<u32>,
611        cand_probs: &CudaColumn, // raw byte column from CudaBuffer
612        is_positive: &TrackedCudaSlice<u8>,
613        num_facts: u32,
614        eps: f32,
615    ) -> Result<(TrackedCudaSlice<f32>, TrackedCudaSlice<f32>)> {
616        let mut credit_out = self.memory.alloc::<f32>(num_facts as usize)?;
617        let mut loss_contrib = self.memory.alloc::<f32>(num_facts as usize)?;
618        if num_facts == 0 {
619            return Ok((credit_out, loss_contrib));
620        }
621        let func = self
622            .device
623            .inner()
624            .get_func(
625                ILP_CREDIT_MODULE,
626                ilp_credit_kernels::ILP_CREDIT_FORWARD_F32,
627            )
628            .ok_or_else(|| {
629                XlogError::Kernel("ilp_credit_forward_f32 kernel not found".to_string())
630            })?;
631        let block_size = 256u32;
632        let grid_size = num_facts.div_ceil(block_size);
633        // reinterpret the u8 byte column as f32 for the kernel
634        let cand_view = RawCudaView::<f32> {
635            ptr: *cand_probs.device_ptr(),
636            len: cudarc::driver::DeviceSlice::len(cand_probs) / 4,
637            stream: cand_probs.stream().clone(),
638            source_block: None,
639            _marker: PhantomData,
640        };
641        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
642        unsafe {
643            func.clone().launch(
644                LaunchConfig {
645                    grid_dim: (grid_size, 1, 1),
646                    block_dim: (block_size, 1, 1),
647                    shared_mem_bytes: 0,
648                },
649                (
650                    row_offsets,
651                    col_indices,
652                    &cand_view,
653                    is_positive,
654                    num_facts,
655                    eps,
656                    &mut credit_out,
657                    &mut loss_contrib,
658                ),
659            )
660        }
661        .map_err(|e| XlogError::Kernel(format!("ilp_credit_forward_f32 failed: {}", e)))?;
662        self.device.synchronize()?;
663        Ok((credit_out, loss_contrib))
664    }
665
666    /// Launch `ilp_credit_forward_f64`: CSR credit gather + clamp + NLL loss.
667    /// Returns `(credit_out, loss_contrib)` device slices of length `num_facts`.
668    pub fn ilp_credit_forward_f64_launch(
669        &self,
670        row_offsets: &TrackedCudaSlice<u32>,
671        col_indices: &TrackedCudaSlice<u32>,
672        cand_probs: &CudaColumn, // raw byte column from CudaBuffer
673        is_positive: &TrackedCudaSlice<u8>,
674        num_facts: u32,
675        eps: f64,
676    ) -> Result<(TrackedCudaSlice<f64>, TrackedCudaSlice<f64>)> {
677        let mut credit_out = self.memory.alloc::<f64>(num_facts as usize)?;
678        let mut loss_contrib = self.memory.alloc::<f64>(num_facts as usize)?;
679        if num_facts == 0 {
680            return Ok((credit_out, loss_contrib));
681        }
682        let func = self
683            .device
684            .inner()
685            .get_func(
686                ILP_CREDIT_MODULE,
687                ilp_credit_kernels::ILP_CREDIT_FORWARD_F64,
688            )
689            .ok_or_else(|| {
690                XlogError::Kernel("ilp_credit_forward_f64 kernel not found".to_string())
691            })?;
692        let block_size = 256u32;
693        let grid_size = num_facts.div_ceil(block_size);
694        let cand_view = RawCudaView::<f64> {
695            ptr: *cand_probs.device_ptr(),
696            len: cudarc::driver::DeviceSlice::len(cand_probs) / 8,
697            stream: cand_probs.stream().clone(),
698            source_block: None,
699            _marker: PhantomData,
700        };
701        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
702        unsafe {
703            func.clone().launch(
704                LaunchConfig {
705                    grid_dim: (grid_size, 1, 1),
706                    block_dim: (block_size, 1, 1),
707                    shared_mem_bytes: 0,
708                },
709                (
710                    row_offsets,
711                    col_indices,
712                    &cand_view,
713                    is_positive,
714                    num_facts,
715                    eps,
716                    &mut credit_out,
717                    &mut loss_contrib,
718                ),
719            )
720        }
721        .map_err(|e| XlogError::Kernel(format!("ilp_credit_forward_f64 failed: {}", e)))?;
722        self.device.synchronize()?;
723        Ok((credit_out, loss_contrib))
724    }
725
726    /// Launch `ilp_credit_backward_f32`: gradient scatter via CSR + atomicAdd.
727    /// Returns `d_cand_probs` gradient of length `num_cands` (zeroed, then accumulated).
728    pub fn ilp_credit_backward_f32_launch(
729        &self,
730        row_offsets: &TrackedCudaSlice<u32>,
731        col_indices: &TrackedCudaSlice<u32>,
732        credit_out: &TrackedCudaSlice<f32>,
733        is_positive: &TrackedCudaSlice<u8>,
734        num_facts: u32,
735        num_cands: u32,
736    ) -> Result<TrackedCudaSlice<f32>> {
737        let mut d_grad = self.memory.alloc::<f32>(num_cands as usize)?;
738        self.device
739            .inner()
740            .memset_zeros(&mut d_grad)
741            .map_err(|e| XlogError::Kernel(format!("Failed to zero grad: {}", e)))?;
742        if num_facts == 0 {
743            return Ok(d_grad);
744        }
745        let func = self
746            .device
747            .inner()
748            .get_func(
749                ILP_CREDIT_MODULE,
750                ilp_credit_kernels::ILP_CREDIT_BACKWARD_F32,
751            )
752            .ok_or_else(|| {
753                XlogError::Kernel("ilp_credit_backward_f32 kernel not found".to_string())
754            })?;
755        let block_size = 256u32;
756        let grid_size = num_facts.div_ceil(block_size);
757        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
758        unsafe {
759            func.clone().launch(
760                LaunchConfig {
761                    grid_dim: (grid_size, 1, 1),
762                    block_dim: (block_size, 1, 1),
763                    shared_mem_bytes: 0,
764                },
765                (
766                    row_offsets,
767                    col_indices,
768                    credit_out,
769                    is_positive,
770                    num_facts,
771                    &mut d_grad,
772                ),
773            )
774        }
775        .map_err(|e| XlogError::Kernel(format!("ilp_credit_backward_f32 failed: {}", e)))?;
776        self.device.synchronize()?;
777        Ok(d_grad)
778    }
779
780    /// Launch `ilp_credit_backward_f64`: gradient scatter via CSR + atomicAdd.
781    /// Returns `d_cand_probs` gradient of length `num_cands` (zeroed, then accumulated).
782    pub fn ilp_credit_backward_f64_launch(
783        &self,
784        row_offsets: &TrackedCudaSlice<u32>,
785        col_indices: &TrackedCudaSlice<u32>,
786        credit_out: &TrackedCudaSlice<f64>,
787        is_positive: &TrackedCudaSlice<u8>,
788        num_facts: u32,
789        num_cands: u32,
790    ) -> Result<TrackedCudaSlice<f64>> {
791        let mut d_grad = self.memory.alloc::<f64>(num_cands as usize)?;
792        self.device
793            .inner()
794            .memset_zeros(&mut d_grad)
795            .map_err(|e| XlogError::Kernel(format!("Failed to zero grad: {}", e)))?;
796        if num_facts == 0 {
797            return Ok(d_grad);
798        }
799        let func = self
800            .device
801            .inner()
802            .get_func(
803                ILP_CREDIT_MODULE,
804                ilp_credit_kernels::ILP_CREDIT_BACKWARD_F64,
805            )
806            .ok_or_else(|| {
807                XlogError::Kernel("ilp_credit_backward_f64 kernel not found".to_string())
808            })?;
809        let block_size = 256u32;
810        let grid_size = num_facts.div_ceil(block_size);
811        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
812        unsafe {
813            func.clone().launch(
814                LaunchConfig {
815                    grid_dim: (grid_size, 1, 1),
816                    block_dim: (block_size, 1, 1),
817                    shared_mem_bytes: 0,
818                },
819                (
820                    row_offsets,
821                    col_indices,
822                    credit_out,
823                    is_positive,
824                    num_facts,
825                    &mut d_grad,
826                ),
827            )
828        }
829        .map_err(|e| XlogError::Kernel(format!("ilp_credit_backward_f64 failed: {}", e)))?;
830        self.device.synchronize()?;
831        Ok(d_grad)
832    }
833
834    /// GPU-side sum reduction (f32).
835    ///
836    /// Sums `n` elements of `input` on device and returns a single-element
837    /// device buffer containing the result.  The caller must zero the output
838    /// buffer *before* launching the kernel — this function handles that.
839    pub fn ilp_reduce_sum_f32_launch(
840        &self,
841        input: &TrackedCudaSlice<f32>,
842        n: u32,
843    ) -> Result<TrackedCudaSlice<f32>> {
844        let mut d_result = self.memory.alloc::<f32>(1)?;
845        self.device
846            .inner()
847            .memset_zeros(&mut d_result)
848            .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f32 zero result: {}", e)))?;
849
850        if n == 0 {
851            return Ok(d_result);
852        }
853
854        let func = self
855            .device
856            .inner()
857            .get_func(ILP_MODULE, ilp_kernels::ILP_REDUCE_SUM_F32)
858            .ok_or_else(|| XlogError::Kernel("ilp_reduce_sum_f32 not found".to_string()))?;
859        let block_size = 256u32;
860        let grid_size = n.div_ceil(block_size);
861        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
862        unsafe {
863            func.clone().launch(
864                LaunchConfig {
865                    grid_dim: (grid_size, 1, 1),
866                    block_dim: (block_size, 1, 1),
867                    shared_mem_bytes: 0,
868                },
869                (input, n, &mut d_result),
870            )
871        }
872        .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f32: {}", e)))?;
873        self.device.synchronize()?;
874        Ok(d_result)
875    }
876
877    /// GPU-side sum reduction (f64).
878    ///
879    /// Sums `n` elements of `input` on device and returns a single-element
880    /// device buffer containing the result.  Requires sm_60+ for double
881    /// atomicAdd (this project targets sm_75 baseline).
882    pub fn ilp_reduce_sum_f64_launch(
883        &self,
884        input: &TrackedCudaSlice<f64>,
885        n: u32,
886    ) -> Result<TrackedCudaSlice<f64>> {
887        let mut d_result = self.memory.alloc::<f64>(1)?;
888        self.device
889            .inner()
890            .memset_zeros(&mut d_result)
891            .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f64 zero result: {}", e)))?;
892
893        if n == 0 {
894            return Ok(d_result);
895        }
896
897        let func = self
898            .device
899            .inner()
900            .get_func(ILP_MODULE, ilp_kernels::ILP_REDUCE_SUM_F64)
901            .ok_or_else(|| XlogError::Kernel("ilp_reduce_sum_f64 not found".to_string()))?;
902        let block_size = 256u32;
903        let grid_size = n.div_ceil(block_size);
904        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
905        unsafe {
906            func.clone().launch(
907                LaunchConfig {
908                    grid_dim: (grid_size, 1, 1),
909                    block_dim: (block_size, 1, 1),
910                    shared_mem_bytes: 0,
911                },
912                (input, n, &mut d_result),
913            )
914        }
915        .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f64: {}", e)))?;
916        self.device.synchronize()?;
917        Ok(d_result)
918    }
919
920    /// Fill COO arrays from a device-side mask and prefix-sum.
921    ///
922    /// For each set bit in `mask`, writes the corresponding `fact_indices` entry
923    /// into `coo_fact` and `cand_value` into `coo_cand` at the position
924    /// determined by `d_offsets[offset_idx] + prefix_sum[tid]`.
925    ///
926    /// Parameters:
927    /// - `offset_idx`: index into `d_offsets` for the write base position
928    /// - `cand_value`: actual candidate index to write into `coo_cand`
929    ///
930    /// This keeps COO assembly fully on device, eliminating the mask D2H transfer.
931    #[allow(clippy::too_many_arguments)]
932    pub fn ilp_coo_fill_from_mask_launch(
933        &self,
934        mask: &TrackedCudaSlice<u8>,
935        prefix_sum: &TrackedCudaSlice<u32>,
936        fact_indices: &TrackedCudaSlice<u32>,
937        offset_idx: u32,
938        cand_value: u32,
939        num_query: u32,
940        d_offsets: &TrackedCudaSlice<u32>,
941        coo_fact: &mut TrackedCudaSlice<u32>,
942        coo_cand: &mut TrackedCudaSlice<u32>,
943    ) -> Result<()> {
944        if num_query == 0 {
945            return Ok(());
946        }
947        let func = self
948            .device()
949            .inner()
950            .get_func(ILP_MODULE, ilp_kernels::ILP_COO_FILL_FROM_MASK)
951            .ok_or_else(|| XlogError::Kernel("ilp_coo_fill_from_mask not found".to_string()))?;
952        let block_size = 256u32;
953        let grid_size = num_query.div_ceil(block_size);
954        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
955        unsafe {
956            func.clone().launch(
957                LaunchConfig {
958                    grid_dim: (grid_size, 1, 1),
959                    block_dim: (block_size, 1, 1),
960                    shared_mem_bytes: 0,
961                },
962                (
963                    mask,
964                    prefix_sum,
965                    fact_indices,
966                    offset_idx,
967                    cand_value,
968                    num_query,
969                    d_offsets,
970                    coo_fact,
971                    coo_cand,
972                ),
973            )
974        }
975        .map_err(|e| XlogError::Kernel(format!("ilp_coo_fill_from_mask: {}", e)))?;
976        self.device()
977            .inner()
978            .synchronize()
979            .map_err(|e| XlogError::Kernel(format!("ilp_coo_fill_from_mask sync: {}", e)))?;
980        Ok(())
981    }
982
983    /// Build a histogram of fact indices from sorted COO data.
984    ///
985    /// For each entry in `sorted_facts[0..nnz]`, atomically increments
986    /// the corresponding bin in the output histogram. The result is a
987    /// device-side count array of length `num_facts`, suitable for
988    /// prefix-sum to produce CSR `row_offsets`.
989    ///
990    /// The caller provides sorted fact indices; the histogram is
991    /// zero-initialized internally.
992    pub fn ilp_csr_histogram_launch(
993        &self,
994        sorted_facts: &TrackedCudaSlice<u32>,
995        nnz: u32,
996        num_facts: u32,
997    ) -> Result<TrackedCudaSlice<u32>> {
998        let mut d_hist = self.memory().alloc::<u32>(num_facts as usize)?;
999        self.device()
1000            .inner()
1001            .memset_zeros(&mut d_hist)
1002            .map_err(|e| XlogError::Kernel(format!("ilp_csr_histogram zero hist: {}", e)))?;
1003
1004        if nnz == 0 {
1005            return Ok(d_hist);
1006        }
1007
1008        let func = self
1009            .device()
1010            .inner()
1011            .get_func(ILP_MODULE, ilp_kernels::ILP_CSR_HISTOGRAM)
1012            .ok_or_else(|| XlogError::Kernel("ilp_csr_histogram kernel not found".to_string()))?;
1013
1014        let block_size = 256u32;
1015        let grid_size = nnz.div_ceil(block_size);
1016
1017        // SAFETY: kernel arguments match the PTX signature; device buffers were allocated with sufficient size
1018        unsafe {
1019            func.clone()
1020                .launch(
1021                    cudarc::driver::LaunchConfig {
1022                        grid_dim: (grid_size, 1, 1),
1023                        block_dim: (block_size, 1, 1),
1024                        shared_mem_bytes: 0,
1025                    },
1026                    (sorted_facts, nnz, num_facts, &mut d_hist),
1027                )
1028                .map_err(|e| XlogError::Kernel(format!("ilp_csr_histogram launch: {}", e)))?;
1029        }
1030
1031        self.device()
1032            .inner()
1033            .synchronize()
1034            .map_err(|e| XlogError::Kernel(format!("ilp_csr_histogram sync: {}", e)))?;
1035
1036        Ok(d_hist)
1037    }
1038}