Skip to main content

xlog_induce/
reduce.rs

1//! Deterministic per-topology top-K reduction and tie diagnostics.
2//!
3//! Behaviorally equivalent to the Python reference in
4//! `crates/pyxlog/python/pyxlog/ilp/exact_induce.py`:
5//!
6//! 1. Within each topology, sort scored pairs by lexicographic key
7//!    `(-positives_covered, negatives_covered, left_idx, right_idx)`.
8//! 2. Filter to pairs with `positives_covered > 0`, then keep the first
9//!    `k_per_topology`.
10//! 3. For each kept pair:
11//!    - `local_rank` = 0-indexed position within the positive-filtered list.
12//!    - `next_positives_covered` / `next_negatives_covered` = the next pair
13//!      in the positive-filtered list, or `(0, 0)` if there is none.
14//!    - `tie_class_size` = count of pairs in the FULL sorted list (including
15//!      zero-coverage) sharing the same `(positives_covered, negatives_covered)`.
16//! 4. Output groups candidates by `Topology::ALL` order.
17
18use crate::types::{ScoredCandidate, Topology};
19use xlog_core::RelId;
20
21/// One scored `(topology, left, right)` triple produced by the scoring stage.
22///
23/// Passed into [`reduce_per_topology`] as a flat list; grouping happens inside.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct ScoredPair {
26    pub topology: Topology,
27    pub left_rel_idx: RelId,
28    pub right_rel_idx: RelId,
29    pub positives_covered: u32,
30    pub negatives_covered: u32,
31}
32
33/// Reduce a flat scored-pair list to the final ordered `ScoredCandidate` list.
34///
35/// Matches the Python reference comparator and diagnostics bit-for-bit.
36pub fn reduce_per_topology(
37    scored_pairs: &[ScoredPair],
38    head_rel_idx: RelId,
39    k_per_topology: u32,
40) -> Vec<ScoredCandidate> {
41    let mut result = Vec::new();
42    let k = k_per_topology as usize;
43
44    for topology in Topology::ALL {
45        // Pull this topology's pairs into a sortable vector.
46        let mut this_topo: Vec<ScoredPair> = scored_pairs
47            .iter()
48            .copied()
49            .filter(|p| p.topology == topology)
50            .collect();
51
52        // Lexicographic sort: max positives, min negatives, min L idx, min R idx.
53        this_topo.sort_by(|a, b| {
54            (
55                std::cmp::Reverse(a.positives_covered),
56                a.negatives_covered,
57                a.left_rel_idx.0,
58                a.right_rel_idx.0,
59            )
60                .cmp(&(
61                    std::cmp::Reverse(b.positives_covered),
62                    b.negatives_covered,
63                    b.left_rel_idx.0,
64                    b.right_rel_idx.0,
65                ))
66        });
67
68        // Positive-coverage filter preserves sorted order.
69        let positives: Vec<ScoredPair> = this_topo
70            .iter()
71            .copied()
72            .filter(|p| p.positives_covered > 0)
73            .collect();
74
75        let kept_n = std::cmp::min(k, positives.len());
76        for (rank, pair) in positives.iter().take(kept_n).enumerate() {
77            // Diagnostics: next candidate in the positive-filtered list.
78            let (next_pos, next_neg) = positives
79                .get(rank + 1)
80                .map(|nxt| (nxt.positives_covered, nxt.negatives_covered))
81                .unwrap_or((0, 0));
82
83            // Tie class counted over the FULL sorted list (including zero-coverage).
84            let tie_count = this_topo
85                .iter()
86                .filter(|s| {
87                    s.positives_covered == pair.positives_covered
88                        && s.negatives_covered == pair.negatives_covered
89                })
90                .count() as u32;
91
92            result.push(ScoredCandidate {
93                topology,
94                head_rel_idx,
95                left_rel_idx: pair.left_rel_idx,
96                right_rel_idx: pair.right_rel_idx,
97                positives_covered: pair.positives_covered,
98                negatives_covered: pair.negatives_covered,
99                local_rank: rank as u32,
100                next_positives_covered: next_pos,
101                next_negatives_covered: next_neg,
102                tie_class_size: tie_count,
103            });
104        }
105    }
106
107    result
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    fn pair(topo: Topology, l: u32, r: u32, pos: u32, neg: u32) -> ScoredPair {
115        ScoredPair {
116            topology: topo,
117            left_rel_idx: RelId(l),
118            right_rel_idx: RelId(r),
119            positives_covered: pos,
120            negatives_covered: neg,
121        }
122    }
123
124    const HEAD: RelId = RelId(100);
125
126    #[test]
127    fn empty_input_yields_empty_output() {
128        let result = reduce_per_topology(&[], HEAD, 2);
129        assert!(result.is_empty());
130    }
131
132    #[test]
133    fn zero_k_yields_empty_output() {
134        let pairs = vec![pair(Topology::Chain, 1, 2, 5, 0)];
135        let result = reduce_per_topology(&pairs, HEAD, 0);
136        assert!(result.is_empty());
137    }
138
139    #[test]
140    fn zero_coverage_pairs_are_excluded_from_kept() {
141        let pairs = vec![
142            pair(Topology::Chain, 1, 2, 0, 0),
143            pair(Topology::Chain, 3, 4, 0, 0),
144        ];
145        let result = reduce_per_topology(&pairs, HEAD, 2);
146        assert!(result.is_empty());
147    }
148
149    #[test]
150    fn single_positive_pair_has_tie_one_and_no_next() {
151        let pairs = vec![pair(Topology::Chain, 1, 2, 5, 0)];
152        let result = reduce_per_topology(&pairs, HEAD, 2);
153        assert_eq!(result.len(), 1);
154        let c = &result[0];
155        assert_eq!(c.topology, Topology::Chain);
156        assert_eq!(c.left_rel_idx, RelId(1));
157        assert_eq!(c.right_rel_idx, RelId(2));
158        assert_eq!(c.positives_covered, 5);
159        assert_eq!(c.negatives_covered, 0);
160        assert_eq!(c.local_rank, 0);
161        assert_eq!(c.next_positives_covered, 0);
162        assert_eq!(c.next_negatives_covered, 0);
163        assert_eq!(c.tie_class_size, 1);
164    }
165
166    #[test]
167    fn max_positives_wins_over_higher_negatives() {
168        // pair A: pos=5, neg=2; pair B: pos=3, neg=0 → A wins despite more negatives.
169        let pairs = vec![
170            pair(Topology::Chain, 1, 2, 3, 0),
171            pair(Topology::Chain, 3, 4, 5, 2),
172        ];
173        let result = reduce_per_topology(&pairs, HEAD, 1);
174        assert_eq!(result.len(), 1);
175        assert_eq!(result[0].positives_covered, 5);
176        assert_eq!(result[0].negatives_covered, 2);
177    }
178
179    #[test]
180    fn min_negatives_breaks_positives_tie() {
181        // Same positives — lower negatives wins.
182        let pairs = vec![
183            pair(Topology::Chain, 1, 2, 5, 3),
184            pair(Topology::Chain, 3, 4, 5, 1),
185        ];
186        let result = reduce_per_topology(&pairs, HEAD, 1);
187        assert_eq!(result.len(), 1);
188        assert_eq!(result[0].left_rel_idx, RelId(3));
189        assert_eq!(result[0].negatives_covered, 1);
190    }
191
192    #[test]
193    fn left_idx_breaks_pos_neg_tie() {
194        // Same positives+negatives — lower left idx wins.
195        let pairs = vec![
196            pair(Topology::Chain, 5, 2, 3, 0),
197            pair(Topology::Chain, 1, 2, 3, 0),
198        ];
199        let result = reduce_per_topology(&pairs, HEAD, 1);
200        assert_eq!(result.len(), 1);
201        assert_eq!(result[0].left_rel_idx, RelId(1));
202    }
203
204    #[test]
205    fn right_idx_breaks_all_other_ties() {
206        // Same pos, neg, left — lower right idx wins.
207        let pairs = vec![
208            pair(Topology::Chain, 1, 5, 3, 0),
209            pair(Topology::Chain, 1, 2, 3, 0),
210        ];
211        let result = reduce_per_topology(&pairs, HEAD, 1);
212        assert_eq!(result.len(), 1);
213        assert_eq!(result[0].right_rel_idx, RelId(2));
214    }
215
216    #[test]
217    fn top_k_truncation_preserves_order() {
218        // Three positive pairs, K=2 → keep the top two.
219        let pairs = vec![
220            pair(Topology::Chain, 1, 2, 3, 0),
221            pair(Topology::Chain, 1, 3, 5, 0),
222            pair(Topology::Chain, 1, 4, 4, 0),
223        ];
224        let result = reduce_per_topology(&pairs, HEAD, 2);
225        assert_eq!(result.len(), 2);
226        assert_eq!(result[0].positives_covered, 5);
227        assert_eq!(result[0].local_rank, 0);
228        assert_eq!(result[1].positives_covered, 4);
229        assert_eq!(result[1].local_rank, 1);
230    }
231
232    #[test]
233    fn next_diagnostics_point_to_rank_plus_one_in_positive_list() {
234        // Three positives; K=2. Top-2 get next_* from positions 1 and 2.
235        let pairs = vec![
236            pair(Topology::Chain, 1, 1, 5, 0), // rank 0
237            pair(Topology::Chain, 2, 2, 4, 0), // rank 1
238            pair(Topology::Chain, 3, 3, 3, 0), // rank 2
239        ];
240        let result = reduce_per_topology(&pairs, HEAD, 2);
241        assert_eq!(result.len(), 2);
242        // rank 0's next_* is rank 1's (pos=4, neg=0)
243        assert_eq!(result[0].next_positives_covered, 4);
244        assert_eq!(result[0].next_negatives_covered, 0);
245        // rank 1's next_* is rank 2's (pos=3, neg=0)
246        assert_eq!(result[1].next_positives_covered, 3);
247        assert_eq!(result[1].next_negatives_covered, 0);
248    }
249
250    #[test]
251    fn next_diagnostics_are_zero_when_no_next() {
252        // Only one positive — next_* should be (0, 0).
253        let pairs = vec![pair(Topology::Chain, 1, 1, 5, 0)];
254        let result = reduce_per_topology(&pairs, HEAD, 2);
255        assert_eq!(result[0].next_positives_covered, 0);
256        assert_eq!(result[0].next_negatives_covered, 0);
257    }
258
259    #[test]
260    fn tie_class_size_counts_same_pos_neg_in_full_sorted_list() {
261        // Three pairs share (pos=5, neg=0); one has different (pos=3, neg=0).
262        let pairs = vec![
263            pair(Topology::Chain, 1, 1, 5, 0),
264            pair(Topology::Chain, 2, 2, 5, 0),
265            pair(Topology::Chain, 3, 3, 5, 0),
266            pair(Topology::Chain, 4, 4, 3, 0),
267        ];
268        let result = reduce_per_topology(&pairs, HEAD, 2);
269        // Top-2 should both have tie_class_size=3 (three pairs share pos=5, neg=0).
270        assert_eq!(result[0].tie_class_size, 3);
271        assert_eq!(result[1].tie_class_size, 3);
272    }
273
274    #[test]
275    fn topologies_output_in_all_order() {
276        // One positive per topology; check output order = chain, star, fanout, fanin.
277        let pairs = vec![
278            pair(Topology::Fanin, 1, 1, 1, 0),
279            pair(Topology::Fanout, 1, 1, 1, 0),
280            pair(Topology::Star, 1, 1, 1, 0),
281            pair(Topology::Chain, 1, 1, 1, 0),
282        ];
283        let result = reduce_per_topology(&pairs, HEAD, 1);
284        assert_eq!(result.len(), 4);
285        assert_eq!(result[0].topology, Topology::Chain);
286        assert_eq!(result[1].topology, Topology::Star);
287        assert_eq!(result[2].topology, Topology::Fanout);
288        assert_eq!(result[3].topology, Topology::Fanin);
289    }
290
291    #[test]
292    fn topology_filtering_is_per_topology() {
293        // Pairs from both chain and star; ranking happens independently.
294        let pairs = vec![
295            pair(Topology::Chain, 1, 1, 3, 0),
296            pair(Topology::Chain, 2, 2, 5, 0),
297            pair(Topology::Star, 3, 3, 2, 0),
298            pair(Topology::Star, 4, 4, 4, 0),
299        ];
300        let result = reduce_per_topology(&pairs, HEAD, 1);
301        assert_eq!(result.len(), 2);
302        // chain winner: pos=5
303        assert_eq!(result[0].topology, Topology::Chain);
304        assert_eq!(result[0].positives_covered, 5);
305        // star winner: pos=4
306        assert_eq!(result[1].topology, Topology::Star);
307        assert_eq!(result[1].positives_covered, 4);
308    }
309
310    #[test]
311    fn k_larger_than_positives_returns_all_positives() {
312        let pairs = vec![
313            pair(Topology::Chain, 1, 1, 5, 0),
314            pair(Topology::Chain, 2, 2, 3, 0),
315        ];
316        let result = reduce_per_topology(&pairs, HEAD, 10);
317        assert_eq!(result.len(), 2);
318    }
319
320    #[test]
321    fn head_rel_idx_propagates_to_every_candidate() {
322        let pairs = vec![
323            pair(Topology::Chain, 1, 1, 5, 0),
324            pair(Topology::Star, 2, 2, 3, 0),
325        ];
326        let head = RelId(42);
327        let result = reduce_per_topology(&pairs, head, 1);
328        for c in &result {
329            assert_eq!(c.head_rel_idx, head);
330        }
331    }
332}