1use std::collections::HashMap as StdHashMap;
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use pyo3::exceptions::{PyRuntimeError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::{PyDict, PySequence};
12
13use xlog_core::{symbol, RelId, ScalarType, Schema};
14use xlog_cuda::{CudaKernelProvider, JoinType};
15use xlog_ir::{ExecutionPlan, RirNode};
16use xlog_logic::ast::{Program as AstProgram, Term, TypeRef};
17use xlog_prob::exact::GpuConfig;
18use xlog_runtime::ilp_registry::IlpMask;
19use xlog_runtime::{read_device_row_count, Executor};
20
21use xlog_cuda::type_seam::GpuScalar;
22
23use super::{
24 dlpack_capsule_from_tensor, dlpack_from_py, ilp_gpu, provider_from_config, types,
25 CompiledIlpProgram, IlpProgramFactory, IlpTaggedCreditDeviceResult,
26};
27
28struct RelationExampleGroup {
33 relation: String,
34 query_buf: xlog_cuda::CudaBuffer,
35 num_rows: u32,
36}
37
38fn type_ref_name(typ: &TypeRef) -> String {
39 match typ {
40 TypeRef::Scalar(scalar) => types::scalar_type_name(scalar),
41 TypeRef::Domain(name) => name.clone(),
42 TypeRef::List(inner) => format!("list<{}>", type_ref_name(inner)),
43 TypeRef::Term => "term".to_string(),
44 TypeRef::Compound => "compound".to_string(),
45 TypeRef::PredRef => "predref".to_string(),
46 }
47}
48
49fn collect_dlpack_columns(
50 dlpack_columns: &Bound<'_, PyAny>,
51 err_msg: &str,
52) -> PyResult<Vec<xlog_cuda::DlpackManagedTensor>> {
53 let seq = dlpack_columns
54 .downcast::<PySequence>()
55 .map_err(|_| PyValueError::new_err(err_msg.to_string()))?;
56 let mut tensors = Vec::with_capacity(seq.len()?);
57 for item in seq.try_iter()? {
58 tensors.push(dlpack_from_py(&item?)?);
59 }
60 Ok(tensors)
61}
62
63fn build_zero_typed<T>(
66 provider: &CudaKernelProvider,
67 py: Python<'_>,
68 num_cands: u32,
69 scalar_type: ScalarType,
70) -> PyResult<(PyObject, PyObject)>
71where
72 T: GpuScalar + cudarc::driver::ValidAsZeroBits,
73{
74 let mut d_grad = provider
75 .memory()
76 .alloc::<T>(num_cands as usize)
77 .map_err(|e| types::gpu_err("alloc grad", e))?;
78 if num_cands > 0 {
79 provider
80 .device()
81 .inner()
82 .memset_zeros(&mut d_grad)
83 .map_err(|e| types::gpu_err("zero grad", e))?;
84 }
85 let mut d_loss = provider
86 .memory()
87 .alloc::<T>(1)
88 .map_err(|e| types::gpu_err("alloc loss", e))?;
89 provider
90 .device()
91 .inner()
92 .memset_zeros(&mut d_loss)
93 .map_err(|e| types::gpu_err("zero loss", e))?;
94 ilp_gpu::export_loss_grad_device(provider, py, d_loss, d_grad, num_cands, scalar_type)
95}
96
97fn export_device_bool_tensor(
98 provider: &CudaKernelProvider,
99 py: Python<'_>,
100 values: xlog_cuda::memory::TrackedCudaSlice<u8>,
101 rows: usize,
102) -> PyResult<PyObject> {
103 let rows_u32 = u32::try_from(rows)
104 .map_err(|_| PyValueError::new_err(format!("Row count {} exceeds u32::MAX", rows)))?;
105
106 let mut d_num_rows = provider.memory().alloc::<u32>(1).map_err(types::xlog_err)?;
107 provider
108 .device()
109 .inner()
110 .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
111 .map_err(types::xlog_err)?;
112
113 let buffer = xlog_cuda::CudaBuffer::from_columns(
114 vec![values.into_bytes().into()],
115 rows as u64,
116 d_num_rows,
117 Schema::new(vec![("col0".to_string(), ScalarType::Bool)]),
118 );
119 let tensor = provider
120 .to_dlpack_table(buffer)
121 .column(0)
122 .map_err(types::xlog_err)?;
123 dlpack_capsule_from_tensor(py, tensor)
124}
125
126fn export_device_u32_tensor_as_i32(
127 provider: &CudaKernelProvider,
128 py: Python<'_>,
129 values: xlog_cuda::memory::TrackedCudaSlice<u32>,
130 rows: usize,
131) -> PyResult<PyObject> {
132 let rows_u32 = u32::try_from(rows)
133 .map_err(|_| PyValueError::new_err(format!("Row count {} exceeds u32::MAX", rows)))?;
134
135 let mut d_num_rows = provider.memory().alloc::<u32>(1).map_err(types::xlog_err)?;
136 provider
137 .device()
138 .inner()
139 .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
140 .map_err(types::xlog_err)?;
141
142 let buffer = xlog_cuda::CudaBuffer::from_columns(
144 vec![values.into_bytes().into()],
145 rows as u64,
146 d_num_rows,
147 Schema::new(vec![("col0".to_string(), ScalarType::I32)]),
148 );
149 let tensor = provider
150 .to_dlpack_table(buffer)
151 .column(0)
152 .map_err(types::xlog_err)?;
153 dlpack_capsule_from_tensor(py, tensor)
154}
155
156fn empty_tagged_credit_device_result(
157 provider: &CudaKernelProvider,
158 py: Python<'_>,
159 num_facts: usize,
160) -> PyResult<IlpTaggedCreditDeviceResult> {
161 let mut d_row_offsets = provider
162 .memory()
163 .alloc::<u32>(num_facts + 1)
164 .map_err(types::xlog_err)?;
165 provider
166 .device()
167 .inner()
168 .memset_zeros(&mut d_row_offsets)
169 .map_err(types::xlog_err)?;
170 let d_empty_indices = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
171 let d_empty_i = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
172 let d_empty_j = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
173 let d_empty_k = provider.memory().alloc::<u32>(0).map_err(types::xlog_err)?;
174
175 Ok(IlpTaggedCreditDeviceResult {
176 fact_row_offsets: export_device_u32_tensor_as_i32(
177 provider,
178 py,
179 d_row_offsets,
180 num_facts + 1,
181 )?,
182 entry_indices: export_device_u32_tensor_as_i32(provider, py, d_empty_indices, 0)?,
183 entry_i: export_device_u32_tensor_as_i32(provider, py, d_empty_i, 0)?,
184 entry_j: export_device_u32_tensor_as_i32(provider, py, d_empty_j, 0)?,
185 entry_k: export_device_u32_tensor_as_i32(provider, py, d_empty_k, 0)?,
186 })
187}
188
189fn push_term_bytes(out: &mut Vec<u8>, term: &Term, typ: ScalarType) -> xlog_core::Result<()> {
190 use xlog_core::XlogError;
191 match (typ, term) {
192 (ScalarType::U32, Term::Integer(v)) => {
193 let v = u32::try_from(*v)
194 .map_err(|_| XlogError::Execution(format!("u32 out of range: {}", v)))?;
195 out.extend_from_slice(&v.to_le_bytes());
196 }
197 (ScalarType::U64, Term::Integer(v)) => {
198 let v = u64::try_from(*v)
199 .map_err(|_| XlogError::Execution(format!("u64 out of range: {}", v)))?;
200 out.extend_from_slice(&v.to_le_bytes());
201 }
202 (ScalarType::I32, Term::Integer(v)) => {
203 let v = i32::try_from(*v)
204 .map_err(|_| XlogError::Execution(format!("i32 out of range: {}", v)))?;
205 out.extend_from_slice(&v.to_le_bytes());
206 }
207 (ScalarType::I64, Term::Integer(v)) => {
208 out.extend_from_slice(&v.to_le_bytes());
209 }
210 (ScalarType::F32, Term::Float(v)) => {
211 out.extend_from_slice(&(*v as f32).to_le_bytes());
212 }
213 (ScalarType::F64, Term::Float(v)) => {
214 out.extend_from_slice(&v.to_le_bytes());
215 }
216 (ScalarType::F32, Term::Integer(v)) => {
217 out.extend_from_slice(&(*v as f32).to_le_bytes());
218 }
219 (ScalarType::F64, Term::Integer(v)) => {
220 out.extend_from_slice(&(*v as f64).to_le_bytes());
221 }
222 (ScalarType::Bool, Term::Integer(v)) => {
223 let b = match *v {
224 0 => 0u8,
225 1 => 1u8,
226 other => {
227 return Err(XlogError::Execution(format!(
228 "bool expects 0/1, got {}",
229 other
230 )));
231 }
232 };
233 out.push(b);
234 }
235 (ScalarType::Bool, Term::Symbol(id)) => {
236 let s = symbol::resolve(*id);
237 if s == "true" || s == "false" {
238 out.push(if s == "true" { 1u8 } else { 0u8 });
239 } else {
240 return Err(XlogError::Execution(format!(
241 "Expected boolean symbol, got '{}'",
242 s
243 )));
244 }
245 }
246 (ScalarType::Symbol, Term::String(s)) => {
247 out.extend_from_slice(&symbol::intern(s).to_le_bytes());
248 }
249 (ScalarType::Symbol, Term::Symbol(id)) => {
250 out.extend_from_slice(&id.to_le_bytes());
251 }
252 (_, Term::Variable(v)) => {
253 return Err(XlogError::Execution(format!(
254 "Fact cannot contain variable {}",
255 v
256 )));
257 }
258 (_, Term::Anonymous) => {
259 return Err(XlogError::Execution(
260 "Fact cannot contain anonymous wildcard '_'".into(),
261 ));
262 }
263 (_, Term::Aggregate(_)) => {
264 return Err(XlogError::Execution("Fact cannot contain aggregate".into()));
265 }
266 (expected, got) => {
267 return Err(XlogError::Execution(format!(
268 "Type mismatch: expected {:?}, got {:?}",
269 expected, got
270 )));
271 }
272 }
273 Ok(())
274}
275
276pub(crate) fn pack_i64_columns_typed(
280 relation: &str,
281 facts: &[Vec<i64>],
282 schema: &Schema,
283) -> PyResult<Vec<Vec<u8>>> {
284 let arity = schema.arity();
285 for (idx, fact) in facts.iter().enumerate() {
286 if fact.len() != arity {
287 return Err(PyValueError::new_err(format!(
288 "Relation '{}': fact {} has {} values, expected {}",
289 relation,
290 idx,
291 fact.len(),
292 arity,
293 )));
294 }
295 }
296
297 let mut columns: Vec<Vec<u8>> = (0..arity)
298 .map(|col_idx| {
299 let elem_size = schema
300 .column_type(col_idx)
301 .map(|t| t.size_bytes())
302 .unwrap_or(4);
303 Vec::with_capacity(facts.len() * elem_size)
304 })
305 .collect();
306
307 for fact in facts {
308 for (col_idx, &val) in fact.iter().enumerate() {
309 let col_type = schema.column_type(col_idx);
310 let col = &mut columns[col_idx];
311 match col_type {
312 Some(ScalarType::U32) => {
313 let v = u32::try_from(val).map_err(|_| {
314 PyValueError::new_err(format!(
315 "Relation '{}' column {} (U32): value {} out of range [0, {}]",
316 relation,
317 col_idx,
318 val,
319 u32::MAX,
320 ))
321 })?;
322 col.extend_from_slice(&v.to_le_bytes());
323 }
324 Some(ScalarType::I32) => {
325 let v = i32::try_from(val).map_err(|_| {
326 PyValueError::new_err(format!(
327 "Relation '{}' column {} (I32): value {} out of range [{}, {}]",
328 relation,
329 col_idx,
330 val,
331 i32::MIN,
332 i32::MAX,
333 ))
334 })?;
335 col.extend_from_slice(&v.to_le_bytes());
336 }
337 Some(ScalarType::U64) => {
338 let v = u64::try_from(val).map_err(|_| PyValueError::new_err(format!(
339 "Relation '{}' column {} (U64): value {} is negative; U64 requires non-negative values",
340 relation, col_idx, val,
341 )))?;
342 col.extend_from_slice(&v.to_le_bytes());
343 }
344 Some(ScalarType::I64) => {
345 col.extend_from_slice(&val.to_le_bytes());
346 }
347 Some(ScalarType::Bool) => match val {
348 0 => col.push(0u8),
349 1 => col.push(1u8),
350 _ => {
351 return Err(PyValueError::new_err(format!(
352 "Relation '{}' column {} (Bool): value {} not in {{0, 1}}",
353 relation, col_idx, val,
354 )))
355 }
356 },
357 Some(ScalarType::Symbol) => {
358 let v = u32::try_from(val).map_err(|_| {
359 PyValueError::new_err(format!(
360 "Relation '{}' column {} (Symbol): value {} out of range [0, {}]",
361 relation,
362 col_idx,
363 val,
364 u32::MAX,
365 ))
366 })?;
367 col.extend_from_slice(&v.to_le_bytes());
368 }
369 Some(ScalarType::F32) => {
370 return Err(PyValueError::new_err(format!(
371 "Relation '{}' column {} (F32): float columns not supported in batch APIs",
372 relation, col_idx,
373 )));
374 }
375 Some(ScalarType::F64) => {
376 return Err(PyValueError::new_err(format!(
377 "Relation '{}' column {} (F64): float columns not supported in batch APIs",
378 relation, col_idx,
379 )));
380 }
381 None => {
382 return Err(PyValueError::new_err(format!(
383 "Relation '{}' column {}: no type in schema",
384 relation, col_idx,
385 )));
386 }
387 }
388 }
389 }
390
391 Ok(columns)
392}
393
394pub(crate) fn load_facts_into_store(
395 ast: &AstProgram,
396 provider: &CudaKernelProvider,
397 executor: &mut Executor,
398 schemas: &HashMap<String, Schema>,
399) -> xlog_core::Result<()> {
400 use xlog_core::XlogError;
401 let mut rows_by_pred: HashMap<&str, Vec<&[Term]>> = HashMap::new();
402 for fact in ast.facts() {
403 rows_by_pred
404 .entry(fact.head.predicate.as_str())
405 .or_default()
406 .push(&fact.head.terms);
407 }
408
409 for (pred, rows) in rows_by_pred {
410 let schema = schemas.get(pred).ok_or_else(|| {
411 XlogError::Execution(format!("Missing schema for fact predicate {}", pred))
412 })?;
413
414 if rows.iter().any(|r| r.len() != schema.arity()) {
415 return Err(XlogError::Execution(format!(
416 "Fact arity mismatch for {} (expected {})",
417 pred,
418 schema.arity()
419 )));
420 }
421
422 let mut columns: Vec<Vec<u8>> = vec![Vec::new(); schema.arity()];
423 for row in &rows {
424 for (col_idx, term) in row.iter().enumerate() {
425 let typ = schema.column_type(col_idx).ok_or_else(|| {
426 XlogError::Execution(format!("Missing type for col {}", col_idx))
427 })?;
428 push_term_bytes(&mut columns[col_idx], term, typ)?;
429 }
430 }
431
432 let slices: Vec<&[u8]> = columns.iter().map(|c| c.as_slice()).collect();
433 let fact_buf = provider.create_buffer_from_slices(&slices, schema.clone())?;
434
435 let existing = executor.store().get(pred).ok_or_else(|| {
436 XlogError::Execution(format!(
437 "Missing base relation {} while loading facts",
438 pred
439 ))
440 })?;
441 let merged = provider.union(existing, &fact_buf)?;
442 executor.store_mut().put(pred, merged);
443 }
444 Ok(())
445}
446
447struct TmjMeta {
449 left_keys: Vec<usize>,
450 right_keys: Vec<usize>,
451 head_projection: Vec<usize>,
452 schema_size: usize,
453 head_rel_name: String,
454}
455
456fn walk_tmj(node: &RirNode, target_mask: Option<&str>) -> Option<TmjMeta> {
457 match node {
458 RirNode::TensorMaskedJoin {
459 mask_name,
460 left_keys,
461 right_keys,
462 head_projection,
463 schema_size,
464 head_rel_name,
465 ..
466 } => {
467 if target_mask.is_none() || target_mask == Some(mask_name.as_str()) {
468 Some(TmjMeta {
469 left_keys: left_keys.clone(),
470 right_keys: right_keys.clone(),
471 head_projection: head_projection.clone(),
472 schema_size: *schema_size,
473 head_rel_name: head_rel_name.clone(),
474 })
475 } else {
476 None
477 }
478 }
479 RirNode::Fixpoint {
480 base, recursive, ..
481 } => walk_tmj(base, target_mask).or_else(|| walk_tmj(recursive, target_mask)),
482 RirNode::Union { inputs } => inputs.iter().find_map(|n| walk_tmj(n, target_mask)),
483 RirNode::Filter { input, .. }
484 | RirNode::Project { input, .. }
485 | RirNode::Distinct { input, .. }
486 | RirNode::GroupBy { input, .. } => walk_tmj(input, target_mask),
487 RirNode::Join { left, right, .. } | RirNode::Diff { left, right } => {
488 walk_tmj(left, target_mask).or_else(|| walk_tmj(right, target_mask))
489 }
490 RirNode::MultiWayJoin { fallback, .. } => walk_tmj(fallback, target_mask),
495 RirNode::ChainJoin { fallback, .. } => walk_tmj(fallback, target_mask),
496 _ => None,
497 }
498}
499
500fn extract_tmj_meta(plan: &ExecutionPlan) -> TmjMeta {
501 extract_tmj_meta_for_mask(plan, None)
502}
503
504fn extract_tmj_meta_for_mask(plan: &ExecutionPlan, mask_name: Option<&str>) -> TmjMeta {
505 for scc_rules in &plan.rules_by_scc {
506 for rule in scc_rules {
507 if let Some(meta) = walk_tmj(&rule.body, mask_name) {
508 return meta;
509 }
510 }
511 }
512 TmjMeta {
513 left_keys: vec![],
514 right_keys: vec![],
515 head_projection: vec![],
516 schema_size: 0,
517 head_rel_name: String::new(),
518 }
519}
520
521fn strip_learnable_declarations(source: &str) -> String {
522 source
523 .lines()
524 .filter(|line| !line.trim_start().starts_with("learnable("))
525 .collect::<Vec<_>>()
526 .join("\n")
527}
528
529fn extract_learnable_declarations(source: &str) -> String {
530 source
531 .lines()
532 .filter(|line| line.trim_start().starts_with("learnable("))
533 .collect::<Vec<_>>()
534 .join("\n")
535}
536
537#[pymethods]
542impl IlpProgramFactory {
543 #[staticmethod]
544 #[pyo3(signature = (source, device=0, memory_mb=512, max_active_rules=None))]
545 pub fn compile(
546 source: &str,
547 device: usize,
548 memory_mb: u64,
549 max_active_rules: Option<usize>,
550 ) -> PyResult<CompiledIlpProgram> {
551 if let Some(max) = max_active_rules {
553 if !(16..=128).contains(&max) {
554 return Err(PyValueError::new_err(format!(
555 "max_active_rules must be between 16 and 128, got {}",
556 max
557 )));
558 }
559 }
560
561 let ast = xlog_logic::parse_program(source).map_err(types::val_err)?;
562
563 let base_source = strip_learnable_declarations(source);
564 let learnable_source = extract_learnable_declarations(source);
565
566 let mut compiler = xlog_logic::Compiler::new();
567 if let Some(max) = max_active_rules {
568 compiler.set_max_active_rules(max);
569 }
570 let plan = compiler.compile_program(&ast).map_err(types::xlog_err)?;
571
572 let mut rel_index: Vec<(RelId, String)> = compiler
573 .rel_ids()
574 .iter()
575 .map(|(name, id)| (*id, name.clone()))
576 .collect();
577 rel_index.sort_by_key(|(id, _)| id.0);
578 let schemas = compiler.schemas().clone();
579
580 let mut config = GpuConfig::default();
581 config.device_ordinal = device;
582 config.memory_bytes = memory_mb * 1024 * 1024;
583 let provider = Arc::new(provider_from_config(config).map_err(types::xlog_err)?);
584
585 let mut executor = Executor::new(provider.clone());
586
587 for (name, rel_id) in compiler.rel_ids() {
588 executor.register_relation(*rel_id, name);
589 }
590
591 for (name, schema) in &schemas {
592 let empty = provider
593 .create_empty_buffer(schema.clone())
594 .map_err(types::xlog_err)?;
595 executor.store_mut().put(name, empty);
596 }
597
598 load_facts_into_store(&ast, &provider, &mut executor, &schemas).map_err(types::xlog_err)?;
599
600 executor.execute_plan(&plan).map_err(types::xlog_err)?;
601
602 let tmj = extract_tmj_meta(&plan);
603
604 let active_rules = max_active_rules.unwrap_or(32);
605
606 Ok(CompiledIlpProgram {
607 base_source,
608 _learnable_source: learnable_source,
609 ast,
610 executor,
611 provider,
612 plan,
613 rel_index,
614 schemas,
615 left_keys: tmj.left_keys,
616 right_keys: tmj.right_keys,
617 head_projection: tmj.head_projection,
618 compiled_schema_size: tmj.schema_size,
619 head_rel_name: tmj.head_rel_name,
620 max_active_rules: active_rules,
621 candidate_map: None,
622 candidate_order: None,
623 relation_overrides: HashMap::new(),
624 coo_chunk_budget: 16 * 1024 * 1024,
625 strict_zero_dtoh: false,
626 })
627 }
628}
629
630#[pymethods]
635impl CompiledIlpProgram {
636 pub fn set_candidate_map(&mut self, candidates: Vec<(u32, u32, u32)>) -> PyResult<()> {
638 let mut map = HashMap::with_capacity(candidates.len());
639 for (cidx, &(i, j, k)) in candidates.iter().enumerate() {
640 map.insert((i, j, k), cidx as u32);
641 }
642 self.candidate_map = Some(map);
643 self.candidate_order = Some(candidates);
644 Ok(())
645 }
646
647 pub fn candidate_map_len(&self) -> usize {
649 self.candidate_map.as_ref().map_or(0, |m| m.len())
650 }
651
652 pub fn debug_ilp_mask_kind(&self, name: String) -> Option<String> {
653 self.executor.ilp_registry().get_mask(&name).map(|mask| {
654 match mask {
655 IlpMask::Dense { .. } => "dense",
656 IlpMask::Sparse { .. } => "sparse_host",
657 IlpMask::SparseDevice { .. } => "sparse_device",
658 }
659 .to_string()
660 })
661 }
662
663 pub fn set_coo_chunk_budget(&mut self, bytes: u64) {
666 self.coo_chunk_budget = bytes;
667 }
668
669 #[allow(deprecated)]
671 pub fn set_coo_memory_cap(&mut self, bytes: u64) {
672 self.coo_chunk_budget = bytes;
673 }
674
675 pub fn set_strict_zero_dtoh(&mut self, strict: bool) {
679 self.strict_zero_dtoh = strict;
680 }
681
682 pub fn put_relation(
688 &mut self,
689 name: String,
690 dlpack_columns: &Bound<'_, PyAny>,
691 ) -> PyResult<()> {
692 if name.starts_with("__") {
693 return Err(PyValueError::new_err(format!(
694 "Relation {} is internal and cannot be stored in a compiled ILP program",
695 name
696 )));
697 }
698 let schema = self.schemas.get(&name).ok_or_else(|| {
699 PyValueError::new_err(format!(
700 "Unknown relation {} (not present in compiled schemas)",
701 name
702 ))
703 })?;
704 let tensors = collect_dlpack_columns(
705 dlpack_columns,
706 &format!("Relation {} must be a sequence of DLPack columns", name),
707 )?;
708 let buffer = self
709 .provider
710 .from_dlpack_tensors_with_schema(schema.clone(), tensors)
711 .map_err(types::xlog_err)?;
712 let live_buffer = self
713 .provider
714 .clone_buffer(&buffer)
715 .map_err(types::xlog_err)?;
716 self.relation_overrides.insert(name.clone(), buffer);
717 self.executor.put_relation(&name, live_buffer);
718 Ok(())
719 }
720
721 pub fn compute_ilp_loss_grad_gpu<'py>(
733 &mut self,
734 py: Python<'py>,
735 positives: Vec<(String, Vec<i64>)>,
736 negatives: Vec<(String, Vec<i64>)>,
737 cand_probs_obj: &Bound<'py, PyAny>,
738 ) -> PyResult<(PyObject, PyObject)> {
739 let candidate_map = self.candidate_map.clone().ok_or_else(|| {
741 PyRuntimeError::new_err(
742 "candidate_map not set — call set_candidate_map() before compute_ilp_loss_grad_gpu()",
743 )
744 })?;
745 let num_cands = candidate_map.len() as u32;
746
747 let managed = dlpack_from_py(cand_probs_obj)?;
748 let cand_buf = self
749 .provider
750 .from_dlpack_tensors(vec![managed])
751 .map_err(|e| types::gpu_err("DLPack import", e))?;
752
753 let cand_schema = cand_buf.schema().clone();
755 let cand_dtype = cand_schema
756 .column_type(0)
757 .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column type"))?;
758 let is_f64 = match cand_dtype {
759 ScalarType::F32 => false,
760 ScalarType::F64 => true,
761 other => {
762 return Err(PyValueError::new_err(format!(
763 "cand_probs must be F32 or F64, got {:?}",
764 other
765 )));
766 }
767 };
768
769 let cand_rows = read_device_row_count(&self.provider, &cand_buf)
770 .map_err(|e| types::gpu_err("row count", e))?;
771 if cand_rows != num_cands as usize {
772 return Err(PyValueError::new_err(format!(
773 "cand_probs length ({}) != candidate_map length ({})",
774 cand_rows, num_cands
775 )));
776 }
777
778 let num_pos = positives.len();
780 let num_neg = negatives.len();
781 let num_facts = (num_pos + num_neg) as u32;
782
783 struct FactInfo {
785 relation: String,
786 values: Vec<i64>,
787 is_positive: bool,
788 }
789 let mut all_facts: Vec<FactInfo> = Vec::with_capacity(num_pos + num_neg);
790 for (rel, vals) in positives {
791 all_facts.push(FactInfo {
792 relation: rel,
793 values: vals,
794 is_positive: true,
795 });
796 }
797 for (rel, vals) in negatives {
798 all_facts.push(FactInfo {
799 relation: rel,
800 values: vals,
801 is_positive: false,
802 });
803 }
804
805 if num_facts == 0 {
807 return self.build_zero_loss_grad(py, num_cands, is_f64);
808 }
809
810 let is_positive_host: Vec<u8> = all_facts
812 .iter()
813 .map(|f| if f.is_positive { 1u8 } else { 0u8 })
814 .collect();
815
816 let mut groups: HashMap<String, Vec<(u32, Vec<i64>)>> = HashMap::new();
818 for (global_idx, fact) in all_facts.iter().enumerate() {
819 groups
820 .entry(fact.relation.clone())
821 .or_default()
822 .push((global_idx as u32, fact.values.clone()));
823 }
824
825 if self.strict_zero_dtoh && self.executor.ilp_registry().has_sparse_device_mask() {
826 self.evaluate_ilp_plan(py)?;
827 }
828
829 let tagged = self
831 .executor
832 .ilp_last_result()
833 .ok_or_else(|| PyRuntimeError::new_err("No ILP result — call evaluate() first"))?;
834
835 let mut tasks: Vec<ilp_gpu::CooTask> = Vec::new();
848 let mut fact_indices_buffers: Vec<xlog_cuda::memory::TrackedCudaSlice<u32>> = Vec::new();
849
850 for (relation, facts_with_idx) in &groups {
851 let k_idx = self
852 .rel_index
853 .iter()
854 .position(|(_, name)| name == relation)
855 .ok_or_else(|| {
856 PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
857 })? as u32;
858
859 let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
860 .entries
861 .iter()
862 .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
863 .collect();
864
865 if relevant_entries.is_empty() {
866 continue;
867 }
868
869 let first_buf = relevant_entries[0]
870 .buffer
871 .as_ref()
872 .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
873 let arity = first_buf.arity();
874 if arity == 0 {
875 continue;
876 }
877 let schema = first_buf.schema().clone();
878
879 let fact_values: Vec<Vec<i64>> =
880 facts_with_idx.iter().map(|(_, v)| v.clone()).collect();
881 let col_bytes = pack_i64_columns_typed(relation, &fact_values, &schema)?;
882 let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
883 let query_buf = self
884 .provider
885 .create_buffer_from_slices(&col_slices, schema)
886 .map_err(|e| types::gpu_err("create_buffer", e))?;
887
888 let keys: Vec<usize> = (0..arity).collect();
889 let num_query = fact_values.len() as u32;
890
891 let global_indices: Vec<u32> = facts_with_idx.iter().map(|(idx, _)| *idx).collect();
894 let mut d_fact_indices = self
895 .provider
896 .memory()
897 .alloc::<u32>(num_query as usize)
898 .map_err(|e| types::gpu_err("alloc fact_indices", e))?;
899 self.provider
900 .device()
901 .inner()
902 .htod_sync_copy_into(&global_indices, &mut d_fact_indices)
903 .map_err(|e| types::gpu_err("htod fact_indices", e))?;
904 let fi_idx = fact_indices_buffers.len();
905 fact_indices_buffers.push(d_fact_indices);
906
907 for entry in &relevant_entries {
908 let cidx = match candidate_map.get(&(entry.i, entry.j, entry.k)) {
909 Some(c) => *c,
910 None => continue,
911 };
912
913 let entry_buf = entry.buffer.as_ref().ok_or_else(|| {
914 PyRuntimeError::new_err("internal: filtered entry has no buffer")
915 })?;
916 let d_mask = self
917 .provider
918 .membership_mask_device(&query_buf, entry_buf, &keys, &keys)
919 .map_err(|e| types::gpu_err("membership_mask", e))?;
920
921 tasks.push(ilp_gpu::CooTask {
922 cidx,
923 num_query,
924 d_mask,
925 fact_indices_idx: fi_idx,
926 });
927 }
928 }
929
930 let num_tasks = tasks.len();
931 let upper_bound: u32 = tasks.iter().map(|t| t.num_query).sum();
932
933 if upper_bound == 0 || num_tasks == 0 {
934 return self.build_loss_grad_empty_coo(
935 py,
936 &is_positive_host,
937 num_facts,
938 num_cands,
939 is_f64,
940 );
941 }
942
943 let coo_bytes = (upper_bound as u64) * 8;
946 let needs_chunking = coo_bytes > self.coo_chunk_budget;
947
948 let mut d_is_positive = self
950 .provider
951 .memory()
952 .alloc::<u8>(num_facts as usize)
953 .map_err(|e| types::gpu_err("alloc is_positive", e))?;
954 self.provider
955 .device()
956 .inner()
957 .htod_sync_copy_into(&is_positive_host, &mut d_is_positive)
958 .map_err(|e| types::gpu_err("htod is_positive", e))?;
959
960 let cand_col = cand_buf
961 .column(0)
962 .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column"))?;
963
964 let (mut d_coo_facts, mut d_coo_cands, actual_nnz) = if !needs_chunking {
966 ilp_gpu::build_coo_single(
967 &self.provider,
968 &tasks,
969 &fact_indices_buffers,
970 num_facts,
971 num_cands,
972 upper_bound,
973 )?
974 } else {
975 match ilp_gpu::build_coo_chunked(
976 &self.provider,
977 &tasks,
978 &fact_indices_buffers,
979 num_facts,
980 self.coo_chunk_budget,
981 )? {
982 Some(result) => result,
983 None => {
984 return self.build_loss_grad_empty_coo(
985 py,
986 &is_positive_host,
987 num_facts,
988 num_cands,
989 is_f64,
990 );
991 }
992 }
993 };
994
995 let d_row_offsets = ilp_gpu::sort_and_build_csr(
997 &self.provider,
998 &mut d_coo_facts,
999 &mut d_coo_cands,
1000 actual_nnz,
1001 num_facts,
1002 )?;
1003
1004 ilp_gpu::forward_backward_reduce(
1006 &self.provider,
1007 py,
1008 &d_row_offsets,
1009 &d_coo_cands,
1010 cand_col,
1011 &d_is_positive,
1012 num_facts,
1013 num_cands,
1014 is_f64,
1015 )
1016 }
1017
1018 pub fn compute_ilp_loss_grad_gpu_relations<'py>(
1025 &mut self,
1026 py: Python<'py>,
1027 positives_by_relation: &Bound<'py, PyAny>,
1028 negatives_by_relation: &Bound<'py, PyAny>,
1029 cand_probs_obj: &Bound<'py, PyAny>,
1030 ) -> PyResult<(PyObject, PyObject)> {
1031 let candidate_map = self.candidate_map.clone().ok_or_else(|| {
1032 PyRuntimeError::new_err(
1033 "candidate_map not set — call set_candidate_map() before compute_ilp_loss_grad_gpu_relations()",
1034 )
1035 })?;
1036 let num_cands = candidate_map.len() as u32;
1037
1038 let managed = dlpack_from_py(cand_probs_obj)?;
1039 let cand_buf = self
1040 .provider
1041 .from_dlpack_tensors(vec![managed])
1042 .map_err(|e| types::gpu_err("DLPack import", e))?;
1043
1044 let cand_schema = cand_buf.schema().clone();
1045 let cand_dtype = cand_schema
1046 .column_type(0)
1047 .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column type"))?;
1048 let is_f64 = match cand_dtype {
1049 ScalarType::F32 => false,
1050 ScalarType::F64 => true,
1051 other => {
1052 return Err(PyValueError::new_err(format!(
1053 "cand_probs must be F32 or F64, got {:?}",
1054 other
1055 )));
1056 }
1057 };
1058
1059 let cand_rows = read_device_row_count(&self.provider, &cand_buf)
1060 .map_err(|e| types::gpu_err("row count", e))?;
1061 if cand_rows != num_cands as usize {
1062 return Err(PyValueError::new_err(format!(
1063 "cand_probs length ({}) != candidate_map length ({})",
1064 cand_rows, num_cands
1065 )));
1066 }
1067
1068 let positive_groups = self.collect_relation_example_groups(
1069 positives_by_relation,
1070 "positives_by_relation must be a dict[str, sequence[dlpack]]",
1071 )?;
1072 let negative_groups = self.collect_relation_example_groups(
1073 negatives_by_relation,
1074 "negatives_by_relation must be a dict[str, sequence[dlpack]]",
1075 )?;
1076
1077 let num_pos: u32 = positive_groups.iter().map(|g| g.num_rows).sum();
1078 let num_neg: u32 = negative_groups.iter().map(|g| g.num_rows).sum();
1079 let num_facts = num_pos + num_neg;
1080
1081 if num_facts == 0 {
1082 return self.build_zero_loss_grad(py, num_cands, is_f64);
1083 }
1084
1085 let mut is_positive_host = Vec::with_capacity(num_facts as usize);
1086 is_positive_host.extend(std::iter::repeat_n(1u8, num_pos as usize));
1087 is_positive_host.extend(std::iter::repeat_n(0u8, num_neg as usize));
1088
1089 if self.strict_zero_dtoh && self.executor.ilp_registry().has_sparse_device_mask() {
1090 self.evaluate_ilp_plan(py)?;
1091 }
1092
1093 let tagged = self
1094 .executor
1095 .ilp_last_result()
1096 .ok_or_else(|| PyRuntimeError::new_err("No ILP result — call evaluate() first"))?;
1097
1098 let mut tasks: Vec<ilp_gpu::CooTask> = Vec::new();
1099 let mut fact_indices_buffers: Vec<xlog_cuda::memory::TrackedCudaSlice<u32>> = Vec::new();
1100 let mut next_global_idx: u32 = 0;
1101
1102 for group in positive_groups.iter().chain(negative_groups.iter()) {
1103 if group.num_rows == 0 {
1104 continue;
1105 }
1106 let k_idx = self
1107 .rel_index
1108 .iter()
1109 .position(|(_, name)| name == &group.relation)
1110 .ok_or_else(|| {
1111 PyValueError::new_err(format!(
1112 "Relation '{}' not in ILP schema",
1113 group.relation
1114 ))
1115 })? as u32;
1116
1117 let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
1118 .entries
1119 .iter()
1120 .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
1121 .collect();
1122 if relevant_entries.is_empty() {
1123 next_global_idx = next_global_idx.saturating_add(group.num_rows);
1124 continue;
1125 }
1126
1127 let arity = group.query_buf.arity();
1128 if arity == 0 {
1129 next_global_idx = next_global_idx.saturating_add(group.num_rows);
1130 continue;
1131 }
1132 let keys: Vec<usize> = (0..arity).collect();
1133
1134 let global_indices: Vec<u32> =
1135 (next_global_idx..next_global_idx + group.num_rows).collect();
1136 let mut d_fact_indices = self
1137 .provider
1138 .memory()
1139 .alloc::<u32>(group.num_rows as usize)
1140 .map_err(|e| types::gpu_err("alloc fact_indices", e))?;
1141 self.provider
1142 .device()
1143 .inner()
1144 .htod_sync_copy_into(&global_indices, &mut d_fact_indices)
1145 .map_err(|e| types::gpu_err("htod fact_indices", e))?;
1146 let fi_idx = fact_indices_buffers.len();
1147 fact_indices_buffers.push(d_fact_indices);
1148
1149 for entry in &relevant_entries {
1150 let cidx = match candidate_map.get(&(entry.i, entry.j, entry.k)) {
1151 Some(c) => *c,
1152 None => continue,
1153 };
1154 let entry_buf = entry.buffer.as_ref().ok_or_else(|| {
1155 PyRuntimeError::new_err("internal: filtered entry has no buffer")
1156 })?;
1157 let d_mask = self
1158 .provider
1159 .membership_mask_device(&group.query_buf, entry_buf, &keys, &keys)
1160 .map_err(|e| types::gpu_err("membership_mask", e))?;
1161 tasks.push(ilp_gpu::CooTask {
1162 cidx,
1163 num_query: group.num_rows,
1164 d_mask,
1165 fact_indices_idx: fi_idx,
1166 });
1167 }
1168
1169 next_global_idx = next_global_idx.saturating_add(group.num_rows);
1170 }
1171
1172 let num_tasks = tasks.len();
1173 let upper_bound: u32 = tasks.iter().map(|t| t.num_query).sum();
1174 if upper_bound == 0 || num_tasks == 0 {
1175 return self.build_loss_grad_empty_coo(
1176 py,
1177 &is_positive_host,
1178 num_facts,
1179 num_cands,
1180 is_f64,
1181 );
1182 }
1183
1184 let coo_bytes = (upper_bound as u64) * 8;
1185 let needs_chunking = coo_bytes > self.coo_chunk_budget;
1186
1187 let mut d_is_positive = self
1188 .provider
1189 .memory()
1190 .alloc::<u8>(num_facts as usize)
1191 .map_err(|e| types::gpu_err("alloc is_positive", e))?;
1192 self.provider
1193 .device()
1194 .inner()
1195 .htod_sync_copy_into(&is_positive_host, &mut d_is_positive)
1196 .map_err(|e| types::gpu_err("htod is_positive", e))?;
1197
1198 let cand_col = cand_buf
1199 .column(0)
1200 .ok_or_else(|| PyRuntimeError::new_err("cand_probs has no column"))?;
1201
1202 let (mut d_coo_facts, mut d_coo_cands, actual_nnz) = if !needs_chunking {
1203 ilp_gpu::build_coo_single(
1204 &self.provider,
1205 &tasks,
1206 &fact_indices_buffers,
1207 num_facts,
1208 num_cands,
1209 upper_bound,
1210 )?
1211 } else {
1212 match ilp_gpu::build_coo_chunked(
1213 &self.provider,
1214 &tasks,
1215 &fact_indices_buffers,
1216 num_facts,
1217 self.coo_chunk_budget,
1218 )? {
1219 Some(result) => result,
1220 None => {
1221 return self.build_loss_grad_empty_coo(
1222 py,
1223 &is_positive_host,
1224 num_facts,
1225 num_cands,
1226 is_f64,
1227 );
1228 }
1229 }
1230 };
1231
1232 let d_row_offsets = ilp_gpu::sort_and_build_csr(
1233 &self.provider,
1234 &mut d_coo_facts,
1235 &mut d_coo_cands,
1236 actual_nnz,
1237 num_facts,
1238 )?;
1239
1240 ilp_gpu::forward_backward_reduce(
1241 &self.provider,
1242 py,
1243 &d_row_offsets,
1244 &d_coo_cands,
1245 cand_col,
1246 &d_is_positive,
1247 num_facts,
1248 num_cands,
1249 is_f64,
1250 )
1251 }
1252
1253 pub fn set_rule_mask(
1254 &mut self,
1255 name: String,
1256 mask_hard_flat: &Bound<'_, PyAny>,
1257 mask_soft_flat: &Bound<'_, PyAny>,
1258 schema_size: usize,
1259 ) -> PyResult<()> {
1260 if self.compiled_schema_size > 0 && schema_size != self.compiled_schema_size {
1261 return Err(PyValueError::new_err(format!(
1262 "schema_size mismatch: mask has N={} but compiled program expects N={}",
1263 schema_size, self.compiled_schema_size,
1264 )));
1265 }
1266
1267 let hard_dmt = dlpack_from_py(mask_hard_flat)?;
1268 let soft_dmt = dlpack_from_py(mask_soft_flat)?;
1269
1270 let hard_buf = self
1271 .provider
1272 .from_dlpack_tensors(vec![hard_dmt])
1273 .map_err(types::xlog_err)?;
1274 let soft_buf = self
1275 .provider
1276 .from_dlpack_tensors(vec![soft_dmt])
1277 .map_err(types::xlog_err)?;
1278
1279 self.executor
1280 .ilp_registry_mut()
1281 .insert_mask(name, hard_buf, soft_buf, schema_size);
1282 Ok(())
1283 }
1284
1285 #[pyo3(signature = (name, candidate_ids, soft_probs_dlpack, budget, allow_recursive=false))]
1295 pub fn set_rule_mask_sparse(
1296 &mut self,
1297 name: String,
1298 candidate_ids: Vec<u32>,
1299 soft_probs_dlpack: &Bound<'_, PyAny>,
1300 budget: usize,
1301 allow_recursive: bool,
1302 ) -> PyResult<()> {
1303 if self.strict_zero_dtoh {
1304 return Err(PyRuntimeError::new_err(
1305 "strict_zero_dtoh forbids legacy set_rule_mask_sparse; use set_rule_mask_sparse_selected instead",
1306 ));
1307 }
1308
1309 let tmj = extract_tmj_meta_for_mask(&self.plan, Some(&name));
1310 let n = tmj.schema_size;
1311 if n == 0 {
1312 return Err(PyValueError::new_err(format!(
1313 "no learnable mask '{}' found",
1314 name
1315 )));
1316 }
1317 if self.compiled_schema_size > 0 && n != self.compiled_schema_size {
1318 return Err(PyValueError::new_err(format!(
1319 "schema_size mismatch for '{}': plan N={} compiled N={}",
1320 name, n, self.compiled_schema_size
1321 )));
1322 }
1323
1324 let expected_c = self.expected_candidate_count(&name, allow_recursive)?;
1325 if candidate_ids.len() != expected_c {
1326 return Err(PyValueError::new_err(format!(
1327 "candidate_ids length {} != expected candidate count {}",
1328 candidate_ids.len(),
1329 expected_c
1330 )));
1331 }
1332 for (idx, &cid) in candidate_ids.iter().enumerate() {
1333 if cid != idx as u32 {
1334 return Err(PyValueError::new_err(format!(
1335 "candidate_ids must be [0..{}), got id {} at position {}",
1336 expected_c, cid, idx
1337 )));
1338 }
1339 }
1340
1341 let soft_dmt = dlpack_from_py(soft_probs_dlpack)?;
1343 let soft_buf = self
1344 .provider
1345 .from_dlpack_tensors(vec![soft_dmt])
1346 .map_err(types::xlog_err)?;
1347
1348 let soft_probs = self
1350 .provider
1351 .download_column_untracked::<f64>(&soft_buf, 0)
1352 .map_err(types::xlog_err)?;
1353
1354 if soft_probs.len() != candidate_ids.len() {
1355 return Err(PyValueError::new_err(format!(
1356 "soft_probs tensor length {} != candidate_ids length {}",
1357 soft_probs.len(),
1358 candidate_ids.len()
1359 )));
1360 }
1361
1362 let candidate_triples = self.candidate_triples_for_mask(&name, allow_recursive)?;
1363 if candidate_triples.len() != expected_c {
1364 return Err(PyRuntimeError::new_err(format!(
1365 "internal candidate count mismatch: triples={} expected={}",
1366 candidate_triples.len(),
1367 expected_c
1368 )));
1369 }
1370
1371 let active_soft: Vec<f32> = soft_probs.iter().map(|&v| v as f32).collect();
1373
1374 self.executor
1375 .ilp_registry_mut()
1376 .insert_mask_from_sparse(name, n, &candidate_triples, &active_soft, budget)
1377 .map_err(types::xlog_err)
1378 }
1379
1380 #[pyo3(signature = (name, selected_candidate_ids, selected_soft_probs_dlpack, allow_recursive=false))]
1387 pub fn set_rule_mask_sparse_selected(
1388 &mut self,
1389 name: String,
1390 selected_candidate_ids: Vec<u32>,
1391 selected_soft_probs_dlpack: &Bound<'_, PyAny>,
1392 allow_recursive: bool,
1393 ) -> PyResult<()> {
1394 if self.strict_zero_dtoh {
1395 return Err(PyRuntimeError::new_err(
1396 "strict_zero_dtoh forbids set_rule_mask_sparse_selected; use explicit compatibility export instead",
1397 ));
1398 }
1399 let tmj = extract_tmj_meta_for_mask(&self.plan, Some(&name));
1400 let n = tmj.schema_size;
1401 if n == 0 {
1402 return Err(PyValueError::new_err(format!(
1403 "no learnable mask '{}' found",
1404 name
1405 )));
1406 }
1407 if self.compiled_schema_size > 0 && n != self.compiled_schema_size {
1408 return Err(PyValueError::new_err(format!(
1409 "schema_size mismatch for '{}': plan N={} compiled N={}",
1410 name, n, self.compiled_schema_size
1411 )));
1412 }
1413
1414 let candidate_triples = self.candidate_triples_for_mask(&name, allow_recursive)?;
1415 let expected_c = candidate_triples.len();
1416
1417 let soft_dmt = dlpack_from_py(selected_soft_probs_dlpack)?;
1418 let soft_buf = self
1419 .provider
1420 .from_dlpack_tensors(vec![soft_dmt])
1421 .map_err(types::xlog_err)?;
1422 let selected_len = usize::try_from(soft_buf.num_rows())
1423 .map_err(|_| PyValueError::new_err("selected soft_probs length overflow"))?;
1424
1425 if selected_len != selected_candidate_ids.len() {
1426 return Err(PyValueError::new_err(format!(
1427 "selected soft_probs length {} != selected_candidate_ids length {}",
1428 selected_len,
1429 selected_candidate_ids.len()
1430 )));
1431 }
1432
1433 let mut seen = HashSet::with_capacity(selected_candidate_ids.len());
1434 let mut selected_entries = Vec::with_capacity(selected_candidate_ids.len());
1435 for (pos, &cid) in selected_candidate_ids.iter().enumerate() {
1436 let idx = usize::try_from(cid).map_err(|_| {
1437 PyValueError::new_err(format!("candidate id {} overflows usize", cid))
1438 })?;
1439 if idx >= expected_c {
1440 return Err(PyValueError::new_err(format!(
1441 "selected candidate id {} out of range [0, {}) at position {}",
1442 cid, expected_c, pos
1443 )));
1444 }
1445 if !seen.insert(cid) {
1446 return Err(PyValueError::new_err(format!(
1447 "duplicate selected candidate id {} at position {}",
1448 cid, pos
1449 )));
1450 }
1451 selected_entries.push(candidate_triples[idx]);
1452 }
1453
1454 self.executor
1455 .ilp_registry_mut()
1456 .insert_selected_mask(name, n, &selected_entries);
1457 Ok(())
1458 }
1459
1460 #[pyo3(signature = (name, selected_candidate_ids_dlpack, selected_soft_probs_dlpack, allow_recursive=false))]
1464 pub fn set_rule_mask_sparse_selected_device(
1465 &mut self,
1466 name: String,
1467 selected_candidate_ids_dlpack: &Bound<'_, PyAny>,
1468 selected_soft_probs_dlpack: &Bound<'_, PyAny>,
1469 allow_recursive: bool,
1470 ) -> PyResult<()> {
1471 self.set_rule_mask_sparse_selected_device_impl(
1472 name,
1473 selected_candidate_ids_dlpack,
1474 selected_soft_probs_dlpack,
1475 allow_recursive,
1476 true,
1477 )
1478 }
1479
1480 pub fn evaluate(&mut self, py: Python<'_>) -> PyResult<()> {
1481 if self.strict_zero_dtoh && self.executor.ilp_registry().has_sparse_device_mask() {
1482 return Err(PyRuntimeError::new_err(
1483 "SparseDevice evaluate() is incompatible with strict_zero_dtoh; \
1484use train_only(..., strict_gpu_native=True) or export an explicit compatibility mask first",
1485 ));
1486 }
1487 self.evaluate_ilp_plan(py)
1488 }
1489
1490 pub fn reset_runtime(&mut self, py: Python<'_>) -> PyResult<()> {
1501 let _result: xlog_core::Result<()> = py.allow_threads(|| {
1502 self.executor.reset_for_ilp();
1504
1505 for (name, schema) in &self.schemas {
1507 let empty = self.provider.create_empty_buffer(schema.clone())?;
1508 self.executor.store_mut().put(name, empty);
1509 }
1510
1511 load_facts_into_store(&self.ast, &self.provider, &mut self.executor, &self.schemas)?;
1513
1514 self.apply_relation_overrides()?;
1516
1517 self.executor.execute_plan(&self.plan)?;
1519
1520 Ok(())
1521 });
1522 _result.map_err(types::xlog_err)?;
1523
1524 self.provider.reset_d2h_transfer_count();
1526
1527 Ok(())
1528 }
1529
1530 pub fn get_tagged_results(&self) -> PyResult<Vec<(u32, u32, u32, u32)>> {
1531 self.ensure_host_semantic_compat(
1532 "get_tagged_results()",
1533 "disable strict_zero_dtoh for host materialization",
1534 )?;
1535 match self.executor.ilp_last_result() {
1536 Some(result) => Ok(result
1537 .entries
1538 .iter()
1539 .map(|e| (e.i, e.j, e.k, e.num_rows))
1540 .collect()),
1541 None => Ok(Vec::new()),
1542 }
1543 }
1544
1545 pub fn fact_exists(&self, relation: &str, values: Vec<i64>) -> PyResult<bool> {
1546 self.ensure_host_semantic_compat("fact_exists()", "batch_fact_membership_device(...)")?;
1547 let buf =
1548 self.executor.store().get(relation).ok_or_else(|| {
1549 PyValueError::new_err(format!("Relation '{}' not found", relation))
1550 })?;
1551
1552 Self::fact_exists_in_buffer(&self.provider, buf, &values).map_err(types::xlog_err)
1553 }
1554
1555 #[pyo3(signature = (rel_name))]
1557 pub fn relation_facts(&self, rel_name: String) -> PyResult<Vec<Vec<i64>>> {
1558 self.ensure_host_semantic_compat(
1559 "relation_facts()",
1560 "disable strict_zero_dtoh for host materialization",
1561 )?;
1562 let buf =
1563 self.executor.store().get(&rel_name).ok_or_else(|| {
1564 PyValueError::new_err(format!("Relation '{}' not found", rel_name))
1565 })?;
1566
1567 let num_rows =
1568 read_device_row_count(&self.provider, buf).map_err(types::xlog_err)? as usize;
1569 if num_rows == 0 {
1570 return Ok(Vec::new());
1571 }
1572
1573 let schema = buf.schema();
1575 let mut columns: Vec<Vec<i64>> = Vec::new();
1576 for col_idx in 0..buf.arity() {
1577 let col_type = schema.column_type(col_idx).ok_or_else(|| {
1578 PyRuntimeError::new_err(format!("Column {} type not found in schema", col_idx))
1579 })?;
1580 let col_i64: Vec<i64> = match col_type {
1581 ScalarType::I64 => self
1582 .provider
1583 .download_column::<i64>(buf, col_idx)
1584 .map_err(types::xlog_err)?,
1585 ScalarType::I32 => self
1586 .provider
1587 .download_column::<i32>(buf, col_idx)
1588 .map_err(types::xlog_err)?
1589 .into_iter()
1590 .map(|v| v as i64)
1591 .collect(),
1592 ScalarType::U32 | ScalarType::Symbol => self
1593 .provider
1594 .download_column::<u32>(buf, col_idx)
1595 .map_err(types::xlog_err)?
1596 .into_iter()
1597 .map(|v| v as i64)
1598 .collect(),
1599 ScalarType::U64 => self
1600 .provider
1601 .download_column::<u64>(buf, col_idx)
1602 .map_err(types::xlog_err)?
1603 .into_iter()
1604 .map(|v| v as i64)
1605 .collect(),
1606 ScalarType::Bool => self
1607 .provider
1608 .download_column::<bool>(buf, col_idx)
1609 .map_err(types::xlog_err)?
1610 .into_iter()
1611 .map(|v| if v { 1i64 } else { 0i64 })
1612 .collect(),
1613 ScalarType::F32 | ScalarType::F64 => {
1614 return Err(PyRuntimeError::new_err(format!(
1615 "relation_facts does not support float column type {:?}",
1616 col_type
1617 )));
1618 }
1619 };
1620 columns.push(col_i64);
1621 }
1622
1623 let mut result = Vec::with_capacity(num_rows);
1624 for r in 0..num_rows {
1625 let mut row = Vec::with_capacity(buf.arity());
1626 for c in 0..buf.arity() {
1627 row.push(columns[c][r]);
1628 }
1629 result.push(row);
1630 }
1631 Ok(result)
1632 }
1633
1634 #[pyo3(signature = (head_rel, exclude, max_n))]
1639 pub fn sample_false_positives(
1640 &self,
1641 head_rel: String,
1642 exclude: Vec<(String, Vec<i64>)>,
1643 max_n: usize,
1644 ) -> PyResult<Vec<Vec<i64>>> {
1645 self.ensure_host_semantic_compat(
1646 "sample_false_positives()",
1647 "disable strict_zero_dtoh for host materialization",
1648 )?;
1649 let exclude_set: HashSet<Vec<i64>> = exclude
1651 .into_iter()
1652 .filter(|(rel, _)| rel == &head_rel)
1653 .map(|(_, vals)| vals)
1654 .collect();
1655
1656 let buf =
1658 self.executor.store().get(&head_rel).ok_or_else(|| {
1659 PyValueError::new_err(format!("Relation '{}' not found", head_rel))
1660 })?;
1661
1662 let num_rows =
1663 read_device_row_count(&self.provider, buf).map_err(types::xlog_err)? as usize;
1664 if num_rows == 0 {
1665 return Ok(Vec::new());
1666 }
1667
1668 let schema = buf.schema();
1669 let mut columns: Vec<Vec<i64>> = Vec::new();
1670 for col_idx in 0..buf.arity() {
1671 let col_type = schema.column_type(col_idx).ok_or_else(|| {
1672 PyRuntimeError::new_err(format!("Column {} type not found in schema", col_idx))
1673 })?;
1674 let col_i64: Vec<i64> = match col_type {
1675 ScalarType::I64 => self
1676 .provider
1677 .download_column::<i64>(buf, col_idx)
1678 .map_err(types::xlog_err)?,
1679 ScalarType::I32 => self
1680 .provider
1681 .download_column::<i32>(buf, col_idx)
1682 .map_err(types::xlog_err)?
1683 .into_iter()
1684 .map(|v| v as i64)
1685 .collect(),
1686 ScalarType::U32 | ScalarType::Symbol => self
1687 .provider
1688 .download_column::<u32>(buf, col_idx)
1689 .map_err(types::xlog_err)?
1690 .into_iter()
1691 .map(|v| v as i64)
1692 .collect(),
1693 ScalarType::U64 => self
1694 .provider
1695 .download_column::<u64>(buf, col_idx)
1696 .map_err(types::xlog_err)?
1697 .into_iter()
1698 .map(|v| v as i64)
1699 .collect(),
1700 ScalarType::Bool => self
1701 .provider
1702 .download_column::<bool>(buf, col_idx)
1703 .map_err(types::xlog_err)?
1704 .into_iter()
1705 .map(|v| if v { 1i64 } else { 0i64 })
1706 .collect(),
1707 ScalarType::F32 | ScalarType::F64 => {
1708 return Err(PyRuntimeError::new_err(format!(
1709 "sample_false_positives does not support float column type {:?}",
1710 col_type
1711 )));
1712 }
1713 };
1714 columns.push(col_i64);
1715 }
1716
1717 let mut result = Vec::with_capacity(max_n.min(num_rows));
1719 for r in 0..num_rows {
1720 if result.len() >= max_n {
1721 break;
1722 }
1723 let mut row = Vec::with_capacity(buf.arity());
1724 for c in 0..buf.arity() {
1725 row.push(columns[c][r]);
1726 }
1727 if !exclude_set.contains(&row) {
1728 result.push(row);
1729 }
1730 }
1731 Ok(result)
1732 }
1733
1734 pub fn tagged_entries_containing_fact(
1735 &self,
1736 relation: &str,
1737 values: Vec<i64>,
1738 ) -> PyResult<Vec<(u32, u32, u32)>> {
1739 self.ensure_host_semantic_compat(
1740 "tagged_entries_containing_fact()",
1741 "batch_tagged_credit_device(...)",
1742 )?;
1743 let k_idx = self
1744 .rel_index
1745 .iter()
1746 .position(|(_, name)| name == relation)
1747 .ok_or_else(|| {
1748 PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
1749 })? as u32;
1750
1751 let tagged = match self.executor.ilp_last_result() {
1752 Some(t) => t,
1753 None => return Ok(Vec::new()),
1754 };
1755
1756 let mut result = Vec::new();
1757 for entry in &tagged.entries {
1758 if entry.k != k_idx || entry.num_rows == 0 {
1759 continue;
1760 }
1761
1762 let (_, left_name) = &self.rel_index[entry.i as usize];
1763 let (_, right_name) = &self.rel_index[entry.j as usize];
1764
1765 let left_buf = match self.executor.store().get(left_name) {
1766 Some(buf) if buf.arity() > 0 => buf,
1767 _ => continue,
1768 };
1769 let right_buf = match self.executor.store().get(right_name) {
1770 Some(buf) if buf.arity() > 0 => buf,
1771 _ => continue,
1772 };
1773
1774 let left_max = self.left_keys.iter().copied().max().unwrap_or(0);
1776 let right_max = self.right_keys.iter().copied().max().unwrap_or(0);
1777 if left_buf.arity() <= left_max || right_buf.arity() <= right_max {
1778 continue;
1779 }
1780
1781 let joined = self
1782 .provider
1783 .hash_join_v2(
1784 left_buf,
1785 right_buf,
1786 &self.left_keys,
1787 &self.right_keys,
1788 JoinType::Inner,
1789 )
1790 .map_err(types::xlog_err)?;
1791
1792 let found = if !self.head_projection.is_empty()
1795 && self.head_projection.len() == values.len()
1796 {
1797 Self::fact_exists_projected(&self.provider, &joined, &values, &self.head_projection)
1798 .map_err(types::xlog_err)?
1799 } else {
1800 Self::fact_exists_in_buffer(&self.provider, &joined, &values)
1801 .map_err(types::xlog_err)?
1802 };
1803
1804 if found {
1805 result.push((entry.i, entry.j, entry.k));
1806 }
1807 }
1808 Ok(result)
1809 }
1810
1811 pub fn ilp_schema_size(&self) -> usize {
1812 self.rel_index.len()
1813 }
1814
1815 pub fn ilp_relation_names(&self) -> Vec<String> {
1816 self.rel_index
1817 .iter()
1818 .map(|(_, name)| name.clone())
1819 .collect()
1820 }
1821
1822 pub fn relation_type_annotations(&self) -> Vec<(String, Vec<String>)> {
1828 self.ast
1829 .predicates
1830 .iter()
1831 .map(|pred| {
1832 let types = pred.types.iter().map(type_ref_name).collect();
1833 (pred.name.clone(), types)
1834 })
1835 .collect()
1836 }
1837
1838 #[pyo3(signature = (mask_name, allow_recursive=false))]
1850 fn valid_candidates(
1851 &self,
1852 py: Python<'_>,
1853 mask_name: String,
1854 allow_recursive: bool,
1855 ) -> PyResult<Vec<StdHashMap<String, PyObject>>> {
1856 let candidates = self.candidate_triples_for_mask(&mask_name, allow_recursive)?;
1857
1858 let result: Vec<StdHashMap<String, PyObject>> = candidates
1859 .iter()
1860 .enumerate()
1861 .map(
1862 |(id, &(i, j, k))| -> PyResult<StdHashMap<String, PyObject>> {
1863 let mut d = StdHashMap::new();
1864 d.insert("id".into(), id.into_pyobject(py)?.into_any().unbind());
1865 d.insert("i".into(), i.into_pyobject(py)?.into_any().unbind());
1866 d.insert("j".into(), j.into_pyobject(py)?.into_any().unbind());
1867 d.insert("k".into(), k.into_pyobject(py)?.into_any().unbind());
1868 d.insert(
1869 "left_name".into(),
1870 self.rel_index[i as usize]
1871 .1
1872 .clone()
1873 .into_pyobject(py)?
1874 .into_any()
1875 .unbind(),
1876 );
1877 d.insert(
1878 "right_name".into(),
1879 self.rel_index[j as usize]
1880 .1
1881 .clone()
1882 .into_pyobject(py)?
1883 .into_any()
1884 .unbind(),
1885 );
1886 d.insert(
1887 "head_name".into(),
1888 self.rel_index[k as usize]
1889 .1
1890 .clone()
1891 .into_pyobject(py)?
1892 .into_any()
1893 .unbind(),
1894 );
1895 Ok(d)
1896 },
1897 )
1898 .collect::<PyResult<_>>()?;
1899
1900 Ok(result)
1901 }
1902
1903 pub fn commit_induced_rule(&mut self, rule_source: &str) -> PyResult<()> {
1904 let new_base = format!("{}\n{}", self.base_source, rule_source);
1905
1906 let ast = xlog_logic::parse_program(&new_base).map_err(types::val_err)?;
1907 let mut compiler = xlog_logic::Compiler::new();
1908 compiler.set_max_active_rules(self.max_active_rules);
1909 let plan = compiler.compile_program(&ast).map_err(types::xlog_err)?;
1910 let schemas = compiler.schemas().clone();
1911
1912 self.executor.reset_for_mc();
1913 for (name, rel_id) in compiler.rel_ids() {
1914 self.executor.register_relation(*rel_id, name);
1915 }
1916 for (name, schema) in &schemas {
1917 let empty = self
1918 .provider
1919 .create_empty_buffer(schema.clone())
1920 .map_err(types::xlog_err)?;
1921 self.executor.store_mut().put(name, empty);
1922 }
1923 load_facts_into_store(&ast, &self.provider, &mut self.executor, &schemas)
1924 .map_err(types::xlog_err)?;
1925 self.apply_relation_overrides().map_err(types::xlog_err)?;
1926 self.executor.execute_plan(&plan).map_err(types::xlog_err)?;
1927
1928 self.base_source = new_base;
1929 self.ast = ast;
1930 let tmj = extract_tmj_meta(&plan);
1931 self.left_keys = tmj.left_keys;
1932 self.right_keys = tmj.right_keys;
1933 self.head_projection = tmj.head_projection;
1934 self.compiled_schema_size = tmj.schema_size;
1935 self.head_rel_name = tmj.head_rel_name;
1936 self.plan = plan;
1937 self.schemas = schemas;
1938 Ok(())
1939 }
1940
1941 pub fn batch_fact_membership_device(
1946 &self,
1947 py: Python<'_>,
1948 relation: &str,
1949 facts: Vec<Vec<i64>>,
1950 ) -> PyResult<PyObject> {
1951 let buf =
1952 self.executor.store().get(relation).ok_or_else(|| {
1953 PyValueError::new_err(format!("Relation '{}' not found", relation))
1954 })?;
1955
1956 if facts.is_empty() {
1957 let empty = self
1958 .provider
1959 .memory()
1960 .alloc::<u8>(0)
1961 .map_err(types::xlog_err)?;
1962 return export_device_bool_tensor(&self.provider, py, empty, 0);
1963 }
1964 if buf.arity() == 0 {
1965 let mut zeros = self
1966 .provider
1967 .memory()
1968 .alloc::<u8>(facts.len())
1969 .map_err(types::xlog_err)?;
1970 self.provider
1971 .device()
1972 .inner()
1973 .memset_zeros(&mut zeros)
1974 .map_err(types::xlog_err)?;
1975 return export_device_bool_tensor(&self.provider, py, zeros, facts.len());
1976 }
1977
1978 let col_bytes = pack_i64_columns_typed(relation, &facts, buf.schema())?;
1979 let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
1980 let query_buf = self
1981 .provider
1982 .create_buffer_from_slices(&col_slices, buf.schema().clone())
1983 .map_err(types::xlog_err)?;
1984
1985 let keys: Vec<usize> = (0..buf.arity()).collect();
1986 let mask = self
1987 .provider
1988 .membership_mask_device(&query_buf, buf, &keys, &keys)
1989 .map_err(types::xlog_err)?;
1990 export_device_bool_tensor(&self.provider, py, mask, facts.len())
1991 }
1992
1993 pub fn batch_fact_membership(
1994 &self,
1995 relation: &str,
1996 facts: Vec<Vec<i64>>,
1997 ) -> PyResult<Vec<bool>> {
1998 self.ensure_host_semantic_compat(
1999 "batch_fact_membership()",
2000 "batch_fact_membership_device(...)",
2001 )?;
2002 if facts.is_empty() {
2003 return Ok(Vec::new());
2004 }
2005
2006 let buf =
2007 self.executor.store().get(relation).ok_or_else(|| {
2008 PyValueError::new_err(format!("Relation '{}' not found", relation))
2009 })?;
2010
2011 let arity = buf.arity();
2012 if arity == 0 {
2013 return Ok(vec![false; facts.len()]);
2014 }
2015
2016 let col_bytes = pack_i64_columns_typed(relation, &facts, buf.schema())?;
2018 let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
2019 let query_buf = self
2020 .provider
2021 .create_buffer_from_slices(&col_slices, buf.schema().clone())
2022 .map_err(types::xlog_err)?;
2023
2024 let keys: Vec<usize> = (0..arity).collect();
2026
2027 self.provider
2028 .membership_mask(&query_buf, buf, &keys, &keys)
2029 .map_err(types::xlog_err)
2030 }
2031
2032 pub fn batch_tagged_credit(
2038 &self,
2039 relation: &str,
2040 facts: Vec<Vec<i64>>,
2041 ) -> PyResult<Vec<Vec<(u32, u32, u32)>>> {
2042 self.ensure_host_semantic_compat(
2043 "batch_tagged_credit()",
2044 "batch_tagged_credit_device(...)",
2045 )?;
2046 if facts.is_empty() {
2047 return Ok(Vec::new());
2048 }
2049
2050 let k_idx = self
2052 .rel_index
2053 .iter()
2054 .position(|(_, name)| name == relation)
2055 .ok_or_else(|| {
2056 PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
2057 })? as u32;
2058
2059 let tagged = match self.executor.ilp_last_result() {
2060 Some(t) => t,
2061 None => return Ok(vec![Vec::new(); facts.len()]),
2062 };
2063
2064 let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
2066 .entries
2067 .iter()
2068 .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
2069 .collect();
2070
2071 if relevant_entries.is_empty() {
2072 return Ok(vec![Vec::new(); facts.len()]);
2073 }
2074
2075 let first_buf = relevant_entries[0]
2077 .buffer
2078 .as_ref()
2079 .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2080 let arity = first_buf.arity();
2081 if arity == 0 {
2082 return Ok(vec![Vec::new(); facts.len()]);
2083 }
2084
2085 let schema = first_buf.schema().clone();
2087 let col_bytes = pack_i64_columns_typed(relation, &facts, &schema)?;
2088 let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
2089 let query_buf = self
2090 .provider
2091 .create_buffer_from_slices(&col_slices, schema)
2092 .map_err(types::xlog_err)?;
2093
2094 let keys: Vec<usize> = (0..arity).collect();
2095
2096 let mut per_fact_credits: Vec<Vec<(u32, u32, u32)>> = vec![Vec::new(); facts.len()];
2098
2099 for entry in &relevant_entries {
2100 let entry_buf = entry
2101 .buffer
2102 .as_ref()
2103 .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2104 let mask = self
2105 .provider
2106 .membership_mask(&query_buf, entry_buf, &keys, &keys)
2107 .map_err(types::xlog_err)?;
2108
2109 for (fact_idx, &found) in mask.iter().enumerate() {
2110 if found {
2111 per_fact_credits[fact_idx].push((entry.i, entry.j, entry.k));
2112 }
2113 }
2114 }
2115
2116 Ok(per_fact_credits)
2117 }
2118
2119 pub fn batch_tagged_credit_device(
2129 &self,
2130 py: Python<'_>,
2131 relation: &str,
2132 facts: Vec<Vec<i64>>,
2133 ) -> PyResult<IlpTaggedCreditDeviceResult> {
2134 if facts.is_empty() {
2135 return empty_tagged_credit_device_result(&self.provider, py, 0);
2136 }
2137
2138 let k_idx = self
2139 .rel_index
2140 .iter()
2141 .position(|(_, name)| name == relation)
2142 .ok_or_else(|| {
2143 PyValueError::new_err(format!("Relation '{}' not in ILP schema", relation))
2144 })? as u32;
2145
2146 let tagged = match self.executor.ilp_last_result() {
2147 Some(t) => t,
2148 None => return empty_tagged_credit_device_result(&self.provider, py, facts.len()),
2149 };
2150
2151 let relevant_entries: Vec<&xlog_runtime::ilp_registry::IlpTagEntry> = tagged
2152 .entries
2153 .iter()
2154 .filter(|e| e.k == k_idx && e.num_rows > 0 && e.buffer.is_some())
2155 .collect();
2156
2157 if relevant_entries.is_empty() {
2158 return empty_tagged_credit_device_result(&self.provider, py, facts.len());
2159 }
2160
2161 let first_buf = relevant_entries[0]
2162 .buffer
2163 .as_ref()
2164 .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2165 let arity = first_buf.arity();
2166 if arity == 0 {
2167 return empty_tagged_credit_device_result(&self.provider, py, facts.len());
2168 }
2169
2170 let schema = first_buf.schema().clone();
2171 let col_bytes = pack_i64_columns_typed(relation, &facts, &schema)?;
2172 let col_slices: Vec<&[u8]> = col_bytes.iter().map(|c| c.as_slice()).collect();
2173 let query_buf = self
2174 .provider
2175 .create_buffer_from_slices(&col_slices, schema)
2176 .map_err(types::xlog_err)?;
2177
2178 let keys: Vec<usize> = (0..arity).collect();
2179 let num_facts = u32::try_from(facts.len())
2180 .map_err(|_| PyValueError::new_err("facts length exceeds u32::MAX"))?;
2181 let num_entries = u32::try_from(relevant_entries.len())
2182 .map_err(|_| PyValueError::new_err("entry count exceeds u32::MAX"))?;
2183 let upper_bound = num_facts
2184 .checked_mul(num_entries)
2185 .ok_or_else(|| PyValueError::new_err("credit upper bound overflow"))?;
2186
2187 let fact_indices_host: Vec<u32> = (0..num_facts).collect();
2188 let mut d_fact_indices = self
2189 .provider
2190 .memory()
2191 .alloc::<u32>(num_facts as usize)
2192 .map_err(|e| types::gpu_err("alloc fact_indices", e))?;
2193 self.provider
2194 .device()
2195 .inner()
2196 .htod_sync_copy_into(&fact_indices_host, &mut d_fact_indices)
2197 .map_err(|e| types::gpu_err("htod fact_indices", e))?;
2198
2199 let mut entry_i_host = Vec::with_capacity(relevant_entries.len());
2200 let mut entry_j_host = Vec::with_capacity(relevant_entries.len());
2201 let mut entry_k_host = Vec::with_capacity(relevant_entries.len());
2202 let mut tasks = Vec::with_capacity(relevant_entries.len());
2203
2204 for (entry_idx, entry) in relevant_entries.iter().enumerate() {
2205 let entry_buf = entry
2206 .buffer
2207 .as_ref()
2208 .ok_or_else(|| PyRuntimeError::new_err("internal: filtered entry has no buffer"))?;
2209 let d_mask = self
2210 .provider
2211 .membership_mask_device(&query_buf, entry_buf, &keys, &keys)
2212 .map_err(|e| types::gpu_err("membership_mask", e))?;
2213 tasks.push(ilp_gpu::CooTask {
2214 d_mask,
2215 fact_indices_idx: 0,
2216 cidx: entry_idx as u32,
2217 num_query: num_facts,
2218 });
2219 entry_i_host.push(entry.i);
2220 entry_j_host.push(entry.j);
2221 entry_k_host.push(entry.k);
2222 }
2223
2224 let (mut d_coo_facts, mut d_coo_cands, actual_nnz) = ilp_gpu::build_coo_single(
2225 &self.provider,
2226 &tasks,
2227 &[d_fact_indices],
2228 num_facts,
2229 num_entries,
2230 upper_bound,
2231 )?;
2232 let d_row_offsets = ilp_gpu::sort_and_build_csr(
2233 &self.provider,
2234 &mut d_coo_facts,
2235 &mut d_coo_cands,
2236 actual_nnz,
2237 num_facts,
2238 )?;
2239
2240 let mut d_entry_i = self
2241 .provider
2242 .memory()
2243 .alloc::<u32>(entry_i_host.len())
2244 .map_err(|e| types::gpu_err("alloc entry_i", e))?;
2245 let mut d_entry_j = self
2246 .provider
2247 .memory()
2248 .alloc::<u32>(entry_j_host.len())
2249 .map_err(|e| types::gpu_err("alloc entry_j", e))?;
2250 let mut d_entry_k = self
2251 .provider
2252 .memory()
2253 .alloc::<u32>(entry_k_host.len())
2254 .map_err(|e| types::gpu_err("alloc entry_k", e))?;
2255 self.provider
2256 .device()
2257 .inner()
2258 .htod_sync_copy_into(&entry_i_host, &mut d_entry_i)
2259 .map_err(|e| types::gpu_err("htod entry_i", e))?;
2260 self.provider
2261 .device()
2262 .inner()
2263 .htod_sync_copy_into(&entry_j_host, &mut d_entry_j)
2264 .map_err(|e| types::gpu_err("htod entry_j", e))?;
2265 self.provider
2266 .device()
2267 .inner()
2268 .htod_sync_copy_into(&entry_k_host, &mut d_entry_k)
2269 .map_err(|e| types::gpu_err("htod entry_k", e))?;
2270
2271 Ok(IlpTaggedCreditDeviceResult {
2272 fact_row_offsets: export_device_u32_tensor_as_i32(
2273 &self.provider,
2274 py,
2275 d_row_offsets,
2276 facts.len() + 1,
2277 )?,
2278 entry_indices: export_device_u32_tensor_as_i32(
2279 &self.provider,
2280 py,
2281 d_coo_cands,
2282 upper_bound as usize,
2283 )?,
2284 entry_i: export_device_u32_tensor_as_i32(
2285 &self.provider,
2286 py,
2287 d_entry_i,
2288 entry_i_host.len(),
2289 )?,
2290 entry_j: export_device_u32_tensor_as_i32(
2291 &self.provider,
2292 py,
2293 d_entry_j,
2294 entry_j_host.len(),
2295 )?,
2296 entry_k: export_device_u32_tensor_as_i32(
2297 &self.provider,
2298 py,
2299 d_entry_k,
2300 entry_k_host.len(),
2301 )?,
2302 })
2303 }
2304
2305 pub fn d2h_transfer_count(&self) -> u64 {
2306 self.provider.d2h_transfer_count()
2307 }
2308
2309 pub fn reset_d2h_transfer_count(&self) {
2310 self.provider.reset_d2h_transfer_count()
2311 }
2312
2313 pub fn host_transfer_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
2314 let stats = self.provider.host_transfer_stats();
2315 let dict = PyDict::new(py);
2316 dict.set_item("dtoh_bytes", stats.dtoh_bytes)?;
2317 dict.set_item("htod_bytes", stats.htod_bytes)?;
2318 dict.set_item("dtoh_calls", stats.dtoh_calls)?;
2319 dict.set_item("htod_calls", stats.htod_calls)?;
2320 Ok(dict.into())
2321 }
2322
2323 pub fn reset_host_transfer_stats(&self) {
2324 self.provider.reset_host_transfer_stats()
2325 }
2326}
2327
2328impl CompiledIlpProgram {
2333 fn collect_relation_example_groups(
2334 &self,
2335 relations_obj: &Bound<'_, PyAny>,
2336 err_msg: &str,
2337 ) -> PyResult<Vec<RelationExampleGroup>> {
2338 let dict = relations_obj
2339 .downcast::<PyDict>()
2340 .map_err(|_| PyValueError::new_err(err_msg.to_string()))?;
2341 let mut groups = Vec::with_capacity(dict.len());
2342 for (name_obj, columns_obj) in dict.iter() {
2343 let relation: String = name_obj.extract()?;
2344 let schema = self.schemas.get(&relation).ok_or_else(|| {
2345 PyValueError::new_err(format!(
2346 "Unknown relation {} (not present in compiled schemas)",
2347 relation
2348 ))
2349 })?;
2350 let tensors = collect_dlpack_columns(
2351 &columns_obj,
2352 &format!("Relation {} must be a sequence of DLPack columns", relation),
2353 )?;
2354 let query_buf = self
2355 .provider
2356 .from_dlpack_tensors_with_schema(schema.clone(), tensors)
2357 .map_err(types::xlog_err)?;
2358 let num_rows = u32::try_from(query_buf.num_rows())
2359 .map_err(|_| PyValueError::new_err("relation row count exceeds u32::MAX"))?;
2360 groups.push(RelationExampleGroup {
2361 relation,
2362 query_buf,
2363 num_rows,
2364 });
2365 }
2366 Ok(groups)
2367 }
2368
2369 fn apply_relation_overrides(&mut self) -> xlog_core::Result<()> {
2370 let relation_names: Vec<String> = self.relation_overrides.keys().cloned().collect();
2371 for name in relation_names {
2372 let stored = self
2373 .relation_overrides
2374 .get(&name)
2375 .expect("relation override disappeared during runtime reset");
2376 let live = self.provider.clone_buffer(stored)?;
2377 self.executor.put_relation(&name, live);
2378 }
2379 Ok(())
2380 }
2381
2382 fn evaluate_ilp_plan(&mut self, py: Python<'_>) -> PyResult<()> {
2383 let result: xlog_core::Result<xlog_cuda::CudaBuffer> = py.allow_threads(|| {
2384 self.executor.reset_for_mc();
2385 for (name, schema) in &self.schemas {
2386 let empty = self.provider.create_empty_buffer(schema.clone())?;
2387 self.executor.store_mut().put(name, empty);
2388 }
2389 load_facts_into_store(&self.ast, &self.provider, &mut self.executor, &self.schemas)?;
2390 self.apply_relation_overrides()?;
2391 self.executor.execute_plan(&self.plan)
2392 });
2393 result.map_err(types::xlog_err)?;
2394 Ok(())
2395 }
2396
2397 fn set_rule_mask_sparse_selected_device_impl(
2398 &mut self,
2399 name: String,
2400 selected_candidate_ids_dlpack: &Bound<'_, PyAny>,
2401 selected_soft_probs_dlpack: &Bound<'_, PyAny>,
2402 allow_recursive: bool,
2403 validate_ids: bool,
2404 ) -> PyResult<()> {
2405 let tmj = extract_tmj_meta_for_mask(&self.plan, Some(&name));
2406 let n = tmj.schema_size;
2407 if n == 0 {
2408 return Err(PyValueError::new_err(format!(
2409 "no learnable mask '{}' found",
2410 name
2411 )));
2412 }
2413 if self.compiled_schema_size > 0 && n != self.compiled_schema_size {
2414 return Err(PyValueError::new_err(format!(
2415 "schema_size mismatch for '{}': plan N={} compiled N={}",
2416 name, n, self.compiled_schema_size
2417 )));
2418 }
2419
2420 let _ = allow_recursive;
2421 let candidate_order = self.candidate_order.as_ref().ok_or_else(|| {
2422 PyRuntimeError::new_err(
2423 "candidate order not set — call set_candidate_map() before strict sparse mask updates",
2424 )
2425 })?;
2426 let expected_c = candidate_order.len();
2427
2428 let ids_dmt = dlpack_from_py(selected_candidate_ids_dlpack)?;
2429 let ids_buf = self
2430 .provider
2431 .from_dlpack_tensors(vec![ids_dmt])
2432 .map_err(types::xlog_err)?;
2433 let soft_dmt = dlpack_from_py(selected_soft_probs_dlpack)?;
2434 let soft_buf = self
2435 .provider
2436 .from_dlpack_tensors(vec![soft_dmt])
2437 .map_err(types::xlog_err)?;
2438
2439 let selected_len = usize::try_from(ids_buf.num_rows())
2440 .map_err(|_| PyValueError::new_err("selected candidate ids length overflow"))?;
2441 let soft_len = usize::try_from(soft_buf.num_rows())
2442 .map_err(|_| PyValueError::new_err("selected soft_probs length overflow"))?;
2443 if soft_len != selected_len {
2444 return Err(PyValueError::new_err(format!(
2445 "selected soft_probs length {} != selected candidate ids length {}",
2446 soft_len, selected_len
2447 )));
2448 }
2449
2450 if validate_ids {
2451 self.provider
2452 .validate_selected_ids(&ids_buf, expected_c)
2453 .map_err(types::val_err)?;
2454 }
2455
2456 let active_flags = self
2457 .provider
2458 .build_selected_id_mask(&ids_buf, expected_c)
2459 .map_err(types::xlog_err)?;
2460
2461 self.executor
2462 .ilp_registry_mut()
2463 .insert_selected_mask_device(
2464 name,
2465 n,
2466 candidate_order.clone(),
2467 active_flags,
2468 selected_len,
2469 );
2470 Ok(())
2471 }
2472
2473 fn ensure_host_semantic_compat(&self, api: &str, device_hint: &str) -> PyResult<()> {
2474 if self.strict_zero_dtoh {
2475 return Err(PyRuntimeError::new_err(format!(
2476 "strict_zero_dtoh forbids {}; use {} instead",
2477 api, device_hint
2478 )));
2479 }
2480 Ok(())
2481 }
2482
2483 fn build_zero_loss_grad(
2487 &self,
2488 py: Python<'_>,
2489 num_cands: u32,
2490 is_f64: bool,
2491 ) -> PyResult<(PyObject, PyObject)> {
2492 if is_f64 {
2493 build_zero_typed::<f64>(&self.provider, py, num_cands, ScalarType::F64)
2494 } else {
2495 build_zero_typed::<f32>(&self.provider, py, num_cands, ScalarType::F32)
2496 }
2497 }
2498
2499 fn build_loss_grad_empty_coo(
2503 &self,
2504 py: Python<'_>,
2505 is_positive_host: &[u8],
2506 num_facts: u32,
2507 num_cands: u32,
2508 is_f64: bool,
2509 ) -> PyResult<(PyObject, PyObject)> {
2510 let row_offsets = vec![0u32; (num_facts + 1) as usize];
2512 let mut d_row_offsets = self
2513 .provider
2514 .memory()
2515 .alloc::<u32>((num_facts + 1) as usize)
2516 .map_err(|e| types::gpu_err("alloc", e))?;
2517 self.provider
2518 .device()
2519 .inner()
2520 .htod_sync_copy_into(&row_offsets, &mut d_row_offsets)
2521 .map_err(|e| types::gpu_err("htod", e))?;
2522
2523 let d_col_indices = self
2525 .provider
2526 .memory()
2527 .alloc::<u32>(0)
2528 .map_err(|e| types::gpu_err("alloc", e))?;
2529
2530 let mut d_is_positive = self
2531 .provider
2532 .memory()
2533 .alloc::<u8>(num_facts as usize)
2534 .map_err(|e| types::gpu_err("alloc", e))?;
2535 self.provider
2536 .device()
2537 .inner()
2538 .htod_sync_copy_into(is_positive_host, &mut d_is_positive)
2539 .map_err(|e| types::gpu_err("htod", e))?;
2540
2541 let dummy_col: xlog_cuda::CudaColumn = if is_f64 {
2546 self.provider
2547 .memory()
2548 .alloc::<f64>(num_cands.max(1) as usize)
2549 .map_err(|e| types::gpu_err("alloc dummy", e))?
2550 .into_bytes()
2551 .into()
2552 } else {
2553 self.provider
2554 .memory()
2555 .alloc::<f32>(num_cands.max(1) as usize)
2556 .map_err(|e| types::gpu_err("alloc dummy", e))?
2557 .into_bytes()
2558 .into()
2559 };
2560
2561 ilp_gpu::forward_backward_reduce(
2562 &self.provider,
2563 py,
2564 &d_row_offsets,
2565 &d_col_indices,
2566 &dummy_col,
2567 &d_is_positive,
2568 num_facts,
2569 num_cands,
2570 is_f64,
2571 )
2572 }
2573
2574 fn candidate_triples_for_mask(
2577 &self,
2578 mask_name: &str,
2579 allow_recursive: bool,
2580 ) -> PyResult<Vec<(u32, u32, u32)>> {
2581 let tmj = extract_tmj_meta_for_mask(&self.plan, Some(mask_name));
2582 let n = tmj.schema_size;
2583 if n == 0 {
2584 return Err(PyValueError::new_err(format!(
2585 "no learnable mask '{}' found in compiled program",
2586 mask_name
2587 )));
2588 }
2589 let head_name = &tmj.head_rel_name;
2590 let k_head = self
2591 .rel_index
2592 .iter()
2593 .position(|(_, name)| name == head_name)
2594 .ok_or_else(|| {
2595 PyValueError::new_err(format!(
2596 "head relation '{}' not in rel_index for mask '{}'",
2597 head_name, mask_name
2598 ))
2599 })? as u32;
2600
2601 let has_tuples: Vec<bool> = self
2603 .rel_index
2604 .iter()
2605 .map(|(_, name)| {
2606 self.executor
2607 .store()
2608 .get(name)
2609 .map(|buf| buf.num_rows() > 0)
2610 .unwrap_or(false)
2611 })
2612 .collect();
2613
2614 let mut triples: Vec<(u32, u32, u32)> = Vec::new();
2615 for i in 0..n as u32 {
2616 for j in 0..n as u32 {
2617 let k = k_head;
2618
2619 if !has_tuples[i as usize] && !has_tuples[j as usize] {
2621 continue;
2622 }
2623
2624 if !allow_recursive && (i == k || j == k) && !has_tuples[k as usize] {
2627 continue;
2628 }
2629
2630 triples.push((i, j, k));
2631 }
2632 }
2633 triples.sort_by_key(|&(i, j, k)| (k, i, j));
2634 Ok(triples)
2635 }
2636
2637 fn expected_candidate_count(&self, mask_name: &str, allow_recursive: bool) -> PyResult<usize> {
2638 Ok(self
2639 .candidate_triples_for_mask(mask_name, allow_recursive)?
2640 .len())
2641 }
2642
2643 fn fact_exists_in_buffer(
2644 provider: &CudaKernelProvider,
2645 buf: &xlog_cuda::CudaBuffer,
2646 values: &[i64],
2647 ) -> xlog_core::Result<bool> {
2648 use xlog_core::XlogError;
2649 let num_rows = read_device_row_count(provider, buf)? as usize;
2650 if num_rows == 0 {
2651 return Ok(false);
2652 }
2653 if values.len() != buf.arity() {
2654 return Ok(false);
2655 }
2656
2657 let schema = buf.schema();
2658 let mut columns: Vec<Vec<i64>> = Vec::new();
2659 for col_idx in 0..buf.arity() {
2660 let col_type = schema.column_type(col_idx).ok_or_else(|| {
2661 XlogError::Kernel(format!("Column {} type not found in schema", col_idx))
2662 })?;
2663 let col_i64: Vec<i64> = match col_type {
2664 ScalarType::I64 => provider.download_column::<i64>(buf, col_idx)?,
2665 ScalarType::I32 => provider
2666 .download_column::<i32>(buf, col_idx)?
2667 .into_iter()
2668 .map(|v| v as i64)
2669 .collect(),
2670 ScalarType::U32 | ScalarType::Symbol => provider
2671 .download_column::<u32>(buf, col_idx)?
2672 .into_iter()
2673 .map(|v| v as i64)
2674 .collect(),
2675 ScalarType::U64 => {
2676 let col_u64 = provider.download_column::<u64>(buf, col_idx)?;
2677 col_u64.into_iter().map(|v| v as i64).collect()
2678 }
2679 ScalarType::Bool => provider
2680 .download_column::<bool>(buf, col_idx)?
2681 .into_iter()
2682 .map(|v| if v { 1i64 } else { 0i64 })
2683 .collect(),
2684 ScalarType::F32 | ScalarType::F64 => {
2685 return Err(XlogError::Kernel(format!(
2686 "fact_exists does not support float column type {:?}",
2687 col_type
2688 )));
2689 }
2690 };
2691 columns.push(col_i64);
2692 }
2693
2694 for row in 0..num_rows {
2695 let mut matches = true;
2696 for (col_idx, val) in values.iter().enumerate() {
2697 if columns[col_idx][row] != *val {
2698 matches = false;
2699 break;
2700 }
2701 }
2702 if matches {
2703 return Ok(true);
2704 }
2705 }
2706 Ok(false)
2707 }
2708
2709 fn fact_exists_projected(
2713 provider: &CudaKernelProvider,
2714 buf: &xlog_cuda::CudaBuffer,
2715 values: &[i64],
2716 projection: &[usize],
2717 ) -> xlog_core::Result<bool> {
2718 use xlog_core::XlogError;
2719 let num_rows = read_device_row_count(provider, buf)? as usize;
2720 if num_rows == 0 {
2721 return Ok(false);
2722 }
2723
2724 let schema = buf.schema();
2725 let mut columns: Vec<Vec<i64>> = Vec::new();
2726 for &col_idx in projection {
2727 if col_idx >= buf.arity() {
2728 return Ok(false);
2729 }
2730 let col_type = schema.column_type(col_idx).ok_or_else(|| {
2731 XlogError::Kernel(format!("Column {} type not found in schema", col_idx))
2732 })?;
2733 let col_i64: Vec<i64> = match col_type {
2734 ScalarType::I64 => provider.download_column::<i64>(buf, col_idx)?,
2735 ScalarType::I32 => provider
2736 .download_column::<i32>(buf, col_idx)?
2737 .into_iter()
2738 .map(|v| v as i64)
2739 .collect(),
2740 ScalarType::U32 | ScalarType::Symbol => provider
2741 .download_column::<u32>(buf, col_idx)?
2742 .into_iter()
2743 .map(|v| v as i64)
2744 .collect(),
2745 ScalarType::U64 => provider
2746 .download_column::<u64>(buf, col_idx)?
2747 .into_iter()
2748 .map(|v| v as i64)
2749 .collect(),
2750 ScalarType::Bool => provider
2751 .download_column::<bool>(buf, col_idx)?
2752 .into_iter()
2753 .map(|v| if v { 1i64 } else { 0i64 })
2754 .collect(),
2755 ScalarType::F32 | ScalarType::F64 => {
2756 return Err(XlogError::Kernel(format!(
2757 "fact_exists does not support float column type {:?}",
2758 col_type
2759 )));
2760 }
2761 };
2762 columns.push(col_i64);
2763 }
2764
2765 for row in 0..num_rows {
2766 let mut matches = true;
2767 for (i, val) in values.iter().enumerate() {
2768 if columns[i][row] != *val {
2769 matches = false;
2770 break;
2771 }
2772 }
2773 if matches {
2774 return Ok(true);
2775 }
2776 }
2777 Ok(false)
2778 }
2779}