xlog_logic/
proof_trace.rs1use std::collections::BTreeMap;
4
5#[derive(Clone, Debug, PartialEq)]
7pub struct ProofTraceSpec {
8 pub answer_key: String,
10 pub clause_id: String,
12 pub support_atoms: Vec<String>,
14 pub initial_weight: f64,
16}
17
18#[derive(Clone, Debug, PartialEq)]
20pub struct ProofTrace {
21 pub proof_id: u64,
23 pub answer_key: String,
25 pub clause_id: String,
27 pub support_atoms: Vec<String>,
29 pub weight: f64,
31 pub gradient: f64,
33}
34
35#[derive(Clone, Debug, Default)]
37pub struct DifferentiableProofTraceMap {
38 traces: BTreeMap<u64, ProofTrace>,
39}
40
41impl DifferentiableProofTraceMap {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn insert(&mut self, spec: ProofTraceSpec) -> u64 {
49 let mut proof_id = stable_proof_id(&spec.answer_key, &spec.clause_id, &spec.support_atoms);
50 while let Some(existing) = self.traces.get(&proof_id) {
51 if existing.answer_key == spec.answer_key
52 && existing.clause_id == spec.clause_id
53 && existing.support_atoms == spec.support_atoms
54 {
55 return proof_id;
56 }
57 proof_id = proof_id.wrapping_add(1);
58 }
59 self.traces.insert(
60 proof_id,
61 ProofTrace {
62 proof_id,
63 answer_key: spec.answer_key,
64 clause_id: spec.clause_id,
65 support_atoms: spec.support_atoms,
66 weight: spec.initial_weight,
67 gradient: 0.0,
68 },
69 );
70 proof_id
71 }
72
73 pub fn trace(&self, proof_id: u64) -> Option<&ProofTrace> {
75 self.traces.get(&proof_id)
76 }
77
78 pub fn traces(&self) -> impl Iterator<Item = &ProofTrace> {
80 self.traces.values()
81 }
82
83 pub fn accumulate_binary_logistic_gradients(&mut self, targets: &[(String, f64)]) -> f64 {
85 for trace in self.traces.values_mut() {
86 trace.gradient = 0.0;
87 }
88
89 let mut loss = 0.0;
90 for (answer_key, target) in targets {
91 let score: f64 = self
92 .traces
93 .values()
94 .filter(|trace| &trace.answer_key == answer_key)
95 .map(|trace| trace.weight)
96 .sum();
97 let prediction = sigmoid(score);
98 let clamped = prediction.clamp(1e-12, 1.0 - 1e-12);
99 loss += -target * clamped.ln() - (1.0 - target) * (1.0 - clamped).ln();
100 let gradient = prediction - target;
101 for trace in self
102 .traces
103 .values_mut()
104 .filter(|trace| &trace.answer_key == answer_key)
105 {
106 trace.gradient += gradient;
107 }
108 }
109 loss
110 }
111
112 pub fn apply_gradients(&mut self, learning_rate: f64) {
114 for trace in self.traces.values_mut() {
115 trace.weight -= learning_rate * trace.gradient;
116 }
117 }
118}
119
120fn sigmoid(value: f64) -> f64 {
121 1.0 / (1.0 + (-value).exp())
122}
123
124fn stable_proof_id(answer_key: &str, clause_id: &str, support_atoms: &[String]) -> u64 {
125 let mut hash = 0xcbf29ce484222325u64;
126 for part in std::iter::once(answer_key)
127 .chain(std::iter::once(clause_id))
128 .chain(support_atoms.iter().map(String::as_str))
129 {
130 for byte in part.as_bytes() {
131 hash ^= u64::from(*byte);
132 hash = hash.wrapping_mul(0x100000001b3);
133 }
134 hash ^= 0xff;
135 hash = hash.wrapping_mul(0x100000001b3);
136 }
137 hash
138}