5#include "./eccvm_builder_types.hpp"
7namespace proof_system {
11 using CycleGroup =
typename Flavor::CycleGroup;
12 using FF =
typename Flavor::FF;
13 using Element =
typename CycleGroup::element;
14 using AffineElement =
typename CycleGroup::affine_element;
16 static constexpr size_t ADDITIONS_PER_ROW = proof_system_eccvm::ADDITIONS_PER_ROW;
17 static constexpr size_t NUM_SCALAR_BITS = proof_system_eccvm::NUM_SCALAR_BITS;
18 static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS;
22 uint32_t msm_size = 0;
23 uint32_t msm_count = 0;
24 uint32_t msm_round = 0;
25 bool msm_transition =
false;
27 bool q_double =
false;
33 AffineElement point{ 0, 0 };
35 FF collision_inverse = 0;
37 std::array<AddState, 4> add_state{
AddState{
false, 0, { 0, 0 }, 0, 0 },
38 AddState{
false, 0, { 0, 0 }, 0, 0 },
39 AddState{
false, 0, { 0, 0 }, 0, 0 },
40 AddState{
false, 0, { 0, 0 }, 0, 0 } };
56 static std::vector<MSMState>
compute_msm_state(
const std::vector<proof_system_eccvm::MSM<CycleGroup>>& msms,
57 std::array<std::vector<size_t>, 2>& point_table_read_counts,
58 const uint32_t total_number_of_muls)
68 const size_t table_rows =
static_cast<size_t>(total_number_of_muls) * 8;
69 point_table_read_counts[0].reserve(table_rows);
70 point_table_read_counts[1].reserve(table_rows);
71 for (
size_t i = 0; i < table_rows; ++i) {
72 point_table_read_counts[0].emplace_back(0);
73 point_table_read_counts[1].emplace_back(0);
75 const auto update_read_counts = [&](
const size_t pc,
const int slice) {
79 const size_t pc_delta = total_number_of_muls - pc;
80 const size_t pc_offset = pc_delta * 8;
81 bool slice_negative = slice < 0;
82 const int slice_row = (slice + 15) / 2;
84 const size_t column_index = slice_negative ? 1 : 0;
98 point_table_read_counts[column_index][pc_offset +
static_cast<size_t>(slice_row)]++;
100 point_table_read_counts[column_index][pc_offset + 15 -
static_cast<size_t>(slice_row)]++;
103 std::vector<MSMState> msm_state;
106 uint32_t pc = total_number_of_muls;
107 AffineElement accumulator = CycleGroup::affine_point_at_infinity;
109 for (
const auto& msm : msms) {
110 const size_t msm_size = msm.size();
112 const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0);
113 static constexpr size_t num_rounds = NUM_SCALAR_BITS / WNAF_SLICE_BITS;
115 const auto add_points = [](
auto& P1,
auto& P2,
auto& lambda,
auto& collision_inverse,
bool predicate) {
116 lambda = predicate ? (P2.y - P1.y) / (P2.x - P1.x) : 0;
117 collision_inverse = predicate ? (P2.x - P1.x).invert() : 0;
118 auto x3 = predicate ? lambda * lambda - (P2.x + P1.x) : P1.x;
119 auto y3 = predicate ? lambda * (P1.x - x3) - P1.y : P1.y;
120 return AffineElement(x3, y3);
122 for (
size_t j = 0; j < num_rounds; ++j) {
123 for (
size_t k = 0; k < rows_per_round; ++k) {
125 const size_t points_per_row =
126 (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW;
127 const size_t idx = k * ADDITIONS_PER_ROW;
128 row.msm_transition = (j == 0) && (k == 0);
130 AffineElement acc(accumulator);
131 Element acc_expected = accumulator;
132 for (
size_t m = 0; m < ADDITIONS_PER_ROW; ++m) {
133 auto& add_state = row.add_state[m];
134 add_state.add = points_per_row > m;
135 int slice = add_state.add ? msm[idx + m].wnaf_slices[j] : 0;
144 add_state.slice = add_state.add ? (slice + 15) / 2 : 0;
145 add_state.point = add_state.add
146 ? msm[idx + m].precomputed_table[
static_cast<size_t>(add_state.slice)]
147 : AffineElement{ 0, 0 };
155 bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add);
157 auto& p1 = (m == 0) ? add_state.point : acc;
158 auto& p2 = (m == 0) ? acc : add_state.point;
160 acc_expected = add_predicate ? (acc_expected + add_state.point) : Element(p1);
162 update_read_counts(pc - idx - m, slice);
164 acc = add_points(p1, p2, add_state.lambda, add_state.collision_inverse, add_predicate);
165 ASSERT(acc == AffineElement(acc_expected));
168 row.q_double =
false;
170 row.msm_round =
static_cast<uint32_t
>(j);
171 row.msm_size =
static_cast<uint32_t
>(msm_size);
172 row.msm_count =
static_cast<uint32_t
>(idx);
173 row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
174 row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
177 msm_state.push_back(row);
179 if (j < num_rounds - 1) {
181 row.msm_transition =
false;
182 row.msm_round =
static_cast<uint32_t
>(j + 1);
183 row.msm_size =
static_cast<uint32_t
>(msm_size);
184 row.msm_count =
static_cast<uint32_t
>(0);
189 auto dx = accumulator.x;
190 auto dy = accumulator.y;
191 for (
size_t m = 0; m < 4; ++m) {
192 auto& add_state = row.add_state[m];
193 add_state.add =
false;
195 add_state.point = { 0, 0 };
196 add_state.collision_inverse = 0;
197 add_state.lambda = ((dx + dx + dx) * dx) / (dy + dy);
198 auto x3 = add_state.lambda.sqr() - dx - dx;
199 dy = add_state.lambda * (dx - x3) - dy;
203 row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
204 row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
205 accumulator = Element(accumulator).dbl().dbl().dbl().dbl();
207 msm_state.push_back(row);
209 for (
size_t k = 0; k < rows_per_round; ++k) {
212 const size_t points_per_row =
213 (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW;
214 const size_t idx = k * ADDITIONS_PER_ROW;
215 row.msm_transition =
false;
217 AffineElement acc(accumulator);
218 Element acc_expected = accumulator;
220 for (
size_t m = 0; m < 4; ++m) {
221 auto& add_state = row.add_state[m];
222 add_state.add = points_per_row > m;
223 add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0;
225 add_state.point = add_state.add
226 ? msm[idx + m].precomputed_table[
static_cast<size_t>(add_state.slice)]
227 : AffineElement{ 0, 0 };
228 bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew :
false;
230 update_read_counts(pc - idx - m, msm[idx + m].wnaf_skew ? -1 : -15);
233 acc, add_state.point, add_state.lambda, add_state.collision_inverse, add_predicate);
234 acc_expected = add_predicate ? (acc_expected + add_state.point) : acc_expected;
235 ASSERT(acc == AffineElement(acc_expected));
238 row.q_double =
false;
240 row.msm_round =
static_cast<uint32_t
>(j + 1);
241 row.msm_size =
static_cast<uint32_t
>(msm_size);
242 row.msm_count =
static_cast<uint32_t
>(idx);
244 row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
245 row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
249 msm_state.emplace_back(row);
253 pc -=
static_cast<uint32_t
>(msm_size);
255 Element expected = CycleGroup::point_at_infinity;
256 for (
size_t i = 0; i < msm.size(); ++i) {
257 expected += (Element(msm[i].base_point) * msm[i].scalar);
260 ASSERT(accumulator == AffineElement(expected));
265 final_row.msm_transition =
true;
266 final_row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
267 final_row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
268 final_row.msm_size = 0;
269 final_row.msm_count = 0;
270 final_row.q_add =
false;
271 final_row.q_double =
false;
272 final_row.q_skew =
false;
273 final_row.add_state = {
typename MSMState::AddState{
false, 0, AffineElement{ 0, 0 }, 0, 0 },
278 msm_state.emplace_back(final_row);
Definition: msm_builder.hpp:9
static std::vector< MSMState > compute_msm_state(const std::vector< proof_system_eccvm::MSM< CycleGroup > > &msms, std::array< std::vector< size_t >, 2 > &point_table_read_counts, const uint32_t total_number_of_muls)
Computes the row values for the Straus MSM columns of the ECCVM.
Definition: msm_builder.hpp:56
Definition: msm_builder.hpp:30
Definition: msm_builder.hpp:20