1use std::collections::HashMap;
8use std::sync::Arc;
9
10use xlog_core::Schema;
11use xlog_cuda::{CudaBuffer, CudaKernelProvider};
12
13pub struct RelationStore {
45 provider: Arc<CudaKernelProvider>,
47 relations: HashMap<String, VersionedCudaBuffer>,
49}
50
51struct VersionedCudaBuffer {
52 buffer: CudaBuffer,
53 version: u64,
54}
55
56impl RelationStore {
57 pub fn new(provider: Arc<CudaKernelProvider>) -> Self {
59 Self {
60 provider,
61 relations: HashMap::new(),
62 }
63 }
64
65 pub fn get(&self, name: &str) -> Option<&CudaBuffer> {
73 self.relations.get(name).map(|e| &e.buffer)
74 }
75
76 pub fn get_mut(&mut self, name: &str) -> Option<&mut CudaBuffer> {
84 self.relations.get_mut(name).map(|e| {
85 e.version = e.version.saturating_add(1);
88 &mut e.buffer
89 })
90 }
91
92 pub fn get_with_version(&self, name: &str) -> Option<(&CudaBuffer, u64)> {
94 self.relations.get(name).map(|e| (&e.buffer, e.version))
95 }
96
97 pub fn version(&self, name: &str) -> Option<u64> {
99 self.relations.get(name).map(|e| e.version)
100 }
101
102 pub fn put(&mut self, name: &str, buffer: CudaBuffer) {
110 let version = self
111 .relations
112 .get(name)
113 .map(|e| e.version.saturating_add(1))
114 .unwrap_or(1);
115 self.relations
116 .insert(name.to_string(), VersionedCudaBuffer { buffer, version });
117 }
118
119 pub fn get_or_insert_empty(
132 &mut self,
133 name: &str,
134 schema: &Schema,
135 ) -> xlog_core::Result<&CudaBuffer> {
136 if !self.relations.contains_key(name) {
137 let buffer = self.provider.create_empty_buffer(schema.clone())?;
138 self.relations
139 .insert(name.to_string(), VersionedCudaBuffer { buffer, version: 1 });
140 }
141 Ok(&self
142 .relations
143 .get(name)
144 .expect("Relation must exist after insertion")
145 .buffer)
146 }
147
148 pub fn get_or_insert_empty_mut(
161 &mut self,
162 name: &str,
163 schema: &Schema,
164 ) -> xlog_core::Result<&mut CudaBuffer> {
165 if !self.relations.contains_key(name) {
166 let buffer = self.provider.create_empty_buffer(schema.clone())?;
167 self.relations
168 .insert(name.to_string(), VersionedCudaBuffer { buffer, version: 1 });
169 }
170 let entry = self
171 .relations
172 .get_mut(name)
173 .expect("Relation must exist after insertion");
174 entry.version = entry.version.saturating_add(1);
175 Ok(&mut entry.buffer)
176 }
177
178 pub fn contains(&self, name: &str) -> bool {
186 self.relations.contains_key(name)
187 }
188
189 pub fn remove(&mut self, name: &str) -> Option<CudaBuffer> {
197 self.relations.remove(name).map(|e| e.buffer)
198 }
199
200 pub fn clear(&mut self) {
205 self.relations.clear();
206 }
207
208 pub fn len(&self) -> usize {
210 self.relations.len()
211 }
212
213 pub fn is_empty(&self) -> bool {
215 self.relations.is_empty()
216 }
217
218 pub fn names(&self) -> impl Iterator<Item = &str> {
220 self.relations.keys().map(|s| s.as_str())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use std::sync::Arc;
228 use xlog_core::{MemoryBudget, ScalarType};
229 use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
230
231 fn setup_provider() -> Option<Arc<CudaKernelProvider>> {
232 let device = match CudaDevice::new(0) {
233 Ok(d) => Arc::new(d),
234 Err(e) => {
235 eprintln!("Skipping: CUDA runtime unavailable: {}", e);
236 return None;
237 }
238 };
239 let memory = Arc::new(GpuMemoryManager::new(
240 device.clone(),
241 MemoryBudget::with_limit(1024 * 1024 * 1024),
242 ));
243 CudaKernelProvider::new(device, memory).ok().map(Arc::new)
244 }
245
246 fn setup_store() -> Option<(RelationStore, Arc<CudaKernelProvider>)> {
247 let provider = setup_provider()?;
248 let store = RelationStore::new(provider.clone());
249 Some((store, provider))
250 }
251
252 fn test_schema() -> Schema {
253 Schema::new(vec![
254 ("a".to_string(), ScalarType::U32),
255 ("b".to_string(), ScalarType::U64),
256 ])
257 }
258
259 fn device_row_count(provider: &CudaKernelProvider, buffer: &CudaBuffer) -> u32 {
260 let mut host_rows = [0u32];
261 provider
262 .device()
263 .inner()
264 .dtoh_sync_copy_into(buffer.num_rows_device(), &mut host_rows)
265 .expect("dtoh row count");
266 host_rows[0]
267 }
268
269 fn make_buffer(provider: &CudaKernelProvider, schema: Schema, rows: usize) -> CudaBuffer {
270 if schema.arity() == 0 {
271 if rows == 0 {
272 return provider.create_empty_buffer(schema).expect("empty buffer");
273 }
274 let rows_u32 = u32::try_from(rows).expect("row count fits u32");
275 let mut d_num_rows = provider.memory().alloc::<u32>(1).expect("alloc");
276 provider
277 .device()
278 .inner()
279 .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
280 .expect("htod row count");
281 return CudaBuffer::from_columns(Vec::new(), rows as u64, d_num_rows, schema);
282 }
283 if rows == 0 {
284 return provider.create_empty_buffer(schema).expect("empty buffer");
285 }
286 let mut columns: Vec<Vec<u8>> = Vec::with_capacity(schema.arity());
287 for col_idx in 0..schema.arity() {
288 let size = schema
289 .column_type(col_idx)
290 .map(|t| t.size_bytes())
291 .unwrap_or(4);
292 columns.push(vec![0u8; rows * size]);
293 }
294 let slices: Vec<&[u8]> = columns.iter().map(|c| c.as_slice()).collect();
295 provider
296 .create_buffer_from_slices(&slices, schema)
297 .expect("buffer")
298 }
299
300 #[test]
301 fn test_new_store_is_empty() {
302 let Some((store, _provider)) = setup_store() else {
303 return;
304 };
305 assert!(store.is_empty());
306 assert_eq!(store.len(), 0);
307 }
308
309 #[test]
310 fn test_put_and_get() {
311 let Some((mut store, provider)) = setup_store() else {
312 return;
313 };
314 let buffer = provider
315 .create_empty_buffer(Schema::new(vec![]))
316 .expect("empty");
317
318 store.put("test_rel", buffer);
319
320 assert!(store.contains("test_rel"));
321 assert!(!store.is_empty());
322 assert_eq!(store.len(), 1);
323
324 let retrieved = store.get("test_rel");
325 assert!(retrieved.is_some());
326 }
327
328 #[test]
329 fn test_get_nonexistent() {
330 let Some((store, _provider)) = setup_store() else {
331 return;
332 };
333 assert!(store.get("nonexistent").is_none());
334 }
335
336 #[test]
337 fn test_contains() {
338 let Some((mut store, provider)) = setup_store() else {
339 return;
340 };
341
342 assert!(!store.contains("test"));
343
344 store.put(
345 "test",
346 provider
347 .create_empty_buffer(Schema::new(vec![]))
348 .expect("empty"),
349 );
350
351 assert!(store.contains("test"));
352 assert!(!store.contains("other"));
353 }
354
355 #[test]
356 fn test_remove() {
357 let Some((mut store, provider)) = setup_store() else {
358 return;
359 };
360 store.put(
361 "test",
362 provider
363 .create_empty_buffer(Schema::new(vec![]))
364 .expect("empty"),
365 );
366
367 assert!(store.contains("test"));
368
369 let removed = store.remove("test");
370 assert!(removed.is_some());
371 assert!(!store.contains("test"));
372 assert!(store.is_empty());
373 }
374
375 #[test]
376 fn test_remove_nonexistent() {
377 let Some((mut store, _provider)) = setup_store() else {
378 return;
379 };
380 let removed = store.remove("nonexistent");
381 assert!(removed.is_none());
382 }
383
384 #[test]
385 fn test_clear() {
386 let Some((mut store, provider)) = setup_store() else {
387 return;
388 };
389 let empty = provider
390 .create_empty_buffer(Schema::new(vec![]))
391 .expect("empty");
392 store.put("rel1", empty);
393 store.put(
394 "rel2",
395 provider
396 .create_empty_buffer(Schema::new(vec![]))
397 .expect("empty"),
398 );
399 store.put(
400 "rel3",
401 provider
402 .create_empty_buffer(Schema::new(vec![]))
403 .expect("empty"),
404 );
405
406 assert_eq!(store.len(), 3);
407
408 store.clear();
409
410 assert!(store.is_empty());
411 assert_eq!(store.len(), 0);
412 }
413
414 #[test]
415 fn test_get_or_insert_empty_existing() {
416 let Some((mut store, provider)) = setup_store() else {
417 return;
418 };
419 let schema = test_schema();
420
421 let buffer = make_buffer(&provider, schema.clone(), 100);
422 store.put("existing", buffer);
423
424 let result = store.get_or_insert_empty("existing", &schema).unwrap();
425 assert_eq!(device_row_count(&provider, result), 100);
426 assert_eq!(result.schema(), &schema);
427 assert_eq!(store.len(), 1);
428 }
429
430 #[test]
431 fn test_get_or_insert_empty_nonexistent() {
432 let Some((mut store, provider)) = setup_store() else {
433 return;
434 };
435 let schema = test_schema();
436
437 assert!(store.is_empty());
438
439 let result = store.get_or_insert_empty("nonexistent", &schema).unwrap();
440 assert_eq!(device_row_count(&provider, result), 0);
441 assert_eq!(result.schema(), &schema);
442 assert!(result.is_empty());
443
444 assert!(store.contains("nonexistent"));
445 assert_eq!(store.len(), 1);
446 }
447
448 #[test]
449 fn test_get_mut() {
450 let Some((mut store, provider)) = setup_store() else {
451 return;
452 };
453 let buffer = make_buffer(&provider, Schema::new(vec![]), 10);
454 store.put("test", buffer);
455
456 {
457 let buf_mut = store.get_mut("test").unwrap();
458 buf_mut.row_cap = 50;
459 provider
460 .device()
461 .inner()
462 .htod_sync_copy_into(&[50u32], &mut buf_mut.d_num_rows)
463 .expect("htod row count");
464 }
465
466 assert_eq!(device_row_count(&provider, store.get("test").unwrap()), 50);
467 }
468
469 #[test]
470 fn test_get_mut_nonexistent() {
471 let Some((mut store, _provider)) = setup_store() else {
472 return;
473 };
474 assert!(store.get_mut("nonexistent").is_none());
475 }
476
477 #[test]
478 fn test_get_or_insert_empty_mut() {
479 let Some((mut store, provider)) = setup_store() else {
480 return;
481 };
482 let schema = test_schema();
483
484 {
485 let buf_mut = store.get_or_insert_empty_mut("new_rel", &schema).unwrap();
486 assert_eq!(device_row_count(&provider, buf_mut), 0);
487 buf_mut.row_cap = 42;
488 provider
489 .device()
490 .inner()
491 .htod_sync_copy_into(&[42u32], &mut buf_mut.d_num_rows)
492 .expect("htod row count");
493 }
494
495 assert!(store.contains("new_rel"));
496 assert_eq!(
497 device_row_count(&provider, store.get("new_rel").unwrap()),
498 42
499 );
500 }
501
502 #[test]
503 fn test_put_replaces_existing() {
504 let Some((mut store, provider)) = setup_store() else {
505 return;
506 };
507
508 let buffer1 = make_buffer(&provider, Schema::new(vec![]), 10);
509 let buffer2 = make_buffer(&provider, Schema::new(vec![]), 20);
510
511 store.put("test", buffer1);
512 assert_eq!(device_row_count(&provider, store.get("test").unwrap()), 10);
513
514 store.put("test", buffer2);
515 assert_eq!(device_row_count(&provider, store.get("test").unwrap()), 20);
516 assert_eq!(store.len(), 1);
517 }
518
519 #[test]
520 fn test_names_iterator() {
521 let Some((mut store, provider)) = setup_store() else {
522 return;
523 };
524 store.put(
525 "alpha",
526 provider
527 .create_empty_buffer(Schema::new(vec![]))
528 .expect("empty"),
529 );
530 store.put(
531 "beta",
532 provider
533 .create_empty_buffer(Schema::new(vec![]))
534 .expect("empty"),
535 );
536 store.put(
537 "gamma",
538 provider
539 .create_empty_buffer(Schema::new(vec![]))
540 .expect("empty"),
541 );
542
543 let mut names: Vec<&str> = store.names().collect();
544 names.sort();
545
546 assert_eq!(names, vec!["alpha", "beta", "gamma"]);
547 }
548
549 #[test]
550 fn test_multiple_operations() {
551 let Some((mut store, provider)) = setup_store() else {
552 return;
553 };
554
555 let empty = provider
556 .create_empty_buffer(Schema::new(vec![]))
557 .expect("empty");
558 store.put("a", empty);
559 store.put(
560 "b",
561 provider
562 .create_empty_buffer(Schema::new(vec![]))
563 .expect("empty"),
564 );
565 store.put(
566 "c",
567 provider
568 .create_empty_buffer(Schema::new(vec![]))
569 .expect("empty"),
570 );
571 assert_eq!(store.len(), 3);
572
573 store.remove("b");
574 assert_eq!(store.len(), 2);
575 assert!(!store.contains("b"));
576
577 store.put("a", make_buffer(&provider, Schema::new(vec![]), 50));
578 assert_eq!(store.len(), 2);
579 assert_eq!(device_row_count(&provider, store.get("a").unwrap()), 50);
580
581 store.clear();
582 assert!(store.is_empty());
583 }
584}