xlog_prob/
neural_fast_path.rs1use cudarc::driver::{CudaView, DeviceSlice};
7use xlog_core::{Result, XlogError};
8use xlog_cuda::memory::TrackedCudaSlice;
9use xlog_cuda::CudaKernelProvider;
10
11#[derive(Debug, Clone, Copy)]
16#[non_exhaustive]
17pub struct NeuralFastPathConfig {
18 pub eps: f64,
20 pub min_p: f64,
22}
23
24impl Default for NeuralFastPathConfig {
25 fn default() -> Self {
26 Self {
27 eps: 1e-7,
28 min_p: 1e-12,
29 }
30 }
31}
32
33pub struct GpuWeightSlots {
39 group_offsets_host: Vec<u32>,
40 group_offsets: TrackedCudaSlice<u32>, slot_cnf_var: TrackedCudaSlice<u32>, }
43
44impl GpuWeightSlots {
45 pub fn upload(provider: &CudaKernelProvider, groups: &[Vec<u32>]) -> Result<Self> {
49 if groups.len() > u32::MAX as usize {
50 return Err(XlogError::Compilation(
51 "Neural fast-path group count exceeds GPU u32 index space".to_string(),
52 ));
53 }
54 let offset_count = groups.len().checked_add(1).ok_or_else(|| {
55 XlogError::Compilation("Neural fast-path group offset count overflow".to_string())
56 })?;
57 let total_slots = groups.iter().try_fold(0usize, |acc, group| {
58 acc.checked_add(group.len()).ok_or_else(|| {
59 XlogError::Compilation("Neural fast-path slot count overflow".to_string())
60 })
61 })?;
62 if total_slots > u32::MAX as usize {
63 return Err(XlogError::Compilation(
64 "Neural fast-path slot count exceeds GPU u32 index space".to_string(),
65 ));
66 }
67
68 let mut offsets: Vec<u32> = Vec::with_capacity(offset_count);
69 offsets.push(0);
70
71 let mut flat: Vec<u32> = Vec::with_capacity(total_slots);
72 for g in groups {
73 flat.extend_from_slice(g);
74 let next_offset = u32::try_from(flat.len()).map_err(|_| {
75 XlogError::Compilation(
76 "Neural fast-path slot offset exceeds GPU u32 index space".to_string(),
77 )
78 })?;
79 offsets.push(next_offset);
80 }
81
82 let memory = provider.memory().clone();
83
84 let mut d_offsets = memory.alloc::<u32>(offsets.len())?;
85 provider
86 .htod_sync_copy_into_tracked(&offsets, &mut d_offsets)
87 .map_err(|e| {
88 XlogError::Kernel(format!("Failed to upload weight slot offsets: {}", e))
89 })?;
90
91 let mut d_vars = memory.alloc::<u32>(flat.len())?;
92 provider
93 .htod_sync_copy_into_tracked(&flat, &mut d_vars)
94 .map_err(|e| XlogError::Kernel(format!("Failed to upload weight slot vars: {}", e)))?;
95
96 Ok(Self {
97 group_offsets_host: offsets,
98 group_offsets: d_offsets,
99 slot_cnf_var: d_vars,
100 })
101 }
102
103 pub fn num_groups(&self) -> u32 {
104 debug_assert!(!self.group_offsets_host.is_empty());
105 (self.group_offsets_host.len() - 1) as u32
106 }
107
108 pub fn num_groups_usize(&self) -> usize {
109 debug_assert!(!self.group_offsets_host.is_empty());
110 self.group_offsets_host.len() - 1
111 }
112
113 pub fn total_slots(&self) -> u32 {
114 debug_assert!(!self.group_offsets_host.is_empty());
115 self.group_offsets_host[self.group_offsets_host.len() - 1]
116 }
117
118 pub fn group_offsets(&self) -> &TrackedCudaSlice<u32> {
119 &self.group_offsets
120 }
121
122 pub fn slot_cnf_var(&self) -> &TrackedCudaSlice<u32> {
123 &self.slot_cnf_var
124 }
125
126 pub fn group_slot_cnf_var(&self, group_idx: usize) -> Result<CudaView<'_, u32>> {
128 let start = *self
129 .group_offsets_host
130 .get(group_idx)
131 .ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
132 as usize;
133 let end = *self
134 .group_offsets_host
135 .get(group_idx + 1)
136 .ok_or_else(|| XlogError::Compilation("Group index out of bounds".to_string()))?
137 as usize;
138 if end < start || end > self.slot_cnf_var.len() {
139 return Err(XlogError::Compilation(
140 "Invalid group slot range in GpuWeightSlots".to_string(),
141 ));
142 }
143 Ok(self.slot_cnf_var.slice(start..end))
144 }
145}