barretenberg
Loading...
Searching...
No Matches
wnaf.hpp
1#pragma once
2#include "barretenberg/numeric/bitop/get_msb.hpp"
3#include <cstdint>
4#include <iostream>
5
6// NOLINTBEGIN(readability-implicit-bool-conversion)
7namespace barretenberg::wnaf {
8constexpr size_t SCALAR_BITS = 127;
9
10#define WNAF_SIZE(x) ((barretenberg::wnaf::SCALAR_BITS + (x)-1) / (x)) // NOLINT(cppcoreguidelines-macro-usage)
11
12constexpr size_t get_optimal_bucket_width(const size_t num_points)
13{
14 if (num_points >= 14617149) {
15 return 21;
16 }
17 if (num_points >= 1139094) {
18 return 18;
19 }
20 // if (num_points >= 100000)
21 if (num_points >= 155975) {
22 return 15;
23 }
24 if (num_points >= 144834)
25 // if (num_points >= 100000)
26 {
27 return 14;
28 }
29 if (num_points >= 25067) {
30 return 12;
31 }
32 if (num_points >= 13926) {
33 return 11;
34 }
35 if (num_points >= 7659) {
36 return 10;
37 }
38 if (num_points >= 2436) {
39 return 9;
40 }
41 if (num_points >= 376) {
42 return 7;
43 }
44 if (num_points >= 231) {
45 return 6;
46 }
47 if (num_points >= 97) {
48 return 5;
49 }
50 if (num_points >= 35) {
51 return 4;
52 }
53 if (num_points >= 10) {
54 return 3;
55 }
56 if (num_points >= 2) {
57 return 2;
58 }
59 return 1;
60}
61constexpr size_t get_num_buckets(const size_t num_points)
62{
63 const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2);
64 return 1UL << bits_per_bucket;
65}
66
67constexpr size_t get_num_rounds(const size_t num_points)
68{
69 const size_t bits_per_bucket = get_optimal_bucket_width(num_points / 2);
70 return WNAF_SIZE(bits_per_bucket + 1);
71}
72
73template <size_t bits, size_t bit_position> inline uint64_t get_wnaf_bits_const(const uint64_t* scalar) noexcept
74{
75 if constexpr (bits == 0) {
76 return 0ULL;
77 } else {
91 constexpr size_t lo_limb_idx = bit_position / 64;
92 constexpr size_t hi_limb_idx = (bit_position + bits - 1) / 64;
93 constexpr uint64_t lo_shift = bit_position & 63UL;
94 constexpr uint64_t bit_mask = (1UL << static_cast<uint64_t>(bits)) - 1UL;
95
96 uint64_t lo = (scalar[lo_limb_idx] >> lo_shift);
97 if constexpr (lo_limb_idx == hi_limb_idx) {
98 return lo & bit_mask;
99 } else {
100 constexpr uint64_t hi_shift = 64UL - (bit_position & 63UL);
101 uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift)));
102 return (lo | hi) & bit_mask;
103 }
104 }
105}
106
107inline uint64_t get_wnaf_bits(const uint64_t* scalar, const uint64_t bits, const uint64_t bit_position) noexcept
108{
122 const auto lo_limb_idx = static_cast<size_t>(bit_position >> 6);
123 const auto hi_limb_idx = static_cast<size_t>((bit_position + bits - 1) >> 6);
124 const uint64_t lo_shift = bit_position & 63UL;
125 const uint64_t bit_mask = (1UL << static_cast<uint64_t>(bits)) - 1UL;
126
127 const uint64_t lo = (scalar[lo_limb_idx] >> lo_shift);
128 const uint64_t hi_shift = bit_position ? 64UL - (bit_position & 63UL) : 0;
129 const uint64_t hi = ((scalar[hi_limb_idx] << (hi_shift)));
130 const uint64_t hi_mask = bit_mask & (0ULL - (lo_limb_idx != hi_limb_idx));
131
132 return (lo & bit_mask) | (hi & hi_mask);
133}
134
135inline void fixed_wnaf_packed(
136 const uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const uint64_t point_index, const size_t wnaf_bits) noexcept
137{
138 skew_map = ((scalar[0] & 1) == 0);
139 uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
140 const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
141
142 for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
143 uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
144 uint64_t predicate = ((slice & 1UL) == 0UL);
145 wnaf[(wnaf_entries - round_i)] =
146 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
147 (point_index);
148 previous = slice + predicate;
149 }
150 size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1));
151 uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
152 uint64_t predicate = ((slice & 1UL) == 0UL);
153
154 wnaf[1] = ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
155 (point_index);
156 wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
157}
158
159inline void fixed_wnaf(const uint64_t* scalar,
160 uint64_t* wnaf,
161 bool& skew_map,
162 const uint64_t point_index,
163 const uint64_t num_points,
164 const size_t wnaf_bits) noexcept
165{
166 skew_map = ((scalar[0] & 1) == 0);
167 uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
168 const size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
169
170 for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
171 uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
172 uint64_t predicate = ((slice & 1UL) == 0UL);
173 wnaf[(wnaf_entries - round_i) * num_points] =
174 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
175 (point_index);
176 previous = slice + predicate;
177 }
178 size_t final_bits = SCALAR_BITS - (wnaf_bits * (wnaf_entries - 1));
179 uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
180 uint64_t predicate = ((slice & 1UL) == 0UL);
181
182 wnaf[num_points] =
183 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
184 (point_index);
185 wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
186}
187
215inline uint64_t get_num_scalar_bits(const uint64_t* scalar)
216{
217 const uint64_t msb_1 = numeric::get_msb(scalar[1]);
218 const uint64_t msb_0 = numeric::get_msb(scalar[0]);
219
220 const uint64_t scalar_1_mask = (0ULL - (scalar[1] > 0));
221 const uint64_t scalar_0_mask = (0ULL - (scalar[0] > 0)) & ~scalar_1_mask;
222
223 const uint64_t msb = (scalar_1_mask & (msb_1 + 64)) | (scalar_0_mask & (msb_0));
224 return msb;
225}
226
255inline void fixed_wnaf_with_counts(const uint64_t* scalar,
256 uint64_t* wnaf,
257 bool& skew_map,
258 uint64_t* wnaf_round_counts,
259 const uint64_t point_index,
260 const uint64_t num_points,
261 const size_t wnaf_bits) noexcept
262{
263 const size_t max_wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
264 if ((scalar[0] | scalar[1]) == 0ULL) {
265 skew_map = false;
266 for (size_t round_i = 0; round_i < max_wnaf_entries; ++round_i) {
267 wnaf[(round_i)*num_points] = 0xffffffffffffffffULL;
268 }
269 return;
270 }
271 const auto current_scalar_bits = static_cast<size_t>(get_num_scalar_bits(scalar) + 1);
272 skew_map = ((scalar[0] & 1) == 0);
273 uint64_t previous = get_wnaf_bits(scalar, wnaf_bits, 0) + static_cast<uint64_t>(skew_map);
274 const auto wnaf_entries = static_cast<size_t>((current_scalar_bits + wnaf_bits - 1) / wnaf_bits);
275
276 if (wnaf_entries == 1) {
277 wnaf[(max_wnaf_entries - 1) * num_points] = (previous >> 1UL) | (point_index);
278 ++wnaf_round_counts[max_wnaf_entries - 1];
279 for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) {
280 wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL;
281 }
282 return;
283 }
284
285 // If there are several windows
286 for (size_t round_i = 1; round_i < wnaf_entries - 1; ++round_i) {
287
288 // Get a bit slice
289 uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
290
291 // Get the predicate (last bit is zero)
292 uint64_t predicate = ((slice & 1UL) == 0UL);
293
294 // Update round count
295 ++wnaf_round_counts[max_wnaf_entries - round_i];
296
297 // Calculate entry value
298 // If the last bit of current slice is 1, we simply put the previous value with the point index
299 // If the last bit of the current slice is 0, we negate everything, so that we subtract from the WNAF form and
300 // make it 0
301 wnaf[(max_wnaf_entries - round_i) * num_points] =
302 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
303 (point_index);
304
305 // Update the previous value to the next windows
306 previous = slice + predicate;
307 }
308 // The final iteration for top bits
309 auto final_bits = static_cast<size_t>(current_scalar_bits - (wnaf_bits * (wnaf_entries - 1)));
310 uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
311 uint64_t predicate = ((slice & 1UL) == 0UL);
312
313 ++wnaf_round_counts[(max_wnaf_entries - wnaf_entries + 1)];
314 wnaf[((max_wnaf_entries - wnaf_entries + 1) * num_points)] =
315 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
316 (point_index);
317
318 // Saving top bits
319 ++wnaf_round_counts[max_wnaf_entries - wnaf_entries];
320 wnaf[(max_wnaf_entries - wnaf_entries) * num_points] = ((slice + predicate) >> 1UL) | (point_index);
321
322 // Fill all unused slots with -1
323 for (size_t j = wnaf_entries; j < max_wnaf_entries; ++j) {
324 wnaf[(max_wnaf_entries - 1 - j) * num_points] = 0xffffffffffffffffULL;
325 }
326}
327
328template <size_t num_points, size_t wnaf_bits, size_t round_i>
329inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept
330{
331 constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
332 constexpr auto log2_num_points = static_cast<size_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
333
334 if constexpr (round_i < wnaf_entries - 1) {
335 uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
336 uint64_t predicate = ((slice & 1UL) == 0UL);
337 wnaf[(wnaf_entries - round_i) << log2_num_points] =
338 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
339 (point_index << 32UL);
340 wnaf_round<num_points, wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
341 } else {
342 constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits;
343 uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
344 // uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
345 uint64_t predicate = ((slice & 1UL) == 0UL);
346 wnaf[num_points] =
347 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
348 (point_index << 32UL);
349 wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
350 }
351}
352
353template <size_t scalar_bits, size_t num_points, size_t wnaf_bits, size_t round_i>
354inline void wnaf_round(uint64_t* scalar, uint64_t* wnaf, const uint64_t point_index, const uint64_t previous) noexcept
355{
356 constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits;
357 constexpr auto log2_num_points = static_cast<uint64_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
358
359 if constexpr (round_i < wnaf_entries - 1) {
360 uint64_t slice = get_wnaf_bits_const<wnaf_bits, round_i * wnaf_bits>(scalar);
361 uint64_t predicate = ((slice & 1UL) == 0UL);
362 wnaf[(wnaf_entries - round_i) << log2_num_points] =
363 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
364 (point_index << 32UL);
365 wnaf_round<scalar_bits, num_points, wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
366 } else {
367 constexpr size_t final_bits = ((scalar_bits / wnaf_bits) * wnaf_bits == scalar_bits)
368 ? wnaf_bits
369 : scalar_bits - (scalar_bits / wnaf_bits) * wnaf_bits;
370 uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
371 uint64_t predicate = ((slice & 1UL) == 0UL);
372 wnaf[num_points] =
373 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
374 (point_index << 32UL);
375 wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
376 }
377}
378
379template <size_t wnaf_bits, size_t round_i>
380inline void wnaf_round_packed(const uint64_t* scalar,
381 uint64_t* wnaf,
382 const uint64_t point_index,
383 const uint64_t previous) noexcept
384{
385 constexpr size_t wnaf_entries = (SCALAR_BITS + wnaf_bits - 1) / wnaf_bits;
386
387 if constexpr (round_i < wnaf_entries - 1) {
388 uint64_t slice = get_wnaf_bits(scalar, wnaf_bits, round_i * wnaf_bits);
389 // uint64_t slice = get_wnaf_bits_const<wnaf_bits, round_i * wnaf_bits>(scalar);
390 uint64_t predicate = ((slice & 1UL) == 0UL);
391 wnaf[(wnaf_entries - round_i)] =
392 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
393 (point_index);
394 wnaf_round_packed<wnaf_bits, round_i + 1>(scalar, wnaf, point_index, slice + predicate);
395 } else {
396 constexpr size_t final_bits = SCALAR_BITS - (SCALAR_BITS / wnaf_bits) * wnaf_bits;
397 uint64_t slice = get_wnaf_bits(scalar, final_bits, (wnaf_entries - 1) * wnaf_bits);
398 // uint64_t slice = get_wnaf_bits_const<final_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
399 uint64_t predicate = ((slice & 1UL) == 0UL);
400 wnaf[1] =
401 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
402 (point_index);
403
404 wnaf[0] = ((slice + predicate) >> 1UL) | (point_index);
405 }
406}
407
408template <size_t num_points, size_t wnaf_bits>
409inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept
410{
411 skew_map = ((scalar[0] & 1) == 0);
412 uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + static_cast<uint64_t>(skew_map);
413 wnaf_round<num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
414}
415
416template <size_t num_bits, size_t num_points, size_t wnaf_bits>
417inline void fixed_wnaf(uint64_t* scalar, uint64_t* wnaf, bool& skew_map, const size_t point_index) noexcept
418{
419 skew_map = ((scalar[0] & 1) == 0);
420 uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + static_cast<uint64_t>(skew_map);
421 wnaf_round<num_bits, num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
422}
423
424template <size_t scalar_bits, size_t num_points, size_t wnaf_bits, size_t round_i>
425inline void wnaf_round_with_restricted_first_slice(uint64_t* scalar,
426 uint64_t* wnaf,
427 const uint64_t point_index,
428 const uint64_t previous) noexcept
429{
430 constexpr size_t wnaf_entries = (scalar_bits + wnaf_bits - 1) / wnaf_bits;
431 constexpr auto log2_num_points = static_cast<uint64_t>(numeric::get_msb(static_cast<uint32_t>(num_points)));
432 constexpr size_t bits_in_first_slice = scalar_bits % wnaf_bits;
433 if constexpr (round_i == 1) {
434 uint64_t slice = get_wnaf_bits_const<wnaf_bits, (round_i - 1) * wnaf_bits + bits_in_first_slice>(scalar);
435 uint64_t predicate = ((slice & 1UL) == 0UL);
436
437 wnaf[(wnaf_entries - round_i) << log2_num_points] =
438 ((((previous - (predicate << (bits_in_first_slice /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) |
439 (predicate << 31UL)) |
440 (point_index << 32UL);
441 if (round_i == 1) {
442 std::cerr << "writing value " << std::hex << wnaf[(wnaf_entries - round_i) << log2_num_points] << std::dec
443 << " at index " << ((wnaf_entries - round_i) << log2_num_points) << std::endl;
444 }
445 wnaf_round_with_restricted_first_slice<scalar_bits, num_points, wnaf_bits, round_i + 1>(
446 scalar, wnaf, point_index, slice + predicate);
447
448 } else if constexpr (round_i < wnaf_entries - 1) {
449 uint64_t slice = get_wnaf_bits_const<wnaf_bits, (round_i - 1) * wnaf_bits + bits_in_first_slice>(scalar);
450 uint64_t predicate = ((slice & 1UL) == 0UL);
451 wnaf[(wnaf_entries - round_i) << log2_num_points] =
452 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
453 (point_index << 32UL);
454 wnaf_round_with_restricted_first_slice<scalar_bits, num_points, wnaf_bits, round_i + 1>(
455 scalar, wnaf, point_index, slice + predicate);
456 } else {
457 uint64_t slice = get_wnaf_bits_const<wnaf_bits, (wnaf_entries - 1) * wnaf_bits>(scalar);
458 uint64_t predicate = ((slice & 1UL) == 0UL);
459 wnaf[num_points] =
460 ((((previous - (predicate << (wnaf_bits /*+ 1*/))) ^ (0UL - predicate)) >> 1UL) | (predicate << 31UL)) |
461 (point_index << 32UL);
462 wnaf[0] = ((slice + predicate) >> 1UL) | (point_index << 32UL);
463 }
464}
465
466template <size_t num_bits, size_t num_points, size_t wnaf_bits>
467inline void fixed_wnaf_with_restricted_first_slice(uint64_t* scalar,
468 uint64_t* wnaf,
469 bool& skew_map,
470 const size_t point_index) noexcept
471{
472 constexpr size_t bits_in_first_slice = num_bits % wnaf_bits;
473 std::cerr << "bits in first slice = " << bits_in_first_slice << std::endl;
474 skew_map = ((scalar[0] & 1) == 0);
475 uint64_t previous = get_wnaf_bits_const<bits_in_first_slice, 0>(scalar) + static_cast<uint64_t>(skew_map);
476 std::cerr << "previous = " << previous << std::endl;
477 wnaf_round_with_restricted_first_slice<num_bits, num_points, wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
478}
479
480// template <size_t wnaf_bits>
481// inline void fixed_wnaf_packed(const uint64_t* scalar,
482// uint64_t* wnaf,
483// bool& skew_map,
484// const uint64_t point_index) noexcept
485// {
486// skew_map = ((scalar[0] & 1) == 0);
487// uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + (uint64_t)skew_map;
488// wnaf_round_packed<wnaf_bits, 1UL>(scalar, wnaf, point_index, previous);
489// }
490
491// template <size_t wnaf_bits>
492// inline constexpr std::array<uint32_t, WNAF_SIZE(wnaf_bits)> fixed_wnaf(const uint64_t *scalar) const noexcept
493// {
494// bool skew_map = ((scalar[0] * 1) == 0);
495// uint64_t previous = get_wnaf_bits_const<wnaf_bits, 0>(scalar) + (uint64_t)skew_map;
496// std::array<uint32_t, WNAF_SIZE(wnaf_bits)> result;
497// }
498} // namespace barretenberg::wnaf
499
500// NOLINTEND(readability-implicit-bool-conversion)