Skip to main content

xlog_cuda/
multi_gpu_memory.rs

1//! Multi-GPU memory management
2//!
3//! Provides memory allocation across multiple GPU devices.
4
5use std::sync::Arc;
6
7use xlog_core::{MemoryBudget, Result, XlogError};
8
9use crate::memory::TrackedCudaSlice;
10use crate::{GpuDevicePool, GpuMemoryManager};
11
12/// Memory manager for multiple GPU devices
13pub struct MultiGpuMemoryManager {
14    pool: Arc<GpuDevicePool>,
15    managers: Vec<Arc<GpuMemoryManager>>,
16}
17
18impl MultiGpuMemoryManager {
19    /// Create a new multi-GPU memory manager
20    pub fn new(pool: Arc<GpuDevicePool>, budget_per_device: MemoryBudget) -> Result<Self> {
21        let mut managers = Vec::with_capacity(pool.device_count());
22
23        for device in pool.devices() {
24            let mgr = GpuMemoryManager::new(device.clone(), budget_per_device.clone());
25            managers.push(Arc::new(mgr));
26        }
27
28        Ok(Self { pool, managers })
29    }
30
31    /// Get the number of devices
32    pub fn device_count(&self) -> usize {
33        self.pool.device_count()
34    }
35
36    /// Allocate memory on a specific device
37    pub fn alloc_on_device<T: cudarc::driver::DeviceRepr>(
38        &self,
39        device_idx: usize,
40        len: usize,
41    ) -> Result<TrackedCudaSlice<T>> {
42        let mgr = self
43            .managers
44            .get(device_idx)
45            .ok_or_else(|| XlogError::Kernel(format!("Device {} not found", device_idx)))?;
46        mgr.alloc::<T>(len)
47    }
48
49    /// Allocate memory on the next device (round-robin)
50    pub fn alloc_next<T: cudarc::driver::DeviceRepr>(
51        &self,
52        len: usize,
53    ) -> Result<(usize, TrackedCudaSlice<T>)> {
54        let device_idx = self.pool.next_device_idx();
55        let slice = self.alloc_on_device::<T>(device_idx, len)?;
56        Ok((device_idx, slice))
57    }
58
59    /// Get the memory manager for a specific device
60    pub fn get_manager(&self, device_idx: usize) -> Option<&Arc<GpuMemoryManager>> {
61        self.managers.get(device_idx)
62    }
63
64    /// Get remaining bytes on a specific device
65    pub fn remaining_bytes(&self, device_idx: usize) -> u64 {
66        self.managers
67            .get(device_idx)
68            .map(|m| m.remaining_bytes())
69            .unwrap_or(0)
70    }
71
72    /// Get the device pool
73    pub fn pool(&self) -> &Arc<GpuDevicePool> {
74        &self.pool
75    }
76}