barretenberg
Loading...
Searching...
No Matches
uint128_impl.hpp
1#ifdef __i386__
2#pragma once
3#include "../bitop/get_msb.hpp"
4#include "./uint128.hpp"
5#include "barretenberg/common/assert.hpp"
6namespace numeric {
7
8constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
9{
10 const uint32_t a_lo = a & 0xffffULL;
11 const uint32_t a_hi = a >> 16ULL;
12 const uint32_t b_lo = b & 0xffffULL;
13 const uint32_t b_hi = b >> 16ULL;
14
15 const uint32_t lo_lo = a_lo * b_lo;
16 const uint32_t hi_lo = a_hi * b_lo;
17 const uint32_t lo_hi = a_lo * b_hi;
18 const uint32_t hi_hi = a_hi * b_hi;
19
20 const uint32_t cross = (lo_lo >> 16) + (hi_lo & 0xffffULL) + lo_hi;
21
22 return { (cross << 16ULL) | (lo_lo & 0xffffULL), (hi_lo >> 16ULL) + (cross >> 16ULL) + hi_hi };
23}
24
25// compute a + b + carry, returning the carry
26constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
27{
28 const uint32_t sum = a + b;
29 const auto carry_temp = static_cast<uint32_t>(sum < a);
30 const uint32_t r = sum + carry_in;
31 const uint32_t carry_out = carry_temp + static_cast<unsigned int>(r < carry_in);
32 return { r, carry_out };
33}
34
35constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
36{
37 return a + b + carry_in;
38}
39
40constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
41{
42 const uint32_t t_1 = a - (borrow_in >> 31ULL);
43 const auto borrow_temp_1 = static_cast<uint32_t>(t_1 > a);
44 const uint32_t t_2 = t_1 - b;
45 const auto borrow_temp_2 = static_cast<uint32_t>(t_2 > t_1);
46
47 return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
48}
49
50constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
51{
52 return a - b - (borrow_in >> 31ULL);
53}
54
55// {r, carry_out} = a + carry_in + b * c
56constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
57 const uint32_t b,
58 const uint32_t c,
59 const uint32_t carry_in)
60{
61 std::pair<uint32_t, uint32_t> result = mul_wide(b, c);
62 result.first += a;
63 const auto overflow_c = static_cast<uint32_t>(result.first < a);
64 result.first += carry_in;
65 const auto overflow_carry = static_cast<uint32_t>(result.first < carry_in);
66 result.second += (overflow_c + overflow_carry);
67 return result;
68}
69
70constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
71 const uint32_t b,
72 const uint32_t c,
73 const uint32_t carry_in)
74{
75 return (b * c + a + carry_in);
76}
77
78constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
79{
80 if (*this == 0 || b == 0) {
81 return { 0, 0 };
82 }
83 if (b == 1) {
84 return { *this, 0 };
85 }
86 if (*this == b) {
87 return { 1, 0 };
88 }
89 if (b > *this) {
90 return { 0, *this };
91 }
92
93 uint128_t quotient = 0;
94 uint128_t remainder = *this;
95
96 uint64_t bit_difference = get_msb() - b.get_msb();
97
98 uint128_t divisor = b << bit_difference;
99 uint128_t accumulator = uint128_t(1) << bit_difference;
100
101 // if the divisor is bigger than the remainder, a and b have the same bit length
102 if (divisor > remainder) {
103 divisor >>= 1;
104 accumulator >>= 1;
105 }
106
107 // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
108 // and add to the quotient
109 while (remainder >= b) {
110
111 // we've shunted 'divisor' up to have the same bit length as our remainder.
112 // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
113 if (remainder >= divisor) {
114 remainder -= divisor;
115 // we can use OR here instead of +, as
116 // accumulator is always a nice power of two
117 quotient |= accumulator;
118 }
119 divisor >>= 1;
120 accumulator >>= 1;
121 }
122
123 return { quotient, remainder };
124}
125
126constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
127{
128 const auto [r0, t0] = mul_wide(data[0], other.data[0]);
129 const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
130 const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
131 const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
132
133 const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
134 const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
135 const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
136 const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
137
138 const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
139 const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
140 const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
141 const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
142
143 const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
144 const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
145 const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
146 const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
147
148 uint128_t lo(r0, r1, r2, r3);
149 uint128_t hi(r4, r5, r6, r7);
150 return { lo, hi };
151}
152
158constexpr uint128_t uint128_t::slice(const uint64_t start, const uint64_t end) const
159{
160 const uint64_t range = end - start;
161 const uint128_t mask = (range == 128) ? -uint128_t(1) : (uint128_t(1) << range) - 1;
162 return ((*this) >> start) & mask;
163}
164
165constexpr uint128_t uint128_t::pow(const uint128_t& exponent) const
166{
167 uint128_t accumulator{ data[0], data[1], data[2], data[3] };
168 uint128_t to_mul{ data[0], data[1], data[2], data[3] };
169 const uint64_t maximum_set_bit = exponent.get_msb();
170
171 for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
172 accumulator *= accumulator;
173 if (exponent.get_bit(static_cast<uint64_t>(i))) {
174 accumulator *= to_mul;
175 }
176 }
177 if (exponent == uint128_t(0)) {
178 accumulator = uint128_t(1);
179 } else if (*this == uint128_t(0)) {
180 accumulator = uint128_t(0);
181 }
182 return accumulator;
183}
184
185constexpr bool uint128_t::get_bit(const uint64_t bit_index) const
186{
187 ASSERT(bit_index < 128);
188 if (bit_index > 127) {
189 return false;
190 }
191 const auto idx = static_cast<size_t>(bit_index >> 5);
192 const size_t shift = bit_index & 31;
193 return static_cast<bool>((data[idx] >> shift) & 1);
194}
195
196constexpr uint64_t uint128_t::get_msb() const
197{
198 uint64_t idx = numeric::get_msb64(data[3]);
199 idx = (idx == 0 && data[3] == 0) ? numeric::get_msb64(data[2]) : idx + 32;
200 idx = (idx == 0 && data[2] == 0) ? numeric::get_msb64(data[1]) : idx + 32;
201 idx = (idx == 0 && data[1] == 0) ? numeric::get_msb64(data[0]) : idx + 32;
202 return idx;
203}
204
205constexpr uint128_t uint128_t::operator+(const uint128_t& other) const
206{
207 const auto [r0, t0] = addc(data[0], other.data[0], 0);
208 const auto [r1, t1] = addc(data[1], other.data[1], t0);
209 const auto [r2, t2] = addc(data[2], other.data[2], t1);
210 const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
211 return { r0, r1, r2, r3 };
212};
213
214constexpr uint128_t uint128_t::operator-(const uint128_t& other) const
215{
216
217 const auto [r0, t0] = sbb(data[0], other.data[0], 0);
218 const auto [r1, t1] = sbb(data[1], other.data[1], t0);
219 const auto [r2, t2] = sbb(data[2], other.data[2], t1);
220 const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
221 return { r0, r1, r2, r3 };
222}
223
224constexpr uint128_t uint128_t::operator-() const
225{
226 return uint128_t(0) - *this;
227}
228
229constexpr uint128_t uint128_t::operator*(const uint128_t& other) const
230{
231 const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
232 const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
233 const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
234 const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
235
236 const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
237 const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
238 const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
239
240 const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
241 const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
242
243 const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
244
245 return { r0, r1, r2, r3 };
246}
247
248constexpr uint128_t uint128_t::operator/(const uint128_t& other) const
249{
250 return divmod(other).first;
251}
252
253constexpr uint128_t uint128_t::operator%(const uint128_t& other) const
254{
255 return divmod(other).second;
256}
257
258constexpr uint128_t uint128_t::operator&(const uint128_t& other) const
259{
260 return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
261}
262
263constexpr uint128_t uint128_t::operator^(const uint128_t& other) const
264{
265 return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
266}
267
268constexpr uint128_t uint128_t::operator|(const uint128_t& other) const
269{
270 return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
271}
272
273constexpr uint128_t uint128_t::operator~() const
274{
275 return { ~data[0], ~data[1], ~data[2], ~data[3] };
276}
277
278constexpr bool uint128_t::operator==(const uint128_t& other) const
279{
280 return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
281}
282
283constexpr bool uint128_t::operator!=(const uint128_t& other) const
284{
285 return !(*this == other);
286}
287
288constexpr bool uint128_t::operator!() const
289{
290 return *this == uint128_t(0ULL);
291}
292
293constexpr bool uint128_t::operator>(const uint128_t& other) const
294{
295 bool t0 = data[3] > other.data[3];
296 bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
297 bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
298 bool t3 =
299 data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
300 return t0 || t1 || t2 || t3;
301}
302
303constexpr bool uint128_t::operator>=(const uint128_t& other) const
304{
305 return (*this > other) || (*this == other);
306}
307
308constexpr bool uint128_t::operator<(const uint128_t& other) const
309{
310 return other > *this;
311}
312
313constexpr bool uint128_t::operator<=(const uint128_t& other) const
314{
315 return (*this < other) || (*this == other);
316}
317
318constexpr uint128_t uint128_t::operator>>(const uint128_t& other) const
319{
320 uint32_t total_shift = other.data[0];
321
322 if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
323 return 0;
324 }
325
326 if (total_shift == 0) {
327 return *this;
328 }
329
330 uint32_t num_shifted_limbs = total_shift >> 5ULL;
331 uint32_t limb_shift = total_shift & 31ULL;
332
333 std::array<uint32_t, 4> shifted_limbs = { 0, 0, 0, 0 };
334
335 if (limb_shift == 0) {
336 shifted_limbs[0] = data[0];
337 shifted_limbs[1] = data[1];
338 shifted_limbs[2] = data[2];
339 shifted_limbs[3] = data[3];
340 } else {
341 uint32_t remainder_shift = 32ULL - limb_shift;
342
343 shifted_limbs[3] = data[3] >> limb_shift;
344
345 uint32_t remainder = (data[3]) << remainder_shift;
346
347 shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
348
349 remainder = (data[2]) << remainder_shift;
350
351 shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
352
353 remainder = (data[1]) << remainder_shift;
354
355 shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
356 }
357 uint128_t result(0);
358
359 for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
360 result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
361 }
362
363 return result;
364}
365
366constexpr uint128_t uint128_t::operator<<(const uint128_t& other) const
367{
368 uint32_t total_shift = other.data[0];
369
370 if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
371 return 0;
372 }
373
374 if (total_shift == 0) {
375 return *this;
376 }
377 uint32_t num_shifted_limbs = total_shift >> 5ULL;
378 uint32_t limb_shift = total_shift & 31ULL;
379
380 std::array<uint32_t, 4> shifted_limbs{ 0, 0, 0, 0 };
381
382 if (limb_shift == 0) {
383 shifted_limbs[0] = data[0];
384 shifted_limbs[1] = data[1];
385 shifted_limbs[2] = data[2];
386 shifted_limbs[3] = data[3];
387 } else {
388 uint32_t remainder_shift = 32ULL - limb_shift;
389
390 shifted_limbs[0] = data[0] << limb_shift;
391
392 uint32_t remainder = data[0] >> remainder_shift;
393
394 shifted_limbs[1] = (data[1] << limb_shift) + remainder;
395
396 remainder = data[1] >> remainder_shift;
397
398 shifted_limbs[2] = (data[2] << limb_shift) + remainder;
399
400 remainder = data[2] >> remainder_shift;
401
402 shifted_limbs[3] = (data[3] << limb_shift) + remainder;
403 }
404 uint128_t result(0);
405
406 for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
407 result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
408 }
409
410 return result;
411}
412
413} // namespace numeric
414#endif
Definition: field2_declarations.hpp:6