Skip to main content

xlog_runtime/executor/
node_dispatch.rs

1//! RIR node dispatch and per-node execution handlers.
2
3use std::collections::HashMap;
4
5use xlog_core::{AggOp, RelId, Result, ScalarType, Schema, XlogError};
6use xlog_cuda::provider::NESTED_LOOP_TOTAL_THRESHOLD;
7use xlog_cuda::{CudaBuffer, JoinType as CudaJoinType};
8use xlog_ir::{JoinType, ProjectExpr, RirNode};
9
10use crate::ilp_registry::{read_device_row_count, IlpMask, IlpTagEntry, IlpTaggedResult};
11
12use super::join_cache::{estimate_join_index_bytes, JoinIndexKey};
13use super::Executor;
14
15/// Eligibility predicate for nested-loop join dispatch.
16///
17/// Returns `true` iff the join shape is admissible for the
18/// `nested_loop_join_v2_inner_u32_1key` provider entry point.
19/// The predicate is intentionally narrow for the nested-loop dispatch contract:
20///   * `JoinType::Inner` only (Semi / Anti / LeftOuter fall back
21///     to hash).
22///   * Exactly one key column on each side.
23///   * Both key columns share the same `ScalarType` AND that
24///     shared type is `U32` or `Symbol` (Symbol is `u32` at the
25///     byte level — same kernel applies). U32-on-Symbol or
26///     other type mismatches return `false`, mirroring
27///     `hash_join_v2`'s own type-mismatch rejection at
28///     `crates/xlog-cuda/src/provider/relational.rs:3567-3576`.
29///
30/// Out-of-bounds key indices yield `Schema::column_type(_) = None`,
31/// which fails the `matches!(...)` guard — falling back to hash
32/// without a separate bounds check.
33///
34/// Cheap O(1): no kernel launches, no row-count reads, no device-to-host transfer.
35/// The threshold check (`num_left * num_right <=
36/// NESTED_LOOP_TOTAL_THRESHOLD`) is performed at the dispatch
37/// site, not in this predicate.
38//
39fn eligible_for_nested_loop(
40    left: &CudaBuffer,
41    right: &CudaBuffer,
42    left_keys: &[usize],
43    right_keys: &[usize],
44    join_type: JoinType,
45) -> bool {
46    if join_type != JoinType::Inner {
47        return false;
48    }
49    if left_keys.len() != 1 || right_keys.len() != 1 {
50        return false;
51    }
52    let lt = left.schema().column_type(left_keys[0]);
53    let rt = right.schema().column_type(right_keys[0]);
54    lt == rt && matches!(lt, Some(ScalarType::U32) | Some(ScalarType::Symbol))
55}
56
57fn is_join_index_mismatch(err: &XlogError) -> bool {
58    matches!(
59        err,
60        XlogError::Kernel(msg)
61            if msg.contains("Join index row count does not match right relation")
62                || msg.contains("Join index key columns do not match requested right_keys")
63    )
64}
65
66impl Executor {
67    /// Execute a Scan node — looks up the relation by RelId and returns a clone.
68    pub(super) fn execute_scan(&mut self, rel: RelId) -> Result<CudaBuffer> {
69        let name = self
70            .get_rel_name(rel)
71            .ok_or_else(|| XlogError::Execution(format!("Unknown relation: RelId({})", rel.0)))?;
72
73        let buffer = self
74            .store
75            .get(name)
76            .ok_or_else(|| XlogError::Execution(format!("Relation not found: {}", name)))?;
77
78        self.stats.record_access(rel);
79        self.stats.update_cardinality(rel, buffer.num_rows());
80        self.stats.update_byte_size(rel, buffer.estimated_bytes());
81
82        self.clone_buffer(buffer)
83    }
84
85    /// Execute a single RIR node tree
86    ///
87    /// Recursively evaluates the node and its children, returning
88    /// the result as a GPU buffer.
89    ///
90    /// # Arguments
91    /// * `node` - The RIR node to execute
92    ///
93    /// # Returns
94    /// A CudaBuffer containing the result of the node execution
95    ///
96    /// # Errors
97    /// Returns an error if the node execution fails
98    pub fn execute_node(&mut self, node: &RirNode) -> Result<CudaBuffer> {
99        if !self.common_subexpression_enabled() || !Self::is_common_subexpression_cacheable(node) {
100            return self.execute_node_uncached(node);
101        }
102
103        let Some(key) = self.common_subexpression_key(node) else {
104            return self.execute_node_uncached(node);
105        };
106
107        if self.common_subexpression_cache.contains_key(&key) {
108            let cached = self
109                .common_subexpression_cache
110                .remove(&key)
111                .expect("cache key checked above");
112            let result = self.clone_buffer(&cached)?;
113            self.common_subexpression_cache.insert(key, cached);
114            self.common_subexpression_stats.hits =
115                self.common_subexpression_stats.hits.saturating_add(1);
116            return Ok(result);
117        }
118
119        self.common_subexpression_stats.misses =
120            self.common_subexpression_stats.misses.saturating_add(1);
121        let result = self.execute_node_uncached(node)?;
122        let cached = self.clone_buffer(&result)?;
123        self.common_subexpression_cache.insert(key, cached);
124        Ok(result)
125    }
126
127    fn execute_node_uncached(&mut self, node: &RirNode) -> Result<CudaBuffer> {
128        match node {
129            RirNode::Unit => {
130                // Materialize the relational "unit" ({()}) as a 0-arity buffer with one row.
131                let mut d_num_rows = self.provider.memory().alloc::<u32>(1)?;
132                self.provider
133                    .htod_launch_metadata_sync_copy_into(&[1u32], &mut d_num_rows)
134                    .map_err(|e| {
135                        XlogError::Kernel(format!("Failed to create unit row count: {}", e))
136                    })?;
137                Ok(CudaBuffer::from_columns(
138                    Vec::new(),
139                    1,
140                    d_num_rows,
141                    Schema::new(vec![]),
142                ))
143            }
144
145            RirNode::Scan { rel } => {
146                let start = self.profiler.start_op();
147                let result = self.execute_scan(*rel)?;
148                if let Some(start) = start {
149                    let mem = self.provider.memory().allocated_bytes();
150                    self.profiler
151                        .record_op("scan", 0, result.num_rows(), start, mem);
152                    self.profiler.record_peak_memory(mem);
153                }
154                Ok(result)
155            }
156
157            RirNode::Filter { input, predicate } => {
158                let input_buf = self.execute_node(input)?;
159                let input_rows = input_buf.num_rows();
160                let start = self.profiler.start_op();
161                let result = self.execute_filter(&input_buf, predicate)?;
162                if let Some(start) = start {
163                    let mem = self.provider.memory().allocated_bytes();
164                    self.profiler
165                        .record_op("filter", input_rows, result.num_rows(), start, mem);
166                    self.profiler.record_peak_memory(mem);
167                }
168                Ok(result)
169            }
170
171            RirNode::Project { input, columns } => {
172                let input_buf = self.execute_node(input)?;
173                let input_rows = input_buf.num_rows();
174                let start = self.profiler.start_op();
175                let result = self.execute_project(&input_buf, columns)?;
176                if let Some(start) = start {
177                    let mem = self.provider.memory().allocated_bytes();
178                    self.profiler
179                        .record_op("project", input_rows, result.num_rows(), start, mem);
180                    self.profiler.record_peak_memory(mem);
181                }
182                Ok(result)
183            }
184
185            RirNode::Join {
186                left,
187                right,
188                left_keys,
189                right_keys,
190                join_type,
191            } => {
192                let left_rel = match left.as_ref() {
193                    RirNode::Scan { rel } => Some(*rel),
194                    _ => None,
195                };
196                let right_rel = match right.as_ref() {
197                    RirNode::Scan { rel } => Some(*rel),
198                    _ => None,
199                };
200                let left_buf = self.execute_node(left)?;
201                let right_buf = self.execute_node(right)?;
202                let input_rows = left_buf.num_rows() + right_buf.num_rows();
203                let start = self.profiler.start_op();
204                let result = self.execute_join(
205                    &left_buf, &right_buf, left_keys, right_keys, *join_type, left_rel, right_rel,
206                )?;
207                if let Some(start) = start {
208                    let mem = self.provider.memory().allocated_bytes();
209                    self.profiler
210                        .record_op("join", input_rows, result.num_rows(), start, mem);
211                    self.profiler.record_peak_memory(mem);
212                }
213                Ok(result)
214            }
215
216            RirNode::GroupBy {
217                input,
218                key_cols,
219                aggs,
220            } => {
221                // Aggregate fusion: a count/sum/min/max-by-root over a
222                // promoted triangle dispatches the fused kernels and never
223                // materializes the join. Declines fall through to the
224                // standard materialize+groupby path below.
225                if let Some(fused) =
226                    self.try_dispatch_wcoj_groupby_root_agg(input, key_cols, aggs)?
227                {
228                    return Ok(fused);
229                }
230                let input_buf = self.execute_node(input)?;
231                let input_rows = input_buf.num_rows();
232                let start = self.profiler.start_op();
233                let result = self.execute_groupby(&input_buf, key_cols, aggs)?;
234                if let Some(start) = start {
235                    let mem = self.provider.memory().allocated_bytes();
236                    self.profiler
237                        .record_op("groupby", input_rows, result.num_rows(), start, mem);
238                    self.profiler.record_peak_memory(mem);
239                }
240                Ok(result)
241            }
242
243            RirNode::Union { inputs } => {
244                let mut buffers = Vec::with_capacity(inputs.len());
245                let mut input_rows = 0u64;
246                for input in inputs {
247                    let buf = self.execute_node(input)?;
248                    input_rows += buf.num_rows();
249                    buffers.push(buf);
250                }
251                let start = self.profiler.start_op();
252                let result = self.execute_union(&buffers)?;
253                if let Some(start) = start {
254                    let mem = self.provider.memory().allocated_bytes();
255                    self.profiler
256                        .record_op("union", input_rows, result.num_rows(), start, mem);
257                    self.profiler.record_peak_memory(mem);
258                }
259                Ok(result)
260            }
261
262            RirNode::Distinct { input, key_cols } => {
263                let input_buf = self.execute_node(input)?;
264                let input_rows = input_buf.num_rows();
265                let start = self.profiler.start_op();
266                let result = self.execute_distinct(&input_buf, key_cols)?;
267                if let Some(start) = start {
268                    let mem = self.provider.memory().allocated_bytes();
269                    self.profiler
270                        .record_op("dedup", input_rows, result.num_rows(), start, mem);
271                    self.profiler.record_peak_memory(mem);
272                }
273                Ok(result)
274            }
275
276            RirNode::Diff { left, right } => {
277                let left_buf = self.execute_node(left)?;
278                let right_buf = self.execute_node(right)?;
279                let input_rows = left_buf.num_rows() + right_buf.num_rows();
280                let start = self.profiler.start_op();
281                let result = self.execute_diff(&left_buf, &right_buf)?;
282                if let Some(start) = start {
283                    let mem = self.provider.memory().allocated_bytes();
284                    self.profiler
285                        .record_op("diff", input_rows, result.num_rows(), start, mem);
286                    self.profiler.record_peak_memory(mem);
287                }
288                Ok(result)
289            }
290
291            RirNode::Fixpoint {
292                scc_id,
293                base,
294                recursive,
295                delta_rel,
296                full_rel,
297            } => {
298                // Semi-naive fixpoint iteration
299                self.execute_fixpoint(*scc_id, base, recursive, *delta_rel, *full_rel)
300            }
301            RirNode::TensorMaskedJoin {
302                mask_name,
303                schema_size,
304                left_keys,
305                right_keys,
306                rel_index,
307                head_rel_name,
308                max_active_rules,
309                head_projection,
310                ..
311            } => self.execute_tensor_masked_join(
312                mask_name,
313                *schema_size,
314                left_keys,
315                right_keys,
316                rel_index,
317                head_rel_name,
318                *max_active_rules,
319                head_projection,
320            ),
321            // Defensive fallback descent for any
322            // `execute_node` caller that bypasses the WCOJ dispatch
323            // hook (probabilistic eval, neural store walks, etc.).
324            // The non-recursive arm in `recursive.rs` short-circuits
325            // dispatch-eligible bodies before reaching here; this
326            // arm is the safety net for everyone else.
327            RirNode::MultiWayJoin { fallback, .. } | RirNode::ChainJoin { fallback, .. } => {
328                self.execute_node(fallback)
329            }
330        }
331    }
332
333    /// Execute a Join node
334    ///
335    /// Delegates to the kernel provider's hash_join_v2 which supports all join types natively.
336    #[allow(clippy::too_many_arguments)]
337    fn execute_join(
338        &mut self,
339        left: &CudaBuffer,
340        right: &CudaBuffer,
341        left_keys: &[usize],
342        right_keys: &[usize],
343        join_type: JoinType,
344        left_rel: Option<RelId>,
345        right_rel: Option<RelId>,
346    ) -> Result<CudaBuffer> {
347        // Convert IR JoinType to CUDA JoinType (used by adaptive
348        // indexing and the hash fallback below).
349        let cuda_join_type = match join_type {
350            JoinType::Inner => CudaJoinType::Inner,
351            JoinType::Semi => CudaJoinType::Semi,
352            JoinType::Anti => CudaJoinType::Anti,
353            JoinType::LeftOuter => CudaJoinType::LeftOuter,
354        };
355
356        // Output buffer set by nested-loop dispatch,
357        // adaptive indexing, or the hash fallback. All three
358        // paths flow through the shared `record_join_result`
359        // feedback block at the end of this fn.
360        let mut out: Option<CudaBuffer> = None;
361
362        // Nested-loop dispatch precedes adaptive indexing
363        // and hash fallback. On predicate + threshold pass,
364        // route to `nested_loop_join_v2_inner_u32_1key` and
365        // bump the dispatch counter; do NOT early-return —
366        // leave the result in `out` so the shared feedback
367        // block observes it. Otherwise leave `out` unchanged.
368        //
369        // Threshold check uses logical row counts via
370        // `provider.device_row_count(...)` (NOT `row_cap`), with
371        // `checked_mul` fail-closed on overflow before comparing
372        // against the nested-loop total-row threshold.
373        if eligible_for_nested_loop(left, right, left_keys, right_keys, join_type) {
374            let num_left = self.provider.device_row_count(left)? as u64;
375            let num_right = self.provider.device_row_count(right)? as u64;
376            let in_threshold = num_left
377                .checked_mul(num_right)
378                .map(|p| p <= NESTED_LOOP_TOTAL_THRESHOLD)
379                .unwrap_or(false);
380            if in_threshold {
381                out = Some(self.provider.nested_loop_join_v2_inner_u32_1key(
382                    left,
383                    right,
384                    left_keys[0],
385                    right_keys[0],
386                )?);
387                self.nested_loop_dispatch_count += 1;
388            }
389        }
390
391        // Adaptive indexing: opportunistically reuse cached
392        // build-side hash tables when the right side is a base
393        // relation scan and has become "hot" in runtime
394        // statistics. Only runs if nested-loop dispatch did not
395        // dispatch.
396        if out.is_none() && self.config.resolved_persistent_hash_indexes() {
397            if let Some(build_rel) = right_rel {
398                let build_heat = self
399                    .stats
400                    .get_relation_stats(build_rel)
401                    .map(|s| s.heat)
402                    .unwrap_or(0.0);
403                let est_index_bytes = estimate_join_index_bytes(right, right_keys);
404                let budget_bytes = self.provider.memory().budget().device_bytes;
405                let remaining_bytes = self.provider.memory().remaining_bytes();
406
407                let should_index = self.join_index_cache.should_build(
408                    est_index_bytes,
409                    build_heat,
410                    remaining_bytes,
411                    budget_bytes,
412                );
413
414                if let Some(build_name) = self.get_rel_name(build_rel).map(|s| s.to_string()) {
415                    if let Some(version) = self.store.version(&build_name) {
416                        let key = JoinIndexKey::new(
417                            build_rel,
418                            version,
419                            right_keys.to_vec(),
420                            right.schema(),
421                            self.provider.device().ordinal() as u32,
422                        );
423
424                        let indexed_result = {
425                            self.join_index_cache.get(&key).map(|index| {
426                                self.provider.hash_join_v2_with_index(
427                                    left,
428                                    right,
429                                    left_keys,
430                                    right_keys,
431                                    cuda_join_type,
432                                    index,
433                                    None,
434                                )
435                            })
436                        };
437                        if let Some(indexed_result) = indexed_result {
438                            match indexed_result {
439                                Ok(joined) => out = Some(joined),
440                                Err(err) if is_join_index_mismatch(&err) => {
441                                    self.join_index_cache.remove_stale(&key);
442                                }
443                                Err(err) => return Err(err),
444                            }
445                        } else if should_index {
446                            let background_build = self
447                                .config
448                                .resolved_persistent_hash_index_background_build();
449                            if background_build {
450                                self.join_index_cache.record_background_build_request();
451                            }
452                            let build_result = if background_build {
453                                self.provider
454                                    .build_join_index_v2_background(right, right_keys)
455                            } else {
456                                self.provider.build_join_index_v2(right, right_keys)
457                            };
458                            match build_result {
459                                Ok(index) => {
460                                    if background_build {
461                                        self.join_index_cache.record_background_build_complete();
462                                        self.join_index_cache.insert(key, index);
463                                        self.join_index_cache.record_background_build_deferred();
464                                        if let Some(stats) =
465                                            self.stats.get_relation_stats_mut(build_rel)
466                                        {
467                                            stats.has_index = true;
468                                        }
469                                    } else {
470                                        match self.provider.hash_join_v2_with_index(
471                                            left,
472                                            right,
473                                            left_keys,
474                                            right_keys,
475                                            cuda_join_type,
476                                            &index,
477                                            None,
478                                        ) {
479                                            Ok(joined) => {
480                                                self.join_index_cache.insert(key, index);
481                                                if let Some(stats) =
482                                                    self.stats.get_relation_stats_mut(build_rel)
483                                                {
484                                                    stats.has_index = true;
485                                                }
486                                                out = Some(joined);
487                                            }
488                                            Err(err) if is_join_index_mismatch(&err) => {}
489                                            Err(err) => return Err(err),
490                                        }
491                                    }
492                                }
493                                Err(_) => {
494                                    // If indexing fails (e.g., memory pressure), fall back to normal join.
495                                }
496                            }
497                        }
498                    }
499                }
500            }
501        } // end adaptive-indexing gate
502
503        let out = match out {
504            Some(buf) => buf,
505            None => {
506                self.provider
507                    .hash_join_v2(left, right, left_keys, right_keys, cuda_join_type)?
508            }
509        };
510
511        if let (Some(l), Some(r)) = (left_rel, right_rel) {
512            let input_rows = left.num_rows().saturating_mul(right.num_rows());
513            self.record_adaptive_join_observation(
514                l,
515                r,
516                left_keys,
517                right_keys,
518                input_rows,
519                out.num_rows(),
520            );
521            self.stats.record_join_result(
522                l,
523                r,
524                left_keys.to_vec(),
525                right_keys.to_vec(),
526                input_rows,
527                out.num_rows(),
528            );
529        }
530
531        Ok(out)
532    }
533
534    /// Execute a GroupBy node
535    ///
536    /// Delegates to the kernel provider's groupby_multi_agg for multi-aggregation support.
537    fn execute_groupby(
538        &self,
539        input: &CudaBuffer,
540        key_cols: &[usize],
541        aggs: &[(usize, AggOp)],
542    ) -> Result<CudaBuffer> {
543        if aggs.is_empty() {
544            // No aggregations: just distinct on key columns
545            return self.provider.dedup(input, key_cols);
546        }
547
548        // Use multi-aggregation groupby
549        self.provider.groupby_multi_agg(input, key_cols, aggs)
550    }
551
552    /// Execute a Union node
553    ///
554    /// Combines multiple input buffers into one using GPU-native operation.
555    pub(super) fn execute_union(&self, inputs: &[CudaBuffer]) -> Result<CudaBuffer> {
556        if inputs.is_empty() {
557            return self.provider.create_empty_buffer(Schema::new(vec![]));
558        }
559
560        if inputs.len() == 1 {
561            return self.clone_buffer(&inputs[0]);
562        }
563
564        // Pairwise union using GPU-native operation
565        let mut result = self.clone_buffer(&inputs[0])?;
566        for input in inputs.iter().skip(1) {
567            result = self.provider.union_gpu(&result, input)?;
568        }
569
570        Ok(result)
571    }
572
573    /// Execute a Distinct node
574    ///
575    /// Removes duplicate rows based on key columns.
576    pub(super) fn execute_distinct(
577        &self,
578        input: &CudaBuffer,
579        key_cols: &[usize],
580    ) -> Result<CudaBuffer> {
581        self.provider.dedup(input, key_cols)
582    }
583
584    /// Execute a Diff node
585    ///
586    /// Returns rows in left that are not in right using GPU-native operation.
587    pub(super) fn execute_diff(&self, left: &CudaBuffer, right: &CudaBuffer) -> Result<CudaBuffer> {
588        self.provider.diff_gpu(left, right)
589    }
590
591    #[allow(clippy::too_many_arguments)]
592    fn execute_tensor_masked_join(
593        &mut self,
594        mask_name: &str,
595        schema_size: usize,
596        left_keys: &[usize],
597        right_keys: &[usize],
598        rel_index: &[(RelId, String)],
599        head_rel_name: &str,
600        max_active_rules: usize,
601        head_projection: &[usize],
602    ) -> Result<CudaBuffer> {
603        // No-op when no mask is registered. Return an empty buffer with
604        // the head relation's schema (not Schema::new(vec![])) to prevent
605        // schema corruption when execute_non_recursive_scc stores the result.
606        let ilp_mask = match self.ilp_registry.get_mask(mask_name) {
607            Some(mask) => mask,
608            None => {
609                self.ilp_last_result = Some(IlpTaggedResult {
610                    entries: Vec::new(),
611                });
612                // Fail hard if the head relation is missing from the store.
613                let schema = self
614                    .store
615                    .get(head_rel_name)
616                    .map(|buf| buf.schema().clone())
617                    .ok_or_else(|| {
618                        XlogError::Execution(format!(
619                            "TensorMaskedJoin: head relation '{}' not found in store \
620                         (was load_facts_into_store called?)",
621                            head_rel_name
622                        ))
623                    })?;
624                return self.provider.create_empty_buffer(schema);
625            }
626        };
627
628        let start = self.profiler.start_op();
629
630        let head_k = rel_index
631            .iter()
632            .position(|(_, name)| name == head_rel_name)
633            .ok_or_else(|| {
634                XlogError::Execution(format!(
635                    "TensorMaskedJoin: head relation '{}' not found in rel_index",
636                    head_rel_name
637                ))
638            })? as u32;
639
640        let mut tag_entries: Vec<IlpTagEntry> = Vec::new();
641        let mut process_rule = |i: u32,
642                                j: u32,
643                                k: u32,
644                                strict_candidate_idx: Option<usize>,
645                                strict_flags: Option<&CudaBuffer>|
646         -> Result<()> {
647            if k != head_k {
648                return Ok(());
649            }
650
651            let (_, left_name) = &rel_index[i as usize];
652            let (_, right_name) = &rel_index[j as usize];
653
654            let left_buf = match self.store.get(left_name) {
655                Some(buf) if buf.arity() > 0 => buf,
656                _ => return Ok(()),
657            };
658            let right_buf = match self.store.get(right_name) {
659                Some(buf) if buf.arity() > 0 => buf,
660                _ => return Ok(()),
661            };
662
663            // Skip arity-mismatched relations: the join keys are fixed by
664            // the learnable rule template, so the mapped relation must have
665            // enough columns for every key index. Relations with matching
666            // arity but different semantic column meanings will join without
667            // error; semantic correctness of the mask is the optimizer's
668            // responsibility.
669            let left_max_key = left_keys.iter().copied().max().unwrap_or(0);
670            let right_max_key = right_keys.iter().copied().max().unwrap_or(0);
671            if left_buf.arity() <= left_max_key || right_buf.arity() <= right_max_key {
672                return Ok(());
673            }
674
675            let joined = self.provider.hash_join_v2(
676                left_buf,
677                right_buf,
678                left_keys,
679                right_keys,
680                CudaJoinType::Inner,
681            )?;
682
683            // Project join result to head schema columns if projection is specified.
684            // The join produces [left_cols..., right_cols...] but the head may only
685            // need a subset (e.g. reach(X,Y) from b1(X,Z) join b2(Z,Y) needs cols 0,3).
686            let projected = if !head_projection.is_empty() && head_projection.len() < joined.arity()
687            {
688                let proj_exprs: Vec<ProjectExpr> = head_projection
689                    .iter()
690                    .map(|&col| ProjectExpr::Column(col))
691                    .collect();
692                self.execute_project(&joined, &proj_exprs)?
693            } else {
694                joined
695            };
696
697            let projected = if let (Some(candidate_idx), Some(active_flags)) =
698                (strict_candidate_idx, strict_flags)
699            {
700                self.provider.filter_buffer_by_candidate_flag(
701                    &projected,
702                    active_flags,
703                    candidate_idx,
704                )?
705            } else {
706                projected
707            };
708
709            // Use the public helper instead of the private device_row_count.
710            let num_rows = read_device_row_count(&self.provider, &projected)? as u32;
711
712            if num_rows > 0 {
713                tag_entries.push(IlpTagEntry {
714                    i,
715                    j,
716                    k,
717                    num_rows,
718                    buffer: Some(projected),
719                });
720            }
721            Ok(())
722        };
723
724        let active_rule_count = match ilp_mask {
725            IlpMask::Dense { hard, soft, .. } => {
726                let active_rules = self.provider.extract_active_rule_indices(
727                    hard,
728                    soft,
729                    schema_size,
730                    max_active_rules,
731                )?;
732                let count = active_rules.len() as u64;
733                for &(i, j, k) in &active_rules {
734                    process_rule(i, j, k, None, None)?;
735                }
736                count
737            }
738            IlpMask::Sparse { active_entries, .. } => {
739                let limit = max_active_rules.min(active_entries.len());
740                for &(i, j, k) in &active_entries[..limit] {
741                    process_rule(i, j, k, None, None)?;
742                }
743                limit as u64
744            }
745            IlpMask::SparseDevice {
746                candidate_order,
747                active_flags,
748                selected_count,
749                ..
750            } => {
751                if *selected_count > 0 {
752                    for (candidate_idx, &(i, j, k)) in candidate_order.iter().enumerate() {
753                        process_rule(i, j, k, Some(candidate_idx), Some(active_flags))?;
754                    }
755                }
756                (*selected_count).min(max_active_rules) as u64
757            }
758        };
759
760        // Union per-rule results by head relation index, borrowing buffers from tag_entries.
761        let mut bufs_by_k: HashMap<u32, Vec<&CudaBuffer>> = HashMap::new();
762        for entry in &tag_entries {
763            if let Some(ref buf) = entry.buffer {
764                bufs_by_k.entry(entry.k).or_default().push(buf);
765            }
766        }
767
768        for (k, buffers) in bufs_by_k {
769            let (_, target_name) = &rel_index[k as usize];
770
771            // Chain-union all buffers (union_gpu takes &CudaBuffer refs)
772            let union_buf = if buffers.len() == 1 {
773                // Single buffer: union with an empty buffer to produce a copy
774                let empty = self
775                    .provider
776                    .create_empty_buffer(buffers[0].schema().clone())?;
777                self.provider.union_gpu(buffers[0], &empty)?
778            } else {
779                let mut acc = self.provider.union_gpu(buffers[0], buffers[1])?;
780                for buf in &buffers[2..] {
781                    acc = self.provider.union_gpu(&acc, buf)?;
782                }
783                acc
784            };
785
786            // Diff against existing and merge
787            if let Some(existing) = self.store.get(target_name) {
788                let delta = self.provider.diff_gpu(&union_buf, existing)?;
789                if !delta.is_empty() {
790                    let merged = self.provider.union_gpu(existing, &delta)?;
791                    self.store_put(target_name, merged);
792                }
793            } else {
794                let key_cols: Vec<usize> = (0..union_buf.arity()).collect();
795                let deduped = self.provider.dedup(&union_buf, &key_cols)?;
796                self.store_put(target_name, deduped);
797            }
798        }
799
800        // Store tag entries with retained buffers.
801        self.ilp_last_result = Some(IlpTaggedResult {
802            entries: tag_entries,
803        });
804
805        if let Some(start) = start {
806            let mem = self.provider.memory().allocated_bytes();
807            self.profiler
808                .record_op("TensorMaskedJoin", 0, active_rule_count, start, mem);
809        }
810
811        // Return empty with head schema (results routed via store).
812        let schema = self
813            .store
814            .get(head_rel_name)
815            .map(|buf| buf.schema().clone())
816            .ok_or_else(|| {
817                XlogError::Execution(format!(
818                    "TensorMaskedJoin: head relation '{}' not found in store \
819                 (was load_facts_into_store called?)",
820                    head_rel_name
821                ))
822            })?;
823        self.provider.create_empty_buffer(schema)
824    }
825}