fp/field/field_internal.rs
1// According to
2// https://doc.rust-lang.org/stable/rustc/lints/listing/warn-by-default.html#private-interfaces:
3//
4// "Having something private in primary interface guarantees that the item will be unusable from
5// outer modules due to type privacy."
6//
7// In our case, this is a feature. We want to be able to use the `FieldInternal` trait in this crate
8// and we also want it to be inaccessible from outside the crate.
9#![allow(private_interfaces)]
10
11use std::{hash::Hash, ops::Range};
12
13use super::element::{FieldElement, FieldElementContainer};
14use crate::{
15 constants::BITS_PER_LIMB,
16 limb::{Limb, LimbBitIndexPair},
17};
18
19macro_rules! normal_from_assign {
20 ($fn_normal:ident, $fn_assign:ident) => {
21 fn $fn_normal(
22 self,
23 mut a: FieldElement<Self>,
24 b: FieldElement<Self>,
25 ) -> FieldElement<Self> {
26 self.$fn_assign(&mut a, b);
27 a
28 }
29 };
30}
31
32/// Internal methods required for fields.
33///
34/// A field has several responsibilities. It must define:
35/// - what its elements "look like", i.e. how they are represented in memory;
36/// - how to perform finite field operations on those elements, namely addition, subtraction,
37/// multiplication, division (except by zero), and the Frobenius endomorphism;
38/// - how to pack and unpack elements into and from `Limb`s, so that `FqVector` can handle them.
39///
40/// We want a trait that makes all those definitions. However, we don't want to expose these
41/// implementation details to the outside world. Therefore, we define a public trait that defines
42/// public field methods (e.g. constructing the zero element) and an internal trait that takes care
43/// of the details. The latter trait is `FieldInternal`.
44///
45/// The fact that each field defines its own element type means that we can define a single struct
46/// that packages both a field and one of its elements, and this struct will be how we expose field
47/// operations to the outside world.
48#[allow(private_bounds)]
49pub trait FieldInternal:
50 std::fmt::Debug + Copy + PartialEq + Eq + Hash + Sized + crate::MaybeArbitrary<()> + 'static
51{
52 /// The internal representation of a field element.
53 type ElementContainer: FieldElementContainer;
54
55 /// Create a new field element. This is the method responsible for ensuring that the returned
56 /// value is in a consistent state. For example, for a prime field of characteristic `p`, this
57 /// function is responsible for ensuring that the `FieldElement` that is returned contains a
58 /// value in the range `0..p`.
59 fn el(self, value: Self::ElementContainer) -> FieldElement<Self>;
60
61 // # Field operations
62 // ## Mendatory methods
63
64 fn add_assign(self, a: &mut FieldElement<Self>, b: FieldElement<Self>);
65 fn mul_assign(self, a: &mut FieldElement<Self>, b: FieldElement<Self>);
66
67 fn neg(self, a: FieldElement<Self>) -> FieldElement<Self>;
68 fn inv(self, a: FieldElement<Self>) -> Option<FieldElement<Self>>;
69
70 fn frobenius(self, a: FieldElement<Self>) -> FieldElement<Self>;
71
72 // ## Default implementations
73
74 fn sub_assign(self, a: &mut FieldElement<Self>, b: FieldElement<Self>) {
75 self.add_assign(a, self.neg(b));
76 }
77
78 normal_from_assign!(add, add_assign);
79 normal_from_assign!(sub, sub_assign);
80 normal_from_assign!(mul, mul_assign);
81
82 fn div(self, a: FieldElement<Self>, b: FieldElement<Self>) -> Option<FieldElement<Self>> {
83 Some(self.mul(a, self.inv(b)?))
84 }
85
86 // # Limb operations
87
88 /// Encode a field element into a `Limb`. The limbs of an `FqVector<Self>` will consist of the
89 /// coordinates of the vector, packed together using this method. It is assumed that the output
90 /// value occupies at most `self.bit_length()` bits with the rest padded with zeros, and that
91 /// the limb is reduced.
92 ///
93 /// It is required that `self.encode(self.zero()) == 0` (whenever `Self` implements `Field`).
94 fn encode(self, element: FieldElement<Self>) -> Limb;
95
96 /// Decode a `Limb` into a field element. The argument will always contain a single encoded
97 /// field element, padded with zeros. This is the inverse of [`encode`](FieldInternal::encode).
98 fn decode(self, element: Limb) -> FieldElement<Self>;
99
100 /// Return the number of bits a `Self::Element` occupies in a limb.
101 fn bit_length(self) -> usize;
102
103 /// Fused multiply-add. Return the `Limb` whose `i`th entry is `limb_a[i] + coeff * limb_b[i]`.
104 /// Both `limb_a` and `limb_b` are assumed to be reduced, and the result does not have to be
105 /// reduced.
106 fn fma_limb(self, limb_a: Limb, limb_b: Limb, coeff: FieldElement<Self>) -> Limb;
107
108 /// Reduce a limb, i.e. make it "canonical". For example, in [`Fp`](super::Fp), this replaces
109 /// every entry by its value modulo p.
110 ///
111 /// Many functions assume that the input limbs are reduced, but it's useful to allow the
112 /// existence of non-reduced limbs for performance reasons. Some functions like `fma_limb` can
113 /// be very quick compared to the reduction step, so finishing a computation by reducing all
114 /// limbs in sequence may allow the compiler to play some tricks with, for example, loop
115 /// unrolling and SIMD.
116 fn reduce(self, limb: Limb) -> Limb;
117
118 /// If `l` is a limb of `Self::Element`s, then `l & F.bitmask()` is the value of the
119 /// first entry of `l`.
120 fn bitmask(self) -> Limb {
121 (1 << self.bit_length()) - 1
122 }
123
124 /// The number of `Self::Element`s that fit in a single limb.
125 fn entries_per_limb(self) -> usize {
126 BITS_PER_LIMB / self.bit_length()
127 }
128
129 fn limb_bit_index_pair(self, idx: usize) -> LimbBitIndexPair {
130 LimbBitIndexPair {
131 limb: idx / self.entries_per_limb(),
132 bit_index: (idx % self.entries_per_limb() * self.bit_length()),
133 }
134 }
135
136 /// Check whether or not a limb is reduced. This may potentially not be faster than calling
137 /// [`reduce`](FieldInternal::reduce) directly.
138 fn is_reduced(self, limb: Limb) -> bool {
139 limb == self.reduce(limb)
140 }
141
142 /// Given an interator of `FieldElement<Self>`s, pack all of them into a single limb in order.
143 /// It is assumed that the values of the iterator fit into a single limb. If this assumption is
144 /// violated, the result will be nonsense.
145 fn pack<T: Iterator<Item = FieldElement<Self>>>(self, entries: T) -> Limb {
146 let bit_length = self.bit_length();
147 let mut result: Limb = 0;
148 let mut shift = 0;
149 for entry in entries {
150 result += self.encode(entry) << shift;
151 shift += bit_length;
152 }
153 result
154 }
155
156 /// Give an iterator over the entries of `limb`.
157 fn unpack(self, limb: Limb) -> LimbIterator<Self> {
158 LimbIterator {
159 fq: self,
160 limb,
161 entries: self.entries_per_limb(),
162 bit_length: self.bit_length(),
163 bit_mask: self.bitmask(),
164 }
165 }
166
167 /// Return the number of limbs required to hold `dim` entries.
168 fn number(self, dim: usize) -> usize {
169 if dim == 0 {
170 0
171 } else {
172 self.limb_bit_index_pair(dim - 1).limb + 1
173 }
174 }
175
176 /// Return the `Range<usize>` starting at the index of the limb containing the `start`th entry, and
177 /// ending at the index of the limb containing the `end`th entry (including the latter).
178 fn range(self, start: usize, end: usize) -> Range<usize> {
179 let min = self.limb_bit_index_pair(start).limb;
180 let max = self.number(end);
181 min..max
182 }
183
184 /// Return either `Some(sum)` if no carries happen in the limb, or `None` if some carry does happen.
185 // TODO: maybe name this something clearer
186 fn truncate(self, sum: Limb) -> Option<Limb> {
187 if self.is_reduced(sum) {
188 Some(sum)
189 } else {
190 None
191 }
192 }
193}
194
195pub(crate) struct LimbIterator<F> {
196 fq: F,
197 limb: Limb,
198 entries: usize,
199 bit_length: usize,
200 bit_mask: Limb,
201}
202
203impl<F: FieldInternal> Iterator for LimbIterator<F> {
204 type Item = FieldElement<F>;
205
206 fn next(&mut self) -> Option<Self::Item> {
207 if self.entries == 0 {
208 return None;
209 }
210 self.entries -= 1;
211 let result = self.limb & self.bit_mask;
212 self.limb >>= self.bit_length;
213 Some(self.fq.decode(result))
214 }
215}