Skip to main content

xlog_cuda/device_runtime/
stream_pool.rs

1//! [`StreamPool`] — owned non-blocking CUDA streams indexed by
2//! [`StreamId`].
3//!
4//! The runtime hands out stable [`StreamId`]s to callers and resolves
5//! them to live `cudarc::driver::CudaStream` handles internally. The
6//! pool grows on demand: `acquire` returns a fresh non-blocking stream
7//! up to `max_streams`. Streams are never returned to a free-list —
8//! they stay alive for the runtime's lifetime so [`StreamId`] handles
9//! remain valid for correlated allocate/launch/deallocate sequences.
10//!
11//! # Failure semantics
12//!
13//! `acquire` returns [`Result`]. On capacity exhaustion or
14//! `cudarc::driver::CudaStream::fork` failure the call returns
15//! [`StreamPoolError`] rather than silently collapsing onto the
16//! default stream — that fall-back was a footgun: it broke
17//! stream-ordered isolation (a "non-default" allocation could end up
18//! on the legacy default stream) without surfacing the failure to
19//! the caller.
20
21use std::sync::Arc;
22use std::sync::Mutex;
23
24use cudarc::driver::CudaStream;
25
26use super::resource::StreamId;
27use crate::CudaDevice;
28
29/// Default maximum stream count. The executor's typical concurrency
30/// is 1 deterministic stream + a small handful of join/scan helpers,
31/// so 16 leaves substantial headroom without burning device-state on
32/// idle streams.
33pub const DEFAULT_MAX_STREAMS: usize = 16;
34pub const ENV_WCOJ_POOL_MB_PER_STREAM: &str = "XLOG_WCOJ_POOL_MB_PER_STREAM";
35pub const DEFAULT_POOL_MB_PER_STREAM: u64 = 256;
36
37pub fn configured_pool_mb_per_stream() -> u64 {
38    std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM)
39        .ok()
40        .and_then(|raw| raw.trim().parse::<u64>().ok())
41        .filter(|mb| *mb > 0)
42        .unwrap_or(DEFAULT_POOL_MB_PER_STREAM)
43}
44
45pub fn configured_pool_bytes_per_stream() -> u64 {
46    configured_pool_mb_per_stream().saturating_mul(1024 * 1024)
47}
48
49pub fn planned_pool_budget_bytes(arms: u64, streams: u64) -> u64 {
50    arms.saturating_mul(streams)
51        .saturating_mul(configured_pool_bytes_per_stream())
52}
53
54/// Errors returned by [`StreamPool::acquire`]. Both variants are hard
55/// failures; callers must not silently substitute [`StreamId::DEFAULT`].
56#[derive(Debug)]
57pub enum StreamPoolError {
58    /// Pool already holds `max` non-default streams. Caller should
59    /// either reuse an existing acquired id or raise the pool cap via
60    /// the runtime config.
61    Capacity { max: usize },
62    /// `CudaStream::fork` returned an error. Carries the wrapped
63    /// driver message.
64    ForkFailed(String),
65}
66
67impl std::fmt::Display for StreamPoolError {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            Self::Capacity { max } => {
71                write!(f, "stream pool at capacity (max={})", max)
72            }
73            Self::ForkFailed(msg) => {
74                write!(f, "stream fork failed: {}", msg)
75            }
76        }
77    }
78}
79
80impl std::error::Error for StreamPoolError {}
81
82/// Pool of owned non-blocking CUDA streams.
83pub struct StreamPool {
84    device: Arc<CudaDevice>,
85    max_streams: usize,
86    pool_bytes_per_stream: u64,
87    /// Stream handles indexed by [`StreamId`]. The slot at index 0 is
88    /// reserved for [`StreamId::DEFAULT`] and lazily-initialized to
89    /// the device's default stream so synchronous codepaths keep
90    /// working unchanged during the migration window.
91    streams: Mutex<Vec<Arc<CudaStream>>>,
92}
93
94impl StreamPool {
95    /// Construct a pool bound to `device`, capped at `max_streams`.
96    pub fn new(device: Arc<CudaDevice>, max_streams: usize) -> Self {
97        Self {
98            device,
99            max_streams: max_streams.max(1),
100            pool_bytes_per_stream: configured_pool_bytes_per_stream(),
101            streams: Mutex::new(Vec::new()),
102        }
103    }
104
105    /// Construct with the default cap [`DEFAULT_MAX_STREAMS`].
106    pub fn with_defaults(device: Arc<CudaDevice>) -> Self {
107        Self::new(device, DEFAULT_MAX_STREAMS)
108    }
109
110    /// Acquire a non-default stream id, growing the pool up to
111    /// `max_streams`. Each successful call returns a distinct
112    /// [`StreamId`] backed by an owned non-blocking
113    /// `cudarc::driver::CudaStream` forked from the device's default
114    /// stream.
115    ///
116    /// # Errors
117    ///   * [`StreamPoolError::Capacity`] if the pool already holds
118    ///     `max_streams` non-default streams.
119    ///   * [`StreamPoolError::ForkFailed`] if the underlying
120    ///     `CudaStream::fork` call failed.
121    ///
122    /// Streams are never returned to a free-list; they remain valid
123    /// for the runtime's lifetime so previously returned [`StreamId`]
124    /// handles keep resolving.
125    pub fn acquire(&self) -> Result<StreamId, StreamPoolError> {
126        let mut streams = self.streams.lock().expect("stream pool poisoned");
127        if streams.len() >= self.max_streams {
128            return Err(StreamPoolError::Capacity {
129                max: self.max_streams,
130            });
131        }
132        match self.device.inner().stream().fork() {
133            Ok(handle) => {
134                streams.push(handle);
135                // Index 0 is reserved for DEFAULT; non-default
136                // streams start at id 1 and correspond to
137                // `streams[id - 1]`.
138                Ok(StreamId(streams.len() as u32))
139            }
140            Err(e) => Err(StreamPoolError::ForkFailed(e.to_string())),
141        }
142    }
143
144    /// Borrow the live `CudaStream` for `id`. Returns `None` if `id`
145    /// has never been issued by this pool. The default-stream slot
146    /// resolves to the device's default stream.
147    pub fn resolve(&self, id: StreamId) -> Option<Arc<CudaStream>> {
148        if id == StreamId::DEFAULT {
149            return Some(Arc::clone(self.device.inner().stream()));
150        }
151        let streams = self.streams.lock().expect("stream pool poisoned");
152        let idx = id.0 as usize;
153        if idx == 0 || idx > streams.len() {
154            return None;
155        }
156        Some(Arc::clone(&streams[idx - 1]))
157    }
158
159    /// Number of non-default streams currently in the pool.
160    pub fn non_default_len(&self) -> usize {
161        self.streams.lock().expect("stream pool poisoned").len()
162    }
163
164    /// Borrow the device handle. Test helpers use this to launch
165    /// kernels into the same device the pool was constructed on.
166    pub fn device(&self) -> &Arc<CudaDevice> {
167        &self.device
168    }
169
170    /// Maximum streams the pool will create on demand.
171    pub fn max_streams(&self) -> usize {
172        self.max_streams
173    }
174
175    /// Planned per-stream pool budget, in bytes, from
176    /// `XLOG_WCOJ_POOL_MB_PER_STREAM` or the 256 MiB default.
177    pub fn pool_bytes_per_stream(&self) -> u64 {
178        self.pool_bytes_per_stream
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    static ENV_LOCK: Mutex<()> = Mutex::new(());
187
188    fn try_device() -> Option<Arc<CudaDevice>> {
189        CudaDevice::new(0).ok().map(Arc::new)
190    }
191
192    #[test]
193    fn acquire_returns_distinct_non_default_ids() {
194        let Some(device) = try_device() else {
195            return;
196        };
197        let pool = StreamPool::new(device, 4);
198        let a = pool.acquire().expect("first acquire");
199        let b = pool.acquire().expect("second acquire");
200        assert_ne!(a, StreamId::DEFAULT);
201        assert_ne!(b, StreamId::DEFAULT);
202        assert_ne!(a, b, "consecutive acquire calls must yield distinct ids");
203        assert_eq!(pool.non_default_len(), 2);
204    }
205
206    #[test]
207    fn acquire_returns_capacity_error_at_max() {
208        let Some(device) = try_device() else {
209            return;
210        };
211        let pool = StreamPool::new(device, 1);
212        let _first = pool.acquire().expect("first acquire under cap");
213        let err = pool.acquire();
214        assert!(
215            matches!(err, Err(StreamPoolError::Capacity { max: 1 })),
216            "expected Capacity error once max_streams hit, got {:?}",
217            err
218        );
219    }
220
221    #[test]
222    fn resolve_default_returns_device_default_stream() {
223        let Some(device) = try_device() else {
224            return;
225        };
226        let pool = StreamPool::with_defaults(device);
227        assert!(pool.resolve(StreamId::DEFAULT).is_some());
228    }
229
230    #[test]
231    fn resolve_acquired_returns_owned_stream() {
232        let Some(device) = try_device() else {
233            return;
234        };
235        let pool = StreamPool::new(device, 4);
236        let id = pool.acquire().expect("acquire");
237        assert_ne!(id, StreamId::DEFAULT);
238        assert!(pool.resolve(id).is_some());
239    }
240
241    #[test]
242    fn resolve_unknown_returns_none() {
243        let Some(device) = try_device() else {
244            return;
245        };
246        let pool = StreamPool::with_defaults(device);
247        assert!(pool.resolve(StreamId(99)).is_none());
248    }
249
250    #[test]
251    fn pool_mb_per_stream_env_overrides_default() {
252        let _guard = ENV_LOCK.lock().expect("env lock poisoned");
253        let old = std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM).ok();
254        std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, "128");
255        assert_eq!(configured_pool_mb_per_stream(), 128);
256        match old {
257            Some(value) => std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, value),
258            None => std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM),
259        }
260    }
261
262    #[test]
263    fn planned_pool_budget_uses_default_4_by_4_contract() {
264        let _guard = ENV_LOCK.lock().expect("env lock poisoned");
265        let old = std::env::var(ENV_WCOJ_POOL_MB_PER_STREAM).ok();
266        std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM);
267        assert_eq!(configured_pool_mb_per_stream(), DEFAULT_POOL_MB_PER_STREAM);
268        assert_eq!(
269            planned_pool_budget_bytes(4, 4),
270            4_u64 * 4 * 256 * 1024 * 1024
271        );
272        match old {
273            Some(value) => std::env::set_var(ENV_WCOJ_POOL_MB_PER_STREAM, value),
274            None => std::env::remove_var(ENV_WCOJ_POOL_MB_PER_STREAM),
275        }
276    }
277}