Skip to main content

xlog_cuda/
cuda_graph.rs

1//! CUDA Graph RAII helpers for production graph capture/replay.
2//!
3//! This module intentionally stays close to the CUDA driver API. The bounded
4//! CSM CUDA Graph path needs explicit graph lifetime ownership and node
5//! inventory before it can safely update graph-exec parameters for runtime
6//! pointers and capacity classes.
7
8use std::{mem, ptr};
9
10use cudarc::driver::{sys, CudaStream};
11use xlog_core::{Result, XlogError};
12
13pub const CSM_CUDA_GRAPH_NODE_LAYOUT_VERSION: u32 = 1;
14
15/// Instantiated CUDA Graph with owned graph + exec handles.
16pub struct CapturedCudaGraph {
17    graph: sys::CUgraph,
18    exec: sys::CUgraphExec,
19}
20
21// CUDA graph handles are context-owned driver handles. xlog stores them behind
22// provider-level synchronization when caching graph executions.
23unsafe impl Send for CapturedCudaGraph {}
24unsafe impl Sync for CapturedCudaGraph {}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum CudaGraphNodeKind {
28    Kernel,
29    Memcpy,
30    Memset,
31    Host,
32    Graph,
33    Empty,
34    WaitEvent,
35    EventRecord,
36    ExternalSemaphoresSignal,
37    ExternalSemaphoresWait,
38    MemAlloc,
39    MemFree,
40    BatchMemOp,
41    Conditional,
42}
43
44#[derive(Debug, Clone, Copy)]
45pub struct CudaGraphNode {
46    pub index: usize,
47    pub raw: sys::CUgraphNode,
48    pub kind: CudaGraphNodeKind,
49}
50
51unsafe impl Send for CudaGraphNode {}
52unsafe impl Sync for CudaGraphNode {}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub enum CsmCudaGraphJoinKind {
56    Inner,
57    IndexedInner,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61pub struct ScanTopology {
62    pub input_len: u32,
63    pub block_size: u32,
64    pub scratch_lengths: Vec<u32>,
65    pub kernel_node_count: usize,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct CsmCudaGraphKey {
70    pub join_kind: CsmCudaGraphJoinKind,
71    pub key_arity: u8,
72    pub key_bytes: u32,
73    pub probe_capacity_class: u32,
74    pub output_capacity_class: u32,
75    pub scan_topology: ScanTopology,
76    pub node_layout_version: u32,
77}
78
79impl CsmCudaGraphKey {
80    pub fn inner(
81        key_arity: usize,
82        key_bytes: u32,
83        probe_capacity: u32,
84        output_capacity: u32,
85    ) -> Result<Self> {
86        let key_arity = u8::try_from(key_arity).map_err(|_| {
87            XlogError::Kernel(format!(
88                "CSM CUDA Graph key arity {} exceeds u8::MAX",
89                key_arity
90            ))
91        })?;
92        Ok(Self {
93            join_kind: CsmCudaGraphJoinKind::Inner,
94            key_arity,
95            key_bytes,
96            probe_capacity_class: graph_capacity_class_u32(probe_capacity),
97            output_capacity_class: graph_capacity_class_u32(output_capacity),
98            scan_topology: scan_topology_u32(probe_capacity),
99            node_layout_version: CSM_CUDA_GRAPH_NODE_LAYOUT_VERSION,
100        })
101    }
102}
103
104pub fn graph_capacity_class_u32(n: u32) -> u32 {
105    if n <= 1 {
106        1
107    } else {
108        n.checked_next_power_of_two().unwrap_or(u32::MAX)
109    }
110}
111
112pub fn scan_topology_u32(mut n: u32) -> ScanTopology {
113    let input_len = n;
114    let block_size = 256u32;
115    let mut scratch_lengths = Vec::new();
116    let mut kernel_node_count = if n == 0 { 0 } else { 1 };
117    while n > block_size {
118        let num_blocks = n.div_ceil(block_size);
119        scratch_lengths.push(num_blocks);
120        kernel_node_count += 2;
121        n = num_blocks;
122    }
123    ScanTopology {
124        input_len,
125        block_size,
126        scratch_lengths,
127        kernel_node_count,
128    }
129}
130
131impl CapturedCudaGraph {
132    /// Capture work submitted by `record` on `stream`, instantiate it, and take
133    /// ownership of the resulting graph handles.
134    pub fn capture_on_stream<F>(stream: &CudaStream, record: F) -> Result<Self>
135    where
136        F: FnOnce() -> Result<()>,
137    {
138        unsafe {
139            cuda_graph_check(
140                "cuStreamBeginCapture_v2",
141                sys::cuStreamBeginCapture_v2(
142                    stream.cu_stream(),
143                    sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
144                ),
145            )?;
146        }
147
148        let record_result = record();
149        let mut graph: sys::CUgraph = ptr::null_mut();
150        let end_result = unsafe {
151            cuda_graph_check(
152                "cuStreamEndCapture",
153                sys::cuStreamEndCapture(stream.cu_stream(), &mut graph),
154            )
155        };
156
157        if let Err(record_err) = record_result {
158            if end_result.is_ok() && !graph.is_null() {
159                unsafe {
160                    let _ = sys::cuGraphDestroy(graph);
161                }
162            }
163            return Err(record_err);
164        }
165        end_result?;
166        if graph.is_null() {
167            return Err(XlogError::Kernel(
168                "cuStreamEndCapture returned a null CUDA graph".to_string(),
169            ));
170        }
171
172        let mut exec: sys::CUgraphExec = ptr::null_mut();
173        unsafe {
174            if let Err(err) = cuda_graph_check(
175                "cuGraphInstantiateWithFlags",
176                sys::cuGraphInstantiateWithFlags(&mut exec, graph, 0),
177            ) {
178                let _ = sys::cuGraphDestroy(graph);
179                return Err(err);
180            }
181        }
182        if exec.is_null() {
183            unsafe {
184                let _ = sys::cuGraphDestroy(graph);
185            }
186            return Err(XlogError::Kernel(
187                "cuGraphInstantiateWithFlags returned a null CUDA graph exec".to_string(),
188            ));
189        }
190
191        Ok(Self { graph, exec })
192    }
193
194    /// Replay the instantiated graph on `stream`.
195    pub fn launch(&self, stream: &CudaStream) -> Result<()> {
196        unsafe {
197            cuda_graph_check(
198                "cuGraphLaunch",
199                sys::cuGraphLaunch(self.exec, stream.cu_stream()),
200            )
201        }
202    }
203
204    /// Number of nodes in the captured graph. Used by bounded CSM CUDA Graph
205    /// cache-key and node-inventory certs to prove topology stability.
206    pub fn node_count(&self) -> Result<usize> {
207        let mut count = 0usize;
208        unsafe {
209            cuda_graph_check(
210                "cuGraphGetNodes(count)",
211                sys::cuGraphGetNodes(self.graph, ptr::null_mut(), &mut count),
212            )?;
213        }
214        Ok(count)
215    }
216
217    /// Return graph nodes in CUDA's captured graph order with their node type.
218    pub fn nodes(&self) -> Result<Vec<CudaGraphNode>> {
219        let count = self.node_count()?;
220        if count == 0 {
221            return Ok(Vec::new());
222        }
223        let mut raw_nodes = vec![ptr::null_mut(); count];
224        let mut count_again = count;
225        unsafe {
226            cuda_graph_check(
227                "cuGraphGetNodes(nodes)",
228                sys::cuGraphGetNodes(self.graph, raw_nodes.as_mut_ptr(), &mut count_again),
229            )?;
230        }
231        raw_nodes.truncate(count_again);
232
233        let mut nodes = Vec::with_capacity(raw_nodes.len());
234        for (index, raw) in raw_nodes.into_iter().enumerate() {
235            let mut ty = sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY;
236            unsafe {
237                cuda_graph_check("cuGraphNodeGetType", sys::cuGraphNodeGetType(raw, &mut ty))?;
238            }
239            nodes.push(CudaGraphNode {
240                index,
241                raw,
242                kind: CudaGraphNodeKind::from_sys(ty),
243            });
244        }
245        Ok(nodes)
246    }
247
248    /// Read CUDA's raw kernel-node params for inventory/update code.
249    ///
250    /// The returned `kernelParams` pointer is CUDA-owned capture metadata. Treat
251    /// it as read-only unless constructing a fresh params object for
252    /// [`Self::set_kernel_node_params`].
253    pub fn kernel_node_params(&self, node: CudaGraphNode) -> Result<sys::CUDA_KERNEL_NODE_PARAMS> {
254        if node.kind != CudaGraphNodeKind::Kernel {
255            return Err(XlogError::Kernel(format!(
256                "kernel_node_params called for non-kernel graph node {:?}",
257                node.kind
258            )));
259        }
260        let mut params: sys::CUDA_KERNEL_NODE_PARAMS = unsafe { mem::zeroed() };
261        unsafe {
262            cuda_graph_check(
263                "cuGraphKernelNodeGetParams_v2",
264                sys::cuGraphKernelNodeGetParams_v2(node.raw, &mut params),
265            )?;
266        }
267        Ok(params)
268    }
269
270    /// Update a kernel node in the instantiated graph.
271    ///
272    /// # Safety
273    /// CUDA requires the replacement params to be topology-compatible with the
274    /// captured node. The caller must keep every pointed-to kernel argument
275    /// alive until CUDA has consumed the update and launched work that uses it.
276    pub unsafe fn set_kernel_node_params(
277        &self,
278        node: CudaGraphNode,
279        params: &sys::CUDA_KERNEL_NODE_PARAMS,
280    ) -> Result<()> {
281        if node.kind != CudaGraphNodeKind::Kernel {
282            return Err(XlogError::Kernel(format!(
283                "set_kernel_node_params called for non-kernel graph node {:?}",
284                node.kind
285            )));
286        }
287        cuda_graph_check(
288            "cuGraphExecKernelNodeSetParams_v2",
289            sys::cuGraphExecKernelNodeSetParams_v2(self.exec, node.raw, params),
290        )
291    }
292
293    /// Read CUDA's raw memset-node params for inventory/update code.
294    pub fn memset_node_params(&self, node: CudaGraphNode) -> Result<sys::CUDA_MEMSET_NODE_PARAMS> {
295        if node.kind != CudaGraphNodeKind::Memset {
296            return Err(XlogError::Kernel(format!(
297                "memset_node_params called for non-memset graph node {:?}",
298                node.kind
299            )));
300        }
301        let mut params: sys::CUDA_MEMSET_NODE_PARAMS = unsafe { mem::zeroed() };
302        unsafe {
303            cuda_graph_check(
304                "cuGraphMemsetNodeGetParams",
305                sys::cuGraphMemsetNodeGetParams(node.raw, &mut params),
306            )?;
307        }
308        Ok(params)
309    }
310
311    /// Update a memset node in the instantiated graph.
312    pub fn set_memset_node_params(
313        &self,
314        node: CudaGraphNode,
315        params: &sys::CUDA_MEMSET_NODE_PARAMS,
316        stream: &CudaStream,
317    ) -> Result<()> {
318        if node.kind != CudaGraphNodeKind::Memset {
319            return Err(XlogError::Kernel(format!(
320                "set_memset_node_params called for non-memset graph node {:?}",
321                node.kind
322            )));
323        }
324        let ctx = stream_context(stream)?;
325        unsafe {
326            cuda_graph_check(
327                "cuGraphExecMemsetNodeSetParams",
328                sys::cuGraphExecMemsetNodeSetParams(self.exec, node.raw, params, ctx),
329            )
330        }
331    }
332
333    /// Raw graph handle for low-level node inventory/update code.
334    pub fn graph(&self) -> sys::CUgraph {
335        self.graph
336    }
337
338    /// Raw instantiated graph handle for low-level graph-exec update code.
339    pub fn exec(&self) -> sys::CUgraphExec {
340        self.exec
341    }
342}
343
344impl Drop for CapturedCudaGraph {
345    fn drop(&mut self) {
346        unsafe {
347            if !self.exec.is_null() {
348                let _ = sys::cuGraphExecDestroy(self.exec);
349            }
350            if !self.graph.is_null() {
351                let _ = sys::cuGraphDestroy(self.graph);
352            }
353        }
354    }
355}
356
357impl CudaGraphNodeKind {
358    fn from_sys(kind: sys::CUgraphNodeType) -> Self {
359        match kind {
360            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL => Self::Kernel,
361            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY => Self::Memcpy,
362            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET => Self::Memset,
363            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST => Self::Host,
364            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH => Self::Graph,
365            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY => Self::Empty,
366            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT => Self::WaitEvent,
367            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD => Self::EventRecord,
368            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL => {
369                Self::ExternalSemaphoresSignal
370            }
371            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT => Self::ExternalSemaphoresWait,
372            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC => Self::MemAlloc,
373            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE => Self::MemFree,
374            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_BATCH_MEM_OP => Self::BatchMemOp,
375            sys::CUgraphNodeType::CU_GRAPH_NODE_TYPE_CONDITIONAL => Self::Conditional,
376        }
377    }
378}
379
380fn cuda_graph_check(label: &str, code: sys::CUresult) -> Result<()> {
381    if code == sys::CUresult::CUDA_SUCCESS {
382        Ok(())
383    } else {
384        Err(XlogError::Kernel(format!("{label} failed: {code:?}")))
385    }
386}
387
388fn stream_context(stream: &CudaStream) -> Result<sys::CUcontext> {
389    let mut ctx = ptr::null_mut();
390    unsafe {
391        cuda_graph_check(
392            "cuStreamGetCtx",
393            sys::cuStreamGetCtx(stream.cu_stream(), &mut ctx),
394        )?;
395    }
396    if ctx.is_null() {
397        Err(XlogError::Kernel(
398            "cuStreamGetCtx returned a null CUDA context".to_string(),
399        ))
400    } else {
401        Ok(ctx)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn scan_topology_matches_recursive_multiblock_shape() {
411        assert_eq!(
412            scan_topology_u32(0),
413            ScanTopology {
414                input_len: 0,
415                block_size: 256,
416                scratch_lengths: vec![],
417                kernel_node_count: 0,
418            }
419        );
420        assert_eq!(scan_topology_u32(256).scratch_lengths, Vec::<u32>::new());
421        assert_eq!(scan_topology_u32(256).kernel_node_count, 1);
422        assert_eq!(scan_topology_u32(257).scratch_lengths, vec![2]);
423        assert_eq!(scan_topology_u32(257).kernel_node_count, 3);
424        assert_eq!(scan_topology_u32(65_537).scratch_lengths, vec![257, 2]);
425        assert_eq!(scan_topology_u32(65_537).kernel_node_count, 5);
426    }
427
428    #[test]
429    fn csm_key_uses_capacity_classes_and_layout_version() {
430        let key = CsmCudaGraphKey::inner(2, 16, 257, 513).expect("key");
431        assert_eq!(key.join_kind, CsmCudaGraphJoinKind::Inner);
432        assert_eq!(key.key_arity, 2);
433        assert_eq!(key.key_bytes, 16);
434        assert_eq!(key.probe_capacity_class, 512);
435        assert_eq!(key.output_capacity_class, 1024);
436        assert_eq!(key.scan_topology.scratch_lengths, vec![2]);
437        assert_eq!(key.node_layout_version, CSM_CUDA_GRAPH_NODE_LAYOUT_VERSION);
438    }
439}