Skip to main content

xlog_prob/mc/
resident.rs

1//! GPU-resident Datalog/MC execution engine.
2//!
3//! This module replaces the host-orchestrated per-sample MC loop with a single
4//! device megakernel (`mc_resident_engine`) that evaluates *all* Monte Carlo
5//! worlds to a device-side fixpoint and counts query/evidence satisfaction with
6//! **zero host interaction inside the measured region**: no host loop over
7//! samples, no per-sample host kernel sequencing, no host metadata reads
8//! (tracked or untracked).
9//!
10//! Supported fragment (bounded-domain positive Datalog), checked structurally —
11//! never by predicate name:
12//!   * predicate arity <= [`MAX_ARITY`];
13//!   * rule body length <= [`MAX_BODY`] positive literals (no negation,
14//!     comparison, arithmetic, epistemic, aggregate, or univ literals);
15//!   * <= [`MAX_VARS`] distinct variables per rule;
16//!   * all terms are variables or ground constants (no lists/compounds/functors);
17//!   * bounded Herbrand universe: `domain^arity` slots, with the total universe
18//!     and per-predicate domain capped so one world fits one CUDA block.
19//!   * positive joins are lowered to device rule records and evaluated from a
20//!     world-segmented sparse column arena (`slot`, `arg0`, `arg1`) with
21//!     device row counters and static device offsets; the dense bitset is only
22//!     a device-side membership/dedup index.
23//!   * worst-case sparse/WCOJ arena bounds are checked before device allocation
24//!     when `XLOG_MC_RESIDENT_MEMORY_BUDGET_BYTES` is set.
25//!
26//! Anything outside the fragment is rejected **before execution** with a typed
27//! [`ResidentRejection`] (fail-closed; no CPU fallback).
28
29use std::collections::BTreeMap;
30use std::sync::Arc;
31
32use cudarc::driver::LaunchConfig;
33use xlog_core::{Result, XlogError};
34use xlog_cuda::memory::TrackedCudaSlice;
35use xlog_cuda::provider::{mc_resident_kernels, MC_RESIDENT_MODULE};
36use xlog_cuda::{CudaKernelProvider, LaunchAsync};
37use xlog_logic::ast::{Atom, BodyLiteral, Term};
38
39use super::{McEvalConfig, McProgram, McSamplingMethod};
40use crate::provenance::{GroundAtom, Value};
41
42/// Maximum supported predicate arity.
43pub const MAX_ARITY: usize = 3;
44/// Maximum supported rule body length (positive literals).
45pub const MAX_BODY: usize = 3;
46/// Maximum supported distinct variables per rule.
47pub const MAX_VARS: usize = 8;
48/// Maximum supported bounded-universe slot count (one world must fit one block's
49/// working set comfortably).
50pub const MAX_UNIVERSE: usize = 1 << 16;
51/// Maximum supported domain size (distinct constants).
52pub const MAX_DOMAIN: usize = 256;
53/// u32 width of one device atom record: base, arity, arg0, arg1, arg2, stride0.
54const ATOM_REC: usize = 6;
55/// u32 width of one device rule record: n_body, n_vars, domain, head, body0..body2.
56const RULE_REC: usize = 3 + 4 * ATOM_REC;
57/// Device encoding flag marking an arg spec as a bound constant.
58const CONST_FLAG: u32 = 0x8000_0000;
59const RESIDENT_BUDGET_ENV: &str = "XLOG_MC_RESIDENT_MEMORY_BUDGET_BYTES";
60const RESIDENT_BLOCKS_PER_WORLD_ENV: &str = "XLOG_MC_RESIDENT_BLOCKS_PER_WORLD";
61
62/// Kind of a fail-closed rejection of an MC program by the resident engine.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub enum ResidentRejectKind {
65    /// A body literal uses negation.
66    Negation,
67    /// A body literal is epistemic (`know`/`possible`).
68    EpistemicLiteral,
69    /// A body literal is a comparison / arithmetic / univ (non-relational).
70    NonRelationalLiteral,
71    /// A predicate arity exceeds [`MAX_ARITY`].
72    ArityTooHigh,
73    /// A rule body has more than [`MAX_BODY`] literals.
74    BodyTooLong,
75    /// A rule uses more than [`MAX_VARS`] distinct variables.
76    TooManyVars,
77    /// A term is not a variable or ground constant (list/compound/functor/agg).
78    UnboundedTerm,
79    /// The bounded domain exceeds [`MAX_DOMAIN`].
80    DomainTooLarge,
81    /// The bounded universe exceeds [`MAX_UNIVERSE`].
82    UniverseTooLarge,
83    /// A predicate appears with inconsistent arity.
84    InconsistentArity,
85    /// Annotated disjunctions are not yet supported by the resident engine.
86    AnnotatedDisjunctionUnsupported,
87}
88
89impl ResidentRejectKind {
90    pub fn as_str(self) -> &'static str {
91        match self {
92            ResidentRejectKind::Negation => "negation",
93            ResidentRejectKind::EpistemicLiteral => "epistemic_literal",
94            ResidentRejectKind::NonRelationalLiteral => "non_relational_literal",
95            ResidentRejectKind::ArityTooHigh => "arity_too_high",
96            ResidentRejectKind::BodyTooLong => "body_too_long",
97            ResidentRejectKind::TooManyVars => "too_many_vars",
98            ResidentRejectKind::UnboundedTerm => "unbounded_term",
99            ResidentRejectKind::DomainTooLarge => "domain_too_large",
100            ResidentRejectKind::UniverseTooLarge => "universe_too_large",
101            ResidentRejectKind::InconsistentArity => "inconsistent_arity",
102            ResidentRejectKind::AnnotatedDisjunctionUnsupported => {
103                "annotated_disjunction_unsupported"
104            }
105        }
106    }
107}
108
109/// A typed fail-closed rejection: which rule was violated, the offending
110/// construct, and the surrounding context (for diagnostics).
111#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct ResidentRejection {
113    pub kind: ResidentRejectKind,
114    pub construct: String,
115    pub context: String,
116}
117
118impl ResidentRejection {
119    fn err(
120        kind: ResidentRejectKind,
121        construct: impl Into<String>,
122        context: impl Into<String>,
123    ) -> Self {
124        ResidentRejection {
125            kind,
126            construct: construct.into(),
127            context: context.into(),
128        }
129    }
130
131    /// Convert to the engine's unified error type while preserving the kind +
132    /// construct + context in the message for callers that only see `XlogError`.
133    pub fn into_error(self) -> XlogError {
134        XlogError::Compilation(format!(
135            "resident MC engine rejected program [kind={}] construct=`{}` context=`{}`",
136            self.kind.as_str(),
137            self.construct,
138            self.context
139        ))
140    }
141}
142
143/// Canonical key for a ground constant, unifying `Term` literals and `Value`s so
144/// the same constant in a rule and in a fact map to one domain index.
145#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
146enum ConstKey {
147    Int(i64),
148    Sym(u32),
149    Str(String),
150    FloatBits(u64),
151}
152
153impl ConstKey {
154    fn from_value(v: &Value) -> ConstKey {
155        match v {
156            Value::I64(i) => ConstKey::Int(*i),
157            Value::Symbol(s) => ConstKey::Sym(*s),
158            Value::String(s) => ConstKey::Str(s.clone()),
159            Value::F64(bits) => ConstKey::FloatBits(*bits),
160        }
161    }
162
163    /// Map a rule term to either a constant key or a variable name.
164    fn from_term(t: &Term) -> std::result::Result<TermClass, ResidentRejection> {
165        match t {
166            Term::Variable(name) => Ok(TermClass::Var(name.clone())),
167            Term::Integer(i) => Ok(TermClass::Const(ConstKey::Int(*i))),
168            Term::Symbol(s) => Ok(TermClass::Const(ConstKey::Sym(*s))),
169            Term::String(s) => Ok(TermClass::Const(ConstKey::Str(s.clone()))),
170            Term::Float(f) => Ok(TermClass::Const(ConstKey::FloatBits(f.to_bits()))),
171            other => Err(ResidentRejection::err(
172                ResidentRejectKind::UnboundedTerm,
173                format!("{:?}", other),
174                "rule term must be a variable or ground constant",
175            )),
176        }
177    }
178}
179
180enum TermClass {
181    Var(String),
182    Const(ConstKey),
183}
184
185/// No-host-interaction instrumentation for the measured engine region.
186#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
187pub struct McNoHostStats {
188    /// Tracked data-plane host-to-device calls inside the measured region.
189    pub tracked_htod_calls: u64,
190    /// Tracked data-plane device-to-host calls inside the measured region.
191    pub tracked_dtoh_calls: u64,
192    /// Untracked control-plane metadata reads inside the measured region.
193    pub untracked_metadata_reads: u64,
194    /// Number of kernel launches the engine issued inside the measured region.
195    pub engine_launches: u64,
196    /// Number of host-side per-sample loop iterations inside the measured region
197    /// (structurally zero: the engine has no host sample loop).
198    pub host_loop_iterations: u64,
199    /// Number of per-sample host launches inside the measured region
200    /// (structurally zero: one global launch covers all worlds).
201    pub per_sample_host_launches: u64,
202    /// Number of host-side fixpoint-loop iterations inside the measured region
203    /// (structurally zero: recursion converges device-side). Required by the
204    /// WCOJ world-batched acceptance contract.
205    pub host_fixpoint_iterations: u64,
206    /// Number of device allocations issued inside the measured region (must be
207    /// zero: every arena is pre-allocated before the region). Backed by the
208    /// memory manager's `alloc_count` snapshot.
209    pub per_operator_host_allocations: u64,
210}
211
212impl McNoHostStats {
213    /// True iff the measured region had **no host interaction**: no tracked
214    /// transfers, no untracked metadata reads, no host sample loop, no
215    /// per-sample host launches, no host fixpoint loop, no in-region device
216    /// allocations. (A single global engine launch is permitted and is *not*
217    /// per-sample; see [`Self::engine_launches`].)
218    pub fn is_no_host(&self) -> bool {
219        self.tracked_htod_calls == 0
220            && self.tracked_dtoh_calls == 0
221            && self.untracked_metadata_reads == 0
222            && self.host_loop_iterations == 0
223            && self.per_sample_host_launches == 0
224            && self.host_fixpoint_iterations == 0
225            && self.per_operator_host_allocations == 0
226    }
227}
228
229/// Device-resident result of a GPU-resident MC run. Counts stay on device; the
230/// caller decides whether/when to download them (after the measured region).
231pub struct McResidentResult {
232    pub query_counts: TrackedCudaSlice<u32>,
233    pub evidence_count: TrackedCudaSlice<u32>,
234    /// Per-world fixpoint iteration count (device-resident).
235    pub iter_trace: TrackedCudaSlice<u32>,
236    /// Final sparse row count per world (device-resident).
237    pub sparse_final_row_counts: TrackedCudaSlice<u32>,
238    /// World-segment offsets for the sparse arena, populated by the resident
239    /// kernel (device-resident).
240    pub sparse_offsets: TrackedCudaSlice<u32>,
241    /// Per-world resident status flags, populated by the resident kernel:
242    /// `[converged_flags; sparse_overflow_flags; block_participation; scratch...]`.
243    pub resident_status_flags: TrackedCudaSlice<u32>,
244    pub total_samples: usize,
245    pub seed: u64,
246    pub confidence: f64,
247    pub sampling_method: McSamplingMethod,
248    /// Number of query atoms (== `query_counts.len()`).
249    pub num_queries: usize,
250    pub no_host: McNoHostStats,
251}
252
253/// Compiled, device-uploadable plan for the resident engine.
254#[derive(Debug, Clone)]
255pub struct ResidentPlan {
256    pub universe_size: u32,
257    pub domain_size: u32,
258    pub max_iters: u32,
259    edb_slots: Vec<u32>,
260    pf_slot: Vec<u32>,
261    pf_var: Vec<u32>,
262    rule_data: Vec<u32>,
263    num_rules: u32,
264    q_slot: Vec<u32>,
265    ev_slot: Vec<u32>,
266    ev_expected: Vec<u8>,
267    ad_data: Vec<u32>,
268    num_ads: u32,
269    pub num_vars: usize,
270    bernoulli_probs: Vec<f32>,
271}
272
273fn resident_memory_budget_bytes() -> Result<Option<u64>> {
274    match std::env::var(RESIDENT_BUDGET_ENV) {
275        Ok(raw) => raw.parse::<u64>().map(Some).map_err(|e| {
276            XlogError::Execution(format!("invalid {RESIDENT_BUDGET_ENV} value `{raw}`: {e}"))
277        }),
278        Err(std::env::VarError::NotPresent) => Ok(None),
279        Err(e) => Err(XlogError::Execution(format!(
280            "invalid {RESIDENT_BUDGET_ENV}: {e}"
281        ))),
282    }
283}
284
285fn resident_blocks_per_world() -> Result<u32> {
286    match std::env::var(RESIDENT_BLOCKS_PER_WORLD_ENV) {
287        Ok(raw) => {
288            let blocks = raw.parse::<u32>().map_err(|e| {
289                XlogError::Execution(format!(
290                    "invalid {RESIDENT_BLOCKS_PER_WORLD_ENV} value `{raw}`: {e}"
291                ))
292            })?;
293            if blocks == 0 {
294                return Err(XlogError::Execution(format!(
295                    "invalid {RESIDENT_BLOCKS_PER_WORLD_ENV} value `{raw}`: must be >= 1"
296                )));
297            }
298            Ok(blocks)
299        }
300        Err(std::env::VarError::NotPresent) => Ok(1),
301        Err(e) => Err(XlogError::Execution(format!(
302            "invalid {RESIDENT_BLOCKS_PER_WORLD_ENV}: {e}"
303        ))),
304    }
305}
306
307fn sat_mul(a: u64, b: u64) -> u64 {
308    a.saturating_mul(b)
309}
310
311fn sat_pow(mut base: u64, mut exp: u32) -> u64 {
312    let mut acc = 1u64;
313    while exp > 0 {
314        if exp & 1 == 1 {
315            acc = sat_mul(acc, base);
316        }
317        exp >>= 1;
318        if exp > 0 {
319            base = sat_mul(base, base);
320        }
321    }
322    acc
323}
324
325fn estimate_resident_bound_bytes(plan: &ResidentPlan, num_worlds: u32) -> u64 {
326    let worlds = num_worlds.max(1) as u64;
327    let vars = plan.num_vars.max(1) as u64;
328    let universe = plan.universe_size.max(1) as u64;
329    let meta_words = plan
330        .edb_slots
331        .len()
332        .saturating_add(plan.pf_slot.len())
333        .saturating_add(plan.pf_var.len())
334        .saturating_add(plan.rule_data.len())
335        .saturating_add(plan.q_slot.len())
336        .saturating_add(plan.ev_slot.len())
337        .saturating_add(plan.ev_expected.len())
338        .saturating_add(plan.ad_data.len())
339        .saturating_add(18);
340
341    let sparse_cap = universe;
342    let setup_bytes = sat_mul(2, vars)
343        .saturating_add(sat_mul(worlds, vars))
344        .saturating_add(sat_mul(sat_mul(sat_mul(worlds, 2), universe), 4))
345        .saturating_add(sat_mul(sat_mul(sat_mul(worlds, 2), sparse_cap), 16))
346        .saturating_add(sat_mul(sat_mul(worlds, 2), 4))
347        .saturating_add(sat_mul(worlds, 4))
348        .saturating_add(sat_mul(worlds.saturating_add(1), 4))
349        .saturating_add(sat_mul(worlds.saturating_mul(4).saturating_add(1), 4))
350        .saturating_add(sat_mul(meta_words as u64, 4))
351        .saturating_add(sat_mul(plan.q_slot.len().max(1) as u64, 4))
352        .saturating_add(4)
353        .saturating_add(sat_mul(worlds, 4));
354
355    let mut sparse_join_bytes = 0u64;
356    for rule in plan.rule_data.chunks_exact(RULE_REC) {
357        let n_body = rule[0];
358        if n_body < 2 {
359            continue;
360        }
361        let n_vars = rule[1];
362        let assignments = sat_pow(plan.domain_size.max(1) as u64, n_vars);
363        let row_words = (n_body as u64).saturating_add(1);
364        sparse_join_bytes =
365            sparse_join_bytes.max(sat_mul(sat_mul(worlds, assignments), row_words * 4));
366    }
367
368    setup_bytes.saturating_add(sparse_join_bytes)
369}
370
371struct PredInfo {
372    arity: usize,
373    base: u32,
374}
375
376struct Universe {
377    domain: BTreeMap<ConstKey, u32>,
378    preds: BTreeMap<String, PredInfo>,
379    domain_size: u32,
380}
381
382impl Universe {
383    fn stride0(&self, arity: usize) -> u32 {
384        if arity >= 2 {
385            self.domain_size.pow((arity - 1) as u32)
386        } else {
387            1
388        }
389    }
390
391    fn arg_stride(&self, arity: usize, arg_idx: usize) -> u32 {
392        if arity <= arg_idx + 1 {
393            1
394        } else {
395            self.domain_size.pow((arity - arg_idx - 1) as u32)
396        }
397    }
398
399    /// Slot id of a ground atom (predicate + constant args).
400    fn ground_slot(&self, atom: &GroundAtom) -> std::result::Result<u32, ResidentRejection> {
401        let info = self.preds.get(&atom.predicate).ok_or_else(|| {
402            ResidentRejection::err(
403                ResidentRejectKind::InconsistentArity,
404                atom.predicate.clone(),
405                "ground atom references unknown predicate",
406            )
407        })?;
408        if atom.args.len() != info.arity {
409            return Err(ResidentRejection::err(
410                ResidentRejectKind::InconsistentArity,
411                atom.predicate.clone(),
412                format!("expected arity {} got {}", info.arity, atom.args.len()),
413            ));
414        }
415        let mut slot = info.base;
416        for (i, v) in atom.args.iter().enumerate() {
417            let key = ConstKey::from_value(v);
418            let idx = *self.domain.get(&key).ok_or_else(|| {
419                ResidentRejection::err(
420                    ResidentRejectKind::UnboundedTerm,
421                    format!("{:?}", v),
422                    "ground constant absent from bounded domain",
423                )
424            })?;
425            slot += idx * self.arg_stride(info.arity, i);
426        }
427        Ok(slot)
428    }
429}
430
431/// Compile an [`McProgram`] into a resident-engine plan, or return a typed
432/// fail-closed rejection for anything outside the supported fragment.
433pub fn compile_resident_plan(
434    mc: &McProgram,
435) -> std::result::Result<ResidentPlan, ResidentRejection> {
436    let program = &mc.program;
437
438    // --- 1. Gather predicate arities (consistency-checked). ---
439    let mut arities: BTreeMap<String, usize> = BTreeMap::new();
440    let mut note_pred = |pred: &str, arity: usize| -> std::result::Result<(), ResidentRejection> {
441        if arity > MAX_ARITY {
442            return Err(ResidentRejection::err(
443                ResidentRejectKind::ArityTooHigh,
444                pred.to_string(),
445                format!("arity {} exceeds max {}", arity, MAX_ARITY),
446            ));
447        }
448        match arities.get(pred) {
449            Some(&existing) if existing != arity => Err(ResidentRejection::err(
450                ResidentRejectKind::InconsistentArity,
451                pred.to_string(),
452                format!("arity {} vs {}", existing, arity),
453            )),
454            _ => {
455                arities.insert(pred.to_string(), arity);
456                Ok(())
457            }
458        }
459    };
460
461    // --- 2. Collect constants for the bounded domain. ---
462    let mut domain: BTreeMap<ConstKey, u32> = BTreeMap::new();
463    let mut note_const = |key: ConstKey, domain: &mut BTreeMap<ConstKey, u32>| {
464        let next = domain.len() as u32;
465        domain.entry(key).or_insert(next);
466    };
467
468    // Deterministic facts.
469    for fact in program.facts() {
470        note_pred(&fact.head.predicate, fact.head.terms.len())?;
471        for t in &fact.head.terms {
472            match ConstKey::from_term(t)? {
473                TermClass::Const(k) => note_const(k, &mut domain),
474                TermClass::Var(_) => {
475                    return Err(ResidentRejection::err(
476                        ResidentRejectKind::UnboundedTerm,
477                        fact.head.predicate.clone(),
478                        "fact head contains a variable",
479                    ))
480                }
481            }
482        }
483    }
484    // Prob facts (ground).
485    for pf in &mc.prob_facts {
486        note_pred(&pf.atom.predicate, pf.atom.args.len())?;
487        for v in &pf.atom.args {
488            note_const(ConstKey::from_value(v), &mut domain);
489        }
490    }
491    // Queries + evidence (ground).
492    for q in &mc.queries {
493        note_pred(&q.predicate, q.args.len())?;
494        for v in &q.args {
495            note_const(ConstKey::from_value(v), &mut domain);
496        }
497    }
498    for (e, _) in &mc.evidence {
499        note_pred(&e.predicate, e.args.len())?;
500        for v in &e.args {
501            note_const(ConstKey::from_value(v), &mut domain);
502        }
503    }
504    // Annotated-disjunction choice atoms (ground).
505    for ad in &mc.annotated_disjunctions {
506        for atom in &ad.choices {
507            note_pred(&atom.predicate, atom.args.len())?;
508            for v in &atom.args {
509                note_const(ConstKey::from_value(v), &mut domain);
510            }
511        }
512    }
513    // Rules: collect predicates + constants (variables ranged over the domain).
514    for rule in &program.rules {
515        if rule.is_fact() {
516            continue;
517        }
518        note_pred(&rule.head.predicate, rule.head.terms.len())?;
519        collect_atom_consts(&rule.head, &mut domain, &mut note_const)?;
520        if rule.body.len() > MAX_BODY {
521            return Err(ResidentRejection::err(
522                ResidentRejectKind::BodyTooLong,
523                rule.head.predicate.clone(),
524                format!("body length {} exceeds max {}", rule.body.len(), MAX_BODY),
525            ));
526        }
527        for lit in &rule.body {
528            let atom = classify_body_literal(lit, &rule.head.predicate)?;
529            note_pred(&atom.predicate, atom.terms.len())?;
530            collect_atom_consts(atom, &mut domain, &mut note_const)?;
531        }
532    }
533
534    if domain.len() > MAX_DOMAIN {
535        return Err(ResidentRejection::err(
536            ResidentRejectKind::DomainTooLarge,
537            format!("{} constants", domain.len()),
538            format!("domain exceeds max {}", MAX_DOMAIN),
539        ));
540    }
541    let domain_size = domain.len() as u32;
542
543    // --- 3. Assign predicate slot blocks (deterministic order). ---
544    let mut preds: BTreeMap<String, PredInfo> = BTreeMap::new();
545    let mut base: u64 = 0;
546    for (pred, &arity) in &arities {
547        let slot_count: u64 = if arity == 0 {
548            1
549        } else {
550            (domain_size as u64).pow(arity as u32)
551        };
552        preds.insert(
553            pred.clone(),
554            PredInfo {
555                arity,
556                base: base as u32,
557            },
558        );
559        base += slot_count;
560        if base > MAX_UNIVERSE as u64 {
561            return Err(ResidentRejection::err(
562                ResidentRejectKind::UniverseTooLarge,
563                format!("{} slots", base),
564                format!("universe exceeds max {}", MAX_UNIVERSE),
565            ));
566        }
567    }
568    let universe_size = base as u32;
569
570    let universe = Universe {
571        domain,
572        preds,
573        domain_size,
574    };
575
576    // --- 4. Lower EDB facts, prob facts, queries, evidence to slots. ---
577    let mut edb_slots = Vec::new();
578    for fact in program.facts() {
579        let ga = ground_atom_from_atom(&fact.head)?;
580        edb_slots.push(universe.ground_slot(&ga)?);
581    }
582    let mut pf_slot = Vec::new();
583    let mut pf_var = Vec::new();
584    for pf in &mc.prob_facts {
585        pf_slot.push(universe.ground_slot(&pf.atom)?);
586        pf_var.push(pf.var_idx as u32);
587    }
588    let mut q_slot = Vec::new();
589    for q in &mc.queries {
590        q_slot.push(universe.ground_slot(q)?);
591    }
592    let mut ev_slot = Vec::new();
593    let mut ev_expected = Vec::new();
594    for (e, v) in &mc.evidence {
595        ev_slot.push(universe.ground_slot(e)?);
596        ev_expected.push(if *v { 1u8 } else { 0u8 });
597    }
598
599    // --- 5. Lower rules to device records. ---
600    let mut rule_data = Vec::new();
601    let mut num_rules = 0u32;
602    for rule in &program.rules {
603        if rule.is_fact() {
604            continue;
605        }
606        let rec = lower_rule(rule, &universe)?;
607        rule_data.extend_from_slice(&rec);
608        num_rules += 1;
609    }
610
611    // --- 6. Lower annotated disjunctions to device records. ---
612    // Each record: [n_choices, n_dvars, dvar_0.., slot_0..]. The kernel walks the
613    // conditional-Bernoulli chain (first firing decision wins; residual = last
614    // choice or "none").
615    let mut ad_data: Vec<u32> = Vec::new();
616    let mut num_ads = 0u32;
617    for ad in &mc.annotated_disjunctions {
618        let n_choices = ad.choices.len() as u32;
619        let n_dvars = ad.decision_vars.len() as u32;
620        ad_data.push(n_choices);
621        ad_data.push(n_dvars);
622        for &dv in &ad.decision_vars {
623            ad_data.push(u32::try_from(dv).map_err(|_| {
624                ResidentRejection::err(
625                    ResidentRejectKind::UnboundedTerm,
626                    "decision_var",
627                    "AD decision var index exceeds u32",
628                )
629            })?);
630        }
631        for atom in &ad.choices {
632            ad_data.push(universe.ground_slot(atom)?);
633        }
634        num_ads += 1;
635    }
636
637    // Iteration cap: at most `universe_size` distinct atoms can be derived, so a
638    // monotone fixpoint must converge within that many passes. +1 for the final
639    // no-change confirmation pass.
640    let max_iters = universe_size.saturating_add(1).max(1);
641
642    Ok(ResidentPlan {
643        universe_size,
644        domain_size,
645        max_iters,
646        edb_slots,
647        pf_slot,
648        pf_var,
649        rule_data,
650        num_rules,
651        q_slot,
652        ev_slot,
653        ev_expected,
654        ad_data,
655        num_ads,
656        num_vars: mc.bernoulli_probs.len(),
657        bernoulli_probs: mc.bernoulli_probs.clone(),
658    })
659}
660
661fn collect_atom_consts<F: FnMut(ConstKey, &mut BTreeMap<ConstKey, u32>)>(
662    atom: &Atom,
663    domain: &mut BTreeMap<ConstKey, u32>,
664    note_const: &mut F,
665) -> std::result::Result<(), ResidentRejection> {
666    if atom.terms.len() > MAX_ARITY {
667        return Err(ResidentRejection::err(
668            ResidentRejectKind::ArityTooHigh,
669            atom.predicate.clone(),
670            format!("arity {} exceeds max {}", atom.terms.len(), MAX_ARITY),
671        ));
672    }
673    for t in &atom.terms {
674        if let TermClass::Const(k) = ConstKey::from_term(t)? {
675            note_const(k, domain);
676        }
677    }
678    Ok(())
679}
680
681/// Classify a body literal: only positive relational atoms are supported.
682fn classify_body_literal<'a>(
683    lit: &'a BodyLiteral,
684    rule_ctx: &str,
685) -> std::result::Result<&'a Atom, ResidentRejection> {
686    match lit {
687        BodyLiteral::Positive(a) => Ok(a),
688        BodyLiteral::Negated(a) => Err(ResidentRejection::err(
689            ResidentRejectKind::Negation,
690            a.predicate.clone(),
691            format!("negated literal in rule for `{}`", rule_ctx),
692        )),
693        BodyLiteral::Epistemic(l) => Err(ResidentRejection::err(
694            ResidentRejectKind::EpistemicLiteral,
695            l.atom.predicate.clone(),
696            format!("epistemic literal in rule for `{}`", rule_ctx),
697        )),
698        BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => {
699            Err(ResidentRejection::err(
700                ResidentRejectKind::NonRelationalLiteral,
701                "comparison/is/univ",
702                format!("non-relational literal in rule for `{}`", rule_ctx),
703            ))
704        }
705    }
706}
707
708/// Lower one rule to its fixed-width device record.
709fn lower_rule(
710    rule: &xlog_logic::ast::Rule,
711    universe: &Universe,
712) -> std::result::Result<Vec<u32>, ResidentRejection> {
713    // Assign variable ids by first occurrence across head + body.
714    let mut var_ids: BTreeMap<String, u32> = BTreeMap::new();
715    let assign_var = |name: &str,
716                      var_ids: &mut BTreeMap<String, u32>|
717     -> std::result::Result<u32, ResidentRejection> {
718        if let Some(&id) = var_ids.get(name) {
719            return Ok(id);
720        }
721        let id = var_ids.len() as u32;
722        if id as usize >= MAX_VARS {
723            return Err(ResidentRejection::err(
724                ResidentRejectKind::TooManyVars,
725                rule.head.predicate.clone(),
726                format!("more than {} distinct variables", MAX_VARS),
727            ));
728        }
729        var_ids.insert(name.to_string(), id);
730        Ok(id)
731    };
732
733    // First pass: assign var ids deterministically (head then body).
734    let body_atoms: Vec<&Atom> = {
735        let mut v = Vec::new();
736        for lit in &rule.body {
737            v.push(classify_body_literal(lit, &rule.head.predicate)?);
738        }
739        v
740    };
741    for t in &rule.head.terms {
742        if let TermClass::Var(name) = ConstKey::from_term(t)? {
743            assign_var(&name, &mut var_ids)?;
744        }
745    }
746    for atom in &body_atoms {
747        for t in &atom.terms {
748            if let TermClass::Var(name) = ConstKey::from_term(t)? {
749                assign_var(&name, &mut var_ids)?;
750            }
751        }
752    }
753    let n_vars = var_ids.len() as u32;
754
755    let encode_atom = |atom: &Atom| -> std::result::Result<[u32; ATOM_REC], ResidentRejection> {
756        let info = universe.preds.get(&atom.predicate).ok_or_else(|| {
757            ResidentRejection::err(
758                ResidentRejectKind::InconsistentArity,
759                atom.predicate.clone(),
760                "rule atom references unknown predicate",
761            )
762        })?;
763        let arity = info.arity as u32;
764        let mut rec = [0u32; ATOM_REC];
765        rec[0] = info.base;
766        rec[1] = arity;
767        rec[5] = universe.stride0(info.arity);
768        for (i, t) in atom.terms.iter().enumerate() {
769            let spec = match ConstKey::from_term(t)? {
770                TermClass::Var(name) => *var_ids.get(&name).expect("var assigned above"),
771                TermClass::Const(k) => {
772                    let idx = *universe.domain.get(&k).ok_or_else(|| {
773                        ResidentRejection::err(
774                            ResidentRejectKind::UnboundedTerm,
775                            format!("{:?}", k),
776                            "rule constant absent from bounded domain",
777                        )
778                    })?;
779                    CONST_FLAG | idx
780                }
781            };
782            rec[2 + i] = spec;
783        }
784        Ok(rec)
785    };
786
787    let mut rec = vec![0u32; RULE_REC];
788    rec[0] = body_atoms.len() as u32;
789    rec[1] = n_vars;
790    rec[2] = universe.domain_size;
791    let head_rec = encode_atom(&rule.head)?;
792    rec[3..3 + ATOM_REC].copy_from_slice(&head_rec);
793    for (bi, atom) in body_atoms.iter().enumerate() {
794        let a = encode_atom(atom)?;
795        let off = 3 + ATOM_REC + bi * ATOM_REC;
796        rec[off..off + ATOM_REC].copy_from_slice(&a);
797    }
798    Ok(rec)
799}
800
801fn ground_atom_from_atom(atom: &Atom) -> std::result::Result<GroundAtom, ResidentRejection> {
802    let mut args = Vec::with_capacity(atom.terms.len());
803    for t in &atom.terms {
804        let v = match t {
805            Term::Integer(i) => Value::I64(*i),
806            Term::Symbol(s) => Value::Symbol(*s),
807            Term::String(s) => Value::String(s.clone()),
808            Term::Float(f) => Value::F64(f.to_bits()),
809            other => {
810                return Err(ResidentRejection::err(
811                    ResidentRejectKind::UnboundedTerm,
812                    format!("{:?}", other),
813                    "fact term must be a ground constant",
814                ))
815            }
816        };
817        args.push(v);
818    }
819    Ok(GroundAtom {
820        predicate: atom.predicate.clone(),
821        args,
822    })
823}
824
825impl McProgram {
826    /// Evaluate this program with the GPU-resident engine. Returns device-resident
827    /// counts plus no-host instrumentation for the measured region. Fails closed
828    /// (typed [`ResidentRejection`] wrapped into [`XlogError`]) for programs
829    /// outside the supported fragment.
830    pub fn evaluate_resident_with_provider(
831        &self,
832        cfg: McEvalConfig,
833        provider: Arc<CudaKernelProvider>,
834    ) -> Result<McResidentResult> {
835        cfg.validate()?;
836        let plan = compile_resident_plan(self).map_err(ResidentRejection::into_error)?;
837        run_resident(&plan, &cfg, self, provider)
838    }
839
840    /// Convenience: evaluate with a fresh provider.
841    pub fn evaluate_resident(&self, cfg: McEvalConfig) -> Result<McResidentResult> {
842        let provider = Arc::new(self.provider()?);
843        self.evaluate_resident_with_provider(cfg, provider)
844    }
845}
846
847fn run_resident(
848    plan: &ResidentPlan,
849    cfg: &McEvalConfig,
850    mc: &McProgram,
851    provider: Arc<CudaKernelProvider>,
852) -> Result<McResidentResult> {
853    let (method, forcing) = mc.resolve_sampling_method(cfg.sampling_method)?;
854    let num_worlds = u32::try_from(cfg.samples)
855        .map_err(|_| XlogError::Execution("MC samples exceed u32::MAX".to_string()))?;
856    let blocks_per_world = resident_blocks_per_world()?;
857    let num_vars = plan.num_vars;
858
859    if let Some(budget_bytes) = resident_memory_budget_bytes()? {
860        let bound_bytes = estimate_resident_bound_bytes(plan, num_worlds);
861        if bound_bytes > budget_bytes {
862            return Err(XlogError::ResourceExhausted {
863                context: format!(
864                    "resident_resource_budget operator=sparse_wcoj bound_bytes={bound_bytes} budget_bytes={budget_bytes}"
865                ),
866                estimated_bytes: bound_bytes,
867                budget_bytes,
868            });
869        }
870    }
871
872    // ---------------- Static setup (BEFORE the measured region) ----------------
873    // All device allocations, the seeded sample matrix, force arrays, and every
874    // plan array are uploaded here. Allocation syncs; nothing below the measured
875    // marker may transfer or read back.
876    let dev = provider.device();
877
878    // Sample matrix (device-resident), forced for clamped evidence.
879    let mut d_force_mask = provider.memory().alloc::<u8>(num_vars.max(1))?;
880    let mut d_forced_value = provider.memory().alloc::<u8>(num_vars.max(1))?;
881    if method == McSamplingMethod::EvidenceClamping && num_vars > 0 {
882        provider.htod_sync_copy_into_tracked(&forcing.force_mask, &mut d_force_mask)?;
883        provider.htod_sync_copy_into_tracked(&forcing.forced_value, &mut d_forced_value)?;
884    } else {
885        dev.inner()
886            .memset_zeros(&mut d_force_mask)
887            .map_err(|e| XlogError::Kernel(format!("zero force_mask: {e}")))?;
888        dev.inner()
889            .memset_zeros(&mut d_forced_value)
890            .map_err(|e| XlogError::Kernel(format!("zero forced_value: {e}")))?;
891    }
892    let samples_device = if num_vars == 0 || cfg.samples == 0 {
893        provider.memory().alloc::<u8>(1)?
894    } else {
895        provider.sample_bernoulli_matrix_device(
896            &plan.bernoulli_probs,
897            cfg.samples,
898            cfg.seed,
899            &d_force_mask.slice(..),
900            &d_forced_value.slice(..),
901        )?
902    };
903
904    // Working dense membership index [num_worlds * 2 * U], pre-zeroed. This is a
905    // device-side dedup/index sidecar for the sparse columnar relation arena.
906    let u = plan.universe_size.max(1) as usize;
907    let rel_len = (num_worlds as usize)
908        .saturating_mul(u)
909        .saturating_mul(2)
910        .max(1);
911    let mut d_rel = provider.memory().alloc::<u32>(rel_len)?;
912    dev.inner()
913        .memset_zeros(&mut d_rel)
914        .map_err(|e| XlogError::Kernel(format!("zero rel: {e}")))?;
915
916    // Sparse world-segmented columnar relation arena. Capacity is the static
917    // universe bound per world/buffer; row counts and offsets remain device
918    // resident and are populated by the resident kernel.
919    let sparse_cap = u.max(1);
920    let sparse_len = (num_worlds as usize)
921        .saturating_mul(2)
922        .saturating_mul(sparse_cap)
923        .max(1);
924    // Four contiguous u32 columns: slot | arg0 | arg1 | arg2.
925    let mut d_sparse_columns = provider
926        .memory()
927        .alloc::<u32>(sparse_len.saturating_mul(4).max(1))?;
928    let mut d_sparse_counts = provider
929        .memory()
930        .alloc::<u32>((num_worlds as usize).saturating_mul(2).max(1))?;
931    let mut d_sparse_final_counts = provider
932        .memory()
933        .alloc::<u32>((num_worlds as usize).max(1))?;
934    let mut d_sparse_offsets = provider
935        .memory()
936        .alloc::<u32>((num_worlds as usize).saturating_add(1).max(1))?;
937    let mut d_resident_status_flags = provider.memory().alloc::<u32>(
938        (num_worlds as usize)
939            .saturating_mul(4)
940            .saturating_add(1)
941            .max(1),
942    )?;
943    dev.inner()
944        .memset_zeros(&mut d_sparse_columns)
945        .map_err(|e| XlogError::Kernel(format!("zero sparse_columns: {e}")))?;
946    dev.inner()
947        .memset_zeros(&mut d_sparse_counts)
948        .map_err(|e| XlogError::Kernel(format!("zero sparse_counts: {e}")))?;
949    dev.inner()
950        .memset_zeros(&mut d_sparse_final_counts)
951        .map_err(|e| XlogError::Kernel(format!("zero sparse_final_counts: {e}")))?;
952    dev.inner()
953        .memset_zeros(&mut d_sparse_offsets)
954        .map_err(|e| XlogError::Kernel(format!("zero sparse_offsets: {e}")))?;
955    dev.inner()
956        .memset_zeros(&mut d_resident_status_flags)
957        .map_err(|e| XlogError::Kernel(format!("zero resident_status_flags: {e}")))?;
958
959    // Pack every read-only plan array into one contiguous `meta` u32 buffer and
960    // record offsets in the `cfg` header. This keeps the kernel under the launch
961    // arg-arity limit (7 params) and uses one HtoD per buffer.
962    let q_count = plan.q_slot.len();
963    let ev_expected_u32: Vec<u32> = plan.ev_expected.iter().map(|&b| b as u32).collect();
964
965    let mut meta: Vec<u32> = Vec::new();
966    let push_meta = |data: &[u32], meta: &mut Vec<u32>| -> u32 {
967        let off = meta.len() as u32;
968        meta.extend_from_slice(data);
969        off
970    };
971    let edb_off = push_meta(&plan.edb_slots, &mut meta);
972    let pf_slot_off = push_meta(&plan.pf_slot, &mut meta);
973    let pf_var_off = push_meta(&plan.pf_var, &mut meta);
974    let rules_off = push_meta(&plan.rule_data, &mut meta);
975    let q_off = push_meta(&plan.q_slot, &mut meta);
976    let ev_slot_off = push_meta(&plan.ev_slot, &mut meta);
977    let ev_exp_off = push_meta(&ev_expected_u32, &mut meta);
978    let ad_off = push_meta(&plan.ad_data, &mut meta);
979
980    // cfg header, indices mirror the `CFG_*` defines in mc_resident.cu.
981    let cfg_host: [u32; 19] = [
982        num_worlds,
983        plan.universe_size,
984        num_vars as u32,
985        plan.max_iters,
986        edb_off,
987        plan.edb_slots.len() as u32,
988        pf_slot_off,
989        pf_var_off,
990        plan.pf_slot.len() as u32,
991        rules_off,
992        plan.num_rules,
993        q_off,
994        q_count as u32,
995        ev_slot_off,
996        ev_exp_off,
997        plan.ev_slot.len() as u32,
998        ad_off,
999        plan.num_ads,
1000        blocks_per_world,
1001    ];
1002
1003    let mut d_cfg = provider.memory().alloc::<u32>(cfg_host.len())?;
1004    provider.htod_sync_copy_into_tracked(&cfg_host, &mut d_cfg)?;
1005    let mut d_meta = provider.memory().alloc::<u32>(meta.len().max(1))?;
1006    if !meta.is_empty() {
1007        provider.htod_sync_copy_into_tracked(&meta, &mut d_meta)?;
1008    }
1009
1010    let mut d_query_counts = provider.memory().alloc::<u32>(q_count.max(1))?;
1011    dev.inner()
1012        .memset_zeros(&mut d_query_counts)
1013        .map_err(|e| XlogError::Kernel(format!("zero query_counts: {e}")))?;
1014    let mut d_evidence_count = provider.memory().alloc::<u32>(1)?;
1015    dev.inner()
1016        .memset_zeros(&mut d_evidence_count)
1017        .map_err(|e| XlogError::Kernel(format!("zero evidence_count: {e}")))?;
1018    let mut d_iter_trace = provider.memory().alloc::<u32>(num_worlds.max(1) as usize)?;
1019    dev.inner()
1020        .memset_zeros(&mut d_iter_trace)
1021        .map_err(|e| XlogError::Kernel(format!("zero iter_trace: {e}")))?;
1022
1023    let engine_fn = dev
1024        .inner()
1025        .get_func(MC_RESIDENT_MODULE, mc_resident_kernels::MC_RESIDENT_ENGINE)
1026        .ok_or_else(|| XlogError::Kernel("mc_resident_engine kernel not found".to_string()))?;
1027
1028    // Ensure all setup work is complete before we start measuring.
1029    dev.synchronize()?;
1030
1031    // ---------------- Measured region (ZERO host interaction) ----------------
1032    let pre = provider.host_transfer_stats();
1033    let pre_untracked = provider.untracked_metadata_dtoh_count();
1034    let pre_allocs = provider.memory().alloc_count();
1035    let mut engine_launches = 0u64;
1036
1037    let block_dim = 128u32;
1038    let grid_dim = num_worlds
1039        .max(1)
1040        .checked_mul(blocks_per_world)
1041        .ok_or_else(|| {
1042            XlogError::Execution(format!(
1043                "resident grid overflow: worlds={num_worlds} blocks_per_world={blocks_per_world}"
1044            ))
1045        })?;
1046    let launch_cfg = LaunchConfig {
1047        grid_dim: (grid_dim, 1, 1),
1048        block_dim: (block_dim, 1, 1),
1049        shared_mem_bytes: 0,
1050    };
1051    // SAFETY: argument list matches the `mc_resident_engine` PTX signature; every
1052    // device buffer was allocated with sufficient length above.
1053    unsafe {
1054        let args = (
1055            &d_cfg,
1056            &d_meta,
1057            &mut d_rel,
1058            &samples_device,
1059            &mut d_query_counts,
1060            &mut d_evidence_count,
1061            &mut d_iter_trace,
1062            &mut d_sparse_columns,
1063            &mut d_sparse_counts,
1064            &mut d_sparse_final_counts,
1065            &mut d_sparse_offsets,
1066            &mut d_resident_status_flags,
1067            sparse_cap as u32,
1068        );
1069        if blocks_per_world == 1 {
1070            engine_fn
1071                .launch(launch_cfg, args)
1072                .map_err(|e| XlogError::Kernel(format!("mc_resident_engine launch failed: {e}")))?;
1073        } else {
1074            engine_fn
1075                .launch_cooperative(launch_cfg, args)
1076                .map_err(|e| {
1077                    XlogError::Kernel(format!("mc_resident_engine cooperative launch failed: {e}"))
1078                })?;
1079        }
1080    }
1081    engine_launches += 1;
1082    dev.synchronize()?;
1083
1084    let post = provider.host_transfer_stats();
1085    let post_untracked = provider.untracked_metadata_dtoh_count();
1086    let post_allocs = provider.memory().alloc_count();
1087    // ---------------- End measured region ----------------
1088
1089    let no_host = McNoHostStats {
1090        tracked_htod_calls: post.htod_calls.saturating_sub(pre.htod_calls),
1091        tracked_dtoh_calls: post.dtoh_calls.saturating_sub(pre.dtoh_calls),
1092        untracked_metadata_reads: post_untracked.saturating_sub(pre_untracked),
1093        engine_launches,
1094        host_loop_iterations: 0,
1095        // The dense engine has no host fixpoint loop (convergence is device-side
1096        // inside the megakernel) and allocates every arena before the region.
1097        host_fixpoint_iterations: 0,
1098        per_operator_host_allocations: post_allocs.saturating_sub(pre_allocs),
1099        per_sample_host_launches: 0,
1100    };
1101
1102    Ok(McResidentResult {
1103        query_counts: d_query_counts,
1104        evidence_count: d_evidence_count,
1105        iter_trace: d_iter_trace,
1106        sparse_final_row_counts: d_sparse_final_counts,
1107        sparse_offsets: d_sparse_offsets,
1108        resident_status_flags: d_resident_status_flags,
1109        total_samples: cfg.samples,
1110        seed: cfg.seed,
1111        confidence: cfg.confidence,
1112        sampling_method: method,
1113        num_queries: q_count,
1114        no_host,
1115    })
1116}