fp/simd/x86_64/
mod.rs

1mod avx;
2mod avx2;
3mod avx512;
4mod sse2;
5
6use crate::{
7    blas::block::{MatrixBlock, MatrixBlockSlice},
8    limb::Limb,
9};
10
11macro_rules! add_simd_arch {
12    ($arch:tt) => {
13        const LIMBS_PER_SIMD: usize =
14            std::mem::size_of::<SimdLimb>() / crate::constants::BYTES_PER_LIMB;
15
16        #[target_feature(enable = $arch)]
17        pub(super) fn add_simd(target: &mut [Limb], source: &[Limb], min_limb: usize) {
18            let max_limb = target.len();
19            let target = target.as_mut_ptr();
20            let source = source.as_ptr();
21            let chunks = (max_limb - min_limb) / LIMBS_PER_SIMD;
22            for i in 0..chunks {
23                unsafe {
24                    let mut target_chunk = load(target.add(LIMBS_PER_SIMD * i + min_limb));
25                    let source_chunk = load(source.add(LIMBS_PER_SIMD * i + min_limb));
26                    target_chunk = xor(target_chunk, source_chunk);
27                    store(target.add(LIMBS_PER_SIMD * i + min_limb), target_chunk);
28                }
29            }
30            for i in (min_limb + LIMBS_PER_SIMD * chunks)..max_limb {
31                unsafe {
32                    // pointer arithmetic
33                    *target.add(i) = *target.add(i) ^ *source.add(i);
34                }
35            }
36        }
37    };
38}
39
40use add_simd_arch;
41
42pub(super) fn add_simd(target: &mut [Limb], source: &[Limb], min_limb: usize) {
43    if is_x86_feature_detected!("avx512f") {
44        unsafe { avx512::add_simd(target, source, min_limb) }
45    } else if is_x86_feature_detected!("avx2") {
46        unsafe { avx2::add_simd(target, source, min_limb) }
47    } else if is_x86_feature_detected!("avx") {
48        unsafe { avx::add_simd(target, source, min_limb) }
49    } else if is_x86_feature_detected!("sse2") {
50        unsafe { sse2::add_simd(target, source, min_limb) }
51    } else {
52        super::generic::add_simd(target, source, min_limb)
53    }
54}
55
56pub(super) fn gather_block_simd(slice: MatrixBlockSlice) -> MatrixBlock {
57    if is_x86_feature_detected!("avx512f") {
58        unsafe { avx512::gather_simd(slice) }
59    } else {
60        super::generic::gather_block_simd(slice)
61    }
62}
63
64pub(super) fn gemm_block_simd(a: MatrixBlock, b: MatrixBlock, c: &mut MatrixBlock) {
65    if is_x86_feature_detected!("avx512f") {
66        unsafe { avx512::gemm_block_simd(a, b, c) }
67    } else {
68        super::generic::gemm_block_simd(a, b, c)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use proptest::prelude::*;
75
76    use super::*;
77
78    proptest! {
79        #[test]
80        fn test_gemm_block_avx512(
81            a: MatrixBlock,
82            b: MatrixBlock,
83            mut c: MatrixBlock,
84        ) {
85            if !is_x86_feature_detected!("avx512f") {
86                return Ok(());
87            }
88            let mut c2 = c;
89            crate::simd::generic::gemm_block_simd(a, b, &mut c);
90            unsafe { super::avx512::gemm_block_simd(a, b, &mut c2) };
91            prop_assert_eq!(c, c2);
92        }
93
94    }
95}