1use arrow::array::{Array, DictionaryArray, StringArray, UInt32Array};
4use arrow::datatypes::UInt32Type;
5use std::collections::HashMap;
6use std::sync::{Arc, OnceLock, RwLock};
7
8static REGISTRY: OnceLock<RwLock<SymbolRegistry>> = OnceLock::new();
9
10struct SymbolRegistry {
11 to_id: HashMap<String, u32>,
12 to_string: Vec<String>,
13}
14
15impl SymbolRegistry {
16 fn new() -> Self {
17 Self {
18 to_id: HashMap::new(),
19 to_string: Vec::new(),
20 }
21 }
22}
23
24fn registry() -> &'static RwLock<SymbolRegistry> {
25 REGISTRY.get_or_init(|| RwLock::new(SymbolRegistry::new()))
26}
27
28pub fn intern(s: &str) -> u32 {
31 {
33 let reg = registry().read().unwrap();
34 if let Some(&id) = reg.to_id.get(s) {
35 return id;
36 }
37 }
38 let mut reg = registry().write().unwrap();
40 if let Some(&id) = reg.to_id.get(s) {
42 return id;
43 }
44 let id = reg.to_string.len() as u32;
45 let owned = s.to_string();
46 reg.to_id.insert(owned.clone(), id);
47 reg.to_string.push(owned);
48 id
49}
50
51pub fn resolve(id: u32) -> String {
53 resolve_checked(id).expect("invalid symbol ID: this is a bug")
54}
55
56pub fn resolve_checked(id: u32) -> Option<String> {
58 let reg = registry().read().unwrap();
59 reg.to_string.get(id as usize).cloned()
60}
61
62pub fn clear() {
65 let mut reg = registry().write().unwrap();
66 reg.to_id.clear();
67 reg.to_string.clear();
68}
69
70pub fn count() -> usize {
72 registry().read().unwrap().to_string.len()
73}
74
75pub fn memory_usage() -> usize {
77 let reg = registry().read().unwrap();
78 let string_bytes: usize = reg.to_string.iter().map(|s| s.len()).sum();
79 let map_overhead =
80 reg.to_id.len() * (std::mem::size_of::<String>() + std::mem::size_of::<u32>());
81 string_bytes + map_overhead
82}
83
84pub fn to_arrow(ids: &[u32]) -> DictionaryArray<UInt32Type> {
86 use std::collections::HashSet;
87
88 let mut seen = HashSet::new();
90 let unique_ids: Vec<u32> = ids.iter().filter(|id| seen.insert(**id)).copied().collect();
91
92 let dict_strings: Vec<String> = unique_ids.iter().map(|&id| resolve(id)).collect();
94 let dictionary = StringArray::from(dict_strings);
95
96 let id_to_index: HashMap<u32, u32> = unique_ids
98 .iter()
99 .enumerate()
100 .map(|(i, &id)| (id, i as u32))
101 .collect();
102
103 let keys: Vec<u32> = ids.iter().map(|id| *id_to_index.get(id).unwrap()).collect();
104 let keys_array = UInt32Array::from(keys);
105
106 DictionaryArray::try_new(keys_array, Arc::new(dictionary)).unwrap()
107}
108
109pub fn from_arrow(arr: &DictionaryArray<UInt32Type>) -> Vec<u32> {
111 let dict = arr
112 .values()
113 .as_any()
114 .downcast_ref::<StringArray>()
115 .expect("dictionary values must be StringArray");
116
117 let dict_to_symbol: Vec<u32> = dict
119 .iter()
120 .map(|s| intern(s.expect("null not supported in symbols")))
121 .collect();
122
123 arr.keys()
125 .iter()
126 .map(|k| {
127 let idx = k.expect("null keys not supported") as usize;
128 dict_to_symbol[idx]
129 })
130 .collect()
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use serial_test::serial;
137
138 fn setup() {
140 clear();
141 }
142
143 #[test]
144 #[serial]
145 fn test_intern_sequential() {
146 setup();
147 assert_eq!(intern("foo"), 0);
148 assert_eq!(intern("bar"), 1);
149 assert_eq!(intern("baz"), 2);
150 }
151
152 #[test]
153 #[serial]
154 fn test_intern_idempotent() {
155 setup();
156 let id1 = intern("hello");
157 let id2 = intern("hello");
158 assert_eq!(id1, id2);
159 assert_eq!(count(), 1); }
161
162 #[test]
163 #[serial]
164 fn test_resolve_roundtrip() {
165 setup();
166 let id = intern("world");
167 assert_eq!(resolve(id), "world");
168 }
169
170 #[test]
171 #[serial]
172 #[should_panic(expected = "invalid symbol ID")]
173 fn test_resolve_invalid() {
174 setup();
175 resolve(9999); }
177
178 #[test]
179 #[serial]
180 fn test_clear() {
181 setup();
182 intern("a");
183 intern("b");
184 assert_eq!(count(), 2);
185 clear();
186 assert_eq!(count(), 0);
187 assert_eq!(intern("a"), 0); }
189
190 #[test]
191 #[serial]
192 fn test_empty_string() {
193 setup();
194 let id = intern("");
195 assert_eq!(resolve(id), "");
196 }
197
198 #[test]
199 #[serial]
200 fn test_unicode() {
201 setup();
202 let id = intern("日本語");
203 assert_eq!(resolve(id), "日本語");
204
205 let id2 = intern("émoji🎉");
206 assert_eq!(resolve(id2), "émoji🎉");
207 }
208
209 #[test]
210 #[serial]
211 fn test_concurrent_intern() {
212 setup();
213 use std::collections::HashSet;
214 use std::thread;
215
216 let handles: Vec<_> = (0..10)
217 .map(|i| {
218 thread::spawn(move || {
219 let mut ids = Vec::new();
220 for j in 0..100 {
221 let s = format!("thread{}_{}", i, j);
222 let id = intern(&s);
223 ids.push((s, id));
224 }
225 ids
226 })
227 })
228 .collect();
229
230 let mut all_results = Vec::new();
231 for h in handles {
232 all_results.extend(h.join().unwrap());
233 }
234
235 for (s, id) in &all_results {
237 assert_eq!(&resolve(*id), s);
238 }
239
240 assert_eq!(count(), 1000);
242
243 let unique_ids: HashSet<u32> = all_results.iter().map(|(_, id)| *id).collect();
245 assert_eq!(unique_ids.len(), 1000);
246 }
247
248 #[test]
249 #[serial]
250 fn test_large_scale() {
251 setup();
252 use std::time::Instant;
253
254 let start = Instant::now();
255
256 for i in 0..100_000 {
258 let s = format!("symbol_{:06}", i);
259 let id = intern(&s);
260 assert_eq!(id, i as u32);
261 }
262
263 let intern_time = start.elapsed();
264
265 let start = Instant::now();
267 for i in 0..100_000 {
268 let expected = format!("symbol_{:06}", i);
269 assert_eq!(resolve(i as u32), expected);
270 }
271 let resolve_time = start.elapsed();
272
273 assert_eq!(count(), 100_000);
275
276 println!(
278 "100K intern: {:?}, 100K resolve: {:?}",
279 intern_time, resolve_time
280 );
281
282 let mem = memory_usage();
284 assert!(mem < 10_000_000, "memory usage {} exceeds 10MB", mem);
285 }
286
287 #[test]
288 #[serial]
289 fn test_arrow_roundtrip() {
290 setup();
291 let ids = vec![
292 intern("apple"),
293 intern("banana"),
294 intern("apple"),
295 intern("cherry"),
296 intern("banana"),
297 ];
298
299 let arrow = to_arrow(&ids);
300 let back = from_arrow(&arrow);
301
302 assert_eq!(ids, back);
303 }
304}