1use std::{marker::PhantomData, ops::Range};
2
3use super::{KdTrie, node::Node};
4use crate::MultiIndexed;
5
6struct IterFrame<R> {
13 depth: usize,
14 current_node: R,
15 range: Range<i32>,
16}
17
18struct KdIterator<R, C> {
28 dimensions: usize,
29 stack: Vec<IterFrame<R>>,
30 coordinates: C,
31}
32
33impl<R: NodeRef, C> KdIterator<R, C> {
34 fn new(dimensions: usize, root: R, coordinates: C) -> Self {
35 let root_range = unsafe { root.range(dimensions == 1) };
36 Self {
37 dimensions,
38 stack: vec![IterFrame {
39 depth: 0,
40 current_node: root,
41 range: root_range,
42 }],
43 coordinates,
44 }
45 }
46}
47
48impl<R: NodeRef, C: Coordinates> Iterator for KdIterator<R, C> {
49 type Item = (C, R::Value);
50
51 fn next(&mut self) -> Option<Self::Item> {
52 while let Some(IterFrame {
53 depth,
54 current_node,
55 mut range,
56 }) = self.stack.pop()
57 {
58 self.coordinates.truncate_to(depth);
59
60 while let Some(idx) = range.next() {
62 if depth == self.dimensions - 1 {
63 if let Some(value) = unsafe { current_node.value(idx) } {
65 if !range.is_empty() {
67 self.stack.push(IterFrame {
68 depth,
69 current_node,
70 range,
71 });
72 }
73
74 self.coordinates.set_coord(depth, idx);
75 return Some((self.coordinates.get(), value));
76 }
77 } else if let Some(child_node) = unsafe { current_node.child(idx) } {
78 if !range.is_empty() {
82 self.stack.push(IterFrame {
83 depth,
84 current_node,
85 range,
86 });
87 }
88
89 self.coordinates.set_coord(depth, idx);
91 let child_range = unsafe { child_node.range(depth + 1 == self.dimensions - 1) };
92 self.stack.push(IterFrame {
93 depth: depth + 1,
94 current_node: child_node,
95 range: child_range,
96 });
97
98 break;
100 }
101 }
102 }
103
104 None
105 }
106}
107
108impl<R: NodeRef, C: Coordinates> std::iter::FusedIterator for KdIterator<R, C> {}
109
110unsafe trait NodeRef: Copy {
124 type Value;
125
126 unsafe fn range(self, is_leaf: bool) -> Range<i32>;
132
133 unsafe fn child(self, idx: i32) -> Option<Self>;
139
140 unsafe fn value(self, idx: i32) -> Option<Self::Value>;
146}
147
148unsafe impl<'a, V> NodeRef for &'a Node<V> {
150 type Value = &'a V;
151
152 unsafe fn range(self, is_leaf: bool) -> Range<i32> {
153 if is_leaf {
154 unsafe { self.leaf() }.range()
155 } else {
156 unsafe { self.inner() }.range()
157 }
158 }
159
160 unsafe fn child(self, idx: i32) -> Option<Self> {
161 unsafe { self.inner().get(idx) }
162 }
163
164 unsafe fn value(self, idx: i32) -> Option<Self::Value> {
165 unsafe { self.leaf().get(idx) }
166 }
167}
168
169struct NodePtrMut<'a, V>(*mut Node<V>, PhantomData<&'a mut V>);
181
182impl<V> Copy for NodePtrMut<'_, V> {}
183
184impl<V> Clone for NodePtrMut<'_, V> {
185 fn clone(&self) -> Self {
186 *self
187 }
188}
189
190unsafe impl<'a, V> NodeRef for NodePtrMut<'a, V> {
192 type Value = &'a mut V;
193
194 unsafe fn range(self, is_leaf: bool) -> Range<i32> {
195 if is_leaf {
196 unsafe { (*self.0).leaf() }.range()
197 } else {
198 unsafe { (*self.0).inner() }.range()
199 }
200 }
201
202 unsafe fn child(self, idx: i32) -> Option<Self> {
203 let child = unsafe { (*self.0).get_child_mut(idx) }?;
204 Some(Self(child as *mut Node<V>, PhantomData))
205 }
206
207 unsafe fn value(self, idx: i32) -> Option<Self::Value> {
208 unsafe { (*self.0).get_value_mut(idx) }
209 }
210}
211
212trait Coordinates {
221 fn set_coord(&mut self, depth: usize, value: i32);
223
224 fn truncate_to(&mut self, depth: usize);
226
227 fn get(&self) -> Self;
229}
230
231impl<const K: usize> Coordinates for [i32; K] {
236 fn set_coord(&mut self, depth: usize, value: i32) {
237 self[depth] = value;
238 }
239
240 fn truncate_to(&mut self, _depth: usize) {}
241
242 fn get(&self) -> Self {
243 *self
244 }
245}
246
247impl Coordinates for Vec<i32> {
252 fn set_coord(&mut self, depth: usize, value: i32) {
253 assert_eq!(self.len(), depth);
254 self.push(value);
255 }
256
257 fn truncate_to(&mut self, depth: usize) {
258 self.truncate(depth);
259 }
260
261 fn get(&self) -> Self {
262 self.clone()
263 }
264}
265
266pub struct Iter<'a, V, C>(KdIterator<&'a Node<V>, C>);
270
271impl<'a, V, C: Coordinates> Iterator for Iter<'a, V, C> {
272 type Item = (C, &'a V);
273
274 fn next(&mut self) -> Option<Self::Item> {
275 self.0.next()
276 }
277}
278
279impl<V, C: Coordinates> std::iter::FusedIterator for Iter<'_, V, C> {}
280
281pub struct IterMut<'a, V, C>(KdIterator<NodePtrMut<'a, V>, C>);
283
284impl<'a, V, C: Coordinates> Iterator for IterMut<'a, V, C> {
285 type Item = (C, &'a mut V);
286
287 fn next(&mut self) -> Option<Self::Item> {
288 self.0.next()
289 }
290}
291
292impl<V, C: Coordinates> std::iter::FusedIterator for IterMut<'_, V, C> {}
293
294impl<V> KdTrie<V> {
295 pub fn iter(&self) -> Iter<'_, V, Vec<i32>> {
296 let dimensions = self.dimensions();
297 Iter(KdIterator::new(
298 dimensions,
299 self.root(),
300 Vec::with_capacity(dimensions),
301 ))
302 }
303
304 pub fn iter_mut(&mut self) -> IterMut<'_, V, Vec<i32>> {
305 let dimensions = self.dimensions();
306 let root = NodePtrMut(self.root_mut() as *mut Node<V>, PhantomData);
307 IterMut(KdIterator::new(
308 dimensions,
309 root,
310 Vec::with_capacity(dimensions),
311 ))
312 }
313}
314
315impl<'a, V> IntoIterator for &'a KdTrie<V> {
316 type IntoIter = Iter<'a, V, Vec<i32>>;
317 type Item = (Vec<i32>, &'a V);
318
319 fn into_iter(self) -> Self::IntoIter {
320 self.iter()
321 }
322}
323
324impl<'a, V> IntoIterator for &'a mut KdTrie<V> {
325 type IntoIter = IterMut<'a, V, Vec<i32>>;
326 type Item = (Vec<i32>, &'a mut V);
327
328 fn into_iter(self) -> Self::IntoIter {
329 self.iter_mut()
330 }
331}
332
333impl<const K: usize, V> MultiIndexed<K, V> {
334 pub fn iter(&self) -> Iter<'_, V, [i32; K]> {
352 Iter(KdIterator::new(K, self.trie.root(), [0; K]))
353 }
354
355 pub fn iter_mut(&mut self) -> IterMut<'_, V, [i32; K]> {
376 let root = NodePtrMut(self.trie.root_mut() as *mut Node<V>, PhantomData);
377 IterMut(KdIterator::new(K, root, [0; K]))
378 }
379}
380
381impl<'a, const K: usize, V> IntoIterator for &'a MultiIndexed<K, V> {
382 type IntoIter = Iter<'a, V, [i32; K]>;
383 type Item = ([i32; K], &'a V);
384
385 fn into_iter(self) -> Self::IntoIter {
386 self.iter()
387 }
388}
389
390impl<'a, const K: usize, V> IntoIterator for &'a mut MultiIndexed<K, V> {
391 type IntoIter = IterMut<'a, V, [i32; K]>;
392 type Item = ([i32; K], &'a mut V);
393
394 fn into_iter(self) -> Self::IntoIter {
395 self.iter_mut()
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
406 fn test_multiindexed_iter_empty() {
407 let arr = MultiIndexed::<2, i32>::new();
408 let items: Vec<_> = arr.iter().collect();
409 assert!(items.is_empty());
410 }
411
412 #[test]
413 fn test_multiindexed_iter_multiple_calls() {
414 let arr = MultiIndexed::<2, i32>::new();
415 arr.insert([1, 2], 10);
416 arr.insert([3, 4], 20);
417
418 let items1: Vec<_> = arr.iter().collect();
419 let items2: Vec<_> = arr.iter().collect();
420 assert_eq!(items1, items2);
421 }
422
423 #[test]
424 fn test_multiindexed_iter_mut_empty() {
425 let mut arr = MultiIndexed::<2, i32>::new();
426 let items: Vec<_> = arr.iter_mut().collect();
427 assert!(items.is_empty());
428 }
429
430 #[test]
431 fn test_multiindexed_iter_mut_basic() {
432 let mut arr = MultiIndexed::<2, i32>::new();
433 arr.insert([1, 2], 10);
434 arr.insert([3, 4], 20);
435 arr.insert([-5, 6], 30);
436
437 for (_, v) in arr.iter_mut() {
438 *v *= 3;
439 }
440
441 assert_eq!(arr.get([1, 2]), Some(&30));
442 assert_eq!(arr.get([3, 4]), Some(&60));
443 assert_eq!(arr.get([-5, 6]), Some(&90));
444 }
445
446 #[test]
447 fn test_multiindexed_iter_and_iter_mut_agree() {
448 let mut arr = MultiIndexed::<2, i32>::new();
449 arr.insert([1, 2], 10);
450 arr.insert([3, -4], 20);
451 arr.insert([-5, 6], 30);
452 arr.insert([0, 0], 40);
453
454 let immutable: Vec<_> = arr.iter().map(|(c, &v)| (c, v)).collect();
455 let mutable: Vec<_> = arr.iter_mut().map(|(c, &mut v)| (c, v)).collect();
456 assert_eq!(immutable, mutable);
457 }
458
459 #[test]
460 fn test_multiindexed_iter_mut_no_aliasing() {
461 let mut arr = MultiIndexed::<3, i32>::new();
462 arr.insert([0, 0, 0], 10);
463 arr.insert([0, 0, 1], 20);
464 arr.insert([0, 1, 0], 30);
465 arr.insert([1, 0, 0], 40);
466
467 let mut it = arr.iter_mut();
468 let (_, a) = it.next().unwrap();
469 let (_, b) = it.next().unwrap();
470 let (_, c) = it.next().unwrap();
471 let (_, d) = it.next().unwrap();
472
473 *a += 1;
476 *b += 2;
477 *c += 3;
478 *d += 4;
479 drop(it);
480
481 assert_eq!(arr.get([0, 0, 0]), Some(&11));
482 assert_eq!(arr.get([0, 0, 1]), Some(&22));
483 assert_eq!(arr.get([0, 1, 0]), Some(&33));
484 assert_eq!(arr.get([1, 0, 0]), Some(&44));
485 }
486
487 #[test]
488 fn test_multiindexed_into_iterator() {
489 let arr = MultiIndexed::<2, i32>::new();
490 arr.insert([1, 2], 10);
491 arr.insert([3, 4], 20);
492
493 let items: Vec<_> = (&arr).into_iter().collect();
494 assert_eq!(items.len(), 2);
495
496 let mut arr = arr;
497 let items: Vec<_> = (&mut arr).into_iter().map(|(c, &mut v)| (c, v)).collect();
498 assert_eq!(items.len(), 2);
499 }
500
501 #[test]
504 fn test_kdtrie_iter_empty() {
505 let trie = KdTrie::<i32>::new(2);
506 let items: Vec<_> = trie.iter().collect();
507 assert_eq!(items, vec![]);
508 }
509
510 #[test]
511 fn test_kdtrie_iter_multiple_calls() {
512 let trie = KdTrie::<i32>::new(2);
513 trie.insert(&[1, 2], 10);
514 trie.insert(&[3, 4], 20);
515
516 let items1: Vec<_> = trie.iter().collect();
517 let items2: Vec<_> = trie.iter().collect();
518
519 assert_eq!(items1, items2);
520 }
521
522 #[test]
523 fn test_kdtrie_iter_mut_empty() {
524 let mut trie = KdTrie::<i32>::new(2);
525 let items: Vec<_> = trie.iter_mut().collect();
526 assert_eq!(items, vec![]);
527 }
528
529 #[test]
530 fn test_kdtrie_iter_mut_basic() {
531 let mut trie = KdTrie::<i32>::new(2);
532 trie.insert(&[1, 2], 10);
533 trie.insert(&[3, 4], 20);
534 trie.insert(&[-5, 6], 30);
535
536 for (_, v) in trie.iter_mut() {
537 *v *= 3;
538 }
539
540 assert_eq!(trie.get(&[1, 2]), Some(&30));
541 assert_eq!(trie.get(&[3, 4]), Some(&60));
542 assert_eq!(trie.get(&[-5, 6]), Some(&90));
543 }
544
545 #[test]
546 fn test_kdtrie_iter_and_iter_mut_agree() {
547 let mut trie = KdTrie::<i32>::new(2);
548 trie.insert(&[1, 2], 10);
549 trie.insert(&[3, -4], 20);
550 trie.insert(&[-5, 6], 30);
551 trie.insert(&[0, 0], 40);
552
553 let immutable: Vec<_> = trie.iter().map(|(c, &v)| (c, v)).collect();
554 let mutable: Vec<_> = trie.iter_mut().map(|(c, &mut v)| (c, v)).collect();
555
556 assert_eq!(immutable, mutable);
557 }
558
559 #[test]
560 fn test_kdtrie_iter_mut_no_aliasing() {
561 let mut trie = KdTrie::<i32>::new(3);
562 trie.insert(&[0, 0, 0], 10);
563 trie.insert(&[0, 0, 1], 20);
564 trie.insert(&[0, 1, 0], 30);
565 trie.insert(&[1, 0, 0], 40);
566
567 let mut it = trie.iter_mut();
568 let (_, a) = it.next().unwrap();
569 let (_, b) = it.next().unwrap();
570 let (_, c) = it.next().unwrap();
571 let (_, d) = it.next().unwrap();
572
573 *a += 1;
576 *b += 2;
577 *c += 3;
578 *d += 4;
579 drop(it);
580
581 assert_eq!(trie.get(&[0, 0, 0]), Some(&11));
582 assert_eq!(trie.get(&[0, 0, 1]), Some(&22));
583 assert_eq!(trie.get(&[0, 1, 0]), Some(&33));
584 assert_eq!(trie.get(&[1, 0, 0]), Some(&44));
585 }
586
587 #[test]
588 fn test_kdtrie_into_iterator() {
589 let trie = KdTrie::<i32>::new(2);
590 trie.insert(&[1, 2], 10);
591 trie.insert(&[3, 4], 20);
592
593 let items: Vec<_> = (&trie).into_iter().collect();
594 assert_eq!(items.len(), 2);
595
596 let mut trie = trie;
597 let items: Vec<_> = (&mut trie).into_iter().map(|(c, &mut v)| (c, v)).collect();
598 assert_eq!(items.len(), 2);
599 }
600}