1use std::sync::Arc;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::provider::{weights_kernels, WEIGHTS_MODULE};
9use xlog_cuda::{CudaKernelProvider, LaunchAsync};
10
11use crate::compilation::gpu_cnf::GpuCnfVarTables;
12
13pub struct GpuWeights {
14 pub log_true: TrackedCudaSlice<f64>,
15 pub log_false: TrackedCudaSlice<f64>,
16}
17
18fn kernel_count_u32(context: &str, count: usize) -> Result<u32> {
19 u32::try_from(count)
20 .map_err(|_| XlogError::Compilation(format!("{context} exceeds GPU u32 index space")))
21}
22
23fn grid_for(count: u32, block: u32) -> Result<u32> {
24 if count == 0 {
25 return Ok(0);
26 }
27 if block == 0 {
28 return Err(XlogError::Compilation(
29 "GPU weight kernel block size must be nonzero".to_string(),
30 ));
31 }
32 let grid = (count as u64).div_ceil(block as u64);
33 let step = grid
34 .checked_mul(block as u64)
35 .ok_or_else(|| XlogError::Compilation("GPU weight grid-stride overflow".to_string()))?;
36 if step > u32::MAX as u64 {
37 return Err(XlogError::Compilation(
38 "GPU weight grid-stride step exceeds u32 index space".to_string(),
39 ));
40 }
41 u32::try_from(grid).map_err(|_| {
42 XlogError::Compilation("GPU weight kernel grid exceeds u32 index space".to_string())
43 })
44}
45
46fn checked_var_table_count(var_cap: u32) -> Result<u32> {
47 var_cap.checked_add(1).ok_or_else(|| {
48 XlogError::Compilation("GPU weight var_cap exceeds u32 table index space".to_string())
49 })
50}
51
52fn weights_len_for_var_cap(var_cap: u32) -> Result<usize> {
53 (var_cap as usize)
54 .checked_add(1)
55 .ok_or_else(|| XlogError::Compilation("weight table size overflow".to_string()))
56}
57
58fn query_weights_len_for_var_cap(var_cap: u32) -> Result<usize> {
59 (var_cap as usize)
60 .checked_add(1)
61 .ok_or_else(|| XlogError::Compilation("query var_cap overflow".to_string()))
62}
63
64fn evidence_len_for_var_cap(var_cap: u32) -> Result<usize> {
65 (var_cap as usize)
66 .checked_add(1)
67 .ok_or_else(|| XlogError::Compilation("evidence var_cap overflow".to_string()))
68}
69
70pub fn build_evidence_by_var_gpu(
71 node_var: &TrackedCudaSlice<u32>,
72 evidence_nodes: &TrackedCudaSlice<u32>,
73 evidence_vals: &TrackedCudaSlice<u8>,
74 var_cap: u32,
75 provider: &Arc<CudaKernelProvider>,
76) -> Result<TrackedCudaSlice<u8>> {
77 if evidence_nodes.len() != evidence_vals.len() {
78 return Err(XlogError::Compilation(format!(
79 "GPU evidence nodes len {} != vals len {}",
80 evidence_nodes.len(),
81 evidence_vals.len()
82 )));
83 }
84 let len = evidence_len_for_var_cap(var_cap)?;
85
86 let memory = provider.memory();
87 let device = provider.device().inner();
88 let mut evidence_by_var = memory.alloc::<u8>(len)?;
89 device
90 .memset_zeros(&mut evidence_by_var)
91 .map_err(|e| XlogError::Kernel(format!("Failed to zero evidence buffer: {}", e)))?;
92
93 let count = evidence_nodes.len();
94 if count == 0 {
95 return Ok(evidence_by_var);
96 }
97 let count_u32 = kernel_count_u32("GPU evidence node count", count)?;
98
99 let func = device
100 .get_func(
101 WEIGHTS_MODULE,
102 weights_kernels::WEIGHTS_SET_EVIDENCE_FROM_NODES,
103 )
104 .ok_or_else(|| {
105 XlogError::Kernel("weights_set_evidence_from_nodes kernel not found".to_string())
106 })?;
107 let block = 256u32;
108 let grid = grid_for(count_u32, block)?;
109 unsafe {
111 func.clone().launch(
112 LaunchConfig {
113 grid_dim: (grid.max(1), 1, 1),
114 block_dim: (block, 1, 1),
115 shared_mem_bytes: 0,
116 },
117 (
118 node_var,
119 evidence_nodes,
120 evidence_vals,
121 count_u32,
122 var_cap,
123 &mut evidence_by_var,
124 ),
125 )
126 }
127 .map_err(|e| XlogError::Kernel(format!("weights_set_evidence_from_nodes failed: {}", e)))?;
128 Ok(evidence_by_var)
130}
131
132pub fn map_nodes_to_vars_gpu(
133 node_var: &TrackedCudaSlice<u32>,
134 node_ids: &TrackedCudaSlice<u32>,
135 var_cap: u32,
136 provider: &Arc<CudaKernelProvider>,
137) -> Result<TrackedCudaSlice<u32>> {
138 let memory = provider.memory();
139 let device = provider.device().inner();
140 let mut out = memory.alloc::<u32>(node_ids.len())?;
141 let count = node_ids.len();
142 if count == 0 {
143 return Ok(out);
144 }
145 let count_u32 = kernel_count_u32("GPU node-to-var map count", count)?;
146
147 let func = device
148 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_MAP_NODES_TO_VARS)
149 .ok_or_else(|| {
150 XlogError::Kernel("weights_map_nodes_to_vars kernel not found".to_string())
151 })?;
152
153 let block = 256u32;
154 let grid = grid_for(count_u32, block)?;
155 unsafe {
157 func.clone().launch(
158 LaunchConfig {
159 grid_dim: (grid.max(1), 1, 1),
160 block_dim: (block, 1, 1),
161 shared_mem_bytes: 0,
162 },
163 (node_var, node_ids, count_u32, var_cap, &mut out),
164 )
165 }
166 .map_err(|e| XlogError::Kernel(format!("weights_map_nodes_to_vars failed: {}", e)))?;
167 Ok(out)
169}
170
171pub fn apply_query_vars_device(
172 provider: &Arc<CudaKernelProvider>,
173 query_vars: &TrackedCudaSlice<u32>,
174 var_cap: u32,
175 log_false: &mut TrackedCudaSlice<f64>,
176 saved: &mut TrackedCudaSlice<f64>,
177) -> Result<()> {
178 let count = query_vars.len();
179 if saved.len() < count {
180 return Err(XlogError::Compilation(format!(
181 "query restore buffer len {} < query vars len {}",
182 saved.len(),
183 count
184 )));
185 }
186 let weights_len = query_weights_len_for_var_cap(var_cap)?;
187 if log_false.len() < weights_len {
188 return Err(XlogError::Compilation(format!(
189 "log_false len {} < var_cap+1 {}",
190 log_false.len(),
191 weights_len
192 )));
193 }
194 if count == 0 {
195 return Ok(());
196 }
197 let count_u32 = kernel_count_u32("GPU query apply count", count)?;
198
199 let device = provider.device().inner();
200 let func = device
201 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_APPLY_QUERY_VARS)
202 .ok_or_else(|| {
203 XlogError::Kernel("weights_apply_query_vars kernel not found".to_string())
204 })?;
205
206 let block = 256u32;
207 let grid = grid_for(count_u32, block)?;
208 unsafe {
210 func.clone().launch(
211 LaunchConfig {
212 grid_dim: (grid.max(1), 1, 1),
213 block_dim: (block, 1, 1),
214 shared_mem_bytes: 0,
215 },
216 (query_vars, count_u32, var_cap, log_false, saved),
217 )
218 }
219 .map_err(|e| XlogError::Kernel(format!("weights_apply_query_vars failed: {}", e)))?;
220 Ok(())
222}
223
224pub fn restore_query_vars_device(
225 provider: &Arc<CudaKernelProvider>,
226 query_vars: &TrackedCudaSlice<u32>,
227 var_cap: u32,
228 log_false: &mut TrackedCudaSlice<f64>,
229 saved: &TrackedCudaSlice<f64>,
230) -> Result<()> {
231 let count = query_vars.len();
232 if saved.len() < count {
233 return Err(XlogError::Compilation(format!(
234 "query restore buffer len {} < query vars len {}",
235 saved.len(),
236 count
237 )));
238 }
239 let weights_len = query_weights_len_for_var_cap(var_cap)?;
240 if log_false.len() < weights_len {
241 return Err(XlogError::Compilation(format!(
242 "log_false len {} < var_cap+1 {}",
243 log_false.len(),
244 weights_len
245 )));
246 }
247 if count == 0 {
248 return Ok(());
249 }
250 let count_u32 = kernel_count_u32("GPU query restore count", count)?;
251
252 let device = provider.device().inner();
253 let func = device
254 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_RESTORE_QUERY_VARS)
255 .ok_or_else(|| {
256 XlogError::Kernel("weights_restore_query_vars kernel not found".to_string())
257 })?;
258
259 let block = 256u32;
260 let grid = grid_for(count_u32, block)?;
261 unsafe {
263 func.clone().launch(
264 LaunchConfig {
265 grid_dim: (grid.max(1), 1, 1),
266 block_dim: (block, 1, 1),
267 shared_mem_bytes: 0,
268 },
269 (query_vars, count_u32, var_cap, log_false, saved),
270 )
271 }
272 .map_err(|e| XlogError::Kernel(format!("weights_restore_query_vars failed: {}", e)))?;
273 Ok(())
275}
276
277pub fn build_weights_gpu(
278 vars: &GpuCnfVarTables,
279 leaf_probs: &TrackedCudaSlice<f64>,
280 choice_true: &TrackedCudaSlice<f64>,
281 choice_false: &TrackedCudaSlice<f64>,
282 evidence_by_var: &TrackedCudaSlice<u8>,
283 provider: &Arc<CudaKernelProvider>,
284) -> Result<GpuWeights> {
285 let var_cap = vars.max_var;
286 let weights_len = weights_len_for_var_cap(var_cap)?;
287
288 if vars.leaf_var.len() < leaf_probs.len() {
289 return Err(XlogError::Compilation(format!(
290 "leaf_probs len {} exceeds leaf_var len {}",
291 leaf_probs.len(),
292 vars.leaf_var.len()
293 )));
294 }
295 if vars.choice_var.len() < choice_true.len() {
296 return Err(XlogError::Compilation(format!(
297 "choice_true len {} exceeds choice_var len {}",
298 choice_true.len(),
299 vars.choice_var.len()
300 )));
301 }
302 if choice_true.len() != choice_false.len() {
303 return Err(XlogError::Compilation(format!(
304 "choice_true len {} != choice_false len {}",
305 choice_true.len(),
306 choice_false.len()
307 )));
308 }
309 if evidence_by_var.len() != weights_len {
310 return Err(XlogError::Compilation(format!(
311 "evidence_by_var len {} != weights len {}",
312 evidence_by_var.len(),
313 weights_len
314 )));
315 }
316
317 let memory = provider.memory();
318 let device = provider.device().inner();
319 let mut log_true = memory.alloc::<f64>(weights_len)?;
320 let mut log_false = memory.alloc::<f64>(weights_len)?;
321
322 device
324 .memset_zeros(&mut log_true)
325 .map_err(|e| XlogError::Kernel(format!("Failed to zero log_true weights: {}", e)))?;
326 device
327 .memset_zeros(&mut log_false)
328 .map_err(|e| XlogError::Kernel(format!("Failed to zero log_false weights: {}", e)))?;
329
330 let block = 256u32;
331
332 if !leaf_probs.is_empty() {
333 let leaf_count = kernel_count_u32("GPU leaf probability count", leaf_probs.len())?;
334 let func = device
335 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FILL_LEAF)
336 .ok_or_else(|| XlogError::Kernel("weights_fill_leaf kernel not found".to_string()))?;
337 let grid = grid_for(leaf_count, block)?;
338 unsafe {
340 func.clone().launch(
341 LaunchConfig {
342 grid_dim: (grid.max(1), 1, 1),
343 block_dim: (block, 1, 1),
344 shared_mem_bytes: 0,
345 },
346 (
347 &vars.leaf_var,
348 leaf_probs,
349 leaf_count,
350 var_cap,
351 &mut log_true,
352 &mut log_false,
353 ),
354 )
355 }
356 .map_err(|e| XlogError::Kernel(format!("weights_fill_leaf failed: {}", e)))?;
357 }
358
359 if !choice_true.is_empty() {
360 let choice_count = kernel_count_u32("GPU choice probability count", choice_true.len())?;
361 let func = device
362 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_FILL_CHOICE)
363 .ok_or_else(|| XlogError::Kernel("weights_fill_choice kernel not found".to_string()))?;
364 let grid = grid_for(choice_count, block)?;
365 unsafe {
367 func.clone().launch(
368 LaunchConfig {
369 grid_dim: (grid.max(1), 1, 1),
370 block_dim: (block, 1, 1),
371 shared_mem_bytes: 0,
372 },
373 (
374 &vars.choice_var,
375 choice_true,
376 choice_false,
377 choice_count,
378 var_cap,
379 &mut log_true,
380 &mut log_false,
381 ),
382 )
383 }
384 .map_err(|e| XlogError::Kernel(format!("weights_fill_choice failed: {}", e)))?;
385 }
386
387 if !evidence_by_var.is_empty() {
388 let var_table_count = checked_var_table_count(var_cap)?;
389 let func = device
390 .get_func(WEIGHTS_MODULE, weights_kernels::WEIGHTS_APPLY_EVIDENCE)
391 .ok_or_else(|| {
392 XlogError::Kernel("weights_apply_evidence kernel not found".to_string())
393 })?;
394 let grid = grid_for(var_table_count, block)?;
395 unsafe {
397 func.clone().launch(
398 LaunchConfig {
399 grid_dim: (grid.max(1), 1, 1),
400 block_dim: (block, 1, 1),
401 shared_mem_bytes: 0,
402 },
403 (evidence_by_var, var_cap, &mut log_true, &mut log_false),
404 )
405 }
406 .map_err(|e| XlogError::Kernel(format!("weights_apply_evidence failed: {}", e)))?;
407 }
408 Ok(GpuWeights {
410 log_true,
411 log_false,
412 })
413}
414
415#[allow(dead_code)] pub(crate) fn upload_weights_from_host(
417 provider: &Arc<CudaKernelProvider>,
418 weights: &[(f64, f64)],
419) -> Result<GpuWeights> {
420 let weights_len = weights.len();
421 let mut host_true: Vec<f64> = Vec::with_capacity(weights_len);
422 let mut host_false: Vec<f64> = Vec::with_capacity(weights_len);
423 for &(t, f) in weights {
424 host_true.push(t);
425 host_false.push(f);
426 }
427
428 let memory = provider.memory();
429 let mut log_true = memory.alloc::<f64>(weights_len)?;
430 let mut log_false = memory.alloc::<f64>(weights_len)?;
431 provider
432 .htod_sync_copy_into_tracked(&host_true, &mut log_true)
433 .map_err(|e| XlogError::Kernel(format!("Upload log_true weights failed: {}", e)))?;
434 provider
435 .htod_sync_copy_into_tracked(&host_false, &mut log_false)
436 .map_err(|e| XlogError::Kernel(format!("Upload log_false weights failed: {}", e)))?;
437
438 Ok(GpuWeights {
439 log_true,
440 log_false,
441 })
442}