once/multiindexed/mod.rs
1use std::ops::{Index, IndexMut};
2
3pub use self::{
4 iter::{Iter, IterMut},
5 kdtrie::KdTrie,
6};
7use crate::std_or_loom::sync::atomic::{AtomicI32, Ordering, fence};
8
9mod iter;
10pub mod kdtrie;
11mod node;
12
13/// A multi-dimensional array that allows efficient storage and retrieval of values using
14/// K-dimensional integer coordinates.
15///
16/// `MultiIndexed<K, V>` provides a thread-safe and wait-free way to store values indexed by
17/// multi-dimensional coordinates. It is implemented using a K-dimensional trie structure that
18/// efficiently handles sparse data, where each level of the trie corresponds to one dimension of
19/// the coordinate space.
20///
21/// The `MultiIndexed` is created with a fixed number of dimensions `K` and can store values of any
22/// type `V`. Each dimension can have both positive and negative indices.
23///
24/// # Thread Safety
25///
26/// `MultiIndexed` is designed to be thread-safe and wait-free, allowing concurrent insertions and
27/// retrievals from multiple threads. This makes it suitable for parallel algorithms that need to
28/// build up a shared data structure.
29///
30/// # Memory Efficiency
31///
32/// The underlying trie structure allocates memory only for coordinates that are actually used,
33/// making it memory-efficient for sparse data. The implementation uses a series of
34/// [`TwoEndedGrove`](crate::TwoEndedGrove) instances to store the trie nodes, which themselves use
35/// a block-based allocation strategy.
36///
37/// # Performance Characteristics
38///
39/// - **Insertion**: O(K) time complexity,
40/// - **Retrieval**: O(K) time complexity,
41/// - **Memory Usage**: amortized O(N) space complexity, where N is the number of inserted elements
42///
43/// # Note
44///
45/// Inserting a value at a coordinate that is already occupied will panic. Values can be mutated
46/// in place through exclusive (`&mut`) references via [`get_mut`](MultiIndexed::get_mut) or
47/// [`IndexMut`]. However, while insertion only requires a shared reference, concurrent mutation is
48/// not supported.
49///
50/// For performance reasons, we do not allow `K = 0`.
51///
52/// # Examples
53///
54/// Correct usage:
55///
56/// ```
57/// use once::MultiIndexed;
58///
59/// // Create a 3-dimensional array
60/// let array = MultiIndexed::<3, i32>::new();
61///
62/// // Insert values at specific coordinates
63/// array.insert([1, 2, 3], 42);
64/// array.insert([5, 0, 2], 100);
65/// array.insert([-1, -2, 3], 200); // Negative coordinates are supported
66///
67/// // Retrieve values
68/// assert_eq!(array.get([1, 2, 3]), Some(&42));
69/// assert_eq!(array.get([5, 0, 2]), Some(&100));
70/// assert_eq!(array.get([-1, -2, 3]), Some(&200));
71/// assert_eq!(array.get([0, 0, 0]), None); // No value at these coordinates
72/// ```
73///
74/// Incorrect usage:
75///
76/// ```should_panic
77/// use once::MultiIndexed;
78///
79/// let array = MultiIndexed::<2, i32>::new();
80///
81/// array.insert([1, 2], 42);
82/// array.insert([1, 2], 43); // Panics because the value at [1, 2] is already set
83/// ```
84///
85/// ```compile_fail
86/// use once::MultiIndexed;
87///
88/// let incorrect = MultiIndexed::<0, ()>::new();
89/// ```
90pub struct MultiIndexed<const K: usize, V> {
91 trie: KdTrie<V>,
92 min_coords: [AtomicI32; K],
93 max_coords: [AtomicI32; K],
94}
95
96impl<const K: usize, V> MultiIndexed<K, V> {
97 const POSITIVE_DIMS: () = assert!(K > 0);
98
99 /// Creates a new empty `MultiIndexed` array with K dimensions.
100 ///
101 /// # Examples
102 ///
103 /// ```
104 /// use once::MultiIndexed;
105 ///
106 /// // Create a 2D array for strings
107 /// let array = MultiIndexed::<2, String>::new();
108 ///
109 /// // Create a 3D array for integers
110 /// let array3d = MultiIndexed::<3, i32>::new();
111 ///
112 /// // Create a 4D array for custom types
113 /// struct Point {
114 /// x: f64,
115 /// y: f64,
116 /// }
117 /// let array4d = MultiIndexed::<4, Point>::new();
118 /// ```
119 pub fn new() -> Self {
120 // Compile-time check
121 let () = Self::POSITIVE_DIMS;
122
123 Self {
124 trie: KdTrie::new(K),
125 min_coords: std::array::from_fn(|_| AtomicI32::new(i32::MAX)),
126 max_coords: std::array::from_fn(|_| AtomicI32::new(i32::MIN)),
127 }
128 }
129
130 /// Retrieves a reference to the value at the specified coordinates, if it exists.
131 ///
132 /// # Parameters
133 ///
134 /// * `coords`: An array of K integer coordinates
135 ///
136 /// # Returns
137 ///
138 /// * `Some(&V)` if a value exists at the specified coordinates
139 /// * `None` if no value exists at the specified coordinates
140 ///
141 /// # Examples
142 ///
143 /// ```
144 /// use once::MultiIndexed;
145 ///
146 /// // Basic retrieval in a 3D array
147 /// let array = MultiIndexed::<3, i32>::new();
148 /// array.insert([1, 2, 3], 42);
149 ///
150 /// assert_eq!(array.get([1, 2, 3]), Some(&42));
151 /// assert_eq!(array.get([0, 0, 0]), None);
152 ///
153 /// // Retrieval with negative coordinates
154 /// array.insert([-5, -10, 15], 100);
155 /// assert_eq!(array.get([-5, -10, 15]), Some(&100));
156 ///
157 /// // Retrieval in a 2D array
158 /// let array2d = MultiIndexed::<2, String>::new();
159 /// array2d.insert([0, 0], "Origin".to_string());
160 /// array2d.insert([10, -5], "Far point".to_string());
161 ///
162 /// assert_eq!(array2d.get([0, 0]), Some(&"Origin".to_string()));
163 /// assert_eq!(array2d.get([10, -5]), Some(&"Far point".to_string()));
164 /// assert_eq!(array2d.get([1, 1]), None);
165 ///
166 /// // Retrieval in a 1D array
167 /// let array1d = MultiIndexed::<1, f64>::new();
168 /// array1d.insert([0], 3.14);
169 /// array1d.insert([-10], -2.71);
170 ///
171 /// assert_eq!(array1d.get([0]), Some(&3.14));
172 /// assert_eq!(array1d.get([-10]), Some(&-2.71));
173 /// assert_eq!(array1d.get([5]), None);
174 /// ```
175 pub fn get(&self, coords: impl Into<[i32; K]>) -> Option<&V> {
176 self.trie.get(&coords.into())
177 }
178
179 /// Retrieves a mutable reference to the value at the specified coordinates, if it exists.
180 ///
181 /// This method can only be called if we have an exclusive reference to self. Having an
182 /// exclusive reference prevents concurrent access, so the atomic synchronization used by `get`
183 /// is unnecessary. This makes it safe to return a mutable reference to the stored value.
184 ///
185 /// # Parameters
186 ///
187 /// * `coords`: An array of K integer coordinates
188 ///
189 /// # Returns
190 ///
191 /// * `Some(&mut V)` if a value exists at the specified coordinates
192 /// * `None` if no value exists at the specified coordinates
193 ///
194 /// # Examples
195 ///
196 /// ```
197 /// use once::MultiIndexed;
198 ///
199 /// let mut array = MultiIndexed::<3, Vec<i32>>::new();
200 /// array.insert([1, 2, 3], vec![1, 2, 3]);
201 /// array.insert([4, 5, 6], vec![4, 5, 6]);
202 ///
203 /// // Modify the vectors in place
204 /// if let Some(vec) = array.get_mut([1, 2, 3]) {
205 /// vec.push(4);
206 /// vec.push(5);
207 /// }
208 ///
209 /// assert_eq!(array.get([1, 2, 3]), Some(&vec![1, 2, 3, 4, 5]));
210 /// ```
211 pub fn get_mut(&mut self, coords: impl Into<[i32; K]>) -> Option<&mut V> {
212 self.trie.get_mut(&coords.into())
213 }
214
215 /// Inserts a value at the specified coordinates.
216 ///
217 /// This operation is thread-safe and can be called from multiple threads. However, this method
218 /// panics if a value already exists at the specified coordinates. Therefore, it should only be
219 /// called at most once for any given set of coordinates.
220 ///
221 /// # Parameters
222 ///
223 /// * `coords`: An array of K integer coordinates
224 /// * `value`: The value to insert at the specified coordinates
225 ///
226 /// # Examples
227 ///
228 /// ```
229 /// use once::MultiIndexed;
230 ///
231 /// // Basic insertion in a 3D array
232 /// let array = MultiIndexed::<3, i32>::new();
233 /// array.insert([1, 2, 3], 42);
234 ///
235 /// // Insertion with negative coordinates
236 /// array.insert([-5, 0, 10], 100);
237 /// array.insert([0, -3, -7], 200);
238 ///
239 /// // Insertion in a 2D array
240 /// let array2d = MultiIndexed::<2, String>::new();
241 /// array2d.insert([0, 0], "Origin".to_string());
242 /// array2d.insert([10, -5], "Far point".to_string());
243 ///
244 /// // Insertion in a 1D array
245 /// let array1d = MultiIndexed::<1, f64>::new();
246 /// array1d.insert([0], 3.14);
247 /// array1d.insert([-10], -2.71);
248 /// ```
249 ///
250 /// # Panics
251 ///
252 /// This method will panic if a value already exists at the specified coordinates:
253 ///
254 /// ```should_panic
255 /// use once::MultiIndexed;
256 ///
257 /// let array = MultiIndexed::<2, i32>::new();
258 /// array.insert([1, 2], 42);
259 /// array.insert([1, 2], 43); // Panics
260 /// ```
261 pub fn insert(&self, coords: impl Into<[i32; K]>, value: V) {
262 let coords = coords.into();
263 // We update the bounds before inserting. The invariant we preserve is the property that
264 // every value lives within bounds, but we're ok if the bounds are not as tight as they
265 // could be.
266 self.update_bounds(&coords);
267 self.trie.insert(&coords, value);
268 }
269
270 pub fn try_insert(&self, coords: impl Into<[i32; K]>, value: V) -> Result<(), V> {
271 let coords = coords.into();
272 // We update the bounds before inserting. The invariant we preserve is the property that
273 // every value lives within bounds, but we're ok if the bounds are not as tight as they
274 // could be.
275 self.update_bounds(&coords);
276 self.trie.try_insert(&coords, value)?;
277 Ok(())
278 }
279
280 fn update_bounds(&self, coords: &[i32; K]) {
281 for (i, coord) in coords.iter().enumerate() {
282 self.min_coords[i].fetch_min(*coord, Ordering::Release);
283 self.max_coords[i].fetch_max(*coord, Ordering::Release);
284 }
285 }
286
287 /// Returns `true` if no values have been inserted.
288 ///
289 /// The bounds are loaded with `Relaxed` ordering, then a single `Acquire` fence synchronizes
290 /// with the `Release` stores in `update_bounds`. This is cheaper than per-load `Acquire` and is
291 /// sufficient: once the fence executes, all prior `Release` stores (across every dimension) are
292 /// visible.
293 pub fn is_empty(&self) -> bool {
294 let min_last = self.min_coords[K - 1].load(Ordering::Relaxed);
295 let max_last = self.max_coords[K - 1].load(Ordering::Relaxed);
296 fence(Ordering::Acquire);
297 min_last > max_last
298 }
299
300 /// Returns the per-dimension minimum coordinates, or `None` if empty.
301 ///
302 /// The `Acquire` fence in [`is_empty`](Self::is_empty) synchronizes with the `Release` stores
303 /// in `update_bounds`, so subsequent `Relaxed` loads on the remaining dimensions see consistent
304 /// values.
305 pub fn min_coords(&self) -> Option<[i32; K]> {
306 if self.is_empty() {
307 return None;
308 }
309 Some(std::array::from_fn(|i| {
310 self.min_coords[i].load(Ordering::Relaxed)
311 }))
312 }
313
314 /// Returns the per-dimension maximum coordinates, or `None` if empty.
315 ///
316 /// The `Acquire` fence in [`is_empty`](Self::is_empty) synchronizes with the `Release` stores
317 /// in `update_bounds`, so subsequent `Relaxed` loads on the remaining dimensions see consistent
318 /// values.
319 pub fn max_coords(&self) -> Option<[i32; K]> {
320 if self.is_empty() {
321 return None;
322 }
323 Some(std::array::from_fn(|i| {
324 self.max_coords[i].load(Ordering::Relaxed)
325 }))
326 }
327}
328
329impl<const K: usize, V> Default for MultiIndexed<K, V> {
330 fn default() -> Self {
331 Self::new()
332 }
333}
334
335impl<const K: usize, V: std::fmt::Debug> std::fmt::Debug for MultiIndexed<K, V> {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.debug_map().entries(self.iter()).finish()
338 }
339}
340
341impl<const K: usize, V> Clone for MultiIndexed<K, V>
342where
343 V: Clone,
344{
345 fn clone(&self) -> Self {
346 let mut min = [i32::MAX; K];
347 let mut max = [i32::MIN; K];
348 let trie = KdTrie::new(K);
349 for (coords, value) in self.iter() {
350 for i in 0..K {
351 if coords[i] < min[i] {
352 min[i] = coords[i];
353 }
354 if coords[i] > max[i] {
355 max[i] = coords[i];
356 }
357 }
358 trie.insert(&coords, value.clone());
359 }
360 Self {
361 trie,
362 min_coords: min.map(AtomicI32::new),
363 max_coords: max.map(AtomicI32::new),
364 }
365 }
366}
367
368impl<const K: usize, V> PartialEq for MultiIndexed<K, V>
369where
370 V: PartialEq,
371{
372 fn eq(&self, other: &Self) -> bool {
373 self.iter().eq(other.iter())
374 }
375}
376
377impl<const K: usize, V> Eq for MultiIndexed<K, V> where V: Eq {}
378
379impl<const K: usize, V: std::hash::Hash> std::hash::Hash for MultiIndexed<K, V> {
380 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
381 for (coords, value) in self.iter() {
382 coords.hash(state);
383 value.hash(state);
384 }
385 }
386}
387
388impl<const K: usize, V, I: Into<[i32; K]>> Index<I> for MultiIndexed<K, V> {
389 type Output = V;
390
391 fn index(&self, index: I) -> &Self::Output {
392 let coords = index.into();
393 self.get(coords)
394 .unwrap_or_else(|| panic!("no value at index {coords:?}"))
395 }
396}
397
398impl<const K: usize, V, I: Into<[i32; K]>> IndexMut<I> for MultiIndexed<K, V> {
399 fn index_mut(&mut self, index: I) -> &mut Self::Output {
400 let coords = index.into();
401 self.get_mut(coords)
402 .unwrap_or_else(|| panic!("no value at index {coords:?}"))
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 #![cfg_attr(miri, allow(dead_code))]
409 use super::*;
410
411 /// Generate all tuples of length K where the sum of coordinates equals n
412 /// ```
413 /// let result = once::get_nth_diagonal::<3>(4);
414 ///
415 /// assert_eq!(result.len(), 15);
416 /// assert!(result.contains(&[0, 0, 4]));
417 /// assert!(result.contains(&[0, 1, 3]));
418 /// assert!(result.contains(&[0, 2, 2]));
419 /// assert!(result.contains(&[0, 3, 1]));
420 /// assert!(result.contains(&[0, 4, 0]));
421 /// assert!(result.contains(&[1, 0, 3]));
422 /// assert!(result.contains(&[1, 1, 2]));
423 /// assert!(result.contains(&[1, 2, 1]));
424 /// assert!(result.contains(&[1, 3, 0]));
425 /// assert!(result.contains(&[2, 0, 2]));
426 /// assert!(result.contains(&[2, 1, 1]));
427 /// assert!(result.contains(&[2, 2, 0]));
428 /// assert!(result.contains(&[3, 0, 1]));
429 /// assert!(result.contains(&[3, 1, 0]));
430 /// assert!(result.contains(&[4, 0, 0]));
431 /// ```
432 pub fn get_nth_diagonal<const K: usize>(n: usize) -> Vec<[i32; K]> {
433 let mut result = Vec::new();
434 let mut tuple = vec![0; K];
435
436 // Generate all tuples where the sum of coordinates equals n
437 generate_tuples::<K>(&mut tuple, 0, n, &mut result);
438
439 result
440 }
441
442 /// Helper function to recursively generate the tuples
443 fn generate_tuples<const K: usize>(
444 tuple: &mut Vec<i32>,
445 index: usize,
446 sum: usize,
447 result: &mut Vec<[i32; K]>,
448 ) {
449 if index == K - 1 {
450 // The last element gets whatever is left to reach the sum
451 tuple[index] = sum as i32;
452 result.push(tuple.clone().try_into().unwrap()); // Convert to [i32; K]
453 return;
454 }
455
456 for i in 0..=sum {
457 tuple[index] = i as i32;
458 generate_tuples::<K>(tuple, index + 1, sum - i, result);
459 }
460 }
461
462 fn get_n_coords<const K: usize>(n: usize) -> Vec<[i32; K]> {
463 (0..).flat_map(get_nth_diagonal).take(n).collect()
464 }
465
466 #[test]
467 fn test_basic() {
468 let arr = MultiIndexed::new();
469
470 arr.insert([1, 2, 3], 42);
471 arr.insert([1, 2, 4], 43);
472 arr.insert([1, 3, 3], 44);
473 arr.insert([1, 3, 4], 45);
474
475 assert_eq!(arr.get([1, 2, 3]), Some(&42));
476 assert_eq!(arr.get([1, 2, 4]), Some(&43));
477 assert_eq!(arr.get([1, 3, 3]), Some(&44));
478 assert_eq!(arr.get([1, 3, 4]), Some(&45));
479 }
480
481 #[test]
482 fn test_get_mut_basic() {
483 let mut arr = MultiIndexed::<3, i32>::new();
484
485 arr.insert([1, 2, 3], 42);
486 arr.insert([4, 5, 6], 100);
487 arr.insert([-1, -2, -3], 200);
488
489 // Modify values using get_mut
490 if let Some(value) = arr.get_mut([1, 2, 3]) {
491 *value = 1000;
492 }
493 if let Some(value) = arr.get_mut([4, 5, 6]) {
494 *value += 50;
495 }
496 if let Some(value) = arr.get_mut([-1, -2, -3]) {
497 *value *= 2;
498 }
499
500 // Verify the modifications
501 assert_eq!(arr.get([1, 2, 3]), Some(&1000));
502 assert_eq!(arr.get([4, 5, 6]), Some(&150));
503 assert_eq!(arr.get([-1, -2, -3]), Some(&400));
504
505 // Verify that get_mut returns None for non-existent coordinates
506 assert_eq!(arr.get_mut([0, 0, 0]), None);
507 }
508
509 // This is a bit too heavy for miri
510 #[cfg_attr(not(miri), test)]
511 fn test_large() {
512 let arr = MultiIndexed::<8, _>::new();
513 for (idx, coord) in get_n_coords(10_000).iter().enumerate() {
514 arr.insert(*coord, idx);
515 }
516 }
517
518 #[test]
519 fn test_index() {
520 let arr = MultiIndexed::<2, i32>::new();
521 arr.insert([1, 2], 10);
522 arr.insert([3, 4], 20);
523
524 assert_eq!(arr[[1, 2]], 10);
525 assert_eq!(arr[[3, 4]], 20);
526 }
527
528 #[test]
529 fn test_index_mut() {
530 let mut arr = MultiIndexed::<2, i32>::new();
531 arr.insert([1, 2], 10);
532 arr.insert([3, 4], 20);
533
534 arr[[1, 2]] += 1;
535 arr[[3, 4]] += 2;
536
537 assert_eq!(arr[[1, 2]], 11);
538 assert_eq!(arr[[3, 4]], 22);
539 }
540
541 #[test]
542 fn test_requires_drop() {
543 use std::{
544 sync::{
545 Arc,
546 atomic::{AtomicUsize, Ordering},
547 },
548 thread,
549 };
550
551 static ACTIVE_ALLOCS: AtomicUsize = AtomicUsize::new(0);
552
553 struct DropCounter;
554
555 impl DropCounter {
556 fn new() -> Self {
557 ACTIVE_ALLOCS.fetch_add(1, Ordering::Relaxed);
558 Self
559 }
560 }
561
562 impl Drop for DropCounter {
563 fn drop(&mut self) {
564 ACTIVE_ALLOCS.fetch_sub(1, Ordering::Relaxed);
565 }
566 }
567
568 let v = Arc::new(MultiIndexed::<3, DropCounter>::new());
569 assert_eq!(ACTIVE_ALLOCS.load(Ordering::Relaxed), 0);
570
571 let num_threads = crate::test_utils::num_threads() as i32;
572 let inserts_per_thread = crate::test_utils::values_per_thread() as i32;
573
574 thread::scope(|s| {
575 for thread_id in 0..num_threads {
576 let v = Arc::clone(&v);
577 s.spawn(move || {
578 for i in (-inserts_per_thread / 2)..(inserts_per_thread / 2) {
579 v.insert([thread_id, i, 4], DropCounter::new());
580 }
581 });
582 }
583 });
584
585 assert_eq!(
586 ACTIVE_ALLOCS.load(Ordering::Relaxed),
587 (num_threads * inserts_per_thread) as usize
588 );
589
590 drop(v);
591
592 assert_eq!(ACTIVE_ALLOCS.load(Ordering::Relaxed), 0);
593 }
594
595 #[test]
596 fn test_debug() {
597 let arr = MultiIndexed::<2, i32>::new();
598 arr.insert([1, 2], 10);
599 arr.insert([3, 4], 20);
600 arr.insert([3, -4], 30);
601 arr.insert([-5, 6], 40);
602
603 expect_test::expect![[r#"
604 {
605 [
606 -5,
607 6,
608 ]: 40,
609 [
610 1,
611 2,
612 ]: 10,
613 [
614 3,
615 -4,
616 ]: 30,
617 [
618 3,
619 4,
620 ]: 20,
621 }
622 "#]]
623 .assert_debug_eq(&arr);
624 }
625
626 #[test]
627 fn test_clone() {
628 let arr = MultiIndexed::<2, i32>::new();
629 arr.insert([1, 2], 10);
630 arr.insert([3, 4], 20);
631
632 let cloned_arr = arr.clone();
633
634 assert_eq!(cloned_arr.get([1, 2]), Some(&10));
635 assert_eq!(cloned_arr.get([3, 4]), Some(&20));
636 assert_eq!(cloned_arr.get([5, 6]), None);
637 }
638
639 #[test]
640 fn test_bounds_empty() {
641 let arr = MultiIndexed::<2, i32>::new();
642 assert!(arr.is_empty());
643 assert_eq!(arr.min_coords(), None);
644 assert_eq!(arr.max_coords(), None);
645 }
646
647 #[test]
648 fn test_bounds_tracking() {
649 let arr = MultiIndexed::<2, i32>::new();
650 arr.insert([1, 2], 10);
651 assert_eq!(arr.min_coords(), Some([1, 2]));
652 assert_eq!(arr.max_coords(), Some([1, 2]));
653
654 arr.insert([3, -4], 20);
655 assert_eq!(arr.min_coords(), Some([1, -4]));
656 assert_eq!(arr.max_coords(), Some([3, 2]));
657
658 arr.insert([-5, 6], 30);
659 assert_eq!(arr.min_coords(), Some([-5, -4]));
660 assert_eq!(arr.max_coords(), Some([3, 6]));
661 }
662
663 #[test]
664 fn test_bounds_clone() {
665 let arr = MultiIndexed::<2, i32>::new();
666 arr.insert([1, 2], 10);
667 arr.insert([-3, 4], 20);
668
669 let cloned = arr.clone();
670 assert_eq!(cloned.min_coords(), Some([-3, 2]));
671 assert_eq!(cloned.max_coords(), Some([1, 4]));
672 }
673
674 #[cfg(not(miri))]
675 mod proptests {
676 use std::collections::HashMap;
677
678 use proptest::prelude::*;
679
680 use super::*;
681
682 // Return `max` such that the max length is twice the number of elements in the hypercube
683 // (-max..=max)^K
684 fn max_from_max_len<const K: usize>(max_len: usize) -> u32 {
685 ((max_len as f32 / 2.0).powf(1.0 / K as f32) / 2.0).ceil() as u32
686 }
687
688 /// Generate a strategy for a single i32 coordinate.
689 fn coord_strategy(max: u32) -> impl Strategy<Value = i32> {
690 -(max as i32)..=max as i32
691 }
692
693 // Generate a strategy for arrays of i32 with a specific dimension
694 fn coords_strategy<const K: usize>(max: u32) -> impl Strategy<Value = [i32; K]> {
695 proptest::collection::vec(coord_strategy(max), K)
696 .prop_map(|v| std::array::from_fn(|i| v[i]))
697 }
698
699 #[derive(Debug, Clone, Copy)]
700 enum Operation<const K: usize> {
701 Insert([i32; K], i32),
702 Get([i32; K]),
703 GetMut([i32; K]),
704 Modify([i32; K], i32), // Add this value to the existing value (requires get_mut)
705 }
706
707 // Generate a strategy for a single operation (insert, get, get_mut, or modify)
708 fn operation_strategy<const K: usize>(max: u32) -> impl Strategy<Value = Operation<K>> {
709 coords_strategy::<K>(max).prop_flat_map(move |coords| {
710 prop_oneof![
711 any::<i32>()
712 .prop_map(move |value| Operation::Insert(coords, value))
713 .boxed(),
714 Just(Operation::Get(coords)).boxed(),
715 Just(Operation::GetMut(coords)).boxed(),
716 any::<i32>()
717 .prop_map(move |delta| Operation::Modify(coords, delta))
718 .boxed(),
719 ]
720 })
721 }
722
723 // Generate a strategy for vectors of i32 coordinates
724 fn coords_vec_strategy<const K: usize>(
725 max_len: usize,
726 ) -> impl Strategy<Value = Vec<[i32; K]>> {
727 let size = max_from_max_len::<K>(max_len);
728 proptest::collection::vec(coords_strategy::<K>(size), 1..=max_len)
729 }
730
731 // Generate a strategy for a list of operations (insert or get)
732 fn operations_strategy<const K: usize>(
733 max_ops: usize,
734 ) -> impl Strategy<Value = Vec<Operation<K>>> {
735 let max = max_from_max_len::<K>(max_ops);
736 proptest::collection::vec(operation_strategy::<K>(max), 1..=max_ops)
737 }
738
739 fn proptest_multiindexed_ops_kd<const K: usize>(ops: Vec<Operation<K>>) {
740 let mut arr = MultiIndexed::<K, i32>::new();
741 let mut reference = HashMap::new();
742
743 for op in ops {
744 match op {
745 Operation::Insert(coords, value) => {
746 // Only insert if the key doesn't exist yet (to avoid panics)
747 if let std::collections::hash_map::Entry::Vacant(e) =
748 reference.entry(coords)
749 {
750 arr.insert(coords, value);
751 e.insert(value);
752 } else {
753 // If the key already exists, test that try_insert returns an error
754 assert!(arr.try_insert(coords, value).is_err());
755 }
756 }
757 Operation::Get(coords) => {
758 // Check that get returns the same as our reference HashMap
759 let actual = arr.get(coords);
760 let expected = reference.get(&coords);
761 assert_eq!(actual, expected);
762 }
763 Operation::GetMut(coords) => {
764 // Check that get_mut returns the same as our reference HashMap
765 let actual = arr.get_mut(coords).map(|v| &*v);
766 let expected = reference.get(&coords);
767 assert_eq!(actual, expected);
768 }
769 Operation::Modify(coords, delta) => {
770 // Try to modify the value using get_mut
771 if let Some(value) = arr.get_mut(coords) {
772 *value = value.wrapping_add(delta);
773 }
774 // Also modify in reference
775 if let Some(value) = reference.get_mut(&coords) {
776 *value = value.wrapping_add(delta);
777 }
778 // Verify they match
779 assert_eq!(arr.get(coords), reference.get(&coords));
780 }
781 }
782 }
783
784 // Final verification: all values should match
785 for (coords, expected_value) in &reference {
786 assert_eq!(arr.get(*coords), Some(expected_value));
787 }
788 }
789
790 fn proptest_multiindexed_iter_kd<const K: usize>(coords: Vec<[i32; K]>) {
791 let arr = MultiIndexed::<K, usize>::new();
792 let mut tagged_coords = vec![];
793 for (i, coord) in coords.iter().enumerate() {
794 if arr.try_insert(*coord, i).is_ok() {
795 // Only insert if the coordinate was not already present
796 tagged_coords.push((*coord, i));
797 };
798 }
799
800 let items: Vec<_> = arr.iter().map(|(coord, value)| (coord, *value)).collect();
801 assert_eq!(items.len(), tagged_coords.len());
802
803 tagged_coords.sort();
804 assert_eq!(tagged_coords, items);
805 }
806
807 fn proptest_multiindexed_iter_mut_kd<const K: usize>(coords: Vec<[i32; K]>) {
808 let mut arr = MultiIndexed::<K, usize>::new();
809 let mut reference = HashMap::new();
810 for (i, coord) in coords.iter().enumerate() {
811 if arr.try_insert(*coord, i).is_ok() {
812 reference.insert(*coord, i);
813 }
814 }
815
816 // Mutate via iter_mut: double every value
817 for (_, v) in arr.iter_mut() {
818 *v *= 2;
819 }
820
821 // Verify against reference
822 for (coord, expected) in &reference {
823 assert_eq!(arr.get(*coord), Some(&(expected * 2)));
824 }
825
826 // Verify iter_mut yields same coordinates as iter (after mutation)
827 let mut iter_items: Vec<_> = arr.iter().map(|(c, &v)| (c, v)).collect();
828 iter_items.sort();
829 let mut reference_items: Vec<_> = reference.iter().map(|(c, v)| (*c, v * 2)).collect();
830 reference_items.sort();
831 assert_eq!(iter_items, reference_items);
832 }
833
834 const MAX_LEN: usize = 10_000;
835
836 proptest! {
837 #[test]
838 fn proptest_multiindexed_ops_2d(ops in operations_strategy::<2>(MAX_LEN)) {
839 proptest_multiindexed_ops_kd::<2>(ops);
840 }
841
842 #[test]
843 fn proptest_multiindexed_ops_3d(ops in operations_strategy::<3>(MAX_LEN)) {
844 proptest_multiindexed_ops_kd::<3>(ops);
845 }
846
847 #[test]
848 fn proptest_multiindexed_iter_2d(coords in coords_vec_strategy::<2>(MAX_LEN)) {
849 proptest_multiindexed_iter_kd::<2>(coords);
850 }
851
852 #[test]
853 fn proptest_multiindexed_iter_3d(coords in coords_vec_strategy::<3>(MAX_LEN)) {
854 proptest_multiindexed_iter_kd::<3>(coords);
855 }
856
857 #[test]
858 fn proptest_multiindexed_iter_mut_2d(coords in coords_vec_strategy::<2>(MAX_LEN)) {
859 proptest_multiindexed_iter_mut_kd::<2>(coords);
860 }
861
862 #[test]
863 fn proptest_multiindexed_iter_mut_3d(coords in coords_vec_strategy::<3>(MAX_LEN)) {
864 proptest_multiindexed_iter_mut_kd::<3>(coords);
865 }
866 }
867 }
868
869 #[cfg(loom)]
870 mod loom_tests {
871 use super::*;
872 use crate::std_or_loom::{sync::Arc, thread};
873
874 #[test]
875 fn loom_concurrent_insert_get() {
876 loom::model(|| {
877 let arr = Arc::new(MultiIndexed::<2, i32>::new());
878
879 // Thread 1: Insert values
880 let arr1 = Arc::clone(&arr);
881 let t1 = thread::spawn(move || {
882 arr1.insert([0, 0], 10);
883 arr1.insert([0, 1], 20);
884 arr1.insert([1, 0], 30);
885 });
886
887 // Thread 2: Insert different values
888 let arr2 = Arc::clone(&arr);
889 let t2 = thread::spawn(move || {
890 arr2.insert([1, 1], 40);
891 arr2.insert([2, 0], 50);
892 arr2.insert([0, 2], 60);
893 });
894
895 // Thread 3: Read values
896 let arr3 = Arc::clone(&arr);
897 let t3 = thread::spawn(move || {
898 // These may or may not be set yet
899 let _ = arr3.get([0, 0]);
900 let _ = arr3.get([1, 1]);
901 let _ = arr3.get([2, 2]); // This one is never set
902 });
903
904 t1.join().unwrap();
905 t2.join().unwrap();
906 t3.join().unwrap();
907
908 // Verify final state
909 assert_eq!(arr.get([0, 0]), Some(&10));
910 assert_eq!(arr.get([0, 1]), Some(&20));
911 assert_eq!(arr.get([1, 0]), Some(&30));
912 assert_eq!(arr.get([1, 1]), Some(&40));
913 assert_eq!(arr.get([2, 0]), Some(&50));
914 assert_eq!(arr.get([0, 2]), Some(&60));
915 assert_eq!(arr.get([2, 2]), None);
916 });
917 }
918
919 #[test]
920 fn loom_concurrent_with_negative_coords() {
921 loom::model(|| {
922 let arr = Arc::new(MultiIndexed::<3, i32>::new());
923
924 // Thread 1: Insert values with negative coordinates
925 let arr1 = Arc::clone(&arr);
926 let t1 = thread::spawn(move || {
927 arr1.insert([-1, -2, -3], 10);
928 arr1.insert([-1, 0, 1], 20);
929 arr1.insert([0, -2, 3], 30);
930 });
931
932 // Thread 2: Insert values with mixed coordinates
933 let arr2 = Arc::clone(&arr);
934 let t2 = thread::spawn(move || {
935 arr2.insert([1, 1, 1], 40);
936 arr2.insert([2, -3, 0], 50);
937 arr2.insert([-5, 2, -1], 60);
938 });
939
940 // Thread 3: Read values
941 let arr3 = Arc::clone(&arr);
942 let t3 = thread::spawn(move || {
943 // These may or may not be set yet
944 let _ = arr3.get([-1, -2, -3]);
945 let _ = arr3.get([1, 1, 1]);
946 let _ = arr3.get([0, 0, 0]); // This one is never set
947 });
948
949 t1.join().unwrap();
950 t2.join().unwrap();
951 t3.join().unwrap();
952
953 // Verify final state
954 assert_eq!(arr.get([-1, -2, -3]), Some(&10));
955 assert_eq!(arr.get([-1, 0, 1]), Some(&20));
956 assert_eq!(arr.get([0, -2, 3]), Some(&30));
957 assert_eq!(arr.get([1, 1, 1]), Some(&40));
958 assert_eq!(arr.get([2, -3, 0]), Some(&50));
959 assert_eq!(arr.get([-5, 2, -1]), Some(&60));
960 assert_eq!(arr.get([0, 0, 0]), None);
961 });
962 }
963
964 /// Asserts that the bounding box invariants hold:
965 /// 1. If `!is_empty()`, bounds are consistent (min <= max per dimension).
966 /// 2. Every retrievable entry fits within the bounding box.
967 ///
968 /// The trie must be read *before* the bounds, because the happens-before chain is:
969 /// `update_bounds(Release) → trie.insert(Release) → trie.get(Acquire) → read bounds`.
970 /// Reading bounds after an Acquire on the trie ensures we see the bounds update that
971 /// preceded the insert.
972 fn assert_bounds_invariants(arr: &MultiIndexed<2, i32>, coords: &[[i32; 2]]) {
973 // Snapshot which entries are currently retrievable. The Acquire loads inside
974 // `get` establish happens-before with the corresponding inserts, which in turn
975 // happen-after their bounds updates.
976 let present: Vec<[i32; 2]> = coords
977 .iter()
978 .copied()
979 .filter(|c| arr.get(*c).is_some())
980 .collect();
981
982 if arr.is_empty() {
983 return;
984 }
985 let min = arr.min_coords().unwrap();
986 let max = arr.max_coords().unwrap();
987 for i in 0..2 {
988 assert!(
989 min[i] <= max[i],
990 "min[{i}] = {} > max[{i}] = {}",
991 min[i],
992 max[i]
993 );
994 }
995 for c in &present {
996 for i in 0..2 {
997 assert!(
998 min[i] <= c[i] && c[i] <= max[i],
999 "entry {c:?} outside bounding box [{min:?}, {max:?}]",
1000 );
1001 }
1002 }
1003 }
1004
1005 #[test]
1006 fn loom_bounds_concurrent_inserts() {
1007 loom::model(|| {
1008 let arr = Arc::new(MultiIndexed::<2, i32>::new());
1009 let coords = [[3, -1], [-2, 4]];
1010
1011 let arr1 = Arc::clone(&arr);
1012 let t1 = thread::spawn(move || {
1013 arr1.insert(coords[0], 10);
1014 });
1015
1016 let arr2 = Arc::clone(&arr);
1017 let t2 = thread::spawn(move || {
1018 arr2.insert(coords[1], 20);
1019 });
1020
1021 let arr3 = Arc::clone(&arr);
1022 let t3 = thread::spawn(move || {
1023 assert_bounds_invariants(&arr3, &coords);
1024 });
1025
1026 t1.join().unwrap();
1027 t2.join().unwrap();
1028 t3.join().unwrap();
1029
1030 assert_bounds_invariants(&arr, &coords);
1031 assert_eq!(arr.min_coords(), Some([-2, -1]));
1032 assert_eq!(arr.max_coords(), Some([3, 4]));
1033 });
1034 }
1035
1036 #[test]
1037 fn loom_bounds_single_element() {
1038 loom::model(|| {
1039 let arr = Arc::new(MultiIndexed::<2, i32>::new());
1040 let coords = [[5, -3]];
1041
1042 let arr1 = Arc::clone(&arr);
1043 let t1 = thread::spawn(move || {
1044 arr1.insert(coords[0], 10);
1045 });
1046
1047 let arr2 = Arc::clone(&arr);
1048 let t2 = thread::spawn(move || {
1049 assert_bounds_invariants(&arr2, &coords);
1050 });
1051
1052 t1.join().unwrap();
1053 t2.join().unwrap();
1054
1055 assert_bounds_invariants(&arr, &coords);
1056 assert_eq!(arr.min_coords(), Some([5, -3]));
1057 assert_eq!(arr.max_coords(), Some([5, -3]));
1058 });
1059 }
1060 }
1061}