1use std::collections::HashMap;
4
5use xlog_core::{AggOp, RelId, Result, ScalarType, Schema, XlogError};
6use xlog_cuda::provider::NESTED_LOOP_TOTAL_THRESHOLD;
7use xlog_cuda::{CudaBuffer, JoinType as CudaJoinType};
8use xlog_ir::{JoinType, ProjectExpr, RirNode};
9
10use crate::ilp_registry::{read_device_row_count, IlpMask, IlpTagEntry, IlpTaggedResult};
11
12use super::join_cache::{estimate_join_index_bytes, JoinIndexKey};
13use super::Executor;
14
15fn eligible_for_nested_loop(
40 left: &CudaBuffer,
41 right: &CudaBuffer,
42 left_keys: &[usize],
43 right_keys: &[usize],
44 join_type: JoinType,
45) -> bool {
46 if join_type != JoinType::Inner {
47 return false;
48 }
49 if left_keys.len() != 1 || right_keys.len() != 1 {
50 return false;
51 }
52 let lt = left.schema().column_type(left_keys[0]);
53 let rt = right.schema().column_type(right_keys[0]);
54 lt == rt && matches!(lt, Some(ScalarType::U32) | Some(ScalarType::Symbol))
55}
56
57fn is_join_index_mismatch(err: &XlogError) -> bool {
58 matches!(
59 err,
60 XlogError::Kernel(msg)
61 if msg.contains("Join index row count does not match right relation")
62 || msg.contains("Join index key columns do not match requested right_keys")
63 )
64}
65
66impl Executor {
67 pub(super) fn execute_scan(&mut self, rel: RelId) -> Result<CudaBuffer> {
69 let name = self
70 .get_rel_name(rel)
71 .ok_or_else(|| XlogError::Execution(format!("Unknown relation: RelId({})", rel.0)))?;
72
73 let buffer = self
74 .store
75 .get(name)
76 .ok_or_else(|| XlogError::Execution(format!("Relation not found: {}", name)))?;
77
78 self.stats.record_access(rel);
79 self.stats.update_cardinality(rel, buffer.num_rows());
80 self.stats.update_byte_size(rel, buffer.estimated_bytes());
81
82 self.clone_buffer(buffer)
83 }
84
85 pub fn execute_node(&mut self, node: &RirNode) -> Result<CudaBuffer> {
99 if !self.common_subexpression_enabled() || !Self::is_common_subexpression_cacheable(node) {
100 return self.execute_node_uncached(node);
101 }
102
103 let Some(key) = self.common_subexpression_key(node) else {
104 return self.execute_node_uncached(node);
105 };
106
107 if self.common_subexpression_cache.contains_key(&key) {
108 let cached = self
109 .common_subexpression_cache
110 .remove(&key)
111 .expect("cache key checked above");
112 let result = self.clone_buffer(&cached)?;
113 self.common_subexpression_cache.insert(key, cached);
114 self.common_subexpression_stats.hits =
115 self.common_subexpression_stats.hits.saturating_add(1);
116 return Ok(result);
117 }
118
119 self.common_subexpression_stats.misses =
120 self.common_subexpression_stats.misses.saturating_add(1);
121 let result = self.execute_node_uncached(node)?;
122 let cached = self.clone_buffer(&result)?;
123 self.common_subexpression_cache.insert(key, cached);
124 Ok(result)
125 }
126
127 fn execute_node_uncached(&mut self, node: &RirNode) -> Result<CudaBuffer> {
128 match node {
129 RirNode::Unit => {
130 let mut d_num_rows = self.provider.memory().alloc::<u32>(1)?;
132 self.provider
133 .htod_launch_metadata_sync_copy_into(&[1u32], &mut d_num_rows)
134 .map_err(|e| {
135 XlogError::Kernel(format!("Failed to create unit row count: {}", e))
136 })?;
137 Ok(CudaBuffer::from_columns(
138 Vec::new(),
139 1,
140 d_num_rows,
141 Schema::new(vec![]),
142 ))
143 }
144
145 RirNode::Scan { rel } => {
146 let start = self.profiler.start_op();
147 let result = self.execute_scan(*rel)?;
148 if let Some(start) = start {
149 let mem = self.provider.memory().allocated_bytes();
150 self.profiler
151 .record_op("scan", 0, result.num_rows(), start, mem);
152 self.profiler.record_peak_memory(mem);
153 }
154 Ok(result)
155 }
156
157 RirNode::Filter { input, predicate } => {
158 let input_buf = self.execute_node(input)?;
159 let input_rows = input_buf.num_rows();
160 let start = self.profiler.start_op();
161 let result = self.execute_filter(&input_buf, predicate)?;
162 if let Some(start) = start {
163 let mem = self.provider.memory().allocated_bytes();
164 self.profiler
165 .record_op("filter", input_rows, result.num_rows(), start, mem);
166 self.profiler.record_peak_memory(mem);
167 }
168 Ok(result)
169 }
170
171 RirNode::Project { input, columns } => {
172 let input_buf = self.execute_node(input)?;
173 let input_rows = input_buf.num_rows();
174 let start = self.profiler.start_op();
175 let result = self.execute_project(&input_buf, columns)?;
176 if let Some(start) = start {
177 let mem = self.provider.memory().allocated_bytes();
178 self.profiler
179 .record_op("project", input_rows, result.num_rows(), start, mem);
180 self.profiler.record_peak_memory(mem);
181 }
182 Ok(result)
183 }
184
185 RirNode::Join {
186 left,
187 right,
188 left_keys,
189 right_keys,
190 join_type,
191 } => {
192 let left_rel = match left.as_ref() {
193 RirNode::Scan { rel } => Some(*rel),
194 _ => None,
195 };
196 let right_rel = match right.as_ref() {
197 RirNode::Scan { rel } => Some(*rel),
198 _ => None,
199 };
200 let left_buf = self.execute_node(left)?;
201 let right_buf = self.execute_node(right)?;
202 let input_rows = left_buf.num_rows() + right_buf.num_rows();
203 let start = self.profiler.start_op();
204 let result = self.execute_join(
205 &left_buf, &right_buf, left_keys, right_keys, *join_type, left_rel, right_rel,
206 )?;
207 if let Some(start) = start {
208 let mem = self.provider.memory().allocated_bytes();
209 self.profiler
210 .record_op("join", input_rows, result.num_rows(), start, mem);
211 self.profiler.record_peak_memory(mem);
212 }
213 Ok(result)
214 }
215
216 RirNode::GroupBy {
217 input,
218 key_cols,
219 aggs,
220 } => {
221 if let Some(fused) =
226 self.try_dispatch_wcoj_groupby_root_agg(input, key_cols, aggs)?
227 {
228 return Ok(fused);
229 }
230 let input_buf = self.execute_node(input)?;
231 let input_rows = input_buf.num_rows();
232 let start = self.profiler.start_op();
233 let result = self.execute_groupby(&input_buf, key_cols, aggs)?;
234 if let Some(start) = start {
235 let mem = self.provider.memory().allocated_bytes();
236 self.profiler
237 .record_op("groupby", input_rows, result.num_rows(), start, mem);
238 self.profiler.record_peak_memory(mem);
239 }
240 Ok(result)
241 }
242
243 RirNode::Union { inputs } => {
244 let mut buffers = Vec::with_capacity(inputs.len());
245 let mut input_rows = 0u64;
246 for input in inputs {
247 let buf = self.execute_node(input)?;
248 input_rows += buf.num_rows();
249 buffers.push(buf);
250 }
251 let start = self.profiler.start_op();
252 let result = self.execute_union(&buffers)?;
253 if let Some(start) = start {
254 let mem = self.provider.memory().allocated_bytes();
255 self.profiler
256 .record_op("union", input_rows, result.num_rows(), start, mem);
257 self.profiler.record_peak_memory(mem);
258 }
259 Ok(result)
260 }
261
262 RirNode::Distinct { input, key_cols } => {
263 let input_buf = self.execute_node(input)?;
264 let input_rows = input_buf.num_rows();
265 let start = self.profiler.start_op();
266 let result = self.execute_distinct(&input_buf, key_cols)?;
267 if let Some(start) = start {
268 let mem = self.provider.memory().allocated_bytes();
269 self.profiler
270 .record_op("dedup", input_rows, result.num_rows(), start, mem);
271 self.profiler.record_peak_memory(mem);
272 }
273 Ok(result)
274 }
275
276 RirNode::Diff { left, right } => {
277 let left_buf = self.execute_node(left)?;
278 let right_buf = self.execute_node(right)?;
279 let input_rows = left_buf.num_rows() + right_buf.num_rows();
280 let start = self.profiler.start_op();
281 let result = self.execute_diff(&left_buf, &right_buf)?;
282 if let Some(start) = start {
283 let mem = self.provider.memory().allocated_bytes();
284 self.profiler
285 .record_op("diff", input_rows, result.num_rows(), start, mem);
286 self.profiler.record_peak_memory(mem);
287 }
288 Ok(result)
289 }
290
291 RirNode::Fixpoint {
292 scc_id,
293 base,
294 recursive,
295 delta_rel,
296 full_rel,
297 } => {
298 self.execute_fixpoint(*scc_id, base, recursive, *delta_rel, *full_rel)
300 }
301 RirNode::TensorMaskedJoin {
302 mask_name,
303 schema_size,
304 left_keys,
305 right_keys,
306 rel_index,
307 head_rel_name,
308 max_active_rules,
309 head_projection,
310 ..
311 } => self.execute_tensor_masked_join(
312 mask_name,
313 *schema_size,
314 left_keys,
315 right_keys,
316 rel_index,
317 head_rel_name,
318 *max_active_rules,
319 head_projection,
320 ),
321 RirNode::MultiWayJoin { fallback, .. } | RirNode::ChainJoin { fallback, .. } => {
328 self.execute_node(fallback)
329 }
330 }
331 }
332
333 #[allow(clippy::too_many_arguments)]
337 fn execute_join(
338 &mut self,
339 left: &CudaBuffer,
340 right: &CudaBuffer,
341 left_keys: &[usize],
342 right_keys: &[usize],
343 join_type: JoinType,
344 left_rel: Option<RelId>,
345 right_rel: Option<RelId>,
346 ) -> Result<CudaBuffer> {
347 let cuda_join_type = match join_type {
350 JoinType::Inner => CudaJoinType::Inner,
351 JoinType::Semi => CudaJoinType::Semi,
352 JoinType::Anti => CudaJoinType::Anti,
353 JoinType::LeftOuter => CudaJoinType::LeftOuter,
354 };
355
356 let mut out: Option<CudaBuffer> = None;
361
362 if eligible_for_nested_loop(left, right, left_keys, right_keys, join_type) {
374 let num_left = self.provider.device_row_count(left)? as u64;
375 let num_right = self.provider.device_row_count(right)? as u64;
376 let in_threshold = num_left
377 .checked_mul(num_right)
378 .map(|p| p <= NESTED_LOOP_TOTAL_THRESHOLD)
379 .unwrap_or(false);
380 if in_threshold {
381 out = Some(self.provider.nested_loop_join_v2_inner_u32_1key(
382 left,
383 right,
384 left_keys[0],
385 right_keys[0],
386 )?);
387 self.nested_loop_dispatch_count += 1;
388 }
389 }
390
391 if out.is_none() && self.config.resolved_persistent_hash_indexes() {
397 if let Some(build_rel) = right_rel {
398 let build_heat = self
399 .stats
400 .get_relation_stats(build_rel)
401 .map(|s| s.heat)
402 .unwrap_or(0.0);
403 let est_index_bytes = estimate_join_index_bytes(right, right_keys);
404 let budget_bytes = self.provider.memory().budget().device_bytes;
405 let remaining_bytes = self.provider.memory().remaining_bytes();
406
407 let should_index = self.join_index_cache.should_build(
408 est_index_bytes,
409 build_heat,
410 remaining_bytes,
411 budget_bytes,
412 );
413
414 if let Some(build_name) = self.get_rel_name(build_rel).map(|s| s.to_string()) {
415 if let Some(version) = self.store.version(&build_name) {
416 let key = JoinIndexKey::new(
417 build_rel,
418 version,
419 right_keys.to_vec(),
420 right.schema(),
421 self.provider.device().ordinal() as u32,
422 );
423
424 let indexed_result = {
425 self.join_index_cache.get(&key).map(|index| {
426 self.provider.hash_join_v2_with_index(
427 left,
428 right,
429 left_keys,
430 right_keys,
431 cuda_join_type,
432 index,
433 None,
434 )
435 })
436 };
437 if let Some(indexed_result) = indexed_result {
438 match indexed_result {
439 Ok(joined) => out = Some(joined),
440 Err(err) if is_join_index_mismatch(&err) => {
441 self.join_index_cache.remove_stale(&key);
442 }
443 Err(err) => return Err(err),
444 }
445 } else if should_index {
446 let background_build = self
447 .config
448 .resolved_persistent_hash_index_background_build();
449 if background_build {
450 self.join_index_cache.record_background_build_request();
451 }
452 let build_result = if background_build {
453 self.provider
454 .build_join_index_v2_background(right, right_keys)
455 } else {
456 self.provider.build_join_index_v2(right, right_keys)
457 };
458 match build_result {
459 Ok(index) => {
460 if background_build {
461 self.join_index_cache.record_background_build_complete();
462 self.join_index_cache.insert(key, index);
463 self.join_index_cache.record_background_build_deferred();
464 if let Some(stats) =
465 self.stats.get_relation_stats_mut(build_rel)
466 {
467 stats.has_index = true;
468 }
469 } else {
470 match self.provider.hash_join_v2_with_index(
471 left,
472 right,
473 left_keys,
474 right_keys,
475 cuda_join_type,
476 &index,
477 None,
478 ) {
479 Ok(joined) => {
480 self.join_index_cache.insert(key, index);
481 if let Some(stats) =
482 self.stats.get_relation_stats_mut(build_rel)
483 {
484 stats.has_index = true;
485 }
486 out = Some(joined);
487 }
488 Err(err) if is_join_index_mismatch(&err) => {}
489 Err(err) => return Err(err),
490 }
491 }
492 }
493 Err(_) => {
494 }
496 }
497 }
498 }
499 }
500 }
501 } let out = match out {
504 Some(buf) => buf,
505 None => {
506 self.provider
507 .hash_join_v2(left, right, left_keys, right_keys, cuda_join_type)?
508 }
509 };
510
511 if let (Some(l), Some(r)) = (left_rel, right_rel) {
512 let input_rows = left.num_rows().saturating_mul(right.num_rows());
513 self.record_adaptive_join_observation(
514 l,
515 r,
516 left_keys,
517 right_keys,
518 input_rows,
519 out.num_rows(),
520 );
521 self.stats.record_join_result(
522 l,
523 r,
524 left_keys.to_vec(),
525 right_keys.to_vec(),
526 input_rows,
527 out.num_rows(),
528 );
529 }
530
531 Ok(out)
532 }
533
534 fn execute_groupby(
538 &self,
539 input: &CudaBuffer,
540 key_cols: &[usize],
541 aggs: &[(usize, AggOp)],
542 ) -> Result<CudaBuffer> {
543 if aggs.is_empty() {
544 return self.provider.dedup(input, key_cols);
546 }
547
548 self.provider.groupby_multi_agg(input, key_cols, aggs)
550 }
551
552 pub(super) fn execute_union(&self, inputs: &[CudaBuffer]) -> Result<CudaBuffer> {
556 if inputs.is_empty() {
557 return self.provider.create_empty_buffer(Schema::new(vec![]));
558 }
559
560 if inputs.len() == 1 {
561 return self.clone_buffer(&inputs[0]);
562 }
563
564 let mut result = self.clone_buffer(&inputs[0])?;
566 for input in inputs.iter().skip(1) {
567 result = self.provider.union_gpu(&result, input)?;
568 }
569
570 Ok(result)
571 }
572
573 pub(super) fn execute_distinct(
577 &self,
578 input: &CudaBuffer,
579 key_cols: &[usize],
580 ) -> Result<CudaBuffer> {
581 self.provider.dedup(input, key_cols)
582 }
583
584 pub(super) fn execute_diff(&self, left: &CudaBuffer, right: &CudaBuffer) -> Result<CudaBuffer> {
588 self.provider.diff_gpu(left, right)
589 }
590
591 #[allow(clippy::too_many_arguments)]
592 fn execute_tensor_masked_join(
593 &mut self,
594 mask_name: &str,
595 schema_size: usize,
596 left_keys: &[usize],
597 right_keys: &[usize],
598 rel_index: &[(RelId, String)],
599 head_rel_name: &str,
600 max_active_rules: usize,
601 head_projection: &[usize],
602 ) -> Result<CudaBuffer> {
603 let ilp_mask = match self.ilp_registry.get_mask(mask_name) {
607 Some(mask) => mask,
608 None => {
609 self.ilp_last_result = Some(IlpTaggedResult {
610 entries: Vec::new(),
611 });
612 let schema = self
614 .store
615 .get(head_rel_name)
616 .map(|buf| buf.schema().clone())
617 .ok_or_else(|| {
618 XlogError::Execution(format!(
619 "TensorMaskedJoin: head relation '{}' not found in store \
620 (was load_facts_into_store called?)",
621 head_rel_name
622 ))
623 })?;
624 return self.provider.create_empty_buffer(schema);
625 }
626 };
627
628 let start = self.profiler.start_op();
629
630 let head_k = rel_index
631 .iter()
632 .position(|(_, name)| name == head_rel_name)
633 .ok_or_else(|| {
634 XlogError::Execution(format!(
635 "TensorMaskedJoin: head relation '{}' not found in rel_index",
636 head_rel_name
637 ))
638 })? as u32;
639
640 let mut tag_entries: Vec<IlpTagEntry> = Vec::new();
641 let mut process_rule = |i: u32,
642 j: u32,
643 k: u32,
644 strict_candidate_idx: Option<usize>,
645 strict_flags: Option<&CudaBuffer>|
646 -> Result<()> {
647 if k != head_k {
648 return Ok(());
649 }
650
651 let (_, left_name) = &rel_index[i as usize];
652 let (_, right_name) = &rel_index[j as usize];
653
654 let left_buf = match self.store.get(left_name) {
655 Some(buf) if buf.arity() > 0 => buf,
656 _ => return Ok(()),
657 };
658 let right_buf = match self.store.get(right_name) {
659 Some(buf) if buf.arity() > 0 => buf,
660 _ => return Ok(()),
661 };
662
663 let left_max_key = left_keys.iter().copied().max().unwrap_or(0);
670 let right_max_key = right_keys.iter().copied().max().unwrap_or(0);
671 if left_buf.arity() <= left_max_key || right_buf.arity() <= right_max_key {
672 return Ok(());
673 }
674
675 let joined = self.provider.hash_join_v2(
676 left_buf,
677 right_buf,
678 left_keys,
679 right_keys,
680 CudaJoinType::Inner,
681 )?;
682
683 let projected = if !head_projection.is_empty() && head_projection.len() < joined.arity()
687 {
688 let proj_exprs: Vec<ProjectExpr> = head_projection
689 .iter()
690 .map(|&col| ProjectExpr::Column(col))
691 .collect();
692 self.execute_project(&joined, &proj_exprs)?
693 } else {
694 joined
695 };
696
697 let projected = if let (Some(candidate_idx), Some(active_flags)) =
698 (strict_candidate_idx, strict_flags)
699 {
700 self.provider.filter_buffer_by_candidate_flag(
701 &projected,
702 active_flags,
703 candidate_idx,
704 )?
705 } else {
706 projected
707 };
708
709 let num_rows = read_device_row_count(&self.provider, &projected)? as u32;
711
712 if num_rows > 0 {
713 tag_entries.push(IlpTagEntry {
714 i,
715 j,
716 k,
717 num_rows,
718 buffer: Some(projected),
719 });
720 }
721 Ok(())
722 };
723
724 let active_rule_count = match ilp_mask {
725 IlpMask::Dense { hard, soft, .. } => {
726 let active_rules = self.provider.extract_active_rule_indices(
727 hard,
728 soft,
729 schema_size,
730 max_active_rules,
731 )?;
732 let count = active_rules.len() as u64;
733 for &(i, j, k) in &active_rules {
734 process_rule(i, j, k, None, None)?;
735 }
736 count
737 }
738 IlpMask::Sparse { active_entries, .. } => {
739 let limit = max_active_rules.min(active_entries.len());
740 for &(i, j, k) in &active_entries[..limit] {
741 process_rule(i, j, k, None, None)?;
742 }
743 limit as u64
744 }
745 IlpMask::SparseDevice {
746 candidate_order,
747 active_flags,
748 selected_count,
749 ..
750 } => {
751 if *selected_count > 0 {
752 for (candidate_idx, &(i, j, k)) in candidate_order.iter().enumerate() {
753 process_rule(i, j, k, Some(candidate_idx), Some(active_flags))?;
754 }
755 }
756 (*selected_count).min(max_active_rules) as u64
757 }
758 };
759
760 let mut bufs_by_k: HashMap<u32, Vec<&CudaBuffer>> = HashMap::new();
762 for entry in &tag_entries {
763 if let Some(ref buf) = entry.buffer {
764 bufs_by_k.entry(entry.k).or_default().push(buf);
765 }
766 }
767
768 for (k, buffers) in bufs_by_k {
769 let (_, target_name) = &rel_index[k as usize];
770
771 let union_buf = if buffers.len() == 1 {
773 let empty = self
775 .provider
776 .create_empty_buffer(buffers[0].schema().clone())?;
777 self.provider.union_gpu(buffers[0], &empty)?
778 } else {
779 let mut acc = self.provider.union_gpu(buffers[0], buffers[1])?;
780 for buf in &buffers[2..] {
781 acc = self.provider.union_gpu(&acc, buf)?;
782 }
783 acc
784 };
785
786 if let Some(existing) = self.store.get(target_name) {
788 let delta = self.provider.diff_gpu(&union_buf, existing)?;
789 if !delta.is_empty() {
790 let merged = self.provider.union_gpu(existing, &delta)?;
791 self.store_put(target_name, merged);
792 }
793 } else {
794 let key_cols: Vec<usize> = (0..union_buf.arity()).collect();
795 let deduped = self.provider.dedup(&union_buf, &key_cols)?;
796 self.store_put(target_name, deduped);
797 }
798 }
799
800 self.ilp_last_result = Some(IlpTaggedResult {
802 entries: tag_entries,
803 });
804
805 if let Some(start) = start {
806 let mem = self.provider.memory().allocated_bytes();
807 self.profiler
808 .record_op("TensorMaskedJoin", 0, active_rule_count, start, mem);
809 }
810
811 let schema = self
813 .store
814 .get(head_rel_name)
815 .map(|buf| buf.schema().clone())
816 .ok_or_else(|| {
817 XlogError::Execution(format!(
818 "TensorMaskedJoin: head relation '{}' not found in store \
819 (was load_facts_into_store called?)",
820 head_rel_name
821 ))
822 })?;
823 self.provider.create_empty_buffer(schema)
824 }
825}