1use std::collections::HashMap;
2use xlog_core::{RelId, ScalarType, Schema};
3use xlog_cuda::{CudaBuffer, JoinIndexV2};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6pub(crate) struct JoinIndexKey {
7 pub(crate) rel: RelId,
8 pub(crate) version: u64,
9 pub(crate) key_cols: Vec<usize>,
10 pub(crate) schema: JoinIndexSchemaSignature,
11 pub(crate) device_ordinal: u32,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub(crate) struct JoinIndexSchemaSignature {
16 column_types: Vec<ScalarType>,
17 row_size_bytes: usize,
18}
19
20impl JoinIndexSchemaSignature {
21 fn from_schema(schema: &Schema) -> Self {
22 Self {
23 column_types: (0..schema.arity())
24 .filter_map(|idx| schema.column_type(idx))
25 .collect(),
26 row_size_bytes: schema.row_size_bytes(),
27 }
28 }
29}
30
31impl JoinIndexKey {
32 pub(crate) fn new(
33 rel: RelId,
34 version: u64,
35 key_cols: Vec<usize>,
36 schema: &Schema,
37 device_ordinal: u32,
38 ) -> Self {
39 Self {
40 rel,
41 version,
42 key_cols,
43 schema: JoinIndexSchemaSignature::from_schema(schema),
44 device_ordinal,
45 }
46 }
47}
48
49struct CachedJoinIndex {
50 index: CachedJoinIndexPayload,
51 bytes: u64,
52 last_used: u64,
53}
54
55#[allow(clippy::large_enum_variant)]
56enum CachedJoinIndexPayload {
57 Ready(JoinIndexV2),
58 #[cfg(test)]
59 Placeholder,
60}
61
62#[derive(Clone, Debug, Default, PartialEq, Eq)]
64pub struct JoinIndexCacheStats {
65 pub lookups: u64,
67 pub hits: u64,
69 pub misses: u64,
71 pub builds: u64,
73 pub evictions: u64,
75 pub invalidations: u64,
77 pub stale_rejections: u64,
79 pub background_build_requests: u64,
81 pub background_builds_completed: u64,
83 pub background_builds_deferred: u64,
85 pub entries: usize,
87 pub total_bytes: u64,
89}
90
91pub(crate) struct JoinIndexCache {
92 entries: HashMap<JoinIndexKey, CachedJoinIndex>,
93 clock: u64,
94 total_bytes: u64,
95 pub(crate) max_bytes: u64,
96 stats: JoinIndexCacheStats,
97}
98
99pub(crate) fn estimate_join_index_bytes(right: &CudaBuffer, right_keys: &[usize]) -> u64 {
103 if right_keys.is_empty() {
104 return u64::MAX;
105 }
106
107 let mut key_bytes_per_row: u64 = 0;
108 for &k in right_keys {
109 let Some(ty) = right.schema().column_type(k) else {
110 return u64::MAX;
111 };
112 key_bytes_per_row = key_bytes_per_row.saturating_add(ty.size_bytes() as u64);
113 }
114
115 let num_rows = right.num_rows();
116 let packed_bytes = num_rows.saturating_mul(key_bytes_per_row);
117 let target = num_rows.saturating_mul(2).max(1024);
118 let num_buckets = target.next_power_of_two();
119
120 packed_bytes
122 .saturating_add(num_buckets.saturating_mul(8))
123 .saturating_add(num_rows.saturating_mul(12))
124}
125
126impl JoinIndexCache {
127 pub(crate) fn new(max_bytes: u64) -> Self {
128 Self {
129 entries: HashMap::new(),
130 clock: 0,
131 total_bytes: 0,
132 max_bytes,
133 stats: JoinIndexCacheStats::default(),
134 }
135 }
136
137 pub(crate) fn should_build(
142 &self,
143 est_index_bytes: u64,
144 build_heat: f32,
145 remaining_device_bytes: u64,
146 device_budget_bytes: u64,
147 ) -> bool {
148 let heat_threshold = if self.max_bytes > 0 && est_index_bytes > self.max_bytes / 2 {
149 0.6
150 } else {
151 0.3
152 };
153 let has_room =
154 remaining_device_bytes >= est_index_bytes.saturating_add(device_budget_bytes / 10);
155
156 build_heat >= heat_threshold && est_index_bytes <= self.max_bytes && has_room
157 }
158
159 pub(crate) fn clear(&mut self) {
160 let removed = self.entries.len() as u64;
161 self.entries.clear();
162 self.clock = 0;
163 self.total_bytes = 0;
164 self.stats.invalidations = self.stats.invalidations.saturating_add(removed);
165 }
166
167 pub(crate) fn get(&mut self, key: &JoinIndexKey) -> Option<&JoinIndexV2> {
168 self.stats.lookups = self.stats.lookups.saturating_add(1);
169 let Some(entry) = self.entries.get_mut(key) else {
170 self.stats.misses = self.stats.misses.saturating_add(1);
171 return None;
172 };
173 self.clock = self.clock.saturating_add(1);
174 entry.last_used = self.clock;
175 match &entry.index {
176 CachedJoinIndexPayload::Ready(index) => {
177 self.stats.hits = self.stats.hits.saturating_add(1);
178 Some(index)
179 }
180 #[cfg(test)]
181 CachedJoinIndexPayload::Placeholder => {
182 self.stats.misses = self.stats.misses.saturating_add(1);
183 None
184 }
185 }
186 }
187
188 pub(crate) fn insert(&mut self, key: JoinIndexKey, index: JoinIndexV2) {
189 let bytes = index.estimated_bytes();
190 if bytes > self.max_bytes {
191 return;
192 }
193
194 self.evict_until_fits(bytes);
195
196 self.clock = self.clock.saturating_add(1);
197 let last_used = self.clock;
198
199 if let Some(prev) = self.entries.remove(&key) {
200 self.total_bytes = self.total_bytes.saturating_sub(prev.bytes);
201 }
202
203 self.total_bytes = self.total_bytes.saturating_add(bytes);
204 self.entries.insert(
205 key,
206 CachedJoinIndex {
207 index: CachedJoinIndexPayload::Ready(index),
208 bytes,
209 last_used,
210 },
211 );
212 self.stats.builds = self.stats.builds.saturating_add(1);
213 }
214
215 pub(crate) fn remove(&mut self, key: &JoinIndexKey) {
216 if let Some(prev) = self.entries.remove(key) {
217 self.total_bytes = self.total_bytes.saturating_sub(prev.bytes);
218 }
219 }
220
221 pub(crate) fn remove_stale(&mut self, key: &JoinIndexKey) {
222 let before = self.entries.len();
223 self.remove(key);
224 if self.entries.len() < before {
225 self.stats.stale_rejections = self.stats.stale_rejections.saturating_add(1);
226 }
227 }
228
229 pub(crate) fn invalidate_rel(&mut self, rel: RelId) {
230 let keys: Vec<JoinIndexKey> = self
231 .entries
232 .keys()
233 .filter(|k| k.rel == rel)
234 .cloned()
235 .collect();
236 for key in keys {
237 if let Some(entry) = self.entries.remove(&key) {
238 self.total_bytes = self.total_bytes.saturating_sub(entry.bytes);
239 self.stats.invalidations = self.stats.invalidations.saturating_add(1);
240 }
241 }
242 }
243
244 pub(crate) fn evict_until_fits(&mut self, additional_bytes: u64) {
245 while !self.entries.is_empty()
246 && self.total_bytes.saturating_add(additional_bytes) > self.max_bytes
247 {
248 let mut oldest_key: Option<JoinIndexKey> = None;
249 let mut oldest_clock = u64::MAX;
250
251 for (k, v) in &self.entries {
252 if v.last_used < oldest_clock {
253 oldest_clock = v.last_used;
254 oldest_key = Some(k.clone());
255 }
256 }
257
258 let Some(key) = oldest_key else {
259 break;
260 };
261 if let Some(entry) = self.entries.remove(&key) {
262 self.total_bytes = self.total_bytes.saturating_sub(entry.bytes);
263 self.stats.evictions = self.stats.evictions.saturating_add(1);
264 } else {
265 break;
266 }
267 }
268 }
269
270 pub(crate) fn record_background_build_request(&mut self) {
271 self.stats.background_build_requests =
272 self.stats.background_build_requests.saturating_add(1);
273 }
274
275 pub(crate) fn record_background_build_complete(&mut self) {
276 self.stats.background_builds_completed =
277 self.stats.background_builds_completed.saturating_add(1);
278 }
279
280 pub(crate) fn record_background_build_deferred(&mut self) {
281 self.stats.background_builds_deferred =
282 self.stats.background_builds_deferred.saturating_add(1);
283 }
284
285 pub(crate) fn stats(&self) -> JoinIndexCacheStats {
286 let mut stats = self.stats.clone();
287 stats.entries = self.entries.len();
288 stats.total_bytes = self.total_bytes;
289 stats
290 }
291
292 #[cfg(test)]
293 fn insert_test_entry(&mut self, key: JoinIndexKey, bytes: u64) {
294 if bytes > self.max_bytes {
295 return;
296 }
297 self.evict_until_fits(bytes);
298 self.clock = self.clock.saturating_add(1);
299 self.total_bytes = self.total_bytes.saturating_add(bytes);
300 self.entries.insert(
301 key,
302 CachedJoinIndex {
303 index: CachedJoinIndexPayload::Placeholder,
304 bytes,
305 last_used: self.clock,
306 },
307 );
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use xlog_core::{ScalarType, Schema};
315
316 fn schema(cols: Vec<(&str, ScalarType)>) -> Schema {
317 Schema::new(
318 cols.into_iter()
319 .map(|(name, ty)| (name.to_string(), ty))
320 .collect(),
321 )
322 }
323
324 #[test]
325 fn persistent_key_includes_schema_generation_key_and_device() {
326 let u32_schema = schema(vec![("k", ScalarType::U32)]);
327 let u64_schema = schema(vec![("k", ScalarType::U64)]);
328
329 let key = JoinIndexKey::new(RelId(7), 3, vec![0], &u32_schema, 0);
330 assert_eq!(key.rel, RelId(7));
331 assert_eq!(key.version, 3);
332 assert_eq!(key.key_cols, vec![0]);
333 assert_eq!(key.device_ordinal, 0);
334
335 assert_ne!(
336 key,
337 JoinIndexKey::new(RelId(7), 4, vec![0], &u32_schema, 0),
338 "generation/version must partition stale indexes"
339 );
340 assert_ne!(
341 key,
342 JoinIndexKey::new(RelId(7), 3, vec![0], &u64_schema, 0),
343 "schema changes must partition indexes"
344 );
345 assert_ne!(
346 key,
347 JoinIndexKey::new(RelId(7), 3, vec![0], &u32_schema, 1),
348 "device ordinal must partition indexes"
349 );
350 }
351
352 #[test]
353 fn persistent_cache_budget_evicts_lru_and_records_stats() {
354 let schema = schema(vec![("k", ScalarType::U32)]);
355 let key_a = JoinIndexKey::new(RelId(1), 1, vec![0], &schema, 0);
356 let key_b = JoinIndexKey::new(RelId(2), 1, vec![0], &schema, 0);
357 let mut cache = JoinIndexCache::new(100);
358
359 cache.insert_test_entry(key_a, 60);
360 cache.insert_test_entry(key_b, 60);
361
362 let stats = cache.stats();
363 assert_eq!(stats.entries, 1);
364 assert_eq!(stats.total_bytes, 60);
365 assert_eq!(stats.evictions, 1);
366 }
367
368 #[test]
369 fn persistent_cache_invalidation_records_removed_entries() {
370 let schema = schema(vec![("k", ScalarType::U32)]);
371 let key = JoinIndexKey::new(RelId(1), 1, vec![0], &schema, 0);
372 let mut cache = JoinIndexCache::new(100);
373
374 cache.insert_test_entry(key, 32);
375 cache.invalidate_rel(RelId(1));
376
377 let stats = cache.stats();
378 assert_eq!(stats.entries, 0);
379 assert_eq!(stats.total_bytes, 0);
380 assert_eq!(stats.invalidations, 1);
381 }
382}