1use std::{mem, ptr};
9
10use cudarc::driver::{sys, CudaStream};
11use xlog_core::{Result, XlogError};
12
13pub const CSM_CUDA_GRAPH_NODE_LAYOUT_VERSION: u32 = 1;
14
15pub struct CapturedCudaGraph {
17 graph: sys::CUgraph,
18 exec: sys::CUgraphExec,
19}
20
21unsafe impl Send for CapturedCudaGraph {}
24unsafe impl Sync for CapturedCudaGraph {}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum CudaGraphNodeKind {
28 Kernel,
29 Memcpy,
30 Memset,
31 Host,
32 Graph,
33 Empty,
34 WaitEvent,
35 EventRecord,
36 ExternalSemaphoresSignal,
37 ExternalSemaphoresWait,
38 MemAlloc,
39 MemFree,
40 BatchMemOp,
41 Conditional,
42}
43
44#[derive(Debug, Clone, Copy)]
45pub struct CudaGraphNode {
46 pub index: usize,
47 pub raw: sys::CUgraphNode,
48 pub kind: CudaGraphNodeKind,
49}
50
51unsafe impl Send for CudaGraphNode {}
52unsafe impl Sync for CudaGraphNode {}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum CsmCudaGraphJoinKind {
56 Inner,
57 IndexedInner,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61pub struct ScanTopology {
62 pub input_len: u32,
63 pub block_size: u32,
64 pub scratch_lengths: Vec<u32>,
65 pub kernel_node_count: usize,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct CsmCudaGraphKey {
70 pub join_kind: CsmCudaGraphJoinKind,
71 pub key_arity: u8,
72 pub key_bytes: u32,
73 pub probe_capacity_class: u32,
74 pub output_capacity_class: u32,
75 pub scan_topology: ScanTopology,
76 pub node_layout_version: u32,
77}
78
79impl CsmCudaGraphKey {
80 pub fn inner(
81 key_arity: usize,
82 key_bytes: u32,
83 probe_capacity: u32,
84 output_capacity: u32,
85 ) -> Result<Self> {
86 let key_arity = u8::try_from(key_arity).map_err(|_| {
87 XlogError::Kernel(format!(
88 "CSM CUDA Graph key arity {} exceeds u8::MAX",
89 key_arity
90 ))
91 })?;
92 Ok(Self {
93 join_kind: CsmCudaGraphJoinKind::Inner,
94 key_arity,
95 key_bytes,
96 probe_capacity_class: graph_capacity_class_u32(probe_capacity),
97 output_capacity_class: graph_capacity_class_u32(output_capacity),
98 scan_topology: scan_topology_u32(probe_capacity),
99 node_layout_version: CSM_CUDA_GRAPH_NODE_LAYOUT_VERSION,
100 })
101 }
102}
103
104pub fn graph_capacity_class_u32(n: u32) -> u32 {
105 if n <= 1 {
106 1
107 } else {
108 n.checked_next_power_of_two().unwrap_or(u32::MAX)
109 }
110}
111
112pub fn scan_topology_u32(mut n: u32) -> ScanTopology {
113 let input_len = n;
114 let block_size = 256u32;
115 let mut scratch_lengths = Vec::new();
116 let mut kernel_node_count = if n == 0 { 0 } else { 1 };
117 while n > block_size {
118 let num_blocks = n.div_ceil(block_size);
119 scratch_lengths.push(num_blocks);
120 kernel_node_count += 2;
121 n = num_blocks;
122 }
123 ScanTopology {
124 input_len,
125 block_size,
126 scratch_lengths,
127 kernel_node_count,
128 }
129}
130
131impl CapturedCudaGraph {
132 pub fn capture_on_stream<F>(stream: &CudaStream, record: F) -> Result<Self>
135 where
136 F: FnOnce() -> Result<()>,
137 {
138 unsafe {
139 cuda_graph_check(
140 "cuStreamBeginCapture_v2",
141 sys::cuStreamBeginCapture_v2(
142 stream.cu_stream(),
143 sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
144 ),
145 )?;
146 }
147
148 let record_result = record();
149 let mut graph: sys::CUgraph = ptr::null_mut();
150 let end_result = unsafe {
151 cuda_graph_check(
152 "cuStreamEndCapture",
153 sys::cuStreamEndCapture(stream.cu_stream(), &mut graph),
154 )
155 };
156
157 if let Err(record_err) = record_result {
158 if end_result.is_ok() && !graph.is_null() {
159 unsafe {
160 let _ = sys::cuGraphDestroy(graph);
161 }
162 }
163 return Err(record_err);
164 }
165 end_result?;
166 if graph.is_null() {
167 return Err(XlogError::Kernel(
168 "cuStreamEndCapture returned a null CUDA graph".to_string(),
169 ));
170 }
171
172 let mut exec: sys::CUgraphExec = ptr::null_mut();
173 unsafe {
174 if let Err(err) = cuda_graph_check(
175 "cuGraphInstantiateWithFlags",
176 sys::cuGraphInstantiateWithFlags(&mut exec, graph, 0),
177 ) {
178 let _ = sys::cuGraphDestroy(graph);
179 return Err(err);
180 }
181 }
182 if exec.is_null() {
183 unsafe {
184 let _ = sys::cuGraphDestroy(graph);
185 }
186 return Err(XlogError::Kernel(
187 "cuGraphInstantiateWithFlags returned a null CUDA graph exec".to_string(),
188 ));
189 }
190
191 Ok(Self { graph, exec })
192 }
193
194 pub fn launch(&self, stream: &CudaStream) -> Result<()> {
196 unsafe {
197 cuda_graph_check(
198 "cuGraphLaunch",
199 sys::cuGraphLaunch(self.exec, stream.cu_stream()),
200 )
201 }
202 }
203
204 pub fn node_count(&self) -> Result<usize> {
207 let mut count = 0usize;
208 unsafe {
209 cuda_graph_check(
210 "cuGraphGetNodes(count)",
211 sys::cuGraphGetNodes(self.graph, ptr::null_mut(), &mut count),
212 )?;
213 }
214 Ok(count)
215 }
216
217 pub fn nodes(&self) -> Result<Vec<CudaGraphNode>> {
219 let count = self.node_count()?;
220 if count == 0 {
221 return Ok(Vec::new());
222 }
223 let mut raw_nodes = vec![ptr::null_mut(); count];
224 let mut count_again = count;
225 unsafe {
226 cuda_graph_check(
227 "cuGraphGetNodes(nodes)",
228 sys::cuGraphGetNodes(self.graph, raw_nodes.as_mut_ptr(), &mut count_again),
229 )?;
230 }
231 raw_nodes.truncate(count_again);
232
233 let mut nodes = Vec::with_capacity(raw_nodes.len());
234 for (index, raw) in raw_nodes.into_iter().enumerate() {
235 let mut ty = sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY;
236 unsafe {
237 cuda_graph_check("cuGraphNodeGetType", sys::cuGraphNodeGetType(raw, &mut ty))?;
238 }
239 nodes.push(CudaGraphNode {
240 index,
241 raw,
242 kind: CudaGraphNodeKind::from_sys(ty),
243 });
244 }
245 Ok(nodes)
246 }
247
248 pub fn kernel_node_params(&self, node: CudaGraphNode) -> Result<sys::CUDA_KERNEL_NODE_PARAMS> {
254 if node.kind != CudaGraphNodeKind::Kernel {
255 return Err(XlogError::Kernel(format!(
256 "kernel_node_params called for non-kernel graph node {:?}",
257 node.kind
258 )));
259 }
260 let mut params: sys::CUDA_KERNEL_NODE_PARAMS = unsafe { mem::zeroed() };
261 unsafe {
262 cuda_graph_check(
263 "cuGraphKernelNodeGetParams_v2",
264 sys::cuGraphKernelNodeGetParams_v2(node.raw, &mut params),
265 )?;
266 }
267 Ok(params)
268 }
269
270 pub unsafe fn set_kernel_node_params(
277 &self,
278 node: CudaGraphNode,
279 params: &sys::CUDA_KERNEL_NODE_PARAMS,
280 ) -> Result<()> {
281 if node.kind != CudaGraphNodeKind::Kernel {
282 return Err(XlogError::Kernel(format!(
283 "set_kernel_node_params called for non-kernel graph node {:?}",
284 node.kind
285 )));
286 }
287 cuda_graph_check(
288 "cuGraphExecKernelNodeSetParams_v2",
289 sys::cuGraphExecKernelNodeSetParams_v2(self.exec, node.raw, params),
290 )
291 }
292
293 pub fn memset_node_params(&self, node: CudaGraphNode) -> Result<sys::CUDA_MEMSET_NODE_PARAMS> {
295 if node.kind != CudaGraphNodeKind::Memset {
296 return Err(XlogError::Kernel(format!(
297 "memset_node_params called for non-memset graph node {:?}",
298 node.kind
299 )));
300 }
301 let mut params: sys::CUDA_MEMSET_NODE_PARAMS = unsafe { mem::zeroed() };
302 unsafe {
303 cuda_graph_check(
304 "cuGraphMemsetNodeGetParams",
305 sys::cuGraphMemsetNodeGetParams(node.raw, &mut params),
306 )?;
307 }
308 Ok(params)
309 }
310
311 pub fn set_memset_node_params(
313 &self,
314 node: CudaGraphNode,
315 params: &sys::CUDA_MEMSET_NODE_PARAMS,
316 stream: &CudaStream,
317 ) -> Result<()> {
318 if node.kind != CudaGraphNodeKind::Memset {
319 return Err(XlogError::Kernel(format!(
320 "set_memset_node_params called for non-memset graph node {:?}",
321 node.kind
322 )));
323 }
324 let ctx = stream_context(stream)?;
325 unsafe {
326 cuda_graph_check(
327 "cuGraphExecMemsetNodeSetParams",
328 sys::cuGraphExecMemsetNodeSetParams(self.exec, node.raw, params, ctx),
329 )
330 }
331 }
332
333 pub fn graph(&self) -> sys::CUgraph {
335 self.graph
336 }
337
338 pub fn exec(&self) -> sys::CUgraphExec {
340 self.exec
341 }
342}
343
344impl Drop for CapturedCudaGraph {
345 fn drop(&mut self) {
346 unsafe {
347 if !self.exec.is_null() {
348 let _ = sys::cuGraphExecDestroy(self.exec);
349 }
350 if !self.graph.is_null() {
351 let _ = sys::cuGraphDestroy(self.graph);
352 }
353 }
354 }
355}
356
357impl CudaGraphNodeKind {
358 fn from_sys(kind: sys::CUgraphNodeType) -> Self {
359 match kind {
360 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL => Self::Kernel,
361 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY => Self::Memcpy,
362 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET => Self::Memset,
363 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST => Self::Host,
364 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH => Self::Graph,
365 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY => Self::Empty,
366 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT => Self::WaitEvent,
367 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD => Self::EventRecord,
368 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL => {
369 Self::ExternalSemaphoresSignal
370 }
371 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT => Self::ExternalSemaphoresWait,
372 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC => Self::MemAlloc,
373 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE => Self::MemFree,
374 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_BATCH_MEM_OP => Self::BatchMemOp,
375 sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_CONDITIONAL => Self::Conditional,
376 }
377 }
378}
379
380fn cuda_graph_check(label: &str, code: sys::CUresult) -> Result<()> {
381 if code == sys::CUresult::CUDA_SUCCESS {
382 Ok(())
383 } else {
384 Err(XlogError::Kernel(format!("{label} failed: {code:?}")))
385 }
386}
387
388fn stream_context(stream: &CudaStream) -> Result<sys::CUcontext> {
389 let mut ctx = ptr::null_mut();
390 unsafe {
391 cuda_graph_check(
392 "cuStreamGetCtx",
393 sys::cuStreamGetCtx(stream.cu_stream(), &mut ctx),
394 )?;
395 }
396 if ctx.is_null() {
397 Err(XlogError::Kernel(
398 "cuStreamGetCtx returned a null CUDA context".to_string(),
399 ))
400 } else {
401 Ok(ctx)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn scan_topology_matches_recursive_multiblock_shape() {
411 assert_eq!(
412 scan_topology_u32(0),
413 ScanTopology {
414 input_len: 0,
415 block_size: 256,
416 scratch_lengths: vec![],
417 kernel_node_count: 0,
418 }
419 );
420 assert_eq!(scan_topology_u32(256).scratch_lengths, Vec::<u32>::new());
421 assert_eq!(scan_topology_u32(256).kernel_node_count, 1);
422 assert_eq!(scan_topology_u32(257).scratch_lengths, vec![2]);
423 assert_eq!(scan_topology_u32(257).kernel_node_count, 3);
424 assert_eq!(scan_topology_u32(65_537).scratch_lengths, vec![257, 2]);
425 assert_eq!(scan_topology_u32(65_537).kernel_node_count, 5);
426 }
427
428 #[test]
429 fn csm_key_uses_capacity_classes_and_layout_version() {
430 let key = CsmCudaGraphKey::inner(2, 16, 257, 513).expect("key");
431 assert_eq!(key.join_kind, CsmCudaGraphJoinKind::Inner);
432 assert_eq!(key.key_arity, 2);
433 assert_eq!(key.key_bytes, 16);
434 assert_eq!(key.probe_capacity_class, 512);
435 assert_eq!(key.output_capacity_class, 1024);
436 assert_eq!(key.scan_topology.scratch_lengths, vec![2]);
437 assert_eq!(key.node_layout_version, CSM_CUDA_GRAPH_NODE_LAYOUT_VERSION);
438 }
439}