Skip to main content

xlog_core/
symbol.rs

1//! Global symbol interning for reversible string-to-ID mapping.
2
3use arrow::array::{Array, DictionaryArray, StringArray, UInt32Array};
4use arrow::datatypes::UInt32Type;
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock, RwLock};
7
8static REGISTRY: OnceLock<RwLock<SymbolRegistry>> = OnceLock::new();
9
10struct SymbolRegistry {
11    to_id: HashMap<String, u32>,
12    to_string: Vec<String>,
13}
14
15impl SymbolRegistry {
16    fn new() -> Self {
17        Self {
18            to_id: HashMap::new(),
19            to_string: Vec::new(),
20        }
21    }
22}
23
24fn registry() -> &'static RwLock<SymbolRegistry> {
25    REGISTRY.get_or_init(|| RwLock::new(SymbolRegistry::new()))
26}
27
28/// Intern a string, returning its unique ID.
29/// Thread-safe. Returns existing ID if already interned.
30pub fn intern(s: &str) -> u32 {
31    // Fast path: check if already interned (read lock)
32    {
33        let reg = registry().read().unwrap();
34        if let Some(&id) = reg.to_id.get(s) {
35            return id;
36        }
37    }
38    // Slow path: insert new (write lock)
39    let mut reg = registry().write().unwrap();
40    // Double-check after acquiring write lock
41    if let Some(&id) = reg.to_id.get(s) {
42        return id;
43    }
44    let id = reg.to_string.len() as u32;
45    let owned = s.to_string();
46    reg.to_id.insert(owned.clone(), id);
47    reg.to_string.push(owned);
48    id
49}
50
51/// Resolve an ID to its string. Panics if ID is invalid.
52pub fn resolve(id: u32) -> String {
53    resolve_checked(id).expect("invalid symbol ID: this is a bug")
54}
55
56/// Resolve an ID to its string if present.
57pub fn resolve_checked(id: u32) -> Option<String> {
58    let reg = registry().read().unwrap();
59    reg.to_string.get(id as usize).cloned()
60}
61
62/// Clear all symbols. For testing/REPL only.
63/// WARNING: Invalidates all existing symbol IDs.
64pub fn clear() {
65    let mut reg = registry().write().unwrap();
66    reg.to_id.clear();
67    reg.to_string.clear();
68}
69
70/// Number of interned symbols.
71pub fn count() -> usize {
72    registry().read().unwrap().to_string.len()
73}
74
75/// Estimated memory usage in bytes.
76pub fn memory_usage() -> usize {
77    let reg = registry().read().unwrap();
78    let string_bytes: usize = reg.to_string.iter().map(|s| s.len()).sum();
79    let map_overhead =
80        reg.to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<u32>());
81    string_bytes + map_overhead
82}
83
84/// Convert a column of symbol IDs to Arrow DictionaryArray.
85pub fn to_arrow(ids: &[u32]) -> DictionaryArray<UInt32Type> {
86    use std::collections::HashSet;
87
88    // Collect unique IDs preserving order
89    let mut seen = HashSet::new();
90    let unique_ids: Vec<u32> = ids.iter().filter(|id| seen.insert(**id)).copied().collect();
91
92    // Build string dictionary
93    let dict_strings: Vec<String> = unique_ids.iter().map(|&id| resolve(id)).collect();
94    let dictionary = StringArray::from(dict_strings);
95
96    // Map original IDs to dictionary indices
97    let id_to_index: HashMap<u32, u32> = unique_ids
98        .iter()
99        .enumerate()
100        .map(|(i, &id)| (id, i as u32))
101        .collect();
102
103    let keys: Vec<u32> = ids.iter().map(|id| *id_to_index.get(id).unwrap()).collect();
104    let keys_array = UInt32Array::from(keys);
105
106    DictionaryArray::try_new(keys_array, Arc::new(dictionary)).unwrap()
107}
108
109/// Convert Arrow DictionaryArray back to symbol IDs.
110pub fn from_arrow(arr: &DictionaryArray<UInt32Type>) -> Vec<u32> {
111    let dict = arr
112        .values()
113        .as_any()
114        .downcast_ref::<StringArray>()
115        .expect("dictionary values must be StringArray");
116
117    // Intern all dictionary values
118    let dict_to_symbol: Vec<u32> = dict
119        .iter()
120        .map(|s| intern(s.expect("null not supported in symbols")))
121        .collect();
122
123    // Map keys through dictionary
124    arr.keys()
125        .iter()
126        .map(|k| {
127            let idx = k.expect("null keys not supported") as usize;
128            dict_to_symbol[idx]
129        })
130        .collect()
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use serial_test::serial;
137
138    // Each test must call setup() to get clean state
139    fn setup() {
140        clear();
141    }
142
143    #[test]
144    #[serial]
145    fn test_intern_sequential() {
146        setup();
147        assert_eq!(intern("foo"), 0);
148        assert_eq!(intern("bar"), 1);
149        assert_eq!(intern("baz"), 2);
150    }
151
152    #[test]
153    #[serial]
154    fn test_intern_idempotent() {
155        setup();
156        let id1 = intern("hello");
157        let id2 = intern("hello");
158        assert_eq!(id1, id2);
159        assert_eq!(count(), 1); // only one entry
160    }
161
162    #[test]
163    #[serial]
164    fn test_resolve_roundtrip() {
165        setup();
166        let id = intern("world");
167        assert_eq!(resolve(id), "world");
168    }
169
170    #[test]
171    #[serial]
172    #[should_panic(expected = "invalid symbol ID")]
173    fn test_resolve_invalid() {
174        setup();
175        resolve(9999); // should panic
176    }
177
178    #[test]
179    #[serial]
180    fn test_clear() {
181        setup();
182        intern("a");
183        intern("b");
184        assert_eq!(count(), 2);
185        clear();
186        assert_eq!(count(), 0);
187        assert_eq!(intern("a"), 0); // IDs restart from 0
188    }
189
190    #[test]
191    #[serial]
192    fn test_empty_string() {
193        setup();
194        let id = intern("");
195        assert_eq!(resolve(id), "");
196    }
197
198    #[test]
199    #[serial]
200    fn test_unicode() {
201        setup();
202        let id = intern("日本語");
203        assert_eq!(resolve(id), "日本語");
204
205        let id2 = intern("émoji🎉");
206        assert_eq!(resolve(id2), "émoji🎉");
207    }
208
209    #[test]
210    #[serial]
211    fn test_concurrent_intern() {
212        setup();
213        use std::collections::HashSet;
214        use std::thread;
215
216        let handles: Vec<_> = (0..10)
217            .map(|i| {
218                thread::spawn(move || {
219                    let mut ids = Vec::new();
220                    for j in 0..100 {
221                        let s = format!("thread{}_{}", i, j);
222                        let id = intern(&s);
223                        ids.push((s, id));
224                    }
225                    ids
226                })
227            })
228            .collect();
229
230        let mut all_results = Vec::new();
231        for h in handles {
232            all_results.extend(h.join().unwrap());
233        }
234
235        // Verify all symbols resolve correctly
236        for (s, id) in &all_results {
237            assert_eq!(&resolve(*id), s);
238        }
239
240        // Verify we have 1000 unique symbols
241        assert_eq!(count(), 1000);
242
243        // Verify no duplicate IDs for different strings
244        let unique_ids: HashSet<u32> = all_results.iter().map(|(_, id)| *id).collect();
245        assert_eq!(unique_ids.len(), 1000);
246    }
247
248    #[test]
249    #[serial]
250    fn test_large_scale() {
251        setup();
252        use std::time::Instant;
253
254        let start = Instant::now();
255
256        // Intern 100K unique symbols
257        for i in 0..100_000 {
258            let s = format!("symbol_{:06}", i);
259            let id = intern(&s);
260            assert_eq!(id, i as u32);
261        }
262
263        let intern_time = start.elapsed();
264
265        // Verify all resolve correctly
266        let start = Instant::now();
267        for i in 0..100_000 {
268            let expected = format!("symbol_{:06}", i);
269            assert_eq!(resolve(i as u32), expected);
270        }
271        let resolve_time = start.elapsed();
272
273        // Verify count
274        assert_eq!(count(), 100_000);
275
276        // Log performance (not assertions, just info)
277        println!(
278            "100K intern: {:?}, 100K resolve: {:?}",
279            intern_time, resolve_time
280        );
281
282        // Memory should be reasonable (rough check: < 10MB for 100K symbols)
283        let mem = memory_usage();
284        assert!(mem < 10_000_000, "memory usage {} exceeds 10MB", mem);
285    }
286
287    #[test]
288    #[serial]
289    fn test_arrow_roundtrip() {
290        setup();
291        let ids = vec![
292            intern("apple"),
293            intern("banana"),
294            intern("apple"),
295            intern("cherry"),
296            intern("banana"),
297        ];
298
299        let arrow = to_arrow(&ids);
300        let back = from_arrow(&arrow);
301
302        assert_eq!(ids, back);
303    }
304}