once/multiindexed/
iter.rs

1use std::{marker::PhantomData, ops::Range};
2
3use super::{KdTrie, node::Node};
4use crate::MultiIndexed;
5
6// --- Iterator ---
7
8/// A stack frame in the depth-first traversal of a [`KdTrie`].
9///
10/// Each frame records the current node, its depth in the trie (i.e. which coordinate dimension it
11/// indexes), and the remaining range of indices to visit at this node.
12struct IterFrame<R> {
13    depth: usize,
14    current_node: R,
15    range: Range<i32>,
16}
17
18/// A depth-first iterator over a [`KdTrie`], generic over:
19///
20/// - `R: NodeRef` — the node handle type, either shared (`&Node<V>`) or exclusive
21///   (`NodePtrMut<'_, V>`), determining whether values are yielded as `&V` or `&mut V`.
22/// - `C: Coordinates` — the coordinate accumulator, either `[i32; K]` (fixed-size, for
23///   [`MultiIndexed`]) or `Vec<i32>` (dynamic, for [`KdTrie`]).
24///
25/// The iterator walks the trie in lexicographic order of coordinates, yielding `(C, R::Value)` for
26/// each stored entry.
27struct KdIterator<R, C> {
28    dimensions: usize,
29    stack: Vec<IterFrame<R>>,
30    coordinates: C,
31}
32
33impl<R: NodeRef, C> KdIterator<R, C> {
34    fn new(dimensions: usize, root: R, coordinates: C) -> Self {
35        let root_range = unsafe { root.range(dimensions == 1) };
36        Self {
37            dimensions,
38            stack: vec![IterFrame {
39                depth: 0,
40                current_node: root,
41                range: root_range,
42            }],
43            coordinates,
44        }
45    }
46}
47
48impl<R: NodeRef, C: Coordinates> Iterator for KdIterator<R, C> {
49    type Item = (C, R::Value);
50
51    fn next(&mut self) -> Option<Self::Item> {
52        while let Some(IterFrame {
53            depth,
54            current_node,
55            mut range,
56        }) = self.stack.pop()
57        {
58            self.coordinates.truncate_to(depth);
59
60            // Find the next index in the current range that has a value
61            while let Some(idx) = range.next() {
62                if depth == self.dimensions - 1 {
63                    // This is a leaf node, check if there's a value at this index
64                    if let Some(value) = unsafe { current_node.value(idx) } {
65                        // Push back the remaining range for this node
66                        if !range.is_empty() {
67                            self.stack.push(IterFrame {
68                                depth,
69                                current_node,
70                                range,
71                            });
72                        }
73
74                        self.coordinates.set_coord(depth, idx);
75                        return Some((self.coordinates.get(), value));
76                    }
77                } else if let Some(child_node) = unsafe { current_node.child(idx) } {
78                    // This is an inner node, check if there's a child at this index
79
80                    // Push back the remaining range for this node
81                    if !range.is_empty() {
82                        self.stack.push(IterFrame {
83                            depth,
84                            current_node,
85                            range,
86                        });
87                    }
88
89                    // Add the current index to coordinates and push the child
90                    self.coordinates.set_coord(depth, idx);
91                    let child_range = unsafe { child_node.range(depth + 1 == self.dimensions - 1) };
92                    self.stack.push(IterFrame {
93                        depth: depth + 1,
94                        current_node: child_node,
95                        range: child_range,
96                    });
97
98                    // Go to the next iteration of the outer loop, which will process the child
99                    break;
100                }
101            }
102        }
103
104        None
105    }
106}
107
108impl<R: NodeRef, C: Coordinates> std::iter::FusedIterator for KdIterator<R, C> {}
109
110// --- NodeRef ---
111
112/// Abstraction over shared (`&Node<V>`) and exclusive (`*mut Node<V>`) node access.
113///
114/// This trait allows [`KdIterator`] to be generic over the borrowing mode, so a single iterator
115/// implementation drives both `iter` (shared) and `iter_mut` (exclusive).
116///
117/// # Safety
118///
119/// Implementations must ensure that:
120/// - `range`, `child`, and `value` uphold the safety preconditions of the underlying [`Node`]
121///   methods (i.e. leaf methods are only called on leaf nodes, and inner methods on inner nodes).
122/// - For mutable implementations, the returned value references do not alias.
123unsafe trait NodeRef: Copy {
124    type Value;
125
126    /// Returns the range of indices for this node.
127    ///
128    /// # Safety
129    ///
130    /// `is_leaf` must correctly indicate whether this is a leaf node.
131    unsafe fn range(self, is_leaf: bool) -> Range<i32>;
132
133    /// Returns a handle to the child node at `idx`, or `None` if no child exists.
134    ///
135    /// # Safety
136    ///
137    /// Must only be called on inner nodes.
138    unsafe fn child(self, idx: i32) -> Option<Self>;
139
140    /// Returns a reference to the value at `idx`, or `None` if the slot is empty.
141    ///
142    /// # Safety
143    ///
144    /// Must only be called on leaf nodes.
145    unsafe fn value(self, idx: i32) -> Option<Self::Value>;
146}
147
148/// Shared node reference. Yields `&V` values.
149unsafe impl<'a, V> NodeRef for &'a Node<V> {
150    type Value = &'a V;
151
152    unsafe fn range(self, is_leaf: bool) -> Range<i32> {
153        if is_leaf {
154            unsafe { self.leaf() }.range()
155        } else {
156            unsafe { self.inner() }.range()
157        }
158    }
159
160    unsafe fn child(self, idx: i32) -> Option<Self> {
161        unsafe { self.inner().get(idx) }
162    }
163
164    unsafe fn value(self, idx: i32) -> Option<Self::Value> {
165        unsafe { self.leaf().get(idx) }
166    }
167}
168
169/// A `Copy` wrapper around `*mut Node<V>` that serves as the exclusive counterpart to
170/// `&Node<V>` in the [`NodeRef`] trait.
171///
172/// The phantom lifetime `'a` ties the yielded `&'a mut V` references back to the original
173/// `&'a mut MultiIndexed` (or `&'a mut KdTrie`), ensuring soundness.
174///
175/// This is safe to use because:
176/// - It is only constructed from `&mut MultiIndexed` / `&mut KdTrie`, guaranteeing exclusive access
177///   to the entire tree.
178/// - The tree structure ensures that nodes at different positions are disjoint in memory.
179/// - The iterator yields each value at most once.
180struct NodePtrMut<'a, V>(*mut Node<V>, PhantomData<&'a mut V>);
181
182impl<V> Copy for NodePtrMut<'_, V> {}
183
184impl<V> Clone for NodePtrMut<'_, V> {
185    fn clone(&self) -> Self {
186        *self
187    }
188}
189
190/// Exclusive node reference. Yields `&mut V` values.
191unsafe impl<'a, V> NodeRef for NodePtrMut<'a, V> {
192    type Value = &'a mut V;
193
194    unsafe fn range(self, is_leaf: bool) -> Range<i32> {
195        if is_leaf {
196            unsafe { (*self.0).leaf() }.range()
197        } else {
198            unsafe { (*self.0).inner() }.range()
199        }
200    }
201
202    unsafe fn child(self, idx: i32) -> Option<Self> {
203        let child = unsafe { (*self.0).get_child_mut(idx) }?;
204        Some(Self(child as *mut Node<V>, PhantomData))
205    }
206
207    unsafe fn value(self, idx: i32) -> Option<Self::Value> {
208        unsafe { (*self.0).get_value_mut(idx) }
209    }
210}
211
212// --- Coordinates ---
213
214/// Trait for managing coordinates during iteration.
215///
216/// The iterator accumulates coordinates dimension-by-dimension as it descends. When it backtracks,
217/// it calls [`truncate_to`](Coordinates::truncate_to) to discard coordinates from deeper
218/// dimensions. When it yields an entry, it calls [`get`](Coordinates::get) to snapshot the current
219/// coordinates.
220trait Coordinates {
221    /// Sets the coordinate at the given `depth` (dimension index) to `value`.
222    fn set_coord(&mut self, depth: usize, value: i32);
223
224    /// Discards any coordinate data beyond `depth`, preparing for backtracking.
225    fn truncate_to(&mut self, depth: usize);
226
227    /// Returns a snapshot of the current coordinates.
228    fn get(&self) -> Self;
229}
230
231/// Fixed-size coordinate accumulator for [`MultiIndexed`].
232///
233/// `truncate_to` is a no-op since all dimensions are always present in the array; stale values at
234/// deeper indices are simply overwritten by `set_coord` before they are ever read.
235impl<const K: usize> Coordinates for [i32; K] {
236    fn set_coord(&mut self, depth: usize, value: i32) {
237        self[depth] = value;
238    }
239
240    fn truncate_to(&mut self, _depth: usize) {}
241
242    fn get(&self) -> Self {
243        *self
244    }
245}
246
247/// Dynamic coordinate accumulator for [`KdTrie`].
248///
249/// `set_coord` pushes a new coordinate (asserting that `depth == len`, i.e. coordinates are always
250/// built in order), and `truncate_to` pops coordinates back to the given depth.
251impl Coordinates for Vec<i32> {
252    fn set_coord(&mut self, depth: usize, value: i32) {
253        assert_eq!(self.len(), depth);
254        self.push(value);
255    }
256
257    fn truncate_to(&mut self, depth: usize) {
258        self.truncate(depth);
259    }
260
261    fn get(&self) -> Self {
262        self.clone()
263    }
264}
265
266// --- Public API ---
267
268/// An iterator over the entries of a [`KdTrie`] or [`MultiIndexed`].
269pub struct Iter<'a, V, C>(KdIterator<&'a Node<V>, C>);
270
271impl<'a, V, C: Coordinates> Iterator for Iter<'a, V, C> {
272    type Item = (C, &'a V);
273
274    fn next(&mut self) -> Option<Self::Item> {
275        self.0.next()
276    }
277}
278
279impl<V, C: Coordinates> std::iter::FusedIterator for Iter<'_, V, C> {}
280
281/// A mutable iterator over the entries of a [`KdTrie`] or [`MultiIndexed`].
282pub struct IterMut<'a, V, C>(KdIterator<NodePtrMut<'a, V>, C>);
283
284impl<'a, V, C: Coordinates> Iterator for IterMut<'a, V, C> {
285    type Item = (C, &'a mut V);
286
287    fn next(&mut self) -> Option<Self::Item> {
288        self.0.next()
289    }
290}
291
292impl<V, C: Coordinates> std::iter::FusedIterator for IterMut<'_, V, C> {}
293
294impl<V> KdTrie<V> {
295    pub fn iter(&self) -> Iter<'_, V, Vec<i32>> {
296        let dimensions = self.dimensions();
297        Iter(KdIterator::new(
298            dimensions,
299            self.root(),
300            Vec::with_capacity(dimensions),
301        ))
302    }
303
304    pub fn iter_mut(&mut self) -> IterMut<'_, V, Vec<i32>> {
305        let dimensions = self.dimensions();
306        let root = NodePtrMut(self.root_mut() as *mut Node<V>, PhantomData);
307        IterMut(KdIterator::new(
308            dimensions,
309            root,
310            Vec::with_capacity(dimensions),
311        ))
312    }
313}
314
315impl<'a, V> IntoIterator for &'a KdTrie<V> {
316    type IntoIter = Iter<'a, V, Vec<i32>>;
317    type Item = (Vec<i32>, &'a V);
318
319    fn into_iter(self) -> Self::IntoIter {
320        self.iter()
321    }
322}
323
324impl<'a, V> IntoIterator for &'a mut KdTrie<V> {
325    type IntoIter = IterMut<'a, V, Vec<i32>>;
326    type Item = (Vec<i32>, &'a mut V);
327
328    fn into_iter(self) -> Self::IntoIter {
329        self.iter_mut()
330    }
331}
332
333impl<const K: usize, V> MultiIndexed<K, V> {
334    /// Returns an iterator over all coordinate-value pairs in the array.
335    ///
336    /// The iterator yields `([i32; K], &V)` tuples in lexicographic order of coordinates.
337    ///
338    /// # Examples
339    ///
340    /// ```
341    /// use once::MultiIndexed;
342    ///
343    /// let array = MultiIndexed::<2, i32>::new();
344    /// array.insert([3, 4], 10);
345    /// array.insert([1, 2], 20);
346    ///
347    /// let mut items: Vec<_> = array.iter().collect();
348    ///
349    /// assert_eq!(items, vec![([1, 2], &20), ([3, 4], &10)]);
350    /// ```
351    pub fn iter(&self) -> Iter<'_, V, [i32; K]> {
352        Iter(KdIterator::new(K, self.trie.root(), [0; K]))
353    }
354
355    /// Returns a mutable iterator over all coordinate-value pairs in the array.
356    ///
357    /// The iterator yields `([i32; K], &mut V)` tuples in lexicographic order of coordinates.
358    ///
359    /// # Examples
360    ///
361    /// ```
362    /// use once::MultiIndexed;
363    ///
364    /// let mut array = MultiIndexed::<2, i32>::new();
365    /// array.insert([1, 2], 10);
366    /// array.insert([3, 4], 20);
367    ///
368    /// for (_, v) in array.iter_mut() {
369    ///     *v *= 2;
370    /// }
371    ///
372    /// assert_eq!(array.get([1, 2]), Some(&20));
373    /// assert_eq!(array.get([3, 4]), Some(&40));
374    /// ```
375    pub fn iter_mut(&mut self) -> IterMut<'_, V, [i32; K]> {
376        let root = NodePtrMut(self.trie.root_mut() as *mut Node<V>, PhantomData);
377        IterMut(KdIterator::new(K, root, [0; K]))
378    }
379}
380
381impl<'a, const K: usize, V> IntoIterator for &'a MultiIndexed<K, V> {
382    type IntoIter = Iter<'a, V, [i32; K]>;
383    type Item = ([i32; K], &'a V);
384
385    fn into_iter(self) -> Self::IntoIter {
386        self.iter()
387    }
388}
389
390impl<'a, const K: usize, V> IntoIterator for &'a mut MultiIndexed<K, V> {
391    type IntoIter = IterMut<'a, V, [i32; K]>;
392    type Item = ([i32; K], &'a mut V);
393
394    fn into_iter(self) -> Self::IntoIter {
395        self.iter_mut()
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    // --- MultiIndexed iteration tests ---
404
405    #[test]
406    fn test_multiindexed_iter_empty() {
407        let arr = MultiIndexed::<2, i32>::new();
408        let items: Vec<_> = arr.iter().collect();
409        assert!(items.is_empty());
410    }
411
412    #[test]
413    fn test_multiindexed_iter_multiple_calls() {
414        let arr = MultiIndexed::<2, i32>::new();
415        arr.insert([1, 2], 10);
416        arr.insert([3, 4], 20);
417
418        let items1: Vec<_> = arr.iter().collect();
419        let items2: Vec<_> = arr.iter().collect();
420        assert_eq!(items1, items2);
421    }
422
423    #[test]
424    fn test_multiindexed_iter_mut_empty() {
425        let mut arr = MultiIndexed::<2, i32>::new();
426        let items: Vec<_> = arr.iter_mut().collect();
427        assert!(items.is_empty());
428    }
429
430    #[test]
431    fn test_multiindexed_iter_mut_basic() {
432        let mut arr = MultiIndexed::<2, i32>::new();
433        arr.insert([1, 2], 10);
434        arr.insert([3, 4], 20);
435        arr.insert([-5, 6], 30);
436
437        for (_, v) in arr.iter_mut() {
438            *v *= 3;
439        }
440
441        assert_eq!(arr.get([1, 2]), Some(&30));
442        assert_eq!(arr.get([3, 4]), Some(&60));
443        assert_eq!(arr.get([-5, 6]), Some(&90));
444    }
445
446    #[test]
447    fn test_multiindexed_iter_and_iter_mut_agree() {
448        let mut arr = MultiIndexed::<2, i32>::new();
449        arr.insert([1, 2], 10);
450        arr.insert([3, -4], 20);
451        arr.insert([-5, 6], 30);
452        arr.insert([0, 0], 40);
453
454        let immutable: Vec<_> = arr.iter().map(|(c, &v)| (c, v)).collect();
455        let mutable: Vec<_> = arr.iter_mut().map(|(c, &mut v)| (c, v)).collect();
456        assert_eq!(immutable, mutable);
457    }
458
459    #[test]
460    fn test_multiindexed_iter_mut_no_aliasing() {
461        let mut arr = MultiIndexed::<3, i32>::new();
462        arr.insert([0, 0, 0], 10);
463        arr.insert([0, 0, 1], 20);
464        arr.insert([0, 1, 0], 30);
465        arr.insert([1, 0, 0], 40);
466
467        let mut it = arr.iter_mut();
468        let (_, a) = it.next().unwrap();
469        let (_, b) = it.next().unwrap();
470        let (_, c) = it.next().unwrap();
471        let (_, d) = it.next().unwrap();
472
473        // Miri detects borrow-model violations if any of the references alias, even before the
474        // writes below.
475        *a += 1;
476        *b += 2;
477        *c += 3;
478        *d += 4;
479        drop(it);
480
481        assert_eq!(arr.get([0, 0, 0]), Some(&11));
482        assert_eq!(arr.get([0, 0, 1]), Some(&22));
483        assert_eq!(arr.get([0, 1, 0]), Some(&33));
484        assert_eq!(arr.get([1, 0, 0]), Some(&44));
485    }
486
487    #[test]
488    fn test_multiindexed_into_iterator() {
489        let arr = MultiIndexed::<2, i32>::new();
490        arr.insert([1, 2], 10);
491        arr.insert([3, 4], 20);
492
493        let items: Vec<_> = (&arr).into_iter().collect();
494        assert_eq!(items.len(), 2);
495
496        let mut arr = arr;
497        let items: Vec<_> = (&mut arr).into_iter().map(|(c, &mut v)| (c, v)).collect();
498        assert_eq!(items.len(), 2);
499    }
500
501    // --- KdTrie iteration tests ---
502
503    #[test]
504    fn test_kdtrie_iter_empty() {
505        let trie = KdTrie::<i32>::new(2);
506        let items: Vec<_> = trie.iter().collect();
507        assert_eq!(items, vec![]);
508    }
509
510    #[test]
511    fn test_kdtrie_iter_multiple_calls() {
512        let trie = KdTrie::<i32>::new(2);
513        trie.insert(&[1, 2], 10);
514        trie.insert(&[3, 4], 20);
515
516        let items1: Vec<_> = trie.iter().collect();
517        let items2: Vec<_> = trie.iter().collect();
518
519        assert_eq!(items1, items2);
520    }
521
522    #[test]
523    fn test_kdtrie_iter_mut_empty() {
524        let mut trie = KdTrie::<i32>::new(2);
525        let items: Vec<_> = trie.iter_mut().collect();
526        assert_eq!(items, vec![]);
527    }
528
529    #[test]
530    fn test_kdtrie_iter_mut_basic() {
531        let mut trie = KdTrie::<i32>::new(2);
532        trie.insert(&[1, 2], 10);
533        trie.insert(&[3, 4], 20);
534        trie.insert(&[-5, 6], 30);
535
536        for (_, v) in trie.iter_mut() {
537            *v *= 3;
538        }
539
540        assert_eq!(trie.get(&[1, 2]), Some(&30));
541        assert_eq!(trie.get(&[3, 4]), Some(&60));
542        assert_eq!(trie.get(&[-5, 6]), Some(&90));
543    }
544
545    #[test]
546    fn test_kdtrie_iter_and_iter_mut_agree() {
547        let mut trie = KdTrie::<i32>::new(2);
548        trie.insert(&[1, 2], 10);
549        trie.insert(&[3, -4], 20);
550        trie.insert(&[-5, 6], 30);
551        trie.insert(&[0, 0], 40);
552
553        let immutable: Vec<_> = trie.iter().map(|(c, &v)| (c, v)).collect();
554        let mutable: Vec<_> = trie.iter_mut().map(|(c, &mut v)| (c, v)).collect();
555
556        assert_eq!(immutable, mutable);
557    }
558
559    #[test]
560    fn test_kdtrie_iter_mut_no_aliasing() {
561        let mut trie = KdTrie::<i32>::new(3);
562        trie.insert(&[0, 0, 0], 10);
563        trie.insert(&[0, 0, 1], 20);
564        trie.insert(&[0, 1, 0], 30);
565        trie.insert(&[1, 0, 0], 40);
566
567        let mut it = trie.iter_mut();
568        let (_, a) = it.next().unwrap();
569        let (_, b) = it.next().unwrap();
570        let (_, c) = it.next().unwrap();
571        let (_, d) = it.next().unwrap();
572
573        // Miri detects borrow-model violations if any of the references alias, even before the
574        // writes below.
575        *a += 1;
576        *b += 2;
577        *c += 3;
578        *d += 4;
579        drop(it);
580
581        assert_eq!(trie.get(&[0, 0, 0]), Some(&11));
582        assert_eq!(trie.get(&[0, 0, 1]), Some(&22));
583        assert_eq!(trie.get(&[0, 1, 0]), Some(&33));
584        assert_eq!(trie.get(&[1, 0, 0]), Some(&44));
585    }
586
587    #[test]
588    fn test_kdtrie_into_iterator() {
589        let trie = KdTrie::<i32>::new(2);
590        trie.insert(&[1, 2], 10);
591        trie.insert(&[3, 4], 20);
592
593        let items: Vec<_> = (&trie).into_iter().collect();
594        assert_eq!(items.len(), 2);
595
596        let mut trie = trie;
597        let items: Vec<_> = (&mut trie).into_iter().map(|(c, &mut v)| (c, v)).collect();
598        assert_eq!(items.len(), 2);
599    }
600}