Skip to main content

xlog_cuda/
wcoj_metadata.rs

1use std::collections::BTreeMap;
2
3use cudarc::driver::DeviceRepr;
4
5use crate::memory::TrackedCudaSlice;
6
7pub const WCOJ_HG_BLOCK_WORK_UNIT_DEFAULT: u32 = 1024;
8
9/// Candidate root variable identifier for WCOJ metadata planning.
10pub type VertexId = u8;
11
12/// Compact per-root heat distribution used by the K-clique planner.
13pub type HeatDist = Vec<f64>;
14
15/// Metadata cached for one candidate root variable.
16#[derive(Debug, Clone, PartialEq)]
17pub struct RootMetadata {
18    /// Column permutation needed to expose this root as the leading key.
19    pub column_permutation: Vec<u8>,
20    /// Signature of the sorted layout used by this candidate root.
21    pub sorted_layout_signature: LayoutSignature,
22    /// Heavy-key heat distribution for this root.
23    pub heat_distribution: HeatDist,
24}
25
26/// Stable identity for a sorted relation layout.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct LayoutSignature {
29    /// Runtime relation identifier.
30    pub relation_id: u32,
31    /// Columns used as the sorted key prefix.
32    pub key_columns: Vec<usize>,
33    /// Logical rows in the sorted layout.
34    pub row_count: u32,
35}
36
37pub struct WcojRelationMetadata<K: DeviceRepr> {
38    pub unique_keys: TrackedCudaSlice<K>,
39    pub fan_out: TrackedCudaSlice<u32>,
40    pub prefix_sum: TrackedCudaSlice<u32>,
41    /// Per-candidate-root metadata cached for planner reuse.
42    pub per_candidate_root: BTreeMap<VertexId, RootMetadata>,
43    pub total: u64,
44    pub key_count: u32,
45    pub row_count: u32,
46}
47
48/// Aggregate-fused triangle group-by-root sum/min/max selector: which
49/// triangle output variable supplies the aggregate
50/// value for the fused group-by-root sum/min/max kernels. The group key is
51/// always the variable-order root X; the value must itself be a triangle
52/// output variable (Y or Z) so the kernel can read it during traversal.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum WcojRootAggValue {
55    /// Aggregate over Y (`e_xy.col1` of the root row).
56    Y,
57    /// Aggregate over Z (the matched intersection value).
58    Z,
59}
60
61/// Aggregate-fused 4-cycle group-by-root sum/min/max selector: which
62/// 4-cycle output variable supplies the aggregate value for
63/// the fused group-by-root sum/min/max kernels. The group key is always
64/// the variable-order root W; the value must itself be a 4-cycle output
65/// variable (X, Y or Z) so the kernel can read it during traversal.
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
67pub enum Wcoj4CycleRootAggValue {
68    /// Aggregate over X (`e1.col1` of the root row).
69    X,
70    /// Aggregate over Y (`e2.col1` of the resolved work item).
71    Y,
72    /// Aggregate over Z (`e3.col1` of the resolved work item).
73    Z,
74}
75
76pub struct WcojTriangleHgWorkPlanU32 {
77    pub xy_work_prefix: TrackedCudaSlice<u32>,
78    pub xy_yz_start: TrackedCudaSlice<u32>,
79    pub xy_yz_end: TrackedCudaSlice<u32>,
80    pub xy_xz_start: TrackedCudaSlice<u32>,
81    pub xy_xz_end: TrackedCudaSlice<u32>,
82    pub block_counts: TrackedCudaSlice<u32>,
83    pub block_offsets: TrackedCudaSlice<u32>,
84    pub scratch_x: TrackedCudaSlice<u32>,
85    pub scratch_y: TrackedCudaSlice<u32>,
86    pub scratch_z: TrackedCudaSlice<u32>,
87    pub total_work: u32,
88    pub block_work_unit: u32,
89    pub row_count: u32,
90}
91
92pub struct WcojTriangleHgCountPhaseU32 {
93    pub total_rows_device: TrackedCudaSlice<u32>,
94    pub total_rows: u32,
95}
96
97pub struct WcojTriangleHgWorkPlanU64 {
98    pub xy_work_prefix: TrackedCudaSlice<u32>,
99    pub xy_yz_start: TrackedCudaSlice<u32>,
100    pub xy_yz_end: TrackedCudaSlice<u32>,
101    pub xy_xz_start: TrackedCudaSlice<u32>,
102    pub xy_xz_end: TrackedCudaSlice<u32>,
103    pub block_counts: TrackedCudaSlice<u32>,
104    pub block_offsets: TrackedCudaSlice<u32>,
105    pub total_work: u32,
106    pub block_work_unit: u32,
107    pub row_count: u32,
108}
109
110pub struct WcojCycle4HgWorkPlanU32 {
111    pub e1_work_prefix: TrackedCudaSlice<u32>,
112    pub e2_work_prefix: TrackedCudaSlice<u32>,
113    pub e1_e2_start: TrackedCudaSlice<u32>,
114    pub e1_e2_end: TrackedCudaSlice<u32>,
115    pub block_counts: TrackedCudaSlice<u32>,
116    pub block_offsets: TrackedCudaSlice<u32>,
117    pub total_work: u32,
118    pub block_work_unit: u32,
119    pub row_count: u32,
120}
121
122pub struct WcojCycle4HgWorkPlanU64 {
123    pub e1_work_prefix: TrackedCudaSlice<u32>,
124    pub e2_work_prefix: TrackedCudaSlice<u32>,
125    pub e1_e2_start: TrackedCudaSlice<u32>,
126    pub e1_e2_end: TrackedCudaSlice<u32>,
127    pub block_counts: TrackedCudaSlice<u32>,
128    pub block_offsets: TrackedCudaSlice<u32>,
129    pub total_work: u32,
130    pub block_work_unit: u32,
131    pub row_count: u32,
132}
133
134impl<K: DeviceRepr> WcojRelationMetadata<K> {
135    pub fn metadata_bytes(&self) -> u64 {
136        let key_bytes = self.unique_keys.len() as u64 * std::mem::size_of::<K>() as u64;
137        let fan_out_bytes = self.fan_out.len() as u64 * std::mem::size_of::<u32>() as u64;
138        let prefix_bytes = self.prefix_sum.len() as u64 * std::mem::size_of::<u32>() as u64;
139        key_bytes
140            .saturating_add(fan_out_bytes)
141            .saturating_add(prefix_bytes)
142    }
143}