1use std::collections::{BTreeMap, BTreeSet};
4
5use xlog_core::symbol;
6
7use crate::ast::{AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Program, Rule, Term};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum RuleSourceKind {
12 Source,
14 Generated,
16 Mined,
18 Imported,
20 RuntimeInjected,
22}
23
24impl RuleSourceKind {
25 pub fn as_str(self) -> &'static str {
27 match self {
28 RuleSourceKind::Source => "source",
29 RuleSourceKind::Generated => "generated",
30 RuleSourceKind::Mined => "mined",
31 RuleSourceKind::Imported => "imported",
32 RuleSourceKind::RuntimeInjected => "runtime_injected",
33 }
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct RuleProvenance {
40 pub rule_id: String,
42 pub head: String,
44 pub source_kind: RuleSourceKind,
46 pub source_span: Option<String>,
48 pub generation_trace_hash: Option<String>,
50 pub support_relation_ids: Vec<String>,
52 pub counterexample_relation_ids: Vec<String>,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct QueryProofTrace {
59 pub query_id: String,
61 pub query: String,
63 pub answer_relation: String,
65 pub rule_ids: Vec<String>,
67 pub source_facts: Vec<String>,
69 pub rejected_alternatives: Vec<String>,
71}
72
73pub fn rule_provenance(
75 program: &Program,
76 generated_program: Option<&Program>,
77) -> Vec<RuleProvenance> {
78 let mut out = Vec::new();
79 let mut source_keys = BTreeSet::new();
80
81 for (idx, rule) in program.rules.iter().enumerate() {
82 source_keys.insert(rule_key(rule));
83 out.push(rule_record(idx, rule, RuleSourceKind::Source));
84 }
85
86 if let Some(generated) = generated_program {
87 let mut generated_idx = 0usize;
88 for rule in &generated.rules {
89 if source_keys.contains(&rule_key(rule)) {
90 continue;
91 }
92 out.push(rule_record(generated_idx, rule, RuleSourceKind::Generated));
93 generated_idx += 1;
94 }
95 }
96
97 out
98}
99
100pub fn build_rule_provenance(
102 program: &Program,
103 generated_predicates: &[String],
104) -> Vec<RuleProvenance> {
105 let mut out = rule_provenance(program, None);
106 for (idx, predicate) in generated_predicates.iter().enumerate() {
107 out.push(RuleProvenance {
108 rule_id: format!("rule:generated:{}:{}", idx, predicate),
109 head: predicate.clone(),
110 source_kind: RuleSourceKind::Generated,
111 source_span: None,
112 generation_trace_hash: Some(stable_hash(&format!("generated:{}", predicate))),
113 support_relation_ids: vec![predicate.clone()],
114 counterexample_relation_ids: Vec::new(),
115 });
116 }
117 out
118}
119
120pub fn query_proof_traces(
122 program: &Program,
123 provenance: &[RuleProvenance],
124) -> Vec<QueryProofTrace> {
125 let mut rule_ids_by_head: BTreeMap<String, Vec<String>> = BTreeMap::new();
126 for entry in provenance {
127 rule_ids_by_head
128 .entry(head_predicate(&entry.head).to_string())
129 .or_default()
130 .push(entry.rule_id.clone());
131 }
132
133 program
134 .queries
135 .iter()
136 .enumerate()
137 .map(|(idx, query)| {
138 let query_pred = query.atom.predicate.clone();
139 let deriving_rules: Vec<&Rule> = program
140 .rules
141 .iter()
142 .filter(|rule| !rule.is_fact() && rule.head.predicate == query_pred)
143 .collect();
144 let rule_ids = rule_ids_by_head
145 .get(&query_pred)
146 .cloned()
147 .unwrap_or_default();
148 let source_facts = source_facts_for_rules(program, &deriving_rules);
149 let rejected_alternatives = deriving_rules
150 .iter()
151 .flat_map(|rule| {
152 rule.body.iter().filter_map(|lit| match lit {
153 BodyLiteral::Negated(atom) => Some(format!("not {}", format_atom(atom))),
154 _ => None,
155 })
156 })
157 .collect::<Vec<_>>();
158
159 QueryProofTrace {
160 query_id: format!("query:source:{}:{}", idx, format_atom(&query.atom)),
161 query: format_atom(&query.atom),
162 answer_relation: format!("__xlog_query_{}", idx),
163 rule_ids,
164 source_facts,
165 rejected_alternatives,
166 }
167 })
168 .collect()
169}
170
171pub fn build_query_proof_traces(program: &Program) -> Vec<QueryProofTrace> {
173 let provenance = rule_provenance(program, None);
174 query_proof_traces(program, &provenance)
175}
176
177fn rule_record(idx: usize, rule: &Rule, source_kind: RuleSourceKind) -> RuleProvenance {
178 let head = format_atom(&rule.head);
179 let prefix = source_kind.as_str();
180 RuleProvenance {
181 rule_id: format!("rule:{}:{}:{}", prefix, idx, stable_hash(&rule_key(rule))),
182 head,
183 source_kind,
184 source_span: Some(format!("rule_index:{}", idx)),
185 generation_trace_hash: Some(stable_hash(&rule_key(rule))),
186 support_relation_ids: support_relation_ids(rule),
187 counterexample_relation_ids: Vec::new(),
188 }
189}
190
191fn support_relation_ids(rule: &Rule) -> Vec<String> {
192 rule.body_predicates()
193 .into_iter()
194 .map(str::to_string)
195 .collect::<BTreeSet<_>>()
196 .into_iter()
197 .collect()
198}
199
200fn source_facts_for_rules(program: &Program, rules: &[&Rule]) -> Vec<String> {
201 let wanted: BTreeSet<String> = rules
202 .iter()
203 .flat_map(|rule| {
204 rule.body
205 .iter()
206 .filter_map(|lit| lit.atom().map(|atom| atom.predicate.clone()))
207 })
208 .collect();
209
210 let mut facts = BTreeSet::new();
211 for fact in program.facts() {
212 if wanted.contains(&fact.head.predicate) {
213 facts.insert(format!("{}.", format_atom(&fact.head)));
214 }
215 }
216 facts.into_iter().collect()
217}
218
219fn rule_key(rule: &Rule) -> String {
220 let mut out = format_atom(&rule.head);
221 if !rule.body.is_empty() {
222 let body = rule
223 .body
224 .iter()
225 .map(format_body_literal)
226 .collect::<Vec<_>>()
227 .join(", ");
228 out.push_str(" :- ");
229 out.push_str(&body);
230 }
231 out
232}
233
234fn head_predicate(head: &str) -> &str {
235 head.split_once('(').map(|(pred, _)| pred).unwrap_or(head)
236}
237
238pub fn format_atom(atom: &Atom) -> String {
240 let args = atom
241 .terms
242 .iter()
243 .map(format_term)
244 .collect::<Vec<_>>()
245 .join(", ");
246 format!("{}({})", atom.predicate, args)
247}
248
249fn format_body_literal(lit: &BodyLiteral) -> String {
250 match lit {
251 BodyLiteral::Positive(atom) => format_atom(atom),
252 BodyLiteral::Negated(atom) => format!("not {}", format_atom(atom)),
253 BodyLiteral::Epistemic(lit) => format_epistemic_literal(lit),
254 BodyLiteral::Comparison(comparison) => format!(
255 "{} {} {}",
256 format_term(&comparison.left),
257 format_comp_op(comparison.op),
258 format_term(&comparison.right)
259 ),
260 BodyLiteral::IsExpr(is_expr) => {
261 format!("{} is {}", is_expr.target, format_arith_expr(&is_expr.expr))
262 }
263 BodyLiteral::Univ(univ) => {
264 format!(
265 "{} =.. {}",
266 format_term(&univ.term),
267 format_term(&univ.parts)
268 )
269 }
270 }
271}
272
273fn format_epistemic_literal(lit: &crate::ast::EpistemicLiteral) -> String {
274 let op = match lit.op {
275 crate::ast::EpistemicOp::Know => "know",
276 crate::ast::EpistemicOp::Possible => "possible",
277 };
278 if lit.negated {
279 format!("not {} {}", op, format_atom(&lit.atom))
280 } else {
281 format!("{} {}", op, format_atom(&lit.atom))
282 }
283}
284
285fn format_term(term: &Term) -> String {
286 match term {
287 Term::Variable(name) => name.clone(),
288 Term::Anonymous => "_".to_string(),
289 Term::Integer(value) => value.to_string(),
290 Term::Float(value) => value.to_string(),
291 Term::String(value) => format!("\"{}\"", value),
292 Term::Symbol(id) => symbol::resolve(*id),
293 Term::List(items) => {
294 let values = items.iter().map(format_term).collect::<Vec<_>>().join(", ");
295 format!("[{}]", values)
296 }
297 Term::Cons { head, tail } => {
298 format!("[{} | {}]", format_term(head), format_term(tail))
299 }
300 Term::Compound { functor, args } => {
301 let values = args.iter().map(format_term).collect::<Vec<_>>().join(", ");
302 format!("{}({})", functor, values)
303 }
304 Term::PredRef(name) => name.clone(),
305 Term::Aggregate(agg) => format!("{}({})", format_agg_op(agg.op), agg.variable),
306 }
307}
308
309fn format_arith_expr(expr: &ArithExpr) -> String {
310 match expr {
311 ArithExpr::Variable(name) => name.clone(),
312 ArithExpr::Integer(value) => value.to_string(),
313 ArithExpr::Float(value) => value.to_string(),
314 ArithExpr::Add(left, right) => {
315 format!(
316 "({} + {})",
317 format_arith_expr(left),
318 format_arith_expr(right)
319 )
320 }
321 ArithExpr::Sub(left, right) => {
322 format!(
323 "({} - {})",
324 format_arith_expr(left),
325 format_arith_expr(right)
326 )
327 }
328 ArithExpr::Mul(left, right) => {
329 format!(
330 "({} * {})",
331 format_arith_expr(left),
332 format_arith_expr(right)
333 )
334 }
335 ArithExpr::Div(left, right) => {
336 format!(
337 "({} / {})",
338 format_arith_expr(left),
339 format_arith_expr(right)
340 )
341 }
342 ArithExpr::Mod(left, right) => {
343 format!(
344 "({} % {})",
345 format_arith_expr(left),
346 format_arith_expr(right)
347 )
348 }
349 ArithExpr::Abs(value) => format!("abs({})", format_arith_expr(value)),
350 ArithExpr::Min(left, right) => {
351 format!(
352 "min({}, {})",
353 format_arith_expr(left),
354 format_arith_expr(right)
355 )
356 }
357 ArithExpr::Max(left, right) => {
358 format!(
359 "max({}, {})",
360 format_arith_expr(left),
361 format_arith_expr(right)
362 )
363 }
364 ArithExpr::Pow(left, right) => {
365 format!(
366 "pow({}, {})",
367 format_arith_expr(left),
368 format_arith_expr(right)
369 )
370 }
371 ArithExpr::Cast(value, ty) => format!("cast({}, {:?})", format_arith_expr(value), ty),
372 ArithExpr::FuncCall { name, args } => {
373 let values = args
374 .iter()
375 .map(format_arith_expr)
376 .collect::<Vec<_>>()
377 .join(", ");
378 format!("{}({})", name, values)
379 }
380 ArithExpr::Conditional {
381 cond_left,
382 cond_op,
383 cond_right,
384 then_expr,
385 else_expr,
386 } => format!(
387 "if {} {} {} then {} else {}",
388 format_arith_expr(cond_left),
389 format_comp_op(*cond_op),
390 format_arith_expr(cond_right),
391 format_arith_expr(then_expr),
392 format_arith_expr(else_expr)
393 ),
394 }
395}
396
397fn format_comp_op(op: CompOp) -> &'static str {
398 match op {
399 CompOp::Eq => "==",
400 CompOp::Ne => "!=",
401 CompOp::Lt => "<",
402 CompOp::Le => "<=",
403 CompOp::Gt => ">",
404 CompOp::Ge => ">=",
405 }
406}
407
408fn format_agg_op(op: AggOp) -> &'static str {
409 match op {
410 AggOp::Count => "count",
411 AggOp::Sum => "sum",
412 AggOp::Min => "min",
413 AggOp::Max => "max",
414 AggOp::LogSumExp => "logsumexp",
415 }
416}
417
418fn stable_hash(value: &str) -> String {
419 let mut hash = 0xcbf29ce484222325u64;
420 for byte in value.as_bytes() {
421 hash ^= u64::from(*byte);
422 hash = hash.wrapping_mul(0x100000001b3);
423 }
424 format!("{:016x}", hash)
425}