Skip to main content

xlog_cuda/provider/
probabilistic.rs

1//! Probabilistic operations: Monte Carlo sampling (Bernoulli matrix).
2
3use crate::{CudaView, LaunchAsync, LaunchConfig};
4use xlog_core::{Result, XlogError};
5
6use super::{mc_sample_kernels, MC_SAMPLE_MODULE};
7use crate::memory::TrackedCudaSlice;
8
9impl super::CudaKernelProvider {
10    /// Sample independent Bernoulli variables on the GPU.
11    ///
12    /// Returns a row-major `(sample, var)` matrix as a flat `Vec<u8>` of length
13    /// `num_samples * probs.len()`, where each entry is 0/1.
14    pub fn sample_bernoulli_matrix(
15        &self,
16        probs: &[f32],
17        num_samples: usize,
18        seed: u64,
19        force_mask: &CudaView<u8>,
20        forced_value: &CudaView<u8>,
21    ) -> Result<Vec<u8>> {
22        if probs.is_empty() || num_samples == 0 {
23            return Ok(Vec::new());
24        }
25
26        let num_vars_u32: u32 = probs.len().try_into().map_err(|_| {
27            XlogError::Kernel(format!(
28                "sample_bernoulli_matrix: num_vars {} exceeds u32::MAX",
29                probs.len()
30            ))
31        })?;
32        let num_samples_u32: u32 = num_samples.try_into().map_err(|_| {
33            XlogError::Kernel(format!(
34                "sample_bernoulli_matrix: num_samples {} exceeds u32::MAX",
35                num_samples
36            ))
37        })?;
38
39        let total = probs.len().checked_mul(num_samples).ok_or_else(|| {
40            XlogError::Kernel("sample_bernoulli_matrix: size overflow".to_string())
41        })?;
42
43        let device = self.device.inner();
44
45        let mut d_probs = self.memory.alloc::<f32>(probs.len())?;
46        self.htod_sync_copy_into_tracked(probs, &mut d_probs)
47            .map_err(|e| XlogError::Kernel(format!("Failed to upload Bernoulli probs: {}", e)))?;
48
49        let mut d_out = self.memory.alloc::<u8>(total)?;
50
51        let block_size = 256u32;
52        let total_u32: u32 = total.try_into().map_err(|_| {
53            XlogError::Kernel(format!(
54                "sample_bernoulli_matrix: total {} exceeds u32::MAX",
55                total
56            ))
57        })?;
58        let num_blocks = total_u32.div_ceil(block_size);
59        let config = LaunchConfig {
60            grid_dim: (num_blocks, 1, 1),
61            block_dim: (block_size, 1, 1),
62            shared_mem_bytes: 0,
63        };
64
65        let kernel = device
66            .get_func(MC_SAMPLE_MODULE, mc_sample_kernels::MC_SAMPLE_BERNOULLI)
67            .ok_or_else(|| XlogError::Kernel("mc_sample_bernoulli kernel not found".to_string()))?;
68
69        // SAFETY: mc_sample_bernoulli(out, probs, force_mask, forced_value, num_vars, num_samples, seed)
70        unsafe {
71            kernel.clone().launch(
72                config,
73                (
74                    &mut d_out,
75                    &d_probs,
76                    force_mask,
77                    forced_value,
78                    num_vars_u32,
79                    num_samples_u32,
80                    seed,
81                ),
82            )
83        }
84        .map_err(|e| XlogError::Kernel(format!("Failed to launch mc_sample_bernoulli: {}", e)))?;
85
86        let mut host: Vec<u8> = vec![0u8; total];
87        device.dtoh_sync_copy_into(&d_out, &mut host).map_err(|e| {
88            XlogError::Kernel(format!("Failed to download Bernoulli samples: {}", e))
89        })?;
90
91        Ok(host)
92    }
93
94    /// Sample Bernoulli matrix on GPU and return device-resident output.
95    ///
96    /// Returns a row-major [num_samples][num_vars] matrix of 0/1 bytes on device.
97    pub fn sample_bernoulli_matrix_device(
98        &self,
99        probs: &[f32],
100        num_samples: usize,
101        seed: u64,
102        force_mask: &CudaView<u8>,
103        forced_value: &CudaView<u8>,
104    ) -> Result<TrackedCudaSlice<u8>> {
105        if probs.is_empty() || num_samples == 0 {
106            return self.memory.alloc::<u8>(0).map_err(|e| {
107                XlogError::Kernel(format!("Failed to allocate empty sample matrix: {}", e))
108            });
109        }
110
111        let num_vars_u32: u32 = probs.len().try_into().map_err(|_| {
112            XlogError::Kernel(format!(
113                "sample_bernoulli_matrix_device: num_vars {} exceeds u32::MAX",
114                probs.len()
115            ))
116        })?;
117        let num_samples_u32: u32 = num_samples.try_into().map_err(|_| {
118            XlogError::Kernel(format!(
119                "sample_bernoulli_matrix_device: num_samples {} exceeds u32::MAX",
120                num_samples
121            ))
122        })?;
123
124        let total = probs.len().saturating_mul(num_samples);
125        let device = self.device.inner();
126
127        let mut d_probs = self.memory.alloc::<f32>(probs.len())?;
128        self.htod_sync_copy_into_tracked(probs, &mut d_probs)
129            .map_err(|e| XlogError::Kernel(format!("Failed to upload Bernoulli probs: {}", e)))?;
130
131        let mut d_out = self.memory.alloc::<u8>(total)?;
132
133        let block_size = 256u32;
134        let total_u32: u32 = total.try_into().map_err(|_| {
135            XlogError::Kernel(format!(
136                "sample_bernoulli_matrix_device: total {} exceeds u32::MAX",
137                total
138            ))
139        })?;
140        let num_blocks = total_u32.div_ceil(block_size);
141        let config = LaunchConfig {
142            grid_dim: (num_blocks, 1, 1),
143            block_dim: (block_size, 1, 1),
144            shared_mem_bytes: 0,
145        };
146
147        let kernel = device
148            .get_func(MC_SAMPLE_MODULE, mc_sample_kernels::MC_SAMPLE_BERNOULLI)
149            .ok_or_else(|| XlogError::Kernel("mc_sample_bernoulli kernel not found".to_string()))?;
150
151        // SAFETY: mc_sample_bernoulli(out, probs, force_mask, forced_value, num_vars, num_samples, seed)
152        unsafe {
153            kernel.clone().launch(
154                config,
155                (
156                    &mut d_out,
157                    &d_probs,
158                    force_mask,
159                    forced_value,
160                    num_vars_u32,
161                    num_samples_u32,
162                    seed,
163                ),
164            )
165        }
166        .map_err(|e| XlogError::Kernel(format!("Failed to launch mc_sample_bernoulli: {}", e)))?;
167
168        Ok(d_out)
169    }
170}