barretenberg
Loading...
Searching...
No Matches
uint256_impl.hpp
1#pragma once
2#include "../bitop/get_msb.hpp"
3#include "./uint256.hpp"
4#include "barretenberg/common/assert.hpp"
5namespace numeric {
6
7constexpr std::pair<uint64_t, uint64_t> uint256_t::mul_wide(const uint64_t a, const uint64_t b)
8{
9 const uint64_t a_lo = a & 0xffffffffULL;
10 const uint64_t a_hi = a >> 32ULL;
11 const uint64_t b_lo = b & 0xffffffffULL;
12 const uint64_t b_hi = b >> 32ULL;
13
14 const uint64_t lo_lo = a_lo * b_lo;
15 const uint64_t hi_lo = a_hi * b_lo;
16 const uint64_t lo_hi = a_lo * b_hi;
17 const uint64_t hi_hi = a_hi * b_hi;
18
19 const uint64_t cross = (lo_lo >> 32ULL) + (hi_lo & 0xffffffffULL) + lo_hi;
20
21 return { (cross << 32ULL) | (lo_lo & 0xffffffffULL), (hi_lo >> 32ULL) + (cross >> 32ULL) + hi_hi };
22}
23
24// compute a + b + carry, returning the carry
25constexpr std::pair<uint64_t, uint64_t> uint256_t::addc(const uint64_t a, const uint64_t b, const uint64_t carry_in)
26{
27 const uint64_t sum = a + b;
28 const auto carry_temp = static_cast<uint64_t>(sum < a);
29 const uint64_t r = sum + carry_in;
30 const uint64_t carry_out = carry_temp + static_cast<uint64_t>(r < carry_in);
31 return { r, carry_out };
32}
33
34constexpr uint64_t uint256_t::addc_discard_hi(const uint64_t a, const uint64_t b, const uint64_t carry_in)
35{
36 return a + b + carry_in;
37}
38
39constexpr std::pair<uint64_t, uint64_t> uint256_t::sbb(const uint64_t a, const uint64_t b, const uint64_t borrow_in)
40{
41 const uint64_t t_1 = a - (borrow_in >> 63ULL);
42 const auto borrow_temp_1 = static_cast<uint64_t>(t_1 > a);
43 const uint64_t t_2 = t_1 - b;
44 const auto borrow_temp_2 = static_cast<uint64_t>(t_2 > t_1);
45
46 return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
47}
48
49constexpr uint64_t uint256_t::sbb_discard_hi(const uint64_t a, const uint64_t b, const uint64_t borrow_in)
50{
51 return a - b - (borrow_in >> 63ULL);
52}
53
54// {r, carry_out} = a + carry_in + b * c
55constexpr std::pair<uint64_t, uint64_t> uint256_t::mac(const uint64_t a,
56 const uint64_t b,
57 const uint64_t c,
58 const uint64_t carry_in)
59{
60 std::pair<uint64_t, uint64_t> result = mul_wide(b, c);
61 result.first += a;
62 const auto overflow_c = static_cast<uint64_t>(result.first < a);
63 result.first += carry_in;
64 const auto overflow_carry = static_cast<uint64_t>(result.first < carry_in);
65 result.second += (overflow_c + overflow_carry);
66 return result;
67}
68
69constexpr uint64_t uint256_t::mac_discard_hi(const uint64_t a,
70 const uint64_t b,
71 const uint64_t c,
72 const uint64_t carry_in)
73{
74 return (b * c + a + carry_in);
75}
76
77constexpr std::pair<uint256_t, uint256_t> uint256_t::divmod(const uint256_t& b) const
78{
79 if (*this == 0 || b == 0) {
80 return { 0, 0 };
81 }
82 if (b == 1) {
83 return { *this, 0 };
84 }
85 if (*this == b) {
86 return { 1, 0 };
87 }
88 if (b > *this) {
89 return { 0, *this };
90 }
91
92 uint256_t quotient = 0;
93 uint256_t remainder = *this;
94
95 uint64_t bit_difference = get_msb() - b.get_msb();
96
97 uint256_t divisor = b << bit_difference;
98 uint256_t accumulator = uint256_t(1) << bit_difference;
99
100 // if the divisor is bigger than the remainder, a and b have the same bit length
101 if (divisor > remainder) {
102 divisor >>= 1;
103 accumulator >>= 1;
104 }
105
106 // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
107 // and add to the quotient
108 while (remainder >= b) {
109
110 // we've shunted 'divisor' up to have the same bit length as our remainder.
111 // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
112 if (remainder >= divisor) {
113 remainder -= divisor;
114 // we can use OR here instead of +, as
115 // accumulator is always a nice power of two
116 quotient |= accumulator;
117 }
118 divisor >>= 1;
119 accumulator >>= 1;
120 }
121
122 return { quotient, remainder };
123}
124
125constexpr std::pair<uint256_t, uint256_t> uint256_t::mul_extended(const uint256_t& other) const
126{
127 const auto [r0, t0] = mul_wide(data[0], other.data[0]);
128 const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
129 const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
130 const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
131
132 const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
133 const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
134 const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
135 const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
136
137 const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
138 const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
139 const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
140 const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
141
142 const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
143 const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
144 const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
145 const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
146
147 uint256_t lo(r0, r1, r2, r3);
148 uint256_t hi(r4, r5, r6, r7);
149 return { lo, hi };
150}
151
157constexpr uint256_t uint256_t::slice(const uint64_t start, const uint64_t end) const
158{
159 const uint64_t range = end - start;
160 const uint256_t mask = (range == 256) ? -uint256_t(1) : (uint256_t(1) << range) - 1;
161 return ((*this) >> start) & mask;
162}
163
164constexpr uint256_t uint256_t::pow(const uint256_t& exponent) const
165{
166 uint256_t accumulator{ data[0], data[1], data[2], data[3] };
167 uint256_t to_mul{ data[0], data[1], data[2], data[3] };
168 const uint64_t maximum_set_bit = exponent.get_msb();
169
170 for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
171 accumulator *= accumulator;
172 if (exponent.get_bit(static_cast<uint64_t>(i))) {
173 accumulator *= to_mul;
174 }
175 }
176 if (exponent == uint256_t(0)) {
177 accumulator = uint256_t(1);
178 } else if (*this == uint256_t(0)) {
179 accumulator = uint256_t(0);
180 }
181 return accumulator;
182}
183
184constexpr bool uint256_t::get_bit(const uint64_t bit_index) const
185{
186 ASSERT(bit_index < 256);
187 if (bit_index > 255) {
188 return static_cast<bool>(0);
189 }
190 const auto idx = static_cast<size_t>(bit_index >> 6);
191 const size_t shift = bit_index & 63;
192 return static_cast<bool>((data[idx] >> shift) & 1);
193}
194
195constexpr uint64_t uint256_t::get_msb() const
196{
197 uint64_t idx = numeric::get_msb(data[3]);
198 idx = (idx == 0 && data[3] == 0) ? numeric::get_msb(data[2]) : idx + 64;
199 idx = (idx == 0 && data[2] == 0) ? numeric::get_msb(data[1]) : idx + 64;
200 idx = (idx == 0 && data[1] == 0) ? numeric::get_msb(data[0]) : idx + 64;
201 return idx;
202}
203
204constexpr uint256_t uint256_t::operator+(const uint256_t& other) const
205{
206 const auto [r0, t0] = addc(data[0], other.data[0], 0);
207 const auto [r1, t1] = addc(data[1], other.data[1], t0);
208 const auto [r2, t2] = addc(data[2], other.data[2], t1);
209 const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
210 return { r0, r1, r2, r3 };
211};
212
213constexpr uint256_t uint256_t::operator-(const uint256_t& other) const
214{
215
216 const auto [r0, t0] = sbb(data[0], other.data[0], 0);
217 const auto [r1, t1] = sbb(data[1], other.data[1], t0);
218 const auto [r2, t2] = sbb(data[2], other.data[2], t1);
219 const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
220 return { r0, r1, r2, r3 };
221}
222
223constexpr uint256_t uint256_t::operator-() const
224{
225 return uint256_t(0) - *this;
226}
227
228constexpr uint256_t uint256_t::operator*(const uint256_t& other) const
229{
230 const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
231 const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
232 const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
233 const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
234
235 const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
236 const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
237 const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
238
239 const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
240 const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
241
242 const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
243
244 return { r0, r1, r2, r3 };
245}
246
247constexpr uint256_t uint256_t::operator/(const uint256_t& other) const
248{
249 return divmod(other).first;
250}
251
252constexpr uint256_t uint256_t::operator%(const uint256_t& other) const
253{
254 return divmod(other).second;
255}
256
257constexpr uint256_t uint256_t::operator&(const uint256_t& other) const
258{
259 return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
260}
261
262constexpr uint256_t uint256_t::operator^(const uint256_t& other) const
263{
264 return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
265}
266
267constexpr uint256_t uint256_t::operator|(const uint256_t& other) const
268{
269 return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
270}
271
272constexpr uint256_t uint256_t::operator~() const
273{
274 return { ~data[0], ~data[1], ~data[2], ~data[3] };
275}
276
277constexpr bool uint256_t::operator==(const uint256_t& other) const
278{
279 return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
280}
281
282constexpr bool uint256_t::operator!=(const uint256_t& other) const
283{
284 return !(*this == other);
285}
286
287constexpr bool uint256_t::operator!() const
288{
289 return *this == uint256_t(0ULL);
290}
291
292constexpr bool uint256_t::operator>(const uint256_t& other) const
293{
294 bool t0 = data[3] > other.data[3];
295 bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
296 bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
297 bool t3 =
298 data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
299 return t0 || t1 || t2 || t3;
300}
301
302constexpr bool uint256_t::operator>=(const uint256_t& other) const
303{
304 return (*this > other) || (*this == other);
305}
306
307constexpr bool uint256_t::operator<(const uint256_t& other) const
308{
309 return other > *this;
310}
311
312constexpr bool uint256_t::operator<=(const uint256_t& other) const
313{
314 return (*this < other) || (*this == other);
315}
316
317constexpr uint256_t uint256_t::operator>>(const uint256_t& other) const
318{
319 uint64_t total_shift = other.data[0];
320
321 if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
322 return 0;
323 }
324
325 if (total_shift == 0) {
326 return *this;
327 }
328
329 uint64_t num_shifted_limbs = total_shift >> 6ULL;
330 uint64_t limb_shift = total_shift & 63ULL;
331
332 std::array<uint64_t, 4> shifted_limbs = { 0, 0, 0, 0 };
333
334 if (limb_shift == 0) {
335 shifted_limbs[0] = data[0];
336 shifted_limbs[1] = data[1];
337 shifted_limbs[2] = data[2];
338 shifted_limbs[3] = data[3];
339 } else {
340 uint64_t remainder_shift = 64ULL - limb_shift;
341
342 shifted_limbs[3] = data[3] >> limb_shift;
343
344 uint64_t remainder = (data[3]) << remainder_shift;
345
346 shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
347
348 remainder = (data[2]) << remainder_shift;
349
350 shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
351
352 remainder = (data[1]) << remainder_shift;
353
354 shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
355 }
356 uint256_t result(0);
357
358 for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
359 result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
360 }
361
362 return result;
363}
364
365constexpr uint256_t uint256_t::operator<<(const uint256_t& other) const
366{
367 uint64_t total_shift = other.data[0];
368
369 if (total_shift >= 256 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
370 return 0;
371 }
372
373 if (total_shift == 0) {
374 return *this;
375 }
376 uint64_t num_shifted_limbs = total_shift >> 6ULL;
377 uint64_t limb_shift = total_shift & 63ULL;
378
379 std::array<uint64_t, 4> shifted_limbs = { 0, 0, 0, 0 };
380
381 if (limb_shift == 0) {
382 shifted_limbs[0] = data[0];
383 shifted_limbs[1] = data[1];
384 shifted_limbs[2] = data[2];
385 shifted_limbs[3] = data[3];
386 } else {
387 uint64_t remainder_shift = 64ULL - limb_shift;
388
389 shifted_limbs[0] = data[0] << limb_shift;
390
391 uint64_t remainder = data[0] >> remainder_shift;
392
393 shifted_limbs[1] = (data[1] << limb_shift) + remainder;
394
395 remainder = data[1] >> remainder_shift;
396
397 shifted_limbs[2] = (data[2] << limb_shift) + remainder;
398
399 remainder = data[2] >> remainder_shift;
400
401 shifted_limbs[3] = (data[3] << limb_shift) + remainder;
402 }
403 uint256_t result(0);
404
405 for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
406 result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
407 }
408
409 return result;
410}
411
412} // namespace numeric
Definition: uint256.hpp:25
constexpr uint256_t slice(uint64_t start, uint64_t end) const
Definition: uint256_impl.hpp:157
Definition: field2_declarations.hpp:6