Various possible simplifications and optimizations
We collect here some simplifications and optimizations we encountered that are possible in the codebase.
More efficient implementation for function Cmp
In the sdk's sdk/utils.go, the function Cmp
is implemented as follows:
// Inspired by
// https://github.com/Consensys/gnark/blob/
// 429616e33c97ed21113dd87787c043e8fb43720c/frontend/cs/scs/api.go#L523
// To reduce constraints comsumption, use predefined number of variable's bits.
func Cmp(api frontend.API, i1, i2 frontend.Variable, nbBits int) frontend.Variable {
bi1 := bits.ToBinary(api, i1, bits.WithNbDigits(nbBits))
bi2 := bits.ToBinary(api, i2, bits.WithNbDigits(nbBits))
var res frontend.Variable
res = 0
for i := nbBits - 1; i >= 0; i-- {
iszeroi1 := api.IsZero(bi1[i])
iszeroi2 := api.IsZero(bi2[i])
i1i2 := api.And(bi1[i], iszeroi2)
i2i1 := api.And(bi2[i], iszeroi1)
n := api.Select(i2i1, -1, 0)
m := api.Select(i1i2, 1, n)
res = api.Select(api.IsZero(res), m, res)
}
return res
}
This function checks that i1
and i2
are smaller than 2^nbBits
, and it constrains the return value to -1 if i1 < i2
, to 0 if i1 == i2
, and to 1 if i1 > i2
.
The standard gnark function with the same purpose referred to in the comment works similarly, but it decomposes both i1
and i2
into as many bits as the bit width of the field over which the circuit variables are defined. In some cases (like in the context of the 32- and 64-bit types of the sdk, for which this function is used), it may be known that i1
and i2
should be representable with fewer bits, so it is more efficient in witnesses and constraints to use less, as implemented in the above function.
However, in those cases, there is a much more efficient way to implement this function. Let first r
be the prime modulus of the field over which the circuit is defined, and n
the bit width of r
so that 2^n > r >= 2^(n-1)
. Assume that a
and b
are two values that satisfy 0 <= a, b < 2^k
for some k
(in other words, a
and b
have at most bit width k
). If a < b
, we have 0 < b - a < 2^k
and -2^k < a - b < 0
. Assuming k <= n-2
, we get the following results: 0 < (b - a) % r < 2^k
, r - 2^k < (a - b) % r < r
, and 2^k <= r - 2^k
(the latter is equivalent to r >= 2^(k+1)
, and this holds as k+1 <= n-1
). The upshot is that if we have constrained both a
and b
to be at most k
bits wide, and a - b
is also at most k
bits wide, then we must have a >= b
.
Checking that a circuit variable is at most k
bits wide is more efficient in gnark than actually constructing the bit decomposition, and in the Cmp
use case we do not actually need the bit decomposition. Thus, the following kind of implementation is more efficient:
func CmpNew(api frontend.API, i1, i2 frontend.Variable, nbBits int) frontend.Variable {
if nbBits > api.Compiler().Field().BitLen()-2 {
panic("CmpNew called with nbBits too large!")
}
rangeChecker := rangecheck.New(api)
rangeChecker.Check(i1, nbBits)
rangeChecker.Check(i2, nbBits)
results, err := api.Compiler().NewHint(CmpHint, 1, i1, i2)
result := results[0]
if err != nil {
panic(err)
}
// Enforce that result is -1, 0, or 1
result_sq := api.Mul(result, result)
api.AssertIsBoolean(result_sq)
// 1 if i1 < i2 according to hint, 0 else
first_smaller := api.IsZero(api.Add(result, 1))
// Select which is the bigger/smaller of i1 and i2, according to the hint
bigger := api.Select(first_smaller, i2, i1)
smaller := api.Select(first_smaller, i1, i2)
// Get the difference, which should be still nonnegative if the hint is correct.
diff := api.Sub(bigger, smaller)
rangeChecker.Check(diff, nbBits)
// If hint said that i1 < i2 but in fact i2 > i1, then the above will fail, and vice versa.
// So what is left is to distinguish between i1 == i2 and i1 != i2.
equal := api.IsZero(diff)
result_zero := api.IsZero(result)
// equal and result_zero must be either both true or both false
api.AssertIsEqual(equal, result_zero)
return result
}
func CmpHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
results[0].SetInt64(int64(inputs[0].Cmp(inputs[1])))
return nil
}
The new function performs a range check on both inputs to be at most nbBits
wide, and then uses the comparison result from a hint to ensure the right difference i1 - i2
or i2 - i1
is at most nbBits
wide as well, with some additional logic for handling distinguishing i1 == i2
from i1 != i2
. The above is what we wrote down quickly to demonstrate this, so some improvements may be possible.
We compared the two implementations with this small circuit:
const bitnum = 64
// Here we define which circuit variables our circuit will have.
type ExampleCircuit struct {
A frontend.Variable // Private witness
B frontend.Variable // Private witness
}
// This function defines the constraints for our circuit
func (circuit *ExampleCircuit) Define(api frontend.API) error {
//CmpOld(api, circuit.A, circuit.B, bitnum)
CmpNew(api, circuit.A, circuit.B, bitnum)
//api.AssertIsDifferent(cmp_result, 2)
//api.AssertIsEqual(cmp_result_old, cmp_result_new)
return nil
}
For eight-bit--wide values, the old implementation takes 139 constraints, the new one 90. For 248-bit--wide values, the difference is much more, the old implementation requiring 4,459 and the new only 904, reducing the number of constraints by about 80%.
The current, less efficient implementation of Cmp
is currently used for the SDK's Uint32
(571 vs 198 constraints, 65% saved) and Uint64
(1,147 vs 318 constraints, 72% saved) types. For the Uint248
type, the standard gnark implementation of Cmp
is used. Here the standard gnark Cmp
takes 5,381 constraints, whereas our CmpNew
takes 904, saving 83%.
More efficient implementation for function ABS
In the sdk repo's sdk/api_int248.go, the ABS
function for the Uint248
type provided by the SDK is implemented as follows:
// ABS returns the absolute value of a
func (api *Int248API) ABS(a Int248) Uint248 {
bs := api.ToBinary(a)
signBit := bs[247] // ToBinary returns little-endian bits, the last bit is sign
flipped := make([]frontend.Variable, len(bs))
for i, v := range bs {
flipped[i] = api.g.IsZero(v.Val)
}
absWhenOrigIsNeg := api.g.Add(1, api.g.FromBinary(flipped...))
abs := api.g.Select(signBit.Val, absWhenOrigIsNeg, a.Val)
return newU248(abs)
}
This implementation is correct but very inefficient. Converting a 248-bit circuit variable into the individual bits takes a significant number of additional circuit variables as well as constraints. From our testing, this function costs around 1,200 constraints. However, it is possible to implement it with much fewer constraints by not using bitwise logic. The following implementation only costs four constraints if the sign bit is cached, an improvement by a factor of over 300.
// ABS returns the absolute value of a
func (api *Int248API) ABSNew(a Int248) Uint248 {
a = api.ensureSignBit(a)
resultIfNonNeg := a.Val
resultIfNeg := api.g.Sub(new(big.Int).Lsh(big.NewInt(1), 248), a.Val)
result := api.g.Select(a.SignBit, resultIfNeg, resultIfNonNeg)
return newU248(result)
}
It works by computing the correct result in the negative case directly from the value, by subtracting it from . Then, select
is used similarly to the current implementation.
Unnecessary branch in Pad101Bytes
The function Pad101Bytes
in the file zk-hash/keccak/periphery.go has an unnecessary branch. It is intended to implement the pad10*1
padding function for keccak, and it is implemented as follows:
func Pad101Bytes(data []byte) []byte {
miss := 136 - len(data)%136
if len(data)%136 == 0 {
miss = 136
}
data = append(data, 1)
for i := 0; i < miss-1; i++ {
data = append(data, 0)
}
data[len(data)-1] ^= 0x80
return data
}
The if case does not actually do anything, however, as when len(data)%136 == 0
, then miss
is already 136. This branch can thus be removed.
Code duplication between IsGreaterThan
and IsLessThan
Several types provided by the sdk offer IsGreaterThan
and IsLessThan
functions. These functions tend to be implemented independently but in a very parallel manner. Code duplication can be reduced by replacing the implementation of one of the two functions by a call to the other. For example, in the case of Int248
, the implementation of IsGreaterThan
is as follows, with IsLessThan
implemented very similarly.
func (api *Int248API) IsGreaterThan(a, b Int248) Uint248 {
a = api.ensureSignBit(a)
b = api.ensureSignBit(b)
cmp := api.g.Cmp(a.Val, b.Val)
isGtAsUint := api.g.IsZero(api.g.Sub(cmp, 1))
isLt := api.g.Lookup2(
a.SignBit, b.SignBit,
isGtAsUint, // a, b both pos
0, // a neg, b pos
1, // a pos, b neg
isGtAsUint, // a, b both neg
)
return newU248(isLt)
}
This function could be replaced by the following implementation:
func (api *Int248API) IsGreaterThan(a, b Int248) Uint248 {
return api.IsLessThan(b, a)
}
Code duplication in ToBytes32
In the sdk repository's sdk/circuit_api.go, the function ToBytes32
constructs the limbs of the resulting Bytes32
directly:
case Uint521:
api.Uint521.AssertIsLessOrEqual(v, MaxBytes32)
bits := api.Uint521.ToBinary(v, 32*8)
lo := api.Uint248.FromBinary(bits[:numBitsPerVar]...)
hi := api.Uint248.FromBinary(bits[numBitsPerVar:256]...)
return Bytes32{Val: [2]variable{lo.Val, hi.Val}}
Instead, the FromBinary
function for Bytes32
could be used.
Unnecessary constraint in Filter
The function Filter
in the sdk repository's sdk/datastream.go contains the following lines:
valid := ds.api.isEqual(ds.toggles[i], 1)
newToggles[i] = api.Select(api.And(toggle, valid), 1, 0)
Here, ds.toggles
are intended to be Boolean. Generally, it appears to be the assumption that the caller will have constrained ds.toggles[i]
to be Boolean already. In that case, the isEqual
constraints are not necessary, and it would be equivalent to do valid := ds.toggles[i]
.
However, even if we assume that ds.toggles[i]
has not yet been constrained to be Boolean, the isEqual
constraints are not necessary. This is because api.And
will constrain both arguments to be Boolean as well. So the above two lines may be replaced by
newToggles[i] = api.Select(api.And(toggle, ds.toggles[i]), 1, 0)
Going further, api.Select
will return 1
in this case if the selector is true (encoded by 1
), otherwise 0
. But as api.And
will return a Boolean, so a value that is either 1
or 0
, the api.Select
will just return the selector. Hence, the above can be further reduced to
newToggles[i] = api.And(toggle, ds.toggles[i])
Note that api.And
constraining both arguments to be Boolean is a not fully documented feature of the current gnark implementation. The API documentation says
// Or returns a & b
// a and b must be 0 or 1
And(a, b Variable) Variable
which could be read to suggest the caller must ensure a
and b
are 0 or 1.
However, the two constraint systems implement And
as follows:
// Or returns a & b
// a and b must be 0 or 1
func (builder *builder) And(a, b frontend.Variable) frontend.Variable {
builder.AssertIsBoolean(a)
builder.AssertIsBoolean(b)
res := builder.Mul(a, b)
builder.MarkBoolean(res)
return res
}
// And compute the AND between two frontend.Variables
func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable {
vars, _ := builder.toVariables(_a, _b)
a := vars[0]
b := vars[1]
builder.AssertIsBoolean(a)
builder.AssertIsBoolean(b)
res := builder.Mul(a, b)
builder.MarkBoolean(res)
return res
}
In both cases, the two arguments are constrained to be Boolean. The above snippets are from github.com/celer-network/[email protected]/frontend/cs/scs/api.go and github.com/celer-network/[email protected]/frontend/cs/r1cs/api.go.
Dead code in DefaultHostCircuit
In the sdk repository's sdk/host_circuit.go, the function DefaultHostCircuit
is implemented as follows:
func DefaultHostCircuit(app AppCircuit) *HostCircuit {
maxReceipts, maxStorage, maxTxs := app.Allocate()
var inputCommits = make([]frontend.Variable, NumMaxDataPoints)
for i := 0; i < NumMaxDataPoints; i++ {
inputCommits[i] = 0
}
h := &HostCircuit{
Input: defaultCircuitInput(maxReceipts, maxStorage, maxTxs),
Guest: app,
}
return h
}
Note that the part
var inputCommits = make([]frontend.Variable, NumMaxDataPoints)
for i := 0; i < NumMaxDataPoints; i++ {
inputCommits[i] = 0
}
is actually never used and can be deleted.
Constraints added in assertInputUniqueness
even when unneeded
In the sdk repository, the assertInputUniqueness
function in sdk/host_circuit.go has a shouldCheck
argument that is not a circuit variable but a native int
, indicating whether the uniqueness check should be performed or not. However, constraints are added in either case:
func assertInputUniqueness(api frontend.API, in []frontend.Variable, shouldCheck int) {
multicommit.WithCommitment(api, func(api frontend.API, gamma frontend.Variable) error {
// [Some constraints added]
for i := 0; i < len(sorted)-1; i++ {
// [More constraints added]
isValid = api.Select(shouldCheck, isValid, 1)
api.AssertIsEqual(isValid, 1)
}
return nil
}, in...)
}
A prover can always satisfy the constraints that were omitted in the above snippet. What a prover cannot always arrange is that isValid
is 1 before the reassignment shown above. However, if shouldCheck
is 0, then isValid
will be 1 after the reassignment, no matter what it was before. So in that case, the prover can always satisfy the constraints introduced by this function. The constraints and associated witnesses introduced by assertInputUniqueness
could thus be removed with no change to the validity of the statement being proven by the overall circuit. The function could thus be rearranged as follows:
func assertInputUniqueness(api frontend.API, in []frontend.Variable, shouldCheck int) {
if shouldCheck == 1 {
multicommit.WithCommitment(api, func(api frontend.API, gamma frontend.Variable) error {
// [Some constraints added]
for i := 0; i < len(sorted)-1; i++ {
// [More constraints added]
api.AssertIsEqual(isValid, 1)
}
return nil
}, in...)
}
}
This would save a significant number of constraints and witnesses (all that are introduced by this function) in the case that shouldCheck
is not 1
. Additionally, in the case that shouldCheck
is 1
, this will save the constraint and witness introduced by api.Select(shouldCheck, isValid, 1)
.
Smaller multiplex in Keccak256
The Keccak256
function in keccak/keccak256.go of the zk-hash repository uses a multiplex at the end of the function to select the final result:
selected := mux.Multiplex(api, roundIndex, 25, MAX_ROUNDS, transpose(outputStates))
for i := 0; i < 4; i++ {
out[i] = selected[i]
}
Here, the multiplex is done first over the full 25-word state, and then only the first four words are copied to the output. It should be more efficient to just do the multiplex over the first four words. This would be similar to how it is done in Keccak256Bits
in keccak/keccak256_bits.go, where the end of the function looks like this:
selected := mux.Multiplex(api, roundIndex, 256, maxRounds, transpose2(states[1:]))
copy(out[:], selected[:256])
The implementation of multiplex allows calling it with a lower number of output components (in this case, 256) than the input components (in this case, 1,600). The same pattern could be used in Keccak256
in keccak256.go to save some constraints.
Simplified Flip
The function Flip
in the zk-hash repository's utils/slice.go is used to reverse a slice. It is implemented as follows:
func Flip[T any](in []T) []T {
res := make([]T, len(in))
copy(res, in)
for i := 0; i < len(in)/2; i++ {
tmp := res[i]
res[i] = res[len(res)-1-i]
res[len(res)-1-i] = tmp
}
return res
}
Instead of reading the values to swap from res
, they could be read from in
, which removes the need to use the variable tmp
and to copy in
to res
initially. However, to ensure that the middle component is populated correctly, the condition in the for loop should then be changed to i < (len(in) + 1) / 2
:
func Flip[T any](in []T) []T {
res := make([]T, len(in))
for i := 0; i < (len(in) + 1) / 2; i++ {
res[i] = in[len(res)-1-i]
res[len(res)-1-i] = in[i]
}
return res
}
Redundant Boolean check in decode
The zk-hash helper function decode
, defined in mux/multiplexer.go, is implemented as follows:
// decodes the input selector num into a bit mask
// e.g. width 8 select 3 -> 00010000
func decode(api frontend.API, width int, input frontend.Variable) (output []frontend.Variable, outputSuccess frontend.Variable) {
outputSuccess = 0
for i := 0; i < width; i++ {
value := isEqual(api, i, input)
output = append(output, value)
outputSuccess = api.Add(outputSuccess, value)
}
api.AssertIsBoolean(outputSuccess)
return
}
In the loop, value
will be 0
unless i
is equal to input
modulo the field modulus r
over which the circuit is defined. As i
ranges from 0
to width - 1
, and this function cannot feasibly be instantiated with width
larger than the field modulus (a 254-bit prime), we can conclude that there will be at most a single value i = input % r
for which value
will be nonzero, where it will be 1
. It follows that outputSuccess
must be either 0 or 1, so Boolean. The constraint introduced by api.AssertIsBoolean(outputSuccess)
is thus redundant. The private function decode
is also only used by Multiplex
in the same file, where outputSuccess
is constrained to be equal to 1 anyway.