xlog_prob/compilation/
gpu_pir.rs1use std::sync::Arc;
6
7use cudarc::driver::DeviceSlice;
8use xlog_core::{Result, XlogError};
9use xlog_cuda::memory::TrackedCudaSlice;
10use xlog_cuda::CudaKernelProvider;
11
12use crate::pir::{PirGraph, PirNode, PirNodeId};
13
14pub(crate) const PIR_CONST: u8 = 0;
16pub const PIR_LIT: u8 = 1;
17pub const PIR_NEG_LIT: u8 = 2;
18pub const PIR_AND: u8 = 3;
19pub const PIR_OR: u8 = 4;
20pub(crate) const PIR_DECISION: u8 = 5;
21
22pub struct GpuPirGraph {
24 pub node_type: TrackedCudaSlice<u8>,
25 pub child_offsets: TrackedCudaSlice<u32>,
26 pub children: TrackedCudaSlice<u32>,
27 pub leaf_id: TrackedCudaSlice<u32>,
28 pub decision_var: TrackedCudaSlice<u32>,
29 pub decision_child_false: TrackedCudaSlice<u32>,
30 pub decision_child_true: TrackedCudaSlice<u32>,
31}
32
33pub struct GpuPirRoots {
35 pub roots: TrackedCudaSlice<u32>,
36}
37
38impl GpuPirGraph {
39 pub fn from_host(pir: &PirGraph, provider: &Arc<CudaKernelProvider>) -> Result<Self> {
44 let num_nodes = pir.len();
45 let num_nodes_u32 = u32::try_from(num_nodes).map_err(|_| {
46 XlogError::Compilation("GpuPirGraph::from_host: node count overflow".to_string())
47 })?;
48
49 let mut node_type: Vec<u8> = Vec::with_capacity(num_nodes);
50 let mut child_offsets: Vec<u32> = Vec::with_capacity(num_nodes + 1);
51 let mut children: Vec<u32> = Vec::new();
52 let mut leaf_id: Vec<u32> = Vec::with_capacity(num_nodes);
53 let mut decision_var: Vec<u32> = Vec::with_capacity(num_nodes);
54 let mut decision_child_false: Vec<u32> = Vec::with_capacity(num_nodes);
55 let mut decision_child_true: Vec<u32> = Vec::with_capacity(num_nodes);
56
57 child_offsets.push(0);
58
59 for (idx, node) in pir.nodes().iter().enumerate() {
60 let node_id = u32::try_from(idx).map_err(|_| {
61 XlogError::Compilation("GpuPirGraph::from_host: node id overflow".to_string())
62 })?;
63
64 match node {
65 PirNode::Const(value) => {
66 node_type.push(PIR_CONST);
67 leaf_id.push(u32::from(*value));
68 decision_var.push(0);
69 decision_child_false.push(0);
70 decision_child_true.push(0);
71 }
72 PirNode::Lit { leaf } => {
73 node_type.push(PIR_LIT);
74 leaf_id.push(leaf.as_u32());
75 decision_var.push(0);
76 decision_child_false.push(0);
77 decision_child_true.push(0);
78 }
79 PirNode::NegLit { leaf } => {
80 node_type.push(PIR_NEG_LIT);
81 leaf_id.push(leaf.as_u32());
82 decision_var.push(0);
83 decision_child_false.push(0);
84 decision_child_true.push(0);
85 }
86 PirNode::And { children: kids } => {
87 validate_children_sorted(node_id, kids, num_nodes_u32)?;
88 node_type.push(PIR_AND);
89 leaf_id.push(0);
90 decision_var.push(0);
91 decision_child_false.push(0);
92 decision_child_true.push(0);
93 for &child in kids {
94 children.push(child.as_u32());
95 }
96 }
97 PirNode::Or { children: kids } => {
98 validate_children_sorted(node_id, kids, num_nodes_u32)?;
99 node_type.push(PIR_OR);
100 leaf_id.push(0);
101 decision_var.push(0);
102 decision_child_false.push(0);
103 decision_child_true.push(0);
104 for &child in kids {
105 children.push(child.as_u32());
106 }
107 }
108 PirNode::Decision {
109 var,
110 child_false,
111 child_true,
112 } => {
113 validate_child_id(node_id, *child_false, num_nodes_u32)?;
114 validate_child_id(node_id, *child_true, num_nodes_u32)?;
115 node_type.push(PIR_DECISION);
116 leaf_id.push(0);
117 decision_var.push(var.as_u32());
118 decision_child_false.push(child_false.as_u32());
119 decision_child_true.push(child_true.as_u32());
120 }
121 }
122
123 let next_off = u32::try_from(children.len()).map_err(|_| {
124 XlogError::Compilation(
125 "GpuPirGraph::from_host: children count exceeds u32".to_string(),
126 )
127 })?;
128 child_offsets.push(next_off);
129 }
130
131 if child_offsets.len() != num_nodes + 1 {
132 return Err(XlogError::Compilation(
133 "GpuPirGraph::from_host: child_offsets length mismatch".to_string(),
134 ));
135 }
136
137 let memory = provider.memory();
138
139 let mut d_node_type = memory.alloc::<u8>(node_type.len())?;
140 let mut d_child_offsets = memory.alloc::<u32>(child_offsets.len())?;
141 let mut d_children = memory.alloc::<u32>(children.len())?;
142 let mut d_leaf_id = memory.alloc::<u32>(leaf_id.len())?;
143 let mut d_decision_var = memory.alloc::<u32>(decision_var.len())?;
144 let mut d_decision_child_false = memory.alloc::<u32>(decision_child_false.len())?;
145 let mut d_decision_child_true = memory.alloc::<u32>(decision_child_true.len())?;
146
147 provider
148 .htod_sync_copy_into_tracked(&node_type, &mut d_node_type)
149 .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload node_type: {}", e)))?;
150 provider
151 .htod_sync_copy_into_tracked(&child_offsets, &mut d_child_offsets)
152 .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload child_offsets: {}", e)))?;
153 provider
154 .htod_sync_copy_into_tracked(&children, &mut d_children)
155 .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload children: {}", e)))?;
156 provider
157 .htod_sync_copy_into_tracked(&leaf_id, &mut d_leaf_id)
158 .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload leaf_id: {}", e)))?;
159 provider
160 .htod_sync_copy_into_tracked(&decision_var, &mut d_decision_var)
161 .map_err(|e| XlogError::Kernel(format!("GpuPirGraph upload decision_var: {}", e)))?;
162 provider
163 .htod_sync_copy_into_tracked(&decision_child_false, &mut d_decision_child_false)
164 .map_err(|e| {
165 XlogError::Kernel(format!("GpuPirGraph upload decision_child_false: {}", e))
166 })?;
167 provider
168 .htod_sync_copy_into_tracked(&decision_child_true, &mut d_decision_child_true)
169 .map_err(|e| {
170 XlogError::Kernel(format!("GpuPirGraph upload decision_child_true: {}", e))
171 })?;
172
173 Ok(Self {
174 node_type: d_node_type,
175 child_offsets: d_child_offsets,
176 children: d_children,
177 leaf_id: d_leaf_id,
178 decision_var: d_decision_var,
179 decision_child_false: d_decision_child_false,
180 decision_child_true: d_decision_child_true,
181 })
182 }
183
184 pub fn num_nodes(&self) -> usize {
185 self.node_type.len()
186 }
187}
188
189impl GpuPirRoots {
190 pub fn from_host(roots: &[PirNodeId], provider: &Arc<CudaKernelProvider>) -> Result<Self> {
191 let mut host: Vec<u32> = Vec::with_capacity(roots.len());
192 for &r in roots {
193 host.push(r.as_u32());
194 }
195
196 let memory = provider.memory();
197 let mut d_roots = memory.alloc::<u32>(host.len())?;
198 provider
199 .htod_sync_copy_into_tracked(&host, &mut d_roots)
200 .map_err(|e| XlogError::Kernel(format!("GpuPirRoots upload: {}", e)))?;
201
202 Ok(Self { roots: d_roots })
203 }
204}
205
206fn validate_child_id(parent: u32, child: PirNodeId, num_nodes: u32) -> Result<()> {
207 let id = child.as_u32();
208 if id >= num_nodes {
209 return Err(XlogError::Compilation(format!(
210 "GpuPirGraph::from_host: child {:?} out of bounds for parent {}",
211 child, parent
212 )));
213 }
214 Ok(())
215}
216
217fn validate_children_sorted(parent: u32, children: &[PirNodeId], num_nodes: u32) -> Result<()> {
218 let mut prev: Option<u32> = None;
219 for &child in children {
220 let id = child.as_u32();
221 if id >= num_nodes {
222 return Err(XlogError::Compilation(format!(
223 "GpuPirGraph::from_host: child {:?} out of bounds for parent {}",
224 child, parent
225 )));
226 }
227 if let Some(p) = prev {
228 if id <= p {
229 return Err(XlogError::Compilation(format!(
230 "GpuPirGraph::from_host: children of {} must be sorted and unique",
231 parent
232 )));
233 }
234 }
235 prev = Some(id);
236 }
237 Ok(())
238}