xlog_cuda/
multi_gpu_memory.rs1use std::sync::Arc;
6
7use xlog_core::{MemoryBudget, Result, XlogError};
8
9use crate::memory::TrackedCudaSlice;
10use crate::{GpuDevicePool, GpuMemoryManager};
11
12pub struct MultiGpuMemoryManager {
14 pool: Arc<GpuDevicePool>,
15 managers: Vec<Arc<GpuMemoryManager>>,
16}
17
18impl MultiGpuMemoryManager {
19 pub fn new(pool: Arc<GpuDevicePool>, budget_per_device: MemoryBudget) -> Result<Self> {
21 let mut managers = Vec::with_capacity(pool.device_count());
22
23 for device in pool.devices() {
24 let mgr = GpuMemoryManager::new(device.clone(), budget_per_device.clone());
25 managers.push(Arc::new(mgr));
26 }
27
28 Ok(Self { pool, managers })
29 }
30
31 pub fn device_count(&self) -> usize {
33 self.pool.device_count()
34 }
35
36 pub fn alloc_on_device<T: cudarc::driver::DeviceRepr>(
38 &self,
39 device_idx: usize,
40 len: usize,
41 ) -> Result<TrackedCudaSlice<T>> {
42 let mgr = self
43 .managers
44 .get(device_idx)
45 .ok_or_else(|| XlogError::Kernel(format!("Device {} not found", device_idx)))?;
46 mgr.alloc::<T>(len)
47 }
48
49 pub fn alloc_next<T: cudarc::driver::DeviceRepr>(
51 &self,
52 len: usize,
53 ) -> Result<(usize, TrackedCudaSlice<T>)> {
54 let device_idx = self.pool.next_device_idx();
55 let slice = self.alloc_on_device::<T>(device_idx, len)?;
56 Ok((device_idx, slice))
57 }
58
59 pub fn get_manager(&self, device_idx: usize) -> Option<&Arc<GpuMemoryManager>> {
61 self.managers.get(device_idx)
62 }
63
64 pub fn remaining_bytes(&self, device_idx: usize) -> u64 {
66 self.managers
67 .get(device_idx)
68 .map(|m| m.remaining_bytes())
69 .unwrap_or(0)
70 }
71
72 pub fn pool(&self) -> &Arc<GpuDevicePool> {
74 &self.pool
75 }
76}