Skip to main content

xlog_cuda/
device_pool.rs

1//! Multi-GPU device pool management
2//!
3//! Provides a pool of CUDA devices for distributing operations.
4
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7
8use xlog_core::{Result, XlogError};
9
10use crate::CudaDevice;
11
12/// Pool of CUDA devices for multi-GPU operations
13///
14/// Manages multiple devices and provides round-robin scheduling.
15pub struct GpuDevicePool {
16    /// Available CUDA devices
17    devices: Vec<Arc<CudaDevice>>,
18    /// Current device index for round-robin scheduling
19    current: AtomicUsize,
20}
21
22impl GpuDevicePool {
23    /// Create a new device pool with the specified number of devices
24    pub fn new(device_count: usize) -> Result<Self> {
25        if device_count == 0 {
26            return Err(XlogError::Kernel(
27                "Device pool requires at least one device".to_string(),
28            ));
29        }
30
31        // cudarc may panic on driver init failures in restricted containers; treat as a normal error.
32        let available = CudaDevice::count()?;
33
34        if device_count > available as usize {
35            return Err(XlogError::Kernel(format!(
36                "Requested {} devices but only {} available",
37                device_count, available
38            )));
39        }
40
41        let mut devices = Vec::with_capacity(device_count);
42        for ordinal in 0..device_count {
43            let device = CudaDevice::new(ordinal).map_err(|e| {
44                XlogError::Kernel(format!("Failed to create device {}: {}", ordinal, e))
45            })?;
46            devices.push(Arc::new(device));
47        }
48
49        Ok(Self {
50            devices,
51            current: AtomicUsize::new(0),
52        })
53    }
54
55    /// Get the number of devices in the pool
56    pub fn device_count(&self) -> usize {
57        self.devices.len()
58    }
59
60    /// Get a specific device by index
61    pub fn get_device(&self, idx: usize) -> Option<&Arc<CudaDevice>> {
62        self.devices.get(idx)
63    }
64
65    /// Get the next device index using round-robin scheduling
66    pub fn next_device_idx(&self) -> usize {
67        let idx = self.current.fetch_add(1, Ordering::SeqCst);
68        idx % self.devices.len()
69    }
70
71    /// Get the next device using round-robin scheduling
72    pub fn next_device(&self) -> &Arc<CudaDevice> {
73        let idx = self.next_device_idx();
74        &self.devices[idx]
75    }
76
77    /// Synchronize all devices
78    pub fn synchronize_all(&self) -> Result<()> {
79        for (i, device) in self.devices.iter().enumerate() {
80            device
81                .synchronize()
82                .map_err(|e| XlogError::Kernel(format!("Failed to sync device {}: {}", i, e)))?;
83        }
84        Ok(())
85    }
86
87    /// Get all devices
88    pub fn devices(&self) -> &[Arc<CudaDevice>] {
89        &self.devices
90    }
91}