Skip to main content

xlog_cuda/provider/
wcoj_project.rs

1//! Owned, recorded helpers for WCOJ variable-ordering column
2//! projection on `CudaBuffer`s.
3//!
4//! Two helpers, both with the same shape:
5//!
6//! * [`CudaKernelProvider::wcoj_project_2col_swap_recorded`] —
7//!   2-col-only swap helper used by the variable-ordering dispatcher when
8//!   triangle non-default leaders need a `(col1, col0)` rotated
9//!   lookup atom before [`CudaKernelProvider::wcoj_layout_u32_recorded`]
10//!   can sort it.
11//!
12//! * [`CudaKernelProvider::wcoj_project_output_columns_recorded`] —
13//!   N-col arbitrary-permutation helper that the variable-ordering dispatcher
14//!   applies to the kernel-direct output buffer to remap kernel
15//!   columns into the rule's head order.
16//!
17//! Both helpers:
18//!
19//! 1. Build owned `TrackedCudaSlice<u8>` columns sized to the
20//!    source's `row_cap × dtype.size_bytes()`.
21//! 2. Allocate a fresh `num_rows_device` device scalar.
22//! 3. Use `LaunchRecorder::new_strict` + `preflight` to declare
23//!    reads (source columns + source `num_rows_device`) and
24//!    writes (new columns + new `num_rows_device`).
25//! 4. Issue per-column DtoD-async copies on the launch stream,
26//!    plus one DtoD-async copy for `num_rows_device`.
27//! 5. **Failure-drain**: any error after the first queued copy
28//!    must `cu_stream.synchronize()` before returning, because
29//!    partially-allocated owned buffers are about to drop and
30//!    an in-flight DtoD copy would race the runtime dealloc.
31//!    Mirrors `wcoj.rs ≈ 2140` (skew-score histogram
32//!    queued-result discipline).
33//! 6. Carry `cached_row_count` from `src` unchanged — logical
34//!    row count is invariant under column permutation.
35
36use cudarc::driver::sys;
37use xlog_core::{Result, Schema, XlogError};
38
39use crate::cuda_compat::DeviceSlice;
40use crate::device_runtime::StreamId;
41use crate::launch::LaunchRecorder;
42use crate::memory::{CudaColumn, TrackedCudaSlice};
43use crate::CudaBuffer;
44
45use super::CudaKernelProvider;
46
47impl CudaKernelProvider {
48    /// Produce an owned 2-col `CudaBuffer` whose columns are
49    /// `[src.col(1), src.col(0)]`. See module docs for the full
50    /// recorded / failure-drain contract.
51    ///
52    /// Used by the variable-ordering dispatcher when a triangle non-default
53    /// leader requires a col-swap before
54    /// [`Self::wcoj_layout_u32_recorded`] sorts the result. The
55    /// 4-cycle path is rotation-only and never invokes this
56    /// helper.
57    ///
58    /// # Errors
59    /// * `XlogError::Kernel` if the manager has no runtime, the
60    ///   stream doesn't resolve, the input isn't 2-col, or any
61    ///   queued DtoD copy fails. On any failure after the first
62    ///   queued copy, the launch stream is synchronized before
63    ///   the function returns.
64    pub fn wcoj_project_2col_swap_recorded(
65        &self,
66        src: &CudaBuffer,
67        launch_stream: StreamId,
68    ) -> Result<CudaBuffer> {
69        let runtime = self.memory().runtime().ok_or_else(|| {
70            XlogError::Kernel(
71                "wcoj_project_2col_swap_recorded requires a runtime-backed \
72                 GpuMemoryManager (constructed via with_runtime)"
73                    .to_string(),
74            )
75        })?;
76        let cu_stream = runtime
77            .stream_pool()
78            .resolve(launch_stream)
79            .ok_or_else(|| {
80                XlogError::Kernel(format!(
81                    "wcoj_project_2col_swap_recorded: launch_stream StreamId({}) does not resolve",
82                    launch_stream.0
83                ))
84            })?;
85
86        if src.arity() != 2 {
87            return Err(XlogError::Kernel(format!(
88                "wcoj_project_2col_swap_recorded: src must be 2-column, got arity {}",
89                src.arity()
90            )));
91        }
92
93        // Build the swapped schema. The (name, type) pairs are
94        // re-ordered; key_columns recomputes via Schema::new.
95        let swapped_schema = Schema::new(vec![
96            src.schema.columns[1].clone(),
97            src.schema.columns[0].clone(),
98        ])
99        .with_sort_labels(vec![
100            src.schema
101                .column_sort_label(1)
102                .unwrap_or("col1")
103                .to_string(),
104            src.schema
105                .column_sort_label(0)
106                .unwrap_or("col0")
107                .to_string(),
108        ])
109        .expect("swapped sort labels match schema arity");
110
111        // Empty buffer fast-path: arity=2 with row_cap=0 produces
112        // an empty owned buffer with the swapped schema. No DtoD
113        // copy is queued; failure-drain is moot.
114        if src.row_cap == 0 {
115            return self.create_empty_buffer(swapped_schema);
116        }
117
118        // Allocate the two output column buffers (sized in bytes
119        // from the source) plus a fresh num_rows_device scalar.
120        // SAFETY: src.column(0/1) are guaranteed Some (arity == 2);
121        // their byte length matches row_cap × dtype.size_bytes()
122        // by CudaBuffer construction invariants.
123        let bytes_col0 = src.column(0).expect("src.col0").len();
124        let bytes_col1 = src.column(1).expect("src.col1").len();
125        let new_col0: TrackedCudaSlice<u8> = self.memory.alloc::<u8>(bytes_col1)?;
126        let new_col1: TrackedCudaSlice<u8> = self.memory.alloc::<u8>(bytes_col0)?;
127        let new_num_rows: TrackedCudaSlice<u32> = self.memory.alloc::<u32>(1)?;
128
129        let mut rec = LaunchRecorder::new_strict(launch_stream);
130        rec.read_column(src.column(0).expect("src.col0"));
131        rec.read_column(src.column(1).expect("src.col1"));
132        rec.read(src.num_rows_device());
133        rec.write(&new_col0);
134        rec.write(&new_col1);
135        rec.write(&new_num_rows);
136        rec.preflight(runtime).map_err(|e| {
137            XlogError::Kernel(format!(
138                "wcoj_project_2col_swap_recorded: launch recorder preflight failed: {}",
139                e
140            ))
141        })?;
142
143        // Failure-drain: from here on, any error must synchronize
144        // the stream before returning so partially-issued copies
145        // drain before the runtime deallocs `new_col0` /
146        // `new_col1` / `new_num_rows` on Err drop.
147        let queued_result: Result<()> = (|| {
148            // SAFETY: all DtoD copies are between live device
149            // pointers on a stream the runtime owns. Sizes match
150            // the source columns exactly. cuMemcpyDtoDAsync_v2 is
151            // genuinely stream-asynchronous.
152            unsafe {
153                let res = sys::cuMemcpyDtoDAsync_v2(
154                    *new_col0.device_ptr(),
155                    *src.column(1).expect("src.col1").device_ptr(),
156                    bytes_col1,
157                    cu_stream.cu_stream(),
158                );
159                if res != sys::cudaError_enum::CUDA_SUCCESS {
160                    return Err(XlogError::Kernel(format!(
161                        "wcoj_project_2col_swap_recorded: dtod col1 → new_col0 failed: {:?}",
162                        res
163                    )));
164                }
165                let res = sys::cuMemcpyDtoDAsync_v2(
166                    *new_col1.device_ptr(),
167                    *src.column(0).expect("src.col0").device_ptr(),
168                    bytes_col0,
169                    cu_stream.cu_stream(),
170                );
171                if res != sys::cudaError_enum::CUDA_SUCCESS {
172                    return Err(XlogError::Kernel(format!(
173                        "wcoj_project_2col_swap_recorded: dtod col0 → new_col1 failed: {:?}",
174                        res
175                    )));
176                }
177                let res = sys::cuMemcpyDtoDAsync_v2(
178                    *new_num_rows.device_ptr(),
179                    *src.num_rows_device().device_ptr(),
180                    std::mem::size_of::<u32>(),
181                    cu_stream.cu_stream(),
182                );
183                if res != sys::cudaError_enum::CUDA_SUCCESS {
184                    return Err(XlogError::Kernel(format!(
185                        "wcoj_project_2col_swap_recorded: dtod num_rows_device failed: {:?}",
186                        res
187                    )));
188                }
189            }
190            Ok(())
191        })();
192
193        if let Err(e) = queued_result {
194            let _ = cu_stream.synchronize();
195            return Err(e);
196        }
197
198        rec.commit(runtime).map_err(|e| {
199            // commit happens after CUDA work is queued; if commit
200            // fails, drain so partially-issued copies finish
201            // before the buffers drop.
202            let _ = cu_stream.synchronize();
203            XlogError::Kernel(format!(
204                "wcoj_project_2col_swap_recorded: launch recorder commit failed: {}",
205                e
206            ))
207        })?;
208
209        // Build the output buffer. Carry cached_row_count from src
210        // (logical row count is invariant under column permutation).
211        let columns: Vec<CudaColumn> = vec![new_col0.into(), new_col1.into()];
212        let buf = match src.cached_row_count() {
213            Some(host_count) => CudaBuffer::from_columns_with_host_count(
214                columns,
215                src.row_cap,
216                new_num_rows,
217                swapped_schema,
218                host_count,
219            ),
220            None => CudaBuffer::from_columns(columns, src.row_cap, new_num_rows, swapped_schema),
221        };
222        Ok(buf)
223    }
224
225    /// Produce an owned N-col `CudaBuffer` with columns
226    /// reordered per `perm` and the schema replaced with
227    /// `head_schema`.
228    ///
229    /// `perm[i]` is the source column index that becomes output
230    /// column `i`. The dispatcher uses this post-kernel to remap
231    /// the kernel-direct output (in leader's `(a, b, c[, d])`
232    /// order) into the rule's head order.
233    ///
234    /// See module docs for the full recorded / failure-drain
235    /// contract.
236    ///
237    /// # Errors
238    /// * `XlogError::Kernel` if the manager has no runtime, the
239    ///   stream doesn't resolve, `perm.len() != head_schema.arity()`,
240    ///   any `perm` index is ≥ `src.arity()`, or any queued DtoD
241    ///   copy fails. Failure-drain on Err.
242    pub fn wcoj_project_output_columns_recorded(
243        &self,
244        src: &CudaBuffer,
245        perm: &[usize],
246        head_schema: Schema,
247        launch_stream: StreamId,
248    ) -> Result<CudaBuffer> {
249        let runtime = self.memory().runtime().ok_or_else(|| {
250            XlogError::Kernel(
251                "wcoj_project_output_columns_recorded requires a runtime-backed \
252                 GpuMemoryManager (constructed via with_runtime)"
253                    .to_string(),
254            )
255        })?;
256        let cu_stream = runtime
257            .stream_pool()
258            .resolve(launch_stream)
259            .ok_or_else(|| {
260                XlogError::Kernel(format!(
261                    "wcoj_project_output_columns_recorded: launch_stream StreamId({}) does not resolve",
262                    launch_stream.0
263                ))
264            })?;
265
266        if perm.len() != head_schema.arity() {
267            return Err(XlogError::Kernel(format!(
268                "wcoj_project_output_columns_recorded: perm len {} must equal head_schema arity {}",
269                perm.len(),
270                head_schema.arity()
271            )));
272        }
273        for (i, &p) in perm.iter().enumerate() {
274            if p >= src.arity() {
275                return Err(XlogError::Kernel(format!(
276                    "wcoj_project_output_columns_recorded: perm[{}] = {} out of bounds (src arity {})",
277                    i,
278                    p,
279                    src.arity()
280                )));
281            }
282        }
283
284        // Empty-output short-circuit. WCOJ legitimately produces
285        // empty output — the helper must NOT divide-by-zero or
286        // refuse to materialize. cached_row_count == Some(0) +
287        // num_rows_device == 0 are produced via create_empty_buffer
288        // (sized to head_schema; its num_rows_device starts at 0).
289        if src.row_cap == 0 {
290            return self.create_empty_buffer(head_schema);
291        }
292
293        // Allocate one new column buffer per output position +
294        // a fresh num_rows_device scalar.
295        let mut new_columns: Vec<TrackedCudaSlice<u8>> = Vec::with_capacity(perm.len());
296        for &p in perm {
297            let bytes = src.column(p).expect("src column").len();
298            new_columns.push(self.memory.alloc::<u8>(bytes)?);
299        }
300        let new_num_rows: TrackedCudaSlice<u32> = self.memory.alloc::<u32>(1)?;
301
302        let mut rec = LaunchRecorder::new_strict(launch_stream);
303        for i in 0..src.arity() {
304            rec.read_column(src.column(i).expect("src.col"));
305        }
306        rec.read(src.num_rows_device());
307        for c in &new_columns {
308            rec.write(c);
309        }
310        rec.write(&new_num_rows);
311        rec.preflight(runtime).map_err(|e| {
312            XlogError::Kernel(format!(
313                "wcoj_project_output_columns_recorded: launch recorder preflight failed: {}",
314                e
315            ))
316        })?;
317
318        let queued_result: Result<()> = (|| {
319            for (i, &p) in perm.iter().enumerate() {
320                let src_col = src.column(p).expect("src column");
321                let bytes = src_col.len();
322                // SAFETY: device pointer is live, len matches source
323                // column. Stream is owned by the runtime.
324                unsafe {
325                    let res = sys::cuMemcpyDtoDAsync_v2(
326                        *new_columns[i].device_ptr(),
327                        *src_col.device_ptr(),
328                        bytes,
329                        cu_stream.cu_stream(),
330                    );
331                    if res != sys::cudaError_enum::CUDA_SUCCESS {
332                        return Err(XlogError::Kernel(format!(
333                            "wcoj_project_output_columns_recorded: dtod perm[{}] = src col {} failed: {:?}",
334                            i, p, res
335                        )));
336                    }
337                }
338            }
339            // num_rows_device DtoD-copy.
340            unsafe {
341                let res = sys::cuMemcpyDtoDAsync_v2(
342                    *new_num_rows.device_ptr(),
343                    *src.num_rows_device().device_ptr(),
344                    std::mem::size_of::<u32>(),
345                    cu_stream.cu_stream(),
346                );
347                if res != sys::cudaError_enum::CUDA_SUCCESS {
348                    return Err(XlogError::Kernel(format!(
349                        "wcoj_project_output_columns_recorded: dtod num_rows_device failed: {:?}",
350                        res
351                    )));
352                }
353            }
354            Ok(())
355        })();
356
357        if let Err(e) = queued_result {
358            let _ = cu_stream.synchronize();
359            return Err(e);
360        }
361
362        rec.commit(runtime).map_err(|e| {
363            let _ = cu_stream.synchronize();
364            XlogError::Kernel(format!(
365                "wcoj_project_output_columns_recorded: launch recorder commit failed: {}",
366                e
367            ))
368        })?;
369
370        let columns: Vec<CudaColumn> = new_columns.into_iter().map(|c| c.into()).collect();
371        let buf = match src.cached_row_count() {
372            Some(host_count) => CudaBuffer::from_columns_with_host_count(
373                columns,
374                src.row_cap,
375                new_num_rows,
376                head_schema,
377                host_count,
378            ),
379            None => CudaBuffer::from_columns(columns, src.row_cap, new_num_rows, head_schema),
380        };
381        Ok(buf)
382    }
383}