1mod avx;
2mod avx2;
3mod avx512;
4mod sse2;
5
6use crate::{
7 blas::block::{MatrixBlock, MatrixBlockSlice},
8 limb::Limb,
9};
10
11macro_rules! add_simd_arch {
12 ($arch:tt) => {
13 const LIMBS_PER_SIMD: usize =
14 std::mem::size_of::<SimdLimb>() / crate::constants::BYTES_PER_LIMB;
15
16 #[target_feature(enable = $arch)]
17 pub(super) fn add_simd(target: &mut [Limb], source: &[Limb], min_limb: usize) {
18 let max_limb = target.len();
19 let target = target.as_mut_ptr();
20 let source = source.as_ptr();
21 let chunks = (max_limb - min_limb) / LIMBS_PER_SIMD;
22 for i in 0..chunks {
23 unsafe {
24 let mut target_chunk = load(target.add(LIMBS_PER_SIMD * i + min_limb));
25 let source_chunk = load(source.add(LIMBS_PER_SIMD * i + min_limb));
26 target_chunk = xor(target_chunk, source_chunk);
27 store(target.add(LIMBS_PER_SIMD * i + min_limb), target_chunk);
28 }
29 }
30 for i in (min_limb + LIMBS_PER_SIMD * chunks)..max_limb {
31 unsafe {
32 *target.add(i) = *target.add(i) ^ *source.add(i);
34 }
35 }
36 }
37 };
38}
39
40use add_simd_arch;
41
42pub(super) fn add_simd(target: &mut [Limb], source: &[Limb], min_limb: usize) {
43 if is_x86_feature_detected!("avx512f") {
44 unsafe { avx512::add_simd(target, source, min_limb) }
45 } else if is_x86_feature_detected!("avx2") {
46 unsafe { avx2::add_simd(target, source, min_limb) }
47 } else if is_x86_feature_detected!("avx") {
48 unsafe { avx::add_simd(target, source, min_limb) }
49 } else if is_x86_feature_detected!("sse2") {
50 unsafe { sse2::add_simd(target, source, min_limb) }
51 } else {
52 super::generic::add_simd(target, source, min_limb)
53 }
54}
55
56pub(super) fn gather_block_simd(slice: MatrixBlockSlice) -> MatrixBlock {
57 if is_x86_feature_detected!("avx512f") {
58 unsafe { avx512::gather_simd(slice) }
59 } else {
60 super::generic::gather_block_simd(slice)
61 }
62}
63
64pub(super) fn gemm_block_simd(a: MatrixBlock, b: MatrixBlock, c: &mut MatrixBlock) {
65 if is_x86_feature_detected!("avx512f") {
66 unsafe { avx512::gemm_block_simd(a, b, c) }
67 } else {
68 super::generic::gemm_block_simd(a, b, c)
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use proptest::prelude::*;
75
76 use super::*;
77
78 proptest! {
79 #[test]
80 fn test_gemm_block_avx512(
81 a: MatrixBlock,
82 b: MatrixBlock,
83 mut c: MatrixBlock,
84 ) {
85 if !is_x86_feature_detected!("avx512f") {
86 return Ok(());
87 }
88 let mut c2 = c;
89 crate::simd::generic::gemm_block_simd(a, b, &mut c);
90 unsafe { super::avx512::gemm_block_simd(a, b, &mut c2) };
91 prop_assert_eq!(c, c2);
92 }
93
94 }
95}