barretenberg
Loading...
Searching...
No Matches
uintx_impl.hpp
1#pragma once
2#include "./uintx.hpp"
3#include "barretenberg/common/assert.hpp"
4
5namespace numeric {
6template <class base_uint>
7constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
8{
9 ASSERT(b != 0);
10 if (*this == 0) {
11 return { uintx(0), uintx(0) };
12 }
13 if (b == 1) {
14 return { *this, uintx(0) };
15 }
16 if (*this == b) {
17 return { uintx(1), uintx(0) };
18 }
19 if (b > *this) {
20 return { uintx(0), *this };
21 }
22
23 uintx quotient(0);
24 uintx remainder = *this;
25
26 uint64_t bit_difference = get_msb() - b.get_msb();
27
28 uintx divisor = b << bit_difference;
29 uintx accumulator = uintx(1) << bit_difference;
30
31 // if the divisor is bigger than the remainder, a and b have the same bit length
32 if (divisor > remainder) {
33 divisor >>= 1;
34 accumulator >>= 1;
35 }
36
37 // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
38 // and add to the quotient
39 while (remainder >= b) {
40
41 // we've shunted 'divisor' up to have the same bit length as our remainder.
42 // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
43 if (remainder >= divisor) {
44 remainder -= divisor;
45 // we can use OR here instead of +, as
46 // accumulator is always a nice power of two
47 quotient |= accumulator;
48 }
49 divisor >>= 1;
50 accumulator >>= 1;
51 }
52
53 return std::make_pair(quotient, remainder);
54}
55
65template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::unsafe_invmod(const uintx& modulus) const
66{
67
68 uintx t1 = 0;
69 uintx t2 = 1;
70 uintx r2 = (*this > modulus) ? *this % modulus : *this;
71 uintx r1 = modulus;
72 uintx q = 0;
73 while (r2 != 0) {
74 q = r1 / r2;
75 uintx temp_t1 = t1;
76 uintx temp_r1 = r1;
77 t1 = t2;
78 t2 = temp_t1 - q * t2;
79 r1 = r2;
80 r2 = temp_r1 - q * r2;
81 }
82
83 if (t1 > modulus) {
84 return modulus + t1;
85 }
86 return t1;
87}
88
97template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::invmod(const uintx& modulus) const
98{
99 ASSERT((*this) != 0);
100 if (modulus == 0) {
101 return 0;
102 }
103 if (modulus.get_msb() >= (2 * base_uint::length() - 1)) {
104 uintx<uintx<base_uint>> a_expanded(*this);
105 uintx<uintx<base_uint>> modulus_expanded(modulus);
106 return a_expanded.unsafe_invmod(modulus_expanded).lo;
107 }
108 return this->unsafe_invmod(modulus);
109}
110
116template <class base_uint>
117constexpr uintx<base_uint> uintx<base_uint>::slice(const uint64_t start, const uint64_t end) const
118{
119 const uint64_t range = end - start;
120 const uintx mask = range == base_uint::length() ? -uintx(1) : (uintx(1) << range) - 1;
121 return ((*this) >> start) & mask;
122}
123
124template <class base_uint> constexpr bool uintx<base_uint>::get_bit(const uint64_t bit_index) const
125{
126 if (bit_index >= base_uint::length()) {
127 return hi.get_bit(bit_index - base_uint::length());
128 }
129 return lo.get_bit(bit_index);
130}
131
132template <class base_uint> constexpr uint64_t uintx<base_uint>::get_msb() const
133{
134 uint64_t hi_idx = hi.get_msb();
135 uint64_t lo_idx = lo.get_msb();
136 return (hi_idx || (hi > base_uint(0))) ? (hi_idx + base_uint::length()) : lo_idx;
137}
138
139template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator+(const uintx& other) const
140{
141 base_uint res_lo = lo + other.lo;
142 bool carry = res_lo < lo;
143 base_uint res_hi = hi + other.hi + ((carry) ? base_uint(1) : base_uint(0));
144 return { res_lo, res_hi };
145};
146
147template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator-(const uintx& other) const
148{
149 base_uint res_lo = lo - other.lo;
150 bool borrow = res_lo > lo;
151 base_uint res_hi = hi - other.hi - ((borrow) ? base_uint(1) : base_uint(0));
152 return { res_lo, res_hi };
153}
154
155template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator-() const
156{
157 return uintx(0) - *this;
158}
159
160template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator*(const uintx& other) const
161{
162 const auto lolo = lo.mul_extended(other.lo);
163 const auto lohi = lo.mul_extended(other.hi);
164 const auto hilo = hi.mul_extended(other.lo);
165
166 base_uint top = lolo.second + hilo.first + lohi.first;
167 base_uint bottom = lolo.first;
168 return { bottom, top };
169}
170
171template <class base_uint>
172constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::mul_extended(const uintx& other) const
173{
174 const auto lolo = lo.mul_extended(other.lo);
175 const auto lohi = lo.mul_extended(other.hi);
176 const auto hilo = hi.mul_extended(other.lo);
177 const auto hihi = hi.mul_extended(other.hi);
178
179 base_uint t0 = lolo.first;
180 base_uint t1 = lolo.second;
181 base_uint t2 = hilo.second;
182 base_uint t3 = hihi.second;
183 base_uint t2_carry(0);
184 base_uint t3_carry(0);
185 t1 += hilo.first;
186 t2_carry += (t1 < hilo.first ? base_uint(1) : base_uint(0));
187 t1 += lohi.first;
188 t2_carry += (t1 < lohi.first ? base_uint(1) : base_uint(0));
189 t2 += lohi.second;
190 t3_carry += (t2 < lohi.second ? base_uint(1) : base_uint(0));
191 t2 += hihi.first;
192 t3_carry += (t2 < hihi.first ? base_uint(1) : base_uint(0));
193 t2 += t2_carry;
194 t3_carry += (t2 < t2_carry ? base_uint(1) : base_uint(0));
195 t3 += t3_carry;
196 return { uintx(t0, t1), uintx(t2, t3) };
197}
198
199template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator/(const uintx& other) const
200{
201 return divmod(other).first;
202}
203
204template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator%(const uintx& other) const
205{
206 return divmod(other).second;
207}
208// 0x2af0296feca4188a80fd373ebe3c64da87a232934abb3a99f9c4cd59e6758a65
209// 0x1182c6cdb54193b51ca27c1932b95c82bebac691e3996e5ec5e1d4395f3023e3
210template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator&(const uintx& other) const
211{
212 return { lo & other.lo, hi & other.hi };
213}
214
215template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator^(const uintx& other) const
216{
217 return { lo ^ other.lo, hi ^ other.hi };
218}
219
220template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator|(const uintx& other) const
221{
222 return { lo | other.lo, hi | other.hi };
223}
224
225template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator~() const
226{
227 return { ~lo, ~hi };
228}
229
230template <class base_uint> constexpr bool uintx<base_uint>::operator==(const uintx& other) const
231{
232 return ((lo == other.lo) && (hi == other.hi));
233}
234
235template <class base_uint> constexpr bool uintx<base_uint>::operator!=(const uintx& other) const
236{
237 return !(*this == other);
238}
239
240template <class base_uint> constexpr bool uintx<base_uint>::operator!() const
241{
242 return *this == uintx(0ULL);
243}
244
245template <class base_uint> constexpr bool uintx<base_uint>::operator>(const uintx& other) const
246{
247 bool hi_gt = hi > other.hi;
248 bool lo_gt = lo > other.lo;
249
250 bool gt = (hi_gt) || (lo_gt && (hi == other.hi));
251 return gt;
252}
253
254template <class base_uint> constexpr bool uintx<base_uint>::operator>=(const uintx& other) const
255{
256 return (*this > other) || (*this == other);
257}
258
259template <class base_uint> constexpr bool uintx<base_uint>::operator<(const uintx& other) const
260{
261 return other > *this;
262}
263
264template <class base_uint> constexpr bool uintx<base_uint>::operator<=(const uintx& other) const
265{
266 return (*this < other) || (*this == other);
267}
268
269template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator>>(const uint64_t other) const
270{
271 const uint64_t total_shift = other;
272 if (total_shift >= length()) {
273 return uintx(0);
274 }
275 if (total_shift == 0) {
276 return *this;
277 }
278 const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb());
279
280 const uint64_t limb_shift = total_shift & static_cast<uint64_t>(base_uint::length() - 1);
281
282 std::array<base_uint, 2> shifted_limbs = { 0, 0 };
283 if (limb_shift == 0) {
284 shifted_limbs[0] = lo;
285 shifted_limbs[1] = hi;
286 } else {
287 const uint64_t remainder_shift = static_cast<uint64_t>(base_uint::length()) - limb_shift;
288
289 shifted_limbs[1] = hi >> limb_shift;
290
291 base_uint remainder = (hi) << remainder_shift;
292
293 shifted_limbs[0] = (lo >> limb_shift) + remainder;
294 }
295 uintx result(0);
296 if (num_shifted_limbs == 0) {
297 result.hi = shifted_limbs[1];
298 result.lo = shifted_limbs[0];
299 } else {
300 result.lo = shifted_limbs[1];
301 }
302 return result;
303}
304
305template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator<<(const uint64_t other) const
306{
307 const uint64_t total_shift = other;
308 if (total_shift >= length()) {
309 return uintx(0);
310 }
311 if (total_shift == 0) {
312 return *this;
313 }
314 const uint64_t num_shifted_limbs = total_shift >> (base_uint(base_uint::length()).get_msb());
315 const uint64_t limb_shift = total_shift & static_cast<uint64_t>(base_uint::length() - 1);
316
317 std::array<base_uint, 2> shifted_limbs = { 0, 0 };
318 if (limb_shift == 0) {
319 shifted_limbs[0] = lo;
320 shifted_limbs[1] = hi;
321 } else {
322 const uint64_t remainder_shift = static_cast<uint64_t>(base_uint::length()) - limb_shift;
323
324 shifted_limbs[0] = lo << limb_shift;
325
326 base_uint remainder = lo >> remainder_shift;
327
328 shifted_limbs[1] = (hi << limb_shift) + remainder;
329 }
330 uintx result(0);
331 if (num_shifted_limbs == 0) {
332 result.hi = shifted_limbs[1];
333 result.lo = shifted_limbs[0];
334 } else {
335 result.hi = shifted_limbs[0];
336 }
337 return result;
338}
339} // namespace numeric
Definition: uintx.hpp:23
constexpr uintx unsafe_invmod(const uintx &modulus) const
Definition: uintx_impl.hpp:65
constexpr uintx invmod(const uintx &modulus) const
Definition: uintx_impl.hpp:97
constexpr uintx slice(uint64_t start, uint64_t end) const
Definition: uintx_impl.hpp:117
Definition: field2_declarations.hpp:6