1use std::marker::PhantomData;
4
5use crate::{DeviceSlice, LaunchAsync, LaunchConfig};
6use xlog_core::{Result, ScalarType, Schema, XlogError};
7
8use super::{ilp_credit_kernels, ilp_kernels, RawCudaView, ILP_CREDIT_MODULE, ILP_MODULE};
9use crate::memory::{CudaBuffer, CudaColumn, TrackedCudaSlice};
10
11impl super::CudaKernelProvider {
12 fn ilp_i32_view<'a>(
13 &self,
14 col: &'a CudaColumn,
15 num_elements: usize,
16 ) -> Result<RawCudaView<'a, i32>> {
17 let required_bytes = num_elements * std::mem::size_of::<i32>();
18 if col.num_bytes() < required_bytes {
19 return Err(XlogError::Kernel(format!(
20 "Column has {} bytes but {} required for {} i32 elements",
21 col.num_bytes(),
22 required_bytes,
23 num_elements
24 )));
25 }
26 let ptr = *col.device_ptr();
27 if !(ptr as usize).is_multiple_of(std::mem::align_of::<i32>()) {
28 return Err(XlogError::Kernel(
29 "Column device pointer is not i32-aligned".to_string(),
30 ));
31 }
32 Ok(RawCudaView {
33 ptr,
34 len: num_elements,
35 stream: col.stream().clone(),
36 source_block: None,
37 _marker: PhantomData,
38 })
39 }
40
41 fn ilp_i64_view<'a>(
42 &self,
43 col: &'a CudaColumn,
44 num_elements: usize,
45 ) -> Result<RawCudaView<'a, i64>> {
46 let required_bytes = num_elements * std::mem::size_of::<i64>();
47 if col.num_bytes() < required_bytes {
48 return Err(XlogError::Kernel(format!(
49 "Column has {} bytes but {} required for {} i64 elements",
50 col.num_bytes(),
51 required_bytes,
52 num_elements
53 )));
54 }
55 let ptr = *col.device_ptr();
56 if !(ptr as usize).is_multiple_of(std::mem::align_of::<i64>()) {
57 return Err(XlogError::Kernel(
58 "Column device pointer is not i64-aligned".to_string(),
59 ));
60 }
61 Ok(RawCudaView {
62 ptr,
63 len: num_elements,
64 stream: col.stream().clone(),
65 source_block: None,
66 _marker: PhantomData,
67 })
68 }
69
70 pub fn build_selected_id_mask(
71 &self,
72 ids_buf: &CudaBuffer,
73 candidate_count: usize,
74 ) -> Result<CudaBuffer> {
75 let selected_len = usize::try_from(ids_buf.num_rows())
76 .map_err(|_| XlogError::Kernel("selected id row count overflow".to_string()))?;
77 let candidate_count_u32 = u32::try_from(candidate_count).map_err(|_| {
78 XlogError::Kernel(format!(
79 "candidate count {} exceeds u32::MAX for strict sparse mask",
80 candidate_count
81 ))
82 })?;
83
84 let mut active_flags = self.memory.alloc::<u32>(candidate_count)?;
85 if candidate_count > 0 {
86 self.device
87 .inner()
88 .memset_zeros(&mut active_flags)
89 .map_err(|e| XlogError::Kernel(format!("zero strict sparse mask: {}", e)))?;
90 }
91
92 if selected_len > 0 {
93 let selected_len_u32 = u32::try_from(selected_len).map_err(|_| {
94 XlogError::Kernel(format!(
95 "selected id count {} exceeds u32::MAX for strict sparse mask",
96 selected_len
97 ))
98 })?;
99 let block_size = 256u32;
100 let grid_size = selected_len_u32.div_ceil(block_size);
101 let ids_col = ids_buf
102 .column(0)
103 .ok_or_else(|| XlogError::Kernel("selected id buffer has no column".to_string()))?;
104 match ids_buf.schema().column_type(0).ok_or_else(|| {
105 XlogError::Kernel("selected id buffer has no schema type".to_string())
106 })? {
107 ScalarType::U32 | ScalarType::Symbol => {
108 let ids_view = self.column_as_u32_view(ids_col, selected_len)?;
109 let func = self
110 .device
111 .inner()
112 .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_U32)
113 .ok_or_else(|| {
114 XlogError::Kernel(
115 "ilp_mark_selected_ids_u32 kernel not found".to_string(),
116 )
117 })?;
118 unsafe {
120 func.clone().launch(
121 LaunchConfig {
122 grid_dim: (grid_size, 1, 1),
123 block_dim: (block_size, 1, 1),
124 shared_mem_bytes: 0,
125 },
126 (
127 &ids_view,
128 selected_len_u32,
129 candidate_count_u32,
130 &mut active_flags,
131 ),
132 )
133 }
134 .map_err(|e| {
135 XlogError::Kernel(format!(
136 "strict sparse selected-id scatter failed: {}",
137 e
138 ))
139 })?;
140 }
141 ScalarType::I32 => {
142 let ids_view = self.ilp_i32_view(ids_col, selected_len)?;
143 let func = self
144 .device
145 .inner()
146 .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_I32)
147 .ok_or_else(|| {
148 XlogError::Kernel(
149 "ilp_mark_selected_ids_i32 kernel not found".to_string(),
150 )
151 })?;
152 unsafe {
154 func.clone().launch(
155 LaunchConfig {
156 grid_dim: (grid_size, 1, 1),
157 block_dim: (block_size, 1, 1),
158 shared_mem_bytes: 0,
159 },
160 (
161 &ids_view,
162 selected_len_u32,
163 candidate_count_u32,
164 &mut active_flags,
165 ),
166 )
167 }
168 .map_err(|e| {
169 XlogError::Kernel(format!(
170 "strict sparse selected-id scatter failed: {}",
171 e
172 ))
173 })?;
174 }
175 ScalarType::I64 => {
176 let ids_view = self.ilp_i64_view(ids_col, selected_len)?;
177 let func = self
178 .device
179 .inner()
180 .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_I64)
181 .ok_or_else(|| {
182 XlogError::Kernel(
183 "ilp_mark_selected_ids_i64 kernel not found".to_string(),
184 )
185 })?;
186 unsafe {
188 func.clone().launch(
189 LaunchConfig {
190 grid_dim: (grid_size, 1, 1),
191 block_dim: (block_size, 1, 1),
192 shared_mem_bytes: 0,
193 },
194 (
195 &ids_view,
196 selected_len_u32,
197 candidate_count_u32,
198 &mut active_flags,
199 ),
200 )
201 }
202 .map_err(|e| {
203 XlogError::Kernel(format!(
204 "strict sparse selected-id scatter failed: {}",
205 e
206 ))
207 })?;
208 }
209 ScalarType::U64 => {
210 let ids_view = self.column_as_u64_view(ids_col, selected_len)?;
211 let func = self
212 .device
213 .inner()
214 .get_func(ILP_MODULE, ilp_kernels::ILP_MARK_SELECTED_IDS_U64)
215 .ok_or_else(|| {
216 XlogError::Kernel(
217 "ilp_mark_selected_ids_u64 kernel not found".to_string(),
218 )
219 })?;
220 unsafe {
222 func.clone().launch(
223 LaunchConfig {
224 grid_dim: (grid_size, 1, 1),
225 block_dim: (block_size, 1, 1),
226 shared_mem_bytes: 0,
227 },
228 (
229 &ids_view,
230 selected_len_u32,
231 candidate_count_u32,
232 &mut active_flags,
233 ),
234 )
235 }
236 .map_err(|e| {
237 XlogError::Kernel(format!(
238 "strict sparse selected-id scatter failed: {}",
239 e
240 ))
241 })?;
242 }
243 other => {
244 return Err(XlogError::Kernel(format!(
245 "selected candidate ids must be I32/I64/U32/U64, got {:?}",
246 other
247 )));
248 }
249 }
250
251 self.device
252 .synchronize()
253 .map_err(|e| XlogError::Kernel(format!("strict sparse scatter sync: {}", e)))?;
254 }
255
256 let d_num_rows = self.upload_device_row_count(candidate_count_u32)?;
257 Ok(CudaBuffer::from_columns_with_host_count(
258 vec![active_flags.into_bytes().into()],
259 candidate_count as u64,
260 d_num_rows,
261 Schema::new(vec![("active".to_string(), ScalarType::U32)]),
262 candidate_count_u32,
263 ))
264 }
265
266 pub fn validate_selected_ids(
267 &self,
268 ids_buf: &CudaBuffer,
269 candidate_count: usize,
270 ) -> Result<()> {
271 let selected_len = usize::try_from(ids_buf.num_rows())
272 .map_err(|_| XlogError::Kernel("selected id row count overflow".to_string()))?;
273 let candidate_count_u32 = u32::try_from(candidate_count).map_err(|_| {
274 XlogError::Kernel(format!(
275 "candidate count {} exceeds u32::MAX for strict sparse mask",
276 candidate_count
277 ))
278 })?;
279
280 if selected_len == 0 {
281 return Ok(());
282 }
283
284 let selected_len_u32 = u32::try_from(selected_len).map_err(|_| {
285 XlogError::Kernel(format!(
286 "selected id count {} exceeds u32::MAX for strict sparse mask",
287 selected_len
288 ))
289 })?;
290 let block_size = 256u32;
291 let grid_size = selected_len_u32.div_ceil(block_size);
292 let ids_col = ids_buf
293 .column(0)
294 .ok_or_else(|| XlogError::Kernel("selected id buffer has no column".to_string()))?;
295
296 let mut seen_flags = self.memory.alloc::<u32>(candidate_count)?;
297 if candidate_count > 0 {
298 self.device
299 .inner()
300 .memset_zeros(&mut seen_flags)
301 .map_err(|e| {
302 XlogError::Kernel(format!("zero strict sparse validation flags: {}", e))
303 })?;
304 }
305
306 let mut error_code = self.memory.alloc::<u32>(1)?;
307 let mut error_pos = self.memory.alloc::<u32>(1)?;
308 self.device
309 .inner()
310 .memset_zeros(&mut error_code)
311 .map_err(|e| XlogError::Kernel(format!("zero strict sparse error code: {}", e)))?;
312 self.device
313 .inner()
314 .memset_zeros(&mut error_pos)
315 .map_err(|e| XlogError::Kernel(format!("zero strict sparse error pos: {}", e)))?;
316
317 match ids_buf
318 .schema()
319 .column_type(0)
320 .ok_or_else(|| XlogError::Kernel("selected id buffer has no schema type".to_string()))?
321 {
322 ScalarType::U32 | ScalarType::Symbol => {
323 let ids_view = self.column_as_u32_view(ids_col, selected_len)?;
324 let func = self
325 .device
326 .inner()
327 .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_U32)
328 .ok_or_else(|| {
329 XlogError::Kernel(
330 "ilp_validate_selected_ids_u32 kernel not found".to_string(),
331 )
332 })?;
333 unsafe {
335 func.clone().launch(
336 LaunchConfig {
337 grid_dim: (grid_size, 1, 1),
338 block_dim: (block_size, 1, 1),
339 shared_mem_bytes: 0,
340 },
341 (
342 &ids_view,
343 selected_len_u32,
344 candidate_count_u32,
345 &mut seen_flags,
346 &mut error_code,
347 &mut error_pos,
348 ),
349 )
350 }
351 .map_err(|e| {
352 XlogError::Kernel(format!(
353 "strict sparse selected-id validation failed: {}",
354 e
355 ))
356 })?;
357 }
358 ScalarType::I32 => {
359 let ids_view = self.ilp_i32_view(ids_col, selected_len)?;
360 let func = self
361 .device
362 .inner()
363 .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_I32)
364 .ok_or_else(|| {
365 XlogError::Kernel(
366 "ilp_validate_selected_ids_i32 kernel not found".to_string(),
367 )
368 })?;
369 unsafe {
371 func.clone().launch(
372 LaunchConfig {
373 grid_dim: (grid_size, 1, 1),
374 block_dim: (block_size, 1, 1),
375 shared_mem_bytes: 0,
376 },
377 (
378 &ids_view,
379 selected_len_u32,
380 candidate_count_u32,
381 &mut seen_flags,
382 &mut error_code,
383 &mut error_pos,
384 ),
385 )
386 }
387 .map_err(|e| {
388 XlogError::Kernel(format!(
389 "strict sparse selected-id validation failed: {}",
390 e
391 ))
392 })?;
393 }
394 ScalarType::I64 => {
395 let ids_view = self.ilp_i64_view(ids_col, selected_len)?;
396 let func = self
397 .device
398 .inner()
399 .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_I64)
400 .ok_or_else(|| {
401 XlogError::Kernel(
402 "ilp_validate_selected_ids_i64 kernel not found".to_string(),
403 )
404 })?;
405 unsafe {
407 func.clone().launch(
408 LaunchConfig {
409 grid_dim: (grid_size, 1, 1),
410 block_dim: (block_size, 1, 1),
411 shared_mem_bytes: 0,
412 },
413 (
414 &ids_view,
415 selected_len_u32,
416 candidate_count_u32,
417 &mut seen_flags,
418 &mut error_code,
419 &mut error_pos,
420 ),
421 )
422 }
423 .map_err(|e| {
424 XlogError::Kernel(format!(
425 "strict sparse selected-id validation failed: {}",
426 e
427 ))
428 })?;
429 }
430 ScalarType::U64 => {
431 let ids_view = self.column_as_u64_view(ids_col, selected_len)?;
432 let func = self
433 .device
434 .inner()
435 .get_func(ILP_MODULE, ilp_kernels::ILP_VALIDATE_SELECTED_IDS_U64)
436 .ok_or_else(|| {
437 XlogError::Kernel(
438 "ilp_validate_selected_ids_u64 kernel not found".to_string(),
439 )
440 })?;
441 unsafe {
443 func.clone().launch(
444 LaunchConfig {
445 grid_dim: (grid_size, 1, 1),
446 block_dim: (block_size, 1, 1),
447 shared_mem_bytes: 0,
448 },
449 (
450 &ids_view,
451 selected_len_u32,
452 candidate_count_u32,
453 &mut seen_flags,
454 &mut error_code,
455 &mut error_pos,
456 ),
457 )
458 }
459 .map_err(|e| {
460 XlogError::Kernel(format!(
461 "strict sparse selected-id validation failed: {}",
462 e
463 ))
464 })?;
465 }
466 other => {
467 return Err(XlogError::Kernel(format!(
468 "selected candidate ids must be I32/I64/U32/U64, got {:?}",
469 other
470 )));
471 }
472 }
473
474 self.device
475 .synchronize()
476 .map_err(|e| XlogError::Kernel(format!("strict sparse validation sync: {}", e)))?;
477
478 let error_code_host = self.dtoh_scalar_untracked(&error_code, 0)?;
479 if error_code_host == 0 {
480 return Ok(());
481 }
482 let error_pos_host = self.dtoh_scalar_untracked(&error_pos, 0)?;
483 match error_code_host {
484 1 => Err(XlogError::Kernel(format!(
485 "selected candidate id out of range at position {}",
486 error_pos_host
487 ))),
488 2 => Err(XlogError::Kernel(format!(
489 "duplicate selected candidate id at position {}",
490 error_pos_host
491 ))),
492 code => Err(XlogError::Kernel(format!(
493 "strict sparse selected-id validation failed with error code {}",
494 code
495 ))),
496 }
497 }
498
499 pub fn filter_buffer_by_candidate_flag(
500 &self,
501 input: &CudaBuffer,
502 candidate_flags: &CudaBuffer,
503 candidate_idx: usize,
504 ) -> Result<CudaBuffer> {
505 if input.is_empty() {
506 return self.create_empty_buffer(input.schema().clone());
507 }
508 if candidate_idx >= candidate_flags.num_rows() as usize {
509 return Err(XlogError::Kernel(format!(
510 "candidate flag index {} out of range [0, {})",
511 candidate_idx,
512 candidate_flags.num_rows()
513 )));
514 }
515
516 let flag_col = candidate_flags
517 .column(0)
518 .ok_or_else(|| XlogError::Kernel("candidate flag buffer has no column".to_string()))?;
519 let flag_view = self.column_as_u32_view(flag_col, candidate_flags.num_rows() as usize)?;
520 let row_count = u32::try_from(input.num_rows()).map_err(|_| {
521 XlogError::Kernel(format!(
522 "strict sparse row count {} exceeds u32::MAX",
523 input.num_rows()
524 ))
525 })?;
526 let candidate_idx_u32 = u32::try_from(candidate_idx).map_err(|_| {
527 XlogError::Kernel(format!(
528 "candidate flag index {} exceeds u32::MAX",
529 candidate_idx
530 ))
531 })?;
532
533 let mut row_mask = self.memory.alloc::<u8>(row_count as usize)?;
534 let func = self
535 .device
536 .inner()
537 .get_func(ILP_MODULE, ilp_kernels::ILP_BROADCAST_CANDIDATE_FLAG)
538 .ok_or_else(|| {
539 XlogError::Kernel("ilp_broadcast_candidate_flag kernel not found".to_string())
540 })?;
541 let block_size = 256u32;
542 let grid_size = row_count.div_ceil(block_size);
543 unsafe {
545 func.clone().launch(
546 LaunchConfig {
547 grid_dim: (grid_size, 1, 1),
548 block_dim: (block_size, 1, 1),
549 shared_mem_bytes: 0,
550 },
551 (&flag_view, candidate_idx_u32, row_count, &mut row_mask),
552 )
553 }
554 .map_err(|e| XlogError::Kernel(format!("strict sparse flag broadcast failed: {}", e)))?;
555
556 self.filter_by_device_mask(input, &row_mask)
557 }
558
559 pub fn ilp_coo_fill_launch(
564 &self,
565 compacted_fact_indices: &TrackedCudaSlice<u32>,
566 cidx: u32,
567 count: u32,
568 offset: u32,
569 coo_fact: &mut TrackedCudaSlice<u32>,
570 coo_cand: &mut TrackedCudaSlice<u32>,
571 ) -> Result<()> {
572 if count == 0 {
573 return Ok(());
574 }
575 let func = self
576 .device
577 .inner()
578 .get_func(ILP_CREDIT_MODULE, ilp_credit_kernels::ILP_COO_FILL)
579 .ok_or_else(|| XlogError::Kernel("ilp_coo_fill kernel not found".to_string()))?;
580 let block_size = 256u32;
581 let grid_size = count.div_ceil(block_size);
582 unsafe {
584 func.clone().launch(
585 LaunchConfig {
586 grid_dim: (grid_size, 1, 1),
587 block_dim: (block_size, 1, 1),
588 shared_mem_bytes: 0,
589 },
590 (
591 compacted_fact_indices,
592 cidx,
593 count,
594 offset,
595 coo_fact,
596 coo_cand,
597 ),
598 )
599 }
600 .map_err(|e| XlogError::Kernel(format!("ilp_coo_fill failed: {}", e)))?;
601 self.device.synchronize()?;
602 Ok(())
603 }
604
605 pub fn ilp_credit_forward_f32_launch(
608 &self,
609 row_offsets: &TrackedCudaSlice<u32>,
610 col_indices: &TrackedCudaSlice<u32>,
611 cand_probs: &CudaColumn, is_positive: &TrackedCudaSlice<u8>,
613 num_facts: u32,
614 eps: f32,
615 ) -> Result<(TrackedCudaSlice<f32>, TrackedCudaSlice<f32>)> {
616 let mut credit_out = self.memory.alloc::<f32>(num_facts as usize)?;
617 let mut loss_contrib = self.memory.alloc::<f32>(num_facts as usize)?;
618 if num_facts == 0 {
619 return Ok((credit_out, loss_contrib));
620 }
621 let func = self
622 .device
623 .inner()
624 .get_func(
625 ILP_CREDIT_MODULE,
626 ilp_credit_kernels::ILP_CREDIT_FORWARD_F32,
627 )
628 .ok_or_else(|| {
629 XlogError::Kernel("ilp_credit_forward_f32 kernel not found".to_string())
630 })?;
631 let block_size = 256u32;
632 let grid_size = num_facts.div_ceil(block_size);
633 let cand_view = RawCudaView::<f32> {
635 ptr: *cand_probs.device_ptr(),
636 len: cudarc::driver::DeviceSlice::len(cand_probs) / 4,
637 stream: cand_probs.stream().clone(),
638 source_block: None,
639 _marker: PhantomData,
640 };
641 unsafe {
643 func.clone().launch(
644 LaunchConfig {
645 grid_dim: (grid_size, 1, 1),
646 block_dim: (block_size, 1, 1),
647 shared_mem_bytes: 0,
648 },
649 (
650 row_offsets,
651 col_indices,
652 &cand_view,
653 is_positive,
654 num_facts,
655 eps,
656 &mut credit_out,
657 &mut loss_contrib,
658 ),
659 )
660 }
661 .map_err(|e| XlogError::Kernel(format!("ilp_credit_forward_f32 failed: {}", e)))?;
662 self.device.synchronize()?;
663 Ok((credit_out, loss_contrib))
664 }
665
666 pub fn ilp_credit_forward_f64_launch(
669 &self,
670 row_offsets: &TrackedCudaSlice<u32>,
671 col_indices: &TrackedCudaSlice<u32>,
672 cand_probs: &CudaColumn, is_positive: &TrackedCudaSlice<u8>,
674 num_facts: u32,
675 eps: f64,
676 ) -> Result<(TrackedCudaSlice<f64>, TrackedCudaSlice<f64>)> {
677 let mut credit_out = self.memory.alloc::<f64>(num_facts as usize)?;
678 let mut loss_contrib = self.memory.alloc::<f64>(num_facts as usize)?;
679 if num_facts == 0 {
680 return Ok((credit_out, loss_contrib));
681 }
682 let func = self
683 .device
684 .inner()
685 .get_func(
686 ILP_CREDIT_MODULE,
687 ilp_credit_kernels::ILP_CREDIT_FORWARD_F64,
688 )
689 .ok_or_else(|| {
690 XlogError::Kernel("ilp_credit_forward_f64 kernel not found".to_string())
691 })?;
692 let block_size = 256u32;
693 let grid_size = num_facts.div_ceil(block_size);
694 let cand_view = RawCudaView::<f64> {
695 ptr: *cand_probs.device_ptr(),
696 len: cudarc::driver::DeviceSlice::len(cand_probs) / 8,
697 stream: cand_probs.stream().clone(),
698 source_block: None,
699 _marker: PhantomData,
700 };
701 unsafe {
703 func.clone().launch(
704 LaunchConfig {
705 grid_dim: (grid_size, 1, 1),
706 block_dim: (block_size, 1, 1),
707 shared_mem_bytes: 0,
708 },
709 (
710 row_offsets,
711 col_indices,
712 &cand_view,
713 is_positive,
714 num_facts,
715 eps,
716 &mut credit_out,
717 &mut loss_contrib,
718 ),
719 )
720 }
721 .map_err(|e| XlogError::Kernel(format!("ilp_credit_forward_f64 failed: {}", e)))?;
722 self.device.synchronize()?;
723 Ok((credit_out, loss_contrib))
724 }
725
726 pub fn ilp_credit_backward_f32_launch(
729 &self,
730 row_offsets: &TrackedCudaSlice<u32>,
731 col_indices: &TrackedCudaSlice<u32>,
732 credit_out: &TrackedCudaSlice<f32>,
733 is_positive: &TrackedCudaSlice<u8>,
734 num_facts: u32,
735 num_cands: u32,
736 ) -> Result<TrackedCudaSlice<f32>> {
737 let mut d_grad = self.memory.alloc::<f32>(num_cands as usize)?;
738 self.device
739 .inner()
740 .memset_zeros(&mut d_grad)
741 .map_err(|e| XlogError::Kernel(format!("Failed to zero grad: {}", e)))?;
742 if num_facts == 0 {
743 return Ok(d_grad);
744 }
745 let func = self
746 .device
747 .inner()
748 .get_func(
749 ILP_CREDIT_MODULE,
750 ilp_credit_kernels::ILP_CREDIT_BACKWARD_F32,
751 )
752 .ok_or_else(|| {
753 XlogError::Kernel("ilp_credit_backward_f32 kernel not found".to_string())
754 })?;
755 let block_size = 256u32;
756 let grid_size = num_facts.div_ceil(block_size);
757 unsafe {
759 func.clone().launch(
760 LaunchConfig {
761 grid_dim: (grid_size, 1, 1),
762 block_dim: (block_size, 1, 1),
763 shared_mem_bytes: 0,
764 },
765 (
766 row_offsets,
767 col_indices,
768 credit_out,
769 is_positive,
770 num_facts,
771 &mut d_grad,
772 ),
773 )
774 }
775 .map_err(|e| XlogError::Kernel(format!("ilp_credit_backward_f32 failed: {}", e)))?;
776 self.device.synchronize()?;
777 Ok(d_grad)
778 }
779
780 pub fn ilp_credit_backward_f64_launch(
783 &self,
784 row_offsets: &TrackedCudaSlice<u32>,
785 col_indices: &TrackedCudaSlice<u32>,
786 credit_out: &TrackedCudaSlice<f64>,
787 is_positive: &TrackedCudaSlice<u8>,
788 num_facts: u32,
789 num_cands: u32,
790 ) -> Result<TrackedCudaSlice<f64>> {
791 let mut d_grad = self.memory.alloc::<f64>(num_cands as usize)?;
792 self.device
793 .inner()
794 .memset_zeros(&mut d_grad)
795 .map_err(|e| XlogError::Kernel(format!("Failed to zero grad: {}", e)))?;
796 if num_facts == 0 {
797 return Ok(d_grad);
798 }
799 let func = self
800 .device
801 .inner()
802 .get_func(
803 ILP_CREDIT_MODULE,
804 ilp_credit_kernels::ILP_CREDIT_BACKWARD_F64,
805 )
806 .ok_or_else(|| {
807 XlogError::Kernel("ilp_credit_backward_f64 kernel not found".to_string())
808 })?;
809 let block_size = 256u32;
810 let grid_size = num_facts.div_ceil(block_size);
811 unsafe {
813 func.clone().launch(
814 LaunchConfig {
815 grid_dim: (grid_size, 1, 1),
816 block_dim: (block_size, 1, 1),
817 shared_mem_bytes: 0,
818 },
819 (
820 row_offsets,
821 col_indices,
822 credit_out,
823 is_positive,
824 num_facts,
825 &mut d_grad,
826 ),
827 )
828 }
829 .map_err(|e| XlogError::Kernel(format!("ilp_credit_backward_f64 failed: {}", e)))?;
830 self.device.synchronize()?;
831 Ok(d_grad)
832 }
833
834 pub fn ilp_reduce_sum_f32_launch(
840 &self,
841 input: &TrackedCudaSlice<f32>,
842 n: u32,
843 ) -> Result<TrackedCudaSlice<f32>> {
844 let mut d_result = self.memory.alloc::<f32>(1)?;
845 self.device
846 .inner()
847 .memset_zeros(&mut d_result)
848 .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f32 zero result: {}", e)))?;
849
850 if n == 0 {
851 return Ok(d_result);
852 }
853
854 let func = self
855 .device
856 .inner()
857 .get_func(ILP_MODULE, ilp_kernels::ILP_REDUCE_SUM_F32)
858 .ok_or_else(|| XlogError::Kernel("ilp_reduce_sum_f32 not found".to_string()))?;
859 let block_size = 256u32;
860 let grid_size = n.div_ceil(block_size);
861 unsafe {
863 func.clone().launch(
864 LaunchConfig {
865 grid_dim: (grid_size, 1, 1),
866 block_dim: (block_size, 1, 1),
867 shared_mem_bytes: 0,
868 },
869 (input, n, &mut d_result),
870 )
871 }
872 .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f32: {}", e)))?;
873 self.device.synchronize()?;
874 Ok(d_result)
875 }
876
877 pub fn ilp_reduce_sum_f64_launch(
883 &self,
884 input: &TrackedCudaSlice<f64>,
885 n: u32,
886 ) -> Result<TrackedCudaSlice<f64>> {
887 let mut d_result = self.memory.alloc::<f64>(1)?;
888 self.device
889 .inner()
890 .memset_zeros(&mut d_result)
891 .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f64 zero result: {}", e)))?;
892
893 if n == 0 {
894 return Ok(d_result);
895 }
896
897 let func = self
898 .device
899 .inner()
900 .get_func(ILP_MODULE, ilp_kernels::ILP_REDUCE_SUM_F64)
901 .ok_or_else(|| XlogError::Kernel("ilp_reduce_sum_f64 not found".to_string()))?;
902 let block_size = 256u32;
903 let grid_size = n.div_ceil(block_size);
904 unsafe {
906 func.clone().launch(
907 LaunchConfig {
908 grid_dim: (grid_size, 1, 1),
909 block_dim: (block_size, 1, 1),
910 shared_mem_bytes: 0,
911 },
912 (input, n, &mut d_result),
913 )
914 }
915 .map_err(|e| XlogError::Kernel(format!("ilp_reduce_sum_f64: {}", e)))?;
916 self.device.synchronize()?;
917 Ok(d_result)
918 }
919
920 #[allow(clippy::too_many_arguments)]
932 pub fn ilp_coo_fill_from_mask_launch(
933 &self,
934 mask: &TrackedCudaSlice<u8>,
935 prefix_sum: &TrackedCudaSlice<u32>,
936 fact_indices: &TrackedCudaSlice<u32>,
937 offset_idx: u32,
938 cand_value: u32,
939 num_query: u32,
940 d_offsets: &TrackedCudaSlice<u32>,
941 coo_fact: &mut TrackedCudaSlice<u32>,
942 coo_cand: &mut TrackedCudaSlice<u32>,
943 ) -> Result<()> {
944 if num_query == 0 {
945 return Ok(());
946 }
947 let func = self
948 .device()
949 .inner()
950 .get_func(ILP_MODULE, ilp_kernels::ILP_COO_FILL_FROM_MASK)
951 .ok_or_else(|| XlogError::Kernel("ilp_coo_fill_from_mask not found".to_string()))?;
952 let block_size = 256u32;
953 let grid_size = num_query.div_ceil(block_size);
954 unsafe {
956 func.clone().launch(
957 LaunchConfig {
958 grid_dim: (grid_size, 1, 1),
959 block_dim: (block_size, 1, 1),
960 shared_mem_bytes: 0,
961 },
962 (
963 mask,
964 prefix_sum,
965 fact_indices,
966 offset_idx,
967 cand_value,
968 num_query,
969 d_offsets,
970 coo_fact,
971 coo_cand,
972 ),
973 )
974 }
975 .map_err(|e| XlogError::Kernel(format!("ilp_coo_fill_from_mask: {}", e)))?;
976 self.device()
977 .inner()
978 .synchronize()
979 .map_err(|e| XlogError::Kernel(format!("ilp_coo_fill_from_mask sync: {}", e)))?;
980 Ok(())
981 }
982
983 pub fn ilp_csr_histogram_launch(
993 &self,
994 sorted_facts: &TrackedCudaSlice<u32>,
995 nnz: u32,
996 num_facts: u32,
997 ) -> Result<TrackedCudaSlice<u32>> {
998 let mut d_hist = self.memory().alloc::<u32>(num_facts as usize)?;
999 self.device()
1000 .inner()
1001 .memset_zeros(&mut d_hist)
1002 .map_err(|e| XlogError::Kernel(format!("ilp_csr_histogram zero hist: {}", e)))?;
1003
1004 if nnz == 0 {
1005 return Ok(d_hist);
1006 }
1007
1008 let func = self
1009 .device()
1010 .inner()
1011 .get_func(ILP_MODULE, ilp_kernels::ILP_CSR_HISTOGRAM)
1012 .ok_or_else(|| XlogError::Kernel("ilp_csr_histogram kernel not found".to_string()))?;
1013
1014 let block_size = 256u32;
1015 let grid_size = nnz.div_ceil(block_size);
1016
1017 unsafe {
1019 func.clone()
1020 .launch(
1021 cudarc::driver::LaunchConfig {
1022 grid_dim: (grid_size, 1, 1),
1023 block_dim: (block_size, 1, 1),
1024 shared_mem_bytes: 0,
1025 },
1026 (sorted_facts, nnz, num_facts, &mut d_hist),
1027 )
1028 .map_err(|e| XlogError::Kernel(format!("ilp_csr_histogram launch: {}", e)))?;
1029 }
1030
1031 self.device()
1032 .inner()
1033 .synchronize()
1034 .map_err(|e| XlogError::Kernel(format!("ilp_csr_histogram sync: {}", e)))?;
1035
1036 Ok(d_hist)
1037 }
1038}