Skip to main content

pyxlog/
program.rs

1// Remaining CompiledProgram methods and internal types.
2//
3// Contains: evaluate/evaluate_device, NLL loss helpers, training control
4// methods (zero_grad, optimizer_step, etc.), pack_result helpers, and the
5// internal types used by CompiledProgram's implementation (CachedCircuit,
6// QuerySignature, InputSource, NeuralGroup, CompiledProbProgram).
7//
8// The #[pyclass] struct definitions remain in lib.rs.
9
10use cudarc::driver::DeviceSlice;
11use pyo3::exceptions::{PyRuntimeError, PyValueError};
12use pyo3::prelude::*;
13use pyo3::types::{PyDict, PyList};
14
15use xlog_core::{ScalarType, Schema};
16use xlog_logic::ast::Term;
17use xlog_prob::exact::ExactDdnnfProgram;
18#[cfg(feature = "host-io")]
19use xlog_prob::exact::{ExactResultWithGrads, QueryProbability};
20use xlog_prob::mc::{McEvalConfig, McProgram, McSamplingMethod};
21use xlog_prob::neural_fast_path::GpuWeightSlots;
22
23use super::neural_registry::NeuralPredicateInfo;
24use super::{
25    dlpack_capsule_from_tensor, enforce_call_memory_limit, provider_memory_stats, types,
26    CompiledProgram, EpochStats, EvalResult, McDeviceEvalResult, TrainingHistory,
27};
28
29// =========================================================================
30// Internal types
31// =========================================================================
32
33/// A cached circuit for a specific query template.
34///
35/// The circuit structure is immutable - only weights change between queries.
36/// Weight slots map network outputs to circuit variables.
37pub(crate) struct CachedCircuit {
38    /// The compiled program containing the GPU circuit
39    pub(crate) program: ExactDdnnfProgram,
40
41    /// Device-resident mapping from neural output slots to CNF variable ids.
42    pub(crate) slots: GpuWeightSlots,
43
44    /// Ordered target domain for Targeted signatures. Empty for Boolean.
45    pub(crate) target_domain: Vec<String>,
46}
47
48#[derive(Debug, Clone)]
49pub(crate) enum InputSource {
50    QueryArg(usize),
51    ImplicitSlot(usize),
52    /// Stage-B real-domain grounding: read row `usize` of the join-domain tensor
53    /// source (`nsr_domain`, the per-event feature batch). Used by the per-event
54    /// expansions of a neural predicate joined on an existential variable.
55    DomainRow(usize),
56    /// Stage-B real-domain grounding: an input-independent group (the rule-weight
57    /// guard expanded per head-key constant). The forward feeds a dummy row of the
58    /// declared input width; only the network's parameters (not the input) matter.
59    ConstDummy,
60}
61
62#[derive(Debug, Clone)]
63pub(crate) struct NeuralGroup {
64    pub(crate) info: NeuralPredicateInfo,
65    pub(crate) input_source: InputSource,
66    /// Stage-B real-domain grounding: when `Some(c)`, the template grounds this
67    /// group's neural atom at the real constant `c` (a domain event id or a
68    /// head-key edge id) instead of a synthetic placeholder, so the one neural
69    /// occurrence expands into one circuit leaf per domain constant.
70    pub(crate) ground_const: Option<Term>,
71    #[cfg(feature = "host-io")]
72    pub(crate) output_var: Option<String>,
73}
74
75/// An ordinary-relation body atom in a trainable-rule query treated as a HARD
76/// join condition: it gates which query groundings can fire but contributes no
77/// probability mass and no gradient. The query probability is
78/// `(hard conditions satisfiable?) x (neural prob)`; gradients flow only
79/// through the neural predicates x sigma(w), never through these fact atoms.
80#[derive(Debug, Clone)]
81pub(crate) struct HardFilter {
82    /// Ordinary relation name (must hold over the program's facts).
83    pub(crate) relation: String,
84    /// For each relation argument position, the query HEAD position whose value
85    /// it must equal. Current scope: every relation argument is a head
86    /// variable; hard conditions that join on existential (non-head) variables
87    /// are a documented follow-up.
88    pub(crate) arg_head_positions: Vec<usize>,
89}
90
91/// Stage-B existential join: the plan for grounding a neural predicate over the
92/// REAL join domain inside the circuit (instead of stripping the join relation as
93/// a pre-filter). The neural groups are already expanded per domain constant; this
94/// records the ordinary relations whose ground facts must stay IN the circuit (so
95/// provenance OR-aggregates `OR_event(neural(event) ∧ join(event, head))` at each
96/// head binding) and the head-key domain that the query ranges over.
97#[derive(Debug, Clone)]
98pub(crate) struct JoinPlan {
99    /// Ordinary relations kept inside the circuit rule; their ground facts are
100    /// added to the template program (read from `self.ast`).
101    pub(crate) relations: Vec<String>,
102    /// The real head-key domain the query ranges over (e.g. edge ids), in the
103    /// same sorted order as the per-edge guard group expansion and the emitted
104    /// `prob_queries`. Serves as the `target_domain` for a join signature.
105    pub(crate) head_domain: Vec<String>,
106}
107
108#[derive(Debug, Clone)]
109pub(crate) enum QuerySignature {
110    Boolean {
111        groups: Vec<NeuralGroup>,
112        hard_filters: Vec<HardFilter>,
113    },
114    Targeted {
115        target_position: usize,
116        groups: Vec<NeuralGroup>,
117        hard_filters: Vec<HardFilter>,
118        /// `Some` for a Stage-B existential-join signature: the head target
119        /// position ranges over the real head-key domain (e.g. edges) and the
120        /// groups are real-domain-grounded. `None` for the ordinary targeted path.
121        join: Option<JoinPlan>,
122    },
123}
124
125impl QuerySignature {
126    pub(crate) fn groups(&self) -> &[NeuralGroup] {
127        match self {
128            QuerySignature::Boolean { groups, .. } | QuerySignature::Targeted { groups, .. } => {
129                groups
130            }
131        }
132    }
133
134    pub(crate) fn hard_filters(&self) -> &[HardFilter] {
135        match self {
136            QuerySignature::Boolean { hard_filters, .. }
137            | QuerySignature::Targeted { hard_filters, .. } => hard_filters,
138        }
139    }
140
141    /// The Stage-B join plan, if this is an existential-join signature.
142    pub(crate) fn join(&self) -> Option<&JoinPlan> {
143        match self {
144            QuerySignature::Targeted { join, .. } => join.as_ref(),
145            QuerySignature::Boolean { .. } => None,
146        }
147    }
148}
149
150pub(crate) enum CompiledProbProgram {
151    Exact(ExactDdnnfProgram),
152    Mc(McProgram),
153}
154
155impl CompiledProbProgram {
156    #[cfg(feature = "host-io")]
157    pub(crate) fn num_vars(&self) -> usize {
158        match self {
159            Self::Exact(p) => p.num_vars(),
160            Self::Mc(p) => p.num_vars(),
161        }
162    }
163}
164
165// =========================================================================
166// Helper functions
167// =========================================================================
168
169#[cfg(feature = "host-io")]
170fn atom_to_string(atom: &xlog_prob::provenance::GroundAtom) -> String {
171    use xlog_prob::provenance::Value;
172
173    if atom.args.is_empty() {
174        return format!("{}()", atom.predicate);
175    }
176
177    let mut s = String::new();
178    s.push_str(&atom.predicate);
179    s.push('(');
180    for (i, arg) in atom.args.iter().enumerate() {
181        if i != 0 {
182            s.push_str(", ");
183        }
184        match arg {
185            Value::I64(v) => s.push_str(&v.to_string()),
186            Value::F64(bits) => s.push_str(&f64::from_bits(*bits).to_string()),
187            Value::Symbol(sym) => s.push_str(&format!("sym#{}", sym)),
188            Value::String(v) => s.push_str(v),
189        }
190    }
191    s.push(')');
192    s
193}
194
195// =========================================================================
196// impl CompiledProgram — private helpers
197// =========================================================================
198
199impl CompiledProgram {
200    pub(crate) fn parse_sampling_method(s: Option<String>) -> PyResult<Option<McSamplingMethod>> {
201        match s.as_deref() {
202            None => Ok(None),
203            Some("rejection") => Ok(Some(McSamplingMethod::Rejection)),
204            Some("evidence_clamping") => Ok(Some(McSamplingMethod::EvidenceClamping)),
205            Some(other) => Err(PyValueError::new_err(format!(
206                "Unknown sampling_method '{}'. Use 'rejection' or 'evidence_clamping'.",
207                other
208            ))),
209        }
210    }
211
212    /// Evaluate probability of a single query by compiling a temporary program.
213    pub(crate) fn evaluate_query_probability(&self, query: &str) -> PyResult<f64> {
214        let probs = self.evaluate_query_probabilities(&[query.to_string()])?;
215        probs
216            .into_iter()
217            .next()
218            .ok_or_else(|| PyRuntimeError::new_err("Query evaluation returned no results"))
219    }
220
221    /// Evaluate probabilities for multiple queries by compiling a temporary program.
222    pub(crate) fn evaluate_query_probabilities(&self, queries: &[String]) -> PyResult<Vec<f64>> {
223        #[cfg(not(feature = "host-io"))]
224        {
225            let _ = queries;
226            return Err(types::host_io_disabled_pyerr());
227        }
228
229        #[cfg(feature = "host-io")]
230        {
231            // Build source with queries appended
232            let mut source_with_queries = self._source.clone();
233            for query in queries {
234                source_with_queries.push_str(&format!("\nquery({}).", query));
235            }
236
237            // Compile and evaluate the temporary program
238            let result: Vec<QueryProbability> = match self._prob_engine {
239                xlog_logic::ast::ProbEngine::ExactDdnnf => {
240                    let program = ExactDdnnfProgram::compile_source_with_gpu(
241                        &source_with_queries,
242                        self._gpu_config,
243                    )
244                    .map_err(|e| types::gpu_err("Query compilation error", e))?;
245
246                    program
247                        .evaluate()
248                        .map_err(|e| types::gpu_err("Query evaluation error", e))?
249                        .query_probs
250                }
251                xlog_logic::ast::ProbEngine::Mc => {
252                    let program =
253                        McProgram::compile_source_with_gpu(&source_with_queries, self._gpu_config)
254                            .map_err(|e| types::gpu_err("Query compilation error", e))?;
255
256                    let cfg = McEvalConfig::default();
257                    program
258                        .evaluate(cfg)
259                        .map_err(|e| types::gpu_err("Query evaluation error", e))?
260                        .query_estimates
261                        .into_iter()
262                        .map(|e| QueryProbability {
263                            atom: e.atom,
264                            prob: e.prob,
265                            log_prob: e.log_prob,
266                        })
267                        .collect()
268                }
269            };
270
271            // Extract probabilities in query order
272            // The results should be in the same order as queries were added
273            let probs: Vec<f64> = result.iter().map(|qp| qp.prob).collect();
274
275            if probs.len() != queries.len() {
276                return Err(PyRuntimeError::new_err(format!(
277                    "Expected {} query results, got {}",
278                    queries.len(),
279                    probs.len()
280                )));
281            }
282
283            Ok(probs)
284        }
285    }
286
287    #[cfg(feature = "host-io")]
288    fn pack_result_probs(
289        &self,
290        py: Python<'_>,
291        query_probs: Vec<QueryProbability>,
292    ) -> PyResult<EvalResult> {
293        let mut atoms: Vec<String> = Vec::with_capacity(query_probs.len());
294        let mut probs: Vec<f64> = Vec::with_capacity(query_probs.len());
295        let mut log_probs: Vec<f64> = Vec::with_capacity(query_probs.len());
296
297        for q in query_probs {
298            atoms.push(atom_to_string(&q.atom));
299            probs.push(q.prob);
300            log_probs.push(q.log_prob);
301        }
302
303        let schema = Schema::new(vec![("col0".to_string(), ScalarType::F64)]);
304        let prob_buf = self
305            .output_provider
306            .create_buffer_from_slice::<f64>(&probs, schema.clone())
307            .map_err(types::xlog_err)?;
308        let log_prob_buf = self
309            .output_provider
310            .create_buffer_from_slice::<f64>(&log_probs, schema)
311            .map_err(types::xlog_err)?;
312
313        let prob_tensor = self
314            .output_provider
315            .to_dlpack_table(prob_buf)
316            .column(0)
317            .map_err(types::xlog_err)?;
318        let log_prob_tensor = self
319            .output_provider
320            .to_dlpack_table(log_prob_buf)
321            .column(0)
322            .map_err(types::xlog_err)?;
323
324        Ok(EvalResult {
325            atoms,
326            prob: dlpack_capsule_from_tensor(py, prob_tensor)?,
327            log_prob: dlpack_capsule_from_tensor(py, log_prob_tensor)?,
328            num_vars: self.program.num_vars(),
329            grad_true: None,
330            grad_false: None,
331            approx: false,
332            stderr: None,
333            ci_low: None,
334            ci_high: None,
335            samples: None,
336            evidence_samples: None,
337            seed: None,
338            confidence: None,
339            nonmonotone_semantics: None,
340            nonmonotone_sccs: None,
341            nonmonotone_cycles: None,
342            nonmonotone_iteration_limit_hits: None,
343            sampling_method: None,
344            mc_engine: None,
345        })
346    }
347
348    #[cfg(feature = "host-io")]
349    fn pack_result_with_grads(
350        &self,
351        py: Python<'_>,
352        result: ExactResultWithGrads,
353    ) -> PyResult<EvalResult> {
354        let mut atoms: Vec<String> = Vec::with_capacity(result.query_grads.len());
355        let mut probs: Vec<f64> = Vec::with_capacity(result.query_grads.len());
356        let mut log_probs: Vec<f64> = Vec::with_capacity(result.query_grads.len());
357
358        let mut grad_true_caps: Vec<PyObject> = Vec::with_capacity(result.query_grads.len());
359        let mut grad_false_caps: Vec<PyObject> = Vec::with_capacity(result.query_grads.len());
360
361        let schema = Schema::new(vec![("col0".to_string(), ScalarType::F64)]);
362
363        let num_vars = self.program.num_vars();
364        for q in result.query_grads {
365            atoms.push(atom_to_string(&q.atom));
366            probs.push(q.prob);
367            log_probs.push(q.log_prob);
368
369            let grad_true_buf = self
370                .output_provider
371                .create_buffer_from_slice::<f64>(&q.grad_true, schema.clone())
372                .map_err(types::xlog_err)?;
373            let grad_false_buf = self
374                .output_provider
375                .create_buffer_from_slice::<f64>(&q.grad_false, schema.clone())
376                .map_err(types::xlog_err)?;
377
378            let grad_true_tensor = self
379                .output_provider
380                .to_dlpack_table(grad_true_buf)
381                .column(0)
382                .map_err(types::xlog_err)?;
383            let grad_false_tensor = self
384                .output_provider
385                .to_dlpack_table(grad_false_buf)
386                .column(0)
387                .map_err(types::xlog_err)?;
388
389            grad_true_caps.push(dlpack_capsule_from_tensor(py, grad_true_tensor)?);
390            grad_false_caps.push(dlpack_capsule_from_tensor(py, grad_false_tensor)?);
391        }
392
393        let prob_buf = self
394            .output_provider
395            .create_buffer_from_slice::<f64>(&probs, schema.clone())
396            .map_err(types::xlog_err)?;
397        let log_prob_buf = self
398            .output_provider
399            .create_buffer_from_slice::<f64>(&log_probs, schema)
400            .map_err(types::xlog_err)?;
401
402        let prob_tensor = self
403            .output_provider
404            .to_dlpack_table(prob_buf)
405            .column(0)
406            .map_err(types::xlog_err)?;
407        let log_prob_tensor = self
408            .output_provider
409            .to_dlpack_table(log_prob_buf)
410            .column(0)
411            .map_err(types::xlog_err)?;
412
413        Ok(EvalResult {
414            atoms,
415            prob: dlpack_capsule_from_tensor(py, prob_tensor)?,
416            log_prob: dlpack_capsule_from_tensor(py, log_prob_tensor)?,
417            num_vars,
418            grad_true: Some(grad_true_caps),
419            grad_false: Some(grad_false_caps),
420            approx: false,
421            stderr: None,
422            ci_low: None,
423            ci_high: None,
424            samples: None,
425            evidence_samples: None,
426            seed: None,
427            confidence: None,
428            nonmonotone_semantics: None,
429            nonmonotone_sccs: None,
430            nonmonotone_cycles: None,
431            nonmonotone_iteration_limit_hits: None,
432            sampling_method: None,
433            mc_engine: None,
434        })
435    }
436
437    #[cfg(feature = "host-io")]
438    fn pack_result_mc(
439        &self,
440        py: Python<'_>,
441        result: xlog_prob::mc::McResult,
442    ) -> PyResult<EvalResult> {
443        let mut atoms: Vec<String> = Vec::with_capacity(result.query_estimates.len());
444        let mut probs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
445        let mut log_probs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
446        let mut stderrs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
447        let mut ci_lows: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
448        let mut ci_highs: Vec<f64> = Vec::with_capacity(result.query_estimates.len());
449
450        for q in &result.query_estimates {
451            atoms.push(atom_to_string(&q.atom));
452            probs.push(q.prob);
453            log_probs.push(q.log_prob);
454            stderrs.push(q.stderr);
455            ci_lows.push(q.ci_low);
456            ci_highs.push(q.ci_high);
457        }
458
459        let schema = Schema::new(vec![("col0".to_string(), ScalarType::F64)]);
460        let prob_buf = self
461            .output_provider
462            .create_buffer_from_slice::<f64>(&probs, schema.clone())
463            .map_err(types::xlog_err)?;
464        let log_prob_buf = self
465            .output_provider
466            .create_buffer_from_slice::<f64>(&log_probs, schema.clone())
467            .map_err(types::xlog_err)?;
468        let stderr_buf = self
469            .output_provider
470            .create_buffer_from_slice::<f64>(&stderrs, schema.clone())
471            .map_err(types::xlog_err)?;
472        let ci_low_buf = self
473            .output_provider
474            .create_buffer_from_slice::<f64>(&ci_lows, schema.clone())
475            .map_err(types::xlog_err)?;
476        let ci_high_buf = self
477            .output_provider
478            .create_buffer_from_slice::<f64>(&ci_highs, schema)
479            .map_err(types::xlog_err)?;
480
481        let prob_tensor = self
482            .output_provider
483            .to_dlpack_table(prob_buf)
484            .column(0)
485            .map_err(types::xlog_err)?;
486        let log_prob_tensor = self
487            .output_provider
488            .to_dlpack_table(log_prob_buf)
489            .column(0)
490            .map_err(types::xlog_err)?;
491        let stderr_tensor = self
492            .output_provider
493            .to_dlpack_table(stderr_buf)
494            .column(0)
495            .map_err(types::xlog_err)?;
496        let ci_low_tensor = self
497            .output_provider
498            .to_dlpack_table(ci_low_buf)
499            .column(0)
500            .map_err(types::xlog_err)?;
501        let ci_high_tensor = self
502            .output_provider
503            .to_dlpack_table(ci_high_buf)
504            .column(0)
505            .map_err(types::xlog_err)?;
506
507        Ok(EvalResult {
508            atoms,
509            prob: dlpack_capsule_from_tensor(py, prob_tensor)?,
510            log_prob: dlpack_capsule_from_tensor(py, log_prob_tensor)?,
511            num_vars: self.program.num_vars(),
512            grad_true: None,
513            grad_false: None,
514            approx: true,
515            stderr: Some(dlpack_capsule_from_tensor(py, stderr_tensor)?),
516            ci_low: Some(dlpack_capsule_from_tensor(py, ci_low_tensor)?),
517            ci_high: Some(dlpack_capsule_from_tensor(py, ci_high_tensor)?),
518            samples: Some(result.total_samples),
519            evidence_samples: Some(result.evidence_samples),
520            seed: Some(result.seed),
521            confidence: Some(result.confidence),
522            nonmonotone_semantics: Some(xlog_prob::mc::NONMONOTONE_SEMANTICS.to_string()),
523            nonmonotone_sccs: Some(result.nonmonotone_sccs),
524            nonmonotone_cycles: Some(result.nonmonotone_cycles),
525            nonmonotone_iteration_limit_hits: Some(result.nonmonotone_iteration_limit_hits),
526            sampling_method: Some(match result.sampling_method {
527                McSamplingMethod::Rejection => "rejection".to_string(),
528                McSamplingMethod::EvidenceClamping => "evidence_clamping".to_string(),
529            }),
530            mc_engine: Some(result.engine.as_str().to_string()),
531        })
532    }
533}
534
535// =========================================================================
536// #[pymethods] impl CompiledProgram — evaluate, NLL, training controls
537// =========================================================================
538
539#[pymethods]
540impl CompiledProgram {
541    #[pyo3(signature = (return_grads=false, samples=None, seed=None, confidence=0.95, max_nonmonotone_iterations=1024, sampling_method=None, memory_mb=None, allow_cpu_oracle=false))]
542    pub fn evaluate(
543        &self,
544        _py: Python<'_>,
545        return_grads: bool,
546        samples: Option<usize>,
547        seed: Option<u64>,
548        confidence: f64,
549        max_nonmonotone_iterations: usize,
550        sampling_method: Option<String>,
551        memory_mb: Option<u64>,
552        allow_cpu_oracle: bool,
553    ) -> PyResult<EvalResult> {
554        enforce_call_memory_limit(&self.output_provider, memory_mb)?;
555        match &self.program {
556            CompiledProbProgram::Exact(_program) => {
557                if samples.is_some() || seed.is_some() {
558                    return Err(PyValueError::new_err(
559                        "samples/seed are only supported for prob_engine='mc'",
560                    ));
561                }
562                #[cfg(feature = "host-io")]
563                {
564                    if return_grads {
565                        let result = _program
566                            .evaluate_gpu_with_grads()
567                            .map_err(types::xlog_err)?;
568                        self.pack_result_with_grads(_py, result)
569                    } else {
570                        let result = _program.evaluate().map_err(types::xlog_err)?;
571                        self.pack_result_probs(_py, result.query_probs)
572                    }
573                }
574                #[cfg(not(feature = "host-io"))]
575                {
576                    let _ = return_grads;
577                    Err(types::host_io_disabled_pyerr())
578                }
579            }
580            CompiledProbProgram::Mc(_program) => {
581                if return_grads {
582                    return Err(PyValueError::new_err(
583                        "MC inference does not support gradients (return_grads must be false)",
584                    ));
585                }
586
587                let mut cfg = McEvalConfig::default();
588                cfg.samples = samples.unwrap_or(10000);
589                cfg.seed = seed.unwrap_or(0);
590                cfg.confidence = confidence;
591                cfg.max_nonmonotone_iterations = max_nonmonotone_iterations;
592                cfg.sampling_method = Self::parse_sampling_method(sampling_method)?;
593                // Fail-closed contract: resident-rejected programs (negation,
594                // aggregates, ...) error unless the caller explicitly opts
595                // into the labeled CPU oracle.
596                cfg.allow_cpu_oracle_fallback = allow_cpu_oracle;
597                #[cfg(feature = "host-io")]
598                {
599                    let result = _program.evaluate(cfg).map_err(types::xlog_err)?;
600                    self.pack_result_mc(_py, result)
601                }
602                #[cfg(not(feature = "host-io"))]
603                {
604                    let _ = cfg;
605                    Err(types::host_io_disabled_pyerr())
606                }
607            }
608        }
609    }
610
611    /// Evaluate Monte Carlo programs and return device-only result counts via DLPack.
612    ///
613    /// This is the primary GPU-native API surface for MC inference. It never performs
614    /// device->host reads for result data (only returns device buffers).
615    #[pyo3(signature = (samples=None, seed=None, confidence=0.95, max_nonmonotone_iterations=1024, sampling_method=None, memory_mb=None))]
616    pub fn evaluate_device(
617        &self,
618        py: Python<'_>,
619        samples: Option<usize>,
620        seed: Option<u64>,
621        confidence: f64,
622        max_nonmonotone_iterations: usize,
623        sampling_method: Option<String>,
624        memory_mb: Option<u64>,
625    ) -> PyResult<McDeviceEvalResult> {
626        enforce_call_memory_limit(&self.output_provider, memory_mb)?;
627        let (
628            query_counts,
629            evidence_count,
630            total_samples,
631            seed,
632            confidence,
633            nonmonotone_sccs,
634            nonmonotone_cycles,
635            nonmonotone_iteration_limit_hits,
636            sampling_method_val,
637            no_host,
638        ) = match &self.program {
639            CompiledProbProgram::Mc(program) => {
640                let mut cfg = McEvalConfig::default();
641                cfg.samples = samples.unwrap_or(10000);
642                cfg.seed = seed.unwrap_or(0);
643                cfg.confidence = confidence;
644                cfg.max_nonmonotone_iterations = max_nonmonotone_iterations;
645                cfg.sampling_method = Self::parse_sampling_method(sampling_method)?;
646
647                let result = program
648                    .evaluate_gpu_device_with_provider(cfg, self.output_provider.clone())
649                    .map_err(types::xlog_err)?;
650
651                (
652                    result.query_counts,
653                    result.evidence_count,
654                    result.total_samples,
655                    result.seed,
656                    result.confidence,
657                    result.nonmonotone_sccs,
658                    result.nonmonotone_cycles,
659                    result.nonmonotone_iteration_limit_hits,
660                    result.sampling_method,
661                    result.no_host,
662                )
663            }
664            _ => {
665                return Err(PyValueError::new_err(
666                    "evaluate_device is only supported for prob_engine='mc'",
667                ))
668            }
669        };
670
671        // PyTorch does not support unsigned 32-bit types. Export as i32 (bitwise identical for
672        // counts < 2^31) for maximum DLPack consumer compatibility.
673        let schema_i32 = Schema::new(vec![("col0".to_string(), ScalarType::I32)]);
674
675        let make_count_tensor =
676            |counts: xlog_cuda::memory::TrackedCudaSlice<u32>, rows: u64| -> PyResult<PyObject> {
677                let rows_u32 = u32::try_from(rows).map_err(|_| {
678                    PyValueError::new_err(format!("Row count {} exceeds u32::MAX", rows))
679                })?;
680
681                let mut d_num_rows = self
682                    .output_provider
683                    .memory()
684                    .alloc::<u32>(1)
685                    .map_err(types::xlog_err)?;
686                self.output_provider
687                    .device()
688                    .inner()
689                    .htod_sync_copy_into(&[rows_u32], &mut d_num_rows)
690                    .map_err(types::xlog_err)?;
691
692                let buffer = xlog_cuda::CudaBuffer::from_columns(
693                    vec![counts.into_bytes().into()],
694                    rows,
695                    d_num_rows,
696                    schema_i32.clone(),
697                );
698                let tensor = self
699                    .output_provider
700                    .to_dlpack_table(buffer)
701                    .column(0)
702                    .map_err(types::xlog_err)?;
703                dlpack_capsule_from_tensor(py, tensor)
704            };
705
706        let query_rows = u64::try_from(query_counts.len())
707            .map_err(|_| PyValueError::new_err("query_counts length overflow"))?;
708        let query_counts_capsule = make_count_tensor(query_counts, query_rows)?;
709        let evidence_count_capsule = make_count_tensor(evidence_count, 1)?;
710        let resident_no_host_certified = no_host.is_no_host();
711
712        Ok(McDeviceEvalResult {
713            query_counts: query_counts_capsule,
714            evidence_count: evidence_count_capsule,
715            total_samples,
716            seed,
717            confidence,
718            nonmonotone_semantics: xlog_prob::mc::NONMONOTONE_SEMANTICS.to_string(),
719            nonmonotone_sccs,
720            nonmonotone_cycles,
721            nonmonotone_iteration_limit_hits,
722            sampling_method: match sampling_method_val {
723                McSamplingMethod::Rejection => "rejection".to_string(),
724                McSamplingMethod::EvidenceClamping => "evidence_clamping".to_string(),
725            },
726            resident_no_host_certified,
727            resident_no_host_policy_result: if resident_no_host_certified {
728                "certified".to_string()
729            } else {
730                "failed".to_string()
731            },
732            resident_no_host_tracked_dtoh_calls: no_host.tracked_dtoh_calls,
733            resident_no_host_tracked_htod_calls: no_host.tracked_htod_calls,
734            resident_no_host_host_loop_iterations: no_host.host_loop_iterations,
735            resident_no_host_per_sample_host_launches: no_host.per_sample_host_launches,
736            resident_no_host_untracked_metadata_reads: no_host.untracked_metadata_reads,
737            resident_no_host_engine_launches: no_host.engine_launches,
738            resident_no_host_host_fixpoint_iterations: no_host.host_fixpoint_iterations,
739            resident_no_host_per_operator_host_allocations: no_host.per_operator_host_allocations,
740        })
741    }
742
743    // =========================================================================
744    // NLL Loss Functions
745    // =========================================================================
746
747    /// Compute negative log-likelihood loss for a single query.
748    ///
749    /// NLL loss = -log(P(query))
750    ///
751    /// This is the fundamental training objective for neural-symbolic programs.
752    /// Lower loss means higher probability of the query being true.
753    ///
754    /// # Arguments
755    /// * `query` - Query atom as string, e.g., "digit(0, 5)" or "path(1, 3)"
756    ///
757    /// # Returns
758    /// The NLL loss value (always non-negative, 0 for certain facts)
759    fn nll_loss(&self, query: &str) -> PyResult<f64> {
760        let prob = self.evaluate_query_probability(query)?;
761        Ok(types::nll_loss_value(prob))
762    }
763
764    /// Compute sum of NLL losses for a batch of queries.
765    ///
766    /// Batch loss = Σ -log(P(query_i))
767    ///
768    /// More efficient than calling nll_loss repeatedly as all queries
769    /// are compiled and evaluated together.
770    ///
771    /// # Arguments
772    /// * `queries` - List of query atoms as strings
773    ///
774    /// # Returns
775    /// Sum of individual NLL losses (0.0 for empty batch)
776    fn nll_loss_batch(&self, queries: Vec<String>) -> PyResult<f64> {
777        if queries.is_empty() {
778            return Ok(0.0);
779        }
780
781        let probs = self.evaluate_query_probabilities(&queries)?;
782        Ok(probs.iter().map(|&p| types::nll_loss_value(p)).sum())
783    }
784
785    /// Compute mean NLL loss for a batch of queries.
786    ///
787    /// Mean loss = (1/n) Σ -log(P(query_i))
788    ///
789    /// Useful for comparing loss across batches of different sizes.
790    ///
791    /// # Arguments
792    /// * `queries` - List of query atoms as strings (must be non-empty)
793    ///
794    /// # Returns
795    /// Mean of individual NLL losses
796    ///
797    /// # Errors
798    /// Returns error if queries is empty
799    fn nll_loss_mean(&self, queries: Vec<String>) -> PyResult<f64> {
800        if queries.is_empty() {
801            return Err(PyValueError::new_err(
802                "Cannot compute mean NLL loss for empty query batch",
803            ));
804        }
805
806        let probs = self.evaluate_query_probabilities(&queries)?;
807        let sum: f64 = probs.iter().map(|&p| types::nll_loss_value(p)).sum();
808        Ok(sum / probs.len() as f64)
809    }
810
811    /// Compute NLL loss and return as PyTorch tensor.
812    ///
813    /// Returns a scalar tensor that can participate in autograd.
814    /// Use this when you need gradients to flow back through the loss.
815    ///
816    /// # Arguments
817    /// * `query` - Query atom as string
818    ///
819    /// # Returns
820    /// PyTorch scalar tensor containing the loss value
821    fn nll_loss_tensor(&self, py: Python<'_>, query: &str) -> PyResult<PyObject> {
822        let loss = self.nll_loss(query)?;
823        types::create_torch_tensor(py, loss)
824    }
825
826    /// Compute batch NLL loss and return as PyTorch tensor.
827    ///
828    /// # Arguments
829    /// * `queries` - List of query atoms as strings
830    ///
831    /// # Returns
832    /// PyTorch scalar tensor containing the sum of losses
833    fn nll_loss_batch_tensor(&self, py: Python<'_>, queries: Vec<String>) -> PyResult<PyObject> {
834        let loss = self.nll_loss_batch(queries)?;
835        types::create_torch_tensor(py, loss)
836    }
837
838    // =========================================================================
839    // Backward Pass / Training Methods
840    // =========================================================================
841
842    /// Zero gradients for all registered networks.
843    ///
844    /// This should be called at the start of each training iteration
845    /// to clear accumulated gradients from previous iterations.
846    pub fn zero_grad(&self, py: Python<'_>) -> PyResult<()> {
847        for name in self.network_registry.names() {
848            if let Some(handle) = self.network_registry.get(name) {
849                if let Some(optimizer) = handle.optimizer() {
850                    optimizer.call_method0(py, "zero_grad")?;
851                }
852            }
853        }
854        Ok(())
855    }
856
857    /// Perform optimizer step for all registered networks.
858    ///
859    /// This applies the accumulated gradients to update network parameters.
860    /// Should be called after forward_backward().
861    pub fn optimizer_step(&self, py: Python<'_>) -> PyResult<()> {
862        for name in self.network_registry.names() {
863            if let Some(handle) = self.network_registry.get(name) {
864                if let Some(optimizer) = handle.optimizer() {
865                    optimizer.call_method0(py, "step")?;
866                }
867            }
868        }
869        Ok(())
870    }
871
872    /// Clip gradient norms for all registered networks.
873    ///
874    /// Uses `torch.nn.utils.clip_grad_norm_`.
875    pub fn clip_grad_norms(&self, py: Python<'_>, max_norm: f64) -> PyResult<()> {
876        let clip_fn = py.import("torch.nn.utils")?.getattr("clip_grad_norm_")?;
877        for name in self.network_registry.names() {
878            if let Some(handle) = self.network_registry.get(name) {
879                if let Some(module) = handle.module() {
880                    let params = module.call_method0(py, "parameters")?;
881                    clip_fn.call1((params, max_norm))?;
882                }
883            }
884        }
885        Ok(())
886    }
887
888    /// Step the learning rate scheduler.
889    ///
890    /// PyTorch schedulers expect at least one optimizer step before the first
891    /// scheduler step. Call this after `optimizer_step()` (or after a training
892    /// path that performs an optimizer step internally).
893    ///
894    /// If `network_name` is provided, steps only that network's scheduler.
895    /// If `None` (default), steps all registered schedulers.
896    #[pyo3(signature = (network_name=None))]
897    fn scheduler_step(&self, py: Python<'_>, network_name: Option<&str>) -> PyResult<()> {
898        match network_name {
899            Some(name) => {
900                let handle = self.network_registry.get(name).ok_or_else(|| {
901                    pyo3::exceptions::PyValueError::new_err(format!(
902                        "No network registered with name '{name}'"
903                    ))
904                })?;
905                if let Some(scheduler) = handle.scheduler() {
906                    scheduler.call_method0(py, "step")?;
907                }
908            }
909            None => {
910                for name in self.network_registry.names() {
911                    if let Some(handle) = self.network_registry.get(name) {
912                        if let Some(scheduler) = handle.scheduler() {
913                            scheduler.call_method0(py, "step")?;
914                        }
915                    }
916                }
917            }
918        }
919        Ok(())
920    }
921
922    /// Get the current learning rate for a registered network.
923    ///
924    /// Reads `optimizer.param_groups[0]['lr']`.
925    ///
926    /// # Arguments
927    /// * `network_name` - Name used in register_network()
928    fn get_lr(&self, py: Python<'_>, network_name: &str) -> PyResult<f64> {
929        let handle = self.network_registry.get(network_name).ok_or_else(|| {
930            pyo3::exceptions::PyValueError::new_err(format!(
931                "No network registered with name '{network_name}'"
932            ))
933        })?;
934        let optimizer = handle.optimizer().ok_or_else(|| {
935            pyo3::exceptions::PyValueError::new_err(format!(
936                "Network '{network_name}' has no optimizer"
937            ))
938        })?;
939        let param_groups = optimizer.getattr(py, "param_groups")?;
940        let group0 = param_groups.call_method1(py, "__getitem__", (0i32,))?;
941        let lr = group0.call_method1(py, "__getitem__", ("lr",))?;
942        lr.extract(py)
943    }
944
945    /// Set the learning rate for a registered network.
946    ///
947    /// Writes to all `optimizer.param_groups[i]['lr']`.
948    ///
949    /// # Arguments
950    /// * `network_name` - Name used in register_network()
951    /// * `lr` - New learning rate value
952    fn set_lr(&self, py: Python<'_>, network_name: &str, lr: f64) -> PyResult<()> {
953        let handle = self.network_registry.get(network_name).ok_or_else(|| {
954            pyo3::exceptions::PyValueError::new_err(format!(
955                "No network registered with name '{network_name}'"
956            ))
957        })?;
958        let optimizer = handle.optimizer().ok_or_else(|| {
959            pyo3::exceptions::PyValueError::new_err(format!(
960                "Network '{network_name}' has no optimizer"
961            ))
962        })?;
963        let param_groups = optimizer.getattr(py, "param_groups")?;
964        let num_groups: usize = param_groups.call_method0(py, "__len__")?.extract(py)?;
965        for i in 0..num_groups {
966            let group = param_groups.call_method1(py, "__getitem__", (i as i32,))?;
967            group.call_method(py, "__setitem__", ("lr", lr), None)?;
968        }
969        Ok(())
970    }
971
972    // =========================================================================
973    // Training Methods
974    // =========================================================================
975
976    /// Train for one epoch over the given queries.
977    ///
978    /// This method:
979    /// 1. Processes queries in batches
980    /// 2. For each batch: zero_grad, forward_backward for each query, optimizer_step
981    /// 3. Returns statistics for the epoch
982    ///
983    /// # Arguments
984    /// * `queries` - List of query strings to train on
985    /// * `batch_size` - Number of queries per batch (default: 32)
986    ///
987    /// # Returns
988    /// EpochStats with avg_loss, num_batches, total_queries
989    #[pyo3(signature = (queries, batch_size=32, max_grad_norm=None))]
990    fn train_epoch(
991        &mut self,
992        py: Python<'_>,
993        queries: Vec<String>,
994        batch_size: usize,
995        max_grad_norm: Option<f64>,
996    ) -> PyResult<EpochStats> {
997        let mut history = TrainingHistory::new();
998        self.train_epoch_internal(
999            py,
1000            &queries,
1001            batch_size,
1002            usize::MAX,
1003            max_grad_norm,
1004            &mut history,
1005        )
1006    }
1007
1008    /// Evaluate mean NLL loss over queries without updating parameters.
1009    ///
1010    /// Useful for validation/test set evaluation.
1011    ///
1012    /// # Arguments
1013    /// * `queries` - List of query strings to evaluate
1014    ///
1015    /// # Returns
1016    /// Mean NLL loss over all queries
1017    pub fn evaluate_loss(&self, queries: Vec<String>) -> PyResult<f64> {
1018        if queries.is_empty() {
1019            return Ok(0.0);
1020        }
1021
1022        let probs = self.evaluate_query_probabilities(&queries)?;
1023        let total_loss: f64 = probs.iter().map(|&p| types::nll_loss_value(p)).sum();
1024        Ok(total_loss / queries.len() as f64)
1025    }
1026
1027    /// Train for one epoch with GPU-native loss accumulation (no per-query .item()).
1028    #[pyo3(signature = (queries, batch_size=32, max_grad_norm=None))]
1029    fn train_epoch_tensor(
1030        &mut self,
1031        py: Python<'_>,
1032        queries: Vec<String>,
1033        batch_size: usize,
1034        max_grad_norm: Option<f64>,
1035    ) -> PyResult<EpochStats> {
1036        let mut history = TrainingHistory::new();
1037        self.train_epoch_tensor_internal(
1038            py,
1039            &queries,
1040            batch_size,
1041            usize::MAX,
1042            max_grad_norm,
1043            &mut history,
1044        )
1045    }
1046
1047    /// Return warmup profiling data as a Python dict (or None if profiling disabled).
1048    ///
1049    /// When XLOG_WARMUP_PROFILE=1, returns a dict with:
1050    ///   - "ptx": PTX load timing breakdown
1051    ///   - "circuit": circuit compilation timing breakdown
1052    /// Returns None if profiling is not enabled or no data is available.
1053    fn warmup_breakdown(&self, py: Python<'_>) -> PyResult<Option<PyObject>> {
1054        let ptx_profile = self.output_provider.ptx_load_profile();
1055        let circuit_profile = self.last_compile_profile.as_ref();
1056
1057        // Return None if neither profile is available.
1058        if ptx_profile.is_none() && circuit_profile.is_none() {
1059            return Ok(None);
1060        }
1061
1062        let result = PyDict::new(py);
1063
1064        if let Some(ptx) = ptx_profile {
1065            let ptx_dict = PyDict::new(py);
1066            ptx_dict.set_item("total_sec", ptx.total_sec)?;
1067            ptx_dict.set_item("cubin_loaded", ptx.cubin_loaded)?;
1068            ptx_dict.set_item("ptx_fallback", ptx.ptx_fallback)?;
1069            let per_module = PyDict::new(py);
1070            for (name, sec) in &ptx.per_module_sec {
1071                per_module.set_item(name, *sec)?;
1072            }
1073            ptx_dict.set_item("per_module_sec", per_module)?;
1074            result.set_item("ptx", ptx_dict)?;
1075        }
1076
1077        if let Some(circuit) = circuit_profile {
1078            let circuit_dict = PyDict::new(py);
1079            circuit_dict.set_item("gpu_cache_hit", circuit.gpu_cache_hit)?;
1080            circuit_dict.set_item("disk_cache_hit", circuit.disk_cache_hit)?;
1081            circuit_dict.set_item("d4_compile_sec", circuit.d4_compile_sec)?;
1082            circuit_dict.set_item("verify_sec", circuit.verify_sec)?;
1083            circuit_dict.set_item("smooth_sec", circuit.smooth_sec)?;
1084            circuit_dict.set_item("cache_store_sec", circuit.cache_store_sec)?;
1085            circuit_dict.set_item("free_var_mask_sec", circuit.free_var_mask_sec)?;
1086            circuit_dict.set_item("cnf_hash_sec", circuit.cnf_hash_sec)?;
1087            result.set_item("circuit", circuit_dict)?;
1088        }
1089
1090        Ok(Some(result.into()))
1091    }
1092
1093    /// Clear the circuit template cache, forcing recompilation on next query.
1094    /// Used for cache ablation benchmarks.
1095    fn clear_circuit_cache(&mut self) {
1096        self.circuit_cache.clear();
1097    }
1098
1099    /// Return memory diagnostics including allocated_bytes and memory_limit_bytes.
1100    pub fn memory_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
1101        provider_memory_stats(py, &self.output_provider)
1102    }
1103
1104    pub fn rule_provenance(&self, py: Python<'_>) -> PyResult<PyObject> {
1105        let provenance = xlog_logic::rule_provenance(&self.ast, None);
1106        pack_rule_provenance(py, &provenance)
1107    }
1108
1109    pub fn proof_traces(&self, py: Python<'_>) -> PyResult<PyObject> {
1110        let provenance = xlog_logic::rule_provenance(&self.ast, None);
1111        let traces = xlog_logic::query_proof_traces(&self.ast, &provenance);
1112        pack_proof_traces(py, &traces)
1113    }
1114
1115    pub fn host_transfer_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
1116        let stats = self.output_provider.host_transfer_stats();
1117        let dict = PyDict::new(py);
1118        dict.set_item("dtoh_bytes", stats.dtoh_bytes)?;
1119        dict.set_item("htod_bytes", stats.htod_bytes)?;
1120        dict.set_item("dtoh_calls", stats.dtoh_calls)?;
1121        dict.set_item("htod_calls", stats.htod_calls)?;
1122        Ok(dict.into())
1123    }
1124
1125    pub fn reset_host_transfer_stats(&self) {
1126        self.output_provider.reset_host_transfer_stats()
1127    }
1128
1129    pub fn neural_hot_loop_diagnostics(&self, py: Python<'_>) -> PyResult<PyObject> {
1130        let transfers = self.output_provider.host_transfer_stats();
1131        let dict = PyDict::new(py);
1132        dict.set_item("post_load_dtoh_bytes", transfers.dtoh_bytes)?;
1133        dict.set_item("post_load_htod_bytes", transfers.htod_bytes)?;
1134        dict.set_item("post_load_dtoh_calls", transfers.dtoh_calls)?;
1135        dict.set_item("post_load_htod_calls", transfers.htod_calls)?;
1136        dict.set_item("control_plane_bytes_per_iteration", py.None())?;
1137        dict.set_item(
1138            "control_plane_status",
1139            "unavailable: per-iteration control-plane byte counter is not registered",
1140        )?;
1141        dict.set_item("scalar_sync_checks", py.None())?;
1142        dict.set_item(
1143            "scalar_sync_status",
1144            "unavailable: scalar synchronization counter is not registered",
1145        )?;
1146
1147        let cuda_graph = PyDict::new(py);
1148        cuda_graph.set_item(
1149            "csm_cuda_graph_captures",
1150            self.output_provider.csm_cuda_graph_captures(),
1151        )?;
1152        cuda_graph.set_item(
1153            "csm_cuda_graph_launches",
1154            self.output_provider.csm_cuda_graph_launches(),
1155        )?;
1156        cuda_graph.set_item(
1157            "csm_cuda_graph_fallbacks",
1158            self.output_provider.csm_cuda_graph_fallbacks(),
1159        )?;
1160        cuda_graph.set_item(
1161            "csm_cuda_graph_cache_hits",
1162            self.output_provider.csm_cuda_graph_cache_hits(),
1163        )?;
1164        dict.set_item("cuda_graph", cuda_graph)?;
1165
1166        let circuit_cache = PyDict::new(py);
1167        circuit_cache.set_item("circuit_cache_size", self.circuit_cache.len())?;
1168        circuit_cache.set_item("circuit_cache_hits", self.circuit_cache_hits)?;
1169        circuit_cache.set_item("circuit_cache_misses", self.circuit_cache_misses)?;
1170        circuit_cache.set_item("template_compile_count", self.template_compile_count)?;
1171        circuit_cache.set_item(
1172            "query_signature_cache_size",
1173            self.query_signature_cache.len(),
1174        )?;
1175        dict.set_item("circuit_cache", circuit_cache)?;
1176
1177        Ok(dict.into())
1178    }
1179
1180    pub fn cuda_graph_stats(&self, py: Python<'_>) -> PyResult<PyObject> {
1181        let dict = PyDict::new(py);
1182        dict.set_item(
1183            "csm_cuda_graph_captures",
1184            self.output_provider.csm_cuda_graph_captures(),
1185        )?;
1186        dict.set_item(
1187            "csm_cuda_graph_launches",
1188            self.output_provider.csm_cuda_graph_launches(),
1189        )?;
1190        dict.set_item(
1191            "csm_cuda_graph_fallbacks",
1192            self.output_provider.csm_cuda_graph_fallbacks(),
1193        )?;
1194        dict.set_item(
1195            "csm_cuda_graph_cache_hits",
1196            self.output_provider.csm_cuda_graph_cache_hits(),
1197        )?;
1198        Ok(dict.into())
1199    }
1200}
1201
1202fn pack_rule_provenance(
1203    py: Python<'_>,
1204    entries: &[xlog_logic::RuleProvenance],
1205) -> PyResult<PyObject> {
1206    let list = PyList::empty(py);
1207    for entry in entries {
1208        let dict = PyDict::new(py);
1209        dict.set_item("rule_id", &entry.rule_id)?;
1210        dict.set_item("head", &entry.head)?;
1211        dict.set_item("source_kind", entry.source_kind.as_str())?;
1212        dict.set_item("source_span", entry.source_span.clone())?;
1213        dict.set_item("generation_trace_hash", entry.generation_trace_hash.clone())?;
1214        dict.set_item("support_relation_ids", entry.support_relation_ids.clone())?;
1215        dict.set_item(
1216            "counterexample_relation_ids",
1217            entry.counterexample_relation_ids.clone(),
1218        )?;
1219        list.append(dict)?;
1220    }
1221    Ok(list.into())
1222}
1223
1224fn pack_proof_traces(
1225    py: Python<'_>,
1226    entries: &[xlog_logic::QueryProofTrace],
1227) -> PyResult<PyObject> {
1228    let list = PyList::empty(py);
1229    for entry in entries {
1230        let dict = PyDict::new(py);
1231        dict.set_item("query_id", &entry.query_id)?;
1232        dict.set_item("query", &entry.query)?;
1233        dict.set_item("answer_relation", &entry.answer_relation)?;
1234        dict.set_item("rule_ids", entry.rule_ids.clone())?;
1235        dict.set_item("source_facts", entry.source_facts.clone())?;
1236        dict.set_item("rejected_alternatives", entry.rejected_alternatives.clone())?;
1237        list.append(dict)?;
1238    }
1239    Ok(list.into())
1240}