1use xlog_core::{Result, ScalarType, XlogError};
24
25use super::fj_delta::FjDeltaCols;
26use super::{wcoj_kernels, CudaKernelProvider, WCOJ_MODULE};
27use crate::device_runtime::StreamId;
28use crate::launch::LaunchRecorder;
29use crate::CudaBuffer;
30use crate::{LaunchAsync, LaunchConfig};
31
32const BLOCK_SIZE: u32 = 256;
33
34fn require_binary_u32_class(buf: &CudaBuffer, name: &str, ctx: &str) -> Result<()> {
35 if buf.arity() != 2 {
36 return Err(XlogError::Kernel(format!(
37 "{ctx}: {name} must be arity-2, got {}",
38 buf.arity()
39 )));
40 }
41 for idx in 0..2 {
42 match buf.schema().column_type(idx) {
43 Some(ScalarType::U32) | Some(ScalarType::Symbol) => {}
44 other => {
45 return Err(XlogError::Kernel(format!(
46 "{ctx}: {name} column {idx} must be U32/Symbol, got {other:?}"
47 )));
48 }
49 }
50 }
51 Ok(())
52}
53
54impl CudaKernelProvider {
55 pub fn fj_delta_sparse_novel_u32_recorded(
67 &self,
68 delta: &CudaBuffer,
69 edge: &CudaBuffer,
70 full_r: &CudaBuffer,
71 cols: FjDeltaCols,
72 max_table_bytes: u64,
73 launch_stream: StreamId,
74 ) -> Result<Option<CudaBuffer>> {
75 let ctx = "fj_delta_sparse_novel_u32_recorded";
76 let runtime = self.memory().runtime().ok_or_else(|| {
77 XlogError::Kernel(format!(
78 "{ctx} requires a runtime-backed GpuMemoryManager (with_runtime)"
79 ))
80 })?;
81 let cu_stream = runtime
82 .stream_pool()
83 .resolve(launch_stream)
84 .ok_or_else(|| {
85 XlogError::Kernel(format!(
86 "{ctx}: launch_stream StreamId({}) does not resolve",
87 launch_stream.0
88 ))
89 })?;
90
91 require_binary_u32_class(delta, "delta", ctx)?;
92 require_binary_u32_class(edge, "edge", ctx)?;
93 require_binary_u32_class(full_r, "full_r", ctx)?;
94
95 let row_count = |buf: &CudaBuffer| -> Result<u32> {
96 match buf.cached_row_count() {
97 Some(c) => Ok(c),
98 None => self.dtoh_scalar_untracked::<u32>(buf.num_rows_device(), 0),
99 }
100 };
101 let n_delta = row_count(delta)?;
102 let n_edge = row_count(edge)?;
103 let n_r = row_count(full_r)?;
104
105 let out_schema = full_r.schema().clone();
106 if n_delta == 0 || n_edge == 0 {
107 return Ok(Some(self.create_empty_buffer(out_schema)?));
108 }
109
110 let delta_x = delta.column(cols.delta_carry).ok_or_else(|| {
111 XlogError::Kernel(format!("{ctx}: delta column {} missing", cols.delta_carry))
112 })?;
113 let delta_y = delta.column(cols.delta_key).ok_or_else(|| {
114 XlogError::Kernel(format!("{ctx}: delta column {} missing", cols.delta_key))
115 })?;
116 let edge_y = edge
117 .column(0)
118 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: edge column 0 missing")))?;
119 let edge_z = edge
120 .column(1)
121 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: edge column 1 missing")))?;
122 let r_x = full_r.column(cols.r_carry).ok_or_else(|| {
123 XlogError::Kernel(format!("{ctx}: full_r column {} missing", cols.r_carry))
124 })?;
125 let r_z = full_r.column(cols.r_value).ok_or_else(|| {
126 XlogError::Kernel(format!("{ctx}: full_r column {} missing", cols.r_value))
127 })?;
128 let delta_y_v = self.column_as_u32_view(delta_y, n_delta as usize)?;
129 let edge_y_v = self.column_as_u32_view(edge_y, n_edge as usize)?;
130
131 let range_lo = self.memory().alloc::<u32>(n_delta as usize)?;
134 let mut wp = self.memory().alloc::<u32>(n_delta as usize + 1)?;
135 {
136 let mut rec = LaunchRecorder::new_strict(launch_stream);
137 rec.read_column(delta_y);
138 rec.read_column(edge_y);
139 rec.write(&range_lo);
140 rec.write(&wp);
141 rec.preflight(runtime)
142 .map_err(|e| XlogError::Kernel(format!("{ctx}: range preflight: {e}")))?;
143 let kernel = self
144 .device()
145 .inner()
146 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_RANGE_U32)
147 .ok_or_else(|| XlogError::Kernel("fj_delta_range_u32 not found".to_string()))?;
148 let grid = n_delta.div_ceil(BLOCK_SIZE);
149 unsafe {
151 kernel
152 .clone()
153 .launch_on_stream(
154 &cu_stream,
155 LaunchConfig {
156 grid_dim: (grid, 1, 1),
157 block_dim: (BLOCK_SIZE, 1, 1),
158 shared_mem_bytes: 0,
159 },
160 (&delta_y_v, n_delta, &edge_y_v, n_edge, &range_lo, &mut wp),
161 )
162 .map_err(|e| XlogError::Kernel(format!("fj_delta_range_u32 launch: {e}")))?;
163 }
164 self.multiblock_scan_u32_inplace_on_stream(
165 &mut wp,
166 n_delta + 1,
167 &cu_stream,
168 launch_stream,
169 runtime,
170 )?;
171 rec.commit(runtime)
172 .map_err(|e| XlogError::Kernel(format!("{ctx}: range commit: {e}")))?;
173 }
174 cu_stream
175 .synchronize()
176 .map_err(|e| XlogError::Kernel(format!("{ctx}: range sync: {e}")))?;
177 let total_work = u64::from(self.dtoh_scalar_untracked::<u32>(&wp, n_delta as usize)?);
178 if total_work == 0 {
179 return Ok(Some(self.create_empty_buffer(out_schema)?));
180 }
181 if total_work > u64::from(u32::MAX - 1) {
182 return Err(XlogError::Kernel(format!(
183 "{ctx}: candidate work {total_work} exceeds u32 work-index space"
184 )));
185 }
186 let total_work = total_work as u32;
187
188 const EST_BITS: u32 = 1 << 26;
197 const EST_WORDS: u32 = EST_BITS / 32;
198 let est_bit_mask = EST_BITS - 1;
199 let mut est = self.memory().alloc::<u32>(EST_WORDS as usize)?;
200 self.device()
201 .inner()
202 .memset_zeros(&mut est)
203 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero estimator: {e}")))?;
204 let mut est_counts = self.memory().alloc::<u32>(EST_WORDS as usize + 1)?;
205 {
206 let mut rec = LaunchRecorder::new_strict(launch_stream);
207 rec.read_column(delta_x);
208 rec.read_column(edge_z);
209 rec.read(&range_lo);
210 rec.read(&wp);
211 rec.read_write(&est);
212 rec.write(&est_counts);
213 rec.preflight(runtime)
214 .map_err(|e| XlogError::Kernel(format!("{ctx}: estimate preflight: {e}")))?;
215 let delta_x_v = self.column_as_u32_view(delta_x, n_delta as usize)?;
216 let edge_z_v = self.column_as_u32_view(edge_z, n_edge as usize)?;
217 let estimate = self
218 .device()
219 .inner()
220 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_SPARSE_ESTIMATE)
221 .ok_or_else(|| {
222 XlogError::Kernel("fj_delta_sparse_estimate not found".to_string())
223 })?;
224 let grid = total_work.div_ceil(BLOCK_SIZE);
225 unsafe {
228 estimate
229 .clone()
230 .launch_on_stream(
231 &cu_stream,
232 LaunchConfig {
233 grid_dim: (grid, 1, 1),
234 block_dim: (BLOCK_SIZE, 1, 1),
235 shared_mem_bytes: 0,
236 },
237 (
238 &delta_x_v,
239 n_delta,
240 &range_lo,
241 &wp,
242 total_work,
243 &edge_z_v,
244 &mut est,
245 est_bit_mask,
246 ),
247 )
248 .map_err(|e| {
249 XlogError::Kernel(format!("fj_delta_sparse_estimate launch: {e}"))
250 })?;
251 }
252 let popcount = self
255 .device()
256 .inner()
257 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_POPCOUNT)
258 .ok_or_else(|| XlogError::Kernel("fj_delta_popcount not found".to_string()))?;
259 let pgrid = EST_WORDS.div_ceil(BLOCK_SIZE);
260 unsafe {
262 popcount
263 .clone()
264 .launch_on_stream(
265 &cu_stream,
266 LaunchConfig {
267 grid_dim: (pgrid, 1, 1),
268 block_dim: (BLOCK_SIZE, 1, 1),
269 shared_mem_bytes: 0,
270 },
271 (&est, EST_WORDS, &mut est_counts),
272 )
273 .map_err(|e| XlogError::Kernel(format!("estimator popcount launch: {e}")))?;
274 }
275 self.multiblock_scan_u32_inplace_on_stream(
276 &mut est_counts,
277 EST_WORDS + 1,
278 &cu_stream,
279 launch_stream,
280 runtime,
281 )?;
282 rec.commit(runtime)
283 .map_err(|e| XlogError::Kernel(format!("{ctx}: estimate commit: {e}")))?;
284 }
285 cu_stream
286 .synchronize()
287 .map_err(|e| XlogError::Kernel(format!("{ctx}: estimate sync: {e}")))?;
288 let distinct_est = self.dtoh_scalar_untracked::<u32>(&est_counts, EST_WORDS as usize)?;
289 drop(est);
292 drop(est_counts);
293
294 let est_margined = (u64::from(distinct_est) * 3) / 2;
300 let upper = u64::from(n_r) + est_margined + 1;
301 let want = upper
302 .checked_mul(2)
303 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: table size overflow")))?;
304 let mut cap: u64 = 1;
305 while cap < want {
306 cap <<= 1;
307 }
308 if cap > u64::from(u32::MAX) {
309 return Err(XlogError::Kernel(format!(
310 "{ctx}: hash table capacity {cap} exceeds u32 slot space (workload too large \
311 for the spike's single-table sizing)"
312 )));
313 }
314 if max_table_bytes != 0 {
318 let table_bytes = u64::from(cap).saturating_mul(8 + 1 + 4).saturating_add(4);
319 if table_bytes > max_table_bytes {
320 return Ok(None);
321 }
322 }
323 let cap = cap as u32;
324 let mask = cap - 1;
325
326 let mut table = self.memory().alloc::<u64>(cap as usize)?;
327 let mut is_r = self.memory().alloc::<u8>(cap as usize)?;
328 self.device()
329 .inner()
330 .memset_zeros(&mut table)
331 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero table: {e}")))?;
332 self.device()
333 .inner()
334 .memset_zeros(&mut is_r)
335 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero is_r: {e}")))?;
336 let mut overflow = self.memory().alloc::<u32>(1)?;
337 self.device()
338 .inner()
339 .memset_zeros(&mut overflow)
340 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero overflow: {e}")))?;
341
342 {
344 let mut rec = LaunchRecorder::new_strict(launch_stream);
345 rec.read_column(delta_x);
346 rec.read_column(edge_z);
347 rec.read(&range_lo);
348 rec.read(&wp);
349 rec.read_write(&table);
350 rec.read_write(&is_r);
351 rec.write(&overflow);
352 if n_r > 0 {
353 rec.read_column(r_x);
354 rec.read_column(r_z);
355 }
356 rec.preflight(runtime)
357 .map_err(|e| XlogError::Kernel(format!("{ctx}: insert preflight: {e}")))?;
358
359 if n_r > 0 {
360 let r_x_v = self.column_as_u32_view(r_x, n_r as usize)?;
361 let r_z_v = self.column_as_u32_view(r_z, n_r as usize)?;
362 let load_r = self
363 .device()
364 .inner()
365 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_SPARSE_LOAD_R)
366 .ok_or_else(|| {
367 XlogError::Kernel("fj_delta_sparse_load_r not found".to_string())
368 })?;
369 let grid = n_r.div_ceil(BLOCK_SIZE);
370 unsafe {
372 load_r
373 .clone()
374 .launch_on_stream(
375 &cu_stream,
376 LaunchConfig {
377 grid_dim: (grid, 1, 1),
378 block_dim: (BLOCK_SIZE, 1, 1),
379 shared_mem_bytes: 0,
380 },
381 (
382 &r_x_v,
383 &r_z_v,
384 n_r,
385 &mut table,
386 &mut is_r,
387 mask,
388 &mut overflow,
389 ),
390 )
391 .map_err(|e| {
392 XlogError::Kernel(format!("fj_delta_sparse_load_r launch: {e}"))
393 })?;
394 }
395 }
396
397 let delta_x_v = self.column_as_u32_view(delta_x, n_delta as usize)?;
398 let edge_z_v = self.column_as_u32_view(edge_z, n_edge as usize)?;
399 let insert = self
400 .device()
401 .inner()
402 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_SPARSE_INSERT_CANDIDATES)
403 .ok_or_else(|| {
404 XlogError::Kernel("fj_delta_sparse_insert_candidates not found".to_string())
405 })?;
406 let grid = total_work.div_ceil(BLOCK_SIZE);
407 unsafe {
410 insert
411 .clone()
412 .launch_on_stream(
413 &cu_stream,
414 LaunchConfig {
415 grid_dim: (grid, 1, 1),
416 block_dim: (BLOCK_SIZE, 1, 1),
417 shared_mem_bytes: 0,
418 },
419 (
420 &delta_x_v,
421 n_delta,
422 &range_lo,
423 &wp,
424 total_work,
425 &edge_z_v,
426 &mut table,
427 mask,
428 &mut overflow,
429 ),
430 )
431 .map_err(|e| {
432 XlogError::Kernel(format!("fj_delta_sparse_insert_candidates launch: {e}"))
433 })?;
434 }
435 rec.commit(runtime)
436 .map_err(|e| XlogError::Kernel(format!("{ctx}: insert commit: {e}")))?;
437 }
438 cu_stream
439 .synchronize()
440 .map_err(|e| XlogError::Kernel(format!("{ctx}: insert sync: {e}")))?;
441
442 if self.dtoh_scalar_untracked::<u32>(&overflow, 0)? != 0 {
447 return Ok(None);
448 }
449
450 let mut counts = self.memory().alloc::<u32>(cap as usize + 1)?;
452 {
453 let mut rec = LaunchRecorder::new_strict(launch_stream);
454 rec.read(&table);
455 rec.read(&is_r);
456 rec.write(&counts);
457 rec.preflight(runtime)
458 .map_err(|e| XlogError::Kernel(format!("{ctx}: mark preflight: {e}")))?;
459 let mark = self
460 .device()
461 .inner()
462 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_SPARSE_MARK)
463 .ok_or_else(|| XlogError::Kernel("fj_delta_sparse_mark not found".to_string()))?;
464 let grid = cap.div_ceil(BLOCK_SIZE);
465 unsafe {
467 mark.clone()
468 .launch_on_stream(
469 &cu_stream,
470 LaunchConfig {
471 grid_dim: (grid, 1, 1),
472 block_dim: (BLOCK_SIZE, 1, 1),
473 shared_mem_bytes: 0,
474 },
475 (&table, &is_r, cap, &mut counts),
476 )
477 .map_err(|e| XlogError::Kernel(format!("fj_delta_sparse_mark launch: {e}")))?;
478 }
479 self.multiblock_scan_u32_inplace_on_stream(
480 &mut counts,
481 cap + 1,
482 &cu_stream,
483 launch_stream,
484 runtime,
485 )?;
486 rec.commit(runtime)
487 .map_err(|e| XlogError::Kernel(format!("{ctx}: mark commit: {e}")))?;
488 }
489 cu_stream
490 .synchronize()
491 .map_err(|e| XlogError::Kernel(format!("{ctx}: mark sync: {e}")))?;
492 let total_novel = self.dtoh_scalar_untracked::<u32>(&counts, cap as usize)?;
493 if total_novel == 0 {
494 return Ok(Some(self.create_empty_buffer(out_schema)?));
495 }
496
497 let out_x = self.memory().alloc::<u32>(total_novel as usize)?;
498 let out_z = self.memory().alloc::<u32>(total_novel as usize)?;
499 {
500 let mut rec = LaunchRecorder::new_strict(launch_stream);
501 rec.read(&table);
502 rec.read(&is_r);
503 rec.read(&counts);
504 rec.write(&out_x);
505 rec.write(&out_z);
506 rec.preflight(runtime)
507 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit preflight: {e}")))?;
508 let emit = self
509 .device()
510 .inner()
511 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_SPARSE_EMIT)
512 .ok_or_else(|| XlogError::Kernel("fj_delta_sparse_emit not found".to_string()))?;
513 let grid = cap.div_ceil(BLOCK_SIZE);
514 unsafe {
516 emit.clone()
517 .launch_on_stream(
518 &cu_stream,
519 LaunchConfig {
520 grid_dim: (grid, 1, 1),
521 block_dim: (BLOCK_SIZE, 1, 1),
522 shared_mem_bytes: 0,
523 },
524 (&table, &is_r, &counts, cap, &out_x, &out_z),
525 )
526 .map_err(|e| XlogError::Kernel(format!("fj_delta_sparse_emit launch: {e}")))?;
527 }
528 rec.commit(runtime)
529 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit commit: {e}")))?;
530 }
531 cu_stream
532 .synchronize()
533 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit sync: {e}")))?;
534
535 let d_nr = self.memory().alloc::<u32>(1)?;
536 self.htod_launch_metadata_async_copy_one(
537 &total_novel,
538 &d_nr,
539 &cu_stream,
540 &format!("{ctx}: result num_rows"),
541 )?;
542 let columns = if cols.r_carry == 0 {
543 vec![out_x.into_bytes().into(), out_z.into_bytes().into()]
544 } else {
545 vec![out_z.into_bytes().into(), out_x.into_bytes().into()]
546 };
547 Ok(Some(CudaBuffer::from_columns_with_host_count(
548 columns,
549 u64::from(total_novel),
550 d_nr,
551 out_schema,
552 total_novel,
553 )))
554 }
555}