1use crate::types::{ScoredCandidate, Topology};
19use xlog_core::RelId;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub struct ScoredPair {
26 pub topology: Topology,
27 pub left_rel_idx: RelId,
28 pub right_rel_idx: RelId,
29 pub positives_covered: u32,
30 pub negatives_covered: u32,
31}
32
33pub fn reduce_per_topology(
37 scored_pairs: &[ScoredPair],
38 head_rel_idx: RelId,
39 k_per_topology: u32,
40) -> Vec<ScoredCandidate> {
41 let mut result = Vec::new();
42 let k = k_per_topology as usize;
43
44 for topology in Topology::ALL {
45 let mut this_topo: Vec<ScoredPair> = scored_pairs
47 .iter()
48 .copied()
49 .filter(|p| p.topology == topology)
50 .collect();
51
52 this_topo.sort_by(|a, b| {
54 (
55 std::cmp::Reverse(a.positives_covered),
56 a.negatives_covered,
57 a.left_rel_idx.0,
58 a.right_rel_idx.0,
59 )
60 .cmp(&(
61 std::cmp::Reverse(b.positives_covered),
62 b.negatives_covered,
63 b.left_rel_idx.0,
64 b.right_rel_idx.0,
65 ))
66 });
67
68 let positives: Vec<ScoredPair> = this_topo
70 .iter()
71 .copied()
72 .filter(|p| p.positives_covered > 0)
73 .collect();
74
75 let kept_n = std::cmp::min(k, positives.len());
76 for (rank, pair) in positives.iter().take(kept_n).enumerate() {
77 let (next_pos, next_neg) = positives
79 .get(rank + 1)
80 .map(|nxt| (nxt.positives_covered, nxt.negatives_covered))
81 .unwrap_or((0, 0));
82
83 let tie_count = this_topo
85 .iter()
86 .filter(|s| {
87 s.positives_covered == pair.positives_covered
88 && s.negatives_covered == pair.negatives_covered
89 })
90 .count() as u32;
91
92 result.push(ScoredCandidate {
93 topology,
94 head_rel_idx,
95 left_rel_idx: pair.left_rel_idx,
96 right_rel_idx: pair.right_rel_idx,
97 positives_covered: pair.positives_covered,
98 negatives_covered: pair.negatives_covered,
99 local_rank: rank as u32,
100 next_positives_covered: next_pos,
101 next_negatives_covered: next_neg,
102 tie_class_size: tie_count,
103 });
104 }
105 }
106
107 result
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 fn pair(topo: Topology, l: u32, r: u32, pos: u32, neg: u32) -> ScoredPair {
115 ScoredPair {
116 topology: topo,
117 left_rel_idx: RelId(l),
118 right_rel_idx: RelId(r),
119 positives_covered: pos,
120 negatives_covered: neg,
121 }
122 }
123
124 const HEAD: RelId = RelId(100);
125
126 #[test]
127 fn empty_input_yields_empty_output() {
128 let result = reduce_per_topology(&[], HEAD, 2);
129 assert!(result.is_empty());
130 }
131
132 #[test]
133 fn zero_k_yields_empty_output() {
134 let pairs = vec![pair(Topology::Chain, 1, 2, 5, 0)];
135 let result = reduce_per_topology(&pairs, HEAD, 0);
136 assert!(result.is_empty());
137 }
138
139 #[test]
140 fn zero_coverage_pairs_are_excluded_from_kept() {
141 let pairs = vec![
142 pair(Topology::Chain, 1, 2, 0, 0),
143 pair(Topology::Chain, 3, 4, 0, 0),
144 ];
145 let result = reduce_per_topology(&pairs, HEAD, 2);
146 assert!(result.is_empty());
147 }
148
149 #[test]
150 fn single_positive_pair_has_tie_one_and_no_next() {
151 let pairs = vec![pair(Topology::Chain, 1, 2, 5, 0)];
152 let result = reduce_per_topology(&pairs, HEAD, 2);
153 assert_eq!(result.len(), 1);
154 let c = &result[0];
155 assert_eq!(c.topology, Topology::Chain);
156 assert_eq!(c.left_rel_idx, RelId(1));
157 assert_eq!(c.right_rel_idx, RelId(2));
158 assert_eq!(c.positives_covered, 5);
159 assert_eq!(c.negatives_covered, 0);
160 assert_eq!(c.local_rank, 0);
161 assert_eq!(c.next_positives_covered, 0);
162 assert_eq!(c.next_negatives_covered, 0);
163 assert_eq!(c.tie_class_size, 1);
164 }
165
166 #[test]
167 fn max_positives_wins_over_higher_negatives() {
168 let pairs = vec![
170 pair(Topology::Chain, 1, 2, 3, 0),
171 pair(Topology::Chain, 3, 4, 5, 2),
172 ];
173 let result = reduce_per_topology(&pairs, HEAD, 1);
174 assert_eq!(result.len(), 1);
175 assert_eq!(result[0].positives_covered, 5);
176 assert_eq!(result[0].negatives_covered, 2);
177 }
178
179 #[test]
180 fn min_negatives_breaks_positives_tie() {
181 let pairs = vec![
183 pair(Topology::Chain, 1, 2, 5, 3),
184 pair(Topology::Chain, 3, 4, 5, 1),
185 ];
186 let result = reduce_per_topology(&pairs, HEAD, 1);
187 assert_eq!(result.len(), 1);
188 assert_eq!(result[0].left_rel_idx, RelId(3));
189 assert_eq!(result[0].negatives_covered, 1);
190 }
191
192 #[test]
193 fn left_idx_breaks_pos_neg_tie() {
194 let pairs = vec![
196 pair(Topology::Chain, 5, 2, 3, 0),
197 pair(Topology::Chain, 1, 2, 3, 0),
198 ];
199 let result = reduce_per_topology(&pairs, HEAD, 1);
200 assert_eq!(result.len(), 1);
201 assert_eq!(result[0].left_rel_idx, RelId(1));
202 }
203
204 #[test]
205 fn right_idx_breaks_all_other_ties() {
206 let pairs = vec![
208 pair(Topology::Chain, 1, 5, 3, 0),
209 pair(Topology::Chain, 1, 2, 3, 0),
210 ];
211 let result = reduce_per_topology(&pairs, HEAD, 1);
212 assert_eq!(result.len(), 1);
213 assert_eq!(result[0].right_rel_idx, RelId(2));
214 }
215
216 #[test]
217 fn top_k_truncation_preserves_order() {
218 let pairs = vec![
220 pair(Topology::Chain, 1, 2, 3, 0),
221 pair(Topology::Chain, 1, 3, 5, 0),
222 pair(Topology::Chain, 1, 4, 4, 0),
223 ];
224 let result = reduce_per_topology(&pairs, HEAD, 2);
225 assert_eq!(result.len(), 2);
226 assert_eq!(result[0].positives_covered, 5);
227 assert_eq!(result[0].local_rank, 0);
228 assert_eq!(result[1].positives_covered, 4);
229 assert_eq!(result[1].local_rank, 1);
230 }
231
232 #[test]
233 fn next_diagnostics_point_to_rank_plus_one_in_positive_list() {
234 let pairs = vec![
236 pair(Topology::Chain, 1, 1, 5, 0), pair(Topology::Chain, 2, 2, 4, 0), pair(Topology::Chain, 3, 3, 3, 0), ];
240 let result = reduce_per_topology(&pairs, HEAD, 2);
241 assert_eq!(result.len(), 2);
242 assert_eq!(result[0].next_positives_covered, 4);
244 assert_eq!(result[0].next_negatives_covered, 0);
245 assert_eq!(result[1].next_positives_covered, 3);
247 assert_eq!(result[1].next_negatives_covered, 0);
248 }
249
250 #[test]
251 fn next_diagnostics_are_zero_when_no_next() {
252 let pairs = vec![pair(Topology::Chain, 1, 1, 5, 0)];
254 let result = reduce_per_topology(&pairs, HEAD, 2);
255 assert_eq!(result[0].next_positives_covered, 0);
256 assert_eq!(result[0].next_negatives_covered, 0);
257 }
258
259 #[test]
260 fn tie_class_size_counts_same_pos_neg_in_full_sorted_list() {
261 let pairs = vec![
263 pair(Topology::Chain, 1, 1, 5, 0),
264 pair(Topology::Chain, 2, 2, 5, 0),
265 pair(Topology::Chain, 3, 3, 5, 0),
266 pair(Topology::Chain, 4, 4, 3, 0),
267 ];
268 let result = reduce_per_topology(&pairs, HEAD, 2);
269 assert_eq!(result[0].tie_class_size, 3);
271 assert_eq!(result[1].tie_class_size, 3);
272 }
273
274 #[test]
275 fn topologies_output_in_all_order() {
276 let pairs = vec![
278 pair(Topology::Fanin, 1, 1, 1, 0),
279 pair(Topology::Fanout, 1, 1, 1, 0),
280 pair(Topology::Star, 1, 1, 1, 0),
281 pair(Topology::Chain, 1, 1, 1, 0),
282 ];
283 let result = reduce_per_topology(&pairs, HEAD, 1);
284 assert_eq!(result.len(), 4);
285 assert_eq!(result[0].topology, Topology::Chain);
286 assert_eq!(result[1].topology, Topology::Star);
287 assert_eq!(result[2].topology, Topology::Fanout);
288 assert_eq!(result[3].topology, Topology::Fanin);
289 }
290
291 #[test]
292 fn topology_filtering_is_per_topology() {
293 let pairs = vec![
295 pair(Topology::Chain, 1, 1, 3, 0),
296 pair(Topology::Chain, 2, 2, 5, 0),
297 pair(Topology::Star, 3, 3, 2, 0),
298 pair(Topology::Star, 4, 4, 4, 0),
299 ];
300 let result = reduce_per_topology(&pairs, HEAD, 1);
301 assert_eq!(result.len(), 2);
302 assert_eq!(result[0].topology, Topology::Chain);
304 assert_eq!(result[0].positives_covered, 5);
305 assert_eq!(result[1].topology, Topology::Star);
307 assert_eq!(result[1].positives_covered, 4);
308 }
309
310 #[test]
311 fn k_larger_than_positives_returns_all_positives() {
312 let pairs = vec![
313 pair(Topology::Chain, 1, 1, 5, 0),
314 pair(Topology::Chain, 2, 2, 3, 0),
315 ];
316 let result = reduce_per_topology(&pairs, HEAD, 10);
317 assert_eq!(result.len(), 2);
318 }
319
320 #[test]
321 fn head_rel_idx_propagates_to_every_candidate() {
322 let pairs = vec![
323 pair(Topology::Chain, 1, 1, 5, 0),
324 pair(Topology::Star, 2, 2, 3, 0),
325 ];
326 let head = RelId(42);
327 let result = reduce_per_topology(&pairs, head, 1);
328 for c in &result {
329 assert_eq!(c.head_rel_idx, head);
330 }
331 }
332}