1use std::collections::BTreeMap;
30use std::sync::Arc;
31
32use cudarc::driver::LaunchConfig;
33use xlog_core::{Result, XlogError};
34use xlog_cuda::memory::TrackedCudaSlice;
35use xlog_cuda::provider::{mc_resident_kernels, MC_RESIDENT_MODULE};
36use xlog_cuda::{CudaKernelProvider, LaunchAsync};
37use xlog_logic::ast::{Atom, BodyLiteral, Term};
38
39use super::{McEvalConfig, McProgram, McSamplingMethod};
40use crate::provenance::{GroundAtom, Value};
41
42pub const MAX_ARITY: usize = 3;
44pub const MAX_BODY: usize = 3;
46pub const MAX_VARS: usize = 8;
48pub const MAX_UNIVERSE: usize = 1 << 16;
51pub const MAX_DOMAIN: usize = 256;
53const ATOM_REC: usize = 6;
55const RULE_REC: usize = 3 + 4 * ATOM_REC;
57const CONST_FLAG: u32 = 0x8000_0000;
59const RESIDENT_BUDGET_ENV: &str = "XLOG_MC_RESIDENT_MEMORY_BUDGET_BYTES";
60const RESIDENT_BLOCKS_PER_WORLD_ENV: &str = "XLOG_MC_RESIDENT_BLOCKS_PER_WORLD";
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum ResidentRejectKind {
65 Negation,
67 EpistemicLiteral,
69 NonRelationalLiteral,
71 ArityTooHigh,
73 BodyTooLong,
75 TooManyVars,
77 UnboundedTerm,
79 DomainTooLarge,
81 UniverseTooLarge,
83 InconsistentArity,
85 AnnotatedDisjunctionUnsupported,
87}
88
89impl ResidentRejectKind {
90 pub fn as_str(self) -> &'static str {
91 match self {
92 ResidentRejectKind::Negation => "negation",
93 ResidentRejectKind::EpistemicLiteral => "epistemic_literal",
94 ResidentRejectKind::NonRelationalLiteral => "non_relational_literal",
95 ResidentRejectKind::ArityTooHigh => "arity_too_high",
96 ResidentRejectKind::BodyTooLong => "body_too_long",
97 ResidentRejectKind::TooManyVars => "too_many_vars",
98 ResidentRejectKind::UnboundedTerm => "unbounded_term",
99 ResidentRejectKind::DomainTooLarge => "domain_too_large",
100 ResidentRejectKind::UniverseTooLarge => "universe_too_large",
101 ResidentRejectKind::InconsistentArity => "inconsistent_arity",
102 ResidentRejectKind::AnnotatedDisjunctionUnsupported => {
103 "annotated_disjunction_unsupported"
104 }
105 }
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct ResidentRejection {
113 pub kind: ResidentRejectKind,
114 pub construct: String,
115 pub context: String,
116}
117
118impl ResidentRejection {
119 fn err(
120 kind: ResidentRejectKind,
121 construct: impl Into<String>,
122 context: impl Into<String>,
123 ) -> Self {
124 ResidentRejection {
125 kind,
126 construct: construct.into(),
127 context: context.into(),
128 }
129 }
130
131 pub fn into_error(self) -> XlogError {
134 XlogError::Compilation(format!(
135 "resident MC engine rejected program [kind={}] construct=`{}` context=`{}`",
136 self.kind.as_str(),
137 self.construct,
138 self.context
139 ))
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
146enum ConstKey {
147 Int(i64),
148 Sym(u32),
149 Str(String),
150 FloatBits(u64),
151}
152
153impl ConstKey {
154 fn from_value(v: &Value) -> ConstKey {
155 match v {
156 Value::I64(i) => ConstKey::Int(*i),
157 Value::Symbol(s) => ConstKey::Sym(*s),
158 Value::String(s) => ConstKey::Str(s.clone()),
159 Value::F64(bits) => ConstKey::FloatBits(*bits),
160 }
161 }
162
163 fn from_term(t: &Term) -> std::result::Result<TermClass, ResidentRejection> {
165 match t {
166 Term::Variable(name) => Ok(TermClass::Var(name.clone())),
167 Term::Integer(i) => Ok(TermClass::Const(ConstKey::Int(*i))),
168 Term::Symbol(s) => Ok(TermClass::Const(ConstKey::Sym(*s))),
169 Term::String(s) => Ok(TermClass::Const(ConstKey::Str(s.clone()))),
170 Term::Float(f) => Ok(TermClass::Const(ConstKey::FloatBits(f.to_bits()))),
171 other => Err(ResidentRejection::err(
172 ResidentRejectKind::UnboundedTerm,
173 format!("{:?}", other),
174 "rule term must be a variable or ground constant",
175 )),
176 }
177 }
178}
179
180enum TermClass {
181 Var(String),
182 Const(ConstKey),
183}
184
185#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
187pub struct McNoHostStats {
188 pub tracked_htod_calls: u64,
190 pub tracked_dtoh_calls: u64,
192 pub untracked_metadata_reads: u64,
194 pub engine_launches: u64,
196 pub host_loop_iterations: u64,
199 pub per_sample_host_launches: u64,
202 pub host_fixpoint_iterations: u64,
206 pub per_operator_host_allocations: u64,
210}
211
212impl McNoHostStats {
213 pub fn is_no_host(&self) -> bool {
219 self.tracked_htod_calls == 0
220 && self.tracked_dtoh_calls == 0
221 && self.untracked_metadata_reads == 0
222 && self.host_loop_iterations == 0
223 && self.per_sample_host_launches == 0
224 && self.host_fixpoint_iterations == 0
225 && self.per_operator_host_allocations == 0
226 }
227}
228
229pub struct McResidentResult {
232 pub query_counts: TrackedCudaSlice<u32>,
233 pub evidence_count: TrackedCudaSlice<u32>,
234 pub iter_trace: TrackedCudaSlice<u32>,
236 pub sparse_final_row_counts: TrackedCudaSlice<u32>,
238 pub sparse_offsets: TrackedCudaSlice<u32>,
241 pub resident_status_flags: TrackedCudaSlice<u32>,
244 pub total_samples: usize,
245 pub seed: u64,
246 pub confidence: f64,
247 pub sampling_method: McSamplingMethod,
248 pub num_queries: usize,
250 pub no_host: McNoHostStats,
251}
252
253#[derive(Debug, Clone)]
255pub struct ResidentPlan {
256 pub universe_size: u32,
257 pub domain_size: u32,
258 pub max_iters: u32,
259 edb_slots: Vec<u32>,
260 pf_slot: Vec<u32>,
261 pf_var: Vec<u32>,
262 rule_data: Vec<u32>,
263 num_rules: u32,
264 q_slot: Vec<u32>,
265 ev_slot: Vec<u32>,
266 ev_expected: Vec<u8>,
267 ad_data: Vec<u32>,
268 num_ads: u32,
269 pub num_vars: usize,
270 bernoulli_probs: Vec<f32>,
271}
272
273fn resident_memory_budget_bytes() -> Result<Option<u64>> {
274 match std::env::var(RESIDENT_BUDGET_ENV) {
275 Ok(raw) => raw.parse::<u64>().map(Some).map_err(|e| {
276 XlogError::Execution(format!("invalid {RESIDENT_BUDGET_ENV} value `{raw}`: {e}"))
277 }),
278 Err(std::env::VarError::NotPresent) => Ok(None),
279 Err(e) => Err(XlogError::Execution(format!(
280 "invalid {RESIDENT_BUDGET_ENV}: {e}"
281 ))),
282 }
283}
284
285fn resident_blocks_per_world() -> Result<u32> {
286 match std::env::var(RESIDENT_BLOCKS_PER_WORLD_ENV) {
287 Ok(raw) => {
288 let blocks = raw.parse::<u32>().map_err(|e| {
289 XlogError::Execution(format!(
290 "invalid {RESIDENT_BLOCKS_PER_WORLD_ENV} value `{raw}`: {e}"
291 ))
292 })?;
293 if blocks == 0 {
294 return Err(XlogError::Execution(format!(
295 "invalid {RESIDENT_BLOCKS_PER_WORLD_ENV} value `{raw}`: must be >= 1"
296 )));
297 }
298 Ok(blocks)
299 }
300 Err(std::env::VarError::NotPresent) => Ok(1),
301 Err(e) => Err(XlogError::Execution(format!(
302 "invalid {RESIDENT_BLOCKS_PER_WORLD_ENV}: {e}"
303 ))),
304 }
305}
306
307fn sat_mul(a: u64, b: u64) -> u64 {
308 a.saturating_mul(b)
309}
310
311fn sat_pow(mut base: u64, mut exp: u32) -> u64 {
312 let mut acc = 1u64;
313 while exp > 0 {
314 if exp & 1 == 1 {
315 acc = sat_mul(acc, base);
316 }
317 exp >>= 1;
318 if exp > 0 {
319 base = sat_mul(base, base);
320 }
321 }
322 acc
323}
324
325fn estimate_resident_bound_bytes(plan: &ResidentPlan, num_worlds: u32) -> u64 {
326 let worlds = num_worlds.max(1) as u64;
327 let vars = plan.num_vars.max(1) as u64;
328 let universe = plan.universe_size.max(1) as u64;
329 let meta_words = plan
330 .edb_slots
331 .len()
332 .saturating_add(plan.pf_slot.len())
333 .saturating_add(plan.pf_var.len())
334 .saturating_add(plan.rule_data.len())
335 .saturating_add(plan.q_slot.len())
336 .saturating_add(plan.ev_slot.len())
337 .saturating_add(plan.ev_expected.len())
338 .saturating_add(plan.ad_data.len())
339 .saturating_add(18);
340
341 let sparse_cap = universe;
342 let setup_bytes = sat_mul(2, vars)
343 .saturating_add(sat_mul(worlds, vars))
344 .saturating_add(sat_mul(sat_mul(sat_mul(worlds, 2), universe), 4))
345 .saturating_add(sat_mul(sat_mul(sat_mul(worlds, 2), sparse_cap), 16))
346 .saturating_add(sat_mul(sat_mul(worlds, 2), 4))
347 .saturating_add(sat_mul(worlds, 4))
348 .saturating_add(sat_mul(worlds.saturating_add(1), 4))
349 .saturating_add(sat_mul(worlds.saturating_mul(4).saturating_add(1), 4))
350 .saturating_add(sat_mul(meta_words as u64, 4))
351 .saturating_add(sat_mul(plan.q_slot.len().max(1) as u64, 4))
352 .saturating_add(4)
353 .saturating_add(sat_mul(worlds, 4));
354
355 let mut sparse_join_bytes = 0u64;
356 for rule in plan.rule_data.chunks_exact(RULE_REC) {
357 let n_body = rule[0];
358 if n_body < 2 {
359 continue;
360 }
361 let n_vars = rule[1];
362 let assignments = sat_pow(plan.domain_size.max(1) as u64, n_vars);
363 let row_words = (n_body as u64).saturating_add(1);
364 sparse_join_bytes =
365 sparse_join_bytes.max(sat_mul(sat_mul(worlds, assignments), row_words * 4));
366 }
367
368 setup_bytes.saturating_add(sparse_join_bytes)
369}
370
371struct PredInfo {
372 arity: usize,
373 base: u32,
374}
375
376struct Universe {
377 domain: BTreeMap<ConstKey, u32>,
378 preds: BTreeMap<String, PredInfo>,
379 domain_size: u32,
380}
381
382impl Universe {
383 fn stride0(&self, arity: usize) -> u32 {
384 if arity >= 2 {
385 self.domain_size.pow((arity - 1) as u32)
386 } else {
387 1
388 }
389 }
390
391 fn arg_stride(&self, arity: usize, arg_idx: usize) -> u32 {
392 if arity <= arg_idx + 1 {
393 1
394 } else {
395 self.domain_size.pow((arity - arg_idx - 1) as u32)
396 }
397 }
398
399 fn ground_slot(&self, atom: &GroundAtom) -> std::result::Result<u32, ResidentRejection> {
401 let info = self.preds.get(&atom.predicate).ok_or_else(|| {
402 ResidentRejection::err(
403 ResidentRejectKind::InconsistentArity,
404 atom.predicate.clone(),
405 "ground atom references unknown predicate",
406 )
407 })?;
408 if atom.args.len() != info.arity {
409 return Err(ResidentRejection::err(
410 ResidentRejectKind::InconsistentArity,
411 atom.predicate.clone(),
412 format!("expected arity {} got {}", info.arity, atom.args.len()),
413 ));
414 }
415 let mut slot = info.base;
416 for (i, v) in atom.args.iter().enumerate() {
417 let key = ConstKey::from_value(v);
418 let idx = *self.domain.get(&key).ok_or_else(|| {
419 ResidentRejection::err(
420 ResidentRejectKind::UnboundedTerm,
421 format!("{:?}", v),
422 "ground constant absent from bounded domain",
423 )
424 })?;
425 slot += idx * self.arg_stride(info.arity, i);
426 }
427 Ok(slot)
428 }
429}
430
431pub fn compile_resident_plan(
434 mc: &McProgram,
435) -> std::result::Result<ResidentPlan, ResidentRejection> {
436 let program = &mc.program;
437
438 let mut arities: BTreeMap<String, usize> = BTreeMap::new();
440 let mut note_pred = |pred: &str, arity: usize| -> std::result::Result<(), ResidentRejection> {
441 if arity > MAX_ARITY {
442 return Err(ResidentRejection::err(
443 ResidentRejectKind::ArityTooHigh,
444 pred.to_string(),
445 format!("arity {} exceeds max {}", arity, MAX_ARITY),
446 ));
447 }
448 match arities.get(pred) {
449 Some(&existing) if existing != arity => Err(ResidentRejection::err(
450 ResidentRejectKind::InconsistentArity,
451 pred.to_string(),
452 format!("arity {} vs {}", existing, arity),
453 )),
454 _ => {
455 arities.insert(pred.to_string(), arity);
456 Ok(())
457 }
458 }
459 };
460
461 let mut domain: BTreeMap<ConstKey, u32> = BTreeMap::new();
463 let mut note_const = |key: ConstKey, domain: &mut BTreeMap<ConstKey, u32>| {
464 let next = domain.len() as u32;
465 domain.entry(key).or_insert(next);
466 };
467
468 for fact in program.facts() {
470 note_pred(&fact.head.predicate, fact.head.terms.len())?;
471 for t in &fact.head.terms {
472 match ConstKey::from_term(t)? {
473 TermClass::Const(k) => note_const(k, &mut domain),
474 TermClass::Var(_) => {
475 return Err(ResidentRejection::err(
476 ResidentRejectKind::UnboundedTerm,
477 fact.head.predicate.clone(),
478 "fact head contains a variable",
479 ))
480 }
481 }
482 }
483 }
484 for pf in &mc.prob_facts {
486 note_pred(&pf.atom.predicate, pf.atom.args.len())?;
487 for v in &pf.atom.args {
488 note_const(ConstKey::from_value(v), &mut domain);
489 }
490 }
491 for q in &mc.queries {
493 note_pred(&q.predicate, q.args.len())?;
494 for v in &q.args {
495 note_const(ConstKey::from_value(v), &mut domain);
496 }
497 }
498 for (e, _) in &mc.evidence {
499 note_pred(&e.predicate, e.args.len())?;
500 for v in &e.args {
501 note_const(ConstKey::from_value(v), &mut domain);
502 }
503 }
504 for ad in &mc.annotated_disjunctions {
506 for atom in &ad.choices {
507 note_pred(&atom.predicate, atom.args.len())?;
508 for v in &atom.args {
509 note_const(ConstKey::from_value(v), &mut domain);
510 }
511 }
512 }
513 for rule in &program.rules {
515 if rule.is_fact() {
516 continue;
517 }
518 note_pred(&rule.head.predicate, rule.head.terms.len())?;
519 collect_atom_consts(&rule.head, &mut domain, &mut note_const)?;
520 if rule.body.len() > MAX_BODY {
521 return Err(ResidentRejection::err(
522 ResidentRejectKind::BodyTooLong,
523 rule.head.predicate.clone(),
524 format!("body length {} exceeds max {}", rule.body.len(), MAX_BODY),
525 ));
526 }
527 for lit in &rule.body {
528 let atom = classify_body_literal(lit, &rule.head.predicate)?;
529 note_pred(&atom.predicate, atom.terms.len())?;
530 collect_atom_consts(atom, &mut domain, &mut note_const)?;
531 }
532 }
533
534 if domain.len() > MAX_DOMAIN {
535 return Err(ResidentRejection::err(
536 ResidentRejectKind::DomainTooLarge,
537 format!("{} constants", domain.len()),
538 format!("domain exceeds max {}", MAX_DOMAIN),
539 ));
540 }
541 let domain_size = domain.len() as u32;
542
543 let mut preds: BTreeMap<String, PredInfo> = BTreeMap::new();
545 let mut base: u64 = 0;
546 for (pred, &arity) in &arities {
547 let slot_count: u64 = if arity == 0 {
548 1
549 } else {
550 (domain_size as u64).pow(arity as u32)
551 };
552 preds.insert(
553 pred.clone(),
554 PredInfo {
555 arity,
556 base: base as u32,
557 },
558 );
559 base += slot_count;
560 if base > MAX_UNIVERSE as u64 {
561 return Err(ResidentRejection::err(
562 ResidentRejectKind::UniverseTooLarge,
563 format!("{} slots", base),
564 format!("universe exceeds max {}", MAX_UNIVERSE),
565 ));
566 }
567 }
568 let universe_size = base as u32;
569
570 let universe = Universe {
571 domain,
572 preds,
573 domain_size,
574 };
575
576 let mut edb_slots = Vec::new();
578 for fact in program.facts() {
579 let ga = ground_atom_from_atom(&fact.head)?;
580 edb_slots.push(universe.ground_slot(&ga)?);
581 }
582 let mut pf_slot = Vec::new();
583 let mut pf_var = Vec::new();
584 for pf in &mc.prob_facts {
585 pf_slot.push(universe.ground_slot(&pf.atom)?);
586 pf_var.push(pf.var_idx as u32);
587 }
588 let mut q_slot = Vec::new();
589 for q in &mc.queries {
590 q_slot.push(universe.ground_slot(q)?);
591 }
592 let mut ev_slot = Vec::new();
593 let mut ev_expected = Vec::new();
594 for (e, v) in &mc.evidence {
595 ev_slot.push(universe.ground_slot(e)?);
596 ev_expected.push(if *v { 1u8 } else { 0u8 });
597 }
598
599 let mut rule_data = Vec::new();
601 let mut num_rules = 0u32;
602 for rule in &program.rules {
603 if rule.is_fact() {
604 continue;
605 }
606 let rec = lower_rule(rule, &universe)?;
607 rule_data.extend_from_slice(&rec);
608 num_rules += 1;
609 }
610
611 let mut ad_data: Vec<u32> = Vec::new();
616 let mut num_ads = 0u32;
617 for ad in &mc.annotated_disjunctions {
618 let n_choices = ad.choices.len() as u32;
619 let n_dvars = ad.decision_vars.len() as u32;
620 ad_data.push(n_choices);
621 ad_data.push(n_dvars);
622 for &dv in &ad.decision_vars {
623 ad_data.push(u32::try_from(dv).map_err(|_| {
624 ResidentRejection::err(
625 ResidentRejectKind::UnboundedTerm,
626 "decision_var",
627 "AD decision var index exceeds u32",
628 )
629 })?);
630 }
631 for atom in &ad.choices {
632 ad_data.push(universe.ground_slot(atom)?);
633 }
634 num_ads += 1;
635 }
636
637 let max_iters = universe_size.saturating_add(1).max(1);
641
642 Ok(ResidentPlan {
643 universe_size,
644 domain_size,
645 max_iters,
646 edb_slots,
647 pf_slot,
648 pf_var,
649 rule_data,
650 num_rules,
651 q_slot,
652 ev_slot,
653 ev_expected,
654 ad_data,
655 num_ads,
656 num_vars: mc.bernoulli_probs.len(),
657 bernoulli_probs: mc.bernoulli_probs.clone(),
658 })
659}
660
661fn collect_atom_consts<F: FnMut(ConstKey, &mut BTreeMap<ConstKey, u32>)>(
662 atom: &Atom,
663 domain: &mut BTreeMap<ConstKey, u32>,
664 note_const: &mut F,
665) -> std::result::Result<(), ResidentRejection> {
666 if atom.terms.len() > MAX_ARITY {
667 return Err(ResidentRejection::err(
668 ResidentRejectKind::ArityTooHigh,
669 atom.predicate.clone(),
670 format!("arity {} exceeds max {}", atom.terms.len(), MAX_ARITY),
671 ));
672 }
673 for t in &atom.terms {
674 if let TermClass::Const(k) = ConstKey::from_term(t)? {
675 note_const(k, domain);
676 }
677 }
678 Ok(())
679}
680
681fn classify_body_literal<'a>(
683 lit: &'a BodyLiteral,
684 rule_ctx: &str,
685) -> std::result::Result<&'a Atom, ResidentRejection> {
686 match lit {
687 BodyLiteral::Positive(a) => Ok(a),
688 BodyLiteral::Negated(a) => Err(ResidentRejection::err(
689 ResidentRejectKind::Negation,
690 a.predicate.clone(),
691 format!("negated literal in rule for `{}`", rule_ctx),
692 )),
693 BodyLiteral::Epistemic(l) => Err(ResidentRejection::err(
694 ResidentRejectKind::EpistemicLiteral,
695 l.atom.predicate.clone(),
696 format!("epistemic literal in rule for `{}`", rule_ctx),
697 )),
698 BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => {
699 Err(ResidentRejection::err(
700 ResidentRejectKind::NonRelationalLiteral,
701 "comparison/is/univ",
702 format!("non-relational literal in rule for `{}`", rule_ctx),
703 ))
704 }
705 }
706}
707
708fn lower_rule(
710 rule: &xlog_logic::ast::Rule,
711 universe: &Universe,
712) -> std::result::Result<Vec<u32>, ResidentRejection> {
713 let mut var_ids: BTreeMap<String, u32> = BTreeMap::new();
715 let assign_var = |name: &str,
716 var_ids: &mut BTreeMap<String, u32>|
717 -> std::result::Result<u32, ResidentRejection> {
718 if let Some(&id) = var_ids.get(name) {
719 return Ok(id);
720 }
721 let id = var_ids.len() as u32;
722 if id as usize >= MAX_VARS {
723 return Err(ResidentRejection::err(
724 ResidentRejectKind::TooManyVars,
725 rule.head.predicate.clone(),
726 format!("more than {} distinct variables", MAX_VARS),
727 ));
728 }
729 var_ids.insert(name.to_string(), id);
730 Ok(id)
731 };
732
733 let body_atoms: Vec<&Atom> = {
735 let mut v = Vec::new();
736 for lit in &rule.body {
737 v.push(classify_body_literal(lit, &rule.head.predicate)?);
738 }
739 v
740 };
741 for t in &rule.head.terms {
742 if let TermClass::Var(name) = ConstKey::from_term(t)? {
743 assign_var(&name, &mut var_ids)?;
744 }
745 }
746 for atom in &body_atoms {
747 for t in &atom.terms {
748 if let TermClass::Var(name) = ConstKey::from_term(t)? {
749 assign_var(&name, &mut var_ids)?;
750 }
751 }
752 }
753 let n_vars = var_ids.len() as u32;
754
755 let encode_atom = |atom: &Atom| -> std::result::Result<[u32; ATOM_REC], ResidentRejection> {
756 let info = universe.preds.get(&atom.predicate).ok_or_else(|| {
757 ResidentRejection::err(
758 ResidentRejectKind::InconsistentArity,
759 atom.predicate.clone(),
760 "rule atom references unknown predicate",
761 )
762 })?;
763 let arity = info.arity as u32;
764 let mut rec = [0u32; ATOM_REC];
765 rec[0] = info.base;
766 rec[1] = arity;
767 rec[5] = universe.stride0(info.arity);
768 for (i, t) in atom.terms.iter().enumerate() {
769 let spec = match ConstKey::from_term(t)? {
770 TermClass::Var(name) => *var_ids.get(&name).expect("var assigned above"),
771 TermClass::Const(k) => {
772 let idx = *universe.domain.get(&k).ok_or_else(|| {
773 ResidentRejection::err(
774 ResidentRejectKind::UnboundedTerm,
775 format!("{:?}", k),
776 "rule constant absent from bounded domain",
777 )
778 })?;
779 CONST_FLAG | idx
780 }
781 };
782 rec[2 + i] = spec;
783 }
784 Ok(rec)
785 };
786
787 let mut rec = vec![0u32; RULE_REC];
788 rec[0] = body_atoms.len() as u32;
789 rec[1] = n_vars;
790 rec[2] = universe.domain_size;
791 let head_rec = encode_atom(&rule.head)?;
792 rec[3..3 + ATOM_REC].copy_from_slice(&head_rec);
793 for (bi, atom) in body_atoms.iter().enumerate() {
794 let a = encode_atom(atom)?;
795 let off = 3 + ATOM_REC + bi * ATOM_REC;
796 rec[off..off + ATOM_REC].copy_from_slice(&a);
797 }
798 Ok(rec)
799}
800
801fn ground_atom_from_atom(atom: &Atom) -> std::result::Result<GroundAtom, ResidentRejection> {
802 let mut args = Vec::with_capacity(atom.terms.len());
803 for t in &atom.terms {
804 let v = match t {
805 Term::Integer(i) => Value::I64(*i),
806 Term::Symbol(s) => Value::Symbol(*s),
807 Term::String(s) => Value::String(s.clone()),
808 Term::Float(f) => Value::F64(f.to_bits()),
809 other => {
810 return Err(ResidentRejection::err(
811 ResidentRejectKind::UnboundedTerm,
812 format!("{:?}", other),
813 "fact term must be a ground constant",
814 ))
815 }
816 };
817 args.push(v);
818 }
819 Ok(GroundAtom {
820 predicate: atom.predicate.clone(),
821 args,
822 })
823}
824
825impl McProgram {
826 pub fn evaluate_resident_with_provider(
831 &self,
832 cfg: McEvalConfig,
833 provider: Arc<CudaKernelProvider>,
834 ) -> Result<McResidentResult> {
835 cfg.validate()?;
836 let plan = compile_resident_plan(self).map_err(ResidentRejection::into_error)?;
837 run_resident(&plan, &cfg, self, provider)
838 }
839
840 pub fn evaluate_resident(&self, cfg: McEvalConfig) -> Result<McResidentResult> {
842 let provider = Arc::new(self.provider()?);
843 self.evaluate_resident_with_provider(cfg, provider)
844 }
845}
846
847fn run_resident(
848 plan: &ResidentPlan,
849 cfg: &McEvalConfig,
850 mc: &McProgram,
851 provider: Arc<CudaKernelProvider>,
852) -> Result<McResidentResult> {
853 let (method, forcing) = mc.resolve_sampling_method(cfg.sampling_method)?;
854 let num_worlds = u32::try_from(cfg.samples)
855 .map_err(|_| XlogError::Execution("MC samples exceed u32::MAX".to_string()))?;
856 let blocks_per_world = resident_blocks_per_world()?;
857 let num_vars = plan.num_vars;
858
859 if let Some(budget_bytes) = resident_memory_budget_bytes()? {
860 let bound_bytes = estimate_resident_bound_bytes(plan, num_worlds);
861 if bound_bytes > budget_bytes {
862 return Err(XlogError::ResourceExhausted {
863 context: format!(
864 "resident_resource_budget operator=sparse_wcoj bound_bytes={bound_bytes} budget_bytes={budget_bytes}"
865 ),
866 estimated_bytes: bound_bytes,
867 budget_bytes,
868 });
869 }
870 }
871
872 let dev = provider.device();
877
878 let mut d_force_mask = provider.memory().alloc::<u8>(num_vars.max(1))?;
880 let mut d_forced_value = provider.memory().alloc::<u8>(num_vars.max(1))?;
881 if method == McSamplingMethod::EvidenceClamping && num_vars > 0 {
882 provider.htod_sync_copy_into_tracked(&forcing.force_mask, &mut d_force_mask)?;
883 provider.htod_sync_copy_into_tracked(&forcing.forced_value, &mut d_forced_value)?;
884 } else {
885 dev.inner()
886 .memset_zeros(&mut d_force_mask)
887 .map_err(|e| XlogError::Kernel(format!("zero force_mask: {e}")))?;
888 dev.inner()
889 .memset_zeros(&mut d_forced_value)
890 .map_err(|e| XlogError::Kernel(format!("zero forced_value: {e}")))?;
891 }
892 let samples_device = if num_vars == 0 || cfg.samples == 0 {
893 provider.memory().alloc::<u8>(1)?
894 } else {
895 provider.sample_bernoulli_matrix_device(
896 &plan.bernoulli_probs,
897 cfg.samples,
898 cfg.seed,
899 &d_force_mask.slice(..),
900 &d_forced_value.slice(..),
901 )?
902 };
903
904 let u = plan.universe_size.max(1) as usize;
907 let rel_len = (num_worlds as usize)
908 .saturating_mul(u)
909 .saturating_mul(2)
910 .max(1);
911 let mut d_rel = provider.memory().alloc::<u32>(rel_len)?;
912 dev.inner()
913 .memset_zeros(&mut d_rel)
914 .map_err(|e| XlogError::Kernel(format!("zero rel: {e}")))?;
915
916 let sparse_cap = u.max(1);
920 let sparse_len = (num_worlds as usize)
921 .saturating_mul(2)
922 .saturating_mul(sparse_cap)
923 .max(1);
924 let mut d_sparse_columns = provider
926 .memory()
927 .alloc::<u32>(sparse_len.saturating_mul(4).max(1))?;
928 let mut d_sparse_counts = provider
929 .memory()
930 .alloc::<u32>((num_worlds as usize).saturating_mul(2).max(1))?;
931 let mut d_sparse_final_counts = provider
932 .memory()
933 .alloc::<u32>((num_worlds as usize).max(1))?;
934 let mut d_sparse_offsets = provider
935 .memory()
936 .alloc::<u32>((num_worlds as usize).saturating_add(1).max(1))?;
937 let mut d_resident_status_flags = provider.memory().alloc::<u32>(
938 (num_worlds as usize)
939 .saturating_mul(4)
940 .saturating_add(1)
941 .max(1),
942 )?;
943 dev.inner()
944 .memset_zeros(&mut d_sparse_columns)
945 .map_err(|e| XlogError::Kernel(format!("zero sparse_columns: {e}")))?;
946 dev.inner()
947 .memset_zeros(&mut d_sparse_counts)
948 .map_err(|e| XlogError::Kernel(format!("zero sparse_counts: {e}")))?;
949 dev.inner()
950 .memset_zeros(&mut d_sparse_final_counts)
951 .map_err(|e| XlogError::Kernel(format!("zero sparse_final_counts: {e}")))?;
952 dev.inner()
953 .memset_zeros(&mut d_sparse_offsets)
954 .map_err(|e| XlogError::Kernel(format!("zero sparse_offsets: {e}")))?;
955 dev.inner()
956 .memset_zeros(&mut d_resident_status_flags)
957 .map_err(|e| XlogError::Kernel(format!("zero resident_status_flags: {e}")))?;
958
959 let q_count = plan.q_slot.len();
963 let ev_expected_u32: Vec<u32> = plan.ev_expected.iter().map(|&b| b as u32).collect();
964
965 let mut meta: Vec<u32> = Vec::new();
966 let push_meta = |data: &[u32], meta: &mut Vec<u32>| -> u32 {
967 let off = meta.len() as u32;
968 meta.extend_from_slice(data);
969 off
970 };
971 let edb_off = push_meta(&plan.edb_slots, &mut meta);
972 let pf_slot_off = push_meta(&plan.pf_slot, &mut meta);
973 let pf_var_off = push_meta(&plan.pf_var, &mut meta);
974 let rules_off = push_meta(&plan.rule_data, &mut meta);
975 let q_off = push_meta(&plan.q_slot, &mut meta);
976 let ev_slot_off = push_meta(&plan.ev_slot, &mut meta);
977 let ev_exp_off = push_meta(&ev_expected_u32, &mut meta);
978 let ad_off = push_meta(&plan.ad_data, &mut meta);
979
980 let cfg_host: [u32; 19] = [
982 num_worlds,
983 plan.universe_size,
984 num_vars as u32,
985 plan.max_iters,
986 edb_off,
987 plan.edb_slots.len() as u32,
988 pf_slot_off,
989 pf_var_off,
990 plan.pf_slot.len() as u32,
991 rules_off,
992 plan.num_rules,
993 q_off,
994 q_count as u32,
995 ev_slot_off,
996 ev_exp_off,
997 plan.ev_slot.len() as u32,
998 ad_off,
999 plan.num_ads,
1000 blocks_per_world,
1001 ];
1002
1003 let mut d_cfg = provider.memory().alloc::<u32>(cfg_host.len())?;
1004 provider.htod_sync_copy_into_tracked(&cfg_host, &mut d_cfg)?;
1005 let mut d_meta = provider.memory().alloc::<u32>(meta.len().max(1))?;
1006 if !meta.is_empty() {
1007 provider.htod_sync_copy_into_tracked(&meta, &mut d_meta)?;
1008 }
1009
1010 let mut d_query_counts = provider.memory().alloc::<u32>(q_count.max(1))?;
1011 dev.inner()
1012 .memset_zeros(&mut d_query_counts)
1013 .map_err(|e| XlogError::Kernel(format!("zero query_counts: {e}")))?;
1014 let mut d_evidence_count = provider.memory().alloc::<u32>(1)?;
1015 dev.inner()
1016 .memset_zeros(&mut d_evidence_count)
1017 .map_err(|e| XlogError::Kernel(format!("zero evidence_count: {e}")))?;
1018 let mut d_iter_trace = provider.memory().alloc::<u32>(num_worlds.max(1) as usize)?;
1019 dev.inner()
1020 .memset_zeros(&mut d_iter_trace)
1021 .map_err(|e| XlogError::Kernel(format!("zero iter_trace: {e}")))?;
1022
1023 let engine_fn = dev
1024 .inner()
1025 .get_func(MC_RESIDENT_MODULE, mc_resident_kernels::MC_RESIDENT_ENGINE)
1026 .ok_or_else(|| XlogError::Kernel("mc_resident_engine kernel not found".to_string()))?;
1027
1028 dev.synchronize()?;
1030
1031 let pre = provider.host_transfer_stats();
1033 let pre_untracked = provider.untracked_metadata_dtoh_count();
1034 let pre_allocs = provider.memory().alloc_count();
1035 let mut engine_launches = 0u64;
1036
1037 let block_dim = 128u32;
1038 let grid_dim = num_worlds
1039 .max(1)
1040 .checked_mul(blocks_per_world)
1041 .ok_or_else(|| {
1042 XlogError::Execution(format!(
1043 "resident grid overflow: worlds={num_worlds} blocks_per_world={blocks_per_world}"
1044 ))
1045 })?;
1046 let launch_cfg = LaunchConfig {
1047 grid_dim: (grid_dim, 1, 1),
1048 block_dim: (block_dim, 1, 1),
1049 shared_mem_bytes: 0,
1050 };
1051 unsafe {
1054 let args = (
1055 &d_cfg,
1056 &d_meta,
1057 &mut d_rel,
1058 &samples_device,
1059 &mut d_query_counts,
1060 &mut d_evidence_count,
1061 &mut d_iter_trace,
1062 &mut d_sparse_columns,
1063 &mut d_sparse_counts,
1064 &mut d_sparse_final_counts,
1065 &mut d_sparse_offsets,
1066 &mut d_resident_status_flags,
1067 sparse_cap as u32,
1068 );
1069 if blocks_per_world == 1 {
1070 engine_fn
1071 .launch(launch_cfg, args)
1072 .map_err(|e| XlogError::Kernel(format!("mc_resident_engine launch failed: {e}")))?;
1073 } else {
1074 engine_fn
1075 .launch_cooperative(launch_cfg, args)
1076 .map_err(|e| {
1077 XlogError::Kernel(format!("mc_resident_engine cooperative launch failed: {e}"))
1078 })?;
1079 }
1080 }
1081 engine_launches += 1;
1082 dev.synchronize()?;
1083
1084 let post = provider.host_transfer_stats();
1085 let post_untracked = provider.untracked_metadata_dtoh_count();
1086 let post_allocs = provider.memory().alloc_count();
1087 let no_host = McNoHostStats {
1090 tracked_htod_calls: post.htod_calls.saturating_sub(pre.htod_calls),
1091 tracked_dtoh_calls: post.dtoh_calls.saturating_sub(pre.dtoh_calls),
1092 untracked_metadata_reads: post_untracked.saturating_sub(pre_untracked),
1093 engine_launches,
1094 host_loop_iterations: 0,
1095 host_fixpoint_iterations: 0,
1098 per_operator_host_allocations: post_allocs.saturating_sub(pre_allocs),
1099 per_sample_host_launches: 0,
1100 };
1101
1102 Ok(McResidentResult {
1103 query_counts: d_query_counts,
1104 evidence_count: d_evidence_count,
1105 iter_trace: d_iter_trace,
1106 sparse_final_row_counts: d_sparse_final_counts,
1107 sparse_offsets: d_sparse_offsets,
1108 resident_status_flags: d_resident_status_flags,
1109 total_samples: cfg.samples,
1110 seed: cfg.seed,
1111 confidence: cfg.confidence,
1112 sampling_method: method,
1113 num_queries: q_count,
1114 no_host,
1115 })
1116}