1use std::collections::HashSet;
77
78use xlog_core::{RelId, Result, ScalarType, Schema};
79use xlog_cuda::device_runtime::StreamId;
80use xlog_cuda::provider::NESTED_LOOP_TOTAL_THRESHOLD;
81use xlog_cuda::wcoj_metadata::{Wcoj4CycleRootAggValue, WcojRootAggValue};
82use xlog_cuda::CudaBuffer;
83use xlog_cuda::JoinType as CudaJoinType;
84use xlog_ir::{
85 rir::{KCliqueVariableOrder, MultiwayPlan, ProjectExpr, VariableOrder},
86 CompiledRule, RirNode,
87};
88
89use super::Executor;
90
91#[cfg(feature = "wcoj-phase-timing")]
92use std::time::Instant;
93
94pub const ENV_USE_WCOJ_TRIANGLE_U32: &str = "XLOG_USE_WCOJ_TRIANGLE_U32";
98
99pub(super) fn wcoj_gate_enabled(config_override: Option<bool>) -> bool {
103 if let Some(v) = config_override {
104 return v;
105 }
106 std::env::var(ENV_USE_WCOJ_TRIANGLE_U32)
107 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
108 .unwrap_or(false)
109}
110
111pub const ENV_WCOJ_BLOCK_WORK_UNIT: &str = "XLOG_WCOJ_BLOCK_WORK_UNIT";
112pub(super) const WCOJ_BLOCK_WORK_UNIT_DEFAULT: u32 = 1024;
113pub(super) const WCOJ_BLOCK_WORK_UNIT_MAX: u32 = 8192;
114
115pub(super) fn wcoj_block_work_unit() -> u32 {
116 match std::env::var(ENV_WCOJ_BLOCK_WORK_UNIT) {
117 Ok(raw) => match raw.trim().parse::<u32>() {
118 Ok(v @ 1..=WCOJ_BLOCK_WORK_UNIT_MAX) => v,
119 Ok(v) => {
120 eprintln!(
121 "warning: {ENV_WCOJ_BLOCK_WORK_UNIT}={v} is outside 1..={WCOJ_BLOCK_WORK_UNIT_MAX}; \
122 using {WCOJ_BLOCK_WORK_UNIT_DEFAULT}"
123 );
124 WCOJ_BLOCK_WORK_UNIT_DEFAULT
125 }
126 Err(_) => {
127 eprintln!(
128 "warning: {ENV_WCOJ_BLOCK_WORK_UNIT}={raw:?} is not a u32; \
129 using {WCOJ_BLOCK_WORK_UNIT_DEFAULT}"
130 );
131 WCOJ_BLOCK_WORK_UNIT_DEFAULT
132 }
133 },
134 Err(_) => WCOJ_BLOCK_WORK_UNIT_DEFAULT,
135 }
136}
137
138pub(super) fn wcoj_adaptive_enabled(config_override: Option<bool>) -> bool {
139 config_override.unwrap_or(true)
140}
141
142pub const ENV_DISABLE_WCOJ_GROUPBY_FUSION: &str = "XLOG_DISABLE_WCOJ_GROUPBY_FUSION";
146
147pub(super) fn wcoj_groupby_fusion_disabled() -> bool {
148 std::env::var(ENV_DISABLE_WCOJ_GROUPBY_FUSION)
149 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
150 .unwrap_or(false)
151}
152
153pub const ENV_DISABLE_FREE_JOIN: &str = "XLOG_DISABLE_FREE_JOIN";
157
158pub(super) fn free_join_disabled() -> bool {
159 std::env::var(ENV_DISABLE_FREE_JOIN)
160 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
161 .unwrap_or(false)
162}
163
164pub const ENV_DISABLE_FACTORIZED_DELTA: &str = "XLOG_DISABLE_FACTORIZED_DELTA";
168
169pub(super) fn factorized_delta_disabled() -> bool {
170 std::env::var(ENV_DISABLE_FACTORIZED_DELTA)
171 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
172 .unwrap_or(false)
173}
174
175const FACTORIZED_DELTA_DEFAULT_MAX_DOMAIN: u32 = 1 << 14;
179
180fn factorized_delta_max_domain() -> u32 {
181 std::env::var("XLOG_FACTORIZED_DELTA_MAX_DOMAIN")
182 .ok()
183 .and_then(|v| v.parse::<u32>().ok())
184 .unwrap_or(FACTORIZED_DELTA_DEFAULT_MAX_DOMAIN)
185 .min(xlog_cuda::provider::FJ_DELTA_MAX_DOMAIN)
186}
187
188fn factorized_delta_max_table_bytes(budget_bytes: u64) -> u64 {
193 std::env::var("XLOG_FACTORIZED_DELTA_MAX_TABLE_BYTES")
194 .ok()
195 .and_then(|v| v.parse::<u64>().ok())
196 .unwrap_or(budget_bytes / 2)
197}
198
199fn factorized_delta_work_divisor() -> u64 {
203 std::env::var("XLOG_FACTORIZED_DELTA_WORK_DIVISOR")
204 .ok()
205 .and_then(|v| v.parse::<u64>().ok())
206 .filter(|&v| v >= 1)
207 .unwrap_or(8)
208}
209
210#[derive(Default)]
216pub(super) struct FactorizedDeltaCtx {
217 domain_by_key: std::collections::HashMap<(String, RelId), Option<u32>>,
218 static_norm_cache: std::collections::HashMap<(RelId, usize), CudaBuffer>,
219}
220
221pub const ENV_WCOJ_STRICT: &str = "XLOG_WCOJ_STRICT";
228
229pub(super) fn wcoj_strict_errors_enabled() -> bool {
230 std::env::var(ENV_WCOJ_STRICT)
231 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
232 .unwrap_or(false)
233}
234
235pub(super) fn wcoj_decline_on_error(
241 counter: &mut u64,
242 stage: &str,
243 err: xlog_core::XlogError,
244) -> Result<Option<CudaBuffer>> {
245 *counter += 1;
246 if wcoj_strict_errors_enabled() {
247 return Err(err);
248 }
249 eprintln!("warning: WCOJ {stage} pipeline error; declining to binary-join fallback: {err}");
250 Ok(None)
251}
252
253pub const ENV_WCOJ_CHAIN_ENABLE: &str = "XLOG_WCOJ_CHAIN_ENABLE";
257
258pub(super) fn chain_dispatch_enabled() -> bool {
259 std::env::var(ENV_WCOJ_CHAIN_ENABLE)
260 .map(|v| !(v == "0" || v.eq_ignore_ascii_case("false")))
261 .unwrap_or(true)
262}
263
264pub const ENV_USE_WCOJ_4CYCLE: &str = "XLOG_USE_WCOJ_4CYCLE";
278
279pub const ENV_USE_WCOJ_4CYCLE_ADAPTIVE: &str = "XLOG_USE_WCOJ_4CYCLE_ADAPTIVE";
281
282pub const ENV_DISABLE_WCOJ_4CYCLE: &str = "XLOG_DISABLE_WCOJ_4CYCLE";
284
285pub(super) fn wcoj_4cycle_gate_enabled(config_override: Option<bool>) -> bool {
287 if let Some(v) = config_override {
288 return v;
289 }
290 std::env::var(ENV_USE_WCOJ_4CYCLE)
291 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
292 .unwrap_or(false)
293}
294
295pub(super) fn wcoj_4cycle_adaptive_enabled(config_override: Option<bool>) -> bool {
305 if let Some(v) = config_override {
306 return v;
307 }
308 std::env::var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE)
309 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
310 .unwrap_or(false)
311}
312
313pub(super) fn wcoj_4cycle_disabled(config_override: Option<bool>) -> bool {
315 if let Some(v) = config_override {
316 return v;
317 }
318 std::env::var(ENV_DISABLE_WCOJ_4CYCLE)
319 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
320 .unwrap_or(false)
321}
322
323#[derive(Clone, Copy, Debug, PartialEq, Eq)]
325enum DispatchMode {
326 Force,
327 CostModel,
328}
329
330pub(super) struct ChainRirMatch {
333 pub rel_left: RelId,
334 pub rel_right: RelId,
335 pub left_key: usize,
336 pub right_key: usize,
337 pub output_columns: Vec<ProjectExpr>,
338}
339
340pub(super) fn match_chain_join(body: &RirNode) -> Option<ChainRirMatch> {
344 let RirNode::ChainJoin {
345 left,
346 right,
347 left_key,
348 right_key,
349 output_columns,
350 ..
351 } = body
352 else {
353 return None;
354 };
355 if *left_key >= 2 || *right_key >= 2 {
356 return None;
357 }
358 let rel_left = scan_rel(left)?;
359 let rel_right = scan_rel(right)?;
360 Some(ChainRirMatch {
361 rel_left,
362 rel_right,
363 left_key: *left_key,
364 right_key: *right_key,
365 output_columns: output_columns.clone(),
366 })
367}
368
369pub(super) struct TriangleRirMatch {
372 pub rel_xy: RelId,
375 pub rel_yz: RelId,
377 pub rel_xz: RelId,
380}
381
382pub(super) fn match_multiway_triangle(body: &RirNode) -> Option<TriangleRirMatch> {
395 let RirNode::MultiWayJoin {
396 inputs,
397 slot_vars,
398 output_columns,
399 ..
400 } = body
401 else {
402 return None;
403 };
404 if inputs.len() != 3 {
405 return None;
406 }
407 if !slot_vars_match_canonical_triangle(slot_vars) {
408 return None;
409 }
410 if !output_columns_match_canonical_triangle(output_columns) {
411 return None;
412 }
413 let rel_xy = scan_rel(&inputs[0])?;
414 let rel_yz = scan_rel(&inputs[1])?;
415 let rel_xz = scan_rel(&inputs[2])?;
416 Some(TriangleRirMatch {
417 rel_xy,
418 rel_yz,
419 rel_xz,
420 })
421}
422
423fn slot_vars_match_canonical_triangle(slot_vars: &[Vec<Option<u32>>]) -> bool {
428 if slot_vars.len() != 3 {
429 return false;
430 }
431 let s0 = &slot_vars[0];
432 let s1 = &slot_vars[1];
433 let s2 = &slot_vars[2];
434 if s0.len() != 2 || s1.len() != 2 || s2.len() != 2 {
435 return false;
436 }
437 let (a, b) = match (s0[0], s0[1]) {
438 (Some(a), Some(b)) if a != b => (a, b),
439 _ => return false,
440 };
441 let c = match (s1[0], s1[1]) {
442 (Some(b1), Some(c)) if b1 == b && c != a && c != b => c,
443 _ => return false,
444 };
445 matches!((s2[0], s2[1]), (Some(a2), Some(c2)) if a2 == a && c2 == c)
446}
447
448fn output_columns_match_canonical_triangle(cols: &[ProjectExpr]) -> bool {
460 if cols.len() != 3 {
461 return false;
462 }
463 let cols_pattern = (
464 matches!(cols[0], ProjectExpr::Column(0)),
465 matches!(cols[1], ProjectExpr::Column(1)) || matches!(cols[1], ProjectExpr::Column(2)),
466 matches!(cols[2], ProjectExpr::Column(3)),
467 );
468 cols_pattern == (true, true, true)
469}
470
471pub(super) struct FourCycleRirMatch {
479 pub rel_e1: RelId,
480 pub rel_e2: RelId,
481 pub rel_e3: RelId,
482 pub rel_e4: RelId,
483}
484
485pub(super) fn match_multiway_4cycle(body: &RirNode) -> Option<FourCycleRirMatch> {
493 let RirNode::MultiWayJoin {
494 inputs,
495 slot_vars,
496 output_columns,
497 ..
498 } = body
499 else {
500 return None;
501 };
502 if inputs.len() != 4 {
503 return None;
504 }
505 if !slot_vars_match_canonical_4cycle(slot_vars) {
506 return None;
507 }
508 if !output_columns_match_canonical_4cycle(output_columns) {
509 return None;
510 }
511 let rel_e1 = scan_rel(&inputs[0])?;
512 let rel_e2 = scan_rel(&inputs[1])?;
513 let rel_e3 = scan_rel(&inputs[2])?;
514 let rel_e4 = scan_rel(&inputs[3])?;
515 Some(FourCycleRirMatch {
516 rel_e1,
517 rel_e2,
518 rel_e3,
519 rel_e4,
520 })
521}
522
523fn slot_vars_match_canonical_4cycle(slot_vars: &[Vec<Option<u32>>]) -> bool {
528 if slot_vars.len() != 4 {
529 return false;
530 }
531 for s in slot_vars {
532 if s.len() != 2 {
533 return false;
534 }
535 }
536 let (a, b) = match (slot_vars[0][0], slot_vars[0][1]) {
537 (Some(a), Some(b)) if a != b => (a, b),
538 _ => return false,
539 };
540 let c = match (slot_vars[1][0], slot_vars[1][1]) {
541 (Some(b1), Some(c)) if b1 == b && c != a && c != b => c,
542 _ => return false,
543 };
544 let d = match (slot_vars[2][0], slot_vars[2][1]) {
545 (Some(c1), Some(d)) if c1 == c && d != a && d != b && d != c => d,
546 _ => return false,
547 };
548 matches!(
549 (slot_vars[3][0], slot_vars[3][1]),
550 (Some(d2), Some(a2)) if d2 == d && a2 == a
551 )
552}
553
554fn output_columns_match_canonical_4cycle(cols: &[ProjectExpr]) -> bool {
564 if cols.len() != 4 {
565 return false;
566 }
567 let exact = |idx: usize, want: usize| matches!(cols[idx], ProjectExpr::Column(c) if c == want);
568 let default_layout = exact(0, 0) && exact(1, 1) && exact(2, 3) && exact(3, 5);
570 let alt_layout = exact(0, 5) && exact(1, 0) && exact(2, 1) && exact(3, 3);
572 default_layout || alt_layout
573}
574
575fn scan_rel(node: &RirNode) -> Option<RelId> {
580 match node {
581 RirNode::Scan { rel } => Some(*rel),
582 _ => None,
583 }
584}
585
586#[derive(Debug, Clone, Copy, PartialEq, Eq)]
590pub(super) enum WcojKeyWidth {
591 FourByte,
592 EightByte,
593}
594
595fn classify_two_col_wcoj_width(buf: &CudaBuffer) -> Option<WcojKeyWidth> {
609 if buf.arity() != 2 {
610 return None;
611 }
612 let c0 = buf.schema.column_type(0)?;
613 let c1 = buf.schema.column_type(1)?;
614 let w0 = scalar_wcoj_width(c0)?;
615 let w1 = scalar_wcoj_width(c1)?;
616 if w0 != w1 {
617 return None;
618 }
619 Some(w0)
620}
621
622fn scalar_wcoj_width(ty: xlog_core::ScalarType) -> Option<WcojKeyWidth> {
623 match ty {
624 xlog_core::ScalarType::U32 | xlog_core::ScalarType::Symbol => Some(WcojKeyWidth::FourByte),
625 xlog_core::ScalarType::U64 => Some(WcojKeyWidth::EightByte),
626 _ => None,
627 }
628}
629
630fn feedback_pair_from_var_order(
658 slot_rels: &[RelId],
659 var_order: Option<&VariableOrder>,
660) -> Option<(RelId, RelId, Vec<usize>, Vec<usize>)> {
661 if slot_rels.len() < 2 {
662 return None;
663 }
664 let Some(vo) = var_order else {
665 return Some((slot_rels[0], slot_rels[1], vec![1], vec![0]));
669 };
670 let leader_idx = vo.leader_idx as usize;
671 match slot_rels.len() {
672 3 => {
673 match leader_idx {
675 0 => Some((slot_rels[0], slot_rels[1], vec![1], vec![0])),
676 1 => {
677 Some((slot_rels[1], slot_rels[2], vec![1], vec![1]))
681 }
682 2 => {
683 Some((slot_rels[2], slot_rels[1], vec![1], vec![1]))
687 }
688 _ => None,
689 }
690 }
691 4 => {
692 if leader_idx >= 4 {
695 return None;
696 }
697 let slot1_input_idx = (leader_idx + 1) % 4;
698 Some((
699 slot_rels[leader_idx],
700 slot_rels[slot1_input_idx],
701 vec![1],
702 vec![0],
703 ))
704 }
705 _ => None,
706 }
707}
708
709fn perm_indices_from_kernel_output_cols(cols: &[ProjectExpr]) -> Result<Vec<usize>> {
710 let mut out = Vec::with_capacity(cols.len());
711 for c in cols {
712 match c {
713 ProjectExpr::Column(idx) => out.push(*idx),
714 other => {
715 return Err(xlog_core::XlogError::Kernel(format!(
716 "perm_indices_from_kernel_output_cols: \
717 kernel_output_cols must be ProjectExpr::Column(_), got {:?}",
718 other
719 )));
720 }
721 }
722 }
723 Ok(out)
724}
725
726fn build_triangle_head_schema(buf_xy: &CudaBuffer, buf_yz: &CudaBuffer) -> Result<Schema> {
731 let x_type = buf_xy.schema.column_type(0).ok_or_else(|| {
732 xlog_core::XlogError::Kernel("build_triangle_head_schema: e_xy.col0 type missing".into())
733 })?;
734 let y_type = buf_xy.schema.column_type(1).ok_or_else(|| {
735 xlog_core::XlogError::Kernel("build_triangle_head_schema: e_xy.col1 type missing".into())
736 })?;
737 let z_type = buf_yz.schema.column_type(1).ok_or_else(|| {
738 xlog_core::XlogError::Kernel("build_triangle_head_schema: e_yz.col1 type missing".into())
739 })?;
740 Schema::new(vec![
741 ("col0".to_string(), x_type),
742 ("col1".to_string(), y_type),
743 ("col2".to_string(), z_type),
744 ])
745 .with_sort_labels(vec![
746 buf_xy
747 .schema
748 .column_sort_label(0)
749 .unwrap_or("col0")
750 .to_string(),
751 buf_xy
752 .schema
753 .column_sort_label(1)
754 .unwrap_or("col1")
755 .to_string(),
756 buf_yz
757 .schema
758 .column_sort_label(1)
759 .unwrap_or("col2")
760 .to_string(),
761 ])
762 .map_err(xlog_core::XlogError::Kernel)
763}
764
765fn build_4cycle_head_schema(
768 buf_e1: &CudaBuffer,
769 buf_e2: &CudaBuffer,
770 buf_e3: &CudaBuffer,
771) -> Result<Schema> {
772 let w_type = buf_e1.schema.column_type(0).ok_or_else(|| {
776 xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_wx.col0 type missing".into())
777 })?;
778 let x_type = buf_e1.schema.column_type(1).ok_or_else(|| {
779 xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_wx.col1 type missing".into())
780 })?;
781 let y_type = buf_e2.schema.column_type(1).ok_or_else(|| {
782 xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_xy.col1 type missing".into())
783 })?;
784 let z_type = buf_e3.schema.column_type(1).ok_or_else(|| {
785 xlog_core::XlogError::Kernel("build_4cycle_head_schema: e_yz.col1 type missing".into())
786 })?;
787 let _: ScalarType = w_type;
791 Schema::new(vec![
792 ("col0".to_string(), w_type),
793 ("col1".to_string(), x_type),
794 ("col2".to_string(), y_type),
795 ("col3".to_string(), z_type),
796 ])
797 .with_sort_labels(vec![
798 buf_e1
799 .schema
800 .column_sort_label(0)
801 .unwrap_or("col0")
802 .to_string(),
803 buf_e1
804 .schema
805 .column_sort_label(1)
806 .unwrap_or("col1")
807 .to_string(),
808 buf_e2
809 .schema
810 .column_sort_label(1)
811 .unwrap_or("col2")
812 .to_string(),
813 buf_e3
814 .schema
815 .column_sort_label(1)
816 .unwrap_or("col3")
817 .to_string(),
818 ])
819 .map_err(xlog_core::XlogError::Kernel)
820}
821
822impl Executor {
823 pub(super) fn try_dispatch_wcoj_triangle(
834 &mut self,
835 rule: &CompiledRule,
836 ) -> Result<Option<CudaBuffer>> {
837 self.try_dispatch_wcoj_triangle_on_body(&rule.body)
840 }
841
842 fn wcoj_output_rows(buf: &CudaBuffer) -> Option<u64> {
849 buf.cached_row_count().map(u64::from)
852 }
853
854 fn record_wcoj_feedback(
907 &mut self,
908 slot_rels: &[RelId],
909 var_order: Option<&VariableOrder>,
910 output_rows: Option<u64>,
911 ) {
912 if slot_rels.len() < 2 {
913 return;
914 }
915 let Some(out_rows) = output_rows else {
916 return;
917 };
918 let Some((rel_a, rel_b, left_keys, right_keys)) =
925 feedback_pair_from_var_order(slot_rels, var_order)
926 else {
927 return;
928 };
929 let card_a = self
930 .stats
931 .get_relation_stats(rel_a)
932 .map(|s| s.cardinality)
933 .filter(|c| *c > 0);
934 let card_b = self
935 .stats
936 .get_relation_stats(rel_b)
937 .map(|s| s.cardinality)
938 .filter(|c| *c > 0);
939 let (Some(a), Some(b)) = (card_a, card_b) else {
940 return;
941 };
942 let input_rows = a.saturating_mul(b);
943 self.stats
946 .record_join_result(rel_a, rel_b, left_keys, right_keys, input_rows, out_rows);
947 }
948
949 pub(super) fn try_dispatch_wcoj_triangle_on_body(
956 &mut self,
957 body: &RirNode,
958 ) -> Result<Option<CudaBuffer>> {
959 #[cfg(feature = "wcoj-phase-timing")]
960 let wall_start = Instant::now();
961 if self.config.wcoj_triangle_dispatch_disabled.unwrap_or(false) {
971 return Ok(None);
972 }
973 let force_override = self.config.wcoj_triangle_dispatch;
974 let force_on = wcoj_gate_enabled(force_override);
975 let mode = if force_on {
976 DispatchMode::Force
977 } else {
978 let force_explicit_off = matches!(force_override, Some(false));
982 if force_explicit_off {
983 return Ok(None);
984 }
985 let adaptive_override = self.config.wcoj_triangle_dispatch_adaptive;
986 if wcoj_adaptive_enabled(adaptive_override) {
987 DispatchMode::CostModel
988 } else {
989 return Ok(None);
990 }
991 };
992
993 let Some(matched) = match_multiway_triangle(body) else {
995 return Ok(None);
996 };
997
998 let name_xy = match self.get_rel_name(matched.rel_xy) {
1002 Some(s) => s.to_string(),
1003 None => return Ok(None),
1004 };
1005 let name_yz = match self.get_rel_name(matched.rel_yz) {
1006 Some(s) => s.to_string(),
1007 None => return Ok(None),
1008 };
1009 let name_xz = match self.get_rel_name(matched.rel_xz) {
1010 Some(s) => s.to_string(),
1011 None => return Ok(None),
1012 };
1013
1014 let buf_xy = match self.store.get(&name_xy) {
1019 Some(b) => b,
1020 None => return Ok(None),
1021 };
1022 let buf_yz = match self.store.get(&name_yz) {
1023 Some(b) => b,
1024 None => return Ok(None),
1025 };
1026 let buf_xz = match self.store.get(&name_xz) {
1027 Some(b) => b,
1028 None => return Ok(None),
1029 };
1030 let width = match (
1031 classify_two_col_wcoj_width(buf_xy),
1032 classify_two_col_wcoj_width(buf_yz),
1033 classify_two_col_wcoj_width(buf_xz),
1034 ) {
1035 (Some(a), Some(b), Some(c)) if a == b && b == c => a,
1036 _ => return Ok(None),
1037 };
1038
1039 if self.provider.memory().runtime().is_none() {
1050 return Ok(None);
1051 }
1052 let launch_stream = match self.wcoj_dispatch_stream_or_init() {
1053 Some(s) => s,
1054 None => return Ok(None),
1055 };
1056
1057 #[cfg(feature = "wcoj-phase-timing")]
1060 let mut classifier_ms: f32 = 0.0;
1061 if mode == DispatchMode::CostModel {
1062 #[cfg(feature = "wcoj-phase-timing")]
1063 let cls_start = Instant::now();
1064 let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
1065 let slot_rels = [matched.rel_xy, matched.rel_yz, matched.rel_xz];
1066 let ctx = super::wcoj_cost_model::WcojDispatchCtx {
1067 stats: &self.stats,
1068 launch_stream,
1069 width,
1070 slot_rels: &slot_rels,
1071 };
1072 let dispatch = model.should_dispatch_triangle(&ctx);
1073 #[cfg(feature = "wcoj-phase-timing")]
1074 {
1075 classifier_ms = cls_start.elapsed().as_secs_f64() as f32 * 1000.0;
1076 }
1077 if !dispatch {
1078 return Ok(None);
1079 }
1080 }
1081
1082 let var_order_opt: Option<&VariableOrder> = match body {
1085 RirNode::MultiWayJoin { var_order, .. } => var_order.as_ref(),
1086 _ => None,
1087 };
1088
1089 #[cfg(feature = "wcoj-phase-timing")]
1095 let mut layout_times: [f32; 3] = [0.0; 3];
1096 let dispatch_result = self.run_wcoj_triangle_pipeline(
1097 buf_xy,
1098 buf_yz,
1099 buf_xz,
1100 launch_stream,
1101 width,
1102 var_order_opt,
1103 #[cfg(feature = "wcoj-phase-timing")]
1104 &mut layout_times,
1105 );
1106 match dispatch_result {
1107 Ok(buf) => {
1108 let output_rows = Self::wcoj_output_rows(&buf);
1126 let slot_rels = [matched.rel_xy, matched.rel_yz, matched.rel_xz];
1127 self.record_wcoj_feedback(&slot_rels, var_order_opt, output_rows);
1128 self.wcoj_triangle_dispatch_count += 1;
1129 #[cfg(feature = "wcoj-phase-timing")]
1130 {
1131 let triangle_timing = self
1132 .provider
1133 .take_wcoj_triangle_phase_timing()
1134 .unwrap_or_default();
1135 let wall_ms = wall_start.elapsed().as_secs_f64() as f32 * 1000.0;
1136 let timing = super::wcoj_phase_timing::WcojDispatchPhaseTiming::new(
1137 classifier_ms,
1138 layout_times[0],
1139 layout_times[1],
1140 layout_times[2],
1141 triangle_timing,
1142 wall_ms,
1143 );
1144 if let Ok(mut g) = self.last_wcoj_phase_timing.lock() {
1145 *g = Some(timing);
1146 }
1147 }
1148 Ok(Some(buf))
1149 }
1150 Err(err) => wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "triangle", err),
1151 }
1152 }
1153
1154 #[allow(clippy::too_many_arguments)]
1165 fn run_wcoj_triangle_pipeline(
1166 &self,
1167 buf_xy: &CudaBuffer,
1168 buf_yz: &CudaBuffer,
1169 buf_xz: &CudaBuffer,
1170 launch_stream: StreamId,
1171 width: WcojKeyWidth,
1172 var_order: Option<&VariableOrder>,
1173 #[cfg(feature = "wcoj-phase-timing")] layout_times_ms: &mut [f32; 3],
1174 ) -> Result<CudaBuffer> {
1175 if let Some(vo) = var_order {
1181 return self.run_wcoj_triangle_pipeline_with_leader_order(
1182 buf_xy,
1183 buf_yz,
1184 buf_xz,
1185 launch_stream,
1186 width,
1187 vo,
1188 );
1189 }
1190 #[cfg(feature = "wcoj-phase-timing")]
1191 let mut time_layout =
1192 |f: &dyn Fn() -> Result<CudaBuffer>, slot: usize| -> Result<CudaBuffer> {
1193 let s = Instant::now();
1194 let r = f()?;
1195 layout_times_ms[slot] = s.elapsed().as_secs_f64() as f32 * 1000.0;
1196 Ok(r)
1197 };
1198 match width {
1199 WcojKeyWidth::FourByte => {
1200 #[cfg(feature = "wcoj-phase-timing")]
1201 let (layout_xy, layout_yz, layout_xz) = {
1202 let xy = time_layout(
1203 &|| {
1204 self.provider
1205 .wcoj_layout_u32_recorded(buf_xy, launch_stream)
1206 },
1207 0,
1208 )?;
1209 let yz = time_layout(
1210 &|| {
1211 self.provider
1212 .wcoj_layout_u32_recorded(buf_yz, launch_stream)
1213 },
1214 1,
1215 )?;
1216 let xz = time_layout(
1217 &|| {
1218 self.provider
1219 .wcoj_layout_u32_recorded(buf_xz, launch_stream)
1220 },
1221 2,
1222 )?;
1223 (xy, yz, xz)
1224 };
1225 #[cfg(not(feature = "wcoj-phase-timing"))]
1226 let layout_xy = self
1227 .provider
1228 .wcoj_layout_u32_recorded(buf_xy, launch_stream)?;
1229 #[cfg(not(feature = "wcoj-phase-timing"))]
1230 let layout_yz = self
1231 .provider
1232 .wcoj_layout_u32_recorded(buf_yz, launch_stream)?;
1233 #[cfg(not(feature = "wcoj-phase-timing"))]
1234 let layout_xz = self
1235 .provider
1236 .wcoj_layout_u32_recorded(buf_xz, launch_stream)?;
1237 let out = self.provider.wcoj_triangle_hg_u32_recorded(
1238 &layout_xy,
1239 &layout_yz,
1240 &layout_xz,
1241 wcoj_block_work_unit(),
1242 launch_stream,
1243 )?;
1244 self.provider.record_wcoj_triangle_hg_dispatch();
1245 Ok(out)
1246 }
1247 WcojKeyWidth::EightByte => {
1248 #[cfg(feature = "wcoj-phase-timing")]
1249 let (layout_xy, layout_yz, layout_xz) = {
1250 let xy = time_layout(
1251 &|| {
1252 self.provider
1253 .wcoj_layout_u64_recorded(buf_xy, launch_stream)
1254 },
1255 0,
1256 )?;
1257 let yz = time_layout(
1258 &|| {
1259 self.provider
1260 .wcoj_layout_u64_recorded(buf_yz, launch_stream)
1261 },
1262 1,
1263 )?;
1264 let xz = time_layout(
1265 &|| {
1266 self.provider
1267 .wcoj_layout_u64_recorded(buf_xz, launch_stream)
1268 },
1269 2,
1270 )?;
1271 (xy, yz, xz)
1272 };
1273 #[cfg(not(feature = "wcoj-phase-timing"))]
1274 let layout_xy = self
1275 .provider
1276 .wcoj_layout_u64_recorded(buf_xy, launch_stream)?;
1277 #[cfg(not(feature = "wcoj-phase-timing"))]
1278 let layout_yz = self
1279 .provider
1280 .wcoj_layout_u64_recorded(buf_yz, launch_stream)?;
1281 #[cfg(not(feature = "wcoj-phase-timing"))]
1282 let layout_xz = self
1283 .provider
1284 .wcoj_layout_u64_recorded(buf_xz, launch_stream)?;
1285 self.provider.wcoj_triangle_u64_recorded(
1286 &layout_xy,
1287 &layout_yz,
1288 &layout_xz,
1289 launch_stream,
1290 )
1291 }
1292 }
1293 }
1294
1295 fn run_wcoj_triangle_pipeline_with_leader_order(
1315 &self,
1316 buf_xy: &CudaBuffer,
1317 buf_yz: &CudaBuffer,
1318 buf_xz: &CudaBuffer,
1319 launch_stream: StreamId,
1320 width: WcojKeyWidth,
1321 var_order: &VariableOrder,
1322 ) -> Result<CudaBuffer> {
1323 let canonical: [&CudaBuffer; 3] = [buf_xy, buf_yz, buf_xz];
1324 let slot_inputs = self.prepare_leader_inputs(&canonical, var_order, launch_stream)?;
1325 if slot_inputs.len() != 3 {
1326 return Err(xlog_core::XlogError::Kernel(
1327 "run_wcoj_triangle_pipeline_with_leader_order: prepare_leader_inputs must return 3 slots"
1328 .to_string(),
1329 ));
1330 }
1331
1332 let head_schema = build_triangle_head_schema(buf_xy, buf_yz)?;
1338 let perm = perm_indices_from_kernel_output_cols(&var_order.kernel_output_cols)?;
1339
1340 let kernel_out: CudaBuffer = match width {
1341 WcojKeyWidth::FourByte => {
1342 let l0 = self
1343 .provider
1344 .wcoj_layout_u32_recorded(&slot_inputs[0], launch_stream)?;
1345 let l1 = self
1346 .provider
1347 .wcoj_layout_u32_recorded(&slot_inputs[1], launch_stream)?;
1348 let l2 = self
1349 .provider
1350 .wcoj_layout_u32_recorded(&slot_inputs[2], launch_stream)?;
1351 let out = self.provider.wcoj_triangle_hg_u32_recorded(
1352 &l0,
1353 &l1,
1354 &l2,
1355 wcoj_block_work_unit(),
1356 launch_stream,
1357 )?;
1358 self.provider.record_wcoj_triangle_hg_dispatch();
1359 out
1360 }
1361 WcojKeyWidth::EightByte => {
1362 let l0 = self
1363 .provider
1364 .wcoj_layout_u64_recorded(&slot_inputs[0], launch_stream)?;
1365 let l1 = self
1366 .provider
1367 .wcoj_layout_u64_recorded(&slot_inputs[1], launch_stream)?;
1368 let l2 = self
1369 .provider
1370 .wcoj_layout_u64_recorded(&slot_inputs[2], launch_stream)?;
1371 self.provider
1372 .wcoj_triangle_u64_recorded(&l0, &l1, &l2, launch_stream)?
1373 }
1374 };
1375
1376 self.provider.wcoj_project_output_columns_recorded(
1377 &kernel_out,
1378 &perm,
1379 head_schema,
1380 launch_stream,
1381 )
1382 }
1383
1384 pub fn wcoj_triangle_dispatch_count(&self) -> u64 {
1390 self.wcoj_triangle_dispatch_count
1391 }
1392
1393 pub fn wcoj_error_decline_count(&self) -> u64 {
1400 self.wcoj_error_decline_count
1401 }
1402
1403 pub fn free_join_dispatch_count(&self) -> u64 {
1406 self.free_join_dispatch_count
1407 }
1408
1409 pub fn factorized_delta_dispatch_count(&self) -> u64 {
1412 self.factorized_delta_dispatch_count
1413 }
1414
1415 fn factorized_delta_normalize_static(
1419 &self,
1420 buf: &CudaBuffer,
1421 key_col: usize,
1422 launch_stream: StreamId,
1423 ) -> Result<CudaBuffer> {
1424 if key_col == 0 {
1425 return self.provider.wcoj_layout_u32_recorded(buf, launch_stream);
1426 }
1427 let ty = |i: usize| {
1428 buf.schema().column_type(i).ok_or_else(|| {
1429 xlog_core::XlogError::Execution(format!(
1430 "factorized-delta: static column {i} type missing"
1431 ))
1432 })
1433 };
1434 let swapped = Schema::new(vec![("k".to_string(), ty(1)?), ("v".to_string(), ty(0)?)]);
1435 let projected = self.provider.wcoj_project_output_columns_recorded(
1436 buf,
1437 &[1, 0],
1438 swapped,
1439 launch_stream,
1440 )?;
1441 self.provider
1442 .wcoj_layout_u32_recorded(&projected, launch_stream)
1443 }
1444
1445 pub(super) fn try_dispatch_factorized_delta(
1462 &mut self,
1463 node: &RirNode,
1464 delta_rel: RelId,
1465 head_pred: &str,
1466 recursive_preds: &HashSet<String>,
1467 ctx: &mut FactorizedDeltaCtx,
1468 ) -> Result<Option<CudaBuffer>> {
1469 use xlog_cuda::provider::FjDeltaCols;
1470
1471 if factorized_delta_disabled() {
1472 return Ok(None);
1473 }
1474 let RirNode::ChainJoin {
1475 left,
1476 right,
1477 left_key,
1478 right_key,
1479 output_columns,
1480 ..
1481 } = node
1482 else {
1483 return Ok(None);
1484 };
1485 let (RirNode::Scan { rel: left_rel }, RirNode::Scan { rel: right_rel }) =
1486 (left.as_ref(), right.as_ref())
1487 else {
1488 return Ok(None);
1489 };
1490 let delta_on_left = match (*left_rel == delta_rel, *right_rel == delta_rel) {
1494 (true, false) => true,
1495 (false, true) => false,
1496 _ => return Ok(None),
1497 };
1498 let (delta_key, static_rel, static_key) = if delta_on_left {
1499 (*left_key, *right_rel, *right_key)
1500 } else {
1501 (*right_key, *left_rel, *left_key)
1502 };
1503 if delta_key > 1 || static_key > 1 {
1504 return Ok(None);
1505 }
1506 let delta_carry = 1 - delta_key;
1507 let static_value = 1 - static_key;
1508
1509 let (delta_off, static_off) = if delta_on_left { (0, 2) } else { (2, 0) };
1512 let carry_global = delta_off + delta_carry;
1513 let value_global = static_off + static_value;
1514 let [ProjectExpr::Column(out0), ProjectExpr::Column(out1)] = output_columns.as_slice()
1515 else {
1516 return Ok(None);
1517 };
1518 let (r_carry, r_value) = if (*out0, *out1) == (carry_global, value_global) {
1519 (0, 1)
1520 } else if (*out0, *out1) == (value_global, carry_global) {
1521 (1, 0)
1522 } else {
1523 return Ok(None);
1524 };
1525
1526 let binary_u32_class = |buf: &CudaBuffer| {
1528 buf.arity() == 2
1529 && (0..2).all(|i| {
1530 matches!(
1531 buf.schema().column_type(i),
1532 Some(ScalarType::U32) | Some(ScalarType::Symbol)
1533 )
1534 })
1535 };
1536 let Some(delta_name) = self.get_rel_name(delta_rel).map(str::to_string) else {
1537 return Ok(None);
1538 };
1539 let Some(static_name) = self.get_rel_name(static_rel).map(str::to_string) else {
1540 return Ok(None);
1541 };
1542 let Some(delta_buf) = self.store.get(&delta_name) else {
1543 return Ok(None);
1544 };
1545 let Some(static_buf) = self.store.get(&static_name) else {
1546 return Ok(None);
1547 };
1548 let Some(full_buf) = self.store.get(head_pred) else {
1549 return Ok(None);
1550 };
1551 if !binary_u32_class(delta_buf)
1552 || !binary_u32_class(static_buf)
1553 || !binary_u32_class(full_buf)
1554 {
1555 return Ok(None);
1556 }
1557 if self.provider.memory().runtime().is_none() {
1558 return Ok(None);
1559 }
1560 let Some(launch_stream) = self.wcoj_dispatch_stream_or_init() else {
1561 return Ok(None);
1562 };
1563
1564 let domain_key = (head_pred.to_string(), static_rel);
1578 let domain = match ctx.domain_by_key.get(&domain_key) {
1579 Some(Some(d)) => *d,
1580 Some(None) => return Ok(None),
1581 None => {
1582 let max_id = match self.provider.fj_delta_columns_max_u32(
1583 &[
1584 (delta_buf, &[0, 1][..]),
1585 (static_buf, &[0, 1][..]),
1586 (full_buf, &[0, 1][..]),
1587 ],
1588 launch_stream,
1589 ) {
1590 Ok(m) => m,
1591 Err(err) => {
1592 return wcoj_decline_on_error(
1593 &mut self.wcoj_error_decline_count,
1594 "factorized-delta",
1595 err,
1596 );
1597 }
1598 };
1599 let decided = if max_id == u32::MAX {
1600 None
1601 } else {
1602 Some(max_id + 1)
1603 };
1604 ctx.domain_by_key.insert(domain_key, decided);
1605 match decided {
1606 Some(d) => d,
1607 None => return Ok(None),
1608 }
1609 }
1610 };
1611
1612 let n_delta = u64::from(self.buffer_row_count(delta_buf)?);
1613 let n_static = u64::from(self.buffer_row_count(static_buf)?);
1614 if n_delta == 0 || n_static == 0 {
1615 return Ok(None);
1616 }
1617
1618 let dense = domain <= factorized_delta_max_domain();
1625 if dense {
1626 let n_words = u64::from(domain.div_ceil(32)) * u64::from(domain);
1627 let work_est = n_delta.saturating_mul((n_static / u64::from(domain)).max(1));
1628 if work_est < n_words / factorized_delta_work_divisor() {
1629 return Ok(None);
1630 }
1631 }
1632
1633 let static_is_recursive = recursive_preds.contains(&static_name);
1639 let norm_owned;
1640 let static_norm: &CudaBuffer = if static_is_recursive {
1641 norm_owned =
1642 match self.factorized_delta_normalize_static(static_buf, static_key, launch_stream)
1643 {
1644 Ok(b) => b,
1645 Err(err) => {
1646 return wcoj_decline_on_error(
1647 &mut self.wcoj_error_decline_count,
1648 "factorized-delta",
1649 err,
1650 );
1651 }
1652 };
1653 &norm_owned
1654 } else {
1655 match ctx.static_norm_cache.entry((static_rel, static_key)) {
1656 std::collections::hash_map::Entry::Occupied(e) => &*e.into_mut(),
1657 std::collections::hash_map::Entry::Vacant(v) => {
1658 let norm = match self.factorized_delta_normalize_static(
1659 static_buf,
1660 static_key,
1661 launch_stream,
1662 ) {
1663 Ok(b) => b,
1664 Err(err) => {
1665 return wcoj_decline_on_error(
1666 &mut self.wcoj_error_decline_count,
1667 "factorized-delta",
1668 err,
1669 );
1670 }
1671 };
1672 &*v.insert(norm)
1673 }
1674 }
1675 };
1676
1677 let cols = FjDeltaCols {
1678 delta_carry,
1679 delta_key,
1680 r_carry,
1681 r_value,
1682 };
1683 if dense {
1684 match self.provider.fj_delta_novel_u32_recorded(
1685 delta_buf,
1686 static_norm,
1687 full_buf,
1688 cols,
1689 domain,
1690 launch_stream,
1691 ) {
1692 Ok(novel) => {
1693 self.factorized_delta_dispatch_count += 1;
1694 Ok(Some(novel))
1695 }
1696 Err(err) => wcoj_decline_on_error(
1697 &mut self.wcoj_error_decline_count,
1698 "factorized-delta",
1699 err,
1700 ),
1701 }
1702 } else {
1703 let max_table_bytes =
1707 factorized_delta_max_table_bytes(self.provider.memory().budget().device_bytes);
1708 match self.provider.fj_delta_sparse_novel_u32_recorded(
1709 delta_buf,
1710 static_norm,
1711 full_buf,
1712 cols,
1713 max_table_bytes,
1714 launch_stream,
1715 ) {
1716 Ok(Some(novel)) => {
1717 self.factorized_delta_dispatch_count += 1;
1718 Ok(Some(novel))
1719 }
1720 Ok(None) => Ok(None),
1721 Err(err) => wcoj_decline_on_error(
1722 &mut self.wcoj_error_decline_count,
1723 "factorized-delta",
1724 err,
1725 ),
1726 }
1727 }
1728 }
1729
1730 pub(super) fn try_dispatch_free_join(&mut self, node: &RirNode) -> Result<Option<CudaBuffer>> {
1746 use xlog_cuda::provider::{FjNode, FjPlan, FjSubAtom};
1747
1748 if free_join_disabled() {
1749 return Ok(None);
1750 }
1751 let RirNode::MultiWayJoin {
1752 inputs,
1753 slot_vars,
1754 output_columns,
1755 plan,
1756 ..
1757 } = node
1758 else {
1759 return Ok(None);
1760 };
1761 if !matches!(plan, Some(MultiwayPlan::FreeJoin)) {
1774 return Ok(None);
1775 }
1776 if inputs.len() < 3 {
1777 return Ok(None);
1778 }
1779 let mut bufs: Vec<&CudaBuffer> = Vec::with_capacity(inputs.len());
1784 let mut all_u32 = true;
1785 let mut all_u64 = true;
1786 for input in inputs {
1787 let RirNode::Scan { rel } = input else {
1788 return Ok(None);
1789 };
1790 let name = match self.get_rel_name(*rel) {
1791 Some(s) => s.to_string(),
1792 None => return Ok(None),
1793 };
1794 let Some(buf) = self.store.get(&name) else {
1795 return Ok(None);
1796 };
1797 for i in 0..buf.arity() {
1798 match buf.schema().column_type(i) {
1799 Some(ScalarType::U32 | ScalarType::Symbol) => all_u64 = false,
1800 Some(ScalarType::U64) => all_u32 = false,
1801 _ => return Ok(None),
1802 }
1803 }
1804 bufs.push(buf);
1805 }
1806 if !all_u32 && !all_u64 {
1807 return Ok(None);
1808 }
1809 {
1816 let slot_rels: Vec<RelId> = inputs
1817 .iter()
1818 .filter_map(|i| match i {
1819 RirNode::Scan { rel } => Some(*rel),
1820 _ => None,
1821 })
1822 .collect();
1823 let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
1824 let width = if all_u32 {
1825 WcojKeyWidth::FourByte
1826 } else {
1827 WcojKeyWidth::EightByte
1828 };
1829 let ctx = super::wcoj_cost_model::WcojDispatchCtx {
1830 stats: &self.stats,
1831 launch_stream: StreamId::DEFAULT,
1832 width,
1833 slot_rels: &slot_rels,
1834 };
1835 if model.factorized_loss_veto(&ctx) {
1836 return Ok(None);
1837 }
1838 }
1839 let mut class_to_var: Vec<u32> = Vec::new();
1841 let mut dense = |class: u32| -> usize {
1842 match class_to_var.iter().position(|c| *c == class) {
1843 Some(i) => i,
1844 None => {
1845 class_to_var.push(class);
1846 class_to_var.len() - 1
1847 }
1848 }
1849 };
1850 let mut atom_vars: Vec<Vec<usize>> = Vec::with_capacity(slot_vars.len());
1851 for (i, cols) in slot_vars.iter().enumerate() {
1852 if cols.len() != bufs[i].arity() {
1853 return Ok(None);
1854 }
1855 let mut vars = Vec::with_capacity(cols.len());
1856 for c in cols {
1857 let Some(class) = c else { return Ok(None) };
1858 vars.push(dense(*class));
1859 }
1860 atom_vars.push(vars);
1861 }
1862 let num_vars = class_to_var.len();
1863 let order: Vec<usize> = {
1878 let slot_rels: Vec<RelId> = inputs
1879 .iter()
1880 .filter_map(|i| match i {
1881 RirNode::Scan { rel } => Some(*rel),
1882 _ => None,
1883 })
1884 .collect();
1885 let cards: Vec<u64> = bufs.iter().map(|b| b.num_rows()).collect();
1886 let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
1887 let width = if all_u32 {
1888 WcojKeyWidth::FourByte
1889 } else {
1890 WcojKeyWidth::EightByte
1891 };
1892 let ctx = super::wcoj_cost_model::WcojDispatchCtx {
1893 stats: &self.stats,
1894 launch_stream: StreamId::DEFAULT,
1895 width,
1896 slot_rels: &slot_rels,
1897 };
1898 match model.plan_free_join_order(&ctx, &atom_vars, &cards) {
1899 super::wcoj_cost_model::FjOrderDecision::Decline => return Ok(None),
1900 super::wcoj_cost_model::FjOrderDecision::Reorder(o) => o,
1901 super::wcoj_cost_model::FjOrderDecision::KeepDefault => {
1902 (0..atom_vars.len()).collect()
1903 }
1904 }
1905 };
1906 let mut bound_at: Vec<Option<usize>> = vec![None; num_vars]; let mut nodes: Vec<FjNode> = Vec::new();
1913 for &i in &order {
1918 let vars = &atom_vars[i];
1919 let split = vars.iter().take_while(|v| bound_at[**v].is_some()).count();
1920 if vars[split..].iter().any(|v| bound_at[*v].is_some()) {
1921 return Ok(None);
1924 }
1925 if split > 0 {
1926 let probe = FjSubAtom {
1927 input_idx: i,
1928 var_positions: vars[..split].to_vec(),
1929 };
1930 if nodes.is_empty() {
1931 return Ok(None);
1932 }
1933 let target = vars[..split]
1934 .iter()
1935 .map(|v| bound_at[*v].expect("prefix vars are bound"))
1936 .max()
1937 .expect("split > 0");
1938 nodes[target].probes.push(probe);
1939 }
1940 if split < vars.len() {
1941 let cover_vars = vars[split..].to_vec();
1942 let mut seen = HashSet::new();
1943 if !cover_vars.iter().all(|v| seen.insert(*v)) {
1944 return Ok(None);
1945 }
1946 let k = nodes.len();
1947 for v in &cover_vars {
1948 bound_at[*v] = Some(k);
1949 }
1950 nodes.push(FjNode {
1951 cover: FjSubAtom {
1952 input_idx: i,
1953 var_positions: cover_vars,
1954 },
1955 probes: Vec::new(),
1956 });
1957 } else if nodes.is_empty() {
1958 return Ok(None);
1959 }
1960 }
1961 let mut col_to_var: Vec<usize> = Vec::new();
1964 for vars in &atom_vars {
1965 col_to_var.extend(vars.iter().copied());
1966 }
1967 let mut output_vars: Vec<usize> = Vec::with_capacity(output_columns.len());
1968 for oc in output_columns {
1969 let ProjectExpr::Column(c) = oc else {
1970 return Ok(None);
1971 };
1972 let Some(v) = col_to_var.get(*c) else {
1973 return Ok(None);
1974 };
1975 output_vars.push(*v);
1976 }
1977 let fj_plan = FjPlan {
1978 num_vars,
1979 nodes,
1980 output_vars,
1981 };
1982 if self.provider.memory().runtime().is_none() {
1983 return Ok(None);
1984 }
1985 let Some(launch_stream) = self.wcoj_dispatch_stream_or_init() else {
1986 return Ok(None);
1987 };
1988 let outcome = if all_u32 {
1989 self.provider
1990 .free_join_execute_u32_recorded(&bufs, &fj_plan, launch_stream)
1991 } else {
1992 self.provider
1993 .free_join_execute_u64_recorded(&bufs, &fj_plan, launch_stream)
1994 };
1995 match outcome {
1996 Ok(buf) => {
1997 self.free_join_dispatch_count += 1;
1998 Ok(Some(buf))
1999 }
2000 Err(err) => wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "free-join", err),
2001 }
2002 }
2003
2004 fn try_dispatch_free_join_count(
2019 &mut self,
2020 node: &RirNode,
2021 group_cols: &[ProjectExpr],
2022 ) -> Result<Option<CudaBuffer>> {
2023 use xlog_cuda::provider::{FjNode, FjPlan, FjSubAtom};
2024
2025 if free_join_disabled() {
2026 return Ok(None);
2027 }
2028 let RirNode::MultiWayJoin {
2029 inputs,
2030 slot_vars,
2031 plan,
2032 ..
2033 } = node
2034 else {
2035 return Ok(None);
2036 };
2037 if !matches!(plan, Some(MultiwayPlan::FreeJoin)) {
2038 return Ok(None);
2039 }
2040 if inputs.len() < 3 {
2041 return Ok(None);
2042 }
2043 let mut bufs: Vec<&CudaBuffer> = Vec::with_capacity(inputs.len());
2044 for input in inputs {
2045 let RirNode::Scan { rel } = input else {
2046 return Ok(None);
2047 };
2048 let name = match self.get_rel_name(*rel) {
2049 Some(s) => s.to_string(),
2050 None => return Ok(None),
2051 };
2052 let Some(buf) = self.store.get(&name) else {
2053 return Ok(None);
2054 };
2055 let four_byte = (0..buf.arity()).all(|i| {
2056 matches!(
2057 buf.schema().column_type(i),
2058 Some(ScalarType::U32 | ScalarType::Symbol)
2059 )
2060 });
2061 if !four_byte {
2062 return Ok(None);
2063 }
2064 bufs.push(buf);
2065 }
2066 let mut class_to_var: Vec<u32> = Vec::new();
2069 let mut dense = |class: u32| -> usize {
2070 match class_to_var.iter().position(|c| *c == class) {
2071 Some(i) => i,
2072 None => {
2073 class_to_var.push(class);
2074 class_to_var.len() - 1
2075 }
2076 }
2077 };
2078 let mut atom_vars: Vec<Vec<usize>> = Vec::with_capacity(slot_vars.len());
2079 for (i, cols) in slot_vars.iter().enumerate() {
2080 if cols.len() != bufs[i].arity() {
2081 return Ok(None);
2082 }
2083 let mut vars = Vec::with_capacity(cols.len());
2084 for c in cols {
2085 let Some(class) = c else { return Ok(None) };
2086 vars.push(dense(*class));
2087 }
2088 atom_vars.push(vars);
2089 }
2090 let num_vars = class_to_var.len();
2091 let mut col_to_var: Vec<usize> = Vec::new();
2094 for vars in &atom_vars {
2095 col_to_var.extend(vars.iter().copied());
2096 }
2097 let Some(ProjectExpr::Column(key_col)) = group_cols.first() else {
2098 return Ok(None);
2099 };
2100 let Some(&group_var) = col_to_var.get(*key_col) else {
2101 return Ok(None);
2102 };
2103 let mut occurrences = vec![0usize; num_vars];
2107 for vars in &atom_vars {
2108 for &v in vars {
2109 occurrences[v] += 1;
2110 }
2111 }
2112 let mut bound_at: Vec<Option<usize>> = vec![None; num_vars];
2114 let mut nodes: Vec<FjNode> = Vec::new();
2115 for (i, vars) in atom_vars.iter().enumerate() {
2116 let split = vars.iter().take_while(|v| bound_at[**v].is_some()).count();
2117 if vars[split..].iter().any(|v| bound_at[*v].is_some()) {
2118 return Ok(None);
2120 }
2121 let mut keep_end = vars.len();
2122 while keep_end > split {
2123 let v = vars[keep_end - 1];
2124 if occurrences[v] == 1 && v != group_var {
2125 keep_end -= 1;
2126 } else {
2127 break;
2128 }
2129 }
2130 if split == 0 && keep_end == 0 {
2131 return Ok(None);
2135 }
2136 if split > 0 {
2137 let probe = FjSubAtom {
2138 input_idx: i,
2139 var_positions: vars[..split].to_vec(),
2140 };
2141 if nodes.is_empty() {
2142 return Ok(None);
2143 }
2144 let target = vars[..split]
2145 .iter()
2146 .map(|v| bound_at[*v].expect("prefix vars are bound"))
2147 .max()
2148 .expect("split > 0");
2149 nodes[target].probes.push(probe);
2150 }
2151 if split < keep_end {
2152 let cover_vars = vars[split..keep_end].to_vec();
2153 let mut seen = HashSet::new();
2154 if !cover_vars.iter().all(|v| seen.insert(*v)) {
2155 return Ok(None);
2156 }
2157 let k = nodes.len();
2158 for v in &cover_vars {
2159 bound_at[*v] = Some(k);
2160 }
2161 nodes.push(FjNode {
2162 cover: FjSubAtom {
2163 input_idx: i,
2164 var_positions: cover_vars,
2165 },
2166 probes: Vec::new(),
2167 });
2168 } else if nodes.is_empty() {
2169 return Ok(None);
2170 }
2171 }
2172 if bound_at[group_var].is_none() {
2173 return Ok(None);
2174 }
2175 let fj_plan = FjPlan {
2176 num_vars,
2177 nodes,
2178 output_vars: vec![group_var],
2179 };
2180 if self.provider.memory().runtime().is_none() {
2181 return Ok(None);
2182 }
2183 let Some(launch_stream) = self.wcoj_dispatch_stream_or_init() else {
2184 return Ok(None);
2185 };
2186 match self
2187 .provider
2188 .free_join_count_by_root_u32_recorded(&bufs, &fj_plan, launch_stream)
2189 {
2190 Ok(buf) => {
2191 self.free_join_dispatch_count += 1;
2192 self.wcoj_groupby_fusion_dispatch_count += 1;
2193 Ok(Some(buf))
2194 }
2195 Err(err) => {
2196 wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "free-join-count", err)
2197 }
2198 }
2199 }
2200
2201 pub fn wcoj_groupby_fusion_dispatch_count(&self) -> u64 {
2205 self.wcoj_groupby_fusion_dispatch_count
2206 }
2207
2208 pub(super) fn try_dispatch_wcoj_groupby_root_agg(
2225 &mut self,
2226 input: &RirNode,
2227 key_cols: &[usize],
2228 aggs: &[(usize, xlog_core::AggOp)],
2229 ) -> Result<Option<CudaBuffer>> {
2230 use xlog_core::AggOp;
2231 if wcoj_groupby_fusion_disabled() {
2232 return Ok(None);
2233 }
2234 if key_cols != [0] {
2235 return Ok(None);
2236 }
2237 if aggs.len() != 1 {
2238 return Ok(None);
2239 }
2240 let (agg_col, agg_op) = aggs[0];
2241 if !matches!(agg_op, AggOp::Count | AggOp::Sum | AggOp::Min | AggOp::Max) {
2242 return Ok(None);
2243 }
2244 let RirNode::Project {
2245 input: multiway,
2246 columns,
2247 } = input
2248 else {
2249 return Ok(None);
2250 };
2251 if columns.is_empty() || !columns.iter().all(|c| matches!(c, ProjectExpr::Column(_))) {
2253 return Ok(None);
2254 }
2255 let key_is_col0 = matches!(columns[0], ProjectExpr::Column(0));
2260 let agg_value_col = if matches!(agg_op, AggOp::Count) {
2268 None
2269 } else {
2270 match columns.get(agg_col) {
2271 Some(ProjectExpr::Column(c)) if *c >= 1 => Some(*c),
2272 _ => return Ok(None),
2273 }
2274 };
2275 let Some(matched) = match_multiway_triangle(multiway) else {
2276 if key_is_col0 {
2280 if let Some(buf) =
2281 self.try_dispatch_wcoj_groupby_root_agg_4cycle(multiway, agg_op, agg_value_col)?
2282 {
2283 return Ok(Some(buf));
2284 }
2285 }
2286 if !matches!(agg_op, AggOp::Count) {
2291 return Ok(None);
2292 }
2293 if let Some(buf) =
2294 self.try_dispatch_wcoj_groupby_root_count_clique(multiway, columns)?
2295 {
2296 return Ok(Some(buf));
2297 }
2298 return self.try_dispatch_free_join_count(multiway, columns);
2302 };
2303 let agg_value = match agg_value_col {
2306 None => None,
2307 Some(1) => Some(WcojRootAggValue::Y),
2308 Some(2) => Some(WcojRootAggValue::Z),
2309 Some(_) => return Ok(None),
2310 };
2311 if !key_is_col0 {
2312 return Ok(None);
2313 }
2314 let name_xy = match self.get_rel_name(matched.rel_xy) {
2315 Some(s) => s.to_string(),
2316 None => return Ok(None),
2317 };
2318 let name_yz = match self.get_rel_name(matched.rel_yz) {
2319 Some(s) => s.to_string(),
2320 None => return Ok(None),
2321 };
2322 let name_xz = match self.get_rel_name(matched.rel_xz) {
2323 Some(s) => s.to_string(),
2324 None => return Ok(None),
2325 };
2326 let buf_xy = match self.store.get(&name_xy) {
2327 Some(b) => b,
2328 None => return Ok(None),
2329 };
2330 let buf_yz = match self.store.get(&name_yz) {
2331 Some(b) => b,
2332 None => return Ok(None),
2333 };
2334 let buf_xz = match self.store.get(&name_xz) {
2335 Some(b) => b,
2336 None => return Ok(None),
2337 };
2338 let width = match (
2339 classify_two_col_wcoj_width(buf_xy),
2340 classify_two_col_wcoj_width(buf_yz),
2341 classify_two_col_wcoj_width(buf_xz),
2342 ) {
2343 (
2344 Some(WcojKeyWidth::FourByte),
2345 Some(WcojKeyWidth::FourByte),
2346 Some(WcojKeyWidth::FourByte),
2347 ) => WcojKeyWidth::FourByte,
2348 (
2349 Some(WcojKeyWidth::EightByte),
2350 Some(WcojKeyWidth::EightByte),
2351 Some(WcojKeyWidth::EightByte),
2352 ) => WcojKeyWidth::EightByte,
2353 _ => return Ok(None),
2354 };
2355 {
2365 let slot_rels = [matched.rel_xy, matched.rel_yz, matched.rel_xz];
2366 let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
2367 let ctx = super::wcoj_cost_model::WcojDispatchCtx {
2368 stats: &self.stats,
2369 launch_stream: StreamId::DEFAULT,
2370 width,
2371 slot_rels: &slot_rels,
2372 };
2373 if model.factorized_loss_veto(&ctx) {
2374 return Ok(None);
2375 }
2376 }
2377 if matches!(width, WcojKeyWidth::FourByte) {
2384 match agg_value {
2385 Some(WcojRootAggValue::Y) => {
2386 if buf_xy.schema().column_type(1) != Some(xlog_core::ScalarType::U32) {
2387 return Ok(None);
2388 }
2389 }
2390 Some(WcojRootAggValue::Z) => {
2391 if buf_yz.schema().column_type(1) != Some(xlog_core::ScalarType::U32)
2392 || buf_xz.schema().column_type(1) != Some(xlog_core::ScalarType::U32)
2393 {
2394 return Ok(None);
2395 }
2396 }
2397 None => {}
2398 }
2399 }
2400 if self.provider.memory().runtime().is_none() {
2401 return Ok(None);
2402 }
2403 let Some(launch_stream) = self.wcoj_dispatch_stream_or_init() else {
2404 return Ok(None);
2405 };
2406 let result = match (agg_value, width) {
2407 (None, WcojKeyWidth::FourByte) => {
2408 self.provider.wcoj_triangle_groupby_root_count_u32_recorded(
2409 buf_xy,
2410 buf_yz,
2411 buf_xz,
2412 wcoj_block_work_unit(),
2413 launch_stream,
2414 )
2415 }
2416 (None, WcojKeyWidth::EightByte) => {
2417 self.provider.wcoj_triangle_groupby_root_count_u64_recorded(
2418 buf_xy,
2419 buf_yz,
2420 buf_xz,
2421 wcoj_block_work_unit(),
2422 launch_stream,
2423 )
2424 }
2425 (Some(value), WcojKeyWidth::FourByte) => {
2426 self.provider.wcoj_triangle_groupby_root_agg_u32_recorded(
2427 buf_xy,
2428 buf_yz,
2429 buf_xz,
2430 agg_op,
2431 value,
2432 wcoj_block_work_unit(),
2433 launch_stream,
2434 )
2435 }
2436 (Some(value), WcojKeyWidth::EightByte) => {
2439 self.provider.wcoj_triangle_groupby_root_agg_u64_recorded(
2440 buf_xy,
2441 buf_yz,
2442 buf_xz,
2443 agg_op,
2444 value,
2445 wcoj_block_work_unit(),
2446 launch_stream,
2447 )
2448 }
2449 };
2450 match result {
2451 Ok(buf) => {
2452 self.wcoj_groupby_fusion_dispatch_count += 1;
2453 Ok(Some(buf))
2454 }
2455 Err(err) => {
2456 wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "groupby-fusion", err)
2457 }
2458 }
2459 }
2460
2461 fn try_dispatch_wcoj_groupby_root_agg_4cycle(
2493 &mut self,
2494 multiway: &RirNode,
2495 agg_op: xlog_core::AggOp,
2496 agg_value_col: Option<usize>,
2497 ) -> Result<Option<CudaBuffer>> {
2498 use xlog_core::AggOp;
2499 let Some(matched) = match_multiway_4cycle(multiway) else {
2500 return Ok(None);
2501 };
2502 let agg_value = match agg_value_col {
2505 None => None,
2506 Some(1) => Some(Wcoj4CycleRootAggValue::X),
2507 Some(2) => Some(Wcoj4CycleRootAggValue::Y),
2508 Some(3) => Some(Wcoj4CycleRootAggValue::Z),
2509 Some(_) => return Ok(None),
2510 };
2511 let name_e1 = match self.get_rel_name(matched.rel_e1) {
2512 Some(s) => s.to_string(),
2513 None => return Ok(None),
2514 };
2515 let name_e2 = match self.get_rel_name(matched.rel_e2) {
2516 Some(s) => s.to_string(),
2517 None => return Ok(None),
2518 };
2519 let name_e3 = match self.get_rel_name(matched.rel_e3) {
2520 Some(s) => s.to_string(),
2521 None => return Ok(None),
2522 };
2523 let name_e4 = match self.get_rel_name(matched.rel_e4) {
2524 Some(s) => s.to_string(),
2525 None => return Ok(None),
2526 };
2527 let buf_e1 = match self.store.get(&name_e1) {
2528 Some(b) => b,
2529 None => return Ok(None),
2530 };
2531 let buf_e2 = match self.store.get(&name_e2) {
2532 Some(b) => b,
2533 None => return Ok(None),
2534 };
2535 let buf_e3 = match self.store.get(&name_e3) {
2536 Some(b) => b,
2537 None => return Ok(None),
2538 };
2539 let buf_e4 = match self.store.get(&name_e4) {
2540 Some(b) => b,
2541 None => return Ok(None),
2542 };
2543 let width = match (
2544 classify_two_col_wcoj_width(buf_e1),
2545 classify_two_col_wcoj_width(buf_e2),
2546 classify_two_col_wcoj_width(buf_e3),
2547 classify_two_col_wcoj_width(buf_e4),
2548 ) {
2549 (
2550 Some(WcojKeyWidth::FourByte),
2551 Some(WcojKeyWidth::FourByte),
2552 Some(WcojKeyWidth::FourByte),
2553 Some(WcojKeyWidth::FourByte),
2554 ) => WcojKeyWidth::FourByte,
2555 (
2556 Some(WcojKeyWidth::EightByte),
2557 Some(WcojKeyWidth::EightByte),
2558 Some(WcojKeyWidth::EightByte),
2559 Some(WcojKeyWidth::EightByte),
2560 ) => WcojKeyWidth::EightByte,
2561 _ => return Ok(None),
2562 };
2563 if agg_value.is_some() && width != WcojKeyWidth::FourByte {
2566 return Ok(None);
2567 }
2568 let value_source = match agg_value {
2576 None => None,
2577 Some(Wcoj4CycleRootAggValue::X) => Some(buf_e1),
2578 Some(Wcoj4CycleRootAggValue::Y) => Some(buf_e2),
2579 Some(Wcoj4CycleRootAggValue::Z) => Some(buf_e3),
2580 };
2581 if let Some(src) = value_source {
2582 if src.schema().column_type(1) != Some(xlog_core::ScalarType::U32) {
2583 return Ok(None);
2584 }
2585 }
2586 if self.provider.memory().runtime().is_none() {
2587 return Ok(None);
2588 }
2589 let Some(launch_stream) = self.wcoj_dispatch_stream_or_init() else {
2590 return Ok(None);
2591 };
2592 debug_assert!(
2593 agg_value.is_some() || matches!(agg_op, AggOp::Count),
2594 "non-Count aggregates resolve a value column above"
2595 );
2596 let result = match (agg_value, width) {
2597 (None, WcojKeyWidth::FourByte) => {
2598 self.provider.wcoj_4cycle_groupby_root_count_u32_recorded(
2599 buf_e1,
2600 buf_e2,
2601 buf_e3,
2602 buf_e4,
2603 wcoj_block_work_unit(),
2604 launch_stream,
2605 )
2606 }
2607 (None, WcojKeyWidth::EightByte) => {
2611 self.provider.wcoj_4cycle_groupby_root_count_u64_recorded(
2612 buf_e1,
2613 buf_e2,
2614 buf_e3,
2615 buf_e4,
2616 wcoj_block_work_unit(),
2617 launch_stream,
2618 )
2619 }
2620 (Some(value), _) => self.provider.wcoj_4cycle_groupby_root_agg_u32_recorded(
2621 buf_e1,
2622 buf_e2,
2623 buf_e3,
2624 buf_e4,
2625 agg_op,
2626 value,
2627 wcoj_block_work_unit(),
2628 launch_stream,
2629 ),
2630 };
2631 match result {
2632 Ok(buf) => {
2633 self.wcoj_groupby_fusion_dispatch_count += 1;
2634 Ok(Some(buf))
2635 }
2636 Err(err) => wcoj_decline_on_error(
2637 &mut self.wcoj_error_decline_count,
2638 "groupby-fusion-4cycle",
2639 err,
2640 ),
2641 }
2642 }
2643
2644 pub fn wcoj_4cycle_dispatch_count(&self) -> u64 {
2649 self.wcoj_4cycle_dispatch_count
2650 }
2651
2652 pub fn chain_dispatch_count(&self) -> u64 {
2655 self.chain_dispatch_count
2656 }
2657
2658 pub fn nested_loop_dispatch_count(&self) -> u64 {
2665 self.nested_loop_dispatch_count
2666 }
2667
2668 pub(super) fn try_dispatch_chain_on_body(
2679 &mut self,
2680 body: &RirNode,
2681 ) -> Result<Option<CudaBuffer>> {
2682 if !chain_dispatch_enabled() {
2683 return Ok(None);
2684 }
2685 let Some(matched) = match_chain_join(body) else {
2686 return Ok(None);
2687 };
2688
2689 let name_left = match self.get_rel_name(matched.rel_left) {
2690 Some(s) => s.to_string(),
2691 None => return Ok(None),
2692 };
2693 let name_right = match self.get_rel_name(matched.rel_right) {
2694 Some(s) => s.to_string(),
2695 None => return Ok(None),
2696 };
2697 let left = match self.store.get(&name_left) {
2698 Some(buf) => buf,
2699 None => return Ok(None),
2700 };
2701 let right = match self.store.get(&name_right) {
2702 Some(buf) => buf,
2703 None => return Ok(None),
2704 };
2705
2706 let num_left = self.provider.device_row_count(left)? as u64;
2707 let num_right = self.provider.device_row_count(right)? as u64;
2708 let in_threshold = num_left
2709 .checked_mul(num_right)
2710 .map(|p| p <= NESTED_LOOP_TOTAL_THRESHOLD)
2711 .unwrap_or(false);
2712 let four_byte = matches!(
2713 classify_two_col_wcoj_width(left),
2714 Some(WcojKeyWidth::FourByte)
2715 ) && matches!(
2716 classify_two_col_wcoj_width(right),
2717 Some(WcojKeyWidth::FourByte)
2718 );
2719
2720 let mut used_nested_loop = false;
2721 let joined = if four_byte {
2722 let left_sorted = self
2723 .provider
2724 .is_sorted_ascending_u32(left, matched.left_key)
2725 .unwrap_or(false);
2726 let right_sorted = self
2727 .provider
2728 .is_sorted_ascending_u32(right, matched.right_key)
2729 .unwrap_or(false);
2730 if left_sorted && right_sorted {
2731 if in_threshold {
2732 self.provider.sort_merge_join_v2_inner_u32_1key(
2733 left,
2734 right,
2735 matched.left_key,
2736 matched.right_key,
2737 )
2738 } else {
2739 let capacity = usize::try_from(num_left.min(num_right)).unwrap_or(usize::MAX);
2740 self.provider.sort_merge_join_v2_inner_u32_1key_bounded(
2741 left,
2742 right,
2743 matched.left_key,
2744 matched.right_key,
2745 capacity,
2746 )
2747 }
2748 } else if in_threshold {
2749 used_nested_loop = true;
2750 self.provider.nested_loop_join_v2_inner_u32_1key(
2751 left,
2752 right,
2753 matched.left_key,
2754 matched.right_key,
2755 )
2756 } else {
2757 self.provider.hash_join_v2(
2758 left,
2759 right,
2760 &[matched.left_key],
2761 &[matched.right_key],
2762 CudaJoinType::Inner,
2763 )
2764 }
2765 } else {
2766 self.provider.hash_join_v2(
2767 left,
2768 right,
2769 &[matched.left_key],
2770 &[matched.right_key],
2771 CudaJoinType::Inner,
2772 )
2773 };
2774
2775 let joined = match joined {
2776 Ok(buf) => buf,
2777 Err(err) => {
2778 return wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "chain-join", err)
2779 }
2780 };
2781 let projected = match self.execute_project(&joined, &matched.output_columns) {
2782 Ok(buf) => buf,
2783 Err(err) => {
2784 return wcoj_decline_on_error(
2785 &mut self.wcoj_error_decline_count,
2786 "chain-join-project",
2787 err,
2788 )
2789 }
2790 };
2791 self.stats.record_join_result(
2792 matched.rel_left,
2793 matched.rel_right,
2794 vec![matched.left_key],
2795 vec![matched.right_key],
2796 num_left.saturating_mul(num_right),
2797 joined.num_rows(),
2798 );
2799 if used_nested_loop {
2800 self.nested_loop_dispatch_count += 1;
2801 }
2802 self.chain_dispatch_count += 1;
2803 Ok(Some(projected))
2804 }
2805
2806 pub(super) fn try_dispatch_wcoj_4cycle(
2822 &mut self,
2823 rule: &CompiledRule,
2824 ) -> Result<Option<CudaBuffer>> {
2825 self.try_dispatch_wcoj_4cycle_on_body(&rule.body)
2828 }
2829
2830 pub(super) fn try_dispatch_wcoj_4cycle_on_body(
2835 &mut self,
2836 body: &RirNode,
2837 ) -> Result<Option<CudaBuffer>> {
2838 if wcoj_4cycle_disabled(self.config.wcoj_4cycle_dispatch_disabled) {
2840 return Ok(None);
2841 }
2842 let force_override = self.config.wcoj_4cycle_dispatch;
2844 let force_on = wcoj_4cycle_gate_enabled(force_override);
2845 let mode = if force_on {
2846 DispatchMode::Force
2847 } else {
2848 if matches!(force_override, Some(false)) {
2851 return Ok(None);
2852 }
2853 let adaptive_override = self.config.wcoj_4cycle_dispatch_adaptive;
2854 if wcoj_4cycle_adaptive_enabled(adaptive_override) {
2855 DispatchMode::CostModel
2856 } else {
2857 return Ok(None);
2858 }
2859 };
2860
2861 let Some(matched) = match_multiway_4cycle(body) else {
2863 return Ok(None);
2864 };
2865
2866 let name_e1 = match self.get_rel_name(matched.rel_e1) {
2868 Some(s) => s.to_string(),
2869 None => return Ok(None),
2870 };
2871 let name_e2 = match self.get_rel_name(matched.rel_e2) {
2872 Some(s) => s.to_string(),
2873 None => return Ok(None),
2874 };
2875 let name_e3 = match self.get_rel_name(matched.rel_e3) {
2876 Some(s) => s.to_string(),
2877 None => return Ok(None),
2878 };
2879 let name_e4 = match self.get_rel_name(matched.rel_e4) {
2880 Some(s) => s.to_string(),
2881 None => return Ok(None),
2882 };
2883
2884 let buf_e1 = match self.store.get(&name_e1) {
2887 Some(b) => b,
2888 None => return Ok(None),
2889 };
2890 let buf_e2 = match self.store.get(&name_e2) {
2891 Some(b) => b,
2892 None => return Ok(None),
2893 };
2894 let buf_e3 = match self.store.get(&name_e3) {
2895 Some(b) => b,
2896 None => return Ok(None),
2897 };
2898 let buf_e4 = match self.store.get(&name_e4) {
2899 Some(b) => b,
2900 None => return Ok(None),
2901 };
2902 let width = match (
2903 classify_two_col_wcoj_width(buf_e1),
2904 classify_two_col_wcoj_width(buf_e2),
2905 classify_two_col_wcoj_width(buf_e3),
2906 classify_two_col_wcoj_width(buf_e4),
2907 ) {
2908 (Some(a), Some(b), Some(c), Some(d)) if a == b && b == c && c == d => a,
2909 _ => return Ok(None),
2910 };
2911
2912 if self.provider.memory().runtime().is_none() {
2916 return Ok(None);
2917 }
2918 let launch_stream = match self.wcoj_dispatch_stream_or_init() {
2919 Some(s) => s,
2920 None => return Ok(None),
2921 };
2922
2923 if mode == DispatchMode::CostModel {
2926 let model = super::wcoj_cost_model::build_wcoj_cost_model(&self.config);
2928 let slot_rels = [
2929 matched.rel_e1,
2930 matched.rel_e2,
2931 matched.rel_e3,
2932 matched.rel_e4,
2933 ];
2934 let ctx = super::wcoj_cost_model::WcojDispatchCtx {
2935 stats: &self.stats,
2936 launch_stream,
2937 width,
2938 slot_rels: &slot_rels,
2939 };
2940 let dispatch = model.should_dispatch_4cycle(&ctx);
2941 if !dispatch {
2942 return Ok(None);
2943 }
2944 }
2945
2946 let var_order_opt: Option<&VariableOrder> = match body {
2949 RirNode::MultiWayJoin { var_order, .. } => var_order.as_ref(),
2950 _ => None,
2951 };
2952
2953 let dispatch_result = self.run_wcoj_4cycle_pipeline(
2956 buf_e1,
2957 buf_e2,
2958 buf_e3,
2959 buf_e4,
2960 launch_stream,
2961 width,
2962 var_order_opt,
2963 );
2964 match dispatch_result {
2965 Ok(buf) => {
2966 let output_rows = Self::wcoj_output_rows(&buf);
2980 let slot_rels = [
2981 matched.rel_e1,
2982 matched.rel_e2,
2983 matched.rel_e3,
2984 matched.rel_e4,
2985 ];
2986 self.record_wcoj_feedback(&slot_rels, var_order_opt, output_rows);
2987 self.wcoj_4cycle_dispatch_count += 1;
2988 Ok(Some(buf))
2989 }
2990 Err(err) => wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "4-cycle", err),
2991 }
2992 }
2993
2994 #[allow(clippy::too_many_arguments)]
2996 fn run_wcoj_4cycle_pipeline(
2997 &self,
2998 buf_e1: &CudaBuffer,
2999 buf_e2: &CudaBuffer,
3000 buf_e3: &CudaBuffer,
3001 buf_e4: &CudaBuffer,
3002 launch_stream: StreamId,
3003 width: WcojKeyWidth,
3004 var_order: Option<&VariableOrder>,
3005 ) -> Result<CudaBuffer> {
3006 if let Some(vo) = var_order {
3007 return self.run_wcoj_4cycle_pipeline_with_leader_order(
3008 buf_e1,
3009 buf_e2,
3010 buf_e3,
3011 buf_e4,
3012 launch_stream,
3013 width,
3014 vo,
3015 );
3016 }
3017 match width {
3018 WcojKeyWidth::FourByte => {
3019 let layout_e1 = self
3020 .provider
3021 .wcoj_layout_u32_recorded(buf_e1, launch_stream)?;
3022 let layout_e2 = self
3023 .provider
3024 .wcoj_layout_u32_recorded(buf_e2, launch_stream)?;
3025 let layout_e3 = self
3026 .provider
3027 .wcoj_layout_u32_recorded(buf_e3, launch_stream)?;
3028 let layout_e4 = self
3029 .provider
3030 .wcoj_layout_u32_recorded(buf_e4, launch_stream)?;
3031 self.provider.wcoj_4cycle_u32_recorded(
3032 &layout_e1,
3033 &layout_e2,
3034 &layout_e3,
3035 &layout_e4,
3036 launch_stream,
3037 )
3038 }
3039 WcojKeyWidth::EightByte => {
3040 let layout_e1 = self
3041 .provider
3042 .wcoj_layout_u64_recorded(buf_e1, launch_stream)?;
3043 let layout_e2 = self
3044 .provider
3045 .wcoj_layout_u64_recorded(buf_e2, launch_stream)?;
3046 let layout_e3 = self
3047 .provider
3048 .wcoj_layout_u64_recorded(buf_e3, launch_stream)?;
3049 let layout_e4 = self
3050 .provider
3051 .wcoj_layout_u64_recorded(buf_e4, launch_stream)?;
3052 self.provider.wcoj_4cycle_u64_recorded(
3053 &layout_e1,
3054 &layout_e2,
3055 &layout_e3,
3056 &layout_e4,
3057 launch_stream,
3058 )
3059 }
3060 }
3061 }
3062
3063 #[allow(clippy::too_many_arguments)]
3069 fn run_wcoj_4cycle_pipeline_with_leader_order(
3070 &self,
3071 buf_e1: &CudaBuffer,
3072 buf_e2: &CudaBuffer,
3073 buf_e3: &CudaBuffer,
3074 buf_e4: &CudaBuffer,
3075 launch_stream: StreamId,
3076 width: WcojKeyWidth,
3077 var_order: &VariableOrder,
3078 ) -> Result<CudaBuffer> {
3079 let canonical: [&CudaBuffer; 4] = [buf_e1, buf_e2, buf_e3, buf_e4];
3080 let slot_inputs = self.prepare_leader_inputs(&canonical, var_order, launch_stream)?;
3081 if slot_inputs.len() != 4 {
3082 return Err(xlog_core::XlogError::Kernel(
3083 "run_wcoj_4cycle_pipeline_with_leader_order: prepare_leader_inputs must return 4 slots"
3084 .to_string(),
3085 ));
3086 }
3087
3088 let head_schema = build_4cycle_head_schema(buf_e1, buf_e2, buf_e3)?;
3089 let perm = perm_indices_from_kernel_output_cols(&var_order.kernel_output_cols)?;
3090
3091 let kernel_out: CudaBuffer = match width {
3092 WcojKeyWidth::FourByte => {
3093 let l0 = self
3094 .provider
3095 .wcoj_layout_u32_recorded(&slot_inputs[0], launch_stream)?;
3096 let l1 = self
3097 .provider
3098 .wcoj_layout_u32_recorded(&slot_inputs[1], launch_stream)?;
3099 let l2 = self
3100 .provider
3101 .wcoj_layout_u32_recorded(&slot_inputs[2], launch_stream)?;
3102 let l3 = self
3103 .provider
3104 .wcoj_layout_u32_recorded(&slot_inputs[3], launch_stream)?;
3105 self.provider
3106 .wcoj_4cycle_u32_recorded(&l0, &l1, &l2, &l3, launch_stream)?
3107 }
3108 WcojKeyWidth::EightByte => {
3109 let l0 = self
3110 .provider
3111 .wcoj_layout_u64_recorded(&slot_inputs[0], launch_stream)?;
3112 let l1 = self
3113 .provider
3114 .wcoj_layout_u64_recorded(&slot_inputs[1], launch_stream)?;
3115 let l2 = self
3116 .provider
3117 .wcoj_layout_u64_recorded(&slot_inputs[2], launch_stream)?;
3118 let l3 = self
3119 .provider
3120 .wcoj_layout_u64_recorded(&slot_inputs[3], launch_stream)?;
3121 self.provider
3122 .wcoj_4cycle_u64_recorded(&l0, &l1, &l2, &l3, launch_stream)?
3123 }
3124 };
3125
3126 self.provider.wcoj_project_output_columns_recorded(
3127 &kernel_out,
3128 &perm,
3129 head_schema,
3130 launch_stream,
3131 )
3132 }
3133
3134 pub fn prepare_leader_inputs(
3162 &self,
3163 canonical: &[&CudaBuffer],
3164 var_order: &VariableOrder,
3165 launch_stream: StreamId,
3166 ) -> Result<Vec<CudaBuffer>> {
3167 let n = canonical.len();
3168 if !(n == 3 || n == 4) {
3169 return Err(xlog_core::XlogError::Kernel(format!(
3170 "prepare_leader_inputs: canonical inputs must be 3 (triangle) or 4 (4-cycle), got {n}"
3171 )));
3172 }
3173 let leader_idx = var_order.leader_idx as usize;
3174 if leader_idx >= n {
3175 return Err(xlog_core::XlogError::Kernel(format!(
3176 "prepare_leader_inputs: leader_idx {leader_idx} out of range for arity {n}"
3177 )));
3178 }
3179 if var_order.lookup_perms.len() != n - 1 {
3180 return Err(xlog_core::XlogError::Kernel(format!(
3181 "prepare_leader_inputs: lookup_perms.len() = {} must equal {} (arity - 1)",
3182 var_order.lookup_perms.len(),
3183 n - 1
3184 )));
3185 }
3186 for (slot, lp) in var_order.lookup_perms.iter().enumerate() {
3187 let input_idx = lp.input_idx as usize;
3188 if input_idx >= n {
3189 return Err(xlog_core::XlogError::Kernel(format!(
3190 "prepare_leader_inputs: lookup_perms[{slot}].input_idx {input_idx} out of range for arity {n}"
3191 )));
3192 }
3193 }
3194 if n == 4 {
3196 for lp in &var_order.lookup_perms {
3197 if lp.swap_cols {
3198 return Err(xlog_core::XlogError::Kernel(
3199 "prepare_leader_inputs: 4-cycle does not support col-swaps".to_string(),
3200 ));
3201 }
3202 }
3203 }
3204
3205 let mut slots: Vec<CudaBuffer> = Vec::with_capacity(n);
3222 slots.push(self.clone_buffer_via_swap(canonical[leader_idx], launch_stream)?);
3224 for lp in &var_order.lookup_perms {
3225 let src = canonical[lp.input_idx as usize];
3226 let buf = if lp.swap_cols {
3227 self.provider
3228 .wcoj_project_2col_swap_recorded(src, launch_stream)?
3229 } else {
3230 self.clone_buffer_via_swap(src, launch_stream)?
3231 };
3232 slots.push(buf);
3233 }
3234 Ok(slots)
3235 }
3236
3237 fn clone_buffer_via_swap(
3243 &self,
3244 src: &CudaBuffer,
3245 launch_stream: StreamId,
3246 ) -> Result<CudaBuffer> {
3247 let once = self
3248 .provider
3249 .wcoj_project_2col_swap_recorded(src, launch_stream)?;
3250 self.provider
3251 .wcoj_project_2col_swap_recorded(&once, launch_stream)
3252 }
3253
3254 pub fn wcoj_dispatch_stream_or_init(&self) -> Option<StreamId> {
3271 if let Some(s) = self.wcoj_dispatch_stream.get() {
3272 return Some(*s);
3273 }
3274 let runtime = self.provider.memory().runtime()?;
3275 let stream = runtime.stream_pool().acquire().ok()?;
3276 let _ = self.wcoj_dispatch_stream.set(stream);
3277 self.wcoj_dispatch_stream.get().copied()
3278 }
3279}
3280
3281impl Executor {
3294 pub fn wcoj_clique5_dispatch_count(&self) -> u64 {
3299 self.wcoj_clique5_dispatch_count
3300 }
3301
3302 pub fn wcoj_clique6_dispatch_count(&self) -> u64 {
3306 self.wcoj_clique6_dispatch_count
3307 }
3308
3309 pub fn wcoj_clique7_dispatch_count(&self) -> u64 {
3313 self.wcoj_clique7_dispatch_count
3314 }
3315
3316 pub fn wcoj_clique8_dispatch_count(&self) -> u64 {
3320 self.wcoj_clique8_dispatch_count
3321 }
3322
3323 pub fn kclique_histogram_refresh_count(&self) -> u64 {
3326 self.kclique_histogram_refresh_count
3327 }
3328
3329 pub fn kclique_histogram_refresh_nanos(&self) -> u128 {
3332 self.kclique_histogram_refresh_nanos
3333 }
3334
3335 pub(super) fn try_dispatch_wcoj_clique5(
3338 &mut self,
3339 rule: &CompiledRule,
3340 ) -> Result<Option<CudaBuffer>> {
3341 self.try_dispatch_wcoj_clique5_on_body(&rule.body)
3342 }
3343
3344 pub(super) fn try_dispatch_wcoj_clique6(
3346 &mut self,
3347 rule: &CompiledRule,
3348 ) -> Result<Option<CudaBuffer>> {
3349 self.try_dispatch_wcoj_clique6_on_body(&rule.body)
3350 }
3351
3352 pub(super) fn try_dispatch_wcoj_clique7(
3354 &mut self,
3355 rule: &CompiledRule,
3356 ) -> Result<Option<CudaBuffer>> {
3357 self.try_dispatch_wcoj_clique7_on_body(&rule.body)
3358 }
3359
3360 pub(super) fn try_dispatch_wcoj_clique8(
3362 &mut self,
3363 rule: &CompiledRule,
3364 ) -> Result<Option<CudaBuffer>> {
3365 self.try_dispatch_wcoj_clique8_on_body(&rule.body)
3366 }
3367
3368 pub(super) fn try_dispatch_wcoj_clique5_on_body(
3370 &mut self,
3371 body: &RirNode,
3372 ) -> Result<Option<CudaBuffer>> {
3373 self.try_dispatch_wcoj_clique_k_on_body(body, 5)
3374 }
3375
3376 pub(super) fn try_dispatch_wcoj_clique6_on_body(
3378 &mut self,
3379 body: &RirNode,
3380 ) -> Result<Option<CudaBuffer>> {
3381 self.try_dispatch_wcoj_clique_k_on_body(body, 6)
3382 }
3383
3384 pub(super) fn try_dispatch_wcoj_clique7_on_body(
3386 &mut self,
3387 body: &RirNode,
3388 ) -> Result<Option<CudaBuffer>> {
3389 self.try_dispatch_wcoj_clique_k_on_body(body, 7)
3390 }
3391
3392 pub(super) fn try_dispatch_wcoj_clique8_on_body(
3394 &mut self,
3395 body: &RirNode,
3396 ) -> Result<Option<CudaBuffer>> {
3397 self.try_dispatch_wcoj_clique_k_on_body(body, 8)
3398 }
3399
3400 fn try_dispatch_wcoj_clique_k_on_body(
3404 &mut self,
3405 body: &RirNode,
3406 k: usize,
3407 ) -> Result<Option<CudaBuffer>> {
3408 let expected_edges = k * (k - 1) / 2;
3409 let RirNode::MultiWayJoin {
3411 inputs,
3412 plan,
3413 var_order,
3414 ..
3415 } = body
3416 else {
3417 return Ok(None);
3418 };
3419 if matches!(plan, Some(MultiwayPlan::PlannedHashRoute { .. })) {
3420 return Ok(None);
3421 }
3422 if inputs.len() != expected_edges {
3423 return Ok(None);
3424 }
3425 let kclique = match var_order.as_ref().and_then(|order| order.kclique.as_ref()) {
3426 Some(plan) if usize::from(plan.k) == k => plan,
3427 _ => return Ok(None),
3428 };
3429 let mut rel_ids: Vec<RelId> = Vec::with_capacity(expected_edges);
3431 for input in inputs {
3432 let RirNode::Scan { rel } = input else {
3433 return Ok(None);
3434 };
3435 rel_ids.push(*rel);
3436 }
3437 let mut raw_bufs: Vec<&CudaBuffer> = Vec::with_capacity(expected_edges);
3439 for rid in &rel_ids {
3440 let name = match self.rel_names.get(rid) {
3441 Some(n) => n.clone(),
3442 None => return Ok(None),
3443 };
3444 match self.store.get(&name) {
3445 Some(b) => raw_bufs.push(b),
3446 None => return Ok(None),
3447 }
3448 }
3449 let launch_stream = match self.wcoj_dispatch_stream_or_init() {
3451 Some(s) => s,
3452 None => return Ok(None),
3453 };
3454 let first_ty = match raw_bufs[0].schema.column_type(0) {
3458 Some(t) => t,
3459 None => return Ok(None),
3460 };
3461 let is_u64 = matches!(first_ty, xlog_core::ScalarType::U64);
3462 let is_4byte = matches!(
3463 first_ty,
3464 xlog_core::ScalarType::U32 | xlog_core::ScalarType::Symbol
3465 );
3466 if !is_u64 && !is_4byte {
3467 return Ok(None);
3468 }
3469 let Some(plan_params) = kclique_dispatch_params(kclique, k) else {
3470 return Ok(None);
3471 };
3472 let head_schema = match build_kclique_head_schema(&raw_bufs, k) {
3473 Some(schema) => schema,
3474 None => return Ok(None),
3475 };
3476 let output_perm = match kclique_output_perm(kclique, k) {
3477 Some(perm) => perm,
3478 None => return Ok(None),
3479 };
3480 let laid_out = match self.orient_and_layout_kclique_edges(
3486 &raw_bufs,
3487 &plan_params,
3488 is_u64,
3489 launch_stream,
3490 ) {
3491 Ok(bufs) => bufs,
3492 Err(err) => {
3493 return wcoj_decline_on_error(
3494 &mut self.wcoj_error_decline_count,
3495 "k-clique-layout",
3496 err,
3497 )
3498 }
3499 };
3500 let edge_refs: Vec<&CudaBuffer> = laid_out.iter().collect();
3503 let result = match (k, is_u64) {
3505 (5, false) => {
3506 let arr: &[&CudaBuffer; 10] = match edge_refs.as_slice().try_into() {
3507 Ok(a) => a,
3508 Err(_) => return Ok(None),
3509 };
3510 self.provider.wcoj_clique5_u32_recorded_planned(
3511 arr,
3512 plan_params.leader_edge_idx,
3513 &plan_params.edge_order,
3514 &plan_params.iteration_order,
3515 launch_stream,
3516 )
3517 }
3518 (5, true) => {
3519 let arr: &[&CudaBuffer; 10] = match edge_refs.as_slice().try_into() {
3520 Ok(a) => a,
3521 Err(_) => return Ok(None),
3522 };
3523 self.provider.wcoj_clique5_u64_recorded_planned(
3524 arr,
3525 plan_params.leader_edge_idx,
3526 &plan_params.edge_order,
3527 &plan_params.iteration_order,
3528 launch_stream,
3529 )
3530 }
3531 (6, false) => {
3532 let arr: &[&CudaBuffer; 15] = match edge_refs.as_slice().try_into() {
3533 Ok(a) => a,
3534 Err(_) => return Ok(None),
3535 };
3536 self.provider.wcoj_clique6_u32_recorded_planned(
3537 arr,
3538 plan_params.leader_edge_idx,
3539 &plan_params.edge_order,
3540 &plan_params.iteration_order,
3541 launch_stream,
3542 )
3543 }
3544 (6, true) => {
3545 let arr: &[&CudaBuffer; 15] = match edge_refs.as_slice().try_into() {
3546 Ok(a) => a,
3547 Err(_) => return Ok(None),
3548 };
3549 self.provider.wcoj_clique6_u64_recorded_planned(
3550 arr,
3551 plan_params.leader_edge_idx,
3552 &plan_params.edge_order,
3553 &plan_params.iteration_order,
3554 launch_stream,
3555 )
3556 }
3557 (7, false) => {
3558 let arr: &[&CudaBuffer; 21] = match edge_refs.as_slice().try_into() {
3559 Ok(a) => a,
3560 Err(_) => return Ok(None),
3561 };
3562 self.provider.wcoj_clique7_u32_recorded_planned(
3563 arr,
3564 plan_params.leader_edge_idx,
3565 &plan_params.edge_order,
3566 &plan_params.iteration_order,
3567 launch_stream,
3568 )
3569 }
3570 (7, true) => {
3571 let arr: &[&CudaBuffer; 21] = match edge_refs.as_slice().try_into() {
3572 Ok(a) => a,
3573 Err(_) => return Ok(None),
3574 };
3575 self.provider.wcoj_clique7_u64_recorded_planned(
3576 arr,
3577 plan_params.leader_edge_idx,
3578 &plan_params.edge_order,
3579 &plan_params.iteration_order,
3580 launch_stream,
3581 )
3582 }
3583 (8, false) => {
3584 let arr: &[&CudaBuffer; 28] = match edge_refs.as_slice().try_into() {
3585 Ok(a) => a,
3586 Err(_) => return Ok(None),
3587 };
3588 self.provider.wcoj_clique8_u32_recorded_planned(
3589 arr,
3590 plan_params.leader_edge_idx,
3591 &plan_params.edge_order,
3592 &plan_params.iteration_order,
3593 launch_stream,
3594 )
3595 }
3596 (8, true) => {
3597 let arr: &[&CudaBuffer; 28] = match edge_refs.as_slice().try_into() {
3598 Ok(a) => a,
3599 Err(_) => return Ok(None),
3600 };
3601 self.provider.wcoj_clique8_u64_recorded_planned(
3602 arr,
3603 plan_params.leader_edge_idx,
3604 &plan_params.edge_order,
3605 &plan_params.iteration_order,
3606 launch_stream,
3607 )
3608 }
3609 _ => return Ok(None),
3610 };
3611 match result {
3614 Ok(buf) => {
3615 let buf = if output_perm.iter().copied().eq(0..output_perm.len()) {
3616 buf
3617 } else {
3618 self.provider.wcoj_project_output_columns_recorded(
3619 &buf,
3620 &output_perm,
3621 head_schema,
3622 launch_stream,
3623 )?
3624 };
3625 match k {
3626 5 => self.wcoj_clique5_dispatch_count += 1,
3627 6 => self.wcoj_clique6_dispatch_count += 1,
3628 7 => self.wcoj_clique7_dispatch_count += 1,
3629 8 => self.wcoj_clique8_dispatch_count += 1,
3630 _ => {}
3631 }
3632 Ok(Some(buf))
3633 }
3634 Err(err) => wcoj_decline_on_error(&mut self.wcoj_error_decline_count, "k-clique", err),
3635 }
3636 }
3637
3638 fn orient_and_layout_kclique_edges(
3647 &self,
3648 raw_bufs: &[&CudaBuffer],
3649 plan_params: &KCliqueDispatchParams,
3650 is_u64: bool,
3651 launch_stream: StreamId,
3652 ) -> Result<Vec<CudaBuffer>> {
3653 let mut laid_out: Vec<CudaBuffer> = Vec::with_capacity(plan_params.edge_permutation.len());
3654 for (slot, &input_idx) in plan_params.edge_permutation.iter().enumerate() {
3655 let src = raw_bufs[input_idx];
3656 let swapped = if plan_params.swap_slots.contains(&slot) {
3657 Some(
3658 self.provider
3659 .wcoj_project_2col_swap_recorded(src, launch_stream)?,
3660 )
3661 } else {
3662 None
3663 };
3664 let oriented = swapped.as_ref().unwrap_or(src);
3665 let res = if plan_params.required_sort_slots.contains(&slot) {
3666 if is_u64 {
3667 self.provider
3668 .wcoj_layout_sort_u64_recorded(oriented, launch_stream)
3669 } else {
3670 self.provider
3671 .wcoj_layout_sort_u32_recorded(oriented, launch_stream)
3672 }
3673 } else if is_u64 {
3674 self.provider
3675 .wcoj_layout_u64_recorded(oriented, launch_stream)
3676 } else {
3677 self.provider
3678 .wcoj_layout_u32_recorded(oriented, launch_stream)
3679 };
3680 laid_out.push(res?);
3681 }
3682 Ok(laid_out)
3683 }
3684
3685 fn try_dispatch_wcoj_groupby_root_count_clique(
3702 &mut self,
3703 multiway: &RirNode,
3704 group_cols: &[ProjectExpr],
3705 ) -> Result<Option<CudaBuffer>> {
3706 let RirNode::MultiWayJoin {
3707 inputs,
3708 plan,
3709 var_order,
3710 ..
3711 } = multiway
3712 else {
3713 return Ok(None);
3714 };
3715 if matches!(plan, Some(MultiwayPlan::PlannedHashRoute { .. })) {
3716 return Ok(None);
3717 }
3718 let kclique = match var_order.as_ref().and_then(|order| order.kclique.as_ref()) {
3719 Some(plan) => plan,
3720 None => return Ok(None),
3721 };
3722 let k = usize::from(kclique.k);
3723 if !matches!(k, 5 | 6) {
3724 return Ok(None);
3725 }
3726 let expected_edges = k * (k - 1) / 2;
3727 if inputs.len() != expected_edges {
3728 return Ok(None);
3729 }
3730 let Some(ProjectExpr::Column(root_var)) = group_cols.first() else {
3732 return Ok(None);
3733 };
3734 let Some(positions) = live_kclique_variable_positions(kclique, k) else {
3735 return Ok(None);
3736 };
3737 if *root_var >= k || positions[*root_var] != 0 {
3738 return Ok(None);
3739 }
3740 let mut rel_ids: Vec<RelId> = Vec::with_capacity(expected_edges);
3742 for input in inputs {
3743 let RirNode::Scan { rel } = input else {
3744 return Ok(None);
3745 };
3746 rel_ids.push(*rel);
3747 }
3748 let mut raw_bufs: Vec<&CudaBuffer> = Vec::with_capacity(expected_edges);
3749 for rid in &rel_ids {
3750 let name = match self.rel_names.get(rid) {
3751 Some(n) => n.clone(),
3752 None => return Ok(None),
3753 };
3754 match self.store.get(&name) {
3755 Some(b) => raw_bufs.push(b),
3756 None => return Ok(None),
3757 }
3758 }
3759 for buf in &raw_bufs {
3760 if classify_two_col_wcoj_width(buf) != Some(WcojKeyWidth::FourByte) {
3761 return Ok(None);
3762 }
3763 }
3764 if self.provider.memory().runtime().is_none() {
3765 return Ok(None);
3766 }
3767 let Some(launch_stream) = self.wcoj_dispatch_stream_or_init() else {
3768 return Ok(None);
3769 };
3770 let Some(plan_params) = kclique_dispatch_params(kclique, k) else {
3771 return Ok(None);
3772 };
3773 let laid_out = match self.orient_and_layout_kclique_edges(
3774 &raw_bufs,
3775 &plan_params,
3776 false,
3777 launch_stream,
3778 ) {
3779 Ok(bufs) => bufs,
3780 Err(err) => {
3781 return wcoj_decline_on_error(
3782 &mut self.wcoj_error_decline_count,
3783 "groupby-fusion-clique-layout",
3784 err,
3785 )
3786 }
3787 };
3788 let edge_refs: Vec<&CudaBuffer> = laid_out.iter().collect();
3789 let result = match k {
3790 5 => {
3791 let arr: &[&CudaBuffer; 10] = match edge_refs.as_slice().try_into() {
3792 Ok(a) => a,
3793 Err(_) => return Ok(None),
3794 };
3795 self.provider
3796 .wcoj_clique5_groupby_root_count_u32_recorded_planned(
3797 arr,
3798 plan_params.leader_edge_idx,
3799 &plan_params.edge_order,
3800 &plan_params.iteration_order,
3801 launch_stream,
3802 )
3803 }
3804 _ => {
3805 let arr: &[&CudaBuffer; 15] = match edge_refs.as_slice().try_into() {
3806 Ok(a) => a,
3807 Err(_) => return Ok(None),
3808 };
3809 self.provider
3810 .wcoj_clique6_groupby_root_count_u32_recorded_planned(
3811 arr,
3812 plan_params.leader_edge_idx,
3813 &plan_params.edge_order,
3814 &plan_params.iteration_order,
3815 launch_stream,
3816 )
3817 }
3818 };
3819 match result {
3820 Ok(buf) => {
3821 self.wcoj_groupby_fusion_dispatch_count += 1;
3822 Ok(Some(buf))
3823 }
3824 Err(err) => wcoj_decline_on_error(
3825 &mut self.wcoj_error_decline_count,
3826 "groupby-fusion-clique",
3827 err,
3828 ),
3829 }
3830 }
3831}
3832
3833#[derive(Debug)]
3834struct KCliqueDispatchParams {
3835 edge_permutation: Vec<usize>,
3836 edge_order: Vec<u8>,
3837 iteration_order: Vec<u8>,
3838 leader_edge_idx: u32,
3839 swap_slots: HashSet<usize>,
3840 required_sort_slots: HashSet<usize>,
3841}
3842
3843fn kclique_dispatch_params(plan: &KCliqueVariableOrder, k: usize) -> Option<KCliqueDispatchParams> {
3844 let expected_edges = k * (k - 1) / 2;
3845 let edge_permutation = live_kclique_edge_permutation(plan, expected_edges)?;
3846 let positions = live_kclique_variable_positions(plan, k)?;
3847 let mut edge_order = vec![u8::MAX; expected_edges];
3848
3849 for (slot, &edge_idx) in edge_permutation.iter().enumerate() {
3850 let (left, right) = clique_edge_pair(edge_idx, k)?;
3851 let left_pos = positions[left];
3852 let right_pos = positions[right];
3853 let logical_edge =
3854 clique_edge_idx_runtime(left_pos.min(right_pos), left_pos.max(right_pos), k)?;
3855 edge_order[logical_edge] = u8::try_from(slot).ok()?;
3856 }
3857 if edge_order.contains(&u8::MAX) {
3858 return None;
3859 }
3860 let leader_edge_idx = u32::from(edge_order[clique_edge_idx_runtime(0, 1, k)?]);
3861 let iteration_order: Vec<u8> = (0..k)
3862 .map(|idx| u8::try_from(idx).ok())
3863 .collect::<Option<_>>()?;
3864
3865 let swap_slots: HashSet<usize> = plan
3866 .column_swaps
3867 .iter()
3868 .filter(|swap| swap.swap_cols)
3869 .map(|swap| usize::from(swap.edge_slot))
3870 .collect();
3871 if swap_slots.iter().any(|slot| *slot >= expected_edges) {
3872 return None;
3873 }
3874 let required_sort_slots: HashSet<usize> = plan
3875 .sorted_layout_requirements
3876 .edge_slots
3877 .iter()
3878 .copied()
3879 .map(usize::from)
3880 .collect();
3881 if required_sort_slots
3882 .iter()
3883 .any(|slot| *slot >= expected_edges)
3884 {
3885 return None;
3886 }
3887
3888 Some(KCliqueDispatchParams {
3889 edge_permutation,
3890 edge_order,
3891 iteration_order,
3892 leader_edge_idx,
3893 swap_slots,
3894 required_sort_slots,
3895 })
3896}
3897
3898fn live_kclique_edge_permutation(
3899 plan: &KCliqueVariableOrder,
3900 expected_edges: usize,
3901) -> Option<Vec<usize>> {
3902 let values: Vec<usize> = plan
3903 .edge_permutation
3904 .iter()
3905 .copied()
3906 .take_while(|value| *value != u8::MAX)
3907 .map(usize::from)
3908 .collect();
3909 if values.len() != expected_edges {
3910 return None;
3911 }
3912 let mut seen = vec![false; expected_edges];
3913 for &value in &values {
3914 if value >= expected_edges || seen[value] {
3915 return None;
3916 }
3917 seen[value] = true;
3918 }
3919 Some(values)
3920}
3921
3922fn live_kclique_variable_positions(plan: &KCliqueVariableOrder, k: usize) -> Option<Vec<usize>> {
3923 let mut positions = Vec::with_capacity(k);
3924 let mut seen = vec![false; k];
3925 for original_var in 0..k {
3926 let pos = usize::from(*plan.variable_positions.get(original_var)?);
3927 if pos >= k || seen[pos] {
3928 return None;
3929 }
3930 seen[pos] = true;
3931 positions.push(pos);
3932 }
3933 Some(positions)
3934}
3935
3936fn clique_edge_idx_runtime(i: usize, j: usize, k: usize) -> Option<usize> {
3937 if !(i < j && j < k) {
3938 return None;
3939 }
3940 Some(i * (k - 1) - i.saturating_sub(1) * i / 2 + (j - i - 1))
3941}
3942
3943fn clique_edge_pair(edge_idx: usize, k: usize) -> Option<(usize, usize)> {
3944 let mut idx = 0usize;
3945 for i in 0..k {
3946 for j in (i + 1)..k {
3947 if idx == edge_idx {
3948 return Some((i, j));
3949 }
3950 idx += 1;
3951 }
3952 }
3953 None
3954}
3955
3956fn build_kclique_head_schema(raw_bufs: &[&CudaBuffer], k: usize) -> Option<Schema> {
3957 let mut columns = Vec::with_capacity(k);
3958 for variable in 0..k {
3959 let (edge_idx, col_idx) = if variable == 0 {
3960 (clique_edge_idx_runtime(0, 1, k)?, 0)
3961 } else {
3962 (clique_edge_idx_runtime(0, variable, k)?, 1)
3963 };
3964 let ty = raw_bufs.get(edge_idx)?.schema.column_type(col_idx)?;
3965 columns.push((format!("col{}", variable), ty));
3966 }
3967 Some(Schema::new(columns))
3968}
3969
3970fn kclique_output_perm(plan: &KCliqueVariableOrder, k: usize) -> Option<Vec<usize>> {
3971 let positions = live_kclique_variable_positions(plan, k)?;
3972 Some(positions)
3973}
3974
3975#[cfg(test)]
3976mod tests {
3977 use std::sync::{Mutex, OnceLock};
3978
3979 use super::{
3980 chain_dispatch_enabled, match_chain_join, match_multiway_triangle, wcoj_adaptive_enabled,
3981 wcoj_gate_enabled, ENV_USE_WCOJ_TRIANGLE_U32, ENV_WCOJ_CHAIN_ENABLE,
3982 };
3983 use xlog_core::RelId;
3984 use xlog_ir::rir::ProjectExpr;
3985 use xlog_ir::RirNode;
3986
3987 fn canonical_multiway() -> RirNode {
3988 RirNode::MultiWayJoin {
3989 inputs: vec![
3990 RirNode::Scan { rel: RelId(1) },
3991 RirNode::Scan { rel: RelId(2) },
3992 RirNode::Scan { rel: RelId(3) },
3993 ],
3994 slot_vars: vec![
3995 vec![Some(0u32), Some(1)],
3996 vec![Some(1u32), Some(2)],
3997 vec![Some(0u32), Some(2)],
3998 ],
3999 output_columns: vec![
4000 ProjectExpr::Column(0),
4001 ProjectExpr::Column(1),
4002 ProjectExpr::Column(3),
4003 ],
4004 fallback: Box::new(RirNode::Unit),
4005 plan: None,
4006 var_order: None,
4007 }
4008 }
4009
4010 fn canonical_chain_join() -> RirNode {
4011 RirNode::ChainJoin {
4012 left: Box::new(RirNode::Scan { rel: RelId(1) }),
4013 right: Box::new(RirNode::Scan { rel: RelId(2) }),
4014 left_key: 1,
4015 right_key: 0,
4016 output_columns: vec![ProjectExpr::Column(0), ProjectExpr::Column(3)],
4017 fallback: Box::new(RirNode::Unit),
4018 }
4019 }
4020
4021 #[test]
4022 fn match_chain_returns_two_rels_and_keys() {
4023 let node = canonical_chain_join();
4024 let m = match_chain_join(&node).expect("must match canonical chain");
4025 assert_eq!(m.rel_left, RelId(1));
4026 assert_eq!(m.rel_right, RelId(2));
4027 assert_eq!(m.left_key, 1);
4028 assert_eq!(m.right_key, 0);
4029 assert_eq!(
4030 m.output_columns,
4031 vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
4032 );
4033 }
4034
4035 #[test]
4036 fn match_chain_rejects_non_scan_inputs() {
4037 let mut node = canonical_chain_join();
4038 if let RirNode::ChainJoin { left, .. } = &mut node {
4039 **left = RirNode::Unit;
4040 }
4041 assert!(match_chain_join(&node).is_none());
4042 }
4043
4044 #[test]
4045 fn match_chain_rejects_multiway_triangle() {
4046 let node = canonical_multiway();
4047 assert!(match_chain_join(&node).is_none());
4048 }
4049
4050 #[test]
4051 fn chain_dispatch_env_defaults_on_and_can_disable() {
4052 static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
4053 let _guard = ENV_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap();
4054 let old = std::env::var(ENV_WCOJ_CHAIN_ENABLE).ok();
4055 unsafe {
4058 std::env::remove_var(ENV_WCOJ_CHAIN_ENABLE);
4059 }
4060 assert!(chain_dispatch_enabled());
4061 unsafe {
4062 std::env::set_var(ENV_WCOJ_CHAIN_ENABLE, "0");
4063 }
4064 assert!(!chain_dispatch_enabled());
4065 unsafe {
4066 std::env::set_var(ENV_WCOJ_CHAIN_ENABLE, "false");
4067 }
4068 assert!(!chain_dispatch_enabled());
4069 unsafe {
4070 std::env::set_var(ENV_WCOJ_CHAIN_ENABLE, "1");
4071 }
4072 assert!(chain_dispatch_enabled());
4073 unsafe {
4074 match old {
4075 Some(v) => std::env::set_var(ENV_WCOJ_CHAIN_ENABLE, v),
4076 None => std::env::remove_var(ENV_WCOJ_CHAIN_ENABLE),
4077 }
4078 }
4079 }
4080
4081 #[test]
4082 fn match_canonical_returns_three_rels() {
4083 let node = canonical_multiway();
4084 let m = match_multiway_triangle(&node).expect("must match canonical triangle");
4085 assert_eq!(m.rel_xy, RelId(1));
4086 assert_eq!(m.rel_yz, RelId(2));
4087 assert_eq!(m.rel_xz, RelId(3));
4088 }
4089
4090 #[test]
4091 fn match_rejects_non_multiway_body() {
4092 let node = RirNode::Scan { rel: RelId(1) };
4093 assert!(match_multiway_triangle(&node).is_none());
4094 }
4095
4096 #[test]
4097 fn match_rejects_rotated_output_columns() {
4098 let mut node = canonical_multiway();
4099 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4100 *output_columns = vec![
4101 ProjectExpr::Column(1),
4102 ProjectExpr::Column(0),
4103 ProjectExpr::Column(3),
4104 ];
4105 }
4106 assert!(match_multiway_triangle(&node).is_none());
4107 }
4108
4109 #[test]
4114 fn match_accepts_z_shared_triangle_output_columns() {
4115 let mut node = canonical_multiway();
4116 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4117 *output_columns = vec![
4118 ProjectExpr::Column(0),
4119 ProjectExpr::Column(2),
4120 ProjectExpr::Column(3),
4121 ];
4122 }
4123 let m = match_multiway_triangle(&node)
4124 .expect("matcher must accept the Z-shared output-column layout");
4125 assert_eq!(m.rel_xy, RelId(1));
4126 assert_eq!(m.rel_yz, RelId(2));
4127 assert_eq!(m.rel_xz, RelId(3));
4128 }
4129
4130 #[test]
4133 fn match_rejects_invalid_triangle_output_columns() {
4134 let mut node = canonical_multiway();
4135 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4136 *output_columns = vec![
4137 ProjectExpr::Column(0),
4138 ProjectExpr::Column(3),
4139 ProjectExpr::Column(3),
4140 ];
4141 }
4142 assert!(match_multiway_triangle(&node).is_none());
4143 }
4144
4145 #[test]
4146 fn match_rejects_arity_mismatched_output_columns() {
4147 let mut node = canonical_multiway();
4148 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4149 *output_columns = vec![ProjectExpr::Column(0), ProjectExpr::Column(1)];
4150 }
4151 assert!(match_multiway_triangle(&node).is_none());
4152 }
4153
4154 #[test]
4155 fn match_rejects_malformed_slot_vars() {
4156 let mut node = canonical_multiway();
4158 if let RirNode::MultiWayJoin { slot_vars, .. } = &mut node {
4159 *slot_vars = vec![
4160 vec![Some(0u32), Some(1)],
4161 vec![Some(1u32), Some(2)],
4162 vec![Some(0u32), Some(1)],
4163 ];
4164 }
4165 assert!(match_multiway_triangle(&node).is_none());
4166 }
4167
4168 #[test]
4169 fn match_rejects_repeated_var_in_slot() {
4170 let mut node = canonical_multiway();
4171 if let RirNode::MultiWayJoin { slot_vars, .. } = &mut node {
4172 *slot_vars = vec![
4174 vec![Some(0u32), Some(0)],
4175 vec![Some(1u32), Some(2)],
4176 vec![Some(0u32), Some(2)],
4177 ];
4178 }
4179 assert!(match_multiway_triangle(&node).is_none());
4180 }
4181
4182 #[test]
4183 fn match_rejects_non_scan_input() {
4184 let mut node = canonical_multiway();
4185 if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
4186 inputs[0] = RirNode::Unit;
4187 }
4188 assert!(match_multiway_triangle(&node).is_none());
4189 }
4190
4191 #[test]
4192 fn match_rejects_input_arity_mismatch() {
4193 let mut node = canonical_multiway();
4194 if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
4195 inputs.pop();
4196 }
4197 assert!(match_multiway_triangle(&node).is_none());
4198 }
4199
4200 fn env_lock() -> &'static Mutex<()> {
4201 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
4202 LOCK.get_or_init(|| Mutex::new(()))
4203 }
4204
4205 struct EnvSnapshot {
4206 force: Option<String>,
4207 }
4208
4209 impl EnvSnapshot {
4210 fn capture_and_clear() -> Self {
4211 let snapshot = Self {
4212 force: std::env::var(ENV_USE_WCOJ_TRIANGLE_U32).ok(),
4213 };
4214
4215 unsafe {
4218 std::env::remove_var(ENV_USE_WCOJ_TRIANGLE_U32);
4219 }
4220
4221 snapshot
4222 }
4223 }
4224
4225 impl Drop for EnvSnapshot {
4226 fn drop(&mut self) {
4227 unsafe {
4230 match self.force.take() {
4231 Some(v) => std::env::set_var(ENV_USE_WCOJ_TRIANGLE_U32, v),
4232 None => std::env::remove_var(ENV_USE_WCOJ_TRIANGLE_U32),
4233 }
4234 }
4235 }
4236 }
4237
4238 fn with_wcoj_env<R>(f: impl FnOnce() -> R) -> R {
4239 let _guard = env_lock().lock().expect("WCOJ env lock poisoned");
4240 let _snapshot = EnvSnapshot::capture_and_clear();
4241 f()
4242 }
4243
4244 fn set_env(name: &str, value: &str) {
4245 unsafe {
4248 std::env::set_var(name, value);
4249 }
4250 }
4251
4252 #[test]
4253 fn stats_gate_defaults_on_when_env_unset() {
4254 with_wcoj_env(|| {
4255 assert!(wcoj_adaptive_enabled(None));
4256 assert!(wcoj_adaptive_enabled(Some(true)));
4257 assert!(!wcoj_adaptive_enabled(Some(false)));
4258 });
4259 }
4260
4261 #[test]
4262 fn config_controls_stats_gate() {
4263 with_wcoj_env(|| {
4264 assert!(wcoj_adaptive_enabled(Some(true)));
4265 assert!(!wcoj_adaptive_enabled(Some(false)));
4266 });
4267 }
4268
4269 #[test]
4270 fn force_resolver_config_still_overrides_env() {
4271 with_wcoj_env(|| {
4272 set_env(ENV_USE_WCOJ_TRIANGLE_U32, "1");
4273 assert!(wcoj_gate_enabled(None));
4274 assert!(!wcoj_gate_enabled(Some(false)));
4275
4276 set_env(ENV_USE_WCOJ_TRIANGLE_U32, "0");
4277 assert!(!wcoj_gate_enabled(None));
4278 assert!(wcoj_gate_enabled(Some(true)));
4279 });
4280 }
4281
4282 #[test]
4287 fn error_decline_counts_and_falls_back_by_default() {
4288 with_wcoj_env(|| {
4289 let mut counter = 0u64;
4290 let err = xlog_core::XlogError::Kernel("synthetic layout failure".to_string());
4291 let out = super::wcoj_decline_on_error(&mut counter, "triangle", err)
4292 .expect("default mode must decline to the binary-join fallback, not error");
4293 assert!(out.is_none(), "decline must hand control to the fallback");
4294 assert_eq!(counter, 1, "every error decline must be counted");
4295 });
4296 }
4297
4298 #[test]
4299 fn error_decline_propagates_under_strict_env() {
4300 with_wcoj_env(|| {
4301 set_env(super::ENV_WCOJ_STRICT, "1");
4302 let mut counter = 0u64;
4303 let err = xlog_core::XlogError::Kernel("synthetic layout failure".to_string());
4304 let out = super::wcoj_decline_on_error(&mut counter, "triangle", err);
4305 unsafe {
4307 std::env::remove_var(super::ENV_WCOJ_STRICT);
4308 }
4309 match out {
4310 Err(err) => assert!(
4311 err.to_string().contains("synthetic layout failure"),
4312 "strict mode must surface the original error: {err}"
4313 ),
4314 Ok(_) => panic!("XLOG_WCOJ_STRICT=1 must propagate the pipeline error"),
4315 }
4316 assert_eq!(counter, 1, "strict mode still counts the decline");
4317 });
4318 }
4319
4320 use super::{
4325 match_multiway_4cycle, wcoj_4cycle_adaptive_enabled, wcoj_4cycle_disabled,
4326 wcoj_4cycle_gate_enabled, ENV_DISABLE_WCOJ_4CYCLE, ENV_USE_WCOJ_4CYCLE,
4327 ENV_USE_WCOJ_4CYCLE_ADAPTIVE,
4328 };
4329
4330 struct EnvSnapshot4Cycle {
4331 force: Option<String>,
4332 adaptive: Option<String>,
4333 disable: Option<String>,
4334 }
4335
4336 impl EnvSnapshot4Cycle {
4337 fn capture_and_clear() -> Self {
4338 let snap = Self {
4339 force: std::env::var(ENV_USE_WCOJ_4CYCLE).ok(),
4340 adaptive: std::env::var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE).ok(),
4341 disable: std::env::var(ENV_DISABLE_WCOJ_4CYCLE).ok(),
4342 };
4343 unsafe {
4345 std::env::remove_var(ENV_USE_WCOJ_4CYCLE);
4346 std::env::remove_var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE);
4347 std::env::remove_var(ENV_DISABLE_WCOJ_4CYCLE);
4348 }
4349 snap
4350 }
4351 }
4352
4353 impl Drop for EnvSnapshot4Cycle {
4354 fn drop(&mut self) {
4355 unsafe {
4357 match self.force.take() {
4358 Some(v) => std::env::set_var(ENV_USE_WCOJ_4CYCLE, v),
4359 None => std::env::remove_var(ENV_USE_WCOJ_4CYCLE),
4360 }
4361 match self.adaptive.take() {
4362 Some(v) => std::env::set_var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, v),
4363 None => std::env::remove_var(ENV_USE_WCOJ_4CYCLE_ADAPTIVE),
4364 }
4365 match self.disable.take() {
4366 Some(v) => std::env::set_var(ENV_DISABLE_WCOJ_4CYCLE, v),
4367 None => std::env::remove_var(ENV_DISABLE_WCOJ_4CYCLE),
4368 }
4369 }
4370 }
4371 }
4372
4373 fn with_4cycle_env<R>(f: impl FnOnce() -> R) -> R {
4374 let _guard = env_lock().lock().expect("4-cycle env lock poisoned");
4375 let _snap = EnvSnapshot4Cycle::capture_and_clear();
4376 f()
4377 }
4378
4379 #[test]
4380 fn force_4cycle_resolver_defaults_off_when_env_unset() {
4381 with_4cycle_env(|| {
4382 assert!(!wcoj_4cycle_gate_enabled(None));
4383 assert!(wcoj_4cycle_gate_enabled(Some(true)));
4384 assert!(!wcoj_4cycle_gate_enabled(Some(false)));
4385 });
4386 }
4387
4388 #[test]
4389 fn force_4cycle_resolver_env_can_enable() {
4390 with_4cycle_env(|| {
4391 set_env(ENV_USE_WCOJ_4CYCLE, "1");
4392 assert!(wcoj_4cycle_gate_enabled(None));
4393 set_env(ENV_USE_WCOJ_4CYCLE, "true");
4394 assert!(wcoj_4cycle_gate_enabled(None));
4395 set_env(ENV_USE_WCOJ_4CYCLE, "0");
4396 assert!(!wcoj_4cycle_gate_enabled(None));
4397 });
4398 }
4399
4400 #[test]
4405 fn adaptive_4cycle_resolver_defaults_off_when_env_unset() {
4406 with_4cycle_env(|| {
4407 assert!(
4408 !wcoj_4cycle_adaptive_enabled(None),
4409 "4-cycle adaptive must be OPT-IN by default (unlike triangle's default-on)"
4410 );
4411 assert!(wcoj_4cycle_adaptive_enabled(Some(true)));
4412 assert!(!wcoj_4cycle_adaptive_enabled(Some(false)));
4413 });
4414 }
4415
4416 #[test]
4417 fn adaptive_4cycle_resolver_env_can_enable() {
4418 with_4cycle_env(|| {
4419 set_env(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, "1");
4420 assert!(wcoj_4cycle_adaptive_enabled(None));
4421 set_env(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, "0");
4422 assert!(!wcoj_4cycle_adaptive_enabled(None));
4423 set_env(ENV_USE_WCOJ_4CYCLE_ADAPTIVE, "true");
4424 assert!(wcoj_4cycle_adaptive_enabled(None));
4425 });
4426 }
4427
4428 #[test]
4429 fn kill_4cycle_resolver_honors_env_and_config() {
4430 with_4cycle_env(|| {
4431 assert!(!wcoj_4cycle_disabled(None));
4432 set_env(ENV_DISABLE_WCOJ_4CYCLE, "1");
4433 assert!(wcoj_4cycle_disabled(None));
4434 assert!(!wcoj_4cycle_disabled(Some(false)));
4435 set_env(ENV_DISABLE_WCOJ_4CYCLE, "0");
4436 assert!(wcoj_4cycle_disabled(Some(true)));
4437 });
4438 }
4439
4440 fn canonical_4cycle_multiway() -> RirNode {
4441 RirNode::MultiWayJoin {
4442 inputs: vec![
4443 RirNode::Scan { rel: RelId(1) },
4444 RirNode::Scan { rel: RelId(2) },
4445 RirNode::Scan { rel: RelId(3) },
4446 RirNode::Scan { rel: RelId(4) },
4447 ],
4448 slot_vars: vec![
4449 vec![Some(0u32), Some(1)],
4450 vec![Some(1u32), Some(2)],
4451 vec![Some(2u32), Some(3)],
4452 vec![Some(3u32), Some(0)],
4453 ],
4454 output_columns: vec![
4455 ProjectExpr::Column(0),
4456 ProjectExpr::Column(1),
4457 ProjectExpr::Column(3),
4458 ProjectExpr::Column(5),
4459 ],
4460 fallback: Box::new(RirNode::Unit),
4461 plan: None,
4462 var_order: None,
4463 }
4464 }
4465
4466 #[test]
4467 fn match_4cycle_canonical_returns_four_rels() {
4468 let node = canonical_4cycle_multiway();
4469 let m = match_multiway_4cycle(&node).expect("must match canonical 4-cycle");
4470 assert_eq!(m.rel_e1, RelId(1));
4471 assert_eq!(m.rel_e2, RelId(2));
4472 assert_eq!(m.rel_e3, RelId(3));
4473 assert_eq!(m.rel_e4, RelId(4));
4474 }
4475
4476 #[test]
4477 fn match_4cycle_rejects_non_multiway() {
4478 assert!(match_multiway_4cycle(&RirNode::Scan { rel: RelId(1) }).is_none());
4479 }
4480
4481 #[test]
4482 fn match_4cycle_rejects_triangle_shape() {
4483 let triangle = RirNode::MultiWayJoin {
4485 inputs: vec![
4486 RirNode::Scan { rel: RelId(1) },
4487 RirNode::Scan { rel: RelId(2) },
4488 RirNode::Scan { rel: RelId(3) },
4489 ],
4490 slot_vars: vec![
4491 vec![Some(0u32), Some(1)],
4492 vec![Some(1u32), Some(2)],
4493 vec![Some(0u32), Some(2)],
4494 ],
4495 output_columns: vec![
4496 ProjectExpr::Column(0),
4497 ProjectExpr::Column(1),
4498 ProjectExpr::Column(3),
4499 ],
4500 fallback: Box::new(RirNode::Unit),
4501 plan: None,
4502 var_order: None,
4503 };
4504 assert!(match_multiway_4cycle(&triangle).is_none());
4505 }
4506
4507 #[test]
4508 fn match_4cycle_rejects_rotated_output_columns() {
4509 let mut node = canonical_4cycle_multiway();
4510 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4511 output_columns.swap(0, 1);
4512 }
4513 assert!(match_multiway_4cycle(&node).is_none());
4514 }
4515
4516 #[test]
4521 fn match_4cycle_accepts_alt_grouping_output_columns() {
4522 let mut node = canonical_4cycle_multiway();
4523 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4524 *output_columns = vec![
4525 ProjectExpr::Column(5),
4526 ProjectExpr::Column(0),
4527 ProjectExpr::Column(1),
4528 ProjectExpr::Column(3),
4529 ];
4530 }
4531 let m = match_multiway_4cycle(&node)
4532 .expect("matcher must accept the Alt-grouping output-column layout");
4533 assert_eq!(m.rel_e1, RelId(1));
4537 assert_eq!(m.rel_e2, RelId(2));
4538 assert_eq!(m.rel_e3, RelId(3));
4539 assert_eq!(m.rel_e4, RelId(4));
4540 }
4541
4542 #[test]
4546 fn match_4cycle_rejects_invalid_output_columns() {
4547 let mut node = canonical_4cycle_multiway();
4548 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4549 *output_columns = vec![
4550 ProjectExpr::Column(1),
4551 ProjectExpr::Column(0),
4552 ProjectExpr::Column(3),
4553 ProjectExpr::Column(5),
4554 ];
4555 }
4556 assert!(match_multiway_4cycle(&node).is_none());
4557 }
4558
4559 #[test]
4560 fn match_4cycle_rejects_arity_mismatched_output_columns() {
4561 let mut node = canonical_4cycle_multiway();
4562 if let RirNode::MultiWayJoin { output_columns, .. } = &mut node {
4563 output_columns.pop();
4564 }
4565 assert!(match_multiway_4cycle(&node).is_none());
4566 }
4567
4568 #[test]
4569 fn match_4cycle_rejects_unclosed_cycle() {
4570 let mut node = canonical_4cycle_multiway();
4573 if let RirNode::MultiWayJoin { slot_vars, .. } = &mut node {
4574 slot_vars[3] = vec![Some(3), Some(99)];
4575 }
4576 assert!(match_multiway_4cycle(&node).is_none());
4577 }
4578
4579 #[test]
4580 fn match_4cycle_rejects_non_scan_input() {
4581 let mut node = canonical_4cycle_multiway();
4582 if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
4583 inputs[0] = RirNode::Unit;
4584 }
4585 assert!(match_multiway_4cycle(&node).is_none());
4586 }
4587
4588 #[test]
4589 fn match_4cycle_rejects_input_arity_mismatch() {
4590 let mut node = canonical_4cycle_multiway();
4591 if let RirNode::MultiWayJoin { inputs, .. } = &mut node {
4592 inputs.push(RirNode::Scan { rel: RelId(5) });
4593 }
4594 assert!(match_multiway_4cycle(&node).is_none());
4595 }
4596}