Skip to main content

xlog_logic/
proof_trace.rs

1//! Differentiable proof trace records for proof-path training surfaces.
2
3use std::collections::BTreeMap;
4
5/// Input used to create one differentiable proof trace.
6#[derive(Clone, Debug, PartialEq)]
7pub struct ProofTraceSpec {
8    /// Stable answer key, for example `root(case_1, primary_root)`.
9    pub answer_key: String,
10    /// Stable symbolic clause identifier.
11    pub clause_id: String,
12    /// Atoms supporting the proof path.
13    pub support_atoms: Vec<String>,
14    /// Initial symbolic clause weight.
15    pub initial_weight: f64,
16}
17
18/// Exported differentiable proof trace.
19#[derive(Clone, Debug, PartialEq)]
20pub struct ProofTrace {
21    /// Stable proof identifier derived from answer, clause, and support atoms.
22    pub proof_id: u64,
23    /// Stable answer key.
24    pub answer_key: String,
25    /// Symbolic clause identifier.
26    pub clause_id: String,
27    /// Atoms supporting the proof path.
28    pub support_atoms: Vec<String>,
29    /// Current symbolic clause weight.
30    pub weight: f64,
31    /// Accumulated gradient for the current training step.
32    pub gradient: f64,
33}
34
35/// Collection of differentiable proof traces keyed by stable proof ID.
36#[derive(Clone, Debug, Default)]
37pub struct DifferentiableProofTraceMap {
38    traces: BTreeMap<u64, ProofTrace>,
39}
40
41impl DifferentiableProofTraceMap {
42    /// Create an empty proof trace map.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Insert a trace and return its stable proof ID.
48    pub fn insert(&mut self, spec: ProofTraceSpec) -> u64 {
49        let mut proof_id = stable_proof_id(&spec.answer_key, &spec.clause_id, &spec.support_atoms);
50        while let Some(existing) = self.traces.get(&proof_id) {
51            if existing.answer_key == spec.answer_key
52                && existing.clause_id == spec.clause_id
53                && existing.support_atoms == spec.support_atoms
54            {
55                return proof_id;
56            }
57            proof_id = proof_id.wrapping_add(1);
58        }
59        self.traces.insert(
60            proof_id,
61            ProofTrace {
62                proof_id,
63                answer_key: spec.answer_key,
64                clause_id: spec.clause_id,
65                support_atoms: spec.support_atoms,
66                weight: spec.initial_weight,
67                gradient: 0.0,
68            },
69        );
70        proof_id
71    }
72
73    /// Get one exported trace by proof ID.
74    pub fn trace(&self, proof_id: u64) -> Option<&ProofTrace> {
75        self.traces.get(&proof_id)
76    }
77
78    /// Iterate over exported traces.
79    pub fn traces(&self) -> impl Iterator<Item = &ProofTrace> {
80        self.traces.values()
81    }
82
83    /// Accumulate binary logistic gradients grouped by answer key.
84    pub fn accumulate_binary_logistic_gradients(&mut self, targets: &[(String, f64)]) -> f64 {
85        for trace in self.traces.values_mut() {
86            trace.gradient = 0.0;
87        }
88
89        let mut loss = 0.0;
90        for (answer_key, target) in targets {
91            let score: f64 = self
92                .traces
93                .values()
94                .filter(|trace| &trace.answer_key == answer_key)
95                .map(|trace| trace.weight)
96                .sum();
97            let prediction = sigmoid(score);
98            let clamped = prediction.clamp(1e-12, 1.0 - 1e-12);
99            loss += -target * clamped.ln() - (1.0 - target) * (1.0 - clamped).ln();
100            let gradient = prediction - target;
101            for trace in self
102                .traces
103                .values_mut()
104                .filter(|trace| &trace.answer_key == answer_key)
105            {
106                trace.gradient += gradient;
107            }
108        }
109        loss
110    }
111
112    /// Apply accumulated gradients to trace weights.
113    pub fn apply_gradients(&mut self, learning_rate: f64) {
114        for trace in self.traces.values_mut() {
115            trace.weight -= learning_rate * trace.gradient;
116        }
117    }
118}
119
120fn sigmoid(value: f64) -> f64 {
121    1.0 / (1.0 + (-value).exp())
122}
123
124fn stable_proof_id(answer_key: &str, clause_id: &str, support_atoms: &[String]) -> u64 {
125    let mut hash = 0xcbf29ce484222325u64;
126    for part in std::iter::once(answer_key)
127        .chain(std::iter::once(clause_id))
128        .chain(support_atoms.iter().map(String::as_str))
129    {
130        for byte in part.as_bytes() {
131            hash ^= u64::from(*byte);
132            hash = hash.wrapping_mul(0x100000001b3);
133        }
134        hash ^= 0xff;
135        hash = hash.wrapping_mul(0x100000001b3);
136    }
137    hash
138}