1#![allow(missing_docs)] #![allow(
4 clippy::large_enum_variant,
5 clippy::needless_range_loop,
6 clippy::too_many_arguments,
7 clippy::type_complexity
8)]
9
10use std::collections::{HashMap, HashSet};
11use std::os::raw::{c_char, c_void};
12use std::sync::Arc;
13
14use pyo3::exceptions::{PyMemoryError, PyRuntimeError, PyValueError};
15use pyo3::prelude::*;
16use pyo3::types::{PyDict, PyList};
17
18use xlog_core::{MemoryBudget, Schema};
19use xlog_cuda::{
20 device_runtime::{
21 AsyncCudaResource, DeviceMemoryResource, GlobalDeviceBudget, StreamPool, XlogDeviceRuntime,
22 },
23 CudaBuffer, CudaDevice, CudaKernelProvider, DlpackManagedTensor, GpuMemoryManager,
24};
25#[cfg(feature = "arrow-device-import")]
26use xlog_cuda::{ArrowDeviceArray, ArrowDeviceArrayOwned};
27use xlog_gpu::logic as gpu_logic;
28use xlog_logic::ast::ProbEngine;
29use xlog_neural::{NetworkRegistry, TensorSourceRegistry};
30use xlog_prob::exact::GpuConfig;
31
32use xlog_core::RelId;
33use xlog_ir::ExecutionPlan;
34use xlog_logic::ast::Program as AstProgram;
35use xlog_runtime::{Executor, RelationStore};
36
37mod neural_registry;
38use neural_registry::NeuralPredicateRegistry;
39mod dlpack;
40mod ilp;
41mod ilp_exact;
42mod ilp_gpu;
43mod logic;
44mod neural;
45mod program;
46mod training;
47mod types;
48pub(crate) use program::{
49 CachedCircuit, CompiledProbProgram, HardFilter, InputSource, JoinPlan, NeuralGroup,
50 QuerySignature,
51};
52
53const DLPACK_CAPSULE_NAME: &[u8] = b"dltensor\0";
54const USED_DLPACK_CAPSULE_NAME: &[u8] = b"used_dltensor\0";
55
56#[cfg(feature = "arrow-device-import")]
57const ARROW_DEVICE_ARRAY_CAPSULE_NAME: &[u8] = b"arrow_device_array\0";
58#[cfg(feature = "arrow-device-import")]
59const USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME: &[u8] = b"used_arrow_device_array\0";
60
61unsafe extern "C" fn dlpack_capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
62 if capsule.is_null() {
63 return;
64 }
65
66 let valid =
67 pyo3::ffi::PyCapsule_IsValid(capsule, DLPACK_CAPSULE_NAME.as_ptr() as *const c_char);
68 if valid == 0 {
69 return;
70 }
71
72 let ptr =
73 pyo3::ffi::PyCapsule_GetPointer(capsule, DLPACK_CAPSULE_NAME.as_ptr() as *const c_char);
74 if ptr.is_null() {
75 pyo3::ffi::PyErr_Clear();
76 return;
77 }
78
79 let managed = ptr as *mut xlog_cuda::DLManagedTensor;
80 drop(DlpackManagedTensor::from_raw(managed));
81}
82
83pub(crate) fn dlpack_capsule_from_tensor(
84 py: Python<'_>,
85 tensor: DlpackManagedTensor,
86) -> PyResult<PyObject> {
87 let raw = tensor.into_raw();
88 let ptr = raw as *mut c_void;
89 let capsule = unsafe {
91 pyo3::ffi::PyCapsule_New(
92 ptr,
93 DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
94 Some(dlpack_capsule_destructor),
95 )
96 };
97 if capsule.is_null() {
98 unsafe {
100 drop(DlpackManagedTensor::from_raw(raw));
101 }
102 return Err(PyRuntimeError::new_err("Failed to create DLPack capsule"));
103 }
104 let obj: Py<PyAny> = unsafe { Py::from_owned_ptr(py, capsule) };
106 Ok(obj)
107}
108
109#[cfg(feature = "arrow-device-import")]
110unsafe extern "C" fn arrow_device_array_capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
111 if capsule.is_null() {
112 return;
113 }
114
115 let valid = pyo3::ffi::PyCapsule_IsValid(
116 capsule,
117 ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
118 );
119 if valid == 0 {
120 return;
121 }
122
123 let ptr = pyo3::ffi::PyCapsule_GetPointer(
124 capsule,
125 ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
126 );
127 if ptr.is_null() {
128 pyo3::ffi::PyErr_Clear();
129 return;
130 }
131
132 drop(ArrowDeviceArrayOwned::from_raw(
133 ptr as *mut ArrowDeviceArray,
134 ));
135}
136
137#[cfg(feature = "arrow-device-import")]
138pub(crate) fn arrow_device_capsule_from_device_array(
139 py: Python<'_>,
140 device_array: ArrowDeviceArrayOwned,
141) -> PyResult<PyObject> {
142 let raw = device_array.into_raw();
143 let ptr = raw as *mut c_void;
144 let capsule = unsafe {
146 pyo3::ffi::PyCapsule_New(
147 ptr,
148 ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
149 Some(arrow_device_array_capsule_destructor),
150 )
151 };
152 if capsule.is_null() {
153 unsafe {
155 drop(ArrowDeviceArrayOwned::from_raw(raw));
156 }
157 return Err(PyRuntimeError::new_err(
158 "Failed to create Arrow device array capsule",
159 ));
160 }
161 let obj: Py<PyAny> = unsafe { Py::from_owned_ptr(py, capsule) };
163 Ok(obj)
164}
165
166#[cfg(feature = "arrow-device-import")]
167pub(crate) fn arrow_device_from_py(obj: &Bound<'_, PyAny>) -> PyResult<ArrowDeviceArrayOwned> {
168 if unsafe {
170 pyo3::ffi::PyCapsule_IsValid(
171 obj.as_ptr(),
172 ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
173 )
174 } == 0
175 {
176 return Err(PyValueError::new_err(
177 "Expected an Arrow device array capsule (arrow_device_array)",
178 ));
179 }
180
181 let ptr = unsafe {
183 pyo3::ffi::PyCapsule_GetPointer(
184 obj.as_ptr(),
185 ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
186 )
187 };
188 if ptr.is_null() {
189 return Err(PyRuntimeError::new_err(
190 "Failed to get Arrow device array pointer",
191 ));
192 }
193
194 let rc = unsafe {
197 pyo3::ffi::PyCapsule_SetName(
198 obj.as_ptr(),
199 USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME.as_ptr() as *const c_char,
200 )
201 };
202 if rc != 0 {
203 return Err(PyRuntimeError::new_err(
204 "Failed to mark Arrow device array capsule as consumed",
205 ));
206 }
207
208 Ok(unsafe { ArrowDeviceArrayOwned::from_raw(ptr as *mut ArrowDeviceArray) })
210}
211
212pub(crate) fn provider_from_config(config: GpuConfig) -> xlog_core::Result<CudaKernelProvider> {
213 let device = Arc::new(CudaDevice::new(config.device_ordinal)?);
214 let stream_pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
215 let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> =
216 Box::new(AsyncCudaResource::new(
217 Arc::clone(&device),
218 config.device_ordinal as u32,
219 Arc::clone(&stream_pool),
220 ));
221 let budget_limit = usize::try_from(config.memory_bytes).unwrap_or(usize::MAX);
222 let budgeted: Box<dyn DeviceMemoryResource + Send + Sync> =
223 Box::new(GlobalDeviceBudget::new(async_resource, budget_limit));
224 let runtime = Arc::new(XlogDeviceRuntime::with_resource(
225 Arc::clone(&device),
226 config.device_ordinal as u32,
227 stream_pool,
228 budgeted,
229 ));
230 let memory = Arc::new(GpuMemoryManager::with_runtime(
231 device.clone(),
232 MemoryBudget::with_limit(config.memory_bytes),
233 runtime,
234 ));
235 CudaKernelProvider::with_runtime(device, memory)
236}
237
238pub(crate) fn enforce_call_memory_limit(
239 provider: &Arc<CudaKernelProvider>,
240 memory_mb: Option<u64>,
241) -> PyResult<()> {
242 let Some(memory_mb) = memory_mb else {
243 return Ok(());
244 };
245 if memory_mb == 0 {
246 return Err(PyValueError::new_err("memory_mb must be > 0"));
247 }
248 let memory_limit_bytes = memory_mb.saturating_mul(1024 * 1024);
249 let allocated_bytes = provider.memory().allocated_bytes();
250 if allocated_bytes > memory_limit_bytes {
251 return Err(PyMemoryError::new_err(format!(
252 "per-call memory limit exceeded before evaluation: allocated_bytes={} memory_limit_bytes={}",
253 allocated_bytes, memory_limit_bytes
254 )));
255 }
256 Ok(())
257}
258
259pub(crate) fn provider_memory_stats(
260 py: Python<'_>,
261 provider: &Arc<CudaKernelProvider>,
262) -> PyResult<PyObject> {
263 let dict = PyDict::new(py);
264 let memory = provider.memory();
265 dict.set_item("allocated_bytes", memory.allocated_bytes())?;
266 dict.set_item("memory_limit_bytes", memory.budget().device_bytes)?;
267 dict.set_item("peak_memory_bytes", memory.allocated_bytes())?;
268 dict.set_item("status", "available")?;
269 Ok(dict.into())
270}
271
272#[allow(dead_code)]
273pub(crate) fn pack_rule_provenance(
274 py: Python<'_>,
275 entries: &[xlog_logic::RuleProvenance],
276) -> PyResult<PyObject> {
277 let list = PyList::empty(py);
278 for entry in entries {
279 let dict = PyDict::new(py);
280 dict.set_item("rule_id", &entry.rule_id)?;
281 dict.set_item("source_kind", entry.source_kind.as_str())?;
282 dict.set_item("head", &entry.head)?;
283 match &entry.source_span {
284 Some(source_span) => dict.set_item("source_span", source_span)?,
285 None => dict.set_item("source_span", py.None())?,
286 }
287 match &entry.generation_trace_hash {
288 Some(hash) => dict.set_item("generation_trace_hash", hash)?,
289 None => dict.set_item("generation_trace_hash", py.None())?,
290 }
291 dict.set_item("support_relation_ids", &entry.support_relation_ids)?;
292 dict.set_item(
293 "counterexample_relation_ids",
294 &entry.counterexample_relation_ids,
295 )?;
296 list.append(dict)?;
297 }
298 Ok(list.into())
299}
300
301#[allow(dead_code)]
302pub(crate) fn pack_query_proof_traces(
303 py: Python<'_>,
304 entries: &[xlog_logic::QueryProofTrace],
305) -> PyResult<PyObject> {
306 let list = PyList::empty(py);
307 for entry in entries {
308 let dict = PyDict::new(py);
309 dict.set_item("query_id", &entry.query_id)?;
310 dict.set_item("query", &entry.query)?;
311 dict.set_item("answer_relation", &entry.answer_relation)?;
312 dict.set_item("rule_ids", &entry.rule_ids)?;
313 dict.set_item("source_facts", &entry.source_facts)?;
314 dict.set_item("rejected_alternatives", &entry.rejected_alternatives)?;
315 list.append(dict)?;
316 }
317 Ok(list.into())
318}
319
320pub(crate) fn parse_prob_engine_override(s: &str) -> PyResult<ProbEngine> {
321 let v = s.trim().to_ascii_lowercase();
322 match v.as_str() {
323 "exact_ddnnf" | "exact" | "ddnnf" => Ok(ProbEngine::ExactDdnnf),
324 "mc" => Ok(ProbEngine::Mc),
325 other => Err(PyValueError::new_err(format!(
326 "Unknown prob_engine '{}'; expected 'exact_ddnnf' or 'mc'",
327 other
328 ))),
329 }
330}
331
332pub(crate) fn dlpack_from_py(obj: &Bound<'_, PyAny>) -> PyResult<DlpackManagedTensor> {
333 let py = obj.py();
334
335 let capsule_obj: Bound<'_, PyAny> = if unsafe {
337 pyo3::ffi::PyCapsule_IsValid(obj.as_ptr(), DLPACK_CAPSULE_NAME.as_ptr() as *const c_char)
338 } != 0
339 {
340 obj.clone()
341 } else if obj.hasattr("__dlpack__")? {
342 match obj.call_method0("__dlpack__") {
343 Ok(v) => v,
344 Err(err) => {
345 if err.is_instance_of::<pyo3::exceptions::PyTypeError>(py) {
346 obj.call_method1("__dlpack__", (py.None(),))?
347 } else {
348 return Err(err);
349 }
350 }
351 }
352 } else {
353 return Err(PyValueError::new_err(
354 "Expected a DLPack capsule or an object with __dlpack__",
355 ));
356 };
357
358 if unsafe {
360 pyo3::ffi::PyCapsule_IsValid(
361 capsule_obj.as_ptr(),
362 DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
363 )
364 } == 0
365 {
366 return Err(PyValueError::new_err("Invalid DLPack capsule"));
367 }
368
369 let ptr = unsafe {
371 pyo3::ffi::PyCapsule_GetPointer(
372 capsule_obj.as_ptr(),
373 DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
374 )
375 };
376 if ptr.is_null() {
377 return Err(PyRuntimeError::new_err("Failed to get DLPack pointer"));
378 }
379
380 let rc = unsafe {
382 pyo3::ffi::PyCapsule_SetName(
383 capsule_obj.as_ptr(),
384 USED_DLPACK_CAPSULE_NAME.as_ptr() as *const c_char,
385 )
386 };
387 if rc != 0 {
388 return Err(PyRuntimeError::new_err(
389 "Failed to mark DLPack capsule as consumed",
390 ));
391 }
392
393 Ok(unsafe { DlpackManagedTensor::from_raw(ptr as *mut xlog_cuda::DLManagedTensor) })
395}
396
397#[pyfunction]
398fn dlpack_is_cuda(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
399 if unsafe {
402 pyo3::ffi::PyCapsule_IsValid(obj.as_ptr(), DLPACK_CAPSULE_NAME.as_ptr() as *const c_char)
403 } == 0
404 {
405 return Err(PyValueError::new_err(
406 "Expected a DLPack capsule (dltensor)",
407 ));
408 }
409
410 let ptr = unsafe {
412 pyo3::ffi::PyCapsule_GetPointer(obj.as_ptr(), DLPACK_CAPSULE_NAME.as_ptr() as *const c_char)
413 };
414 if ptr.is_null() {
415 return Err(PyRuntimeError::new_err("Failed to get DLPack pointer"));
416 }
417
418 let managed = unsafe { &*(ptr as *const xlog_cuda::DLManagedTensor) };
420 Ok(managed.dl_tensor.device.device_type == xlog_cuda::dlpack::K_DLCUDA)
421}
422
423#[pyclass(name = "DifferentiableProofTraceMap")]
424pub struct PyDifferentiableProofTraceMap {
425 inner: xlog_logic::DifferentiableProofTraceMap,
426}
427
428fn pack_differentiable_proof_trace(
429 py: Python<'_>,
430 trace: &xlog_logic::ProofTrace,
431) -> PyResult<PyObject> {
432 let dict = PyDict::new(py);
433 dict.set_item("proof_id", trace.proof_id)?;
434 dict.set_item("answer_key", &trace.answer_key)?;
435 dict.set_item("clause_id", &trace.clause_id)?;
436 dict.set_item("support_atoms", &trace.support_atoms)?;
437 dict.set_item("weight", trace.weight)?;
438 dict.set_item("gradient", trace.gradient)?;
439 Ok(dict.into())
440}
441
442#[pymethods]
443impl PyDifferentiableProofTraceMap {
444 #[new]
445 fn new() -> Self {
446 Self {
447 inner: xlog_logic::DifferentiableProofTraceMap::new(),
448 }
449 }
450
451 fn insert(
452 &mut self,
453 answer_key: String,
454 clause_id: String,
455 support_atoms: Vec<String>,
456 initial_weight: f64,
457 ) -> PyResult<u64> {
458 if !initial_weight.is_finite() {
459 return Err(PyValueError::new_err(
460 "initial_weight must be a finite float",
461 ));
462 }
463 Ok(self.inner.insert(xlog_logic::ProofTraceSpec {
464 answer_key,
465 clause_id,
466 support_atoms,
467 initial_weight,
468 }))
469 }
470
471 fn trace(&self, py: Python<'_>, proof_id: u64) -> PyResult<Option<PyObject>> {
472 self.inner
473 .trace(proof_id)
474 .map(|trace| pack_differentiable_proof_trace(py, trace))
475 .transpose()
476 }
477
478 fn traces(&self, py: Python<'_>) -> PyResult<PyObject> {
479 let list = PyList::empty(py);
480 for trace in self.inner.traces() {
481 list.append(pack_differentiable_proof_trace(py, trace)?)?;
482 }
483 Ok(list.into())
484 }
485
486 fn accumulate_binary_logistic_gradients(
487 &mut self,
488 targets: Vec<(String, f64)>,
489 ) -> PyResult<f64> {
490 if targets.iter().any(|(_, target)| !target.is_finite()) {
491 return Err(PyValueError::new_err("targets must be finite floats"));
492 }
493 Ok(self.inner.accumulate_binary_logistic_gradients(&targets))
494 }
495
496 fn apply_gradients(&mut self, learning_rate: f64) -> PyResult<()> {
497 if !learning_rate.is_finite() || learning_rate < 0.0 {
498 return Err(PyValueError::new_err(
499 "learning_rate must be a finite non-negative float",
500 ));
501 }
502 self.inner.apply_gradients(learning_rate);
503 Ok(())
504 }
505}
506
507#[pyclass]
508pub struct Program;
509
510#[pyclass]
511pub struct CompiledProgram {
512 pub(crate) program: CompiledProbProgram,
513 pub(crate) output_provider: Arc<CudaKernelProvider>,
514 pub(crate) network_registry: NetworkRegistry,
516 pub(crate) neural_registry: NeuralPredicateRegistry,
518 pub(crate) declared_networks: HashSet<String>,
520 pub(crate) declared_network_forms: HashMap<String, bool>,
522 pub(crate) tensor_sources: TensorSourceRegistry,
524 pub(crate) domain_source: Option<String>,
529 pub(crate) _source: String,
531 pub(crate) ast: xlog_logic::ast::Program,
533 pub(crate) _gpu_config: GpuConfig,
535 pub(crate) _prob_engine: ProbEngine,
537 pub(crate) query_signature_cache: HashMap<String, QuerySignature>,
539 pub(crate) circuit_cache: HashMap<String, CachedCircuit>,
541 pub(crate) circuit_cache_hits: usize,
543 pub(crate) circuit_cache_misses: usize,
545 pub(crate) template_compile_count: usize,
547 pub(crate) batch_queries: bool,
549 pub(crate) last_compile_profile: Option<xlog_prob::compilation::CircuitCompileProfile>,
551}
552
553#[pyclass]
554pub struct LogicProgram;
555
556#[pyclass]
557pub struct CompiledLogicProgram {
558 pub(crate) program: gpu_logic::LogicProgram,
559 pub(crate) provider: Arc<CudaKernelProvider>,
560}
561
562#[pyclass]
563pub struct LogicRelationSession {
564 pub(crate) program: gpu_logic::LogicProgram,
565 pub(crate) provider: Arc<CudaKernelProvider>,
566 pub(crate) relation_store: RelationStore,
567 pub(crate) evaluation_store: Option<RelationStore>,
568 pub(crate) session_runtime: Option<gpu_logic::LogicSessionRuntime>,
569 pub(crate) last_delta_stats: Option<LogicDeltaStats>,
570 pub(crate) relation_callbacks: Vec<RelationChangeCallback>,
571 pub(crate) next_relation_callback_id: u64,
572 pub(crate) relation_generations: HashMap<String, u64>,
573}
574
575pub(crate) struct RelationChangeCallback {
576 pub id: u64,
577 pub callback: PyObject,
578}
579
580#[derive(Clone, Debug)]
581pub(crate) struct LogicDeltaStats {
582 pub input_delta_count: usize,
583 pub changed_relations: usize,
584 pub changed_relation_names: Vec<String>,
585 pub insert_rows: u64,
586 pub delete_rows: u64,
587 pub has_deletes: bool,
588 pub affected_sccs: usize,
589 pub recomputed_sccs: usize,
590 pub incremental_sccs: usize,
591 pub coalesced_insert_rows: u64,
592 pub coalesced_delete_rows: u64,
593 pub canceled_rows: u64,
594 pub equivalent_to_full_recompute: Option<bool>,
595 pub planner_telemetry: gpu_logic::DeltaPlannerTelemetry,
596 pub debug_trace: Vec<String>,
597}
598
599#[pyclass]
600pub struct LogicQueryResult {
601 #[pyo3(get)]
602 pub relation_name: String,
603 #[pyo3(get)]
604 pub columns: Vec<String>,
605 #[pyo3(get)]
606 pub sort_labels: Vec<String>,
607 #[pyo3(get)]
608 pub tensors: Vec<PyObject>,
609 #[pyo3(get)]
610 pub num_rows: usize,
611 #[pyo3(get)]
612 pub is_true: bool,
613}
614
615#[pyclass]
616pub struct LogicEvalResult {
617 #[pyo3(get)]
618 pub queries: Vec<Py<LogicQueryResult>>,
619}
620
621#[pyclass]
622pub struct IlpTaggedCreditDeviceResult {
623 #[pyo3(get)]
624 pub fact_row_offsets: PyObject,
625 #[pyo3(get)]
626 pub entry_indices: PyObject,
627 #[pyo3(get)]
628 pub entry_i: PyObject,
629 #[pyo3(get)]
630 pub entry_j: PyObject,
631 #[pyo3(get)]
632 pub entry_k: PyObject,
633}
634
635#[pyclass]
636pub struct McDeviceEvalResult {
637 #[pyo3(get)]
639 pub query_counts: PyObject,
640 #[pyo3(get)]
642 pub evidence_count: PyObject,
643 #[pyo3(get)]
644 pub total_samples: usize,
645 #[pyo3(get)]
646 pub seed: u64,
647 #[pyo3(get)]
648 pub confidence: f64,
649 #[pyo3(get)]
650 pub nonmonotone_semantics: String,
651 #[pyo3(get)]
652 pub nonmonotone_sccs: usize,
653 #[pyo3(get)]
654 pub nonmonotone_cycles: usize,
655 #[pyo3(get)]
656 pub nonmonotone_iteration_limit_hits: usize,
657 #[pyo3(get)]
658 pub sampling_method: String,
659 #[pyo3(get)]
660 pub resident_no_host_certified: bool,
661 #[pyo3(get)]
662 pub resident_no_host_policy_result: String,
663 #[pyo3(get)]
664 pub resident_no_host_tracked_dtoh_calls: u64,
665 #[pyo3(get)]
666 pub resident_no_host_tracked_htod_calls: u64,
667 #[pyo3(get)]
668 pub resident_no_host_host_loop_iterations: u64,
669 #[pyo3(get)]
670 pub resident_no_host_per_sample_host_launches: u64,
671 #[pyo3(get)]
672 pub resident_no_host_untracked_metadata_reads: u64,
673 #[pyo3(get)]
674 pub resident_no_host_engine_launches: u64,
675 #[pyo3(get)]
676 pub resident_no_host_host_fixpoint_iterations: u64,
677 #[pyo3(get)]
678 pub resident_no_host_per_operator_host_allocations: u64,
679}
680
681#[pyclass]
682pub struct EvalResult {
683 #[pyo3(get)]
684 pub atoms: Vec<String>,
685 #[pyo3(get)]
686 pub prob: PyObject,
687 #[pyo3(get)]
688 pub log_prob: PyObject,
689 #[pyo3(get)]
690 pub num_vars: usize,
691 #[pyo3(get)]
692 pub grad_true: Option<Vec<PyObject>>,
693 #[pyo3(get)]
694 pub grad_false: Option<Vec<PyObject>>,
695 #[pyo3(get)]
696 pub approx: bool,
697 #[pyo3(get)]
698 pub stderr: Option<PyObject>,
699 #[pyo3(get)]
700 pub ci_low: Option<PyObject>,
701 #[pyo3(get)]
702 pub ci_high: Option<PyObject>,
703 #[pyo3(get)]
704 pub samples: Option<usize>,
705 #[pyo3(get)]
706 pub evidence_samples: Option<usize>,
707 #[pyo3(get)]
708 pub seed: Option<u64>,
709 #[pyo3(get)]
710 pub confidence: Option<f64>,
711 #[pyo3(get)]
712 pub nonmonotone_semantics: Option<String>,
713 #[pyo3(get)]
714 pub nonmonotone_sccs: Option<usize>,
715 #[pyo3(get)]
716 pub nonmonotone_cycles: Option<usize>,
717 #[pyo3(get)]
718 pub nonmonotone_iteration_limit_hits: Option<usize>,
719 #[pyo3(get)]
720 pub sampling_method: Option<String>,
721 #[pyo3(get)]
725 pub mc_engine: Option<String>,
726}
727
728#[pyclass]
734#[derive(Clone)]
735pub struct EpochStats {
736 #[pyo3(get)]
738 pub avg_loss: f64,
739 #[pyo3(get)]
741 pub num_batches: usize,
742 #[pyo3(get)]
744 pub total_queries: usize,
745}
746
747#[pyclass]
749#[derive(Clone)]
750pub struct TrainingHistory {
751 #[pyo3(get)]
753 pub epoch_losses: Vec<f64>,
754 #[pyo3(get)]
756 pub epoch_times: Vec<f64>,
757 #[pyo3(get)]
759 pub batch_losses: Vec<f64>,
760 #[pyo3(get)]
762 pub stopped_early: bool,
763}
764
765#[pyclass]
766pub struct IlpProgramFactory;
767
768#[pyclass]
769pub struct CompiledIlpProgram {
770 pub(crate) base_source: String,
771 pub(crate) _learnable_source: String,
772 pub(crate) ast: AstProgram,
773 pub(crate) executor: Executor,
774 pub(crate) provider: Arc<CudaKernelProvider>,
775 pub(crate) plan: ExecutionPlan,
776 pub(crate) rel_index: Vec<(RelId, String)>,
777 pub(crate) schemas: HashMap<String, Schema>,
778 pub(crate) left_keys: Vec<usize>,
779 pub(crate) right_keys: Vec<usize>,
780 pub(crate) head_projection: Vec<usize>,
781 pub(crate) compiled_schema_size: usize,
782 pub(crate) head_rel_name: String,
783 pub(crate) max_active_rules: usize,
784 pub(crate) candidate_map: Option<HashMap<(u32, u32, u32), u32>>,
785 pub(crate) candidate_order: Option<Vec<(u32, u32, u32)>>,
786 pub(crate) relation_overrides: HashMap<String, CudaBuffer>,
787 pub(crate) coo_chunk_budget: u64,
791 pub(crate) strict_zero_dtoh: bool,
794}
795
796#[pymodule]
797#[pyo3(name = "_native")]
798fn pyxlog(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
799 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
800 m.add_class::<Program>()?;
801 m.add_class::<CompiledProgram>()?;
802 m.add_class::<LogicProgram>()?;
803 m.add_class::<CompiledLogicProgram>()?;
804 m.add_class::<LogicRelationSession>()?;
805 m.add_class::<LogicQueryResult>()?;
806 m.add_class::<LogicEvalResult>()?;
807 m.add_class::<McDeviceEvalResult>()?;
808 m.add_class::<EvalResult>()?;
809 m.add_class::<PyDifferentiableProofTraceMap>()?;
811 m.add_class::<EpochStats>()?;
812 m.add_class::<TrainingHistory>()?;
813 m.add_class::<IlpProgramFactory>()?;
815 m.add_class::<CompiledIlpProgram>()?;
816 m.add_class::<IlpTaggedCreditDeviceResult>()?;
817 m.add_function(wrap_pyfunction!(training::train_model, m)?)?;
818 m.add_function(wrap_pyfunction!(training::train_model_tensor, m)?)?;
819 m.add_function(wrap_pyfunction!(dlpack::dlpack_roundtrip, m)?)?;
820 m.add_function(wrap_pyfunction!(dlpack_is_cuda, m)?)?;
821 #[cfg(feature = "arrow-device-import")]
822 m.add_function(wrap_pyfunction!(dlpack::export_arrow_device, m)?)?;
823 #[cfg(feature = "arrow-device-import")]
824 m.add_function(wrap_pyfunction!(dlpack::import_arrow_device, m)?)?;
825 Ok(())
826}