1use cudarc::driver::DeviceSlice;
11use pyo3::exceptions::{PyRuntimeError, PyValueError};
12use pyo3::prelude::*;
13use pyo3::types::{PyDict, PyList};
14
15use xlog_core::{ScalarType, Schema};
16use xlog_logic::ast::Term;
17use xlog_prob::exact::ExactDdnnfProgram;
18#[cfg(feature = "host-io")]
19use xlog_prob::exact::{ExactResultWithGrads, QueryProbability};
20use xlog_prob::mc::{McEvalConfig, McProgram, McSamplingMethod};
21use xlog_prob::neural_fast_path::GpuWeightSlots;
22
23use super::neural_registry::NeuralPredicateInfo;
24use super::{
25 dlpack_capsule_from_tensor, enforce_call_memory_limit, provider_memory_stats, types,
26 CompiledProgram, EpochStats, EvalResult, McDeviceEvalResult, TrainingHistory,
27};
28
29pub(crate) struct CachedCircuit {
38 pub(crate) program: ExactDdnnfProgram,
40
41 pub(crate) slots: GpuWeightSlots,
43
44 pub(crate) target_domain: Vec<String>,
46}
47
48#[derive(Debug, Clone)]
49pub(crate) enum InputSource {
50 QueryArg(usize),
51 ImplicitSlot(usize),
52 DomainRow(usize),
56 ConstDummy,
60}
61
62#[derive(Debug, Clone)]
63pub(crate) struct NeuralGroup {
64 pub(crate) info: NeuralPredicateInfo,
65 pub(crate) input_source: InputSource,
66 pub(crate) ground_const: Option<Term>,
71 #[cfg(feature = "host-io")]
72 pub(crate) output_var: Option<String>,
73}
74
75#[derive(Debug, Clone)]
81pub(crate) struct HardFilter {
82 pub(crate) relation: String,
84 pub(crate) arg_head_positions: Vec<usize>,
89}
90
91#[derive(Debug, Clone)]
98pub(crate) struct JoinPlan {
99 pub(crate) relations: Vec<String>,
102 pub(crate) head_domain: Vec<String>,
106}
107
108#[derive(Debug, Clone)]
109pub(crate) enum QuerySignature {
110 Boolean {
111 groups: Vec<NeuralGroup>,
112 hard_filters: Vec<HardFilter>,
113 },
114 Targeted {
115 target_position: usize,
116 groups: Vec<NeuralGroup>,
117 hard_filters: Vec<HardFilter>,
118 join: Option<JoinPlan>,
122 },
123}
124
125impl QuerySignature {
126 pub(crate) fn groups(&self) -> &[NeuralGroup] {
127 match self {
128 QuerySignature::Boolean { groups, .. } | QuerySignature::Targeted { groups, .. } => {
129 groups
130 }
131 }
132 }
133
134 pub(crate) fn hard_filters(&self) -> &[HardFilter] {
135 match self {
136 QuerySignature::Boolean { hard_filters, .. }
137 | QuerySignature::Targeted { hard_filters, .. } => hard_filters,
138 }
139 }
140
141 pub(crate) fn join(&self) -> Option<&JoinPlan> {
143 match self {
144 QuerySignature::Targeted { join, .. } => join.as_ref(),
145 QuerySignature::Boolean { .. } => None,
146 }
147 }
148}
149
150pub(crate) enum CompiledProbProgram {
151 Exact(ExactDdnnfProgram),
152 Mc(McProgram),
153}
154
155impl CompiledProbProgram {
156 #[cfg(feature = "host-io")]
157 pub(crate) fn num_vars(&self) -> usize {
158 match self {
159 Self::Exact(p) => p.num_vars(),
160 Self::Mc(p) => p.num_vars(),
161 }
162 }
163}
164
165#[cfg(feature = "host-io")]
170fn atom_to_string(atom: &xlog_prob::provenance::GroundAtom) -> String {
171 use xlog_prob::provenance::Value;
172
173 if atom.args.is_empty() {
174 return format!("{}()", atom.predicate);
175 }
176
177 let mut s = String::new();
178 s.push_str(&atom.predicate);
179 s.push('(');
180 for (i, arg) in atom.args.iter().enumerate() {
181 if i != 0 {
182 s.push_str(", ");
183 }
184 match arg {
185 Value::I64(v) => s.push_str(&v.to_string()),
186 Value::F64(bits) => s.push_str(&f64::from_bits(*bits).to_string()),
187 Value::Symbol(sym) => s.push_str(&format!("sym#{}", sym)),
188 Value::String(v) => s.push_str(v),
189 }
190 }
191 s.push(')');
192 s
193}
194
195impl CompiledProgram {
200 pub(crate) fn parse_sampling_method(s: Option<String>) -> PyResult<Option<McSamplingMethod>> {
201 match s.as_deref() {
202 None => Ok(None),
203 Some("rejection") => Ok(Some(McSamplingMethod::Rejection)),
204 Some("evidence_clamping") => Ok(Some(McSamplingMethod::EvidenceClamping)),
205 Some(other) => Err(PyValueError::new_err(format!(
206 "Unknown sampling_method '{}'. Use 'rejection' or 'evidence_clamping'.",
207 other
208 ))),
209 }
210 }
211
212 pub(crate) fn evaluate_query_probability(&self, query: &str) -> PyResult<f64> {
214 let probs = self.evaluate_query_probabilities(&[query.to_string()])?;
215 probs
216 .into_iter()
217 .next()
218 .ok_or_else(|| PyRuntimeError::new_err("Query evaluation returned no results"))
219 }
220
221 pub(crate) fn evaluate_query_probabilities(&self, queries: &[String]) -> PyResult<Vec<f64>> {
223 #[cfg(not(feature = "host-io"))]
224 {
225 let _ = queries;
226 return Err(types::host_io_disabled_pyerr());
227 }
228
229 #[cfg(feature = "host-io")]
230 {
231 let mut source_with_queries = self._source.clone();
233 for query in queries {
234 source_with_queries.push_str(&format!("\nquery({}).", query));
235 }
236
237 let result: Vec<QueryProbability> = match self._prob_engine {
239 xlog_logic::ast::ProbEngine::ExactDdnnf => {
240 let program = ExactDdnnfProgram::compile_source_with_gpu(
241 &source_with_queries,
242 self._gpu_config,
243 )
244 .map_err(|e| types::gpu_err("Query compilation error", e))?;
245
246 program
247 .evaluate()
248 .map_err(|e| types::gpu_err("Query evaluation error", e))?
249 .query_probs
250 }
251 xlog_logic::ast::ProbEngine::Mc => {
252 let program =
253 McProgram::compile_source_with_gpu(&source_with_queries, self._gpu_config)
254 .map_err(|e| types::gpu_err("Query compilation error", e))?;
255
256 let cfg = McEvalConfig::default();
257 program
258 .evaluate(cfg)
259 .map_err(|e| types::gpu_err("Query evaluation error", e))?
260 .query_estimates
261 .into_iter()
262 .map(|e| QueryProbability {
263 atom: e.atom,
264 prob: e.prob,
265 log_prob: e.log_prob,
266 })
267 .collect()
268 }
269 };
270
271 let probs: Vec<f64> = result.iter().map(|qp| qp.prob).collect();
274
275 if probs.len() != queries.len() {
276 return Err(PyRuntimeError::new_err(format!(
277 "Expected {} query results, got {}",
278 queries.len(),
279 probs.len()
280 )));
281 }
282
283 Ok(probs)
284 }
285 }
286
287 #[cfg(feature = "host-io")]
288 fn pack_result_probs(
289 &self,
290 py: Python<'_>,
291 query_probs: Vec<QueryProbability>,
292 ) -> PyResult<EvalResult> {
293 let mut atoms: Vec<String> = Vec::with_capacity(query_probs.len());
294 let mut probs: Vec<f64> = Vec::with_capacity(query_probs.len());
295 let mut log_probs: Vec<f64> = Vec::with_capacity(query_probs.len());
296
297 for q in query_probs {
298 atoms.push(atom_to_string(&q.atom));
299 probs.push(q.prob);
300 log_probs.push(q.log_prob);
301 }
302
303 let schema = Schema::new(vec![("col0".to_string(), ScalarType::F64)]);
304 let prob_buf = self
305 .output_provider
306 .create_buffer_from_slice::<f64>(&probs, schema.clone())
307 .map_err(types::xlog_err)?;
308 let log_prob_buf = self
309 .output_provider
310 .create_buffer_from_slice::<f64>(&log_probs, schema)
311 .map_err(types::xlog_err)?;
312
313 let prob_tensor = self
314 .output_provider
315 .to_dlpack_table(prob_buf)
316 .column(0)
317 .map_err(types::xlog_err)?;
318 let log_prob_tensor = self
319 .output_provider
320 .to_dlpack_table(log_prob_buf)
321 .column(0)
322 .map_err(types::xlog_err)?;
323
324 Ok(EvalResult {
325 atoms,
326 prob: dlpack_capsule_from_tensor(py, prob_tensor)?,
327 log_prob: dlpack_capsule_from_tensor(py, log_prob_tensor)?,
328 num_vars: self.program.num_vars(),
329 grad_true: None,
330 grad_false: None,
331 approx: false,
332 stderr: None,
333 ci_low: None,
334 ci_high: None,
335 samples: None,
336 evidence_samples: None,
337 seed: None,
338 confidence: None,
339 nonmonotone_semantics: None,
340 nonmonotone_sccs: None,
341 nonmonotone_cycles: None,
342 nonmonotone_iteration_limit_hits: None,
343 sampling_method: None,
344 mc_engine: None,
345 })
346 }
347
348 #[cfg(feature = "host-io")]
349 fn pack_result_with_grads(
350 &self,
351 py: Python<'_>,
352 result: ExactResultWithGrads,
353 ) -> PyResult<EvalResult> {
354 let mut atoms: Vec<String> = Vec::with_capacity(result.query_grads.len());
355 let mut probs: Vec<f64> = Vec::with_capacity(result.query_grads.len());
356 let mut log_probs: Vec<f64> = Vec::with_capacity(result.query_grads.len());
357
358 let mut grad_true_caps: Vec<PyObject> = Vec::with_capacity(result.query_grads.len());
359 let mut grad_false_caps: Vec<PyObject> = Vec::with_capacity(result.query_grads.len());
360
361 let schema = Schema::new(vec![("col0".to_string(), ScalarType::F64)]);
362
363 let num_vars = self.program.num_vars();
364 for q in result.query_grads {
365 atoms.push(atom_to_string(&q.atom));
366 probs.push(q.prob);
367 log_probs.push(q.log_prob);
368
369 let grad_true_buf = self
370 .output_provider
371 .create_buffer_from_slice::<f64>(&q.grad_true, schema.clone())
372 .map_err(types::xlog_err)?;
373 let grad_false_buf = self
374 .output_provider
375 .create_buffer_from_slice::<f64>(&q.grad_false, schema.clone())
376 .map_err(types::xlog_err)?;
377
378 let grad_true_tensor = self
379 .output_provider
380 .to_dlpack_table(grad_true_buf)
381 .column(0)
382 .map_err(types::xlog_err)?;
383 let grad_false_tensor = self
384 .output_provider
385 .to_dlpack_table(grad_false_buf)
386 .column(0)
387 .map_err(types::xlog_err)?;
388
389 grad_true_caps.push(dlpack_capsule_from_tensor(py, grad_true_tensor)?);
390 grad_false_caps.push(dlpack_capsule_from_tensor(py, grad_false_tensor)?);
391 }
392
393 let prob_buf = self
394 .output_provider
395 .create_buffer_from_slice::<f64>(&probs, schema.clone())
396 .map_err(types::xlog_err)?;
397 let log_prob_buf = self
398 .output_provider
399 .create_buffer_from_slice::<f64>(&log_probs, schema)
400 .map_err(types::xlog_err)?;
401
402 let prob_tensor = self
403 .output_provider
404 .to_dlpack_table(prob_buf)
405 .column(0)
406 .map_err(types::xlog_err)?;
407 let log_prob_tensor = self
408 .output_provider
409 .to_dlpack_table(log_prob_buf)
410 .column(0)
411 .map_err(types::xlog_err)?;
412
413 Ok(EvalResult {
414 atoms,
415 prob: dlpack_capsule_from_tensor(py, prob_tensor)?,
416 log_prob: dlpack_capsule_from_tensor(py, log_prob_tensor)?,
417 num_vars,
418 grad_true: Some(grad_true_caps),
419 grad_false: Some(grad_false_caps),
420 approx: false,
421 stderr: None,
422 ci_low: None,
423 ci_high: None,
424 samples: None,
425 evidence_samples: None,
426 seed: None,
427 confidence: None,
428 nonmonotone_semantics: None,
429 nonmonotone_sccs: None,
430 nonmonotone_cycles: None,
431 nonmonotone_iteration_limit_hits: None,
432 sampling_method: None,
433 mc_engine: None,
434 })
435 }
436
437 #[cfg(feature = "host-io")]
438 fn pack_result_mc(
439 &self,
440 py: Python<'_>,
441 result: xlog_prob::mc::McResult,
442 ) -> PyResult<EvalResult> {
443 let mut atoms: Vec<String> = Vec::with_capacity(result.query_estimates.len());
444 let mut probs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
445 let mut log_probs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
446 let mut stderrs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
447 let mut ci_lows: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
448 let mut ci_highs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
449
450 for q in &result.query_estimates {
451 atoms.push(atom_to_string(&q.atom));
452 probs.push(q.prob);
453 log_probs.push(q.log_prob);
454 stderrs.push(q.stderr);
455 ci_lows.push(q.ci_low);
456 ci_highs.push(q.ci_high);
457 }
458
459 let schema = Schema::new(vec![("col0".to_string(), ScalarType::F64)]);
460 let prob_buf = self
461 .output_provider
462 .create_buffer_from_slice::<f64>(&probs, schema.clone())
463 .map_err(types::xlog_err)?;
464 let log_prob_buf = self
465 .output_provider
466 .create_buffer_from_slice::<f64>(&log_probs, schema.clone())
467 .map_err(types::xlog_err)?;
468 let stderr_buf = self
469 .output_provider
470 .create_buffer_from_slice::<f64>(&stderrs, schema.clone())
471 .map_err(types::xlog_err)?;
472 let ci_low_buf = self
473 .output_provider
474 .create_buffer_from_slice::<f64>(&ci_lows, schema.clone())
475 .map_err(types::xlog_err)?;
476 let ci_high_buf = self
477 .output_provider
478 .create_buffer_from_slice::<f64>(&ci_highs, schema)
479 .map_err(types::xlog_err)?;
480
481 let prob_tensor = self
482 .output_provider
483 .to_dlpack_table(prob_buf)
484 .column(0)
485 .map_err(types::xlog_err)?;
486 let log_prob_tensor = self
487 .output_provider
488 .to_dlpack_table(log_prob_buf)
489 .column(0)
490 .map_err(types::xlog_err)?;
491 let stderr_tensor = self
492 .output_provider
493 .to_dlpack_table(stderr_buf)
494 .column(0)
495 .map_err(types::xlog_err)?;
496 let ci_low_tensor = self
497 .output_provider
498 .to_dlpack_table(ci_low_buf)
499 .column(0)
500 .map_err(types::xlog_err)?;
501 let ci_high_tensor = self
502 .output_provider
503 .to_dlpack_table(ci_high_buf)
504 .column(0)
505 .map_err(types::xlog_err)?;
506
507 Ok(EvalResult {
508 atoms,
509 prob: dlpack_capsule_from_tensor(py, prob_tensor)?,
510 log_prob: dlpack_capsule_from_tensor(py, log_prob_tensor)?,
511 num_vars: self.program.num_vars(),
512 grad_true: None,
513 grad_false: None,
514 approx: true,
515 stderr: Some(dlpack_capsule_from_tensor(py, stderr_tensor)?),
516 ci_low: Some(dlpack_capsule_from_tensor(py, ci_low_tensor)?),
517 ci_high: Some(dlpack_capsule_from_tensor(py, ci_high_tensor)?),
518 samples: Some(result.total_samples),
519 evidence_samples: Some(result.evidence_samples),
520 seed: Some(result.seed),
521 confidence: Some(result.confidence),
522 nonmonotone_semantics: Some(xlog_prob::mc::NONMONOTONE_SEMANTICS.to_string()),
523 nonmonotone_sccs: Some(result.nonmonotone_sccs),
524 nonmonotone_cycles: Some(result.nonmonotone_cycles),
525 nonmonotone_iteration_limit_hits: Some(result.nonmonotone_iteration_limit_hits),
526 sampling_method: Some(match result.sampling_method {
527 McSamplingMethod::Rejection => "rejection".to_string(),
528 McSamplingMethod::EvidenceClamping => "evidence_clamping".to_string(),
529 }),
530 mc_engine: Some(result.engine.as_str().to_string()),
531 })
532 }
533}
534
535#[pymethods]
540impl CompiledProgram {
541 #[pyo3(signature = (return_grads=false, samples=None, seed=None, confidence=0.95, max_nonmonotone_iterations=1024, sampling_method=None, memory_mb=None, allow_cpu_oracle=false))]
542 pub fn evaluate(
543 &self,
544 _py: Python<'_>,
545 return_grads: bool,
546 samples: Option<usize>,
547 seed: Option<u64>,
548 confidence: f64,
549 max_nonmonotone_iterations: usize,
550 sampling_method: Option<String>,
551 memory_mb: Option<u64>,
552 allow_cpu_oracle: bool,
553 ) -> PyResult<EvalResult> {
554 enforce_call_memory_limit(&self.output_provider, memory_mb)?;
555 match &self.program {
556 CompiledProbProgram::Exact(_program) => {
557 if samples.is_some() || seed.is_some() {
558 return Err(PyValueError::new_err(
559 "samples/seed are only supported for prob_engine='mc'",
560 ));
561 }
562 #[cfg(feature = "host-io")]
563 {
564 if return_grads {
565 let result = _program
566 .evaluate_gpu_with_grads()
567 .map_err(types::xlog_err)?;
568 self.pack_result_with_grads(_py, result)
569 } else {
570 let result = _program.evaluate().map_err(types::xlog_err)?;
571 self.pack_result_probs(_py, result.query_probs)
572 }
573 }
574 #[cfg(not(feature = "host-io"))]
575 {
576 let _ = return_grads;
577 Err(types::host_io_disabled_pyerr())
578 }
579 }
580 CompiledProbProgram::Mc(_program) => {
581 if return_grads {
582 return Err(PyValueError::new_err(
583 "MC inference does not support gradients (return_grads must be false)",
584 ));
585 }
586
587 let mut cfg = McEvalConfig::default();
588 cfg.samples = samples.unwrap_or(10000);
589 cfg.seed = seed.unwrap_or(0);
590 cfg.confidence = confidence;
591 cfg.max_nonmonotone_iterations = max_nonmonotone_iterations;
592 cfg.sampling_method = Self::parse_sampling_method(sampling_method)?;
593 cfg.allow_cpu_oracle_fallback = allow_cpu_oracle;
597 #[cfg(feature = "host-io")]
598 {
599 let result = _program.evaluate(cfg).map_err(types::xlog_err)?;
600 self.pack_result_mc(_py, result)
601 }
602 #[cfg(not(feature = "host-io"))]
603 {
604 let _ = cfg;
605 Err(types::host_io_disabled_pyerr())
606 }
607 }
608 }
609 }
610
611 #[pyo3(signature = (samples=None, seed=None, confidence=0.95, max_nonmonotone_iterations=1024, sampling_method=None, memory_mb=None))]
616 pub fn evaluate_device(
617 &self,
618 py: Python<'_>,
619 samples: Option<usize>,
620 seed: Option<u64>,
621 confidence: f64,
622 max_nonmonotone_iterations: usize,
623 sampling_method: Option<String>,
624 memory_mb: Option<u64>,
625 ) -> PyResult<McDeviceEvalResult> {
626 enforce_call_memory_limit(&self.output_provider, memory_mb)?;
627 let (
628 query_counts,
629 evidence_count,
630 total_samples,
631 seed,
632 confidence,
633 nonmonotone_sccs,
634 nonmonotone_cycles,
635 nonmonotone_iteration_limit_hits,
636 sampling_method_val,
637 no_host,
638 ) = match &self.program {
639 CompiledProbProgram::Mc(program) => {
640 let mut cfg = McEvalConfig::default();
641 cfg.samples = samples.unwrap_or(10000);
642 cfg.seed = seed.unwrap_or(0);
643 cfg.confidence = confidence;
644 cfg.max_nonmonotone_iterations = max_nonmonotone_iterations;
645 cfg.sampling_method = Self::parse_sampling_method(sampling_method)?;
646
647 let result = program
648 .evaluate_gpu_device_with_provider(cfg, self.output_provider.clone())
649 .map_err(types::xlog_err)?;
650
651 (
652 result.query_counts,
653 result.evidence_count,
654 result.total_samples,
655 result.seed,
656 result.confidence,
657 result.nonmonotone_sccs,
658 result.nonmonotone_cycles,
659 result.nonmonotone_iteration_limit_hits,
660 result.sampling_method,
661 result.no_host,
662 )
663 }
664 _ => {
665 return Err(PyValueError::new_err(
666 "evaluate_device is only supported for prob_engine='mc'",
667 ))
668 }
669 };
670
671 let schema_i32 = Schema::new(vec![("col0".to_string(), ScalarType::I32)]);
674
675 let make_count_tensor =
676 |counts: xlog_cuda::memory::TrackedCudaSlice<u32>, rows: u64| -> PyResult<PyObject> {
677 let rows_u32 = u32::try_from(rows).map_err(|_| {
678 PyValueError::new_err(format!("Row count {} exceeds u32::MAX", rows))
679 })?;
680
681 let mut d_num_rows = self
682 .output_provider
683 .memory()
684 .alloc::<u32>(1)
685 .map_err(types::xlog_err)?;
686 self.output_provider
687 .device()
688 .inner()
689 .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
690 .map_err(types::xlog_err)?;
691
692 let buffer = xlog_cuda::CudaBuffer::from_columns(
693 vec![counts.into_bytes().into()],
694 rows,
695 d_num_rows,
696 schema_i32.clone(),
697 );
698 let tensor = self
699 .output_provider
700 .to_dlpack_table(buffer)
701 .column(0)
702 .map_err(types::xlog_err)?;
703 dlpack_capsule_from_tensor(py, tensor)
704 };
705
706 let query_rows = u64::try_from(query_counts.len())
707 .map_err(|_| PyValueError::new_err("query_counts length overflow"))?;
708 let query_counts_capsule = make_count_tensor(query_counts, query_rows)?;
709 let evidence_count_capsule = make_count_tensor(evidence_count, 1)?;
710 let resident_no_host_certified = no_host.is_no_host();
711
712 Ok(McDeviceEvalResult {
713 query_counts: query_counts_capsule,
714 evidence_count: evidence_count_capsule,
715 total_samples,
716 seed,
717 confidence,
718 nonmonotone_semantics: xlog_prob::mc::NONMONOTONE_SEMANTICS.to_string(),
719 nonmonotone_sccs,
720 nonmonotone_cycles,
721 nonmonotone_iteration_limit_hits,
722 sampling_method: match sampling_method_val {
723 McSamplingMethod::Rejection => "rejection".to_string(),
724 McSamplingMethod::EvidenceClamping => "evidence_clamping".to_string(),
725 },
726 resident_no_host_certified,
727 resident_no_host_policy_result: if resident_no_host_certified {
728 "certified".to_string()
729 } else {
730 "failed".to_string()
731 },
732 resident_no_host_tracked_dtoh_calls: no_host.tracked_dtoh_calls,
733 resident_no_host_tracked_htod_calls: no_host.tracked_htod_calls,
734 resident_no_host_host_loop_iterations: no_host.host_loop_iterations,
735 resident_no_host_per_sample_host_launches: no_host.per_sample_host_launches,
736 resident_no_host_untracked_metadata_reads: no_host.untracked_metadata_reads,
737 resident_no_host_engine_launches: no_host.engine_launches,
738 resident_no_host_host_fixpoint_iterations: no_host.host_fixpoint_iterations,
739 resident_no_host_per_operator_host_allocations: no_host.per_operator_host_allocations,
740 })
741 }
742
743 fn nll_loss(&self, query: &str) -> PyResult<f64> {
760 let prob = self.evaluate_query_probability(query)?;
761 Ok(types::nll_loss_value(prob))
762 }
763
764 fn nll_loss_batch(&self, queries: Vec<String>) -> PyResult<f64> {
777 if queries.is_empty() {
778 return Ok(0.0);
779 }
780
781 let probs = self.evaluate_query_probabilities(&queries)?;
782 Ok(probs.iter().map(|&p| types::nll_loss_value(p)).sum())
783 }
784
785 fn nll_loss_mean(&self, queries: Vec<String>) -> PyResult<f64> {
800 if queries.is_empty() {
801 return Err(PyValueError::new_err(
802 "Cannot compute mean NLL loss for empty query batch",
803 ));
804 }
805
806 let probs = self.evaluate_query_probabilities(&queries)?;
807 let sum: f64 = probs.iter().map(|&p| types::nll_loss_value(p)).sum();
808 Ok(sum / probs.len() as f64)
809 }
810
811 fn nll_loss_tensor(&self, py: Python<'_>, query: &str) -> PyResult<PyObject> {
822 let loss = self.nll_loss(query)?;
823 types::create_torch_tensor(py, loss)
824 }
825
826 fn nll_loss_batch_tensor(&self, py: Python<'_>, queries: Vec<String>) -> PyResult<PyObject> {
834 let loss = self.nll_loss_batch(queries)?;
835 types::create_torch_tensor(py, loss)
836 }
837
838 pub fn zero_grad(&self, py: Python<'_>) -> PyResult<()> {
847 for name in self.network_registry.names() {
848 if let Some(handle) = self.network_registry.get(name) {
849 if let Some(optimizer) = handle.optimizer() {
850 optimizer.call_method0(py, "zero_grad")?;
851 }
852 }
853 }
854 Ok(())
855 }
856
857 pub fn optimizer_step(&self, py: Python<'_>) -> PyResult<()> {
862 for name in self.network_registry.names() {
863 if let Some(handle) = self.network_registry.get(name) {
864 if let Some(optimizer) = handle.optimizer() {
865 optimizer.call_method0(py, "step")?;
866 }
867 }
868 }
869 Ok(())
870 }
871
872 pub fn clip_grad_norms(&self, py: Python<'_>, max_norm: f64) -> PyResult<()> {
876 let clip_fn = py.import("torch.nn.utils")?.getattr("clip_grad_norm_")?;
877 for name in self.network_registry.names() {
878 if let Some(handle) = self.network_registry.get(name) {
879 if let Some(module) = handle.module() {
880 let params = module.call_method0(py, "parameters")?;
881 clip_fn.call1((params, max_norm))?;
882 }
883 }
884 }
885 Ok(())
886 }
887
888 #[pyo3(signature = (network_name=None))]
897 fn scheduler_step(&self, py: Python<'_>, network_name: Option<&str>) -> PyResult<()> {
898 match network_name {
899 Some(name) => {
900 let handle = self.network_registry.get(name).ok_or_else(|| {
901 pyo3::exceptions::PyValueError::new_err(format!(
902 "No network registered with name '{name}'"
903 ))
904 })?;
905 if let Some(scheduler) = handle.scheduler() {
906 scheduler.call_method0(py, "step")?;
907 }
908 }
909 None => {
910 for name in self.network_registry.names() {
911 if let Some(handle) = self.network_registry.get(name) {
912 if let Some(scheduler) = handle.scheduler() {
913 scheduler.call_method0(py, "step")?;
914 }
915 }
916 }
917 }
918 }
919 Ok(())
920 }
921
922 fn get_lr(&self, py: Python<'_>, network_name: &str) -> PyResult<f64> {
929 let handle = self.network_registry.get(network_name).ok_or_else(|| {
930 pyo3::exceptions::PyValueError::new_err(format!(
931 "No network registered with name '{network_name}'"
932 ))
933 })?;
934 let optimizer = handle.optimizer().ok_or_else(|| {
935 pyo3::exceptions::PyValueError::new_err(format!(
936 "Network '{network_name}' has no optimizer"
937 ))
938 })?;
939 let param_groups = optimizer.getattr(py, "param_groups")?;
940 let group0 = param_groups.call_method1(py, "__getitem__", (0i32,))?;
941 let lr = group0.call_method1(py, "__getitem__", ("lr",))?;
942 lr.extract(py)
943 }
944
945 fn set_lr(&self, py: Python<'_>, network_name: &str, lr: f64) -> PyResult<()> {
953 let handle = self.network_registry.get(network_name).ok_or_else(|| {
954 pyo3::exceptions::PyValueError::new_err(format!(
955 "No network registered with name '{network_name}'"
956 ))
957 })?;
958 let optimizer = handle.optimizer().ok_or_else(|| {
959 pyo3::exceptions::PyValueError::new_err(format!(
960 "Network '{network_name}' has no optimizer"
961 ))
962 })?;
963 let param_groups = optimizer.getattr(py, "param_groups")?;
964 let num_groups: usize = param_groups.call_method0(py, "__len__")?.extract(py)?;
965 for i in 0..num_groups {
966 let group = param_groups.call_method1(py, "__getitem__", (i as i32,))?;
967 group.call_method(py, "__setitem__", ("lr", lr), None)?;
968 }
969 Ok(())
970 }
971
972 #[pyo3(signature = (queries, batch_size=32, max_grad_norm=None))]
990 fn train_epoch(
991 &mut self,
992 py: Python<'_>,
993 queries: Vec<String>,
994 batch_size: usize,
995 max_grad_norm: Option<f64>,
996 ) -> PyResult<EpochStats> {
997 let mut history = TrainingHistory::new();
998 self.train_epoch_internal(
999 py,
1000 &queries,
1001 batch_size,
1002 usize::MAX,
1003 max_grad_norm,
1004 &mut history,
1005 )
1006 }
1007
1008 pub fn evaluate_loss(&self, queries: Vec<String>) -> PyResult<f64> {
1018 if queries.is_empty() {
1019 return Ok(0.0);
1020 }
1021
1022 let probs = self.evaluate_query_probabilities(&queries)?;
1023 let total_loss: f64 = probs.iter().map(|&p| types::nll_loss_value(p)).sum();
1024 Ok(total_loss / queries.len() as f64)
1025 }
1026
1027 #[pyo3(signature = (queries, batch_size=32, max_grad_norm=None))]
1029 fn train_epoch_tensor(
1030 &mut self,
1031 py: Python<'_>,
1032 queries: Vec<String>,
1033 batch_size: usize,
1034 max_grad_norm: Option<f64>,
1035 ) -> PyResult<EpochStats> {
1036 let mut history = TrainingHistory::new();
1037 self.train_epoch_tensor_internal(
1038 py,
1039 &queries,
1040 batch_size,
1041 usize::MAX,
1042 max_grad_norm,
1043 &mut history,
1044 )
1045 }
1046
1047 fn warmup_breakdown(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
1054 let ptx_profile = self.output_provider.ptx_load_profile();
1055 let circuit_profile = self.last_compile_profile.as_ref();
1056
1057 if ptx_profile.is_none() && circuit_profile.is_none() {
1059 return Ok(None);
1060 }
1061
1062 let result = PyDict::new(py);
1063
1064 if let Some(ptx) = ptx_profile {
1065 let ptx_dict = PyDict::new(py);
1066 ptx_dict.set_item("total_sec", ptx.total_sec)?;
1067 ptx_dict.set_item("cubin_loaded", ptx.cubin_loaded)?;
1068 ptx_dict.set_item("ptx_fallback", ptx.ptx_fallback)?;
1069 let per_module = PyDict::new(py);
1070 for (name, sec) in &ptx.per_module_sec {
1071 per_module.set_item(name, *sec)?;
1072 }
1073 ptx_dict.set_item("per_module_sec", per_module)?;
1074 result.set_item("ptx", ptx_dict)?;
1075 }
1076
1077 if let Some(circuit) = circuit_profile {
1078 let circuit_dict = PyDict::new(py);
1079 circuit_dict.set_item("gpu_cache_hit", circuit.gpu_cache_hit)?;
1080 circuit_dict.set_item("disk_cache_hit", circuit.disk_cache_hit)?;
1081 circuit_dict.set_item("d4_compile_sec", circuit.d4_compile_sec)?;
1082 circuit_dict.set_item("verify_sec", circuit.verify_sec)?;
1083 circuit_dict.set_item("smooth_sec", circuit.smooth_sec)?;
1084 circuit_dict.set_item("cache_store_sec", circuit.cache_store_sec)?;
1085 circuit_dict.set_item("free_var_mask_sec", circuit.free_var_mask_sec)?;
1086 circuit_dict.set_item("cnf_hash_sec", circuit.cnf_hash_sec)?;
1087 result.set_item("circuit", circuit_dict)?;
1088 }
1089
1090 Ok(Some(result.into()))
1091 }
1092
1093 fn clear_circuit_cache(&mut self) {
1096 self.circuit_cache.clear();
1097 }
1098
1099 pub fn memory_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
1101 provider_memory_stats(py, &self.output_provider)
1102 }
1103
1104 pub fn rule_provenance(&self, py: Python<'_>) -> PyResult<PyObject> {
1105 let provenance = xlog_logic::rule_provenance(&self.ast, None);
1106 pack_rule_provenance(py, &provenance)
1107 }
1108
1109 pub fn proof_traces(&self, py: Python<'_>) -> PyResult<PyObject> {
1110 let provenance = xlog_logic::rule_provenance(&self.ast, None);
1111 let traces = xlog_logic::query_proof_traces(&self.ast, &provenance);
1112 pack_proof_traces(py, &traces)
1113 }
1114
1115 pub fn host_transfer_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
1116 let stats = self.output_provider.host_transfer_stats();
1117 let dict = PyDict::new(py);
1118 dict.set_item("dtoh_bytes", stats.dtoh_bytes)?;
1119 dict.set_item("htod_bytes", stats.htod_bytes)?;
1120 dict.set_item("dtoh_calls", stats.dtoh_calls)?;
1121 dict.set_item("htod_calls", stats.htod_calls)?;
1122 Ok(dict.into())
1123 }
1124
1125 pub fn reset_host_transfer_stats(&self) {
1126 self.output_provider.reset_host_transfer_stats()
1127 }
1128
1129 pub fn neural_hot_loop_diagnostics(&self, py: Python<'_>) -> PyResult<PyObject> {
1130 let transfers = self.output_provider.host_transfer_stats();
1131 let dict = PyDict::new(py);
1132 dict.set_item("post_load_dtoh_bytes", transfers.dtoh_bytes)?;
1133 dict.set_item("post_load_htod_bytes", transfers.htod_bytes)?;
1134 dict.set_item("post_load_dtoh_calls", transfers.dtoh_calls)?;
1135 dict.set_item("post_load_htod_calls", transfers.htod_calls)?;
1136 dict.set_item("control_plane_bytes_per_iteration", py.None())?;
1137 dict.set_item(
1138 "control_plane_status",
1139 "unavailable: per-iteration control-plane byte counter is not registered",
1140 )?;
1141 dict.set_item("scalar_sync_checks", py.None())?;
1142 dict.set_item(
1143 "scalar_sync_status",
1144 "unavailable: scalar synchronization counter is not registered",
1145 )?;
1146
1147 let cuda_graph = PyDict::new(py);
1148 cuda_graph.set_item(
1149 "csm_cuda_graph_captures",
1150 self.output_provider.csm_cuda_graph_captures(),
1151 )?;
1152 cuda_graph.set_item(
1153 "csm_cuda_graph_launches",
1154 self.output_provider.csm_cuda_graph_launches(),
1155 )?;
1156 cuda_graph.set_item(
1157 "csm_cuda_graph_fallbacks",
1158 self.output_provider.csm_cuda_graph_fallbacks(),
1159 )?;
1160 cuda_graph.set_item(
1161 "csm_cuda_graph_cache_hits",
1162 self.output_provider.csm_cuda_graph_cache_hits(),
1163 )?;
1164 dict.set_item("cuda_graph", cuda_graph)?;
1165
1166 let circuit_cache = PyDict::new(py);
1167 circuit_cache.set_item("circuit_cache_size", self.circuit_cache.len())?;
1168 circuit_cache.set_item("circuit_cache_hits", self.circuit_cache_hits)?;
1169 circuit_cache.set_item("circuit_cache_misses", self.circuit_cache_misses)?;
1170 circuit_cache.set_item("template_compile_count", self.template_compile_count)?;
1171 circuit_cache.set_item(
1172 "query_signature_cache_size",
1173 self.query_signature_cache.len(),
1174 )?;
1175 dict.set_item("circuit_cache", circuit_cache)?;
1176
1177 Ok(dict.into())
1178 }
1179
1180 pub fn cuda_graph_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
1181 let dict = PyDict::new(py);
1182 dict.set_item(
1183 "csm_cuda_graph_captures",
1184 self.output_provider.csm_cuda_graph_captures(),
1185 )?;
1186 dict.set_item(
1187 "csm_cuda_graph_launches",
1188 self.output_provider.csm_cuda_graph_launches(),
1189 )?;
1190 dict.set_item(
1191 "csm_cuda_graph_fallbacks",
1192 self.output_provider.csm_cuda_graph_fallbacks(),
1193 )?;
1194 dict.set_item(
1195 "csm_cuda_graph_cache_hits",
1196 self.output_provider.csm_cuda_graph_cache_hits(),
1197 )?;
1198 Ok(dict.into())
1199 }
1200}
1201
1202fn pack_rule_provenance(
1203 py: Python<'_>,
1204 entries: &[xlog_logic::RuleProvenance],
1205) -> PyResult<PyObject> {
1206 let list = PyList::empty(py);
1207 for entry in entries {
1208 let dict = PyDict::new(py);
1209 dict.set_item("rule_id", &entry.rule_id)?;
1210 dict.set_item("head", &entry.head)?;
1211 dict.set_item("source_kind", entry.source_kind.as_str())?;
1212 dict.set_item("source_span", entry.source_span.clone())?;
1213 dict.set_item("generation_trace_hash", entry.generation_trace_hash.clone())?;
1214 dict.set_item("support_relation_ids", entry.support_relation_ids.clone())?;
1215 dict.set_item(
1216 "counterexample_relation_ids",
1217 entry.counterexample_relation_ids.clone(),
1218 )?;
1219 list.append(dict)?;
1220 }
1221 Ok(list.into())
1222}
1223
1224fn pack_proof_traces(
1225 py: Python<'_>,
1226 entries: &[xlog_logic::QueryProofTrace],
1227) -> PyResult<PyObject> {
1228 let list = PyList::empty(py);
1229 for entry in entries {
1230 let dict = PyDict::new(py);
1231 dict.set_item("query_id", &entry.query_id)?;
1232 dict.set_item("query", &entry.query)?;
1233 dict.set_item("answer_relation", &entry.answer_relation)?;
1234 dict.set_item("rule_ids", entry.rule_ids.clone())?;
1235 dict.set_item("source_facts", entry.source_facts.clone())?;
1236 dict.set_item("rejected_alternatives", entry.rejected_alternatives.clone())?;
1237 list.append(dict)?;
1238 }
1239 Ok(list.into())
1240}