barretenberg
Loading...
Searching...
No Matches
safe_uint.hpp
1#pragma once
2#include "../bool/bool.hpp"
3#include "../circuit_builders/circuit_builders.hpp"
4#include "../circuit_builders/circuit_builders_fwd.hpp"
5#include "../field/field.hpp"
6#include "../witness/witness.hpp"
7#include "barretenberg/common/assert.hpp"
8#include <functional>
9
10// The purpose of this class is to enable positive integer operations without a risk of overflow.
11// Despite the name, it is *not* a "safe" version of the uint class - as operations are positive integer
12// operations, and not modulo 2^t for some t, as they are in the uint class.
13
14namespace proof_system::plonk {
15namespace stdlib {
16
17template <typename Builder> class safe_uint_t {
18 private:
21 // this constructor is private since we only want the operators to be able to define a positive int without a range
22 // check.
23 safe_uint_t(field_ct const& value, uint256_t current_max, size_t safety)
24 : value(value)
25 , current_max(current_max)
26 {
27 ASSERT(safety == IS_UNSAFE);
28 if (current_max > MAX_VALUE) // For optimal efficiency this should only be checked while testing a circuit
29 {
30 throw_or_abort("exceeded modulus in safe_uint class");
31 }
32 }
33
34 public:
35 // The following constant should be small enough that any thing with this bitnum is smaller than the modulus
36 static constexpr size_t MAX_BIT_NUM = barretenberg::fr::modulus.get_msb();
37 static constexpr uint256_t MAX_VALUE = barretenberg::fr::modulus - 1;
38 static constexpr size_t IS_UNSAFE = 143; // weird constant to make it hard to use accidentally
39 // Make sure our uint256 values don't wrap - add_two function sums three of these
40 static_assert((uint512_t)MAX_VALUE * 3 < (uint512_t)1 << 256);
41 field_ct value;
42 uint256_t current_max;
43
45 : value(0)
46 , current_max(0)
47 {}
48
49 safe_uint_t(field_ct const& value, size_t bit_num, std::string const& description = "unknown")
50 : value(value)
51 {
52 ASSERT(bit_num <= MAX_BIT_NUM);
53 this->value.create_range_constraint(bit_num, format("safe_uint_t range constraint failure: ", description));
54 current_max = ((uint256_t)1 << bit_num) - 1;
55 }
56
57 // When initialzing a constant, we can set the max value to the constant itself (rather than the usually larger
58 // 2^n-1)
59 safe_uint_t(const barretenberg::fr& const_value)
60 : value(const_value)
61 , current_max(const_value)
62 {}
63
64 // When initialzing a constant, we can set the max value to the constant itself (rather than the usually larger
65 // 2^n-1)
66 safe_uint_t(const uint256_t& const_value)
67 : value(barretenberg::fr(const_value))
68 , current_max(barretenberg::fr(const_value))
69 {}
70 safe_uint_t(const unsigned int& const_value)
71 : value(barretenberg::fr(const_value))
72 , current_max(barretenberg::fr(const_value))
73 {}
74
75 safe_uint_t(const safe_uint_t& other)
76 : value(other.value)
77 , current_max(other.current_max)
78 {}
79
80 static safe_uint_t<Builder> create_constant_witness(Builder* parent_context, barretenberg::fr const& value)
81
82 {
83 witness_t<Builder> out(parent_context, value);
84 parent_context->assert_equal_constant(out.witness_index, value);
85 return safe_uint_t(value, uint256_t(value), IS_UNSAFE);
86 }
87
88 // We take advantage of the range constraint already being applied in the bool constructor and don't make a
89 // redundant one.
90 safe_uint_t(const bool_ct& other)
91 : value(other)
92 , current_max(1)
93 {}
94
95 explicit operator bool_ct() { return bool_ct(value); }
96 static safe_uint_t from_witness_index(Builder* parent_context, const uint32_t witness_index);
97
98 // Subtraction when you have a pre-determined bound on the difference size
99 safe_uint_t subtract(const safe_uint_t& other,
100 const size_t difference_bit_size,
101 std::string const& description = "") const;
102
103 safe_uint_t operator-(const safe_uint_t& other) const;
104
105 // division when you have a pre-determined bound on the sizes of the quotient and remainder
107 const safe_uint_t& other,
108 const size_t quotient_bit_size,
109 const size_t remainder_bit_size,
110 std::string const& description = "",
111 const std::function<std::pair<uint256_t, uint256_t>(uint256_t, uint256_t)>& get_quotient =
112 [](uint256_t val, uint256_t divisor) {
113 return std::make_pair((uint256_t)(val / (uint256_t)divisor), (uint256_t)(val % (uint256_t)divisor));
114 }) const;
115
116 // Potentially less efficient than divide function - bounds remainder and quotient by max of this
117 safe_uint_t operator/(const safe_uint_t& other) const;
118
119 safe_uint_t add_two(const safe_uint_t& add_a, const safe_uint_t& add_b) const
120 {
121 ASSERT(current_max + add_a.current_max + add_b.current_max <= MAX_VALUE && "Exceeded modulus in add_two");
122 auto new_val = value.add_two(add_a.value, add_b.value);
123 auto new_max = current_max + add_a.current_max + add_b.current_max;
124 return safe_uint_t(new_val, new_max, IS_UNSAFE);
125 }
126
127 safe_uint_t madd(const safe_uint_t& to_mul, const safe_uint_t& to_add) const
128 {
129 ASSERT((uint512_t)current_max * (uint512_t)to_mul.current_max + (uint512_t)to_add.current_max <= MAX_VALUE &&
130 "Exceeded modulus in madd");
131 auto new_val = value.madd(to_mul.value, to_add.value);
132 auto new_max = current_max * to_mul.current_max + to_add.current_max;
133 return safe_uint_t(new_val, new_max, IS_UNSAFE);
134 }
135
136 safe_uint_t& operator=(const safe_uint_t& other)
137 {
138 value = other.value;
139 current_max = other.current_max;
140 return *this;
141 }
142
143 safe_uint_t& operator=(safe_uint_t&& other)
144 {
145 value = other.value;
146 current_max = other.current_max;
147 return *this;
148 }
149
150 safe_uint_t operator+=(const safe_uint_t& other)
151 {
152 *this = *this + other;
153 return *this;
154 }
155
156 safe_uint_t operator*=(const safe_uint_t& other)
157 {
158 *this = *this * other;
159 return *this;
160 }
161
162 std::array<safe_uint_t<Builder>, 3> slice(const uint8_t msb, const uint8_t lsb) const;
163 void set_public() const { value.set_public(); }
164 operator field_ct() { return value; }
165 operator field_ct() const { return value; }
166 safe_uint_t operator+(const safe_uint_t& other) const;
167 safe_uint_t operator*(const safe_uint_t& other) const;
168 bool_ct operator==(const safe_uint_t& other) const;
169 bool_ct operator!=(const safe_uint_t& other) const;
170
178 safe_uint_t normalize() const;
179
180 barretenberg::fr get_value() const;
181
182 Builder* get_context() const { return value.context; }
183
188 bool_ct is_zero() const;
189
190 void assert_equal(const safe_uint_t& rhs, std::string const& msg = "safe_uint_t::assert_equal") const
191 {
192 this->value.assert_equal(rhs.value, msg);
193 }
194 void assert_is_not_zero(std::string const& msg = "safe_uint_t::assert_is_not_zero") const;
195 void assert_is_zero(std::string const& msg = "safe_uint_t::assert_is_zero") const;
196 bool is_constant() const { return value.is_constant(); }
197
198 static safe_uint_t conditional_assign(const bool_ct& predicate, const safe_uint_t& lhs, const safe_uint_t& rhs)
199 {
200 auto new_val = (lhs.value - rhs.value).madd(predicate, rhs.value);
201 auto new_max = lhs.current_max > rhs.current_max ? lhs.current_max : rhs.current_max;
202 return safe_uint_t(new_val, new_max, IS_UNSAFE);
203 }
204
205 uint32_t get_witness_index() const { return value.get_witness_index(); }
206};
207
208template <typename Builder> inline std::ostream& operator<<(std::ostream& os, safe_uint_t<Builder> const& v)
209{
210 return os << v.value;
211}
212
213EXTERN_STDLIB_TYPE(safe_uint_t);
214
215} // namespace stdlib
216} // namespace proof_system::plonk
Definition: uint256.hpp:25
Definition: standard_circuit_builder.hpp:12
Definition: field.hpp:10
void assert_equal(const field_t &rhs, std::string const &msg="field_t::assert_equal") const
Constrain that this field is equal to the given field.
Definition: field.cpp:749
field_t madd(const field_t &to_mul, const field_t &to_add) const
Definition: field.cpp:384
Definition: safe_uint.hpp:17
safe_uint_t operator/(const safe_uint_t &other) const
Potentially less efficient than divide function - bounds remainder and quotient by max of this.
Definition: safe_uint.cpp:143
safe_uint_t subtract(const safe_uint_t &other, const size_t difference_bit_size, std::string const &description="") const
Subtraction when you have a pre-determined bound on the difference size.
Definition: safe_uint.cpp:35
safe_uint_t operator-(const safe_uint_t &other) const
Subtraction on two safe_uint_t objects.
Definition: safe_uint.cpp:67
safe_uint_t divide(const safe_uint_t &other, const size_t quotient_bit_size, const size_t remainder_bit_size, std::string const &description="", const std::function< std::pair< uint256_t, uint256_t >(uint256_t, uint256_t)> &get_quotient=[](uint256_t val, uint256_t divisor) { return std::make_pair((uint256_t)(val/(uint256_t) divisor),(uint256_t)(val %(uint256_t) divisor));}) const
division when you have a pre-determined bound on the sizes of the quotient and remainder
Definition: safe_uint.cpp:104
safe_uint_t normalize() const
Definition: safe_uint.cpp:168
bool_ct is_zero() const
Definition: safe_uint.cpp:184
Definition: witness.hpp:10
Definition: widget.bench.cpp:13