barretenberg
Loading...
Searching...
No Matches
msm_builder.hpp
1#pragma once
2
3#include <cstddef>
4
5#include "./eccvm_builder_types.hpp"
6
7namespace proof_system {
8
9template <typename Flavor> class ECCVMMSMMBuilder {
10 public:
11 using CycleGroup = typename Flavor::CycleGroup;
12 using FF = typename Flavor::FF;
13 using Element = typename CycleGroup::element;
14 using AffineElement = typename CycleGroup::affine_element;
15
16 static constexpr size_t ADDITIONS_PER_ROW = proof_system_eccvm::ADDITIONS_PER_ROW;
17 static constexpr size_t NUM_SCALAR_BITS = proof_system_eccvm::NUM_SCALAR_BITS;
18 static constexpr size_t WNAF_SLICE_BITS = proof_system_eccvm::WNAF_SLICE_BITS;
19
20 struct MSMState {
21 uint32_t pc = 0;
22 uint32_t msm_size = 0;
23 uint32_t msm_count = 0;
24 uint32_t msm_round = 0;
25 bool msm_transition = false;
26 bool q_add = false;
27 bool q_double = false;
28 bool q_skew = false;
29
30 struct AddState {
31 bool add = false;
32 int slice = 0;
33 AffineElement point{ 0, 0 };
34 FF lambda = 0;
35 FF collision_inverse = 0;
36 };
37 std::array<AddState, 4> add_state{ AddState{ false, 0, { 0, 0 }, 0, 0 },
38 AddState{ false, 0, { 0, 0 }, 0, 0 },
39 AddState{ false, 0, { 0, 0 }, 0, 0 },
40 AddState{ false, 0, { 0, 0 }, 0, 0 } };
41 FF accumulator_x = 0;
42 FF accumulator_y = 0;
43 };
44
56 static std::vector<MSMState> compute_msm_state(const std::vector<proof_system_eccvm::MSM<CycleGroup>>& msms,
57 std::array<std::vector<size_t>, 2>& point_table_read_counts,
58 const uint32_t total_number_of_muls)
59 {
60 // N.B. the following comments refer to a "point lookup table" frequently.
61 // To perform a scalar multiplicaiton of a point [P] by a scalar x, we compute multiples of [P] and store in a
62 // table: specifically: -15[P], -13[P], ..., -3[P], -[P], [P], 3[P], ..., 15[P] when we define our point lookup
63 // table, we have 2 write columns and 4 read columns when we perform a read on a given row, we need to increment
64 // the read count on the respective write column by 1 we can define the following struture: 1st write column =
65 // positive 2nd write column = negative the row number is a function of pc and slice value row = pc_delta *
66 // rows_per_point_table + some function of the slice value pc_delta = total_number_of_muls - pc
67 // std::vector<std::array<size_t, > point_table_read_counts;
68 const size_t table_rows = static_cast<size_t>(total_number_of_muls) * 8;
69 point_table_read_counts[0].reserve(table_rows);
70 point_table_read_counts[1].reserve(table_rows);
71 for (size_t i = 0; i < table_rows; ++i) {
72 point_table_read_counts[0].emplace_back(0);
73 point_table_read_counts[1].emplace_back(0);
74 }
75 const auto update_read_counts = [&](const size_t pc, const int slice) {
76 // When we compute our wnaf/point tables, we start with the point with the largest pc value.
77 // i.e. if we are reading a slice for point with a point counter value `pc`,
78 // its position in the wnaf/point table (relative to other points) will be `total_number_of_muls - pc`
79 const size_t pc_delta = total_number_of_muls - pc;
80 const size_t pc_offset = pc_delta * 8;
81 bool slice_negative = slice < 0;
82 const int slice_row = (slice + 15) / 2;
83
84 const size_t column_index = slice_negative ? 1 : 0;
85
97 if (slice_negative) {
98 point_table_read_counts[column_index][pc_offset + static_cast<size_t>(slice_row)]++;
99 } else {
100 point_table_read_counts[column_index][pc_offset + 15 - static_cast<size_t>(slice_row)]++;
101 }
102 };
103 std::vector<MSMState> msm_state;
104 // start with empty row (shiftable polynomials must have 0 as first coefficient)
105 msm_state.emplace_back(MSMState{});
106 uint32_t pc = total_number_of_muls;
107 AffineElement accumulator = CycleGroup::affine_point_at_infinity;
108
109 for (const auto& msm : msms) {
110 const size_t msm_size = msm.size();
111
112 const size_t rows_per_round = (msm_size / ADDITIONS_PER_ROW) + (msm_size % ADDITIONS_PER_ROW != 0 ? 1 : 0);
113 static constexpr size_t num_rounds = NUM_SCALAR_BITS / WNAF_SLICE_BITS;
114
115 const auto add_points = [](auto& P1, auto& P2, auto& lambda, auto& collision_inverse, bool predicate) {
116 lambda = predicate ? (P2.y - P1.y) / (P2.x - P1.x) : 0;
117 collision_inverse = predicate ? (P2.x - P1.x).invert() : 0;
118 auto x3 = predicate ? lambda * lambda - (P2.x + P1.x) : P1.x;
119 auto y3 = predicate ? lambda * (P1.x - x3) - P1.y : P1.y;
120 return AffineElement(x3, y3);
121 };
122 for (size_t j = 0; j < num_rounds; ++j) {
123 for (size_t k = 0; k < rows_per_round; ++k) {
124 MSMState row;
125 const size_t points_per_row =
126 (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW;
127 const size_t idx = k * ADDITIONS_PER_ROW;
128 row.msm_transition = (j == 0) && (k == 0);
129
130 AffineElement acc(accumulator);
131 Element acc_expected = accumulator;
132 for (size_t m = 0; m < ADDITIONS_PER_ROW; ++m) {
133 auto& add_state = row.add_state[m];
134 add_state.add = points_per_row > m;
135 int slice = add_state.add ? msm[idx + m].wnaf_slices[j] : 0;
136 // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row.
137 // if `row.add_state[m].add = 1`, this indicates that we want to add the `m`'th point in the MSM
138 // columns into the MSM accumulator
139 // `add_state.slice` = A 4-bit WNAF slice of the scalar multiplier associated with the point we
140 // are adding (the specific slice chosen depends on the value of msm_round) (WNAF =
141 // windowed-non-adjacent-form. Value range is `-15, -13, ..., 15`) If `add_state.add = 1`, we
142 // want `add_state.slice` to be the *compressed* form of the WNAF slice value. (compressed = no
143 // gaps in the value range. i.e. -15, -13, ..., 15 maps to 0, ... , 15)
144 add_state.slice = add_state.add ? (slice + 15) / 2 : 0;
145 add_state.point = add_state.add
146 ? msm[idx + m].precomputed_table[static_cast<size_t>(add_state.slice)]
147 : AffineElement{ 0, 0 };
148 // predicate logic:
149 // add_predicate should normally equal add_state.add
150 // However! if j == 0 AND k == 0 AND m == 0 this implies we are examing the 1st point addition
151 // of a new MSM In this case, we do NOT add the 1st point into the accumulator, instead we SET
152 // the accumulator to equal the 1st point. add_predicate is used to determine whether we add the
153 // output of a point addition into the accumulator, therefore if j == 0 AND k == 0 AND m == 0,
154 // add_predicate = 0 even if add_state.add = true
155 bool add_predicate = (m == 0 ? (j != 0 || k != 0) : add_state.add);
156
157 auto& p1 = (m == 0) ? add_state.point : acc;
158 auto& p2 = (m == 0) ? acc : add_state.point;
159
160 acc_expected = add_predicate ? (acc_expected + add_state.point) : Element(p1);
161 if (add_state.add) {
162 update_read_counts(pc - idx - m, slice);
163 }
164 acc = add_points(p1, p2, add_state.lambda, add_state.collision_inverse, add_predicate);
165 ASSERT(acc == AffineElement(acc_expected));
166 }
167 row.q_add = true;
168 row.q_double = false;
169 row.q_skew = false;
170 row.msm_round = static_cast<uint32_t>(j);
171 row.msm_size = static_cast<uint32_t>(msm_size);
172 row.msm_count = static_cast<uint32_t>(idx);
173 row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
174 row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
175 row.pc = pc;
176 accumulator = acc;
177 msm_state.push_back(row);
178 }
179 if (j < num_rounds - 1) {
180 MSMState row;
181 row.msm_transition = false;
182 row.msm_round = static_cast<uint32_t>(j + 1);
183 row.msm_size = static_cast<uint32_t>(msm_size);
184 row.msm_count = static_cast<uint32_t>(0);
185 row.q_add = false;
186 row.q_double = true;
187 row.q_skew = false;
188
189 auto dx = accumulator.x;
190 auto dy = accumulator.y;
191 for (size_t m = 0; m < 4; ++m) {
192 auto& add_state = row.add_state[m];
193 add_state.add = false;
194 add_state.slice = 0;
195 add_state.point = { 0, 0 };
196 add_state.collision_inverse = 0;
197 add_state.lambda = ((dx + dx + dx) * dx) / (dy + dy);
198 auto x3 = add_state.lambda.sqr() - dx - dx;
199 dy = add_state.lambda * (dx - x3) - dy;
200 dx = x3;
201 }
202
203 row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
204 row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
205 accumulator = Element(accumulator).dbl().dbl().dbl().dbl();
206 row.pc = pc;
207 msm_state.push_back(row);
208 } else {
209 for (size_t k = 0; k < rows_per_round; ++k) {
210 MSMState row;
211
212 const size_t points_per_row =
213 (k + 1) * ADDITIONS_PER_ROW > msm_size ? msm_size % ADDITIONS_PER_ROW : ADDITIONS_PER_ROW;
214 const size_t idx = k * ADDITIONS_PER_ROW;
215 row.msm_transition = false;
216
217 AffineElement acc(accumulator);
218 Element acc_expected = accumulator;
219
220 for (size_t m = 0; m < 4; ++m) {
221 auto& add_state = row.add_state[m];
222 add_state.add = points_per_row > m;
223 add_state.slice = add_state.add ? msm[idx + m].wnaf_skew ? 7 : 0 : 0;
224
225 add_state.point = add_state.add
226 ? msm[idx + m].precomputed_table[static_cast<size_t>(add_state.slice)]
227 : AffineElement{ 0, 0 };
228 bool add_predicate = add_state.add ? msm[idx + m].wnaf_skew : false;
229 if (add_state.add) {
230 update_read_counts(pc - idx - m, msm[idx + m].wnaf_skew ? -1 : -15);
231 }
232 acc = add_points(
233 acc, add_state.point, add_state.lambda, add_state.collision_inverse, add_predicate);
234 acc_expected = add_predicate ? (acc_expected + add_state.point) : acc_expected;
235 ASSERT(acc == AffineElement(acc_expected));
236 }
237 row.q_add = false;
238 row.q_double = false;
239 row.q_skew = true;
240 row.msm_round = static_cast<uint32_t>(j + 1);
241 row.msm_size = static_cast<uint32_t>(msm_size);
242 row.msm_count = static_cast<uint32_t>(idx);
243
244 row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
245 row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
246
247 row.pc = pc;
248 accumulator = acc;
249 msm_state.emplace_back(row);
250 }
251 }
252 }
253 pc -= static_cast<uint32_t>(msm_size);
254 // Validate our computed accumulator matches the real MSM result!
255 Element expected = CycleGroup::point_at_infinity;
256 for (size_t i = 0; i < msm.size(); ++i) {
257 expected += (Element(msm[i].base_point) * msm[i].scalar);
258 }
259 // Validate the accumulator is correct!
260 ASSERT(accumulator == AffineElement(expected));
261 }
262
263 MSMState final_row;
264 final_row.pc = pc;
265 final_row.msm_transition = true;
266 final_row.accumulator_x = accumulator.is_point_at_infinity() ? 0 : accumulator.x;
267 final_row.accumulator_y = accumulator.is_point_at_infinity() ? 0 : accumulator.y;
268 final_row.msm_size = 0;
269 final_row.msm_count = 0;
270 final_row.q_add = false;
271 final_row.q_double = false;
272 final_row.q_skew = false;
273 final_row.add_state = { typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
274 typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
275 typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
276 typename MSMState::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } };
277
278 msm_state.emplace_back(final_row);
279 return msm_state;
280 }
281};
282} // namespace proof_system
Definition: msm_builder.hpp:9
static std::vector< MSMState > compute_msm_state(const std::vector< proof_system_eccvm::MSM< CycleGroup > > &msms, std::array< std::vector< size_t >, 2 > &point_table_read_counts, const uint32_t total_number_of_muls)
Computes the row values for the Straus MSM columns of the ECCVM.
Definition: msm_builder.hpp:56
Definition: msm_builder.hpp:20