1use xlog_core::{Result, ScalarType, XlogError};
36
37use super::{wcoj_kernels, CudaKernelProvider, WCOJ_MODULE};
38use crate::device_runtime::StreamId;
39use crate::launch::LaunchRecorder;
40use crate::CudaBuffer;
41use crate::{LaunchAsync, LaunchConfig};
42
43const BLOCK_SIZE: u32 = 256;
44
45pub const FJ_DELTA_MAX_DOMAIN: u32 = 1 << 16;
48
49#[derive(Debug, Clone, Copy)]
56pub struct FjDeltaCols {
57 pub delta_carry: usize,
59 pub delta_key: usize,
61 pub r_carry: usize,
63 pub r_value: usize,
65}
66
67impl FjDeltaCols {
68 pub const CANONICAL: Self = Self {
70 delta_carry: 0,
71 delta_key: 1,
72 r_carry: 0,
73 r_value: 1,
74 };
75
76 fn validate(&self, ctx: &str) -> Result<()> {
77 let ok = |a: usize, b: usize| a < 2 && b < 2 && a != b;
78 if !ok(self.delta_carry, self.delta_key) || !ok(self.r_carry, self.r_value) {
79 return Err(XlogError::Kernel(format!(
80 "{ctx}: invalid column roles {self:?} (arity-2 indices, pairwise distinct)"
81 )));
82 }
83 Ok(())
84 }
85}
86
87fn require_binary_u32_class(buf: &CudaBuffer, name: &str, ctx: &str) -> Result<()> {
88 if buf.arity() != 2 {
89 return Err(XlogError::Kernel(format!(
90 "{ctx}: {name} must be arity-2, got {}",
91 buf.arity()
92 )));
93 }
94 for idx in 0..2 {
95 match buf.schema().column_type(idx) {
96 Some(ScalarType::U32) | Some(ScalarType::Symbol) => {}
97 other => {
98 return Err(XlogError::Kernel(format!(
99 "{ctx}: {name} column {idx} must be U32/Symbol, got {other:?}"
100 )));
101 }
102 }
103 }
104 Ok(())
105}
106
107impl CudaKernelProvider {
108 pub fn fj_delta_novel_u32_recorded(
118 &self,
119 delta: &CudaBuffer,
120 edge: &CudaBuffer,
121 full_r: &CudaBuffer,
122 cols: FjDeltaCols,
123 domain: u32,
124 launch_stream: StreamId,
125 ) -> Result<CudaBuffer> {
126 let ctx = "fj_delta_novel_u32_recorded";
127 let runtime = self.memory().runtime().ok_or_else(|| {
128 XlogError::Kernel(format!(
129 "{ctx} requires a runtime-backed GpuMemoryManager \
130 (constructed via with_runtime)"
131 ))
132 })?;
133 let cu_stream = runtime
134 .stream_pool()
135 .resolve(launch_stream)
136 .ok_or_else(|| {
137 XlogError::Kernel(format!(
138 "{ctx}: launch_stream StreamId({}) does not resolve",
139 launch_stream.0
140 ))
141 })?;
142
143 cols.validate(ctx)?;
144 require_binary_u32_class(delta, "delta", ctx)?;
145 require_binary_u32_class(edge, "edge", ctx)?;
146 require_binary_u32_class(full_r, "full_r", ctx)?;
147 if domain == 0 || domain > FJ_DELTA_MAX_DOMAIN {
148 return Err(XlogError::Kernel(format!(
149 "{ctx}: domain {domain} outside (0, {FJ_DELTA_MAX_DOMAIN}] \
150 (dense-domain spike bound)"
151 )));
152 }
153
154 let row_count = |buf: &CudaBuffer| -> Result<u32> {
155 match buf.cached_row_count() {
156 Some(c) => Ok(c),
157 None => self.dtoh_scalar_untracked::<u32>(buf.num_rows_device(), 0),
158 }
159 };
160 let n_delta = row_count(delta)?;
161 let n_edge = row_count(edge)?;
162 let n_r = row_count(full_r)?;
163
164 let out_schema = full_r.schema().clone();
165 if n_delta == 0 || n_edge == 0 {
166 return self.create_empty_buffer(out_schema);
167 }
168
169 let words_per_row = domain.div_ceil(32);
170 let n_words = u64::from(domain) * u64::from(words_per_row);
171 let n_words = n_words as u32;
173
174 let delta_x = delta.column(cols.delta_carry).ok_or_else(|| {
175 XlogError::Kernel(format!("{ctx}: delta column {} missing", cols.delta_carry))
176 })?;
177 let delta_y = delta.column(cols.delta_key).ok_or_else(|| {
178 XlogError::Kernel(format!("{ctx}: delta column {} missing", cols.delta_key))
179 })?;
180 let edge_y = edge
181 .column(0)
182 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: edge column 0 missing")))?;
183 let edge_z = edge
184 .column(1)
185 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: edge column 1 missing")))?;
186 let r_x = full_r.column(cols.r_carry).ok_or_else(|| {
187 XlogError::Kernel(format!("{ctx}: full_r column {} missing", cols.r_carry))
188 })?;
189 let r_z = full_r.column(cols.r_value).ok_or_else(|| {
190 XlogError::Kernel(format!("{ctx}: full_r column {} missing", cols.r_value))
191 })?;
192 let delta_x_v = self.column_as_u32_view(delta_x, n_delta as usize)?;
193 let delta_y_v = self.column_as_u32_view(delta_y, n_delta as usize)?;
194 let edge_y_v = self.column_as_u32_view(edge_y, n_edge as usize)?;
195 let edge_z_v = self.column_as_u32_view(edge_z, n_edge as usize)?;
196
197 let range_lo = self.memory().alloc::<u32>(n_delta as usize)?;
199 let mut wp = self.memory().alloc::<u32>(n_delta as usize + 1)?;
200 {
201 let mut rec = LaunchRecorder::new_strict(launch_stream);
202 rec.read_column(delta_y);
203 rec.read_column(edge_y);
204 rec.write(&range_lo);
205 rec.write(&wp);
206 rec.preflight(runtime)
207 .map_err(|e| XlogError::Kernel(format!("{ctx}: range preflight failed: {e}")))?;
208 let kernel = self
209 .device()
210 .inner()
211 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_RANGE_U32)
212 .ok_or_else(|| {
213 XlogError::Kernel("fj_delta_range_u32 kernel not found".to_string())
214 })?;
215 let grid = n_delta.div_ceil(BLOCK_SIZE);
216 unsafe {
220 kernel
221 .clone()
222 .launch_on_stream(
223 &cu_stream,
224 LaunchConfig {
225 grid_dim: (grid, 1, 1),
226 block_dim: (BLOCK_SIZE, 1, 1),
227 shared_mem_bytes: 0,
228 },
229 (&delta_y_v, n_delta, &edge_y_v, n_edge, &range_lo, &mut wp),
230 )
231 .map_err(|e| {
232 XlogError::Kernel(format!("fj_delta_range_u32 launch failed: {e}"))
233 })?;
234 }
235 self.multiblock_scan_u32_inplace_on_stream(
236 &mut wp,
237 n_delta + 1,
238 &cu_stream,
239 launch_stream,
240 runtime,
241 )?;
242 rec.commit(runtime)
243 .map_err(|e| XlogError::Kernel(format!("{ctx}: range commit failed: {e}")))?;
244 }
245 cu_stream
246 .synchronize()
247 .map_err(|e| XlogError::Kernel(format!("{ctx}: range sync failed: {e}")))?;
248 let total_work = u64::from(self.dtoh_scalar_untracked::<u32>(&wp, n_delta as usize)?);
249 if total_work == 0 {
250 return self.create_empty_buffer(out_schema);
251 }
252 if total_work > u64::from(u32::MAX - 1) {
253 return Err(XlogError::Kernel(format!(
254 "{ctx}: candidate work {total_work} exceeds the u32 work-index space"
255 )));
256 }
257 let total_work = total_work as u32;
258
259 let mut bitmap = self.memory().alloc::<u32>(n_words as usize)?;
261 self.device()
262 .inner()
263 .memset_zeros(&mut bitmap)
264 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero bitmap failed: {e}")))?;
265 let mut error_flag = self.memory().alloc::<u32>(1)?;
266 self.device()
267 .inner()
268 .memset_zeros(&mut error_flag)
269 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero error flag failed: {e}")))?;
270 {
271 let mut rec = LaunchRecorder::new_strict(launch_stream);
272 rec.read_column(delta_x);
273 rec.read_column(edge_z);
274 rec.read(&range_lo);
275 rec.read(&wp);
276 rec.read_write(&bitmap);
277 rec.write(&error_flag);
278 if n_r > 0 {
279 rec.read_column(r_x);
280 rec.read_column(r_z);
281 }
282 rec.preflight(runtime)
283 .map_err(|e| XlogError::Kernel(format!("{ctx}: mark preflight failed: {e}")))?;
284 let mark = self
285 .device()
286 .inner()
287 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_MARK_U32)
288 .ok_or_else(|| {
289 XlogError::Kernel("fj_delta_mark_u32 kernel not found".to_string())
290 })?;
291 let grid = total_work.div_ceil(BLOCK_SIZE);
292 unsafe {
296 mark.clone()
297 .launch_on_stream(
298 &cu_stream,
299 LaunchConfig {
300 grid_dim: (grid, 1, 1),
301 block_dim: (BLOCK_SIZE, 1, 1),
302 shared_mem_bytes: 0,
303 },
304 (
305 &delta_x_v,
306 n_delta,
307 &range_lo,
308 &wp,
309 total_work,
310 &edge_z_v,
311 &mut bitmap,
312 words_per_row,
313 domain,
314 &mut error_flag,
315 ),
316 )
317 .map_err(|e| {
318 XlogError::Kernel(format!("fj_delta_mark_u32 launch failed: {e}"))
319 })?;
320 }
321 if n_r > 0 {
322 let r_x_v = self.column_as_u32_view(r_x, n_r as usize)?;
323 let r_z_v = self.column_as_u32_view(r_z, n_r as usize)?;
324 let subtract = self
325 .device()
326 .inner()
327 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_SUBTRACT_U32)
328 .ok_or_else(|| {
329 XlogError::Kernel("fj_delta_subtract_u32 kernel not found".to_string())
330 })?;
331 let grid = n_r.div_ceil(BLOCK_SIZE);
332 unsafe {
336 subtract
337 .clone()
338 .launch_on_stream(
339 &cu_stream,
340 LaunchConfig {
341 grid_dim: (grid, 1, 1),
342 block_dim: (BLOCK_SIZE, 1, 1),
343 shared_mem_bytes: 0,
344 },
345 (
346 &r_x_v,
347 &r_z_v,
348 n_r,
349 &mut bitmap,
350 words_per_row,
351 domain,
352 &mut error_flag,
353 ),
354 )
355 .map_err(|e| {
356 XlogError::Kernel(format!("fj_delta_subtract_u32 launch failed: {e}"))
357 })?;
358 }
359 }
360 rec.commit(runtime)
361 .map_err(|e| XlogError::Kernel(format!("{ctx}: mark commit failed: {e}")))?;
362 }
363 cu_stream
364 .synchronize()
365 .map_err(|e| XlogError::Kernel(format!("{ctx}: mark sync failed: {e}")))?;
366 if self.dtoh_scalar_untracked::<u32>(&error_flag, 0)? != 0 {
367 return Err(XlogError::Kernel(format!(
368 "{ctx}: id outside domain {domain} (fail-closed; raise domain or \
369 renumber the fixture)"
370 )));
371 }
372
373 let mut counts = self.memory().alloc::<u32>(n_words as usize + 1)?;
375 {
376 let mut rec = LaunchRecorder::new_strict(launch_stream);
377 rec.read(&bitmap);
378 rec.write(&counts);
379 rec.preflight(runtime)
380 .map_err(|e| XlogError::Kernel(format!("{ctx}: count preflight failed: {e}")))?;
381 let popcount = self
382 .device()
383 .inner()
384 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_POPCOUNT)
385 .ok_or_else(|| {
386 XlogError::Kernel("fj_delta_popcount kernel not found".to_string())
387 })?;
388 let grid = n_words.div_ceil(BLOCK_SIZE);
389 unsafe {
391 popcount
392 .clone()
393 .launch_on_stream(
394 &cu_stream,
395 LaunchConfig {
396 grid_dim: (grid, 1, 1),
397 block_dim: (BLOCK_SIZE, 1, 1),
398 shared_mem_bytes: 0,
399 },
400 (&bitmap, n_words, &mut counts),
401 )
402 .map_err(|e| {
403 XlogError::Kernel(format!("fj_delta_popcount launch failed: {e}"))
404 })?;
405 }
406 self.multiblock_scan_u32_inplace_on_stream(
407 &mut counts,
408 n_words + 1,
409 &cu_stream,
410 launch_stream,
411 runtime,
412 )?;
413 rec.commit(runtime)
414 .map_err(|e| XlogError::Kernel(format!("{ctx}: count commit failed: {e}")))?;
415 }
416 cu_stream
417 .synchronize()
418 .map_err(|e| XlogError::Kernel(format!("{ctx}: count sync failed: {e}")))?;
419 let total_novel = self.dtoh_scalar_untracked::<u32>(&counts, n_words as usize)?;
420 if total_novel == 0 {
421 return self.create_empty_buffer(out_schema);
422 }
423
424 let out_x = self.memory().alloc::<u32>(total_novel as usize)?;
425 let out_z = self.memory().alloc::<u32>(total_novel as usize)?;
426 {
427 let mut rec = LaunchRecorder::new_strict(launch_stream);
428 rec.read(&bitmap);
429 rec.read(&counts);
430 rec.write(&out_x);
431 rec.write(&out_z);
432 rec.preflight(runtime)
433 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit preflight failed: {e}")))?;
434 let emit = self
435 .device()
436 .inner()
437 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_EMIT_U32)
438 .ok_or_else(|| {
439 XlogError::Kernel("fj_delta_emit_u32 kernel not found".to_string())
440 })?;
441 let grid = n_words.div_ceil(BLOCK_SIZE);
442 unsafe {
445 emit.clone()
446 .launch_on_stream(
447 &cu_stream,
448 LaunchConfig {
449 grid_dim: (grid, 1, 1),
450 block_dim: (BLOCK_SIZE, 1, 1),
451 shared_mem_bytes: 0,
452 },
453 (&bitmap, words_per_row, n_words, &counts, &out_x, &out_z),
454 )
455 .map_err(|e| {
456 XlogError::Kernel(format!("fj_delta_emit_u32 launch failed: {e}"))
457 })?;
458 }
459 rec.commit(runtime)
460 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit commit failed: {e}")))?;
461 }
462 cu_stream
463 .synchronize()
464 .map_err(|e| XlogError::Kernel(format!("{ctx}: emit sync failed: {e}")))?;
465
466 let d_nr = self.memory().alloc::<u32>(1)?;
467 self.htod_launch_metadata_async_copy_one(
468 &total_novel,
469 &d_nr,
470 &cu_stream,
471 &format!("{ctx}: result num_rows"),
472 )?;
473 let columns = if cols.r_carry == 0 {
476 vec![out_x.into_bytes().into(), out_z.into_bytes().into()]
477 } else {
478 vec![out_z.into_bytes().into(), out_x.into_bytes().into()]
479 };
480 Ok(CudaBuffer::from_columns_with_host_count(
481 columns,
482 u64::from(total_novel),
483 d_nr,
484 out_schema,
485 total_novel,
486 ))
487 }
488
489 pub fn fj_delta_columns_max_u32(
494 &self,
495 inputs: &[(&CudaBuffer, &[usize])],
496 launch_stream: StreamId,
497 ) -> Result<u32> {
498 let ctx = "fj_delta_columns_max_u32";
499 let runtime = self.memory().runtime().ok_or_else(|| {
500 XlogError::Kernel(format!("{ctx} requires a runtime-backed GpuMemoryManager"))
501 })?;
502 let cu_stream = runtime
503 .stream_pool()
504 .resolve(launch_stream)
505 .ok_or_else(|| {
506 XlogError::Kernel(format!(
507 "{ctx}: launch_stream StreamId({}) does not resolve",
508 launch_stream.0
509 ))
510 })?;
511 let mut d_max = self.memory().alloc::<u32>(1)?;
512 self.device()
513 .inner()
514 .memset_zeros(&mut d_max)
515 .map_err(|e| XlogError::Kernel(format!("{ctx}: zero max cell failed: {e}")))?;
516 let kernel = self
517 .device()
518 .inner()
519 .get_func(WCOJ_MODULE, wcoj_kernels::FJ_DELTA_MAX_U32)
520 .ok_or_else(|| XlogError::Kernel("fj_delta_max_u32 kernel not found".to_string()))?;
521 for (buf, col_idxs) in inputs {
522 let n = match buf.cached_row_count() {
523 Some(c) => c,
524 None => self.dtoh_scalar_untracked::<u32>(buf.num_rows_device(), 0)?,
525 };
526 if n == 0 {
527 continue;
528 }
529 for &idx in *col_idxs {
530 let col = buf
531 .column(idx)
532 .ok_or_else(|| XlogError::Kernel(format!("{ctx}: column {idx} missing")))?;
533 let view = self.column_as_u32_view(col, n as usize)?;
534 let mut rec = LaunchRecorder::new_strict(launch_stream);
535 rec.read_column(col);
536 rec.read_write(&d_max);
537 rec.preflight(runtime)
538 .map_err(|e| XlogError::Kernel(format!("{ctx}: preflight failed: {e}")))?;
539 let grid = n.div_ceil(BLOCK_SIZE);
540 unsafe {
543 kernel
544 .clone()
545 .launch_on_stream(
546 &cu_stream,
547 LaunchConfig {
548 grid_dim: (grid, 1, 1),
549 block_dim: (BLOCK_SIZE, 1, 1),
550 shared_mem_bytes: 0,
551 },
552 (&view, n, &mut d_max),
553 )
554 .map_err(|e| {
555 XlogError::Kernel(format!("fj_delta_max_u32 launch failed: {e}"))
556 })?;
557 }
558 rec.commit(runtime)
559 .map_err(|e| XlogError::Kernel(format!("{ctx}: commit failed: {e}")))?;
560 }
561 }
562 cu_stream
563 .synchronize()
564 .map_err(|e| XlogError::Kernel(format!("{ctx}: sync failed: {e}")))?;
565 self.dtoh_scalar_untracked::<u32>(&d_max, 0)
566 }
567}