1use std::ffi::c_void;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::provider::{
9 arith_kernels, d4_kernels, filter_kernels, ARITH_MODULE, D4_MODULE, FILTER_MODULE,
10};
11use xlog_cuda::{circuit_kernels, AsKernelParam, CudaKernelProvider, LaunchAsync, CIRCUIT_MODULE};
12
13use crate::compilation::gpu_d4::exclusive_scan_u32_inplace;
14use crate::xgcf::{Xgcf, XgcfNodeType};
15
16pub struct GpuCircuitBuilder {
21 pub node_type: TrackedCudaSlice<u8>,
22 pub child_offsets: TrackedCudaSlice<u32>,
23 pub child_indices: TrackedCudaSlice<u32>,
24 pub lit: TrackedCudaSlice<i32>,
25 pub decision_var: TrackedCudaSlice<u32>,
26 pub decision_child_false: TrackedCudaSlice<u32>,
27 pub decision_child_true: TrackedCudaSlice<u32>,
28}
29
30pub struct GpuCircuitLayout {
32 pub num_nodes: u32,
33 pub num_edges: u32,
34 pub num_levels: u32,
35 pub level_offsets: TrackedCudaSlice<u32>,
36 pub level_nodes: TrackedCudaSlice<u32>,
37 pub root: u32,
38 pub max_var: u32,
39 pub num_nodes_device: Option<TrackedCudaSlice<u32>>,
40 pub num_edges_device: Option<TrackedCudaSlice<u32>>,
41}
42
43pub struct GpuXgcf {
44 node_type: TrackedCudaSlice<u8>,
45 child_offsets: TrackedCudaSlice<u32>,
46 child_indices: TrackedCudaSlice<u32>,
47 lit: TrackedCudaSlice<i32>,
48 decision_var: TrackedCudaSlice<u32>,
49 decision_child_false: TrackedCudaSlice<u32>,
50 decision_child_true: TrackedCudaSlice<u32>,
51 level_nodes: TrackedCudaSlice<u32>,
52 level_offsets: TrackedCudaSlice<u32>,
53 level_offsets_host: Option<Vec<u32>>,
56 node_cap: u32,
57 edge_cap: u32,
58 num_levels: u32,
59 root: u32,
60 max_var: u32,
61 meta_num_nodes: TrackedCudaSlice<u32>,
62 meta_num_edges: TrackedCudaSlice<u32>,
63 var_log_true: TrackedCudaSlice<f64>,
64 var_log_false: TrackedCudaSlice<f64>,
65 values: TrackedCudaSlice<f64>,
66 adj: TrackedCudaSlice<f64>,
67 grad_true: TrackedCudaSlice<f64>,
68 grad_false: TrackedCudaSlice<f64>,
69 free_var_mask: Option<TrackedCudaSlice<u8>>,
70}
71
72fn checked_gpu_u32_len(context: &str, len: usize) -> Result<u32> {
73 u32::try_from(len)
74 .map_err(|_| XlogError::Compilation(format!("{context} exceeds u32::MAX: {len}")))
75}
76
77fn checked_gpu_len_add_one(context: &str, len: usize) -> Result<usize> {
78 len.checked_add(1)
79 .ok_or_else(|| XlogError::Compilation(format!("{context} length overflow")))
80}
81
82fn checked_gpu_launch_blocks(context: &str, item_count: usize, block_size: u32) -> Result<u32> {
83 let item_count = u32::try_from(item_count).map_err(|_| {
84 XlogError::Kernel(format!(
85 "{context} launch item count exceeds u32::MAX: {item_count}"
86 ))
87 })?;
88 item_count
89 .checked_add(block_size - 1)
90 .map(|rounded| rounded / block_size)
91 .ok_or_else(|| XlogError::Kernel(format!("{context} launch grid overflow")))
92}
93
94fn checked_host_level_width(level_offsets: &[u32], level: usize) -> Result<usize> {
95 let start = level_offsets[level];
96 let end = level_offsets[level + 1];
97 if end < start {
98 return Err(XlogError::Compilation(format!(
99 "XGCF invariant violation: level_offsets decrease at level {} ({} > {})",
100 level, start, end
101 )));
102 }
103 Ok((end - start) as usize)
104}
105
106fn validate_xgcf_for_gpu_upload(circuit: &Xgcf) -> Result<(u32, u32, u32)> {
107 let n = circuit.node_type.len();
108 if n == 0 {
109 return Err(XlogError::Compilation(
110 "GPU XGCF upload requires at least one node".to_string(),
111 ));
112 }
113 let node_count = checked_gpu_u32_len("GPU XGCF node count", n)?;
114 let child_offsets_len = checked_gpu_len_add_one("GPU XGCF child_offsets", n)?;
115 if circuit.child_offsets.len() != child_offsets_len {
116 return Err(XlogError::Compilation(format!(
117 "XGCF invariant violation: child_offsets len {} != num_nodes+1 ({})",
118 circuit.child_offsets.len(),
119 child_offsets_len
120 )));
121 }
122 if circuit.lit.len() != n
123 || circuit.decision_var.len() != n
124 || circuit.decision_child_false.len() != n
125 || circuit.decision_child_true.len() != n
126 {
127 return Err(XlogError::Compilation(
128 "XGCF invariant violation: per-node arrays length mismatch".to_string(),
129 ));
130 }
131
132 let edge_count = checked_gpu_u32_len("GPU XGCF edge count", circuit.child_indices.len())?;
133 let mut previous_offset = 0u32;
134 for (idx, &offset) in circuit.child_offsets.iter().enumerate() {
135 if offset < previous_offset {
136 return Err(XlogError::Compilation(format!(
137 "XGCF invariant violation: child_offsets decrease at index {} ({} > {})",
138 idx, previous_offset, offset
139 )));
140 }
141 if offset > edge_count {
142 return Err(XlogError::Compilation(format!(
143 "XGCF invariant violation: child_offsets[{}] {} exceeds child_indices len {}",
144 idx, offset, edge_count
145 )));
146 }
147 previous_offset = offset;
148 }
149 if previous_offset != edge_count {
150 return Err(XlogError::Compilation(format!(
151 "XGCF invariant violation: final child offset {} != child_indices len {}",
152 previous_offset, edge_count
153 )));
154 }
155 for (edge, &child) in circuit.child_indices.iter().enumerate() {
156 if child >= node_count {
157 return Err(XlogError::Compilation(format!(
158 "XGCF invariant violation: child_indices[{}] {} out of bounds (num_nodes={})",
159 edge, child, node_count
160 )));
161 }
162 }
163
164 for (idx, &ty) in circuit.node_type.iter().enumerate() {
165 match ty {
166 XgcfNodeType::Const0 | XgcfNodeType::Const1 => {}
167 XgcfNodeType::Lit => {
168 if circuit.lit[idx] == 0 {
169 return Err(XlogError::Compilation(format!(
170 "XGCF invariant violation: LIT node {} has lit=0",
171 idx
172 )));
173 }
174 }
175 XgcfNodeType::And | XgcfNodeType::Or => {
176 if circuit.child_offsets[idx] == circuit.child_offsets[idx + 1] {
177 return Err(XlogError::Compilation(format!(
178 "XGCF invariant violation: {:?} node {} has no children",
179 ty, idx
180 )));
181 }
182 }
183 XgcfNodeType::Decision => {
184 if circuit.decision_var[idx] == 0 {
185 return Err(XlogError::Compilation(format!(
186 "XGCF invariant violation: DECISION node {} has var=0",
187 idx
188 )));
189 }
190 if circuit.decision_child_false[idx] >= node_count {
191 return Err(XlogError::Compilation(format!(
192 "XGCF invariant violation: DECISION node {} false child {} out of bounds",
193 idx, circuit.decision_child_false[idx]
194 )));
195 }
196 if circuit.decision_child_true[idx] >= node_count {
197 return Err(XlogError::Compilation(format!(
198 "XGCF invariant violation: DECISION node {} true child {} out of bounds",
199 idx, circuit.decision_child_true[idx]
200 )));
201 }
202 }
203 }
204 }
205
206 if circuit.level_offsets.is_empty() || circuit.level_offsets[0] != 0 {
207 return Err(XlogError::Compilation(
208 "XGCF invariant violation: level_offsets must start at 0".to_string(),
209 ));
210 }
211 let level_nodes_len =
212 checked_gpu_u32_len("GPU XGCF level_nodes len", circuit.level_nodes.len())?;
213 let mut previous_level_offset = 0u32;
214 for (idx, &offset) in circuit.level_offsets.iter().enumerate() {
215 if offset < previous_level_offset {
216 return Err(XlogError::Compilation(format!(
217 "XGCF invariant violation: level_offsets decrease at index {} ({} > {})",
218 idx, previous_level_offset, offset
219 )));
220 }
221 if offset > level_nodes_len {
222 return Err(XlogError::Compilation(format!(
223 "XGCF invariant violation: level_offsets[{}] {} exceeds level_nodes len {}",
224 idx, offset, level_nodes_len
225 )));
226 }
227 previous_level_offset = offset;
228 }
229 if previous_level_offset != level_nodes_len {
230 return Err(XlogError::Compilation(format!(
231 "XGCF invariant violation: level_offsets last {} != level_nodes.len {}",
232 previous_level_offset, level_nodes_len
233 )));
234 }
235 for (idx, &node) in circuit.level_nodes.iter().enumerate() {
236 if node >= node_count {
237 return Err(XlogError::Compilation(format!(
238 "XGCF invariant violation: level_nodes[{}] {} out of bounds (num_nodes={})",
239 idx, node, node_count
240 )));
241 }
242 }
243 let num_levels_usize = circuit.level_offsets.len() - 1;
244 let num_levels = checked_gpu_u32_len("GPU XGCF level count", num_levels_usize)?;
245 if num_levels == 0 {
246 return Err(XlogError::Compilation(
247 "GPU XGCF upload requires at least one level".to_string(),
248 ));
249 }
250
251 if circuit.roots.len() != 1 {
252 return Err(XlogError::Compilation(format!(
253 "GPU XGCF eval expects exactly 1 root, got {}",
254 circuit.roots.len()
255 )));
256 }
257 if circuit.roots[0] >= node_count {
258 return Err(XlogError::Compilation(format!(
259 "XGCF invariant violation: root {} out of bounds (num_nodes={})",
260 circuit.roots[0], node_count
261 )));
262 }
263
264 Ok((node_count, edge_count, num_levels))
265}
266
267impl GpuXgcf {
268 pub fn from_device(
269 builder: GpuCircuitBuilder,
270 layout: GpuCircuitLayout,
271 provider: &CudaKernelProvider,
272 ) -> Result<GpuXgcf> {
273 if layout.num_nodes == 0 {
274 return Err(XlogError::Compilation(
275 "GpuXgcf::from_device requires num_nodes > 0".to_string(),
276 ));
277 }
278 if layout.root >= layout.num_nodes {
279 return Err(XlogError::Compilation(format!(
280 "GpuXgcf::from_device: root {} out of bounds (num_nodes={})",
281 layout.root, layout.num_nodes
282 )));
283 }
284 if layout.num_levels == 0 {
285 return Err(XlogError::Compilation(
286 "GpuXgcf::from_device requires num_levels > 0".to_string(),
287 ));
288 }
289
290 let num_nodes = layout.num_nodes as usize;
291 let num_edges = layout.num_edges as usize;
292 let node_cap = builder.node_type.len();
293 if num_nodes == 0 || num_nodes > node_cap {
294 return Err(XlogError::Compilation(
295 "GpuXgcf::from_device: num_nodes out of bounds".to_string(),
296 ));
297 }
298 let child_offsets_len =
299 checked_gpu_len_add_one("GpuXgcf::from_device child_offsets", node_cap)?;
300 if builder.child_offsets.len() != child_offsets_len
301 || builder.lit.len() != node_cap
302 || builder.decision_var.len() != node_cap
303 || builder.decision_child_false.len() != node_cap
304 || builder.decision_child_true.len() != node_cap
305 {
306 return Err(XlogError::Compilation(
307 "GpuXgcf::from_device: circuit buffer length mismatch".to_string(),
308 ));
309 }
310 if num_edges > builder.child_indices.len() {
311 return Err(XlogError::Compilation(
312 "GpuXgcf::from_device: num_edges out of bounds".to_string(),
313 ));
314 }
315
316 let num_levels = layout.num_levels as usize;
317 let level_offsets_len =
318 checked_gpu_len_add_one("GpuXgcf::from_device level_offsets", num_levels)?;
319 if layout.level_offsets.len() != level_offsets_len {
320 return Err(XlogError::Compilation(format!(
321 "GpuXgcf::from_device: level_offsets len {} != num_levels+1 ({})",
322 layout.level_offsets.len(),
323 level_offsets_len
324 )));
325 }
326 if layout.level_nodes.len() < num_nodes {
327 return Err(XlogError::Compilation(format!(
328 "GpuXgcf::from_device: level_nodes len {} < num_nodes ({})",
329 layout.level_nodes.len(),
330 num_nodes
331 )));
332 }
333
334 let memory = provider.memory();
335
336 let weights_len = (layout.max_var as usize) + 1;
337 let var_log_true = memory.alloc::<f64>(weights_len)?;
338 let var_log_false = memory.alloc::<f64>(weights_len)?;
339 let values = memory.alloc::<f64>(num_nodes)?;
340 let adj = memory.alloc::<f64>(num_nodes)?;
341 let grad_true = memory.alloc::<f64>(weights_len)?;
342 let grad_false = memory.alloc::<f64>(weights_len)?;
343
344 let meta_num_nodes = match layout.num_nodes_device {
345 Some(meta) => meta,
346 None => {
347 let mut meta = memory.alloc::<u32>(1)?;
348 provider
349 .htod_launch_metadata_sync_copy_into(&[layout.num_nodes], &mut meta)
350 .map_err(|e| {
351 XlogError::Kernel(format!("Failed to upload num_nodes meta: {}", e))
352 })?;
353 meta
354 }
355 };
356 let meta_num_edges = match layout.num_edges_device {
357 Some(meta) => meta,
358 None => {
359 let mut meta = memory.alloc::<u32>(1)?;
360 provider
361 .htod_launch_metadata_sync_copy_into(&[layout.num_edges], &mut meta)
362 .map_err(|e| {
363 XlogError::Kernel(format!("Failed to upload num_edges meta: {}", e))
364 })?;
365 meta
366 }
367 };
368
369 Ok(Self {
370 node_type: builder.node_type,
371 child_offsets: builder.child_offsets,
372 child_indices: builder.child_indices,
373 lit: builder.lit,
374 decision_var: builder.decision_var,
375 decision_child_false: builder.decision_child_false,
376 decision_child_true: builder.decision_child_true,
377 level_nodes: layout.level_nodes,
378 level_offsets: layout.level_offsets,
379 level_offsets_host: None,
380 node_cap: layout.num_nodes,
381 edge_cap: layout.num_edges,
382 num_levels: layout.num_levels,
383 root: layout.root,
384 max_var: layout.max_var,
385 meta_num_nodes,
386 meta_num_edges,
387 var_log_true,
388 var_log_false,
389 values,
390 adj,
391 grad_true,
392 grad_false,
393 free_var_mask: None,
394 })
395 }
396
397 pub fn smooth_random_vars_device(
402 &self,
403 provider: &CudaKernelProvider,
404 random_var_list: &TrackedCudaSlice<u32>,
405 random_var_count: u32,
406 smooth_node_cap: u32,
407 smooth_edge_cap: u32,
408 ) -> Result<GpuXgcf> {
409 if smooth_node_cap == 0 || smooth_edge_cap == 0 {
410 return Err(XlogError::Compilation(
411 "GPU smoothing requires non-zero node/edge caps".to_string(),
412 ));
413 }
414
415 let num_nodes = self.node_cap;
416 if num_nodes == 0 {
417 return Err(XlogError::Compilation(
418 "GPU smoothing: num_nodes must be > 0".to_string(),
419 ));
420 }
421 if self.child_offsets.len() < (num_nodes as usize + 1) {
422 return Err(XlogError::Compilation(
423 "GPU smoothing: child_offsets len mismatch".to_string(),
424 ));
425 }
426 let num_edges = self.edge_cap;
427 if num_edges == 0 {
428 return Err(XlogError::Compilation(
429 "GPU smoothing: num_edges must be > 0".to_string(),
430 ));
431 }
432
433 let list_len = u32::try_from(random_var_list.len()).map_err(|_| {
434 XlogError::Compilation("GPU smoothing: random var list len exceeds u32".to_string())
435 })?;
436 let num_random_vars = random_var_count;
437 if num_random_vars > list_len {
438 return Err(XlogError::Compilation(format!(
439 "GPU smoothing: random var count {} exceeds list len {}",
440 num_random_vars, list_len
441 )));
442 }
443
444 let base_node = 2u32.checked_add(num_random_vars).ok_or_else(|| {
445 XlogError::Compilation("GPU smoothing: base node overflow".to_string())
446 })?;
447 let base_nodes = (base_node as u64)
448 .checked_add(num_nodes as u64)
449 .ok_or_else(|| {
450 XlogError::Compilation("GPU smoothing: base node overflow".to_string())
451 })?;
452 if base_nodes > smooth_node_cap as u64 {
453 return Err(XlogError::Compilation(format!(
454 "GPU smoothing: base nodes {} exceed smooth_node_cap {}",
455 base_nodes, smooth_node_cap
456 )));
457 }
458
459 let words_per_support = num_random_vars.div_ceil(32).max(1);
460
461 let support_len = (num_nodes as u64)
462 .checked_mul(words_per_support as u64)
463 .and_then(|v| usize::try_from(v).ok())
464 .ok_or_else(|| {
465 XlogError::Compilation("GPU smoothing: support size overflow".to_string())
466 })?;
467
468 let dec_entries = (num_nodes as u64)
469 .checked_mul(2)
470 .and_then(|v| usize::try_from(v).ok())
471 .ok_or_else(|| {
472 XlogError::Compilation("GPU smoothing: decision array overflow".to_string())
473 })?;
474 let dec_entries_u32 = u32::try_from(dec_entries).map_err(|_| {
475 XlogError::Compilation("GPU smoothing: decision entries exceed u32".to_string())
476 })?;
477
478 let device = provider.device().inner();
479 let memory = provider.memory();
480 let block_size: u32 = 256;
481
482 let map_len = (self.max_var as usize)
483 .checked_add(1)
484 .ok_or_else(|| XlogError::Compilation("GPU smoothing: max_var overflow".to_string()))?;
485 let map_len_u32 = u32::try_from(map_len).map_err(|_| {
486 XlogError::Compilation("GPU smoothing: random map len exceeds u32".to_string())
487 })?;
488 let mut d_random_map = memory.alloc::<u32>(map_len)?;
489 if map_len > 0 {
490 let fill_const = device
491 .get_func(FILTER_MODULE, filter_kernels::FILL_U32_CONST)
492 .ok_or_else(|| XlogError::Kernel("fill_u32_const kernel not found".to_string()))?;
493 let grid = map_len_u32.div_ceil(block_size);
494 unsafe {
496 fill_const.clone().launch(
497 LaunchConfig {
498 grid_dim: (grid, 1, 1),
499 block_dim: (block_size, 1, 1),
500 shared_mem_bytes: 0,
501 },
502 (&mut d_random_map, map_len_u32, u32::MAX),
503 )
504 }
505 .map_err(|e| XlogError::Kernel(format!("fill_u32_const failed: {}", e)))?;
506 }
507 if num_random_vars > 0 {
508 let map_kernel = device
509 .get_func(FILTER_MODULE, filter_kernels::RANDOM_VAR_TO_BIT_FROM_LIST)
510 .ok_or_else(|| {
511 XlogError::Kernel("random_var_to_bit_from_list kernel not found".to_string())
512 })?;
513 let grid = num_random_vars.div_ceil(block_size);
514 unsafe {
516 map_kernel.clone().launch(
517 LaunchConfig {
518 grid_dim: (grid, 1, 1),
519 block_dim: (block_size, 1, 1),
520 shared_mem_bytes: 0,
521 },
522 (
523 random_var_list,
524 num_random_vars,
525 map_len_u32,
526 &mut d_random_map,
527 ),
528 )
529 }
530 .map_err(|e| XlogError::Kernel(format!("random_var_to_bit_from_list failed: {}", e)))?;
531 }
532
533 let mut support = memory.alloc::<u32>(support_len)?;
534 device
535 .memset_zeros(&mut support)
536 .map_err(|e| XlogError::Kernel(format!("Failed to zero support: {}", e)))?;
537
538 let support_kernel = device
539 .get_func(D4_MODULE, d4_kernels::D4_SUPPORT_LEVEL)
540 .ok_or_else(|| XlogError::Kernel("d4_support_level kernel not found".to_string()))?;
541
542 let num_levels = self.num_levels as usize;
543 let random_map_len = map_len_u32;
544 for level in 0..num_levels {
545 let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
546 Some(off) => checked_host_level_width(off, level)?,
547 None => self.level_nodes.len(),
548 };
549 if num_level_nodes == 0 {
550 continue;
551 }
552 let num_blocks =
553 checked_gpu_launch_blocks("d4_support_level", num_level_nodes, block_size)?;
554 let config = LaunchConfig {
555 grid_dim: (num_blocks, 1, 1),
556 block_dim: (block_size, 1, 1),
557 shared_mem_bytes: 0,
558 };
559 let level_u32 = level as u32;
560 let mut params: Vec<*mut c_void> = vec![
561 (&self.node_type).as_kernel_param(),
562 (&self.child_offsets).as_kernel_param(),
563 (&self.child_indices).as_kernel_param(),
564 (&self.lit).as_kernel_param(),
565 (&self.decision_var).as_kernel_param(),
566 (&self.decision_child_false).as_kernel_param(),
567 (&self.decision_child_true).as_kernel_param(),
568 (&self.level_nodes).as_kernel_param(),
569 (&self.level_offsets).as_kernel_param(),
570 level_u32.as_kernel_param(),
571 (&d_random_map).as_kernel_param(),
572 random_map_len.as_kernel_param(),
573 words_per_support.as_kernel_param(),
574 (&support).as_kernel_param(),
575 ];
576 unsafe { support_kernel.clone().launch(config, &mut params) }
578 .map_err(|e| XlogError::Kernel(format!("d4_support_level failed: {}", e)))?;
579 }
580
581 if num_random_vars > 0 {
582 let root_kernel = device
583 .get_func(D4_MODULE, d4_kernels::D4_SUPPORT_SET_ROOT_BITS)
584 .ok_or_else(|| {
585 XlogError::Kernel("d4_support_set_root_bits kernel not found".to_string())
586 })?;
587 let num_words = num_random_vars.div_ceil(32);
588 let grid = num_words.div_ceil(block_size);
589 unsafe {
591 root_kernel.clone().launch(
592 LaunchConfig {
593 grid_dim: (grid, 1, 1),
594 block_dim: (block_size, 1, 1),
595 shared_mem_bytes: 0,
596 },
597 (self.root, num_random_vars, words_per_support, &mut support),
598 )
599 }
600 .map_err(|e| XlogError::Kernel(format!("d4_support_set_root_bits failed: {}", e)))?;
601 }
602
603 let mut wrap_prefix_or = memory.alloc::<u32>(num_edges as usize)?;
604 let mut wrap_missing_or = memory.alloc::<u32>(num_edges as usize)?;
605 let mut wrap_prefix_dec = memory.alloc::<u32>(dec_entries)?;
606 let mut wrap_missing_dec = memory.alloc::<u32>(dec_entries)?;
607
608 device
609 .memset_zeros(&mut wrap_prefix_or)
610 .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_prefix_or: {}", e)))?;
611 device
612 .memset_zeros(&mut wrap_missing_or)
613 .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_missing_or: {}", e)))?;
614 device
615 .memset_zeros(&mut wrap_prefix_dec)
616 .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_prefix_dec: {}", e)))?;
617 device
618 .memset_zeros(&mut wrap_missing_dec)
619 .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_missing_dec: {}", e)))?;
620
621 let mut out_edge_counts = memory.alloc::<u32>(smooth_node_cap as usize)?;
622 device
623 .memset_zeros(&mut out_edge_counts)
624 .map_err(|e| XlogError::Kernel(format!("Failed to zero edge_counts: {}", e)))?;
625
626 let count_kernel = device
627 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_COUNT)
628 .ok_or_else(|| XlogError::Kernel("d4_smooth_count kernel not found".to_string()))?;
629 let num_blocks = num_nodes.div_ceil(block_size);
630 let mut params: Vec<*mut c_void> = vec![
631 (&self.node_type).as_kernel_param(),
632 (&self.child_offsets).as_kernel_param(),
633 (&self.child_indices).as_kernel_param(),
634 (&self.decision_var).as_kernel_param(),
635 (&self.decision_child_false).as_kernel_param(),
636 (&self.decision_child_true).as_kernel_param(),
637 (&self.meta_num_nodes).as_kernel_param(),
638 (&support).as_kernel_param(),
639 words_per_support.as_kernel_param(),
640 (&d_random_map).as_kernel_param(),
641 random_map_len.as_kernel_param(),
642 (&wrap_prefix_or).as_kernel_param(),
643 (&wrap_missing_or).as_kernel_param(),
644 (&wrap_prefix_dec).as_kernel_param(),
645 (&wrap_missing_dec).as_kernel_param(),
646 (&out_edge_counts).as_kernel_param(),
647 base_node.as_kernel_param(),
648 smooth_node_cap.as_kernel_param(),
649 ];
650 unsafe {
652 count_kernel.clone().launch(
653 LaunchConfig {
654 grid_dim: (num_blocks, 1, 1),
655 block_dim: (block_size, 1, 1),
656 shared_mem_bytes: 0,
657 },
658 &mut params,
659 )
660 }
661 .map_err(|e| XlogError::Kernel(format!("d4_smooth_count failed: {}", e)))?;
662
663 exclusive_scan_u32_inplace(provider, &mut wrap_prefix_or, num_edges)?;
664 exclusive_scan_u32_inplace(provider, &mut wrap_prefix_dec, dec_entries_u32)?;
665
666 let mut wrap_counts = memory.alloc::<u32>(3)?;
667 device
668 .memset_zeros(&mut wrap_counts)
669 .map_err(|e| XlogError::Kernel(format!("Failed to zero wrap_counts: {}", e)))?;
670
671 let counts_kernel = device
672 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_WRAPPER_COUNTS)
673 .ok_or_else(|| {
674 XlogError::Kernel("d4_smooth_wrapper_counts kernel not found".to_string())
675 })?;
676 unsafe {
678 counts_kernel.clone().launch(
679 LaunchConfig {
680 grid_dim: (1, 1, 1),
681 block_dim: (1, 1, 1),
682 shared_mem_bytes: 0,
683 },
684 (
685 &wrap_prefix_or,
686 &wrap_missing_or,
687 num_edges,
688 &wrap_prefix_dec,
689 &wrap_missing_dec,
690 dec_entries_u32,
691 base_node,
692 &self.meta_num_nodes,
693 u32::MAX,
694 &mut wrap_counts,
695 ),
696 )
697 }
698 .map_err(|e| XlogError::Kernel(format!("d4_smooth_wrapper_counts failed: {}", e)))?;
699
700 let wrap_or_kernel = device
701 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_WRAPPER_EDGE_COUNTS_OR)
702 .ok_or_else(|| {
703 XlogError::Kernel("d4_smooth_wrapper_edge_counts_or kernel not found".to_string())
704 })?;
705 if num_edges > 0 {
706 let num_blocks = num_edges.div_ceil(block_size);
707 unsafe {
709 wrap_or_kernel.clone().launch(
710 LaunchConfig {
711 grid_dim: (num_blocks, 1, 1),
712 block_dim: (block_size, 1, 1),
713 shared_mem_bytes: 0,
714 },
715 (
716 &wrap_prefix_or,
717 &wrap_missing_or,
718 num_edges,
719 base_node,
720 &self.meta_num_nodes,
721 smooth_node_cap,
722 &mut out_edge_counts,
723 ),
724 )
725 }
726 .map_err(|e| {
727 XlogError::Kernel(format!("d4_smooth_wrapper_edge_counts_or failed: {}", e))
728 })?;
729 }
730
731 let wrap_dec_kernel = device
732 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_WRAPPER_EDGE_COUNTS_DEC)
733 .ok_or_else(|| {
734 XlogError::Kernel("d4_smooth_wrapper_edge_counts_dec kernel not found".to_string())
735 })?;
736 if dec_entries > 0 {
737 let num_blocks = dec_entries_u32.div_ceil(block_size);
738 unsafe {
740 wrap_dec_kernel.clone().launch(
741 LaunchConfig {
742 grid_dim: (num_blocks, 1, 1),
743 block_dim: (block_size, 1, 1),
744 shared_mem_bytes: 0,
745 },
746 (
747 &wrap_prefix_dec,
748 &wrap_missing_dec,
749 dec_entries_u32,
750 base_node,
751 &self.meta_num_nodes,
752 &wrap_counts,
753 smooth_node_cap,
754 &mut out_edge_counts,
755 ),
756 )
757 }
758 .map_err(|e| {
759 XlogError::Kernel(format!("d4_smooth_wrapper_edge_counts_dec failed: {}", e))
760 })?;
761 }
762
763 let mut out_child_offsets = memory.alloc::<u32>((smooth_node_cap as usize) + 1)?;
764 device
765 .memset_zeros(&mut out_child_offsets)
766 .map_err(|e| XlogError::Kernel(format!("Failed to zero child_offsets: {}", e)))?;
767 if smooth_node_cap > 0 {
768 device
769 .dtod_copy(
770 &out_edge_counts,
771 &mut out_child_offsets.slice_mut(0..smooth_node_cap as usize),
772 )
773 .map_err(|e| XlogError::Kernel(format!("Failed to copy edge_counts: {}", e)))?;
774 }
775 let child_scan_len = smooth_node_cap.checked_add(1).ok_or_else(|| {
776 XlogError::Compilation("GPU smoothing: child offset scan overflow".to_string())
777 })?;
778 exclusive_scan_u32_inplace(provider, &mut out_child_offsets, child_scan_len)?;
779
780 let edge_cap_check = device
781 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_CHECK_EDGE_CAP)
782 .ok_or_else(|| {
783 XlogError::Kernel("d4_smooth_check_edge_cap kernel not found".to_string())
784 })?;
785 let mut meta_num_nodes = memory.alloc::<u32>(1)?;
786 let mut meta_num_edges = memory.alloc::<u32>(1)?;
787 device
788 .memset_zeros(&mut meta_num_nodes)
789 .map_err(|e| XlogError::Kernel(format!("Failed to zero smooth num_nodes: {}", e)))?;
790 device
791 .memset_zeros(&mut meta_num_edges)
792 .map_err(|e| XlogError::Kernel(format!("Failed to zero smooth num_edges: {}", e)))?;
793 unsafe {
795 edge_cap_check.clone().launch(
796 LaunchConfig {
797 grid_dim: (1, 1, 1),
798 block_dim: (1, 1, 1),
799 shared_mem_bytes: 0,
800 },
801 (
802 &out_child_offsets,
803 smooth_node_cap,
804 smooth_edge_cap,
805 &wrap_counts,
806 &mut meta_num_nodes,
807 &mut meta_num_edges,
808 ),
809 )
810 }
811 .map_err(|e| XlogError::Kernel(format!("d4_smooth_check_edge_cap failed: {}", e)))?;
812
813 let mut out_node_type = memory.alloc::<u8>(smooth_node_cap as usize)?;
814 let mut out_child_indices = memory.alloc::<u32>(smooth_edge_cap as usize)?;
815 let mut out_lit = memory.alloc::<i32>(smooth_node_cap as usize)?;
816 let mut out_decision_var = memory.alloc::<u32>(smooth_node_cap as usize)?;
817 let mut out_decision_child_false = memory.alloc::<u32>(smooth_node_cap as usize)?;
818 let mut out_decision_child_true = memory.alloc::<u32>(smooth_node_cap as usize)?;
819 let mut out_node_level = memory.alloc::<u32>(smooth_node_cap as usize)?;
820
821 device
822 .memset_zeros(&mut out_node_type)
823 .map_err(|e| XlogError::Kernel(format!("Failed to zero node_type: {}", e)))?;
824 device
825 .memset_zeros(&mut out_child_indices)
826 .map_err(|e| XlogError::Kernel(format!("Failed to zero child_indices: {}", e)))?;
827 device
828 .memset_zeros(&mut out_lit)
829 .map_err(|e| XlogError::Kernel(format!("Failed to zero lit: {}", e)))?;
830 device
831 .memset_zeros(&mut out_decision_var)
832 .map_err(|e| XlogError::Kernel(format!("Failed to zero decision_var: {}", e)))?;
833 device
834 .memset_zeros(&mut out_decision_child_false)
835 .map_err(|e| {
836 XlogError::Kernel(format!("Failed to zero decision_child_false: {}", e))
837 })?;
838 device
839 .memset_zeros(&mut out_decision_child_true)
840 .map_err(|e| XlogError::Kernel(format!("Failed to zero decision_child_true: {}", e)))?;
841 device
842 .memset_zeros(&mut out_node_level)
843 .map_err(|e| XlogError::Kernel(format!("Failed to zero node_level: {}", e)))?;
844
845 let init_kernel = device
846 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_INIT_NODES)
847 .ok_or_else(|| {
848 XlogError::Kernel("d4_smooth_init_nodes kernel not found".to_string())
849 })?;
850 let init_blocks = checked_gpu_launch_blocks(
851 "d4_smooth_init_nodes",
852 num_random_vars.max(1) as usize,
853 block_size,
854 )?;
855 unsafe {
857 init_kernel.clone().launch(
858 LaunchConfig {
859 grid_dim: (init_blocks, 1, 1),
860 block_dim: (block_size, 1, 1),
861 shared_mem_bytes: 0,
862 },
863 (
864 random_var_list,
865 num_random_vars,
866 smooth_node_cap,
867 &mut out_node_type,
868 &mut out_lit,
869 &mut out_decision_var,
870 &mut out_decision_child_false,
871 &mut out_decision_child_true,
872 &mut out_node_level,
873 ),
874 )
875 }
876 .map_err(|e| XlogError::Kernel(format!("d4_smooth_init_nodes failed: {}", e)))?;
877
878 let num_levels_out = self
879 .num_levels
880 .checked_mul(2)
881 .and_then(|levels| levels.checked_add(4))
882 .ok_or_else(|| {
883 XlogError::Compilation("GPU smoothing output level count overflow".to_string())
884 })?;
885 let num_levels_out_usize = num_levels_out as usize;
886 let level_offsets_len =
887 checked_gpu_len_add_one("GPU smoothing level offsets", num_levels_out_usize)?;
888
889 let emit_kernel = device
890 .get_func(D4_MODULE, d4_kernels::D4_SMOOTH_EMIT_LEVEL)
891 .ok_or_else(|| {
892 XlogError::Kernel("d4_smooth_emit_level kernel not found".to_string())
893 })?;
894 for level in 0..num_levels {
895 let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
896 Some(off) => checked_host_level_width(off, level)?,
897 None => self.level_nodes.len(),
898 };
899 if num_level_nodes == 0 {
900 continue;
901 }
902 let num_blocks =
903 checked_gpu_launch_blocks("xgcf_smooth_forward", num_level_nodes, block_size)?;
904 let level_u32 = level as u32;
905 let mut params: Vec<*mut c_void> = vec![
906 (&self.node_type).as_kernel_param(),
907 (&self.child_offsets).as_kernel_param(),
908 (&self.child_indices).as_kernel_param(),
909 (&self.lit).as_kernel_param(),
910 (&self.decision_var).as_kernel_param(),
911 (&self.decision_child_false).as_kernel_param(),
912 (&self.decision_child_true).as_kernel_param(),
913 (&self.level_nodes).as_kernel_param(),
914 (&self.level_offsets).as_kernel_param(),
915 level_u32.as_kernel_param(),
916 (&support).as_kernel_param(),
917 words_per_support.as_kernel_param(),
918 (&wrap_prefix_or).as_kernel_param(),
919 (&wrap_missing_or).as_kernel_param(),
920 (&wrap_prefix_dec).as_kernel_param(),
921 (&wrap_missing_dec).as_kernel_param(),
922 base_node.as_kernel_param(),
923 (&self.meta_num_nodes).as_kernel_param(),
924 (&wrap_counts).as_kernel_param(),
925 num_random_vars.as_kernel_param(),
926 num_levels_out.as_kernel_param(),
927 (&out_node_type).as_kernel_param(),
928 (&out_child_offsets).as_kernel_param(),
929 (&out_child_indices).as_kernel_param(),
930 (&out_lit).as_kernel_param(),
931 (&out_decision_var).as_kernel_param(),
932 (&out_decision_child_false).as_kernel_param(),
933 (&out_decision_child_true).as_kernel_param(),
934 (&out_node_level).as_kernel_param(),
935 ];
936 unsafe {
938 emit_kernel.clone().launch(
939 LaunchConfig {
940 grid_dim: (num_blocks, 1, 1),
941 block_dim: (block_size, 1, 1),
942 shared_mem_bytes: 0,
943 },
944 &mut params,
945 )
946 }
947 .map_err(|e| XlogError::Kernel(format!("d4_smooth_emit_level failed: {}", e)))?;
948 }
949
950 let mut level_counts = memory.alloc::<u32>(num_levels_out_usize)?;
951 let mut level_offsets = memory.alloc::<u32>(level_offsets_len)?;
952 let mut level_cursors = memory.alloc::<u32>(num_levels_out_usize)?;
953 let mut level_nodes = memory.alloc::<u32>(smooth_node_cap as usize)?;
954
955 device
956 .memset_zeros(&mut level_counts)
957 .map_err(|e| XlogError::Kernel(format!("Failed to zero level_counts: {}", e)))?;
958 device
959 .memset_zeros(&mut level_offsets)
960 .map_err(|e| XlogError::Kernel(format!("Failed to zero level_offsets: {}", e)))?;
961 device
962 .memset_zeros(&mut level_cursors)
963 .map_err(|e| XlogError::Kernel(format!("Failed to zero level_cursors: {}", e)))?;
964 device
965 .memset_zeros(&mut level_nodes)
966 .map_err(|e| XlogError::Kernel(format!("Failed to zero level_nodes: {}", e)))?;
967
968 let mut compile_needed = memory.alloc::<u32>(1)?;
969 provider
970 .htod_launch_metadata_sync_copy_into(&[1u32], &mut compile_needed)
971 .map_err(|e| XlogError::Kernel(format!("Failed to upload compile_needed: {}", e)))?;
972
973 let levelize_counts = device
974 .get_func(D4_MODULE, d4_kernels::D4_LEVELIZE_COUNTS)
975 .ok_or_else(|| XlogError::Kernel("d4_levelize_counts kernel not found".to_string()))?;
976 let num_blocks =
977 checked_gpu_launch_blocks("d4_smooth_levelize", smooth_node_cap as usize, block_size)?;
978 unsafe {
980 levelize_counts.clone().launch(
981 LaunchConfig {
982 grid_dim: (num_blocks, 1, 1),
983 block_dim: (block_size, 1, 1),
984 shared_mem_bytes: 0,
985 },
986 (
987 &compile_needed,
988 &out_node_level,
989 &meta_num_nodes,
990 num_levels_out,
991 &mut level_counts,
992 ),
993 )
994 }
995 .map_err(|e| XlogError::Kernel(format!("d4_levelize_counts failed: {}", e)))?;
996
997 device
998 .dtod_copy(
999 &level_counts,
1000 &mut level_offsets.slice_mut(0..num_levels_out_usize),
1001 )
1002 .map_err(|e| XlogError::Kernel(format!("Failed to copy level_counts: {}", e)))?;
1003 let level_scan_len = num_levels_out.checked_add(1).ok_or_else(|| {
1004 XlogError::Compilation("GPU smoothing: level offset scan overflow".to_string())
1005 })?;
1006 exclusive_scan_u32_inplace(provider, &mut level_offsets, level_scan_len)?;
1007
1008 let levelize_emit = device
1009 .get_func(D4_MODULE, d4_kernels::D4_LEVELIZE_EMIT)
1010 .ok_or_else(|| XlogError::Kernel("d4_levelize_emit kernel not found".to_string()))?;
1011 unsafe {
1013 levelize_emit.clone().launch(
1014 LaunchConfig {
1015 grid_dim: (num_blocks, 1, 1),
1016 block_dim: (block_size, 1, 1),
1017 shared_mem_bytes: 0,
1018 },
1019 (
1020 &compile_needed,
1021 &out_node_level,
1022 &meta_num_nodes,
1023 num_levels_out,
1024 &level_offsets,
1025 &mut level_cursors,
1026 &mut level_nodes,
1027 ),
1028 )
1029 }
1030 .map_err(|e| XlogError::Kernel(format!("d4_levelize_emit failed: {}", e)))?;
1031
1032 let builder = GpuCircuitBuilder {
1034 node_type: out_node_type,
1035 child_offsets: out_child_offsets,
1036 child_indices: out_child_indices,
1037 lit: out_lit,
1038 decision_var: out_decision_var,
1039 decision_child_false: out_decision_child_false,
1040 decision_child_true: out_decision_child_true,
1041 };
1042 let layout = GpuCircuitLayout {
1043 num_nodes: smooth_node_cap,
1044 num_edges: smooth_edge_cap,
1045 num_levels: num_levels_out,
1046 level_offsets,
1047 level_nodes,
1048 root: base_node + self.root,
1049 max_var: self.max_var,
1050 num_nodes_device: Some(meta_num_nodes),
1051 num_edges_device: Some(meta_num_edges),
1052 };
1053
1054 GpuXgcf::from_device(builder, layout, provider)
1055 }
1056
1057 pub fn upload(provider: &CudaKernelProvider, circuit: &Xgcf) -> Result<Self> {
1058 let (node_cap, edge_cap, num_levels) = validate_xgcf_for_gpu_upload(circuit)?;
1059
1060 let memory = provider.memory().clone();
1061
1062 let n = circuit.node_type.len();
1063 let mut host_node_type: Vec<u8> = Vec::with_capacity(n);
1064 for &ty in &circuit.node_type {
1065 host_node_type.push(ty as u8);
1066 }
1067
1068 let mut max_var: u32 = 0;
1069 for (&ty, &lit) in circuit.node_type.iter().zip(circuit.lit.iter()) {
1070 if ty == XgcfNodeType::Lit && lit != 0 {
1071 max_var = max_var.max(lit.unsigned_abs());
1072 }
1073 }
1074 for &var in &circuit.decision_var {
1075 max_var = max_var.max(var);
1076 }
1077
1078 let mut d_node_type = memory.alloc::<u8>(n)?;
1079 provider
1080 .htod_sync_copy_into_tracked(&host_node_type, &mut d_node_type)
1081 .map_err(|e| XlogError::Kernel(format!("Failed to upload circuit node_type: {}", e)))?;
1082
1083 let mut d_child_offsets = memory.alloc::<u32>(circuit.child_offsets.len())?;
1084 provider
1085 .htod_sync_copy_into_tracked(&circuit.child_offsets, &mut d_child_offsets)
1086 .map_err(|e| {
1087 XlogError::Kernel(format!("Failed to upload circuit child_offsets: {}", e))
1088 })?;
1089
1090 let mut d_child_indices = memory.alloc::<u32>(circuit.child_indices.len())?;
1091 provider
1092 .htod_sync_copy_into_tracked(&circuit.child_indices, &mut d_child_indices)
1093 .map_err(|e| {
1094 XlogError::Kernel(format!("Failed to upload circuit child_indices: {}", e))
1095 })?;
1096
1097 let mut d_lit = memory.alloc::<i32>(circuit.lit.len())?;
1098 provider
1099 .htod_sync_copy_into_tracked(&circuit.lit, &mut d_lit)
1100 .map_err(|e| XlogError::Kernel(format!("Failed to upload circuit lit: {}", e)))?;
1101
1102 let mut d_decision_var = memory.alloc::<u32>(circuit.decision_var.len())?;
1103 provider
1104 .htod_sync_copy_into_tracked(&circuit.decision_var, &mut d_decision_var)
1105 .map_err(|e| {
1106 XlogError::Kernel(format!("Failed to upload circuit decision_var: {}", e))
1107 })?;
1108
1109 let mut d_decision_child_false = memory.alloc::<u32>(circuit.decision_child_false.len())?;
1110 provider
1111 .htod_sync_copy_into_tracked(&circuit.decision_child_false, &mut d_decision_child_false)
1112 .map_err(|e| {
1113 XlogError::Kernel(format!(
1114 "Failed to upload circuit decision_child_false: {}",
1115 e
1116 ))
1117 })?;
1118
1119 let mut d_decision_child_true = memory.alloc::<u32>(circuit.decision_child_true.len())?;
1120 provider
1121 .htod_sync_copy_into_tracked(&circuit.decision_child_true, &mut d_decision_child_true)
1122 .map_err(|e| {
1123 XlogError::Kernel(format!(
1124 "Failed to upload circuit decision_child_true: {}",
1125 e
1126 ))
1127 })?;
1128
1129 let mut d_level_nodes = memory.alloc::<u32>(circuit.level_nodes.len())?;
1130 provider
1131 .htod_sync_copy_into_tracked(&circuit.level_nodes, &mut d_level_nodes)
1132 .map_err(|e| {
1133 XlogError::Kernel(format!("Failed to upload circuit level_nodes: {}", e))
1134 })?;
1135
1136 let mut d_level_offsets = memory.alloc::<u32>(circuit.level_offsets.len())?;
1137 provider
1138 .htod_sync_copy_into_tracked(&circuit.level_offsets, &mut d_level_offsets)
1139 .map_err(|e| {
1140 XlogError::Kernel(format!("Failed to upload circuit level_offsets: {}", e))
1141 })?;
1142
1143 let weights_len = (max_var as usize) + 1;
1144 let var_log_true = memory.alloc::<f64>(weights_len)?;
1145 let var_log_false = memory.alloc::<f64>(weights_len)?;
1146 let values = memory.alloc::<f64>(n)?;
1147 let adj = memory.alloc::<f64>(n)?;
1148 let grad_true = memory.alloc::<f64>(weights_len)?;
1149 let grad_false = memory.alloc::<f64>(weights_len)?;
1150 let mut meta_num_nodes = memory.alloc::<u32>(1)?;
1151 provider
1152 .htod_launch_metadata_sync_copy_into(&[node_cap], &mut meta_num_nodes)
1153 .map_err(|e| XlogError::Kernel(format!("Failed to upload num_nodes meta: {}", e)))?;
1154 let mut meta_num_edges = memory.alloc::<u32>(1)?;
1155 provider
1156 .htod_launch_metadata_sync_copy_into(&[edge_cap], &mut meta_num_edges)
1157 .map_err(|e| XlogError::Kernel(format!("Failed to upload num_edges meta: {}", e)))?;
1158
1159 Ok(Self {
1160 node_type: d_node_type,
1161 child_offsets: d_child_offsets,
1162 child_indices: d_child_indices,
1163 lit: d_lit,
1164 decision_var: d_decision_var,
1165 decision_child_false: d_decision_child_false,
1166 decision_child_true: d_decision_child_true,
1167 level_nodes: d_level_nodes,
1168 level_offsets: d_level_offsets,
1169 level_offsets_host: Some(circuit.level_offsets.clone()),
1170 node_cap,
1171 edge_cap,
1172 num_levels,
1173 root: circuit.roots[0],
1174 max_var,
1175 meta_num_nodes,
1176 meta_num_edges,
1177 var_log_true,
1178 var_log_false,
1179 values,
1180 adj,
1181 grad_true,
1182 grad_false,
1183 free_var_mask: None,
1184 })
1185 }
1186
1187 pub fn max_var(&self) -> u32 {
1188 self.max_var
1189 }
1190
1191 pub fn root(&self) -> u32 {
1193 self.root
1194 }
1195
1196 pub fn num_nodes(&self) -> usize {
1198 self.node_cap as usize
1199 }
1200
1201 pub fn num_edges(&self) -> usize {
1203 self.edge_cap as usize
1204 }
1205
1206 pub fn num_levels(&self) -> u32 {
1208 self.num_levels
1209 }
1210
1211 pub fn num_nodes_device(&self) -> &TrackedCudaSlice<u32> {
1213 &self.meta_num_nodes
1214 }
1215
1216 pub fn num_edges_device(&self) -> &TrackedCudaSlice<u32> {
1218 &self.meta_num_edges
1219 }
1220
1221 pub fn level_nodes(&self) -> &TrackedCudaSlice<u32> {
1223 &self.level_nodes
1224 }
1225
1226 pub fn level_offsets(&self) -> &TrackedCudaSlice<u32> {
1228 &self.level_offsets
1229 }
1230
1231 pub fn node_type(&self) -> &TrackedCudaSlice<u8> {
1233 &self.node_type
1234 }
1235
1236 pub fn child_offsets(&self) -> &TrackedCudaSlice<u32> {
1238 &self.child_offsets
1239 }
1240
1241 pub fn child_indices(&self) -> &TrackedCudaSlice<u32> {
1243 &self.child_indices
1244 }
1245
1246 pub fn lit(&self) -> &TrackedCudaSlice<i32> {
1248 &self.lit
1249 }
1250
1251 pub fn decision_var(&self) -> &TrackedCudaSlice<u32> {
1253 &self.decision_var
1254 }
1255
1256 pub fn decision_child_false(&self) -> &TrackedCudaSlice<u32> {
1257 &self.decision_child_false
1258 }
1259
1260 pub fn decision_child_true(&self) -> &TrackedCudaSlice<u32> {
1261 &self.decision_child_true
1262 }
1263
1264 pub fn values(&self) -> &TrackedCudaSlice<f64> {
1266 &self.values
1267 }
1268
1269 pub fn grad_true(&self) -> &TrackedCudaSlice<f64> {
1271 &self.grad_true
1272 }
1273
1274 pub fn grad_false(&self) -> &TrackedCudaSlice<f64> {
1276 &self.grad_false
1277 }
1278
1279 pub fn var_log_true(&self) -> &TrackedCudaSlice<f64> {
1281 &self.var_log_true
1282 }
1283
1284 pub fn var_log_false(&self) -> &TrackedCudaSlice<f64> {
1286 &self.var_log_false
1287 }
1288
1289 #[allow(dead_code)] pub(crate) fn var_log_true_mut(&mut self) -> &mut TrackedCudaSlice<f64> {
1292 &mut self.var_log_true
1293 }
1294
1295 #[allow(dead_code)] pub(crate) fn var_log_false_mut(&mut self) -> &mut TrackedCudaSlice<f64> {
1298 &mut self.var_log_false
1299 }
1300
1301 pub fn var_log_weights_mut(
1306 &mut self,
1307 ) -> (&mut TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) {
1308 (&mut self.var_log_true, &mut self.var_log_false)
1309 }
1310
1311 pub fn set_free_var_mask_device(&mut self, mask: TrackedCudaSlice<u8>) -> Result<()> {
1313 if mask.len() != self.var_log_true.len() {
1314 return Err(XlogError::Compilation(format!(
1315 "GPU free-var mask len {} != weights len {}",
1316 mask.len(),
1317 self.var_log_true.len()
1318 )));
1319 }
1320 self.free_var_mask = Some(mask);
1321 Ok(())
1322 }
1323
1324 #[allow(dead_code)] pub(crate) fn set_free_var_mask_from_host(
1327 &mut self,
1328 provider: &CudaKernelProvider,
1329 mask: &[u8],
1330 ) -> Result<()> {
1331 if mask.len() != self.var_log_true.len() {
1332 return Err(XlogError::Compilation(format!(
1333 "GPU free-var mask len {} != weights len {}",
1334 mask.len(),
1335 self.var_log_true.len()
1336 )));
1337 }
1338 let memory = provider.memory();
1339 let mut d_mask = memory.alloc::<u8>(mask.len())?;
1340 provider
1341 .htod_sync_copy_into_tracked(mask, &mut d_mask)
1342 .map_err(|e| XlogError::Kernel(format!("Failed to upload free_var_mask: {}", e)))?;
1343 self.free_var_mask = Some(d_mask);
1344 Ok(())
1345 }
1346
1347 pub fn set_base_weights(
1352 &mut self,
1353 provider: &CudaKernelProvider,
1354 var_log_weights: &[(f64, f64)],
1355 ) -> Result<()> {
1356 let weights_len = (self.max_var as usize) + 1;
1357 if var_log_weights.len() < weights_len {
1358 return Err(XlogError::Compilation(format!(
1359 "GPU XGCF weights init expects weight table len >= {}, got {}",
1360 weights_len,
1361 var_log_weights.len()
1362 )));
1363 }
1364
1365 let mut host_true: Vec<f64> = Vec::with_capacity(weights_len);
1366 let mut host_false: Vec<f64> = Vec::with_capacity(weights_len);
1367 for &(t, f) in &var_log_weights[..weights_len] {
1368 host_true.push(t);
1369 host_false.push(f);
1370 }
1371
1372 provider
1373 .htod_sync_copy_into_tracked(&host_true, &mut self.var_log_true)
1374 .map_err(|e| XlogError::Kernel(format!("Failed to upload log_true weights: {}", e)))?;
1375 provider
1376 .htod_sync_copy_into_tracked(&host_false, &mut self.var_log_false)
1377 .map_err(|e| XlogError::Kernel(format!("Failed to upload log_false weights: {}", e)))?;
1378
1379 Ok(())
1380 }
1381
1382 pub fn eval_log_wmc_device_inplace(
1386 &mut self,
1387 provider: &CudaKernelProvider,
1388 out_log_z: &mut TrackedCudaSlice<f64>,
1389 ) -> Result<()> {
1390 if out_log_z.len() != 1 {
1391 return Err(XlogError::Compilation(format!(
1392 "GPU device logZ output len {} != 1",
1393 out_log_z.len()
1394 )));
1395 }
1396
1397 let device = provider.device().inner();
1398 let func = device
1399 .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FORWARD_LEVEL)
1400 .ok_or_else(|| XlogError::Kernel("xgcf_forward_level kernel not found".to_string()))?;
1401
1402 let block_size: u32 = 256;
1403 let num_levels: usize = self.num_levels as usize;
1404 for level in 0..num_levels {
1405 let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
1406 Some(off) => checked_host_level_width(off, level)?,
1407 None => self.level_nodes.len(),
1408 };
1409 if num_level_nodes == 0 {
1410 continue;
1411 }
1412
1413 let num_blocks =
1414 checked_gpu_launch_blocks("xgcf_forward_level", num_level_nodes, block_size)?;
1415 let config = LaunchConfig {
1416 grid_dim: (num_blocks, 1, 1),
1417 block_dim: (block_size, 1, 1),
1418 shared_mem_bytes: 0,
1419 };
1420 let level_u32: u32 = level as u32;
1421
1422 let mut params: Vec<*mut c_void> = vec![
1423 (&self.node_type).as_kernel_param(),
1424 (&self.child_offsets).as_kernel_param(),
1425 (&self.child_indices).as_kernel_param(),
1426 (&self.lit).as_kernel_param(),
1427 (&self.decision_var).as_kernel_param(),
1428 (&self.decision_child_false).as_kernel_param(),
1429 (&self.decision_child_true).as_kernel_param(),
1430 (&self.level_nodes).as_kernel_param(),
1431 (&self.level_offsets).as_kernel_param(),
1432 level_u32.as_kernel_param(),
1433 (&self.var_log_true).as_kernel_param(),
1434 (&self.var_log_false).as_kernel_param(),
1435 (&self.values).as_kernel_param(),
1436 ];
1437
1438 unsafe { func.clone().launch(config, &mut params) }
1440 .map_err(|e| XlogError::Kernel(format!("xgcf_forward_level failed: {}", e)))?;
1441 }
1442
1443 self.apply_free_var_correction(provider, true, false)?;
1444
1445 let root_idx = self.root as usize;
1446 let root_view = self.values.slice(root_idx..(root_idx + 1));
1447 device
1448 .dtod_copy(&root_view, out_log_z)
1449 .map_err(|e| XlogError::Kernel(format!("Failed to copy device logZ: {}", e)))?;
1450
1451 Ok(())
1454 }
1455
1456 pub fn eval_log_wmc_device_into(
1458 &mut self,
1459 provider: &CudaKernelProvider,
1460 var_log_weights: &[(f64, f64)],
1461 out_log_z: &mut TrackedCudaSlice<f64>,
1462 ) -> Result<()> {
1463 self.set_base_weights(provider, var_log_weights)?;
1464 self.eval_log_wmc_device_inplace(provider, out_log_z)
1465 }
1466
1467 pub fn eval_log_wmc_device(
1469 &mut self,
1470 provider: &CudaKernelProvider,
1471 var_log_weights: &[(f64, f64)],
1472 ) -> Result<TrackedCudaSlice<f64>> {
1473 let memory = provider.memory();
1474 let mut out_log_z = memory.alloc::<f64>(1)?;
1475 self.eval_log_wmc_device_into(provider, var_log_weights, &mut out_log_z)?;
1476 Ok(out_log_z)
1477 }
1478
1479 fn apply_free_var_correction(
1480 &mut self,
1481 provider: &CudaKernelProvider,
1482 apply_log_z: bool,
1483 apply_grads: bool,
1484 ) -> Result<()> {
1485 let Some(mask) = self.free_var_mask.as_ref() else {
1486 return Ok(());
1487 };
1488
1489 if mask.len() != self.var_log_true.len() {
1490 return Err(XlogError::Compilation(format!(
1491 "GPU free-var mask len {} != weights len {}",
1492 mask.len(),
1493 self.var_log_true.len()
1494 )));
1495 }
1496
1497 let n = u32::try_from(mask.len())
1498 .map_err(|_| XlogError::Compilation("GPU free-var mask length overflow".to_string()))?;
1499 if n == 0 {
1500 return Ok(());
1501 }
1502
1503 let device = provider.device().inner();
1504 let block_dim = 256u32;
1505 let grid_dim = n.div_ceil(block_dim);
1506
1507 if apply_grads {
1508 let apply_grad = device
1509 .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FREE_VAR_APPLY_GRAD)
1510 .ok_or_else(|| {
1511 XlogError::Kernel("xgcf_free_var_apply_grad kernel not found".to_string())
1512 })?;
1513 unsafe {
1515 apply_grad.clone().launch(
1516 LaunchConfig {
1517 grid_dim: (grid_dim, 1, 1),
1518 block_dim: (block_dim, 1, 1),
1519 shared_mem_bytes: 0,
1520 },
1521 (
1522 mask,
1523 &self.var_log_true,
1524 &self.var_log_false,
1525 n,
1526 &mut self.grad_true,
1527 &mut self.grad_false,
1528 ),
1529 )
1530 }
1531 .map_err(|e| XlogError::Kernel(format!("xgcf_free_var_apply_grad failed: {}", e)))?;
1532 }
1533
1534 if apply_log_z {
1535 let reduce_stage = device
1536 .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FREE_VAR_REDUCE_STAGE)
1537 .ok_or_else(|| {
1538 XlogError::Kernel("xgcf_free_var_reduce_stage kernel not found".to_string())
1539 })?;
1540 let add_scalar = device
1541 .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_ADD_SCALAR)
1542 .ok_or_else(|| XlogError::Kernel("xgcf_add_scalar kernel not found".to_string()))?;
1543
1544 let memory = provider.memory();
1545 let mut buf_a = memory.alloc::<f64>(mask.len())?;
1546 let mut buf_b = memory.alloc::<f64>(mask.len())?;
1547
1548 let mut stage_n = n;
1549 let mut stage0 = true;
1550 let mut output_is_a = true;
1551 loop {
1552 let out_len = stage_n.div_ceil(2);
1553 let stage_grid = out_len.div_ceil(block_dim);
1554
1555 let (in_buf, out_buf): (&TrackedCudaSlice<f64>, &mut TrackedCudaSlice<f64>) =
1556 if output_is_a {
1557 (&buf_b, &mut buf_a)
1558 } else {
1559 (&buf_a, &mut buf_b)
1560 };
1561 let mode = if stage0 { 0u32 } else { 1u32 };
1562
1563 unsafe {
1565 reduce_stage.clone().launch(
1566 LaunchConfig {
1567 grid_dim: (stage_grid, 1, 1),
1568 block_dim: (block_dim, 1, 1),
1569 shared_mem_bytes: 0,
1570 },
1571 (
1572 mask,
1573 &self.var_log_true,
1574 &self.var_log_false,
1575 in_buf,
1576 stage_n,
1577 mode,
1578 out_buf,
1579 ),
1580 )
1581 }
1582 .map_err(|e| {
1583 XlogError::Kernel(format!("xgcf_free_var_reduce_stage failed: {}", e))
1584 })?;
1585
1586 if out_len == 1 {
1587 let result_buf = if output_is_a { &buf_a } else { &buf_b };
1588 unsafe {
1590 add_scalar.clone().launch(
1591 LaunchConfig {
1592 grid_dim: (1, 1, 1),
1593 block_dim: (1, 1, 1),
1594 shared_mem_bytes: 0,
1595 },
1596 (&mut self.values, self.root, result_buf),
1597 )
1598 }
1599 .map_err(|e| XlogError::Kernel(format!("xgcf_add_scalar failed: {}", e)))?;
1600 break;
1601 }
1602
1603 stage_n = out_len;
1604 stage0 = false;
1605 output_is_a = !output_is_a;
1606 }
1607 }
1608
1609 Ok(())
1610 }
1611
1612 pub fn eval_grads_inplace(&mut self, provider: &CudaKernelProvider) -> Result<()> {
1618 let device = provider.device().inner();
1619
1620 let func = device
1622 .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FORWARD_LEVEL)
1623 .ok_or_else(|| XlogError::Kernel("xgcf_forward_level kernel not found".to_string()))?;
1624
1625 let block_size: u32 = 256;
1626 let num_levels: usize = self.num_levels as usize;
1627 for level in 0..num_levels {
1628 let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
1629 Some(off) => checked_host_level_width(off, level)?,
1630 None => self.level_nodes.len(),
1631 };
1632 if num_level_nodes == 0 {
1633 continue;
1634 }
1635
1636 let num_blocks =
1637 checked_gpu_launch_blocks("xgcf_forward_level", num_level_nodes, block_size)?;
1638 let config = LaunchConfig {
1639 grid_dim: (num_blocks, 1, 1),
1640 block_dim: (block_size, 1, 1),
1641 shared_mem_bytes: 0,
1642 };
1643 let level_u32: u32 = level as u32;
1644
1645 let mut params: Vec<*mut c_void> = vec![
1646 (&self.node_type).as_kernel_param(),
1647 (&self.child_offsets).as_kernel_param(),
1648 (&self.child_indices).as_kernel_param(),
1649 (&self.lit).as_kernel_param(),
1650 (&self.decision_var).as_kernel_param(),
1651 (&self.decision_child_false).as_kernel_param(),
1652 (&self.decision_child_true).as_kernel_param(),
1653 (&self.level_nodes).as_kernel_param(),
1654 (&self.level_offsets).as_kernel_param(),
1655 level_u32.as_kernel_param(),
1656 (&self.var_log_true).as_kernel_param(),
1657 (&self.var_log_false).as_kernel_param(),
1658 (&self.values).as_kernel_param(),
1659 ];
1660
1661 unsafe { func.clone().launch(config, &mut params) }
1663 .map_err(|e| XlogError::Kernel(format!("xgcf_forward_level failed: {}", e)))?;
1664 }
1665
1666 device
1668 .memset_zeros(&mut self.adj)
1669 .map_err(|e| XlogError::Kernel(format!("Failed to zero adj buffer: {}", e)))?;
1670 device
1671 .memset_zeros(&mut self.grad_true)
1672 .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_true buffer: {}", e)))?;
1673 device
1674 .memset_zeros(&mut self.grad_false)
1675 .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_false buffer: {}", e)))?;
1676
1677 let root_idx = self.root as usize;
1679 let mut root_adj_view = self.adj.slice_mut(root_idx..(root_idx + 1));
1680 let fill_const = device
1681 .get_func(ARITH_MODULE, arith_kernels::ARITH_FILL_CONST_F64)
1682 .ok_or_else(|| {
1683 XlogError::Kernel("arith_fill_const_f64 kernel not found".to_string())
1684 })?;
1685 unsafe {
1687 fill_const.clone().launch(
1688 LaunchConfig {
1689 grid_dim: (1, 1, 1),
1690 block_dim: (1, 1, 1),
1691 shared_mem_bytes: 0,
1692 },
1693 (1.0_f64, 1u32, &mut root_adj_view),
1694 )
1695 }
1696 .map_err(|e| XlogError::Kernel(format!("arith_fill_const_f64 failed: {}", e)))?;
1697
1698 let propagate = device
1699 .get_func(
1700 CIRCUIT_MODULE,
1701 circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
1702 )
1703 .ok_or_else(|| {
1704 XlogError::Kernel("xgcf_backward_level_propagate kernel not found".to_string())
1705 })?;
1706 let decision_grad = device
1707 .get_func(
1708 CIRCUIT_MODULE,
1709 circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
1710 )
1711 .ok_or_else(|| {
1712 XlogError::Kernel("xgcf_backward_level_decision_grad kernel not found".to_string())
1713 })?;
1714 let lit_grad = device
1715 .get_func(
1716 CIRCUIT_MODULE,
1717 circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
1718 )
1719 .ok_or_else(|| {
1720 XlogError::Kernel("xgcf_backward_level_lit_grad kernel not found".to_string())
1721 })?;
1722
1723 let num_levels: usize = self.num_levels as usize;
1724 for level in (0..num_levels).rev() {
1725 let num_level_nodes: usize = match self.level_offsets_host.as_ref() {
1726 Some(off) => checked_host_level_width(off, level)?,
1727 None => self.level_nodes.len(),
1728 };
1729 if num_level_nodes == 0 {
1730 continue;
1731 }
1732
1733 let num_blocks =
1734 checked_gpu_launch_blocks("xgcf_backward_level", num_level_nodes, block_size)?;
1735 let config = LaunchConfig {
1736 grid_dim: (num_blocks, 1, 1),
1737 block_dim: (block_size, 1, 1),
1738 shared_mem_bytes: 0,
1739 };
1740 let level_u32: u32 = level as u32;
1741
1742 let mut params: Vec<*mut c_void> = vec![
1743 (&self.node_type).as_kernel_param(),
1744 (&self.child_offsets).as_kernel_param(),
1745 (&self.child_indices).as_kernel_param(),
1746 (&self.decision_var).as_kernel_param(),
1747 (&self.decision_child_false).as_kernel_param(),
1748 (&self.decision_child_true).as_kernel_param(),
1749 (&self.level_nodes).as_kernel_param(),
1750 (&self.level_offsets).as_kernel_param(),
1751 level_u32.as_kernel_param(),
1752 (&self.var_log_true).as_kernel_param(),
1753 (&self.var_log_false).as_kernel_param(),
1754 (&self.values).as_kernel_param(),
1755 (&self.adj).as_kernel_param(),
1756 ];
1757
1758 unsafe { propagate.clone().launch(config, &mut params) }.map_err(|e| {
1760 XlogError::Kernel(format!("xgcf_backward_level_propagate failed: {}", e))
1761 })?;
1762
1763 let mut params: Vec<*mut c_void> = vec![
1764 (&self.node_type).as_kernel_param(),
1765 (&self.decision_var).as_kernel_param(),
1766 (&self.decision_child_false).as_kernel_param(),
1767 (&self.decision_child_true).as_kernel_param(),
1768 (&self.level_nodes).as_kernel_param(),
1769 (&self.level_offsets).as_kernel_param(),
1770 level_u32.as_kernel_param(),
1771 (&self.var_log_true).as_kernel_param(),
1772 (&self.var_log_false).as_kernel_param(),
1773 (&self.values).as_kernel_param(),
1774 (&self.adj).as_kernel_param(),
1775 (&self.grad_true).as_kernel_param(),
1776 (&self.grad_false).as_kernel_param(),
1777 ];
1778
1779 unsafe { decision_grad.clone().launch(config, &mut params) }.map_err(|e| {
1781 XlogError::Kernel(format!("xgcf_backward_level_decision_grad failed: {}", e))
1782 })?;
1783
1784 unsafe {
1786 lit_grad.clone().launch(
1787 config,
1788 (
1789 &self.node_type,
1790 &self.lit,
1791 &self.level_nodes,
1792 &self.level_offsets,
1793 level_u32,
1794 &self.adj,
1795 &self.grad_true,
1796 &self.grad_false,
1797 ),
1798 )
1799 }
1800 .map_err(|e| {
1801 XlogError::Kernel(format!("xgcf_backward_level_lit_grad failed: {}", e))
1802 })?;
1803 }
1804
1805 self.apply_free_var_correction(provider, true, true)?;
1806 Ok(())
1809 }
1810
1811 #[cfg(feature = "host-io")]
1812 pub fn eval_log_wmc(
1813 &mut self,
1814 provider: &CudaKernelProvider,
1815 var_log_weights: &[(f64, f64)],
1816 ) -> Result<f64> {
1817 let device = provider.device().inner();
1818 let mut out_log_z = provider.memory().alloc::<f64>(1)?;
1819 self.eval_log_wmc_device_into(provider, var_log_weights, &mut out_log_z)?;
1820
1821 let mut host = [0.0_f64];
1822 device
1823 .dtoh_sync_copy_into(&out_log_z, &mut host)
1824 .map_err(|e| XlogError::Kernel(format!("Failed to read circuit root value: {}", e)))?;
1825 Ok(host[0])
1826 }
1827
1828 #[cfg(feature = "host-io")]
1829 pub fn eval_log_wmc_and_grads(
1830 &mut self,
1831 provider: &CudaKernelProvider,
1832 var_log_weights: &[(f64, f64)],
1833 ) -> Result<(f64, Vec<f64>, Vec<f64>)> {
1834 self.set_base_weights(provider, var_log_weights)?;
1835 self.eval_grads_inplace(provider)?;
1836
1837 let device = provider.device().inner();
1838
1839 let weights_len = (self.max_var as usize) + 1;
1840 let mut host_grad_true: Vec<f64> = vec![0.0; weights_len];
1841 let mut host_grad_false: Vec<f64> = vec![0.0; weights_len];
1842
1843 let root_idx = self.root as usize;
1844 let root_view = self.values.slice(root_idx..(root_idx + 1));
1845 let mut log_z = [0.0_f64];
1846 device
1847 .dtoh_sync_copy_into(&root_view, &mut log_z)
1848 .map_err(|e| XlogError::Kernel(format!("Failed to read circuit root value: {}", e)))?;
1849
1850 device
1851 .dtoh_sync_copy_into(&self.grad_true, &mut host_grad_true)
1852 .map_err(|e| XlogError::Kernel(format!("Failed to download grad_true: {}", e)))?;
1853 device
1854 .dtoh_sync_copy_into(&self.grad_false, &mut host_grad_false)
1855 .map_err(|e| XlogError::Kernel(format!("Failed to download grad_false: {}", e)))?;
1856
1857 Ok((log_z[0], host_grad_true, host_grad_false))
1858 }
1859}