xlog_cuda/provider/
probabilistic.rs1use crate::{CudaView, LaunchAsync, LaunchConfig};
4use xlog_core::{Result, XlogError};
5
6use super::{mc_sample_kernels, MC_SAMPLE_MODULE};
7use crate::memory::TrackedCudaSlice;
8
9impl super::CudaKernelProvider {
10 pub fn sample_bernoulli_matrix(
15 &self,
16 probs: &[f32],
17 num_samples: usize,
18 seed: u64,
19 force_mask: &CudaView<u8>,
20 forced_value: &CudaView<u8>,
21 ) -> Result<Vec<u8>> {
22 if probs.is_empty() || num_samples == 0 {
23 return Ok(Vec::new());
24 }
25
26 let num_vars_u32: u32 = probs.len().try_into().map_err(|_| {
27 XlogError::Kernel(format!(
28 "sample_bernoulli_matrix: num_vars {} exceeds u32::MAX",
29 probs.len()
30 ))
31 })?;
32 let num_samples_u32: u32 = num_samples.try_into().map_err(|_| {
33 XlogError::Kernel(format!(
34 "sample_bernoulli_matrix: num_samples {} exceeds u32::MAX",
35 num_samples
36 ))
37 })?;
38
39 let total = probs.len().checked_mul(num_samples).ok_or_else(|| {
40 XlogError::Kernel("sample_bernoulli_matrix: size overflow".to_string())
41 })?;
42
43 let device = self.device.inner();
44
45 let mut d_probs = self.memory.alloc::<f32>(probs.len())?;
46 self.htod_sync_copy_into_tracked(probs, &mut d_probs)
47 .map_err(|e| XlogError::Kernel(format!("Failed to upload Bernoulli probs: {}", e)))?;
48
49 let mut d_out = self.memory.alloc::<u8>(total)?;
50
51 let block_size = 256u32;
52 let total_u32: u32 = total.try_into().map_err(|_| {
53 XlogError::Kernel(format!(
54 "sample_bernoulli_matrix: total {} exceeds u32::MAX",
55 total
56 ))
57 })?;
58 let num_blocks = total_u32.div_ceil(block_size);
59 let config = LaunchConfig {
60 grid_dim: (num_blocks, 1, 1),
61 block_dim: (block_size, 1, 1),
62 shared_mem_bytes: 0,
63 };
64
65 let kernel = device
66 .get_func(MC_SAMPLE_MODULE, mc_sample_kernels::MC_SAMPLE_BERNOULLI)
67 .ok_or_else(|| XlogError::Kernel("mc_sample_bernoulli kernel not found".to_string()))?;
68
69 unsafe {
71 kernel.clone().launch(
72 config,
73 (
74 &mut d_out,
75 &d_probs,
76 force_mask,
77 forced_value,
78 num_vars_u32,
79 num_samples_u32,
80 seed,
81 ),
82 )
83 }
84 .map_err(|e| XlogError::Kernel(format!("Failed to launch mc_sample_bernoulli: {}", e)))?;
85
86 let mut host: Vec<u8> = vec![0u8; total];
87 device.dtoh_sync_copy_into(&d_out, &mut host).map_err(|e| {
88 XlogError::Kernel(format!("Failed to download Bernoulli samples: {}", e))
89 })?;
90
91 Ok(host)
92 }
93
94 pub fn sample_bernoulli_matrix_device(
98 &self,
99 probs: &[f32],
100 num_samples: usize,
101 seed: u64,
102 force_mask: &CudaView<u8>,
103 forced_value: &CudaView<u8>,
104 ) -> Result<TrackedCudaSlice<u8>> {
105 if probs.is_empty() || num_samples == 0 {
106 return self.memory.alloc::<u8>(0).map_err(|e| {
107 XlogError::Kernel(format!("Failed to allocate empty sample matrix: {}", e))
108 });
109 }
110
111 let num_vars_u32: u32 = probs.len().try_into().map_err(|_| {
112 XlogError::Kernel(format!(
113 "sample_bernoulli_matrix_device: num_vars {} exceeds u32::MAX",
114 probs.len()
115 ))
116 })?;
117 let num_samples_u32: u32 = num_samples.try_into().map_err(|_| {
118 XlogError::Kernel(format!(
119 "sample_bernoulli_matrix_device: num_samples {} exceeds u32::MAX",
120 num_samples
121 ))
122 })?;
123
124 let total = probs.len().saturating_mul(num_samples);
125 let device = self.device.inner();
126
127 let mut d_probs = self.memory.alloc::<f32>(probs.len())?;
128 self.htod_sync_copy_into_tracked(probs, &mut d_probs)
129 .map_err(|e| XlogError::Kernel(format!("Failed to upload Bernoulli probs: {}", e)))?;
130
131 let mut d_out = self.memory.alloc::<u8>(total)?;
132
133 let block_size = 256u32;
134 let total_u32: u32 = total.try_into().map_err(|_| {
135 XlogError::Kernel(format!(
136 "sample_bernoulli_matrix_device: total {} exceeds u32::MAX",
137 total
138 ))
139 })?;
140 let num_blocks = total_u32.div_ceil(block_size);
141 let config = LaunchConfig {
142 grid_dim: (num_blocks, 1, 1),
143 block_dim: (block_size, 1, 1),
144 shared_mem_bytes: 0,
145 };
146
147 let kernel = device
148 .get_func(MC_SAMPLE_MODULE, mc_sample_kernels::MC_SAMPLE_BERNOULLI)
149 .ok_or_else(|| XlogError::Kernel("mc_sample_bernoulli kernel not found".to_string()))?;
150
151 unsafe {
153 kernel.clone().launch(
154 config,
155 (
156 &mut d_out,
157 &d_probs,
158 force_mask,
159 forced_value,
160 num_vars_u32,
161 num_samples_u32,
162 seed,
163 ),
164 )
165 }
166 .map_err(|e| XlogError::Kernel(format!("Failed to launch mc_sample_bernoulli: {}", e)))?;
167
168 Ok(d_out)
169 }
170}