1use std::collections::HashMap;
4use xlog_core::XlogError;
5use xlog_cuda::{CudaBuffer, CudaKernelProvider};
6
7pub fn read_device_row_count(
9 provider: &CudaKernelProvider,
10 buffer: &CudaBuffer,
11) -> Result<usize, XlogError> {
12 provider.device_row_count(buffer)
13}
14
15pub struct IlpRegistry {
17 masks: HashMap<String, IlpMask>,
18}
19
20#[allow(clippy::large_enum_variant)]
22pub enum IlpMask {
23 Dense {
25 hard: CudaBuffer,
27 soft: CudaBuffer,
29 schema_size: usize,
31 },
32 Sparse {
34 active_entries: Vec<(u32, u32, u32)>,
36 schema_size: usize,
38 },
39 SparseDevice {
41 candidate_order: Vec<(u32, u32, u32)>,
43 active_flags: CudaBuffer,
45 selected_count: usize,
47 schema_size: usize,
49 },
50}
51
52impl IlpMask {
53 pub fn schema_size(&self) -> usize {
55 match self {
56 IlpMask::Dense { schema_size, .. } => *schema_size,
57 IlpMask::Sparse { schema_size, .. } => *schema_size,
58 IlpMask::SparseDevice { schema_size, .. } => *schema_size,
59 }
60 }
61}
62
63pub struct IlpTaggedResult {
66 pub entries: Vec<IlpTagEntry>,
68}
69
70pub struct IlpTagEntry {
73 pub i: u32,
75 pub j: u32,
77 pub k: u32,
79 pub num_rows: u32,
81 pub buffer: Option<CudaBuffer>,
83}
84
85impl IlpRegistry {
86 pub fn new() -> Self {
88 Self {
89 masks: HashMap::new(),
90 }
91 }
92
93 pub fn clear(&mut self) {
95 self.masks.clear();
96 }
97
98 pub fn insert_mask(
100 &mut self,
101 name: String,
102 hard: CudaBuffer,
103 soft: CudaBuffer,
104 schema_size: usize,
105 ) {
106 self.masks.insert(
107 name,
108 IlpMask::Dense {
109 hard,
110 soft,
111 schema_size,
112 },
113 );
114 }
115
116 pub fn insert_mask_from_sparse(
121 &mut self,
122 name: String,
123 schema_size: usize,
124 active_ijk: &[(u32, u32, u32)],
125 active_soft: &[f32],
126 budget: usize,
127 ) -> Result<(), XlogError> {
128 if active_ijk.len() != active_soft.len() {
129 return Err(XlogError::Execution(format!(
130 "active_ijk length {} != active_soft length {}",
131 active_ijk.len(),
132 active_soft.len()
133 )));
134 }
135
136 let mut ranked: Vec<(usize, f32)> = active_soft.iter().copied().enumerate().collect();
139 ranked.retain(|(_, soft)| *soft > 0.0);
140 ranked.sort_by(|a, b| {
141 b.1.partial_cmp(&a.1)
142 .unwrap_or(std::cmp::Ordering::Equal)
143 .then(a.0.cmp(&b.0))
144 });
145 ranked.truncate(budget.min(ranked.len()));
146
147 let entries: Vec<(u32, u32, u32)> =
148 ranked.iter().map(|&(idx, _)| active_ijk[idx]).collect();
149
150 self.masks.insert(
151 name,
152 IlpMask::Sparse {
153 active_entries: entries,
154 schema_size,
155 },
156 );
157 Ok(())
158 }
159
160 pub fn insert_selected_mask(
162 &mut self,
163 name: String,
164 schema_size: usize,
165 active_entries: &[(u32, u32, u32)],
166 ) {
167 self.masks.insert(
168 name,
169 IlpMask::Sparse {
170 active_entries: active_entries.to_vec(),
171 schema_size,
172 },
173 );
174 }
175
176 pub fn insert_selected_mask_device(
178 &mut self,
179 name: String,
180 schema_size: usize,
181 candidate_order: Vec<(u32, u32, u32)>,
182 active_flags: CudaBuffer,
183 selected_count: usize,
184 ) {
185 self.masks.insert(
186 name,
187 IlpMask::SparseDevice {
188 candidate_order,
189 active_flags,
190 selected_count,
191 schema_size,
192 },
193 );
194 }
195
196 pub fn get_mask(&self, name: &str) -> Option<&IlpMask> {
198 self.masks.get(name)
199 }
200
201 pub fn has_sparse_device_mask(&self) -> bool {
203 self.masks
204 .values()
205 .any(|mask| matches!(mask, IlpMask::SparseDevice { .. }))
206 }
207}
208
209impl Default for IlpRegistry {
210 fn default() -> Self {
211 Self::new()
212 }
213}