2#include "barretenberg/common/mem.hpp"
3#include "barretenberg/common/slab_allocator.hpp"
4#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
5#include "barretenberg/plonk/proof_system/public_inputs/public_inputs.hpp"
6#include "barretenberg/polynomials/iterate_over_domain.hpp"
7#include "barretenberg/polynomials/polynomial.hpp"
8#include "barretenberg/polynomials/polynomial_arithmetic.hpp"
9#include "barretenberg/transcript/transcript.hpp"
13template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
14ProverPermutationWidget<program_width, idpolys, num_roots_cut_out_of_vanishing_polynomial>::ProverPermutationWidget(
15 proving_key* input_key)
16 : ProverRandomWidget(input_key)
19template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
20ProverPermutationWidget<program_width, idpolys, num_roots_cut_out_of_vanishing_polynomial>::ProverPermutationWidget(
21 const ProverPermutationWidget& other)
22 : ProverRandomWidget(other)
25template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
26ProverPermutationWidget<program_width, idpolys, num_roots_cut_out_of_vanishing_polynomial>::ProverPermutationWidget(
27 ProverPermutationWidget&& other)
28 : ProverRandomWidget(other)
31template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
32ProverPermutationWidget<program_width, idpolys, num_roots_cut_out_of_vanishing_polynomial>& ProverPermutationWidget<
35 num_roots_cut_out_of_vanishing_polynomial>::operator=(
const ProverPermutationWidget& other)
37 ProverRandomWidget::operator=(other);
41template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
42ProverPermutationWidget<program_width, idpolys, num_roots_cut_out_of_vanishing_polynomial>& ProverPermutationWidget<
45 num_roots_cut_out_of_vanishing_polynomial>::operator=(ProverPermutationWidget&& other)
47 ProverRandomWidget::operator=(other);
57template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
61 if (round_number != 3) {
69 size_t num_accumulators = (program_width == 1) ? 3 : program_width * 2;
70 std::shared_ptr<void> accumulators_ptrs[num_accumulators];
71 fr* accumulators[num_accumulators];
73 for (
size_t k = 0; k < num_accumulators; ++k) {
74 accumulators_ptrs[k] =
get_mem_slab(key->circuit_size *
sizeof(
fr));
75 accumulators[k] = (
fr*)accumulators_ptrs[k].get();
81 std::array<std::shared_ptr<fr[]>, program_width> lagrange_base_wires_ptr;
82 std::array<std::shared_ptr<fr[]>, program_width> lagrange_base_sigmas_ptr;
83 [[maybe_unused]] std::array<std::shared_ptr<fr[]>, program_width> lagrange_base_ids_ptr;
85 std::array<fr*, program_width> lagrange_base_wires;
86 std::array<fr*, program_width> lagrange_base_sigmas;
87 [[maybe_unused]] std::array<fr*, program_width> lagrange_base_ids;
89 for (
size_t i = 0; i < program_width; ++i) {
90 lagrange_base_wires_ptr[i] = key->polynomial_store.get(
"w_" + std::to_string(i + 1) +
"_lagrange").data();
91 lagrange_base_wires[i] = lagrange_base_wires_ptr[i].get();
92 lagrange_base_sigmas_ptr[i] = key->polynomial_store.get(
"sigma_" + std::to_string(i + 1) +
"_lagrange").data();
93 lagrange_base_sigmas[i] = lagrange_base_sigmas_ptr[i].get();
97 if constexpr (idpolys) {
98 lagrange_base_ids_ptr[i] = key->polynomial_store.get(
"id_" + std::to_string(i + 1) +
"_lagrange").data();
99 lagrange_base_ids[i] = lagrange_base_ids_ptr[i].get();
151 parallel_for(key->small_domain.num_threads, [&](
size_t j) {
152 barretenberg::fr thread_root = key->small_domain.root.pow(
153 static_cast<uint64_t>(j * key->small_domain.thread_size));
154 [[maybe_unused]] barretenberg::fr cur_root_times_beta = thread_root * beta;
156 barretenberg::fr wire_plus_gamma;
157 size_t start = j * key->small_domain.thread_size;
158 size_t end = (j + 1) * key->small_domain.thread_size;
159 for (size_t i = start; i < end; ++i) {
160 wire_plus_gamma = gamma + lagrange_base_wires[0][i];
162 if constexpr (!idpolys) {
163 accumulators[0][i] = wire_plus_gamma + cur_root_times_beta;
165 if constexpr (idpolys) {
166 T0 = lagrange_base_ids[0][i] * beta;
167 accumulators[0][i] = T0 + wire_plus_gamma;
170 T0 = lagrange_base_sigmas[0][i] * beta;
171 accumulators[program_width][i] = T0 + wire_plus_gamma;
173 for (size_t k = 1; k < program_width; ++k) {
174 wire_plus_gamma = gamma + lagrange_base_wires[k][i];
176 if constexpr (idpolys) {
177 T0 = lagrange_base_ids[k][i] * beta;
179 T0 = fr::coset_generator(k - 1) * cur_root_times_beta;
182 accumulators[k][i] = T0 + wire_plus_gamma;
184 T0 = lagrange_base_sigmas[k][i] * beta;
185 accumulators[k + program_width][i] = T0 + wire_plus_gamma;
187 if constexpr (!idpolys)
188 cur_root_times_beta *= key->small_domain.root;
206 parallel_for(program_width * 2, [&](
size_t i) {
207 fr* coeffs = &accumulators[i][0];
208 for (
size_t j = 0; j < key->small_domain.size - 1; ++j) {
209 coeffs[j + 1] *= coeffs[j];
235 parallel_for(key->small_domain.num_threads, [&](
size_t j) {
236 const size_t start = j * key->small_domain.thread_size;
238 ((j + 1) * key->small_domain.thread_size) - ((j == key->small_domain.num_threads - 1) ? 1 : 0);
239 barretenberg::fr inversion_accumulator = fr::one();
240 constexpr size_t inversion_index = (program_width == 1) ? 2 : program_width * 2 - 1;
241 fr* inversion_coefficients = &accumulators[inversion_index][0];
242 for (size_t i = start; i < end; ++i) {
244 for (size_t k = 1; k < program_width; ++k) {
245 accumulators[0][i] *= accumulators[k][i];
246 accumulators[program_width][i] *= accumulators[program_width + k][i];
248 inversion_coefficients[i] = accumulators[0][i] * inversion_accumulator;
249 inversion_accumulator *= accumulators[program_width][i];
251 inversion_accumulator = inversion_accumulator.invert();
252 for (
size_t i = end - 1; i != start - 1; --i) {
256 accumulators[0][i] = inversion_accumulator * inversion_coefficients[i];
257 inversion_accumulator *= accumulators[program_width][i];
263 polynomial z_perm(key->circuit_size);
264 z_perm[0] = fr::one();
265 barretenberg::polynomial_arithmetic::copy_polynomial(
266 accumulators[0], &z_perm[1], key->circuit_size - 1, key->circuit_size - 1);
310 const size_t z_randomness = 3;
311 ASSERT(z_randomness < num_roots_cut_out_of_vanishing_polynomial);
312 for (
size_t k = 0; k < z_randomness; ++k) {
313 z_perm[(key->circuit_size - num_roots_cut_out_of_vanishing_polynomial) + 1 + k] = fr::random_element();
316 z_perm.ifft(key->small_domain);
320 work_queue::WorkType::SCALAR_MULTIPLICATION,
329 work_queue::WorkType::FFT,
336 key->polynomial_store.put(
"z_perm", std::move(z_perm));
339template <
size_t program_w
idth,
bool idpolys, const
size_t num_roots_cut_out_of_vanishing_polynomial>
343 const polynomial& z_perm_fft = key->polynomial_store.get(
"z_perm_fft");
351 key->quotient_polynomial_parts[0][key->circuit_size] = 0;
352 key->quotient_polynomial_parts[1][key->circuit_size] = 0;
353 key->quotient_polynomial_parts[2][key->circuit_size] = 0;
372 std::array<std::shared_ptr<fr[]>, program_width> wire_ffts_ptr;
373 std::array<std::shared_ptr<fr[]>, program_width> sigma_ffts_ptr;
374 [[maybe_unused]] std::array<std::shared_ptr<fr[]>, program_width> id_ffts_ptr;
376 std::array<fr*, program_width> wire_ffts;
377 std::array<fr*, program_width> sigma_ffts;
378 [[maybe_unused]] std::array<fr*, program_width> id_ffts;
380 for (
size_t i = 0; i < program_width; ++i) {
384 wire_ffts_ptr[i] = key->polynomial_store.get(
"w_" + std::to_string(i + 1) +
"_fft").data();
385 sigma_ffts_ptr[i] = key->polynomial_store.get(
"sigma_" + std::to_string(i + 1) +
"_fft").data();
386 wire_ffts[i] = wire_ffts_ptr[i].get();
387 sigma_ffts[i] = sigma_ffts_ptr[i].get();
392 if constexpr (idpolys) {
393 id_ffts_ptr[i] = key->polynomial_store.get(
"id_" + std::to_string(i + 1) +
"_fft").data();
394 id_ffts[i] = id_ffts_ptr[i].get();
399 const polynomial& l_start = key->polynomial_store.get(
"lagrange_1_fft");
402 std::vector<barretenberg::fr> public_inputs = many_from_buffer<fr>(transcript.
get_element(
"public_inputs"));
405 compute_public_input_delta<fr>(public_inputs, beta, gamma, key->small_domain.root);
407 const size_t block_mask = key->large_domain.size - 1;
409 parallel_for(key->large_domain.num_threads, [&](
size_t j) {
410 const size_t start = j * key->large_domain.thread_size;
411 const size_t end = (j + 1) * key->large_domain.thread_size;
418 barretenberg::fr cur_root_times_beta =
419 key->large_domain.root.pow(static_cast<uint64_t>(j * key->large_domain.thread_size));
420 cur_root_times_beta *= key->small_domain.generator;
421 cur_root_times_beta *= beta;
423 barretenberg::fr wire_plus_gamma;
425 barretenberg::fr denominator;
426 barretenberg::fr numerator;
427 for (size_t i = start; i < end; ++i) {
428 wire_plus_gamma = gamma + wire_ffts[0][i];
431 if constexpr (!idpolys)
434 numerator = cur_root_times_beta + wire_plus_gamma;
436 numerator = id_ffts[0][i] * beta + wire_plus_gamma;
440 denominator = sigma_ffts[0][i] * beta;
441 denominator += wire_plus_gamma;
443 for (size_t k = 1; k < program_width; ++k) {
444 wire_plus_gamma = gamma + wire_ffts[k][i];
445 if constexpr (!idpolys)
447 T0 = fr::coset_generator(k - 1) * cur_root_times_beta;
448 if constexpr (idpolys)
449 T0 = id_ffts[k][i] * beta;
451 T0 += wire_plus_gamma;
455 T0 = sigma_ffts[k][i] * beta;
456 T0 += wire_plus_gamma;
460 numerator *= z_perm_fft[i];
461 denominator *= z_perm_fft[(i + 4) & block_mask];
499 T0 = z_perm_fft[(i + 4) & block_mask] - public_input_delta;
511 T0 *= l_start[(i + 4 + 4 * num_roots_cut_out_of_vanishing_polynomial) & block_mask];
519 T0 = z_perm_fft[i] - fr(1);
525 T0 = numerator - denominator;
526 key->quotient_polynomial_parts[i >> key->small_domain.log2_size][i & (key->circuit_size - 1)] =
530 cur_root_times_beta *= key->large_domain.root;
533 return alpha_base.sqr().sqr();
538template <
typename Field,
typename Group,
typename Transcript, const
size_t num_roots_cut_out_of_vanishing_polynomial>
539VerifierPermutationWidget<Field, Group, Transcript, num_roots_cut_out_of_vanishing_polynomial>::
540 VerifierPermutationWidget()
561template <
typename Field,
typename Group,
typename Transcript, const
size_t num_roots_cut_out_of_vanishing_polynomial>
565 const Transcript& transcript,
566 Field& quotient_numerator_eval,
570 Field alpha_squared = alpha.sqr();
571 Field alpha_cubed = alpha_squared * alpha;
572 Field z = transcript.get_challenge_field_element(
"z");
573 Field beta = transcript.get_challenge_field_element(
"beta", 0);
574 Field gamma = transcript.get_challenge_field_element(
"beta", 1);
575 Field z_beta = z * beta;
579 std::vector<Field> wire_evaluations;
580 std::vector<Field> sigma_evaluations;
582 for (
size_t i = 0; i < key->program_width; ++i) {
583 std::string index = std::to_string(i + 1);
584 sigma_evaluations.emplace_back(transcript.get_field_element(
"sigma_" + index));
587 for (
size_t i = 0; i < key->program_width; ++i) {
588 wire_evaluations.emplace_back(transcript.get_field_element(
589 "w_" + std::to_string(
602 Field numerator = key->z_pow_n - Field(1);
604 numerator *= key->domain.domain_inverse;
605 Field l_start = numerator / (z - Field(1));
608 Field l_end_root = (num_roots_cut_out_of_vanishing_polynomial & 1) ? key->domain.root.sqr() : key->domain.root;
609 for (
size_t i = 0; i < num_roots_cut_out_of_vanishing_polynomial / 2; ++i) {
610 l_end_root *= key->domain.root.sqr();
612 Field l_end = numerator / ((z * l_end_root) - Field(1));
614 Field z_1_shifted_eval = transcript.get_field_element(
"z_perm_omega");
634 Field sigma_contribution = Field(1);
636 for (
size_t i = 0; i < key->program_width - 1; ++i) {
637 T0 = sigma_evaluations[i] * beta;
638 T1 = wire_evaluations[i] + gamma;
640 sigma_contribution *= T0;
643 T0 = wire_evaluations[key->program_width - 1] + gamma;
644 sigma_contribution *= T0;
645 sigma_contribution *= z_1_shifted_eval;
646 sigma_contribution *= alpha;
653 std::vector<Field> public_inputs = transcript.get_field_element_vector(
"public_inputs");
654 Field public_input_delta = compute_public_input_delta<Field>(public_inputs, beta, gamma, key->domain.root);
656 T1 = z_1_shifted_eval - public_input_delta;
664 T2 = l_start * alpha_cubed;
673 T1 -= sigma_contribution;
674 quotient_numerator_eval += T1;
683 sigma_contribution = Field(1);
684 for (
size_t i = 0; i < key->program_width - 1; ++i) {
685 T0 = sigma_evaluations[i] * beta;
686 T0 += wire_evaluations[i];
688 sigma_contribution *= T0;
690 sigma_contribution *= z_1_shifted_eval;
691 Field sigma_last_multiplicand = -(sigma_contribution * alpha);
692 sigma_last_multiplicand *= beta;
710 quotient_numerator_eval += (sigma_last_multiplicand * sigma_evaluations[key->program_width - 1]);
712 Field z_eval = transcript.get_field_element(
"z_perm");
725 Field id_contribution = Field(1);
726 for (
size_t i = 0; i < key->program_width; ++i) {
727 Field id_evaluation = transcript.get_field_element(
"id_" + std::to_string(i + 1));
728 T0 = id_evaluation * beta;
729 T0 += wire_evaluations[i];
731 id_contribution *= T0;
733 Field id_last_multiplicand = id_contribution * alpha;
734 T0 = l_start * alpha_cubed;
735 id_last_multiplicand += T0;
749 quotient_numerator_eval += (id_last_multiplicand * z_eval);
760 Field z_contribution = Field(1);
761 for (
size_t i = 0; i < key->program_width; ++i) {
762 Field coset_generator = (i == 0) ? Field(1) : Field::coset_generator(i - 1);
763 T0 = z_beta * coset_generator;
764 T0 += wire_evaluations[i];
766 z_contribution *= T0;
768 Field z_1_multiplicand = (z_contribution * alpha);
769 T0 = l_start * alpha_cubed;
770 z_1_multiplicand += T0;
773 quotient_numerator_eval += (z_1_multiplicand * z_eval);
775 return alpha_squared.sqr();
778template <
typename Field,
typename Group,
typename Transcript, const
size_t num_roots_cut_out_of_vanishing_polynomial>
781 const Field& alpha_base,
782 const Transcript& transcript)
784 Field alpha_step = transcript.get_challenge_field_element(
"alpha");
785 return alpha_base * alpha_step.sqr() * alpha_step;
Definition: affine_element.hpp:11
Definition: work_queue.hpp:11
Definition: transcript_wrappers.hpp:13
std::array< uint8_t, PRNG_OUTPUT_SIZE > get_challenge(const std::string &challenge_name, const size_t idx=0) const
Definition: transcript.cpp:308
std::vector< uint8_t > get_element(const std::string &element_name) const
Definition: transcript.cpp:392
std::shared_ptr< void > get_mem_slab(size_t size)
Definition: slab_allocator.cpp:214
Definition: widget.bench.cpp:13
BBERG_INLINE constexpr field sqr() const noexcept
Definition: field_impl.hpp:61