1use std::ffi::c_void;
4
5use cudarc::driver::{DeviceSlice, LaunchConfig};
6use xlog_core::{Result, XlogError};
7use xlog_cuda::memory::TrackedCudaSlice;
8use xlog_cuda::{circuit_kernels, AsKernelParam, CudaFunction, LaunchAsync, CIRCUIT_MODULE};
9
10use super::TestContext;
11
12#[derive(Debug, Clone)]
13pub struct TinyXgcfSpec {
14 pub num_nodes: usize,
15 pub num_vars: usize,
16 pub root: u32,
17 pub node_type: Vec<u8>,
18 pub child_offsets: Vec<u32>,
19 pub child_indices: Vec<u32>,
20 pub lit: Vec<i32>,
21 pub decision_var: Vec<u32>,
22 pub decision_child_false: Vec<u32>,
23 pub decision_child_true: Vec<u32>,
24 pub level_nodes: Vec<u32>,
25 pub levels: Vec<(u32, u32)>,
26 pub var_log_true: Vec<f64>,
27 pub var_log_false: Vec<f64>,
28 pub expected_values: Vec<f64>,
29 pub expected_grad_true: Vec<f64>,
30 pub expected_grad_false: Vec<f64>,
31}
32
33#[derive(Debug, Clone)]
34pub struct TinyXgcfRun {
35 pub values: Vec<f64>,
36 pub adj: Vec<f64>,
37 pub grad_true: Vec<f64>,
38 pub grad_false: Vec<f64>,
39}
40
41pub struct TinyXgcfDevice {
46 pub num_nodes: usize,
47 pub num_vars: usize,
48 pub root: u32,
49 levels: Vec<(u32, u32)>,
50
51 forward_fn: CudaFunction,
53 backward_propagate_fn: CudaFunction,
54 backward_decision_grad_fn: CudaFunction,
55 backward_lit_grad_fn: CudaFunction,
56
57 d_node_type: TrackedCudaSlice<u8>,
59 d_child_offsets: TrackedCudaSlice<u32>,
60 d_child_indices: TrackedCudaSlice<u32>,
61 d_lit: TrackedCudaSlice<i32>,
62 d_decision_var: TrackedCudaSlice<u32>,
63 d_decision_child_false: TrackedCudaSlice<u32>,
64 d_decision_child_true: TrackedCudaSlice<u32>,
65 d_level_nodes: TrackedCudaSlice<u32>,
66 d_level_offsets: TrackedCudaSlice<u32>,
67
68 d_var_log_true: TrackedCudaSlice<f64>,
70 d_var_log_false: TrackedCudaSlice<f64>,
71
72 d_values: TrackedCudaSlice<f64>,
74 d_adj: TrackedCudaSlice<f64>,
75 d_grad_true: TrackedCudaSlice<f64>,
76 d_grad_false: TrackedCudaSlice<f64>,
77}
78
79impl TinyXgcfDevice {
80 pub fn upload(ctx: &TestContext, spec: &TinyXgcfSpec) -> Result<Self> {
81 let device = ctx.device.inner();
82
83 let forward_fn = device
84 .get_func(CIRCUIT_MODULE, circuit_kernels::XGCF_FORWARD_LEVEL)
85 .ok_or_else(|| {
86 XlogError::Kernel(format!(
87 "Kernel {} not found in {}",
88 circuit_kernels::XGCF_FORWARD_LEVEL,
89 CIRCUIT_MODULE
90 ))
91 })?;
92 let backward_propagate_fn = device
93 .get_func(
94 CIRCUIT_MODULE,
95 circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
96 )
97 .ok_or_else(|| {
98 XlogError::Kernel(format!(
99 "Kernel {} not found in {}",
100 circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
101 CIRCUIT_MODULE
102 ))
103 })?;
104 let backward_decision_grad_fn = device
105 .get_func(
106 CIRCUIT_MODULE,
107 circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
108 )
109 .ok_or_else(|| {
110 XlogError::Kernel(format!(
111 "Kernel {} not found in {}",
112 circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
113 CIRCUIT_MODULE
114 ))
115 })?;
116 let backward_lit_grad_fn = device
117 .get_func(
118 CIRCUIT_MODULE,
119 circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
120 )
121 .ok_or_else(|| {
122 XlogError::Kernel(format!(
123 "Kernel {} not found in {}",
124 circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
125 CIRCUIT_MODULE
126 ))
127 })?;
128
129 let mut d_node_type = ctx.memory.alloc::<u8>(spec.node_type.len())?;
130 ctx.htod_sync_copy_into(&spec.node_type, &mut d_node_type)
131 .map_err(|e| XlogError::Kernel(format!("Failed to upload node_type: {}", e)))?;
132
133 let mut d_child_offsets = ctx.memory.alloc::<u32>(spec.child_offsets.len())?;
134 ctx.htod_sync_copy_into(&spec.child_offsets, &mut d_child_offsets)
135 .map_err(|e| XlogError::Kernel(format!("Failed to upload child_offsets: {}", e)))?;
136
137 let mut d_child_indices = ctx.memory.alloc::<u32>(spec.child_indices.len())?;
138 ctx.htod_sync_copy_into(&spec.child_indices, &mut d_child_indices)
139 .map_err(|e| XlogError::Kernel(format!("Failed to upload child_indices: {}", e)))?;
140
141 let mut d_lit = ctx.memory.alloc::<i32>(spec.lit.len())?;
142 ctx.htod_sync_copy_into(&spec.lit, &mut d_lit)
143 .map_err(|e| XlogError::Kernel(format!("Failed to upload lit: {}", e)))?;
144
145 let mut d_decision_var = ctx.memory.alloc::<u32>(spec.decision_var.len())?;
146 ctx.htod_sync_copy_into(&spec.decision_var, &mut d_decision_var)
147 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_var: {}", e)))?;
148
149 let mut d_decision_child_false =
150 ctx.memory.alloc::<u32>(spec.decision_child_false.len())?;
151 ctx.htod_sync_copy_into(&spec.decision_child_false, &mut d_decision_child_false)
152 .map_err(|e| {
153 XlogError::Kernel(format!("Failed to upload decision_child_false: {}", e))
154 })?;
155
156 let mut d_decision_child_true = ctx.memory.alloc::<u32>(spec.decision_child_true.len())?;
157 ctx.htod_sync_copy_into(&spec.decision_child_true, &mut d_decision_child_true)
158 .map_err(|e| {
159 XlogError::Kernel(format!("Failed to upload decision_child_true: {}", e))
160 })?;
161
162 let mut d_level_nodes = ctx.memory.alloc::<u32>(spec.level_nodes.len())?;
163 ctx.htod_sync_copy_into(&spec.level_nodes, &mut d_level_nodes)
164 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_nodes: {}", e)))?;
165
166 if spec.levels.is_empty() {
168 return Err(XlogError::Kernel(
169 "TinyXgcfSpec requires non-empty levels".to_string(),
170 ));
171 }
172 let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
173 for &(offset, _len) in &spec.levels {
174 level_offsets.push(offset);
175 }
176 let (last_offset, last_len) = *spec.levels.last().unwrap();
177 level_offsets.push(last_offset + last_len);
178
179 if level_offsets[0] != 0 {
180 return Err(XlogError::Kernel(
181 "TinyXgcfSpec level_offsets must start at 0".to_string(),
182 ));
183 }
184 for (i, &(offset, len)) in spec.levels.iter().enumerate() {
185 let expected_next = offset + len;
186 if level_offsets[i] != offset || level_offsets[i + 1] != expected_next {
187 return Err(XlogError::Kernel(
188 "TinyXgcfSpec levels must be contiguous and match offsets".to_string(),
189 ));
190 }
191 }
192 let total = *level_offsets.last().unwrap() as usize;
193 if total != spec.level_nodes.len() {
194 return Err(XlogError::Kernel(format!(
195 "TinyXgcfSpec level_nodes len {} != level_offsets.last {}",
196 spec.level_nodes.len(),
197 total
198 )));
199 }
200
201 let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
202 ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
203 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
204
205 let mut d_var_log_true = ctx.memory.alloc::<f64>(spec.var_log_true.len())?;
206 ctx.htod_sync_copy_into(&spec.var_log_true, &mut d_var_log_true)
207 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
208
209 let mut d_var_log_false = ctx.memory.alloc::<f64>(spec.var_log_false.len())?;
210 ctx.htod_sync_copy_into(&spec.var_log_false, &mut d_var_log_false)
211 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
212
213 let d_values = ctx.memory.alloc::<f64>(spec.num_nodes)?;
214 let d_adj = ctx.memory.alloc::<f64>(spec.num_nodes)?;
215 let d_grad_true = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
216 let d_grad_false = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
217
218 Ok(Self {
219 num_nodes: spec.num_nodes,
220 num_vars: spec.num_vars,
221 root: spec.root,
222 levels: spec.levels.clone(),
223 forward_fn,
224 backward_propagate_fn,
225 backward_decision_grad_fn,
226 backward_lit_grad_fn,
227 d_node_type,
228 d_child_offsets,
229 d_child_indices,
230 d_lit,
231 d_decision_var,
232 d_decision_child_false,
233 d_decision_child_true,
234 d_level_nodes,
235 d_level_offsets,
236 d_var_log_true,
237 d_var_log_false,
238 d_values,
239 d_adj,
240 d_grad_true,
241 d_grad_false,
242 })
243 }
244
245 fn launch_level_cached(
246 kernel: &CudaFunction,
247 num_level_nodes: u32,
248 params: &mut Vec<*mut c_void>,
249 ) -> Result<()> {
250 if num_level_nodes == 0 {
251 return Ok(());
252 }
253 let block_size = 256u32;
254 let num_blocks = (num_level_nodes + block_size - 1) / block_size;
255 let config = LaunchConfig {
256 grid_dim: (num_blocks, 1, 1),
257 block_dim: (block_size, 1, 1),
258 shared_mem_bytes: 0,
259 };
260 unsafe { kernel.clone().launch(config, params) }
262 .map_err(|e| XlogError::Kernel(format!("Failed to launch level kernel: {}", e)))?;
263 Ok(())
264 }
265
266 pub fn set_weights(
267 &mut self,
268 ctx: &TestContext,
269 log_true: &[f64],
270 log_false: &[f64],
271 ) -> Result<()> {
272 if log_true.len() != self.d_var_log_true.len()
273 || log_false.len() != self.d_var_log_false.len()
274 {
275 return Err(XlogError::Kernel(format!(
276 "Weight length mismatch: got (true={}, false={}), expected (true={}, false={})",
277 log_true.len(),
278 log_false.len(),
279 self.d_var_log_true.len(),
280 self.d_var_log_false.len()
281 )));
282 }
283 ctx.htod_sync_copy_into(log_true, &mut self.d_var_log_true)
284 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
285 ctx.htod_sync_copy_into(log_false, &mut self.d_var_log_false)
286 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
287 Ok(())
288 }
289
290 pub fn forward_launch(&mut self, _ctx: &TestContext) -> Result<()> {
292 for (level, &(_offset, len)) in self.levels.iter().enumerate() {
293 let level_u32 = level as u32;
294 let mut params: Vec<*mut c_void> = vec![
295 (&self.d_node_type).as_kernel_param(),
296 (&self.d_child_offsets).as_kernel_param(),
297 (&self.d_child_indices).as_kernel_param(),
298 (&self.d_lit).as_kernel_param(),
299 (&self.d_decision_var).as_kernel_param(),
300 (&self.d_decision_child_false).as_kernel_param(),
301 (&self.d_decision_child_true).as_kernel_param(),
302 (&self.d_level_nodes).as_kernel_param(),
303 (&self.d_level_offsets).as_kernel_param(),
304 level_u32.as_kernel_param(),
305 (&self.d_var_log_true).as_kernel_param(),
306 (&self.d_var_log_false).as_kernel_param(),
307 (&mut self.d_values).as_kernel_param(),
308 ];
309 Self::launch_level_cached(&self.forward_fn, len, &mut params)?;
310 }
311 Ok(())
312 }
313
314 pub fn forward_download_values(&mut self, ctx: &TestContext) -> Result<Vec<f64>> {
315 self.forward_launch(ctx)?;
316 ctx.sync_and_check()?;
317 ctx.dtoh_sync_copy(&self.d_values)
318 .map_err(|e| XlogError::Kernel(format!("Failed to download values: {}", e)))
319 }
320
321 pub fn forward_download_root(&mut self, ctx: &TestContext) -> Result<f64> {
322 self.forward_launch(ctx)?;
323 ctx.sync_and_check()?;
324 let root_idx: usize = self.root as usize;
325 if root_idx >= self.num_nodes {
326 return Err(XlogError::Kernel(format!(
327 "Root {} out of bounds for num_nodes {}",
328 self.root, self.num_nodes
329 )));
330 }
331 let root_view = self.d_values.slice(root_idx..(root_idx + 1));
332 let mut root_host = [0.0f64];
333 ctx.dtoh_sync_copy_into(&root_view, &mut root_host)
334 .map_err(|e| XlogError::Kernel(format!("Failed to download root value: {}", e)))?;
335 Ok(root_host[0])
336 }
337
338 pub fn backward_only_launch(&mut self, ctx: &TestContext) -> Result<()> {
340 let device = ctx.device.inner();
341 device
342 .memset_zeros(&mut self.d_adj)
343 .map_err(|e| XlogError::Kernel(format!("Failed to zero adj: {}", e)))?;
344 device
345 .memset_zeros(&mut self.d_grad_true)
346 .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_true: {}", e)))?;
347 device
348 .memset_zeros(&mut self.d_grad_false)
349 .map_err(|e| XlogError::Kernel(format!("Failed to zero grad_false: {}", e)))?;
350
351 let root_idx: usize = self.root as usize;
352 if root_idx >= self.num_nodes {
353 return Err(XlogError::Kernel(format!(
354 "Root {} out of bounds for num_nodes {}",
355 self.root, self.num_nodes
356 )));
357 }
358 let mut root_view = self.d_adj.slice_mut(root_idx..(root_idx + 1));
359 ctx.htod_sync_copy_into(&[1.0f64], &mut root_view)
360 .map_err(|e| XlogError::Kernel(format!("Failed to set root adjoint: {}", e)))?;
361
362 for (level, &(_offset, len)) in self.levels.iter().enumerate().rev() {
363 let level_u32 = level as u32;
364 let mut params: Vec<*mut c_void> = vec![
365 (&self.d_node_type).as_kernel_param(),
366 (&self.d_child_offsets).as_kernel_param(),
367 (&self.d_child_indices).as_kernel_param(),
368 (&self.d_decision_var).as_kernel_param(),
369 (&self.d_decision_child_false).as_kernel_param(),
370 (&self.d_decision_child_true).as_kernel_param(),
371 (&self.d_level_nodes).as_kernel_param(),
372 (&self.d_level_offsets).as_kernel_param(),
373 level_u32.as_kernel_param(),
374 (&self.d_var_log_true).as_kernel_param(),
375 (&self.d_var_log_false).as_kernel_param(),
376 (&self.d_values).as_kernel_param(),
377 (&mut self.d_adj).as_kernel_param(),
378 ];
379 Self::launch_level_cached(&self.backward_propagate_fn, len, &mut params)?;
380 }
381
382 for (level, &(_offset, len)) in self.levels.iter().enumerate().rev() {
383 let level_u32 = level as u32;
384 let mut params: Vec<*mut c_void> = vec![
385 (&self.d_node_type).as_kernel_param(),
386 (&self.d_decision_var).as_kernel_param(),
387 (&self.d_decision_child_false).as_kernel_param(),
388 (&self.d_decision_child_true).as_kernel_param(),
389 (&self.d_level_nodes).as_kernel_param(),
390 (&self.d_level_offsets).as_kernel_param(),
391 level_u32.as_kernel_param(),
392 (&self.d_var_log_true).as_kernel_param(),
393 (&self.d_var_log_false).as_kernel_param(),
394 (&self.d_values).as_kernel_param(),
395 (&self.d_adj).as_kernel_param(),
396 (&mut self.d_grad_true).as_kernel_param(),
397 (&mut self.d_grad_false).as_kernel_param(),
398 ];
399 Self::launch_level_cached(&self.backward_decision_grad_fn, len, &mut params)?;
400 }
401
402 for (level, &(_offset, len)) in self.levels.iter().enumerate().rev() {
403 let level_u32 = level as u32;
404 let mut params: Vec<*mut c_void> = vec![
405 (&self.d_node_type).as_kernel_param(),
406 (&self.d_lit).as_kernel_param(),
407 (&self.d_level_nodes).as_kernel_param(),
408 (&self.d_level_offsets).as_kernel_param(),
409 level_u32.as_kernel_param(),
410 (&self.d_adj).as_kernel_param(),
411 (&mut self.d_grad_true).as_kernel_param(),
412 (&mut self.d_grad_false).as_kernel_param(),
413 ];
414 Self::launch_level_cached(&self.backward_lit_grad_fn, len, &mut params)?;
415 }
416
417 Ok(())
418 }
419
420 pub fn forward_then_backward_launch(&mut self, ctx: &TestContext) -> Result<()> {
422 self.forward_launch(ctx)?;
423 self.backward_only_launch(ctx)
424 }
425}
426
427fn logsumexp2(a: f64, b: f64) -> f64 {
428 let m = a.max(b);
429 if m.is_infinite() && m.is_sign_negative() {
430 return m;
431 }
432 m + ((a - m).exp() + (b - m).exp()).ln()
433}
434
435pub fn tiny_xgcf_spec() -> TinyXgcfSpec {
437 const CONST0: u8 = 0;
438 const CONST1: u8 = 1;
439 const LIT: u8 = 2;
440 const AND: u8 = 3;
441 const OR: u8 = 4;
442 const DECISION: u8 = 5;
443
444 let num_nodes = 7;
453 let root = 6u32;
454
455 let node_type: Vec<u8> = vec![CONST1, LIT, LIT, AND, DECISION, CONST0, OR];
456 let lit: Vec<i32> = vec![0, 1, -2, 0, 0, 0, 0];
457 let decision_var: Vec<u32> = vec![0, 0, 0, 0, 3, 0, 0];
458 let decision_child_false: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0];
459 let decision_child_true: Vec<u32> = vec![0, 0, 0, 0, 3, 0, 0];
460
461 let child_offsets: Vec<u32> = vec![0, 0, 0, 0, 2, 2, 2, 4];
462 let child_indices: Vec<u32> = vec![1, 2, 4, 5];
463
464 let level_nodes: Vec<u32> = vec![0, 1, 2, 5, 3, 4, 6];
466 let levels: Vec<(u32, u32)> = vec![(0, 4), (4, 1), (5, 1), (6, 1)];
467
468 let num_vars = 3usize;
469 let var_log_true: Vec<f64> = vec![
470 0.0,
471 0.7f64.ln(), 0.2f64.ln(), 0.6f64.ln(), ];
475 let var_log_false: Vec<f64> = vec![
476 0.0,
477 0.3f64.ln(), 0.8f64.ln(), 0.4f64.ln(), ];
481
482 let v0 = 0.0;
483 let v1 = var_log_true[1];
484 let v2 = var_log_false[2];
485 let v3 = v1 + v2;
486 let v4 = logsumexp2(var_log_false[3] + v0, var_log_true[3] + v3);
487 let v5 = f64::NEG_INFINITY;
488 let v6 = logsumexp2(v4, v5);
489
490 let expected_values: Vec<f64> = vec![v0, v1, v2, v3, v4, v5, v6];
491
492 let p_false = (var_log_false[3] + v0 - v4).exp();
493 let p_true = (var_log_true[3] + v3 - v4).exp();
494
495 let mut expected_grad_true = vec![0.0f64; num_vars + 1];
496 let mut expected_grad_false = vec![0.0f64; num_vars + 1];
497 expected_grad_true[1] = p_true; expected_grad_false[2] = p_true; expected_grad_true[3] = p_true; expected_grad_false[3] = p_false;
501
502 TinyXgcfSpec {
503 num_nodes,
504 num_vars,
505 root,
506 node_type,
507 child_offsets,
508 child_indices,
509 lit,
510 decision_var,
511 decision_child_false,
512 decision_child_true,
513 level_nodes,
514 levels,
515 var_log_true,
516 var_log_false,
517 expected_values,
518 expected_grad_true,
519 expected_grad_false,
520 }
521}
522
523fn launch_level(
524 ctx: &TestContext,
525 kernel_name: &str,
526 num_level_nodes: u32,
527 params: &mut Vec<*mut c_void>,
528) -> Result<()> {
529 if num_level_nodes == 0 {
530 return Ok(());
531 }
532 let device = ctx.device.inner();
533 let kernel = device
534 .get_func(CIRCUIT_MODULE, kernel_name)
535 .ok_or_else(|| {
536 XlogError::Kernel(format!(
537 "Kernel {} not found in {}",
538 kernel_name, CIRCUIT_MODULE
539 ))
540 })?;
541
542 let block_size = 256u32;
543 let num_blocks = (num_level_nodes + block_size - 1) / block_size;
544 let config = LaunchConfig {
545 grid_dim: (num_blocks, 1, 1),
546 block_dim: (block_size, 1, 1),
547 shared_mem_bytes: 0,
548 };
549
550 unsafe { kernel.clone().launch(config, params) }
552 .map_err(|e| XlogError::Kernel(format!("Failed to launch {}: {}", kernel_name, e)))?;
553 Ok(())
554}
555
556pub fn run_tiny_xgcf_forward(ctx: &TestContext, spec: &TinyXgcfSpec) -> Result<Vec<f64>> {
557 let mut d_node_type = ctx.memory.alloc::<u8>(spec.node_type.len())?;
558 ctx.htod_sync_copy_into(&spec.node_type, &mut d_node_type)
559 .map_err(|e| XlogError::Kernel(format!("Failed to upload node_type: {}", e)))?;
560
561 let mut d_child_offsets = ctx.memory.alloc::<u32>(spec.child_offsets.len())?;
562 ctx.htod_sync_copy_into(&spec.child_offsets, &mut d_child_offsets)
563 .map_err(|e| XlogError::Kernel(format!("Failed to upload child_offsets: {}", e)))?;
564
565 let mut d_child_indices = ctx.memory.alloc::<u32>(spec.child_indices.len())?;
566 ctx.htod_sync_copy_into(&spec.child_indices, &mut d_child_indices)
567 .map_err(|e| XlogError::Kernel(format!("Failed to upload child_indices: {}", e)))?;
568
569 let mut d_lit = ctx.memory.alloc::<i32>(spec.lit.len())?;
570 ctx.htod_sync_copy_into(&spec.lit, &mut d_lit)
571 .map_err(|e| XlogError::Kernel(format!("Failed to upload lit: {}", e)))?;
572
573 let mut d_decision_var = ctx.memory.alloc::<u32>(spec.decision_var.len())?;
574 ctx.htod_sync_copy_into(&spec.decision_var, &mut d_decision_var)
575 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_var: {}", e)))?;
576
577 let mut d_decision_child_false = ctx.memory.alloc::<u32>(spec.decision_child_false.len())?;
578 ctx.htod_sync_copy_into(&spec.decision_child_false, &mut d_decision_child_false)
579 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_false: {}", e)))?;
580
581 let mut d_decision_child_true = ctx.memory.alloc::<u32>(spec.decision_child_true.len())?;
582 ctx.htod_sync_copy_into(&spec.decision_child_true, &mut d_decision_child_true)
583 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_true: {}", e)))?;
584
585 let mut d_level_nodes = ctx.memory.alloc::<u32>(spec.level_nodes.len())?;
586 ctx.htod_sync_copy_into(&spec.level_nodes, &mut d_level_nodes)
587 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_nodes: {}", e)))?;
588
589 let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
590 for &(offset, _len) in &spec.levels {
591 level_offsets.push(offset);
592 }
593 let (last_offset, last_len) = *spec
594 .levels
595 .last()
596 .ok_or_else(|| XlogError::Kernel("TinyXgcfSpec requires non-empty levels".to_string()))?;
597 level_offsets.push(last_offset + last_len);
598 if level_offsets[0] != 0 {
599 return Err(XlogError::Kernel(
600 "TinyXgcfSpec level_offsets must start at 0".to_string(),
601 ));
602 }
603 let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
604 ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
605 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
606
607 let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
608 for &(offset, _len) in &spec.levels {
609 level_offsets.push(offset);
610 }
611 let (last_offset, last_len) = *spec
612 .levels
613 .last()
614 .ok_or_else(|| XlogError::Kernel("TinyXgcfSpec requires non-empty levels".to_string()))?;
615 level_offsets.push(last_offset + last_len);
616 if level_offsets[0] != 0 {
617 return Err(XlogError::Kernel(
618 "TinyXgcfSpec level_offsets must start at 0".to_string(),
619 ));
620 }
621 let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
622 ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
623 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
624
625 let mut d_var_log_true = ctx.memory.alloc::<f64>(spec.var_log_true.len())?;
626 ctx.htod_sync_copy_into(&spec.var_log_true, &mut d_var_log_true)
627 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
628
629 let mut d_var_log_false = ctx.memory.alloc::<f64>(spec.var_log_false.len())?;
630 ctx.htod_sync_copy_into(&spec.var_log_false, &mut d_var_log_false)
631 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
632
633 let mut d_values = ctx.memory.alloc::<f64>(spec.num_nodes)?;
634 let init_values = vec![0.0f64; spec.num_nodes];
635 ctx.htod_sync_copy_into(&init_values, &mut d_values)
636 .map_err(|e| XlogError::Kernel(format!("Failed to init values: {}", e)))?;
637
638 for (level, &(_offset, len)) in spec.levels.iter().enumerate() {
639 let level_u32 = level as u32;
640 let mut params: Vec<*mut c_void> = vec![
641 (&d_node_type).as_kernel_param(),
642 (&d_child_offsets).as_kernel_param(),
643 (&d_child_indices).as_kernel_param(),
644 (&d_lit).as_kernel_param(),
645 (&d_decision_var).as_kernel_param(),
646 (&d_decision_child_false).as_kernel_param(),
647 (&d_decision_child_true).as_kernel_param(),
648 (&d_level_nodes).as_kernel_param(),
649 (&d_level_offsets).as_kernel_param(),
650 level_u32.as_kernel_param(),
651 (&d_var_log_true).as_kernel_param(),
652 (&d_var_log_false).as_kernel_param(),
653 (&mut d_values).as_kernel_param(),
654 ];
655 launch_level(ctx, circuit_kernels::XGCF_FORWARD_LEVEL, len, &mut params)?;
656 }
657
658 ctx.sync_and_check()?;
659
660 ctx.dtoh_sync_copy(&d_values)
661 .map_err(|e| XlogError::Kernel(format!("Failed to download values: {}", e)))
662}
663
664pub fn run_tiny_xgcf_backward(ctx: &TestContext, spec: &TinyXgcfSpec) -> Result<TinyXgcfRun> {
665 let mut d_node_type = ctx.memory.alloc::<u8>(spec.node_type.len())?;
666 ctx.htod_sync_copy_into(&spec.node_type, &mut d_node_type)
667 .map_err(|e| XlogError::Kernel(format!("Failed to upload node_type: {}", e)))?;
668
669 let mut d_child_offsets = ctx.memory.alloc::<u32>(spec.child_offsets.len())?;
670 ctx.htod_sync_copy_into(&spec.child_offsets, &mut d_child_offsets)
671 .map_err(|e| XlogError::Kernel(format!("Failed to upload child_offsets: {}", e)))?;
672
673 let mut d_child_indices = ctx.memory.alloc::<u32>(spec.child_indices.len())?;
674 ctx.htod_sync_copy_into(&spec.child_indices, &mut d_child_indices)
675 .map_err(|e| XlogError::Kernel(format!("Failed to upload child_indices: {}", e)))?;
676
677 let mut d_lit = ctx.memory.alloc::<i32>(spec.lit.len())?;
678 ctx.htod_sync_copy_into(&spec.lit, &mut d_lit)
679 .map_err(|e| XlogError::Kernel(format!("Failed to upload lit: {}", e)))?;
680
681 let mut d_decision_var = ctx.memory.alloc::<u32>(spec.decision_var.len())?;
682 ctx.htod_sync_copy_into(&spec.decision_var, &mut d_decision_var)
683 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_var: {}", e)))?;
684
685 let mut d_decision_child_false = ctx.memory.alloc::<u32>(spec.decision_child_false.len())?;
686 ctx.htod_sync_copy_into(&spec.decision_child_false, &mut d_decision_child_false)
687 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_false: {}", e)))?;
688
689 let mut d_decision_child_true = ctx.memory.alloc::<u32>(spec.decision_child_true.len())?;
690 ctx.htod_sync_copy_into(&spec.decision_child_true, &mut d_decision_child_true)
691 .map_err(|e| XlogError::Kernel(format!("Failed to upload decision_child_true: {}", e)))?;
692
693 let mut d_level_nodes = ctx.memory.alloc::<u32>(spec.level_nodes.len())?;
694 ctx.htod_sync_copy_into(&spec.level_nodes, &mut d_level_nodes)
695 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_nodes: {}", e)))?;
696
697 let mut level_offsets: Vec<u32> = Vec::with_capacity(spec.levels.len() + 1);
698 for &(offset, _len) in &spec.levels {
699 level_offsets.push(offset);
700 }
701 let (last_offset, last_len) = *spec
702 .levels
703 .last()
704 .ok_or_else(|| XlogError::Kernel("TinyXgcfSpec requires non-empty levels".to_string()))?;
705 level_offsets.push(last_offset + last_len);
706 if level_offsets[0] != 0 {
707 return Err(XlogError::Kernel(
708 "TinyXgcfSpec level_offsets must start at 0".to_string(),
709 ));
710 }
711 let mut d_level_offsets = ctx.memory.alloc::<u32>(level_offsets.len())?;
712 ctx.htod_sync_copy_into(&level_offsets, &mut d_level_offsets)
713 .map_err(|e| XlogError::Kernel(format!("Failed to upload level_offsets: {}", e)))?;
714
715 let mut d_var_log_true = ctx.memory.alloc::<f64>(spec.var_log_true.len())?;
716 ctx.htod_sync_copy_into(&spec.var_log_true, &mut d_var_log_true)
717 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_true: {}", e)))?;
718
719 let mut d_var_log_false = ctx.memory.alloc::<f64>(spec.var_log_false.len())?;
720 ctx.htod_sync_copy_into(&spec.var_log_false, &mut d_var_log_false)
721 .map_err(|e| XlogError::Kernel(format!("Failed to upload var_log_false: {}", e)))?;
722
723 let mut d_values = ctx.memory.alloc::<f64>(spec.num_nodes)?;
724 let init_values = vec![0.0f64; spec.num_nodes];
725 ctx.htod_sync_copy_into(&init_values, &mut d_values)
726 .map_err(|e| XlogError::Kernel(format!("Failed to init values: {}", e)))?;
727
728 for (level, &(_offset, len)) in spec.levels.iter().enumerate() {
729 let level_u32 = level as u32;
730 let mut params: Vec<*mut c_void> = vec![
731 (&d_node_type).as_kernel_param(),
732 (&d_child_offsets).as_kernel_param(),
733 (&d_child_indices).as_kernel_param(),
734 (&d_lit).as_kernel_param(),
735 (&d_decision_var).as_kernel_param(),
736 (&d_decision_child_false).as_kernel_param(),
737 (&d_decision_child_true).as_kernel_param(),
738 (&d_level_nodes).as_kernel_param(),
739 (&d_level_offsets).as_kernel_param(),
740 level_u32.as_kernel_param(),
741 (&d_var_log_true).as_kernel_param(),
742 (&d_var_log_false).as_kernel_param(),
743 (&mut d_values).as_kernel_param(),
744 ];
745 launch_level(ctx, circuit_kernels::XGCF_FORWARD_LEVEL, len, &mut params)?;
746 }
747
748 let mut adj_init = vec![0.0f64; spec.num_nodes];
750 let root_idx: usize = spec.root as usize;
751 if root_idx >= adj_init.len() {
752 return Err(XlogError::Kernel(format!(
753 "Root {} out of bounds for num_nodes {}",
754 spec.root, spec.num_nodes
755 )));
756 }
757 adj_init[root_idx] = 1.0;
758 let mut d_adj = ctx.memory.alloc::<f64>(spec.num_nodes)?;
759 ctx.htod_sync_copy_into(&adj_init, &mut d_adj)
760 .map_err(|e| XlogError::Kernel(format!("Failed to init adj: {}", e)))?;
761
762 let mut d_grad_true = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
763 let mut d_grad_false = ctx.memory.alloc::<f64>(spec.num_vars + 1)?;
764 let grad_init = vec![0.0f64; spec.num_vars + 1];
765 ctx.htod_sync_copy_into(&grad_init, &mut d_grad_true)
766 .map_err(|e| XlogError::Kernel(format!("Failed to init grad_true: {}", e)))?;
767 ctx.htod_sync_copy_into(&grad_init, &mut d_grad_false)
768 .map_err(|e| XlogError::Kernel(format!("Failed to init grad_false: {}", e)))?;
769
770 for (level, &(_offset, len)) in spec.levels.iter().enumerate().rev() {
771 let level_u32 = level as u32;
772 let mut params: Vec<*mut c_void> = vec![
773 (&d_node_type).as_kernel_param(),
774 (&d_child_offsets).as_kernel_param(),
775 (&d_child_indices).as_kernel_param(),
776 (&d_decision_var).as_kernel_param(),
777 (&d_decision_child_false).as_kernel_param(),
778 (&d_decision_child_true).as_kernel_param(),
779 (&d_level_nodes).as_kernel_param(),
780 (&d_level_offsets).as_kernel_param(),
781 level_u32.as_kernel_param(),
782 (&d_var_log_true).as_kernel_param(),
783 (&d_var_log_false).as_kernel_param(),
784 (&d_values).as_kernel_param(),
785 (&mut d_adj).as_kernel_param(),
786 ];
787 launch_level(
788 ctx,
789 circuit_kernels::XGCF_BACKWARD_LEVEL_PROPAGATE,
790 len,
791 &mut params,
792 )?;
793 }
794
795 for (level, &(_offset, len)) in spec.levels.iter().enumerate().rev() {
796 let level_u32 = level as u32;
797 let mut params: Vec<*mut c_void> = vec![
798 (&d_node_type).as_kernel_param(),
799 (&d_decision_var).as_kernel_param(),
800 (&d_decision_child_false).as_kernel_param(),
801 (&d_decision_child_true).as_kernel_param(),
802 (&d_level_nodes).as_kernel_param(),
803 (&d_level_offsets).as_kernel_param(),
804 level_u32.as_kernel_param(),
805 (&d_var_log_true).as_kernel_param(),
806 (&d_var_log_false).as_kernel_param(),
807 (&d_values).as_kernel_param(),
808 (&d_adj).as_kernel_param(),
809 (&mut d_grad_true).as_kernel_param(),
810 (&mut d_grad_false).as_kernel_param(),
811 ];
812 launch_level(
813 ctx,
814 circuit_kernels::XGCF_BACKWARD_LEVEL_DECISION_GRAD,
815 len,
816 &mut params,
817 )?;
818 }
819
820 for (level, &(_offset, len)) in spec.levels.iter().enumerate().rev() {
821 let level_u32 = level as u32;
822 let mut params: Vec<*mut c_void> = vec![
823 (&d_node_type).as_kernel_param(),
824 (&d_lit).as_kernel_param(),
825 (&d_level_nodes).as_kernel_param(),
826 (&d_level_offsets).as_kernel_param(),
827 level_u32.as_kernel_param(),
828 (&d_adj).as_kernel_param(),
829 (&mut d_grad_true).as_kernel_param(),
830 (&mut d_grad_false).as_kernel_param(),
831 ];
832 launch_level(
833 ctx,
834 circuit_kernels::XGCF_BACKWARD_LEVEL_LIT_GRAD,
835 len,
836 &mut params,
837 )?;
838 }
839
840 ctx.sync_and_check()?;
841
842 let values = ctx
843 .dtoh_sync_copy(&d_values)
844 .map_err(|e| XlogError::Kernel(format!("Failed to download values: {}", e)))?;
845 let adj = ctx
846 .dtoh_sync_copy(&d_adj)
847 .map_err(|e| XlogError::Kernel(format!("Failed to download adj: {}", e)))?;
848 let grad_true = ctx
849 .dtoh_sync_copy(&d_grad_true)
850 .map_err(|e| XlogError::Kernel(format!("Failed to download grad_true: {}", e)))?;
851 let grad_false = ctx
852 .dtoh_sync_copy(&d_grad_false)
853 .map_err(|e| XlogError::Kernel(format!("Failed to download grad_false: {}", e)))?;
854
855 Ok(TinyXgcfRun {
856 values,
857 adj,
858 grad_true,
859 grad_false,
860 })
861}
862
863pub fn gen_single_lit_circuit(var: u32) -> TinyXgcfSpec {
865 const LIT: u8 = 2;
866
867 let num_nodes = 1;
868 let num_vars = var as usize;
869 let root = 0;
870
871 let node_type = vec![LIT];
872 let child_offsets = vec![0, 0];
873 let child_indices = vec![];
874 let lit = vec![var as i32];
875 let decision_var = vec![0];
876 let decision_child_false = vec![0];
877 let decision_child_true = vec![0];
878 let level_nodes = vec![0];
879 let levels = vec![(0, 1)];
880
881 let mut var_log_true = vec![0.0; num_vars + 1];
882 let mut var_log_false = vec![0.0; num_vars + 1];
883 var_log_true[var as usize] = 0.7_f64.ln();
884 var_log_false[var as usize] = 0.3_f64.ln();
885
886 let expected_values = vec![var_log_true[var as usize]];
887 let mut expected_grad_true = vec![0.0; num_vars + 1];
888 let expected_grad_false = vec![0.0; num_vars + 1];
889 expected_grad_true[var as usize] = 1.0;
890
891 TinyXgcfSpec {
892 num_nodes,
893 num_vars,
894 root,
895 node_type,
896 child_offsets,
897 child_indices,
898 lit,
899 decision_var,
900 decision_child_false,
901 decision_child_true,
902 level_nodes,
903 levels,
904 var_log_true,
905 var_log_false,
906 expected_values,
907 expected_grad_true,
908 expected_grad_false,
909 }
910}
911
912pub fn gen_and_circuit() -> TinyXgcfSpec {
914 const LIT: u8 = 2;
915 const AND: u8 = 3;
916
917 let num_nodes = 3;
918 let num_vars = 2;
919 let root = 2;
920
921 let node_type = vec![LIT, LIT, AND];
922 let child_offsets = vec![0, 0, 0, 2];
923 let child_indices = vec![0, 1];
924 let lit = vec![1, 2, 0];
925 let decision_var = vec![0, 0, 0];
926 let decision_child_false = vec![0, 0, 0];
927 let decision_child_true = vec![0, 0, 0];
928 let level_nodes = vec![0, 1, 2];
929 let levels = vec![(0, 2), (2, 1)];
930
931 let p1 = 0.7_f64;
932 let p2 = 0.6_f64;
933 let var_log_true = vec![0.0, p1.ln(), p2.ln()];
934 let var_log_false = vec![0.0, (1.0 - p1).ln(), (1.0 - p2).ln()];
935
936 let v0 = var_log_true[1];
937 let v1 = var_log_true[2];
938 let v2 = v0 + v1;
939 let expected_values = vec![v0, v1, v2];
940 let expected_grad_true = vec![0.0, 1.0, 1.0];
941 let expected_grad_false = vec![0.0, 0.0, 0.0];
942
943 TinyXgcfSpec {
944 num_nodes,
945 num_vars,
946 root,
947 node_type,
948 child_offsets,
949 child_indices,
950 lit,
951 decision_var,
952 decision_child_false,
953 decision_child_true,
954 level_nodes,
955 levels,
956 var_log_true,
957 var_log_false,
958 expected_values,
959 expected_grad_true,
960 expected_grad_false,
961 }
962}
963
964pub fn gen_or_circuit() -> TinyXgcfSpec {
966 const LIT: u8 = 2;
967 const OR: u8 = 4;
968
969 let num_nodes = 3;
970 let num_vars = 2;
971 let root = 2;
972
973 let node_type = vec![LIT, LIT, OR];
974 let child_offsets = vec![0, 0, 0, 2];
975 let child_indices = vec![0, 1];
976 let lit = vec![1, 2, 0];
977 let decision_var = vec![0, 0, 0];
978 let decision_child_false = vec![0, 0, 0];
979 let decision_child_true = vec![0, 0, 0];
980 let level_nodes = vec![0, 1, 2];
981 let levels = vec![(0, 2), (2, 1)];
982
983 let p1 = 0.7_f64;
984 let p2 = 0.6_f64;
985 let var_log_true = vec![0.0, p1.ln(), p2.ln()];
986 let var_log_false = vec![0.0, (1.0 - p1).ln(), (1.0 - p2).ln()];
987
988 let v0 = var_log_true[1];
989 let v1 = var_log_true[2];
990 let v2 = logsumexp2(v0, v1);
991 let expected_values = vec![v0, v1, v2];
992
993 let p_child0 = (v0 - v2).exp();
994 let p_child1 = (v1 - v2).exp();
995 let expected_grad_true = vec![0.0, p_child0, p_child1];
996 let expected_grad_false = vec![0.0, 0.0, 0.0];
997
998 TinyXgcfSpec {
999 num_nodes,
1000 num_vars,
1001 root,
1002 node_type,
1003 child_offsets,
1004 child_indices,
1005 lit,
1006 decision_var,
1007 decision_child_false,
1008 decision_child_true,
1009 level_nodes,
1010 levels,
1011 var_log_true,
1012 var_log_false,
1013 expected_values,
1014 expected_grad_true,
1015 expected_grad_false,
1016 }
1017}
1018
1019pub fn gen_decision_circuit() -> TinyXgcfSpec {
1021 const CONST1: u8 = 1;
1022 const LIT: u8 = 2;
1023 const DECISION: u8 = 5;
1024
1025 let num_nodes = 3;
1026 let num_vars = 2;
1027 let root = 2;
1028
1029 let node_type = vec![CONST1, LIT, DECISION];
1030 let child_offsets = vec![0, 0, 0, 0];
1031 let child_indices = vec![];
1032 let lit = vec![0, 1, 0];
1033 let decision_var = vec![0, 0, 2];
1034 let decision_child_false = vec![0, 0, 0];
1035 let decision_child_true = vec![0, 0, 1];
1036 let level_nodes = vec![0, 1, 2];
1037 let levels = vec![(0, 2), (2, 1)];
1038
1039 let p1 = 0.7_f64;
1040 let p2 = 0.6_f64;
1041 let var_log_true = vec![0.0, p1.ln(), p2.ln()];
1042 let var_log_false = vec![0.0, (1.0 - p1).ln(), (1.0 - p2).ln()];
1043
1044 let v0 = 0.0;
1045 let v1 = var_log_true[1];
1046 let v2 = logsumexp2(var_log_false[2] + v0, var_log_true[2] + v1);
1047 let expected_values = vec![v0, v1, v2];
1048
1049 let p_false = (var_log_false[2] + v0 - v2).exp();
1050 let p_true = (var_log_true[2] + v1 - v2).exp();
1051
1052 let expected_grad_true = vec![0.0, p_true, p_true];
1053 let expected_grad_false = vec![0.0, 0.0, p_false];
1054
1055 TinyXgcfSpec {
1056 num_nodes,
1057 num_vars,
1058 root,
1059 node_type,
1060 child_offsets,
1061 child_indices,
1062 lit,
1063 decision_var,
1064 decision_child_false,
1065 decision_child_true,
1066 level_nodes,
1067 levels,
1068 var_log_true,
1069 var_log_false,
1070 expected_values,
1071 expected_grad_true,
1072 expected_grad_false,
1073 }
1074}
1075
1076pub fn gen_large_or_circuit(num_vars: usize) -> TinyXgcfSpec {
1078 const LIT: u8 = 2;
1079 const OR: u8 = 4;
1080
1081 let num_nodes = num_vars + 1;
1082 let root = num_vars as u32;
1083
1084 let mut node_type = vec![LIT; num_vars];
1085 node_type.push(OR);
1086
1087 let mut child_offsets: Vec<u32> = (0..=num_vars).map(|_| 0).collect();
1088 child_offsets.push(num_vars as u32);
1089
1090 let child_indices: Vec<u32> = (0..num_vars as u32).collect();
1091
1092 let mut lit: Vec<i32> = (1..=num_vars as i32).collect();
1093 lit.push(0);
1094
1095 let decision_var = vec![0; num_nodes];
1096 let decision_child_false = vec![0; num_nodes];
1097 let decision_child_true = vec![0; num_nodes];
1098
1099 let mut level_nodes: Vec<u32> = (0..num_vars as u32).collect();
1100 level_nodes.push(root);
1101 let levels = vec![(0, num_vars as u32), (num_vars as u32, 1)];
1102
1103 let p = 0.5_f64;
1104 let mut var_log_true = vec![0.0; num_vars + 1];
1105 let mut var_log_false = vec![0.0; num_vars + 1];
1106 for i in 1..=num_vars {
1107 var_log_true[i] = p.ln();
1108 var_log_false[i] = (1.0 - p).ln();
1109 }
1110
1111 let lit_val = p.ln();
1112 let mut expected_values = vec![lit_val; num_vars];
1113 let or_val = lit_val + (num_vars as f64).ln();
1114 expected_values.push(or_val);
1115
1116 let grad_per_lit = 1.0 / num_vars as f64;
1117 let mut expected_grad_true = vec![0.0; num_vars + 1];
1118 for i in 1..=num_vars {
1119 expected_grad_true[i] = grad_per_lit;
1120 }
1121 let expected_grad_false = vec![0.0; num_vars + 1];
1122
1123 TinyXgcfSpec {
1124 num_nodes,
1125 num_vars,
1126 root,
1127 node_type,
1128 child_offsets,
1129 child_indices,
1130 lit,
1131 decision_var,
1132 decision_child_false,
1133 decision_child_true,
1134 level_nodes,
1135 levels,
1136 var_log_true,
1137 var_log_false,
1138 expected_values,
1139 expected_grad_true,
1140 expected_grad_false,
1141 }
1142}
1143
1144pub fn gen_deep_chain_circuit(depth: usize) -> TinyXgcfSpec {
1146 const LIT: u8 = 2;
1147 const AND: u8 = 3;
1148
1149 let num_nodes = depth + 1;
1150 let num_vars = 1;
1151 let root = depth as u32;
1152
1153 let mut node_type = vec![LIT];
1154 for _ in 0..depth {
1155 node_type.push(AND);
1156 }
1157
1158 let mut child_offsets: Vec<u32> = vec![0];
1159 let mut child_indices: Vec<u32> = vec![];
1160 for i in 0..depth {
1161 child_offsets.push(child_indices.len() as u32);
1162 child_indices.push(i as u32);
1163 }
1164 child_offsets.push(child_indices.len() as u32);
1165
1166 let mut lit = vec![1i32];
1167 lit.extend(vec![0i32; depth]);
1168
1169 let decision_var = vec![0; num_nodes];
1170 let decision_child_false = vec![0; num_nodes];
1171 let decision_child_true = vec![0; num_nodes];
1172
1173 let level_nodes: Vec<u32> = (0..num_nodes as u32).collect();
1174 let levels: Vec<(u32, u32)> = (0..num_nodes).map(|i| (i as u32, 1)).collect();
1175
1176 let p = 0.7_f64;
1177 let var_log_true = vec![0.0, p.ln()];
1178 let var_log_false = vec![0.0, (1.0 - p).ln()];
1179
1180 let lit_val = p.ln();
1181 let expected_values = vec![lit_val; num_nodes];
1182
1183 let expected_grad_true = vec![0.0, 1.0];
1184 let expected_grad_false = vec![0.0, 0.0];
1185
1186 TinyXgcfSpec {
1187 num_nodes,
1188 num_vars,
1189 root,
1190 node_type,
1191 child_offsets,
1192 child_indices,
1193 lit,
1194 decision_var,
1195 decision_child_false,
1196 decision_child_true,
1197 level_nodes,
1198 levels,
1199 var_log_true,
1200 var_log_false,
1201 expected_values,
1202 expected_grad_true,
1203 expected_grad_false,
1204 }
1205}
1206
1207pub fn numerical_gradient(
1209 ctx: &TestContext,
1210 spec: &TinyXgcfSpec,
1211 var: usize,
1212 eps: f64,
1213) -> xlog_core::Result<(f64, f64)> {
1214 let mut spec_plus = spec.clone();
1215 let mut spec_minus = spec.clone();
1216 spec_plus.var_log_true[var] += eps;
1217 spec_minus.var_log_true[var] -= eps;
1218
1219 let values_plus = run_tiny_xgcf_forward(ctx, &spec_plus)?;
1220 let values_minus = run_tiny_xgcf_forward(ctx, &spec_minus)?;
1221
1222 let grad_true =
1223 (values_plus[spec.root as usize] - values_minus[spec.root as usize]) / (2.0 * eps);
1224
1225 let mut spec_plus = spec.clone();
1226 let mut spec_minus = spec.clone();
1227 spec_plus.var_log_false[var] += eps;
1228 spec_minus.var_log_false[var] -= eps;
1229
1230 let values_plus = run_tiny_xgcf_forward(ctx, &spec_plus)?;
1231 let values_minus = run_tiny_xgcf_forward(ctx, &spec_minus)?;
1232
1233 let grad_false =
1234 (values_plus[spec.root as usize] - values_minus[spec.root as usize]) / (2.0 * eps);
1235
1236 Ok((grad_true, grad_false))
1237}