Skip to main content

pyxlog/
ilp.rs

1// ---------------------------------------------------------------------------
2// ILP (Inductive Logic Programming) Python bindings
3// ---------------------------------------------------------------------------
4
5use std::collections::HashMap as StdHashMap;
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use pyo3::exceptions::{PyRuntimeError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::{PyDict, PySequence};
12
13use xlog_core::{symbol, RelId, ScalarType, Schema};
14use xlog_cuda::{CudaKernelProvider, JoinType};
15use xlog_ir::{ExecutionPlan, RirNode};
16use xlog_logic::ast::{Program as AstProgram, Term, TypeRef};
17use xlog_prob::exact::GpuConfig;
18use xlog_runtime::ilp_registry::IlpMask;
19use xlog_runtime::{read_device_row_count, Executor};
20
21use xlog_cuda::type_seam::GpuScalar;
22
23use super::{
24    dlpack_capsule_from_tensor, dlpack_from_py, ilp_gpu, provider_from_config, types,
25    CompiledIlpProgram, IlpProgramFactory, IlpTaggedCreditDeviceResult,
26};
27
28// ---------------------------------------------------------------------------
29// Helper functions
30// ---------------------------------------------------------------------------
31
32struct RelationExampleGroup {
33    relation: String,
34    query_buf: xlog_cuda::CudaBuffer,
35    num_rows: u32,
36}
37
38fn type_ref_name(typ: &TypeRef) -> String {
39    match typ {
40        TypeRef::Scalar(scalar) => types::scalar_type_name(scalar),
41        TypeRef::Domain(name) => name.clone(),
42        TypeRef::List(inner) => format!("list<{}>", type_ref_name(inner)),
43        TypeRef::Term => "term".to_string(),
44        TypeRef::Compound => "compound".to_string(),
45        TypeRef::PredRef => "predref".to_string(),
46    }
47}
48
49fn collect_dlpack_columns(
50    dlpack_columns: &Bound<'_, PyAny>,
51    err_msg: &str,
52) -> PyResult<Vec<xlog_cuda::DlpackManagedTensor>> {
53    let seq = dlpack_columns
54        .downcast::<PySequence>()
55        .map_err(|_| PyValueError::new_err(err_msg.to_string()))?;
56    let mut tensors = Vec::with_capacity(seq.len()?);
57    for item in seq.try_iter()? {
58        tensors.push(dlpack_from_py(&item?)?);
59    }
60    Ok(tensors)
61}
62
63/// Allocate zero-initialized loss (scalar) and grad (num_cands) on GPU and
64/// export both as DLPack capsules. Shared by both f32 and f64 paths.
65fn build_zero_typed<T>(
66    provider: &CudaKernelProvider,
67    py: Python<'_>,
68    num_cands: u32,
69    scalar_type: ScalarType,
70) -> PyResult<(PyObject, PyObject)>
71where
72    T: GpuScalar + cudarc::driver::ValidAsZeroBits,
73{
74    let mut d_grad = provider
75        .memory()
76        .alloc::<T>(num_cands as usize)
77        .map_err(|e| types::gpu_err("alloc grad", e))?;
78    if num_cands > 0 {
79        provider
80            .device()
81            .inner()
82            .memset_zeros(&mut d_grad)
83            .map_err(|e| types::gpu_err("zero grad", e))?;
84    }
85    let mut d_loss = provider
86        .memory()
87        .alloc::<T>(1)
88        .map_err(|e| types::gpu_err("alloc loss", e))?;
89    provider
90        .device()
91        .inner()
92        .memset_zeros(&mut d_loss)
93        .map_err(|e| types::gpu_err("zero loss", e))?;
94    ilp_gpu::export_loss_grad_device(provider, py, d_loss, d_grad, num_cands, scalar_type)
95}
96
97fn export_device_bool_tensor(
98    provider: &CudaKernelProvider,
99    py: Python<'_>,
100    values: xlog_cuda::memory::TrackedCudaSlice<u8>,
101    rows: usize,
102) -> PyResult<PyObject> {
103    let rows_u32 = u32::try_from(rows)
104        .map_err(|_| PyValueError::new_err(format!("Row count {} exceeds u32::MAX", rows)))?;
105
106    let mut d_num_rows = provider.memory().alloc::<u32>(1).map_err(types::xlog_err)?;
107    provider
108        .device()
109        .inner()
110        .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
111        .map_err(types::xlog_err)?;
112
113    let buffer = xlog_cuda::CudaBuffer::from_columns(
114        vec![values.into_bytes().into()],
115        rows as u64,
116        d_num_rows,
117        Schema::new(vec![("col0".to_string(), ScalarType::Bool)]),
118    );
119    let tensor = provider
120        .to_dlpack_table(buffer)
121        .column(0)
122        .map_err(types::xlog_err)?;
123    dlpack_capsule_from_tensor(py, tensor)
124}
125
126fn export_device_u32_tensor_as_i32(
127    provider: &CudaKernelProvider,
128    py: Python<'_>,
129    values: xlog_cuda::memory::TrackedCudaSlice<u32>,
130    rows: usize,
131) -> PyResult<PyObject> {
132    let rows_u32 = u32::try_from(rows)
133        .map_err(|_| PyValueError::new_err(format!("Row count {} exceeds u32::MAX", rows)))?;
134
135    let mut d_num_rows = provider.memory().alloc::<u32>(1).map_err(types::xlog_err)?;
136    provider
137        .device()
138        .inner()
139        .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
140        .map_err(types::xlog_err)?;
141
142    // PyTorch does not support unsigned 32-bit DLPack tensors; export as i32.
143    let buffer = xlog_cuda::CudaBuffer::from_columns(
144        vec![values.into_bytes().into()],
145        rows as u64,
146        d_num_rows,
147        Schema::new(vec![("col0".to_string(), ScalarType::I32)]),
148    );
149    let tensor = provider
150        .to_dlpack_table(buffer)
151        .column(0)
152        .map_err(types::xlog_err)?;
153    dlpack_capsule_from_tensor(py, tensor)
154}
155
156fn empty_tagged_credit_device_result(
157    provider: &CudaKernelProvider,
158    py: Python<'_>,
159    num_facts: usize,
160) -> PyResult<IlpTaggedCreditDeviceResult> {
161    let mut d_row_offsets = provider
162        .memory()
163        .alloc::<u32>(num_facts + 1)
164        .map_err(types::xlog_err)?;
165    provider
166        .device()
167        .inner()
168        .memset_zeros(&mut d_row_offsets)
169        .map_err(types::xlog_err)?;
170    let d_empty_indices = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
171    let d_empty_i = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
172    let d_empty_j = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
173    let d_empty_k = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
174
175    Ok(IlpTaggedCreditDeviceResult {
176        fact_row_offsets: export_device_u32_tensor_as_i32(
177            provider,
178            py,
179            d_row_offsets,
180            num_facts + 1,
181        )?,
182        entry_indices: export_device_u32_tensor_as_i32(provider, py, d_empty_indices, 0)?,
183        entry_i: export_device_u32_tensor_as_i32(provider, py, d_empty_i, 0)?,
184        entry_j: export_device_u32_tensor_as_i32(provider, py, d_empty_j, 0)?,
185        entry_k: export_device_u32_tensor_as_i32(provider, py, d_empty_k, 0)?,
186    })
187}
188
189fn push_term_bytes(out: &mut Vec<u8>, term: &Term, typ: ScalarType) -> xlog_core::Result<()> {
190    use xlog_core::XlogError;
191    match (typ, term) {
192        (ScalarType::U32, Term::Integer(v)) => {
193            let v = u32::try_from(*v)
194                .map_err(|_| XlogError::Execution(format!("u32 out of range: {}", v)))?;
195            out.extend_from_slice(&v.to_le_bytes());
196        }
197        (ScalarType::U64, Term::Integer(v)) => {
198            let v = u64::try_from(*v)
199                .map_err(|_| XlogError::Execution(format!("u64 out of range: {}", v)))?;
200            out.extend_from_slice(&v.to_le_bytes());
201        }
202        (ScalarType::I32, Term::Integer(v)) => {
203            let v = i32::try_from(*v)
204                .map_err(|_| XlogError::Execution(format!("i32 out of range: {}", v)))?;
205            out.extend_from_slice(&v.to_le_bytes());
206        }
207        (ScalarType::I64, Term::Integer(v)) => {
208            out.extend_from_slice(&v.to_le_bytes());
209        }
210        (ScalarType::F32, Term::Float(v)) => {
211            out.extend_from_slice(&(*v as f32).to_le_bytes());
212        }
213        (ScalarType::F64, Term::Float(v)) => {
214            out.extend_from_slice(&v.to_le_bytes());
215        }
216        (ScalarType::F32, Term::Integer(v)) => {
217            out.extend_from_slice(&(*v as f32).to_le_bytes());
218        }
219        (ScalarType::F64, Term::Integer(v)) => {
220            out.extend_from_slice(&(*v as f64).to_le_bytes());
221        }
222        (ScalarType::Bool, Term::Integer(v)) => {
223            let b = match *v {
224                0 => 0u8,
225                1 => 1u8,
226                other => {
227                    return Err(XlogError::Execution(format!(
228                        "bool expects 0/1, got {}",
229                        other
230                    )));
231                }
232            };
233            out.push(b);
234        }
235        (ScalarType::Bool, Term::Symbol(id)) => {
236            let s = symbol::resolve(*id);
237            if s == "true" || s == "false" {
238                out.push(if s == "true" { 1u8 } else { 0u8 });
239            } else {
240                return Err(XlogError::Execution(format!(
241                    "Expected boolean symbol, got '{}'",
242                    s
243                )));
244            }
245        }
246        (ScalarType::Symbol, Term::String(s)) => {
247            out.extend_from_slice(&symbol::intern(s).to_le_bytes());
248        }
249        (ScalarType::Symbol, Term::Symbol(id)) => {
250            out.extend_from_slice(&id.to_le_bytes());
251        }
252        (_, Term::Variable(v)) => {
253            return Err(XlogError::Execution(format!(
254                "Fact cannot contain variable {}",
255                v
256            )));
257        }
258        (_, Term::Anonymous) => {
259            return Err(XlogError::Execution(
260                "Fact cannot contain anonymous wildcard '_'".into(),
261            ));
262        }
263        (_, Term::Aggregate(_)) => {
264            return Err(XlogError::Execution("Fact cannot contain aggregate".into()));
265        }
266        (expected, got) => {
267            return Err(XlogError::Execution(format!(
268                "Type mismatch: expected {:?}, got {:?}",
269                expected, got
270            )));
271        }
272    }
273    Ok(())
274}
275
276/// Pack `i64` fact values into typed byte columns according to schema.
277/// Returns one `Vec<u8>` per column with correctly-encoded LE bytes.
278/// Rejects F32/F64 columns (not supported in batch APIs).
279pub(crate) fn pack_i64_columns_typed(
280    relation: &str,
281    facts: &[Vec<i64>],
282    schema: &Schema,
283) -> PyResult<Vec<Vec<u8>>> {
284    let arity = schema.arity();
285    for (idx, fact) in facts.iter().enumerate() {
286        if fact.len() != arity {
287            return Err(PyValueError::new_err(format!(
288                "Relation '{}': fact {} has {} values, expected {}",
289                relation,
290                idx,
291                fact.len(),
292                arity,
293            )));
294        }
295    }
296
297    let mut columns: Vec<Vec<u8>> = (0..arity)
298        .map(|col_idx| {
299            let elem_size = schema
300                .column_type(col_idx)
301                .map(|t| t.size_bytes())
302                .unwrap_or(4);
303            Vec::with_capacity(facts.len() * elem_size)
304        })
305        .collect();
306
307    for fact in facts {
308        for (col_idx, &val) in fact.iter().enumerate() {
309            let col_type = schema.column_type(col_idx);
310            let col = &mut columns[col_idx];
311            match col_type {
312                Some(ScalarType::U32) => {
313                    let v = u32::try_from(val).map_err(|_| {
314                        PyValueError::new_err(format!(
315                            "Relation '{}' column {} (U32): value {} out of range [0, {}]",
316                            relation,
317                            col_idx,
318                            val,
319                            u32::MAX,
320                        ))
321                    })?;
322                    col.extend_from_slice(&v.to_le_bytes());
323                }
324                Some(ScalarType::I32) => {
325                    let v = i32::try_from(val).map_err(|_| {
326                        PyValueError::new_err(format!(
327                            "Relation '{}' column {} (I32): value {} out of range [{}, {}]",
328                            relation,
329                            col_idx,
330                            val,
331                            i32::MIN,
332                            i32::MAX,
333                        ))
334                    })?;
335                    col.extend_from_slice(&v.to_le_bytes());
336                }
337                Some(ScalarType::U64) => {
338                    let v = u64::try_from(val).map_err(|_| PyValueError::new_err(format!(
339                        "Relation '{}' column {} (U64): value {} is negative; U64 requires non-negative values",
340                        relation, col_idx, val,
341                    )))?;
342                    col.extend_from_slice(&v.to_le_bytes());
343                }
344                Some(ScalarType::I64) => {
345                    col.extend_from_slice(&val.to_le_bytes());
346                }
347                Some(ScalarType::Bool) => match val {
348                    0 => col.push(0u8),
349                    1 => col.push(1u8),
350                    _ => {
351                        return Err(PyValueError::new_err(format!(
352                            "Relation '{}' column {} (Bool): value {} not in {{0, 1}}",
353                            relation, col_idx, val,
354                        )))
355                    }
356                },
357                Some(ScalarType::Symbol) => {
358                    let v = u32::try_from(val).map_err(|_| {
359                        PyValueError::new_err(format!(
360                            "Relation '{}' column {} (Symbol): value {} out of range [0, {}]",
361                            relation,
362                            col_idx,
363                            val,
364                            u32::MAX,
365                        ))
366                    })?;
367                    col.extend_from_slice(&v.to_le_bytes());
368                }
369                Some(ScalarType::F32) => {
370                    return Err(PyValueError::new_err(format!(
371                        "Relation '{}' column {} (F32): float columns not supported in batch APIs",
372                        relation, col_idx,
373                    )));
374                }
375                Some(ScalarType::F64) => {
376                    return Err(PyValueError::new_err(format!(
377                        "Relation '{}' column {} (F64): float columns not supported in batch APIs",
378                        relation, col_idx,
379                    )));
380                }
381                None => {
382                    return Err(PyValueError::new_err(format!(
383                        "Relation '{}' column {}: no type in schema",
384                        relation, col_idx,
385                    )));
386                }
387            }
388        }
389    }
390
391    Ok(columns)
392}
393
394pub(crate) fn load_facts_into_store(
395    ast: &AstProgram,
396    provider: &CudaKernelProvider,
397    executor: &mut Executor,
398    schemas: &HashMap<String, Schema>,
399) -> xlog_core::Result<()> {
400    use xlog_core::XlogError;
401    let mut rows_by_pred: HashMap<&str, Vec<&[Term]>> = HashMap::new();
402    for fact in ast.facts() {
403        rows_by_pred
404            .entry(fact.head.predicate.as_str())
405            .or_default()
406            .push(&fact.head.terms);
407    }
408
409    for (pred, rows) in rows_by_pred {
410        let schema = schemas.get(pred).ok_or_else(|| {
411            XlogError::Execution(format!("Missing schema for fact predicate {}", pred))
412        })?;
413
414        if rows.iter().any(|r| r.len() != schema.arity()) {
415            return Err(XlogError::Execution(format!(
416                "Fact arity mismatch for {} (expected {})",
417                pred,
418                schema.arity()
419            )));
420        }
421
422        let mut columns: Vec<Vec<u8>> = vec![Vec::new(); schema.arity()];
423        for row in &rows {
424            for (col_idx, term) in row.iter().enumerate() {
425                let typ = schema.column_type(col_idx).ok_or_else(|| {
426                    XlogError::Execution(format!("Missing type for col {}", col_idx))
427                })?;
428                push_term_bytes(&mut columns[col_idx], term, typ)?;
429            }
430        }
431
432        let slices: Vec<&[u8]> = columns.iter().map(|c| c.as_slice()).collect();
433        let fact_buf = provider.create_buffer_from_slices(&slices, schema.clone())?;
434
435        let existing = executor.store().get(pred).ok_or_else(|| {
436            XlogError::Execution(format!(
437                "Missing base relation {} while loading facts",
438                pred
439            ))
440        })?;
441        let merged = provider.union(existing, &fact_buf)?;
442        executor.store_mut().put(pred, merged);
443    }
444    Ok(())
445}
446
447/// Extracted TensorMaskedJoin metadata from the execution plan.
448struct TmjMeta {
449    left_keys: Vec<usize>,
450    right_keys: Vec<usize>,
451    head_projection: Vec<usize>,
452    schema_size: usize,
453    head_rel_name: String,
454}
455
456fn walk_tmj(node: &RirNode, target_mask: Option<&str>) -> Option<TmjMeta> {
457    match node {
458        RirNode::TensorMaskedJoin {
459            mask_name,
460            left_keys,
461            right_keys,
462            head_projection,
463            schema_size,
464            head_rel_name,
465            ..
466        } => {
467            if target_mask.is_none() || target_mask == Some(mask_name.as_str()) {
468                Some(TmjMeta {
469                    left_keys: left_keys.clone(),
470                    right_keys: right_keys.clone(),
471                    head_projection: head_projection.clone(),
472                    schema_size: *schema_size,
473                    head_rel_name: head_rel_name.clone(),
474                })
475            } else {
476                None
477            }
478        }
479        RirNode::Fixpoint {
480            base, recursive, ..
481        } => walk_tmj(base, target_mask).or_else(|| walk_tmj(recursive, target_mask)),
482        RirNode::Union { inputs } => inputs.iter().find_map(|n| walk_tmj(n, target_mask)),
483        RirNode::Filter { input, .. }
484        | RirNode::Project { input, .. }
485        | RirNode::Distinct { input, .. }
486        | RirNode::GroupBy { input, .. } => walk_tmj(input, target_mask),
487        RirNode::Join { left, right, .. } | RirNode::Diff { left, right } => {
488            walk_tmj(left, target_mask).or_else(|| walk_tmj(right, target_mask))
489        }
490        // Descend into the fallback so a TMJ wrapped beneath a promoted
491        // MultiWayJoin is still discoverable. The promoter does not currently
492        // wrap TMJ-bearing trees, but the explicit arm documents the contract
493        // instead of relying on the catch-all.
494        RirNode::MultiWayJoin { fallback, .. } => walk_tmj(fallback, target_mask),
495        RirNode::ChainJoin { fallback, .. } => walk_tmj(fallback, target_mask),
496        _ => None,
497    }
498}
499
500fn extract_tmj_meta(plan: &ExecutionPlan) -> TmjMeta {
501    extract_tmj_meta_for_mask(plan, None)
502}
503
504fn extract_tmj_meta_for_mask(plan: &ExecutionPlan, mask_name: Option<&str>) -> TmjMeta {
505    for scc_rules in &plan.rules_by_scc {
506        for rule in scc_rules {
507            if let Some(meta) = walk_tmj(&rule.body, mask_name) {
508                return meta;
509            }
510        }
511    }
512    TmjMeta {
513        left_keys: vec![],
514        right_keys: vec![],
515        head_projection: vec![],
516        schema_size: 0,
517        head_rel_name: String::new(),
518    }
519}
520
521fn strip_learnable_declarations(source: &str) -> String {
522    source
523        .lines()
524        .filter(|line| !line.trim_start().starts_with("learnable("))
525        .collect::<Vec<_>>()
526        .join("\n")
527}
528
529fn extract_learnable_declarations(source: &str) -> String {
530    source
531        .lines()
532        .filter(|line| line.trim_start().starts_with("learnable("))
533        .collect::<Vec<_>>()
534        .join("\n")
535}
536
537// ---------------------------------------------------------------------------
538// IlpProgramFactory
539// ---------------------------------------------------------------------------
540
541#[pymethods]
542impl IlpProgramFactory {
543    #[staticmethod]
544    #[pyo3(signature = (source, device=0, memory_mb=512, max_active_rules=None))]
545    pub fn compile(
546        source: &str,
547        device: usize,
548        memory_mb: u64,
549        max_active_rules: Option<usize>,
550    ) -> PyResult<CompiledIlpProgram> {
551        // Validate max_active_rules range
552        if let Some(max) = max_active_rules {
553            if !(16..=128).contains(&max) {
554                return Err(PyValueError::new_err(format!(
555                    "max_active_rules must be between 16 and 128, got {}",
556                    max
557                )));
558            }
559        }
560
561        let ast = xlog_logic::parse_program(source).map_err(types::val_err)?;
562
563        let base_source = strip_learnable_declarations(source);
564        let learnable_source = extract_learnable_declarations(source);
565
566        let mut compiler = xlog_logic::Compiler::new();
567        if let Some(max) = max_active_rules {
568            compiler.set_max_active_rules(max);
569        }
570        let plan = compiler.compile_program(&ast).map_err(types::xlog_err)?;
571
572        let mut rel_index: Vec<(RelId, String)> = compiler
573            .rel_ids()
574            .iter()
575            .map(|(name, id)| (*id, name.clone()))
576            .collect();
577        rel_index.sort_by_key(|(id, _)| id.0);
578        let schemas = compiler.schemas().clone();
579
580        let mut config = GpuConfig::default();
581        config.device_ordinal = device;
582        config.memory_bytes = memory_mb * 1024 * 1024;
583        let provider = Arc::new(provider_from_config(config).map_err(types::xlog_err)?);
584
585        let mut executor = Executor::new(provider.clone());
586
587        for (name, rel_id) in compiler.rel_ids() {
588            executor.register_relation(*rel_id, name);
589        }
590
591        for (name, schema) in &schemas {
592            let empty = provider
593                .create_empty_buffer(schema.clone())
594                .map_err(types::xlog_err)?;
595            executor.store_mut().put(name, empty);
596        }
597
598        load_facts_into_store(&ast, &provider, &mut executor, &schemas).map_err(types::xlog_err)?;
599
600        executor.execute_plan(&plan).map_err(types::xlog_err)?;
601
602        let tmj = extract_tmj_meta(&plan);
603
604        let active_rules = max_active_rules.unwrap_or(32);
605
606        Ok(CompiledIlpProgram {
607            base_source,
608            _learnable_source: learnable_source,
609            ast,
610            executor,
611            provider,
612            plan,
613            rel_index,
614            schemas,
615            left_keys: tmj.left_keys,
616            right_keys: tmj.right_keys,
617            head_projection: tmj.head_projection,
618            compiled_schema_size: tmj.schema_size,
619            head_rel_name: tmj.head_rel_name,
620            max_active_rules: active_rules,
621            candidate_map: None,
622            candidate_order: None,
623            relation_overrides: HashMap::new(),
624            coo_chunk_budget: 16 * 1024 * 1024,
625            strict_zero_dtoh: false,
626        })
627    }
628}
629
630// ---------------------------------------------------------------------------
631// CompiledIlpProgram — #[pymethods] block
632// ---------------------------------------------------------------------------
633
634#[pymethods]
635impl CompiledIlpProgram {
636    /// Upload candidate (i,j,k) -> index mapping. Called once per attempt.
637    pub fn set_candidate_map(&mut self, candidates: Vec<(u32, u32, u32)>) -> PyResult<()> {
638        let mut map = HashMap::with_capacity(candidates.len());
639        for (cidx, &(i, j, k)) in candidates.iter().enumerate() {
640            map.insert((i, j, k), cidx as u32);
641        }
642        self.candidate_map = Some(map);
643        self.candidate_order = Some(candidates);
644        Ok(())
645    }
646
647    /// Length of current candidate map (0 if not set).
648    pub fn candidate_map_len(&self) -> usize {
649        self.candidate_map.as_ref().map_or(0, |m| m.len())
650    }
651
652    pub fn debug_ilp_mask_kind(&self, name: String) -> Option<String> {
653        self.executor.ilp_registry().get_mask(&name).map(|mask| {
654            match mask {
655                IlpMask::Dense { .. } => "dense",
656                IlpMask::Sparse { .. } => "sparse_host",
657                IlpMask::SparseDevice { .. } => "sparse_device",
658            }
659            .to_string()
660        })
661    }
662
663    /// Set the per-chunk temp allocation budget in bytes. The final merged
664    /// COO buffer is exact-NNZ sized and may exceed this budget. Default: 16 MB.
665    pub fn set_coo_chunk_budget(&mut self, bytes: u64) {
666        self.coo_chunk_budget = bytes;
667    }
668
669    /// Deprecated: use `set_coo_chunk_budget`. Kept for one release cycle.
670    #[allow(deprecated)]
671    pub fn set_coo_memory_cap(&mut self, bytes: u64) {
672        self.coo_chunk_budget = bytes;
673    }
674
675    /// Enable strict zero-D2H mode. When true, raises RuntimeError instead
676    /// of falling back to the chunked COO path (which uses D2H transfers).
677    /// Use for zero-D2H benchmarks and CI gates.
678    pub fn set_strict_zero_dtoh(&mut self, strict: bool) {
679        self.strict_zero_dtoh = strict;
680    }
681
682    /// Upload a named relation as DLPack tensor columns into the ILP program store.
683    ///
684    /// This enables tensor-native ILP data upload: compile a schema-only source
685    /// (predicate declarations only, no facts), then upload GPU tensor data via
686    /// DLPack without host materialization.
687    pub fn put_relation(
688        &mut self,
689        name: String,
690        dlpack_columns: &Bound<'_, PyAny>,
691    ) -> PyResult<()> {
692        if name.starts_with("__") {
693            return Err(PyValueError::new_err(format!(
694                "Relation {} is internal and cannot be stored in a compiled ILP program",
695                name
696            )));
697        }
698        let schema = self.schemas.get(&name).ok_or_else(|| {
699            PyValueError::new_err(format!(
700                "Unknown relation {} (not present in compiled schemas)",
701                name
702            ))
703        })?;
704        let tensors = collect_dlpack_columns(
705            dlpack_columns,
706            &format!("Relation {} must be a sequence of DLPack columns", name),
707        )?;
708        let buffer = self
709            .provider
710            .from_dlpack_tensors_with_schema(schema.clone(), tensors)
711            .map_err(types::xlog_err)?;
712        let live_buffer = self
713            .provider
714            .clone_buffer(&buffer)
715            .map_err(types::xlog_err)?;
716        self.relation_overrides.insert(name.clone(), buffer);
717        self.executor.put_relation(&name, live_buffer);
718        Ok(())
719    }
720
721    /// GPU-resident ILP loss + gradient computation.
722    ///
723    /// Builds a sparse CSR structure from retained per-entry membership masks
724    /// and launches forward (credit gather + NLL loss) and backward (gradient
725    /// scatter) CUDA kernels.  Returns `(loss_dlpack, grad_dlpack)` where both
726    /// are GPU-resident tensors exported via DLPack.
727    ///
728    /// # Arguments
729    /// * `positives` — list of `(relation, [col_values])` positive examples
730    /// * `negatives` — list of `(relation, [col_values])` negative examples
731    /// * `cand_probs_obj` — DLPack/PyTorch tensor of candidate probabilities on GPU
732    pub fn compute_ilp_loss_grad_gpu<'py>(
733        &mut self,
734        py: Python<'py>,
735        positives: Vec<(String, Vec<i64>)>,
736        negatives: Vec<(String, Vec<i64>)>,
737        cand_probs_obj: &Bound<'py, PyAny>,
738    ) -> PyResult<(PyObject, PyObject)> {
739        // Validate inputs and import candidate probabilities via DLPack.
740        let candidate_map = self.candidate_map.clone().ok_or_else(|| {
741            PyRuntimeError::new_err(
742                "candidate_map not set — call set_candidate_map() before compute_ilp_loss_grad_gpu()",
743            )
744        })?;
745        let num_cands = candidate_map.len() as u32;
746
747        let managed = dlpack_from_py(cand_probs_obj)?;
748        let cand_buf = self
749            .provider
750            .from_dlpack_tensors(vec![managed])
751            .map_err(|e| types::gpu_err("DLPack import", e))?;
752
753        // Determine dtype from imported buffer
754        let cand_schema = cand_buf.schema().clone();
755        let cand_dtype = cand_schema
756            .column_type(0)
757            .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column type"))?;
758        let is_f64 = match cand_dtype {
759            ScalarType::F32 => false,
760            ScalarType::F64 => true,
761            other => {
762                return Err(PyValueError::new_err(format!(
763                    "cand_probs must be F32 or F64, got {:?}",
764                    other
765                )));
766            }
767        };
768
769        let cand_rows = read_device_row_count(&self.provider, &cand_buf)
770            .map_err(|e| types::gpu_err("row count", e))?;
771        if cand_rows != num_cands as usize {
772            return Err(PyValueError::new_err(format!(
773                "cand_probs length ({}) != candidate_map length ({})",
774                cand_rows, num_cands
775            )));
776        }
777
778        // Build the fact list and group examples by relation.
779        let num_pos = positives.len();
780        let num_neg = negatives.len();
781        let num_facts = (num_pos + num_neg) as u32;
782
783        // Build all_facts: (relation, values, is_positive, global_fact_idx)
784        struct FactInfo {
785            relation: String,
786            values: Vec<i64>,
787            is_positive: bool,
788        }
789        let mut all_facts: Vec<FactInfo> = Vec::with_capacity(num_pos + num_neg);
790        for (rel, vals) in positives {
791            all_facts.push(FactInfo {
792                relation: rel,
793                values: vals,
794                is_positive: true,
795            });
796        }
797        for (rel, vals) in negatives {
798            all_facts.push(FactInfo {
799                relation: rel,
800                values: vals,
801                is_positive: false,
802            });
803        }
804
805        // Handle empty facts edge case: return zero loss + zero grad
806        if num_facts == 0 {
807            return self.build_zero_loss_grad(py, num_cands, is_f64);
808        }
809
810        // Build is_positive array for upload
811        let is_positive_host: Vec<u8> = all_facts
812            .iter()
813            .map(|f| if f.is_positive { 1u8 } else { 0u8 })
814            .collect();
815
816        // Group facts by relation, preserving global fact_idx
817        let mut groups: HashMap<String, Vec<(u32, Vec<i64>)>> = HashMap::new();
818        for (global_idx, fact) in all_facts.iter().enumerate() {
819            groups
820                .entry(fact.relation.clone())
821                .or_default()
822                .push((global_idx as u32, fact.values.clone()));
823        }
824
825        if self.strict_zero_dtoh && self.executor.ilp_registry().has_sparse_device_mask() {
826            self.evaluate_ilp_plan(py)?;
827        }
828
829        // Get ILP tagged result
830        let tagged = self
831            .executor
832            .ilp_last_result()
833            .ok_or_else(|| PyRuntimeError::new_err("No ILP result — call evaluate() first"))?;
834
835        // Build the COO structure on the device without D2H reads.
836        //
837        // Two-pass approach over (relation, candidate) tasks:
838        //   Pass 1: compute GPU membership masks
839        //   Pass 2: scatter COO entries at host-computed offsets via device kernel
840        //
841        // The key insight is that each task's num_query is known on the host
842        // (it equals the number of facts in that relation group), so we can
843        // compute COO write offsets entirely on the host without any D2H reads.
844        // We over-allocate COO arrays at upper_bound (sum of all num_query)
845        // and fill sentinel values for unused slots.
846
847        let mut tasks: Vec<ilp_gpu::CooTask> = Vec::new();
848        let mut fact_indices_buffers: Vec<xlog_cuda::memory::TrackedCudaSlice<u32>> = Vec::new();
849
850        for (relation, facts_with_idx) in &groups {
851            let k_idx = self
852                .rel_index
853                .iter()
854                .position(|(_, name)| name == relation)
855                .ok_or_else(|| {
856                    PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
857                })? as u32;
858
859            let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
860                .entries
861                .iter()
862                .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
863                .collect();
864
865            if relevant_entries.is_empty() {
866                continue;
867            }
868
869            let first_buf = relevant_entries[0]
870                .buffer
871                .as_ref()
872                .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
873            let arity = first_buf.arity();
874            if arity == 0 {
875                continue;
876            }
877            let schema = first_buf.schema().clone();
878
879            let fact_values: Vec<Vec<i64>> =
880                facts_with_idx.iter().map(|(_, v)| v.clone()).collect();
881            let col_bytes = pack_i64_columns_typed(relation, &fact_values, &schema)?;
882            let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
883            let query_buf = self
884                .provider
885                .create_buffer_from_slices(&col_slices, schema)
886                .map_err(|e| types::gpu_err("create_buffer", e))?;
887
888            let keys: Vec<usize> = (0..arity).collect();
889            let num_query = fact_values.len() as u32;
890
891            // Upload global fact indices for this relation group (H2D, allowed).
892            // Shared across all entries in this relation group via index.
893            let global_indices: Vec<u32> = facts_with_idx.iter().map(|(idx, _)| *idx).collect();
894            let mut d_fact_indices = self
895                .provider
896                .memory()
897                .alloc::<u32>(num_query as usize)
898                .map_err(|e| types::gpu_err("alloc fact_indices", e))?;
899            self.provider
900                .device()
901                .inner()
902                .htod_sync_copy_into(&global_indices, &mut d_fact_indices)
903                .map_err(|e| types::gpu_err("htod fact_indices", e))?;
904            let fi_idx = fact_indices_buffers.len();
905            fact_indices_buffers.push(d_fact_indices);
906
907            for entry in &relevant_entries {
908                let cidx = match candidate_map.get(&(entry.i, entry.j, entry.k)) {
909                    Some(c) => *c,
910                    None => continue,
911                };
912
913                let entry_buf = entry.buffer.as_ref().ok_or_else(|| {
914                    PyRuntimeError::new_err("internal: filtered entry has no buffer")
915                })?;
916                let d_mask = self
917                    .provider
918                    .membership_mask_device(&query_buf, entry_buf, &keys, &keys)
919                    .map_err(|e| types::gpu_err("membership_mask", e))?;
920
921                tasks.push(ilp_gpu::CooTask {
922                    cidx,
923                    num_query,
924                    d_mask,
925                    fact_indices_idx: fi_idx,
926                });
927            }
928        }
929
930        let num_tasks = tasks.len();
931        let upper_bound: u32 = tasks.iter().map(|t| t.num_query).sum();
932
933        if upper_bound == 0 || num_tasks == 0 {
934            return self.build_loss_grad_empty_coo(
935                py,
936                &is_positive_host,
937                num_facts,
938                num_cands,
939                is_f64,
940            );
941        }
942
943        // Check if COO allocation would exceed memory cap.
944        // Each COO entry uses 8 bytes (4 for fact_idx + 4 for cand_idx).
945        let coo_bytes = (upper_bound as u64) * 8;
946        let needs_chunking = coo_bytes > self.coo_chunk_budget;
947
948        // Upload is_positive once (shared across all paths, H2D allowed)
949        let mut d_is_positive = self
950            .provider
951            .memory()
952            .alloc::<u8>(num_facts as usize)
953            .map_err(|e| types::gpu_err("alloc is_positive", e))?;
954        self.provider
955            .device()
956            .inner()
957            .htod_sync_copy_into(&is_positive_host, &mut d_is_positive)
958            .map_err(|e| types::gpu_err("htod is_positive", e))?;
959
960        let cand_col = cand_buf
961            .column(0)
962            .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column"))?;
963
964        // Construct COO entries.
965        let (mut d_coo_facts, mut d_coo_cands, actual_nnz) = if !needs_chunking {
966            ilp_gpu::build_coo_single(
967                &self.provider,
968                &tasks,
969                &fact_indices_buffers,
970                num_facts,
971                num_cands,
972                upper_bound,
973            )?
974        } else {
975            match ilp_gpu::build_coo_chunked(
976                &self.provider,
977                &tasks,
978                &fact_indices_buffers,
979                num_facts,
980                self.coo_chunk_budget,
981            )? {
982                Some(result) => result,
983                None => {
984                    return self.build_loss_grad_empty_coo(
985                        py,
986                        &is_positive_host,
987                        num_facts,
988                        num_cands,
989                        is_f64,
990                    );
991                }
992            }
993        };
994
995        // Sort COO entries and build CSR row offsets on the device.
996        let d_row_offsets = ilp_gpu::sort_and_build_csr(
997            &self.provider,
998            &mut d_coo_facts,
999            &mut d_coo_cands,
1000            actual_nnz,
1001            num_facts,
1002        )?;
1003
1004        // Run forward, backward, and device-side reduction kernels.
1005        ilp_gpu::forward_backward_reduce(
1006            &self.provider,
1007            py,
1008            &d_row_offsets,
1009            &d_coo_cands,
1010            cand_col,
1011            &d_is_positive,
1012            num_facts,
1013            num_cands,
1014            is_f64,
1015        )
1016    }
1017
1018    /// GPU-resident ILP loss + gradient computation from relation-native
1019    /// positive/negative examples stored as DLPack column groups.
1020    ///
1021    /// `positives_by_relation` and `negatives_by_relation` must be Python dicts
1022    /// mapping relation name -> sequence of DLPack columns with schema-matching
1023    /// dtypes. No host fact tuple materialization occurs on this path.
1024    pub fn compute_ilp_loss_grad_gpu_relations<'py>(
1025        &mut self,
1026        py: Python<'py>,
1027        positives_by_relation: &Bound<'py, PyAny>,
1028        negatives_by_relation: &Bound<'py, PyAny>,
1029        cand_probs_obj: &Bound<'py, PyAny>,
1030    ) -> PyResult<(PyObject, PyObject)> {
1031        let candidate_map = self.candidate_map.clone().ok_or_else(|| {
1032            PyRuntimeError::new_err(
1033                "candidate_map not set — call set_candidate_map() before compute_ilp_loss_grad_gpu_relations()",
1034            )
1035        })?;
1036        let num_cands = candidate_map.len() as u32;
1037
1038        let managed = dlpack_from_py(cand_probs_obj)?;
1039        let cand_buf = self
1040            .provider
1041            .from_dlpack_tensors(vec![managed])
1042            .map_err(|e| types::gpu_err("DLPack import", e))?;
1043
1044        let cand_schema = cand_buf.schema().clone();
1045        let cand_dtype = cand_schema
1046            .column_type(0)
1047            .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column type"))?;
1048        let is_f64 = match cand_dtype {
1049            ScalarType::F32 => false,
1050            ScalarType::F64 => true,
1051            other => {
1052                return Err(PyValueError::new_err(format!(
1053                    "cand_probs must be F32 or F64, got {:?}",
1054                    other
1055                )));
1056            }
1057        };
1058
1059        let cand_rows = read_device_row_count(&self.provider, &cand_buf)
1060            .map_err(|e| types::gpu_err("row count", e))?;
1061        if cand_rows != num_cands as usize {
1062            return Err(PyValueError::new_err(format!(
1063                "cand_probs length ({}) != candidate_map length ({})",
1064                cand_rows, num_cands
1065            )));
1066        }
1067
1068        let positive_groups = self.collect_relation_example_groups(
1069            positives_by_relation,
1070            "positives_by_relation must be a dict[str, sequence[dlpack]]",
1071        )?;
1072        let negative_groups = self.collect_relation_example_groups(
1073            negatives_by_relation,
1074            "negatives_by_relation must be a dict[str, sequence[dlpack]]",
1075        )?;
1076
1077        let num_pos: u32 = positive_groups.iter().map(|g| g.num_rows).sum();
1078        let num_neg: u32 = negative_groups.iter().map(|g| g.num_rows).sum();
1079        let num_facts = num_pos + num_neg;
1080
1081        if num_facts == 0 {
1082            return self.build_zero_loss_grad(py, num_cands, is_f64);
1083        }
1084
1085        let mut is_positive_host = Vec::with_capacity(num_facts as usize);
1086        is_positive_host.extend(std::iter::repeat_n(1u8, num_pos as usize));
1087        is_positive_host.extend(std::iter::repeat_n(0u8, num_neg as usize));
1088
1089        if self.strict_zero_dtoh && self.executor.ilp_registry().has_sparse_device_mask() {
1090            self.evaluate_ilp_plan(py)?;
1091        }
1092
1093        let tagged = self
1094            .executor
1095            .ilp_last_result()
1096            .ok_or_else(|| PyRuntimeError::new_err("No ILP result — call evaluate() first"))?;
1097
1098        let mut tasks: Vec<ilp_gpu::CooTask> = Vec::new();
1099        let mut fact_indices_buffers: Vec<xlog_cuda::memory::TrackedCudaSlice<u32>> = Vec::new();
1100        let mut next_global_idx: u32 = 0;
1101
1102        for group in positive_groups.iter().chain(negative_groups.iter()) {
1103            if group.num_rows == 0 {
1104                continue;
1105            }
1106            let k_idx = self
1107                .rel_index
1108                .iter()
1109                .position(|(_, name)| name == &group.relation)
1110                .ok_or_else(|| {
1111                    PyValueError::new_err(format!(
1112                        "Relation '{}' not in ILP schema",
1113                        group.relation
1114                    ))
1115                })? as u32;
1116
1117            let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
1118                .entries
1119                .iter()
1120                .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
1121                .collect();
1122            if relevant_entries.is_empty() {
1123                next_global_idx = next_global_idx.saturating_add(group.num_rows);
1124                continue;
1125            }
1126
1127            let arity = group.query_buf.arity();
1128            if arity == 0 {
1129                next_global_idx = next_global_idx.saturating_add(group.num_rows);
1130                continue;
1131            }
1132            let keys: Vec<usize> = (0..arity).collect();
1133
1134            let global_indices: Vec<u32> =
1135                (next_global_idx..next_global_idx + group.num_rows).collect();
1136            let mut d_fact_indices = self
1137                .provider
1138                .memory()
1139                .alloc::<u32>(group.num_rows as usize)
1140                .map_err(|e| types::gpu_err("alloc fact_indices", e))?;
1141            self.provider
1142                .device()
1143                .inner()
1144                .htod_sync_copy_into(&global_indices, &mut d_fact_indices)
1145                .map_err(|e| types::gpu_err("htod fact_indices", e))?;
1146            let fi_idx = fact_indices_buffers.len();
1147            fact_indices_buffers.push(d_fact_indices);
1148
1149            for entry in &relevant_entries {
1150                let cidx = match candidate_map.get(&(entry.i, entry.j, entry.k)) {
1151                    Some(c) => *c,
1152                    None => continue,
1153                };
1154                let entry_buf = entry.buffer.as_ref().ok_or_else(|| {
1155                    PyRuntimeError::new_err("internal: filtered entry has no buffer")
1156                })?;
1157                let d_mask = self
1158                    .provider
1159                    .membership_mask_device(&group.query_buf, entry_buf, &keys, &keys)
1160                    .map_err(|e| types::gpu_err("membership_mask", e))?;
1161                tasks.push(ilp_gpu::CooTask {
1162                    cidx,
1163                    num_query: group.num_rows,
1164                    d_mask,
1165                    fact_indices_idx: fi_idx,
1166                });
1167            }
1168
1169            next_global_idx = next_global_idx.saturating_add(group.num_rows);
1170        }
1171
1172        let num_tasks = tasks.len();
1173        let upper_bound: u32 = tasks.iter().map(|t| t.num_query).sum();
1174        if upper_bound == 0 || num_tasks == 0 {
1175            return self.build_loss_grad_empty_coo(
1176                py,
1177                &is_positive_host,
1178                num_facts,
1179                num_cands,
1180                is_f64,
1181            );
1182        }
1183
1184        let coo_bytes = (upper_bound as u64) * 8;
1185        let needs_chunking = coo_bytes > self.coo_chunk_budget;
1186
1187        let mut d_is_positive = self
1188            .provider
1189            .memory()
1190            .alloc::<u8>(num_facts as usize)
1191            .map_err(|e| types::gpu_err("alloc is_positive", e))?;
1192        self.provider
1193            .device()
1194            .inner()
1195            .htod_sync_copy_into(&is_positive_host, &mut d_is_positive)
1196            .map_err(|e| types::gpu_err("htod is_positive", e))?;
1197
1198        let cand_col = cand_buf
1199            .column(0)
1200            .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column"))?;
1201
1202        let (mut d_coo_facts, mut d_coo_cands, actual_nnz) = if !needs_chunking {
1203            ilp_gpu::build_coo_single(
1204                &self.provider,
1205                &tasks,
1206                &fact_indices_buffers,
1207                num_facts,
1208                num_cands,
1209                upper_bound,
1210            )?
1211        } else {
1212            match ilp_gpu::build_coo_chunked(
1213                &self.provider,
1214                &tasks,
1215                &fact_indices_buffers,
1216                num_facts,
1217                self.coo_chunk_budget,
1218            )? {
1219                Some(result) => result,
1220                None => {
1221                    return self.build_loss_grad_empty_coo(
1222                        py,
1223                        &is_positive_host,
1224                        num_facts,
1225                        num_cands,
1226                        is_f64,
1227                    );
1228                }
1229            }
1230        };
1231
1232        let d_row_offsets = ilp_gpu::sort_and_build_csr(
1233            &self.provider,
1234            &mut d_coo_facts,
1235            &mut d_coo_cands,
1236            actual_nnz,
1237            num_facts,
1238        )?;
1239
1240        ilp_gpu::forward_backward_reduce(
1241            &self.provider,
1242            py,
1243            &d_row_offsets,
1244            &d_coo_cands,
1245            cand_col,
1246            &d_is_positive,
1247            num_facts,
1248            num_cands,
1249            is_f64,
1250        )
1251    }
1252
1253    pub fn set_rule_mask(
1254        &mut self,
1255        name: String,
1256        mask_hard_flat: &Bound<'_, PyAny>,
1257        mask_soft_flat: &Bound<'_, PyAny>,
1258        schema_size: usize,
1259    ) -> PyResult<()> {
1260        if self.compiled_schema_size > 0 && schema_size != self.compiled_schema_size {
1261            return Err(PyValueError::new_err(format!(
1262                "schema_size mismatch: mask has N={} but compiled program expects N={}",
1263                schema_size, self.compiled_schema_size,
1264            )));
1265        }
1266
1267        let hard_dmt = dlpack_from_py(mask_hard_flat)?;
1268        let soft_dmt = dlpack_from_py(mask_soft_flat)?;
1269
1270        let hard_buf = self
1271            .provider
1272            .from_dlpack_tensors(vec![hard_dmt])
1273            .map_err(types::xlog_err)?;
1274        let soft_buf = self
1275            .provider
1276            .from_dlpack_tensors(vec![soft_dmt])
1277            .map_err(types::xlog_err)?;
1278
1279        self.executor
1280            .ilp_registry_mut()
1281            .insert_mask(name, hard_buf, soft_buf, schema_size);
1282        Ok(())
1283    }
1284
1285    /// Sparse mask API: candidate IDs + DLPack soft probabilities (GPU tensor).
1286    ///
1287    /// `candidate_ids` must be exactly `[0..C)` where C is the candidate count
1288    /// for this mask under the provided recursion policy.
1289    ///
1290    /// `soft_probs_dlpack` is a DLPack capsule (CUDA f64 tensor) passed from PyTorch.
1291    /// Rust imports it zero-copy, downloads values (not counted by the D2H counter),
1292    /// performs deterministic top-k (desc soft value, then lower id), and
1293    /// stores a sparse IlpMask (no dense N^3 materialization).
1294    #[pyo3(signature = (name, candidate_ids, soft_probs_dlpack, budget, allow_recursive=false))]
1295    pub fn set_rule_mask_sparse(
1296        &mut self,
1297        name: String,
1298        candidate_ids: Vec<u32>,
1299        soft_probs_dlpack: &Bound<'_, PyAny>,
1300        budget: usize,
1301        allow_recursive: bool,
1302    ) -> PyResult<()> {
1303        if self.strict_zero_dtoh {
1304            return Err(PyRuntimeError::new_err(
1305                "strict_zero_dtoh forbids legacy set_rule_mask_sparse; use set_rule_mask_sparse_selected instead",
1306            ));
1307        }
1308
1309        let tmj = extract_tmj_meta_for_mask(&self.plan, Some(&name));
1310        let n = tmj.schema_size;
1311        if n == 0 {
1312            return Err(PyValueError::new_err(format!(
1313                "no learnable mask '{}' found",
1314                name
1315            )));
1316        }
1317        if self.compiled_schema_size > 0 && n != self.compiled_schema_size {
1318            return Err(PyValueError::new_err(format!(
1319                "schema_size mismatch for '{}': plan N={} compiled N={}",
1320                name, n, self.compiled_schema_size
1321            )));
1322        }
1323
1324        let expected_c = self.expected_candidate_count(&name, allow_recursive)?;
1325        if candidate_ids.len() != expected_c {
1326            return Err(PyValueError::new_err(format!(
1327                "candidate_ids length {} != expected candidate count {}",
1328                candidate_ids.len(),
1329                expected_c
1330            )));
1331        }
1332        for (idx, &cid) in candidate_ids.iter().enumerate() {
1333            if cid != idx as u32 {
1334                return Err(PyValueError::new_err(format!(
1335                    "candidate_ids must be [0..{}), got id {} at position {}",
1336                    expected_c, cid, idx
1337                )));
1338            }
1339        }
1340
1341        // Import DLPack tensor (zero-copy, stays on GPU)
1342        let soft_dmt = dlpack_from_py(soft_probs_dlpack)?;
1343        let soft_buf = self
1344            .provider
1345            .from_dlpack_tensors(vec![soft_dmt])
1346            .map_err(types::xlog_err)?;
1347
1348        // Download f64 values (control-plane, NOT tracked by D2H counter)
1349        let soft_probs = self
1350            .provider
1351            .download_column_untracked::<f64>(&soft_buf, 0)
1352            .map_err(types::xlog_err)?;
1353
1354        if soft_probs.len() != candidate_ids.len() {
1355            return Err(PyValueError::new_err(format!(
1356                "soft_probs tensor length {} != candidate_ids length {}",
1357                soft_probs.len(),
1358                candidate_ids.len()
1359            )));
1360        }
1361
1362        let candidate_triples = self.candidate_triples_for_mask(&name, allow_recursive)?;
1363        if candidate_triples.len() != expected_c {
1364            return Err(PyRuntimeError::new_err(format!(
1365                "internal candidate count mismatch: triples={} expected={}",
1366                candidate_triples.len(),
1367                expected_c
1368            )));
1369        }
1370
1371        // Convert to f32 for top-k ranking in insert_mask_from_sparse
1372        let active_soft: Vec<f32> = soft_probs.iter().map(|&v| v as f32).collect();
1373
1374        self.executor
1375            .ilp_registry_mut()
1376            .insert_mask_from_sparse(name, n, &candidate_triples, &active_soft, budget)
1377            .map_err(types::xlog_err)
1378    }
1379
1380    /// Preferred sparse mask API: caller preselects candidate IDs and passes
1381    /// only the selected subset plus aligned soft probabilities.
1382    ///
1383    /// This path avoids the full-vector soft-probability download in
1384    /// `set_rule_mask_sparse(...)`. The selected IDs are mapped directly to the
1385    /// existing candidate triples and stored in sparse-mask order.
1386    #[pyo3(signature = (name, selected_candidate_ids, selected_soft_probs_dlpack, allow_recursive=false))]
1387    pub fn set_rule_mask_sparse_selected(
1388        &mut self,
1389        name: String,
1390        selected_candidate_ids: Vec<u32>,
1391        selected_soft_probs_dlpack: &Bound<'_, PyAny>,
1392        allow_recursive: bool,
1393    ) -> PyResult<()> {
1394        if self.strict_zero_dtoh {
1395            return Err(PyRuntimeError::new_err(
1396                "strict_zero_dtoh forbids set_rule_mask_sparse_selected; use explicit compatibility export instead",
1397            ));
1398        }
1399        let tmj = extract_tmj_meta_for_mask(&self.plan, Some(&name));
1400        let n = tmj.schema_size;
1401        if n == 0 {
1402            return Err(PyValueError::new_err(format!(
1403                "no learnable mask '{}' found",
1404                name
1405            )));
1406        }
1407        if self.compiled_schema_size > 0 && n != self.compiled_schema_size {
1408            return Err(PyValueError::new_err(format!(
1409                "schema_size mismatch for '{}': plan N={} compiled N={}",
1410                name, n, self.compiled_schema_size
1411            )));
1412        }
1413
1414        let candidate_triples = self.candidate_triples_for_mask(&name, allow_recursive)?;
1415        let expected_c = candidate_triples.len();
1416
1417        let soft_dmt = dlpack_from_py(selected_soft_probs_dlpack)?;
1418        let soft_buf = self
1419            .provider
1420            .from_dlpack_tensors(vec![soft_dmt])
1421            .map_err(types::xlog_err)?;
1422        let selected_len = usize::try_from(soft_buf.num_rows())
1423            .map_err(|_| PyValueError::new_err("selected soft_probs length overflow"))?;
1424
1425        if selected_len != selected_candidate_ids.len() {
1426            return Err(PyValueError::new_err(format!(
1427                "selected soft_probs length {} != selected_candidate_ids length {}",
1428                selected_len,
1429                selected_candidate_ids.len()
1430            )));
1431        }
1432
1433        let mut seen = HashSet::with_capacity(selected_candidate_ids.len());
1434        let mut selected_entries = Vec::with_capacity(selected_candidate_ids.len());
1435        for (pos, &cid) in selected_candidate_ids.iter().enumerate() {
1436            let idx = usize::try_from(cid).map_err(|_| {
1437                PyValueError::new_err(format!("candidate id {} overflows usize", cid))
1438            })?;
1439            if idx >= expected_c {
1440                return Err(PyValueError::new_err(format!(
1441                    "selected candidate id {} out of range [0, {}) at position {}",
1442                    cid, expected_c, pos
1443                )));
1444            }
1445            if !seen.insert(cid) {
1446                return Err(PyValueError::new_err(format!(
1447                    "duplicate selected candidate id {} at position {}",
1448                    cid, pos
1449                )));
1450            }
1451            selected_entries.push(candidate_triples[idx]);
1452        }
1453
1454        self.executor
1455            .ilp_registry_mut()
1456            .insert_selected_mask(name, n, &selected_entries);
1457        Ok(())
1458    }
1459
1460    /// Strict sparse mask API: selected candidate IDs stay as a device tensor on
1461    /// the Python side. Rust resolves them against the fixed candidate order
1462    /// uploaded via `set_candidate_map(...)`.
1463    #[pyo3(signature = (name, selected_candidate_ids_dlpack, selected_soft_probs_dlpack, allow_recursive=false))]
1464    pub fn set_rule_mask_sparse_selected_device(
1465        &mut self,
1466        name: String,
1467        selected_candidate_ids_dlpack: &Bound<'_, PyAny>,
1468        selected_soft_probs_dlpack: &Bound<'_, PyAny>,
1469        allow_recursive: bool,
1470    ) -> PyResult<()> {
1471        self.set_rule_mask_sparse_selected_device_impl(
1472            name,
1473            selected_candidate_ids_dlpack,
1474            selected_soft_probs_dlpack,
1475            allow_recursive,
1476            true,
1477        )
1478    }
1479
1480    pub fn evaluate(&mut self, py: Python<'_>) -> PyResult<()> {
1481        if self.strict_zero_dtoh && self.executor.ilp_registry().has_sparse_device_mask() {
1482            return Err(PyRuntimeError::new_err(
1483                "SparseDevice evaluate() is incompatible with strict_zero_dtoh; \
1484use train_only(..., strict_gpu_native=True) or export an explicit compatibility mask first",
1485            ));
1486        }
1487        self.evaluate_ilp_plan(py)
1488    }
1489
1490    /// Reset mutable runtime state for ILP attempt reuse.
1491    ///
1492    /// Clears ILP registry (masks/tagged results), executor store,
1493    /// join index cache, stats, and profiler. Then re-registers schemas
1494    /// with empty buffers, reloads base facts from AST, and re-executes
1495    /// the plan. Preserves all immutable compile artifacts (AST, plan,
1496    /// schemas, rel_index, provider, TMJ metadata, max_active_rules).
1497    ///
1498    /// After reset, the program is in the same state as a fresh compile()
1499    /// with the same source — ready for set_rule_mask / evaluate cycles.
1500    pub fn reset_runtime(&mut self, py: Python<'_>) -> PyResult<()> {
1501        let _result: xlog_core::Result<()> = py.allow_threads(|| {
1502            // 1. Clear all mutable state (ILP registry, store, caches, stats)
1503            self.executor.reset_for_ilp();
1504
1505            // 2. Re-register schemas with empty buffers
1506            for (name, schema) in &self.schemas {
1507                let empty = self.provider.create_empty_buffer(schema.clone())?;
1508                self.executor.store_mut().put(name, empty);
1509            }
1510
1511            // 3. Reload base facts from preserved AST
1512            load_facts_into_store(&self.ast, &self.provider, &mut self.executor, &self.schemas)?;
1513
1514            // 3b. Reapply persistent relation uploads on top of AST/base facts.
1515            self.apply_relation_overrides()?;
1516
1517            // 4. Re-execute plan (populates derived relations)
1518            self.executor.execute_plan(&self.plan)?;
1519
1520            Ok(())
1521        });
1522        _result.map_err(types::xlog_err)?;
1523
1524        // 5. Reset D2H transfer counter
1525        self.provider.reset_d2h_transfer_count();
1526
1527        Ok(())
1528    }
1529
1530    pub fn get_tagged_results(&self) -> PyResult<Vec<(u32, u32, u32, u32)>> {
1531        self.ensure_host_semantic_compat(
1532            "get_tagged_results()",
1533            "disable strict_zero_dtoh for host materialization",
1534        )?;
1535        match self.executor.ilp_last_result() {
1536            Some(result) => Ok(result
1537                .entries
1538                .iter()
1539                .map(|e| (e.i, e.j, e.k, e.num_rows))
1540                .collect()),
1541            None => Ok(Vec::new()),
1542        }
1543    }
1544
1545    pub fn fact_exists(&self, relation: &str, values: Vec<i64>) -> PyResult<bool> {
1546        self.ensure_host_semantic_compat("fact_exists()", "batch_fact_membership_device(...)")?;
1547        let buf =
1548            self.executor.store().get(relation).ok_or_else(|| {
1549                PyValueError::new_err(format!("Relation '{}' not found", relation))
1550            })?;
1551
1552        Self::fact_exists_in_buffer(&self.provider, buf, &values).map_err(types::xlog_err)
1553    }
1554
1555    /// Return all facts in the named relation as a list of int lists.
1556    #[pyo3(signature = (rel_name))]
1557    pub fn relation_facts(&self, rel_name: String) -> PyResult<Vec<Vec<i64>>> {
1558        self.ensure_host_semantic_compat(
1559            "relation_facts()",
1560            "disable strict_zero_dtoh for host materialization",
1561        )?;
1562        let buf =
1563            self.executor.store().get(&rel_name).ok_or_else(|| {
1564                PyValueError::new_err(format!("Relation '{}' not found", rel_name))
1565            })?;
1566
1567        let num_rows =
1568            read_device_row_count(&self.provider, buf).map_err(types::xlog_err)? as usize;
1569        if num_rows == 0 {
1570            return Ok(Vec::new());
1571        }
1572
1573        // Download all columns (reuse fact_exists_in_buffer pattern)
1574        let schema = buf.schema();
1575        let mut columns: Vec<Vec<i64>> = Vec::new();
1576        for col_idx in 0..buf.arity() {
1577            let col_type = schema.column_type(col_idx).ok_or_else(|| {
1578                PyRuntimeError::new_err(format!("Column {} type not found in schema", col_idx))
1579            })?;
1580            let col_i64: Vec<i64> = match col_type {
1581                ScalarType::I64 => self
1582                    .provider
1583                    .download_column::<i64>(buf, col_idx)
1584                    .map_err(types::xlog_err)?,
1585                ScalarType::I32 => self
1586                    .provider
1587                    .download_column::<i32>(buf, col_idx)
1588                    .map_err(types::xlog_err)?
1589                    .into_iter()
1590                    .map(|v| v as i64)
1591                    .collect(),
1592                ScalarType::U32 | ScalarType::Symbol => self
1593                    .provider
1594                    .download_column::<u32>(buf, col_idx)
1595                    .map_err(types::xlog_err)?
1596                    .into_iter()
1597                    .map(|v| v as i64)
1598                    .collect(),
1599                ScalarType::U64 => self
1600                    .provider
1601                    .download_column::<u64>(buf, col_idx)
1602                    .map_err(types::xlog_err)?
1603                    .into_iter()
1604                    .map(|v| v as i64)
1605                    .collect(),
1606                ScalarType::Bool => self
1607                    .provider
1608                    .download_column::<bool>(buf, col_idx)
1609                    .map_err(types::xlog_err)?
1610                    .into_iter()
1611                    .map(|v| if v { 1i64 } else { 0i64 })
1612                    .collect(),
1613                ScalarType::F32 | ScalarType::F64 => {
1614                    return Err(PyRuntimeError::new_err(format!(
1615                        "relation_facts does not support float column type {:?}",
1616                        col_type
1617                    )));
1618                }
1619            };
1620            columns.push(col_i64);
1621        }
1622
1623        let mut result = Vec::with_capacity(num_rows);
1624        for r in 0..num_rows {
1625            let mut row = Vec::with_capacity(buf.arity());
1626            for c in 0..buf.arity() {
1627                row.push(columns[c][r]);
1628            }
1629            result.push(row);
1630        }
1631        Ok(result)
1632    }
1633
1634    /// Sample up to `max_n` derived facts for `head_rel` that are NOT in `exclude`.
1635    ///
1636    /// Returns `list[list[int]]` — each inner list is a tuple of column values.
1637    /// Uses the same column-download pattern as `relation_facts`.
1638    #[pyo3(signature = (head_rel, exclude, max_n))]
1639    pub fn sample_false_positives(
1640        &self,
1641        head_rel: String,
1642        exclude: Vec<(String, Vec<i64>)>,
1643        max_n: usize,
1644    ) -> PyResult<Vec<Vec<i64>>> {
1645        self.ensure_host_semantic_compat(
1646            "sample_false_positives()",
1647            "disable strict_zero_dtoh for host materialization",
1648        )?;
1649        // Build exclude set: only consider tuples for the requested relation
1650        let exclude_set: HashSet<Vec<i64>> = exclude
1651            .into_iter()
1652            .filter(|(rel, _)| rel == &head_rel)
1653            .map(|(_, vals)| vals)
1654            .collect();
1655
1656        // Download all facts using the same pattern as relation_facts
1657        let buf =
1658            self.executor.store().get(&head_rel).ok_or_else(|| {
1659                PyValueError::new_err(format!("Relation '{}' not found", head_rel))
1660            })?;
1661
1662        let num_rows =
1663            read_device_row_count(&self.provider, buf).map_err(types::xlog_err)? as usize;
1664        if num_rows == 0 {
1665            return Ok(Vec::new());
1666        }
1667
1668        let schema = buf.schema();
1669        let mut columns: Vec<Vec<i64>> = Vec::new();
1670        for col_idx in 0..buf.arity() {
1671            let col_type = schema.column_type(col_idx).ok_or_else(|| {
1672                PyRuntimeError::new_err(format!("Column {} type not found in schema", col_idx))
1673            })?;
1674            let col_i64: Vec<i64> = match col_type {
1675                ScalarType::I64 => self
1676                    .provider
1677                    .download_column::<i64>(buf, col_idx)
1678                    .map_err(types::xlog_err)?,
1679                ScalarType::I32 => self
1680                    .provider
1681                    .download_column::<i32>(buf, col_idx)
1682                    .map_err(types::xlog_err)?
1683                    .into_iter()
1684                    .map(|v| v as i64)
1685                    .collect(),
1686                ScalarType::U32 | ScalarType::Symbol => self
1687                    .provider
1688                    .download_column::<u32>(buf, col_idx)
1689                    .map_err(types::xlog_err)?
1690                    .into_iter()
1691                    .map(|v| v as i64)
1692                    .collect(),
1693                ScalarType::U64 => self
1694                    .provider
1695                    .download_column::<u64>(buf, col_idx)
1696                    .map_err(types::xlog_err)?
1697                    .into_iter()
1698                    .map(|v| v as i64)
1699                    .collect(),
1700                ScalarType::Bool => self
1701                    .provider
1702                    .download_column::<bool>(buf, col_idx)
1703                    .map_err(types::xlog_err)?
1704                    .into_iter()
1705                    .map(|v| if v { 1i64 } else { 0i64 })
1706                    .collect(),
1707                ScalarType::F32 | ScalarType::F64 => {
1708                    return Err(PyRuntimeError::new_err(format!(
1709                        "sample_false_positives does not support float column type {:?}",
1710                        col_type
1711                    )));
1712                }
1713            };
1714            columns.push(col_i64);
1715        }
1716
1717        // Filter out excluded tuples and cap at max_n
1718        let mut result = Vec::with_capacity(max_n.min(num_rows));
1719        for r in 0..num_rows {
1720            if result.len() >= max_n {
1721                break;
1722            }
1723            let mut row = Vec::with_capacity(buf.arity());
1724            for c in 0..buf.arity() {
1725                row.push(columns[c][r]);
1726            }
1727            if !exclude_set.contains(&row) {
1728                result.push(row);
1729            }
1730        }
1731        Ok(result)
1732    }
1733
1734    pub fn tagged_entries_containing_fact(
1735        &self,
1736        relation: &str,
1737        values: Vec<i64>,
1738    ) -> PyResult<Vec<(u32, u32, u32)>> {
1739        self.ensure_host_semantic_compat(
1740            "tagged_entries_containing_fact()",
1741            "batch_tagged_credit_device(...)",
1742        )?;
1743        let k_idx = self
1744            .rel_index
1745            .iter()
1746            .position(|(_, name)| name == relation)
1747            .ok_or_else(|| {
1748                PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
1749            })? as u32;
1750
1751        let tagged = match self.executor.ilp_last_result() {
1752            Some(t) => t,
1753            None => return Ok(Vec::new()),
1754        };
1755
1756        let mut result = Vec::new();
1757        for entry in &tagged.entries {
1758            if entry.k != k_idx || entry.num_rows == 0 {
1759                continue;
1760            }
1761
1762            let (_, left_name) = &self.rel_index[entry.i as usize];
1763            let (_, right_name) = &self.rel_index[entry.j as usize];
1764
1765            let left_buf = match self.executor.store().get(left_name) {
1766                Some(buf) if buf.arity() > 0 => buf,
1767                _ => continue,
1768            };
1769            let right_buf = match self.executor.store().get(right_name) {
1770                Some(buf) if buf.arity() > 0 => buf,
1771                _ => continue,
1772            };
1773
1774            // Arity guard: same as executor (skip if join keys exceed columns)
1775            let left_max = self.left_keys.iter().copied().max().unwrap_or(0);
1776            let right_max = self.right_keys.iter().copied().max().unwrap_or(0);
1777            if left_buf.arity() <= left_max || right_buf.arity() <= right_max {
1778                continue;
1779            }
1780
1781            let joined = self
1782                .provider
1783                .hash_join_v2(
1784                    left_buf,
1785                    right_buf,
1786                    &self.left_keys,
1787                    &self.right_keys,
1788                    JoinType::Inner,
1789                )
1790                .map_err(types::xlog_err)?;
1791
1792            // Apply head_projection: check projected columns against values,
1793            // not the raw join output (which has more columns than the head).
1794            let found = if !self.head_projection.is_empty()
1795                && self.head_projection.len() == values.len()
1796            {
1797                Self::fact_exists_projected(&self.provider, &joined, &values, &self.head_projection)
1798                    .map_err(types::xlog_err)?
1799            } else {
1800                Self::fact_exists_in_buffer(&self.provider, &joined, &values)
1801                    .map_err(types::xlog_err)?
1802            };
1803
1804            if found {
1805                result.push((entry.i, entry.j, entry.k));
1806            }
1807        }
1808        Ok(result)
1809    }
1810
1811    pub fn ilp_schema_size(&self) -> usize {
1812        self.rel_index.len()
1813    }
1814
1815    pub fn ilp_relation_names(&self) -> Vec<String> {
1816        self.rel_index
1817            .iter()
1818            .map(|(_, name)| name.clone())
1819            .collect()
1820    }
1821
1822    /// Return declared predicate types from source `pred` declarations.
1823    ///
1824    /// Output is a list of `(name, types)` tuples so callers can
1825    /// deterministically inspect whether metadata is available for
1826    /// relations used during promotion.
1827    pub fn relation_type_annotations(&self) -> Vec<(String, Vec<String>)> {
1828        self.ast
1829            .predicates
1830            .iter()
1831            .map(|pred| {
1832                let types = pred.types.iter().map(type_ref_name).collect();
1833                (pred.name.clone(), types)
1834            })
1835            .collect()
1836    }
1837
1838    /// Return the set of valid (i,j,k) candidates for the given learnable mask.
1839    ///
1840    /// Pruning rules:
1841    /// - k must be the head relation for this mask
1842    /// - At least one of (i,j) must have nonzero tuples in the store
1843    /// - Template+template body pairs (both have zero tuples) are pruned
1844    /// - If allow_recursive is false: i==k_head or j==k_head are pruned
1845    ///   (unless head already has base facts)
1846    ///
1847    /// Returns list of dicts: [{id, i, j, k, left_name, right_name, head_name}]
1848    /// IDs assigned 0..C-1 after sorting by (k, i, j) ascending.
1849    #[pyo3(signature = (mask_name, allow_recursive=false))]
1850    fn valid_candidates(
1851        &self,
1852        py: Python<'_>,
1853        mask_name: String,
1854        allow_recursive: bool,
1855    ) -> PyResult<Vec<StdHashMap<String, PyObject>>> {
1856        let candidates = self.candidate_triples_for_mask(&mask_name, allow_recursive)?;
1857
1858        let result: Vec<StdHashMap<String, PyObject>> = candidates
1859            .iter()
1860            .enumerate()
1861            .map(
1862                |(id, &(i, j, k))| -> PyResult<StdHashMap<String, PyObject>> {
1863                    let mut d = StdHashMap::new();
1864                    d.insert("id".into(), id.into_pyobject(py)?.into_any().unbind());
1865                    d.insert("i".into(), i.into_pyobject(py)?.into_any().unbind());
1866                    d.insert("j".into(), j.into_pyobject(py)?.into_any().unbind());
1867                    d.insert("k".into(), k.into_pyobject(py)?.into_any().unbind());
1868                    d.insert(
1869                        "left_name".into(),
1870                        self.rel_index[i as usize]
1871                            .1
1872                            .clone()
1873                            .into_pyobject(py)?
1874                            .into_any()
1875                            .unbind(),
1876                    );
1877                    d.insert(
1878                        "right_name".into(),
1879                        self.rel_index[j as usize]
1880                            .1
1881                            .clone()
1882                            .into_pyobject(py)?
1883                            .into_any()
1884                            .unbind(),
1885                    );
1886                    d.insert(
1887                        "head_name".into(),
1888                        self.rel_index[k as usize]
1889                            .1
1890                            .clone()
1891                            .into_pyobject(py)?
1892                            .into_any()
1893                            .unbind(),
1894                    );
1895                    Ok(d)
1896                },
1897            )
1898            .collect::<PyResult<_>>()?;
1899
1900        Ok(result)
1901    }
1902
1903    pub fn commit_induced_rule(&mut self, rule_source: &str) -> PyResult<()> {
1904        let new_base = format!("{}\n{}", self.base_source, rule_source);
1905
1906        let ast = xlog_logic::parse_program(&new_base).map_err(types::val_err)?;
1907        let mut compiler = xlog_logic::Compiler::new();
1908        compiler.set_max_active_rules(self.max_active_rules);
1909        let plan = compiler.compile_program(&ast).map_err(types::xlog_err)?;
1910        let schemas = compiler.schemas().clone();
1911
1912        self.executor.reset_for_mc();
1913        for (name, rel_id) in compiler.rel_ids() {
1914            self.executor.register_relation(*rel_id, name);
1915        }
1916        for (name, schema) in &schemas {
1917            let empty = self
1918                .provider
1919                .create_empty_buffer(schema.clone())
1920                .map_err(types::xlog_err)?;
1921            self.executor.store_mut().put(name, empty);
1922        }
1923        load_facts_into_store(&ast, &self.provider, &mut self.executor, &schemas)
1924            .map_err(types::xlog_err)?;
1925        self.apply_relation_overrides().map_err(types::xlog_err)?;
1926        self.executor.execute_plan(&plan).map_err(types::xlog_err)?;
1927
1928        self.base_source = new_base;
1929        self.ast = ast;
1930        let tmj = extract_tmj_meta(&plan);
1931        self.left_keys = tmj.left_keys;
1932        self.right_keys = tmj.right_keys;
1933        self.head_projection = tmj.head_projection;
1934        self.compiled_schema_size = tmj.schema_size;
1935        self.head_rel_name = tmj.head_rel_name;
1936        self.plan = plan;
1937        self.schemas = schemas;
1938        Ok(())
1939    }
1940
1941    /// GPU-side batch fact membership check.
1942    /// Uploads `facts` (list of value-lists) to a temporary CudaBuffer,
1943    /// semi-joins against the named relation, returns per-fact boolean mask.
1944    /// Zero download_column_* calls — only downloads the u8 mask.
1945    pub fn batch_fact_membership_device(
1946        &self,
1947        py: Python<'_>,
1948        relation: &str,
1949        facts: Vec<Vec<i64>>,
1950    ) -> PyResult<PyObject> {
1951        let buf =
1952            self.executor.store().get(relation).ok_or_else(|| {
1953                PyValueError::new_err(format!("Relation '{}' not found", relation))
1954            })?;
1955
1956        if facts.is_empty() {
1957            let empty = self
1958                .provider
1959                .memory()
1960                .alloc::<u8>(0)
1961                .map_err(types::xlog_err)?;
1962            return export_device_bool_tensor(&self.provider, py, empty, 0);
1963        }
1964        if buf.arity() == 0 {
1965            let mut zeros = self
1966                .provider
1967                .memory()
1968                .alloc::<u8>(facts.len())
1969                .map_err(types::xlog_err)?;
1970            self.provider
1971                .device()
1972                .inner()
1973                .memset_zeros(&mut zeros)
1974                .map_err(types::xlog_err)?;
1975            return export_device_bool_tensor(&self.provider, py, zeros, facts.len());
1976        }
1977
1978        let col_bytes = pack_i64_columns_typed(relation, &facts, buf.schema())?;
1979        let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
1980        let query_buf = self
1981            .provider
1982            .create_buffer_from_slices(&col_slices, buf.schema().clone())
1983            .map_err(types::xlog_err)?;
1984
1985        let keys: Vec<usize> = (0..buf.arity()).collect();
1986        let mask = self
1987            .provider
1988            .membership_mask_device(&query_buf, buf, &keys, &keys)
1989            .map_err(types::xlog_err)?;
1990        export_device_bool_tensor(&self.provider, py, mask, facts.len())
1991    }
1992
1993    pub fn batch_fact_membership(
1994        &self,
1995        relation: &str,
1996        facts: Vec<Vec<i64>>,
1997    ) -> PyResult<Vec<bool>> {
1998        self.ensure_host_semantic_compat(
1999            "batch_fact_membership()",
2000            "batch_fact_membership_device(...)",
2001        )?;
2002        if facts.is_empty() {
2003            return Ok(Vec::new());
2004        }
2005
2006        let buf =
2007            self.executor.store().get(relation).ok_or_else(|| {
2008                PyValueError::new_err(format!("Relation '{}' not found", relation))
2009            })?;
2010
2011        let arity = buf.arity();
2012        if arity == 0 {
2013            return Ok(vec![false; facts.len()]);
2014        }
2015
2016        // Schema-aware typed upload
2017        let col_bytes = pack_i64_columns_typed(relation, &facts, buf.schema())?;
2018        let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
2019        let query_buf = self
2020            .provider
2021            .create_buffer_from_slices(&col_slices, buf.schema().clone())
2022            .map_err(types::xlog_err)?;
2023
2024        // All columns are keys (full-tuple match)
2025        let keys: Vec<usize> = (0..arity).collect();
2026
2027        self.provider
2028            .membership_mask(&query_buf, buf, &keys, &keys)
2029            .map_err(types::xlog_err)
2030    }
2031
2032    /// GPU-side batch credit assignment.
2033    ///
2034    /// For each fact in `facts`, returns the list of (i,j,k) entries whose
2035    /// join result contains that fact. Uses membership_mask against retained
2036    /// per-entry buffers — zero download_column_* calls.
2037    pub fn batch_tagged_credit(
2038        &self,
2039        relation: &str,
2040        facts: Vec<Vec<i64>>,
2041    ) -> PyResult<Vec<Vec<(u32, u32, u32)>>> {
2042        self.ensure_host_semantic_compat(
2043            "batch_tagged_credit()",
2044            "batch_tagged_credit_device(...)",
2045        )?;
2046        if facts.is_empty() {
2047            return Ok(Vec::new());
2048        }
2049
2050        // Find k index for this relation
2051        let k_idx = self
2052            .rel_index
2053            .iter()
2054            .position(|(_, name)| name == relation)
2055            .ok_or_else(|| {
2056                PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
2057            })? as u32;
2058
2059        let tagged = match self.executor.ilp_last_result() {
2060            Some(t) => t,
2061            None => return Ok(vec![Vec::new(); facts.len()]),
2062        };
2063
2064        // Filter entries to those matching target relation k, with retained buffers
2065        let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
2066            .entries
2067            .iter()
2068            .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
2069            .collect();
2070
2071        if relevant_entries.is_empty() {
2072            return Ok(vec![Vec::new(); facts.len()]);
2073        }
2074
2075        // Determine arity from the first entry's buffer
2076        let first_buf = relevant_entries[0]
2077            .buffer
2078            .as_ref()
2079            .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2080        let arity = first_buf.arity();
2081        if arity == 0 {
2082            return Ok(vec![Vec::new(); facts.len()]);
2083        }
2084
2085        // Schema-aware typed upload
2086        let schema = first_buf.schema().clone();
2087        let col_bytes = pack_i64_columns_typed(relation, &facts, &schema)?;
2088        let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
2089        let query_buf = self
2090            .provider
2091            .create_buffer_from_slices(&col_slices, schema)
2092            .map_err(types::xlog_err)?;
2093
2094        let keys: Vec<usize> = (0..arity).collect();
2095
2096        // For each relevant entry, compute membership mask against query facts
2097        let mut per_fact_credits: Vec<Vec<(u32, u32, u32)>> = vec![Vec::new(); facts.len()];
2098
2099        for entry in &relevant_entries {
2100            let entry_buf = entry
2101                .buffer
2102                .as_ref()
2103                .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2104            let mask = self
2105                .provider
2106                .membership_mask(&query_buf, entry_buf, &keys, &keys)
2107                .map_err(types::xlog_err)?;
2108
2109            for (fact_idx, &found) in mask.iter().enumerate() {
2110                if found {
2111                    per_fact_credits[fact_idx].push((entry.i, entry.j, entry.k));
2112                }
2113            }
2114        }
2115
2116        Ok(per_fact_credits)
2117    }
2118
2119    /// GPU-side batch credit assignment.
2120    ///
2121    /// Returns a CSR-style device representation:
2122    /// - `fact_row_offsets`: len = num_facts + 1
2123    /// - `entry_indices`: COO candidate indices, sorted by fact row
2124    /// - `entry_i/j/k`: metadata arrays indexed by `entry_indices`
2125    ///
2126    /// Zero DTOH calls on the query path. Uses the non-chunked COO builder
2127    /// to avoid reading device-side nnz metadata back to the host.
2128    pub fn batch_tagged_credit_device(
2129        &self,
2130        py: Python<'_>,
2131        relation: &str,
2132        facts: Vec<Vec<i64>>,
2133    ) -> PyResult<IlpTaggedCreditDeviceResult> {
2134        if facts.is_empty() {
2135            return empty_tagged_credit_device_result(&self.provider, py, 0);
2136        }
2137
2138        let k_idx = self
2139            .rel_index
2140            .iter()
2141            .position(|(_, name)| name == relation)
2142            .ok_or_else(|| {
2143                PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
2144            })? as u32;
2145
2146        let tagged = match self.executor.ilp_last_result() {
2147            Some(t) => t,
2148            None => return empty_tagged_credit_device_result(&self.provider, py, facts.len()),
2149        };
2150
2151        let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
2152            .entries
2153            .iter()
2154            .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
2155            .collect();
2156
2157        if relevant_entries.is_empty() {
2158            return empty_tagged_credit_device_result(&self.provider, py, facts.len());
2159        }
2160
2161        let first_buf = relevant_entries[0]
2162            .buffer
2163            .as_ref()
2164            .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2165        let arity = first_buf.arity();
2166        if arity == 0 {
2167            return empty_tagged_credit_device_result(&self.provider, py, facts.len());
2168        }
2169
2170        let schema = first_buf.schema().clone();
2171        let col_bytes = pack_i64_columns_typed(relation, &facts, &schema)?;
2172        let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
2173        let query_buf = self
2174            .provider
2175            .create_buffer_from_slices(&col_slices, schema)
2176            .map_err(types::xlog_err)?;
2177
2178        let keys: Vec<usize> = (0..arity).collect();
2179        let num_facts = u32::try_from(facts.len())
2180            .map_err(|_| PyValueError::new_err("facts length exceeds u32::MAX"))?;
2181        let num_entries = u32::try_from(relevant_entries.len())
2182            .map_err(|_| PyValueError::new_err("entry count exceeds u32::MAX"))?;
2183        let upper_bound = num_facts
2184            .checked_mul(num_entries)
2185            .ok_or_else(|| PyValueError::new_err("credit upper bound overflow"))?;
2186
2187        let fact_indices_host: Vec<u32> = (0..num_facts).collect();
2188        let mut d_fact_indices = self
2189            .provider
2190            .memory()
2191            .alloc::<u32>(num_facts as usize)
2192            .map_err(|e| types::gpu_err("alloc fact_indices", e))?;
2193        self.provider
2194            .device()
2195            .inner()
2196            .htod_sync_copy_into(&fact_indices_host, &mut d_fact_indices)
2197            .map_err(|e| types::gpu_err("htod fact_indices", e))?;
2198
2199        let mut entry_i_host = Vec::with_capacity(relevant_entries.len());
2200        let mut entry_j_host = Vec::with_capacity(relevant_entries.len());
2201        let mut entry_k_host = Vec::with_capacity(relevant_entries.len());
2202        let mut tasks = Vec::with_capacity(relevant_entries.len());
2203
2204        for (entry_idx, entry) in relevant_entries.iter().enumerate() {
2205            let entry_buf = entry
2206                .buffer
2207                .as_ref()
2208                .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2209            let d_mask = self
2210                .provider
2211                .membership_mask_device(&query_buf, entry_buf, &keys, &keys)
2212                .map_err(|e| types::gpu_err("membership_mask", e))?;
2213            tasks.push(ilp_gpu::CooTask {
2214                d_mask,
2215                fact_indices_idx: 0,
2216                cidx: entry_idx as u32,
2217                num_query: num_facts,
2218            });
2219            entry_i_host.push(entry.i);
2220            entry_j_host.push(entry.j);
2221            entry_k_host.push(entry.k);
2222        }
2223
2224        let (mut d_coo_facts, mut d_coo_cands, actual_nnz) = ilp_gpu::build_coo_single(
2225            &self.provider,
2226            &tasks,
2227            &[d_fact_indices],
2228            num_facts,
2229            num_entries,
2230            upper_bound,
2231        )?;
2232        let d_row_offsets = ilp_gpu::sort_and_build_csr(
2233            &self.provider,
2234            &mut d_coo_facts,
2235            &mut d_coo_cands,
2236            actual_nnz,
2237            num_facts,
2238        )?;
2239
2240        let mut d_entry_i = self
2241            .provider
2242            .memory()
2243            .alloc::<u32>(entry_i_host.len())
2244            .map_err(|e| types::gpu_err("alloc entry_i", e))?;
2245        let mut d_entry_j = self
2246            .provider
2247            .memory()
2248            .alloc::<u32>(entry_j_host.len())
2249            .map_err(|e| types::gpu_err("alloc entry_j", e))?;
2250        let mut d_entry_k = self
2251            .provider
2252            .memory()
2253            .alloc::<u32>(entry_k_host.len())
2254            .map_err(|e| types::gpu_err("alloc entry_k", e))?;
2255        self.provider
2256            .device()
2257            .inner()
2258            .htod_sync_copy_into(&entry_i_host, &mut d_entry_i)
2259            .map_err(|e| types::gpu_err("htod entry_i", e))?;
2260        self.provider
2261            .device()
2262            .inner()
2263            .htod_sync_copy_into(&entry_j_host, &mut d_entry_j)
2264            .map_err(|e| types::gpu_err("htod entry_j", e))?;
2265        self.provider
2266            .device()
2267            .inner()
2268            .htod_sync_copy_into(&entry_k_host, &mut d_entry_k)
2269            .map_err(|e| types::gpu_err("htod entry_k", e))?;
2270
2271        Ok(IlpTaggedCreditDeviceResult {
2272            fact_row_offsets: export_device_u32_tensor_as_i32(
2273                &self.provider,
2274                py,
2275                d_row_offsets,
2276                facts.len() + 1,
2277            )?,
2278            entry_indices: export_device_u32_tensor_as_i32(
2279                &self.provider,
2280                py,
2281                d_coo_cands,
2282                upper_bound as usize,
2283            )?,
2284            entry_i: export_device_u32_tensor_as_i32(
2285                &self.provider,
2286                py,
2287                d_entry_i,
2288                entry_i_host.len(),
2289            )?,
2290            entry_j: export_device_u32_tensor_as_i32(
2291                &self.provider,
2292                py,
2293                d_entry_j,
2294                entry_j_host.len(),
2295            )?,
2296            entry_k: export_device_u32_tensor_as_i32(
2297                &self.provider,
2298                py,
2299                d_entry_k,
2300                entry_k_host.len(),
2301            )?,
2302        })
2303    }
2304
2305    pub fn d2h_transfer_count(&self) -> u64 {
2306        self.provider.d2h_transfer_count()
2307    }
2308
2309    pub fn reset_d2h_transfer_count(&self) {
2310        self.provider.reset_d2h_transfer_count()
2311    }
2312
2313    pub fn host_transfer_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
2314        let stats = self.provider.host_transfer_stats();
2315        let dict = PyDict::new(py);
2316        dict.set_item("dtoh_bytes", stats.dtoh_bytes)?;
2317        dict.set_item("htod_bytes", stats.htod_bytes)?;
2318        dict.set_item("dtoh_calls", stats.dtoh_calls)?;
2319        dict.set_item("htod_calls", stats.htod_calls)?;
2320        Ok(dict.into())
2321    }
2322
2323    pub fn reset_host_transfer_stats(&self) {
2324        self.provider.reset_host_transfer_stats()
2325    }
2326}
2327
2328// ---------------------------------------------------------------------------
2329// CompiledIlpProgram — plain impl block (GPU export helpers + internal)
2330// ---------------------------------------------------------------------------
2331
2332impl CompiledIlpProgram {
2333    fn collect_relation_example_groups(
2334        &self,
2335        relations_obj: &Bound<'_, PyAny>,
2336        err_msg: &str,
2337    ) -> PyResult<Vec<RelationExampleGroup>> {
2338        let dict = relations_obj
2339            .downcast::<PyDict>()
2340            .map_err(|_| PyValueError::new_err(err_msg.to_string()))?;
2341        let mut groups = Vec::with_capacity(dict.len());
2342        for (name_obj, columns_obj) in dict.iter() {
2343            let relation: String = name_obj.extract()?;
2344            let schema = self.schemas.get(&relation).ok_or_else(|| {
2345                PyValueError::new_err(format!(
2346                    "Unknown relation {} (not present in compiled schemas)",
2347                    relation
2348                ))
2349            })?;
2350            let tensors = collect_dlpack_columns(
2351                &columns_obj,
2352                &format!("Relation {} must be a sequence of DLPack columns", relation),
2353            )?;
2354            let query_buf = self
2355                .provider
2356                .from_dlpack_tensors_with_schema(schema.clone(), tensors)
2357                .map_err(types::xlog_err)?;
2358            let num_rows = u32::try_from(query_buf.num_rows())
2359                .map_err(|_| PyValueError::new_err("relation row count exceeds u32::MAX"))?;
2360            groups.push(RelationExampleGroup {
2361                relation,
2362                query_buf,
2363                num_rows,
2364            });
2365        }
2366        Ok(groups)
2367    }
2368
2369    fn apply_relation_overrides(&mut self) -> xlog_core::Result<()> {
2370        let relation_names: Vec<String> = self.relation_overrides.keys().cloned().collect();
2371        for name in relation_names {
2372            let stored = self
2373                .relation_overrides
2374                .get(&name)
2375                .expect("relation override disappeared during runtime reset");
2376            let live = self.provider.clone_buffer(stored)?;
2377            self.executor.put_relation(&name, live);
2378        }
2379        Ok(())
2380    }
2381
2382    fn evaluate_ilp_plan(&mut self, py: Python<'_>) -> PyResult<()> {
2383        let result: xlog_core::Result<xlog_cuda::CudaBuffer> = py.allow_threads(|| {
2384            self.executor.reset_for_mc();
2385            for (name, schema) in &self.schemas {
2386                let empty = self.provider.create_empty_buffer(schema.clone())?;
2387                self.executor.store_mut().put(name, empty);
2388            }
2389            load_facts_into_store(&self.ast, &self.provider, &mut self.executor, &self.schemas)?;
2390            self.apply_relation_overrides()?;
2391            self.executor.execute_plan(&self.plan)
2392        });
2393        result.map_err(types::xlog_err)?;
2394        Ok(())
2395    }
2396
2397    fn set_rule_mask_sparse_selected_device_impl(
2398        &mut self,
2399        name: String,
2400        selected_candidate_ids_dlpack: &Bound<'_, PyAny>,
2401        selected_soft_probs_dlpack: &Bound<'_, PyAny>,
2402        allow_recursive: bool,
2403        validate_ids: bool,
2404    ) -> PyResult<()> {
2405        let tmj = extract_tmj_meta_for_mask(&self.plan, Some(&name));
2406        let n = tmj.schema_size;
2407        if n == 0 {
2408            return Err(PyValueError::new_err(format!(
2409                "no learnable mask '{}' found",
2410                name
2411            )));
2412        }
2413        if self.compiled_schema_size > 0 && n != self.compiled_schema_size {
2414            return Err(PyValueError::new_err(format!(
2415                "schema_size mismatch for '{}': plan N={} compiled N={}",
2416                name, n, self.compiled_schema_size
2417            )));
2418        }
2419
2420        let _ = allow_recursive;
2421        let candidate_order = self.candidate_order.as_ref().ok_or_else(|| {
2422            PyRuntimeError::new_err(
2423                "candidate order not set — call set_candidate_map() before strict sparse mask updates",
2424            )
2425        })?;
2426        let expected_c = candidate_order.len();
2427
2428        let ids_dmt = dlpack_from_py(selected_candidate_ids_dlpack)?;
2429        let ids_buf = self
2430            .provider
2431            .from_dlpack_tensors(vec![ids_dmt])
2432            .map_err(types::xlog_err)?;
2433        let soft_dmt = dlpack_from_py(selected_soft_probs_dlpack)?;
2434        let soft_buf = self
2435            .provider
2436            .from_dlpack_tensors(vec![soft_dmt])
2437            .map_err(types::xlog_err)?;
2438
2439        let selected_len = usize::try_from(ids_buf.num_rows())
2440            .map_err(|_| PyValueError::new_err("selected candidate ids length overflow"))?;
2441        let soft_len = usize::try_from(soft_buf.num_rows())
2442            .map_err(|_| PyValueError::new_err("selected soft_probs length overflow"))?;
2443        if soft_len != selected_len {
2444            return Err(PyValueError::new_err(format!(
2445                "selected soft_probs length {} != selected candidate ids length {}",
2446                soft_len, selected_len
2447            )));
2448        }
2449
2450        if validate_ids {
2451            self.provider
2452                .validate_selected_ids(&ids_buf, expected_c)
2453                .map_err(types::val_err)?;
2454        }
2455
2456        let active_flags = self
2457            .provider
2458            .build_selected_id_mask(&ids_buf, expected_c)
2459            .map_err(types::xlog_err)?;
2460
2461        self.executor
2462            .ilp_registry_mut()
2463            .insert_selected_mask_device(
2464                name,
2465                n,
2466                candidate_order.clone(),
2467                active_flags,
2468                selected_len,
2469            );
2470        Ok(())
2471    }
2472
2473    fn ensure_host_semantic_compat(&self, api: &str, device_hint: &str) -> PyResult<()> {
2474        if self.strict_zero_dtoh {
2475            return Err(PyRuntimeError::new_err(format!(
2476                "strict_zero_dtoh forbids {}; use {} instead",
2477                api, device_hint
2478            )));
2479        }
2480        Ok(())
2481    }
2482
2483    // ─── GPU loss/grad export helpers ──────────────────────────────────
2484
2485    /// Build zero loss (scalar) + zero grad (num_cands) on GPU, export via DLPack.
2486    fn build_zero_loss_grad(
2487        &self,
2488        py: Python<'_>,
2489        num_cands: u32,
2490        is_f64: bool,
2491    ) -> PyResult<(PyObject, PyObject)> {
2492        if is_f64 {
2493            build_zero_typed::<f64>(&self.provider, py, num_cands, ScalarType::F64)
2494        } else {
2495            build_zero_typed::<f32>(&self.provider, py, num_cands, ScalarType::F32)
2496        }
2497    }
2498
2499    /// Handle the case where COO is empty but we have facts.
2500    /// Facts with no covering entries get: positive -> -log(eps), negative -> -log(1.0) = 0.
2501    /// We still run the kernels with an empty CSR for correctness.
2502    fn build_loss_grad_empty_coo(
2503        &self,
2504        py: Python<'_>,
2505        is_positive_host: &[u8],
2506        num_facts: u32,
2507        num_cands: u32,
2508        is_f64: bool,
2509    ) -> PyResult<(PyObject, PyObject)> {
2510        // Build CSR with all-zero row_offsets (every row has 0 non-zeros)
2511        let row_offsets = vec![0u32; (num_facts + 1) as usize];
2512        let mut d_row_offsets = self
2513            .provider
2514            .memory()
2515            .alloc::<u32>((num_facts + 1) as usize)
2516            .map_err(|e| types::gpu_err("alloc", e))?;
2517        self.provider
2518            .device()
2519            .inner()
2520            .htod_sync_copy_into(&row_offsets, &mut d_row_offsets)
2521            .map_err(|e| types::gpu_err("htod", e))?;
2522
2523        // Empty col_indices
2524        let d_col_indices = self
2525            .provider
2526            .memory()
2527            .alloc::<u32>(0)
2528            .map_err(|e| types::gpu_err("alloc", e))?;
2529
2530        let mut d_is_positive = self
2531            .provider
2532            .memory()
2533            .alloc::<u8>(num_facts as usize)
2534            .map_err(|e| types::gpu_err("alloc", e))?;
2535        self.provider
2536            .device()
2537            .inner()
2538            .htod_sync_copy_into(is_positive_host, &mut d_is_positive)
2539            .map_err(|e| types::gpu_err("htod", e))?;
2540
2541        // Build a single-column CudaBuffer with 0 elements to represent an empty cand_probs
2542        // for the kernel launch (won't be read since row ranges are all empty).
2543        // We need a dummy CudaColumn. Use the actual cand count = num_cands.
2544        // Since COO is empty the kernel won't access cand_probs, but we need to pass something.
2545        let dummy_col: xlog_cuda::CudaColumn = if is_f64 {
2546            self.provider
2547                .memory()
2548                .alloc::<f64>(num_cands.max(1) as usize)
2549                .map_err(|e| types::gpu_err("alloc dummy", e))?
2550                .into_bytes()
2551                .into()
2552        } else {
2553            self.provider
2554                .memory()
2555                .alloc::<f32>(num_cands.max(1) as usize)
2556                .map_err(|e| types::gpu_err("alloc dummy", e))?
2557                .into_bytes()
2558                .into()
2559        };
2560
2561        ilp_gpu::forward_backward_reduce(
2562            &self.provider,
2563            py,
2564            &d_row_offsets,
2565            &d_col_indices,
2566            &dummy_col,
2567            &d_is_positive,
2568            num_facts,
2569            num_cands,
2570            is_f64,
2571        )
2572    }
2573
2574    /// Returns sorted (i,j,k) candidate triples for the given learnable mask.
2575    /// Pruning logic must stay aligned with `valid_candidates`.
2576    fn candidate_triples_for_mask(
2577        &self,
2578        mask_name: &str,
2579        allow_recursive: bool,
2580    ) -> PyResult<Vec<(u32, u32, u32)>> {
2581        let tmj = extract_tmj_meta_for_mask(&self.plan, Some(mask_name));
2582        let n = tmj.schema_size;
2583        if n == 0 {
2584            return Err(PyValueError::new_err(format!(
2585                "no learnable mask '{}' found in compiled program",
2586                mask_name
2587            )));
2588        }
2589        let head_name = &tmj.head_rel_name;
2590        let k_head = self
2591            .rel_index
2592            .iter()
2593            .position(|(_, name)| name == head_name)
2594            .ok_or_else(|| {
2595                PyValueError::new_err(format!(
2596                    "head relation '{}' not in rel_index for mask '{}'",
2597                    head_name, mask_name
2598                ))
2599            })? as u32;
2600
2601        // Identify which relations currently have nonzero tuples in store.
2602        let has_tuples: Vec<bool> = self
2603            .rel_index
2604            .iter()
2605            .map(|(_, name)| {
2606                self.executor
2607                    .store()
2608                    .get(name)
2609                    .map(|buf| buf.num_rows() > 0)
2610                    .unwrap_or(false)
2611            })
2612            .collect();
2613
2614        let mut triples: Vec<(u32, u32, u32)> = Vec::new();
2615        for i in 0..n as u32 {
2616            for j in 0..n as u32 {
2617                let k = k_head;
2618
2619                // Prune template+template (both no tuples).
2620                if !has_tuples[i as usize] && !has_tuples[j as usize] {
2621                    continue;
2622                }
2623
2624                // Keep behavior aligned with existing alpha candidate pruning:
2625                // recursive body refs are allowed only if head already has tuples.
2626                if !allow_recursive && (i == k || j == k) && !has_tuples[k as usize] {
2627                    continue;
2628                }
2629
2630                triples.push((i, j, k));
2631            }
2632        }
2633        triples.sort_by_key(|&(i, j, k)| (k, i, j));
2634        Ok(triples)
2635    }
2636
2637    fn expected_candidate_count(&self, mask_name: &str, allow_recursive: bool) -> PyResult<usize> {
2638        Ok(self
2639            .candidate_triples_for_mask(mask_name, allow_recursive)?
2640            .len())
2641    }
2642
2643    fn fact_exists_in_buffer(
2644        provider: &CudaKernelProvider,
2645        buf: &xlog_cuda::CudaBuffer,
2646        values: &[i64],
2647    ) -> xlog_core::Result<bool> {
2648        use xlog_core::XlogError;
2649        let num_rows = read_device_row_count(provider, buf)? as usize;
2650        if num_rows == 0 {
2651            return Ok(false);
2652        }
2653        if values.len() != buf.arity() {
2654            return Ok(false);
2655        }
2656
2657        let schema = buf.schema();
2658        let mut columns: Vec<Vec<i64>> = Vec::new();
2659        for col_idx in 0..buf.arity() {
2660            let col_type = schema.column_type(col_idx).ok_or_else(|| {
2661                XlogError::Kernel(format!("Column {} type not found in schema", col_idx))
2662            })?;
2663            let col_i64: Vec<i64> = match col_type {
2664                ScalarType::I64 => provider.download_column::<i64>(buf, col_idx)?,
2665                ScalarType::I32 => provider
2666                    .download_column::<i32>(buf, col_idx)?
2667                    .into_iter()
2668                    .map(|v| v as i64)
2669                    .collect(),
2670                ScalarType::U32 | ScalarType::Symbol => provider
2671                    .download_column::<u32>(buf, col_idx)?
2672                    .into_iter()
2673                    .map(|v| v as i64)
2674                    .collect(),
2675                ScalarType::U64 => {
2676                    let col_u64 = provider.download_column::<u64>(buf, col_idx)?;
2677                    col_u64.into_iter().map(|v| v as i64).collect()
2678                }
2679                ScalarType::Bool => provider
2680                    .download_column::<bool>(buf, col_idx)?
2681                    .into_iter()
2682                    .map(|v| if v { 1i64 } else { 0i64 })
2683                    .collect(),
2684                ScalarType::F32 | ScalarType::F64 => {
2685                    return Err(XlogError::Kernel(format!(
2686                        "fact_exists does not support float column type {:?}",
2687                        col_type
2688                    )));
2689                }
2690            };
2691            columns.push(col_i64);
2692        }
2693
2694        for row in 0..num_rows {
2695            let mut matches = true;
2696            for (col_idx, val) in values.iter().enumerate() {
2697                if columns[col_idx][row] != *val {
2698                    matches = false;
2699                    break;
2700                }
2701            }
2702            if matches {
2703                return Ok(true);
2704            }
2705        }
2706        Ok(false)
2707    }
2708
2709    /// Like fact_exists_in_buffer but checks only the projected columns.
2710    /// `projection[i]` is the column index in `buf` that corresponds to
2711    /// `values[i]` in the head relation.
2712    fn fact_exists_projected(
2713        provider: &CudaKernelProvider,
2714        buf: &xlog_cuda::CudaBuffer,
2715        values: &[i64],
2716        projection: &[usize],
2717    ) -> xlog_core::Result<bool> {
2718        use xlog_core::XlogError;
2719        let num_rows = read_device_row_count(provider, buf)? as usize;
2720        if num_rows == 0 {
2721            return Ok(false);
2722        }
2723
2724        let schema = buf.schema();
2725        let mut columns: Vec<Vec<i64>> = Vec::new();
2726        for &col_idx in projection {
2727            if col_idx >= buf.arity() {
2728                return Ok(false);
2729            }
2730            let col_type = schema.column_type(col_idx).ok_or_else(|| {
2731                XlogError::Kernel(format!("Column {} type not found in schema", col_idx))
2732            })?;
2733            let col_i64: Vec<i64> = match col_type {
2734                ScalarType::I64 => provider.download_column::<i64>(buf, col_idx)?,
2735                ScalarType::I32 => provider
2736                    .download_column::<i32>(buf, col_idx)?
2737                    .into_iter()
2738                    .map(|v| v as i64)
2739                    .collect(),
2740                ScalarType::U32 | ScalarType::Symbol => provider
2741                    .download_column::<u32>(buf, col_idx)?
2742                    .into_iter()
2743                    .map(|v| v as i64)
2744                    .collect(),
2745                ScalarType::U64 => provider
2746                    .download_column::<u64>(buf, col_idx)?
2747                    .into_iter()
2748                    .map(|v| v as i64)
2749                    .collect(),
2750                ScalarType::Bool => provider
2751                    .download_column::<bool>(buf, col_idx)?
2752                    .into_iter()
2753                    .map(|v| if v { 1i64 } else { 0i64 })
2754                    .collect(),
2755                ScalarType::F32 | ScalarType::F64 => {
2756                    return Err(XlogError::Kernel(format!(
2757                        "fact_exists does not support float column type {:?}",
2758                        col_type
2759                    )));
2760                }
2761            };
2762            columns.push(col_i64);
2763        }
2764
2765        for row in 0..num_rows {
2766            let mut matches = true;
2767            for (i, val) in values.iter().enumerate() {
2768                if columns[i][row] != *val {
2769                    matches = false;
2770                    break;
2771                }
2772            }
2773            if matches {
2774                return Ok(true);
2775            }
2776        }
2777        Ok(false)
2778    }
2779}