Verification of Poseidon constants
The zk-hash repository contains an in-circuit Poseidon implementation in poseidon/circuit/circuit.go, with associated constants in poseidon/circuit/constants.go. The implementation does not follow the definition of Poseidon from the main text of the Poseidon paper↗ directly. Instead, an optimization is used, where in particular mixing between non--S-box parts of the state in the partial rounds is moved outside of the partial rounds, and some other changes to the algorithm are done. This is done to improve efficiency, as this reduces the number of operations that need to be done in circuit. Essentially, the formulas are rearranged so as to obtain terms that involve only constants, so that these terms can be precomputed. Such optimizations are discussed in the paper as well, in Appendix B.
The Poseidon paper's authors provide scripts↗ that can be used to calculate constants to be used in Poseidon implementations. These scripts provide constants that fit with the standard description used in the main text of the paper, not the different constants needed for the particular optimized algorithm used by the implementation in poseidon/circuit/circuit.go.
To verify that the implementation in zk-hash is equivalent to the paper authors' specification and constants, we wrote a script that calculates the coefficients of the linear dependency of each part of the state on certain previous words of states. Doing this for both the standard algorithm and constants, and Brevis's algorithm and constants, we can check that they are equal step by step, thereby at the end confirming that the two algorithms compute the same function.
To run our scripts, it is first necessary to clone the paper authors' repository at https://extgit.iaik.tugraz.at/krypto/hadeshash/-/tree/master↗. All further description will be local in the code
directory of that repository.
A generate_params_poseidon.sage script is provided to generate parameters for Poseidon instances. However, Brevis uses a different number of partial rounds by rounding the suggested number up to the nearest multiple of t
, the number of field elements the Poseidon permutation is defined for. Thus, the script can be used to obtain the partial round numbers, to check that the ones Brevis uses are correctly rounded up, but the script needs to be patched to produce constants needed for Brevis instantiation of Poseidon. The following patch makes the necessary changes:
--- generate_params_poseidon.sage 2024-10-14 21:00:10.632136143 +0200
+++ generate_params_poseidon_fixed_r_p.sage 2024-10-17 23:21:55.858042766 +0200
@@ -2,8 +2,8 @@
import sys
from sage.rings.polynomial.polynomial_gf2x import GF2X_BuildIrred_list
-if len(sys.argv) < 8:
- print("Usage: <script> <field> <s_box> <field_size> <num_cells> <alpha> <security_level> <modulus_hex>")
+if len(sys.argv) < 9:
+ print("Usage: <script> <field> <s_box> <field_size> <num_cells> <alpha> <security_level> <modulus_hex> <R_P>")
print("field = 1 for GF(p)")
print("s_box = 0 for x^alpha, s_box = 1 for x^(-1)")
exit()
@@ -47,6 +47,8 @@
else:
print("Unknown field type, only 0 and 1 supported!")
exit()
+R_P_REQUESTED = int(sys.argv[8])
+
def sat_inequiv_alpha(p, t, R_F, R_P, alpha, M):
N = int(FIELD_SIZE * NUM_CELLS)
@@ -144,6 +146,9 @@
ROUND_NUMBERS = calc_final_numbers_fixed(PRIME_NUMBER, NUM_CELLS, ALPHA, SECURITY_LEVEL, True)
R_F_FIXED = ROUND_NUMBERS[0]
R_P_FIXED = ROUND_NUMBERS[1]
+assert R_P_FIXED <= R_P_REQUESTED
+R_P_FIXED = R_P_REQUESTED
+assert R_F_FIXED == 8
# R_F_FIXED = 8
# R_P_FIXED = 60
We called the so-changed script "generate_params_poseidon_fixed_r_p.sage".
Given the above changes, the following Python script, placed in the same directory, will use generate_params_poseidon_fixed_r_p.sage to generate the constants needed for the standard Poseidon algorithm. It will then read the file brevis.go to obtain the modified constants used by Brevis. This file, brevis.go, should be the file poseidon/circuit/constants.go from the zk-hash repo, copied into the directory. Finally, the script will carry out checks to ensure that the two Poseidon implementations are equivalent.
#!/usr/bin/env python3
import subprocess
from sage.all import matrix, vector, GF, block_matrix, identity_matrix, zero_matrix
prime = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001
F = GF(prime)
def generate_params():
rounds = [
None, None,
56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68
]
for t in range(2,18):
subprocess.run(
['sage', 'generate_params_poseidon_fixed_r_p.sage',
'1', '0', '254', str(t), '5', '128',
'0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001',
str(rounds[t])])
def get_paper_params():
round_contants = dict()
MDS_matrices = dict()
R_Ps = dict()
for t in range(2,18):
with open(f'poseidon_params_n254_t{t}_alpha5_M128.txt', 'r') as fh:
data = fh.read()
data = data.split('\n')
R_F = int(data[0].split('R_F=')[1].split(', ')[0])
R_P = int(data[0].split('R_P=')[1])
assert R_F == 8
c = eval(data[5])
m = eval(data[17])
for i in range(len(c)):
c[i] = eval(c[i])
for i in range(len(m)):
for j in range(len(m[i])):
m[i][j] = eval(m[i][j])
R_Ps[t] = R_P
round_contants[t] = c
MDS_matrices[t] = m
return round_contants, MDS_matrices, R_Ps
def extract_brevis_param_letter(data, letter):
result = dict()
data = data.split('func strPOSEIDON_' + letter)[1].split('func ')[0]
data = data.split('if t == ')[1:]
for entry in data:
t = int(entry.split(' ')[0])
string = entry.split('`')[1]
if t == 2 and letter == 'C':
string = string[:-1]
value = eval(string)
result[t] = value
return result
def get_brevis_params():
with open('brevis.go', 'r') as fh:
data = fh.read()
c = extract_brevis_param_letter(data, 'C')
s = extract_brevis_param_letter(data, 'S')
m = extract_brevis_param_letter(data, 'M')
p = extract_brevis_param_letter(data, 'P')
return c,s,m,p
def check_for_t(t, c, s, m, p, paper_c, paper_m, r_P):
m = matrix(F, m)
p = matrix(F, p)
paper_m = matrix(F, paper_m)
# Correctness of first four full rounds except the last Mix need the
# following.
assert m.transpose() == paper_m
assert c[:t] == paper_c[:t]
for i in range(1,4):
assert vector(F, c[t*i:t*(i+1)]) == vector(F, paper_c[t*i:t*(i+1)]) * m.inverse()
# For the partial rounds, we symbolically check this.
# We do this from just before the last Mix from the last full round
# to after the round key addition of the first full round after the
# partial ones.
# We use the following variables:
# - dummy variable for constant 1 (so we can handle that this is
# not linear, but affine linear)
# - initial state[i] for i>0
# - initial state[0], and state[0] at each step just after sigma,
# 1+r_P in total.
# and compute the matrices for how each state[0] before a sigma
# application depends on the previous, and how the resulting state
# depends on all of the above.
# This is all linear, so this means just computing some matrices.
# We then compare this to matrices from the constants.
# We do the paper first, then Brevis' implementation
# Initial state. This means before applying the last Mix from the
# last full round.
num_vars = 1 + (t - 1) + (1 + r_P)
state_0 = matrix(F, [0] + [0]*(t-1) + [1] + [0]*r_P).transpose()
state_positive = block_matrix(F,
[[zero_matrix(F, 1, t - 1)],
[identity_matrix(F, t - 1)],
[zero_matrix(F, 1 + r_P, t - 1)]])
state = block_matrix(F, [[state_0, state_positive]])
paper_pre_sigma_state_0 = []
# First, the last Mix.
state = state * m
for i in range(r_P):
# Addition of round keys:
round_keys_this_round = paper_c[(4+i)*t:(4+i+1)*t]
round_keys_matrix = matrix(
[round_keys_this_round] +
[[F(0) for __ in range(t)] for _ in range(num_vars - 1)])
state = state + round_keys_matrix
# Now we extract the pre-sigma state_0.
paper_pre_sigma_state_0.append(state.column(0))
# We will check that the pre-sigma state 0 is correct separately.
# We thus continue with the post-sigma state 0 replaced by a variable.
new_state_0 = matrix(F,
[[0]*(1 + t - 1 + 1 + i) +
[1] + [0]*(r_P - 1 - i) ]).transpose()
rest_state = state.submatrix(0, 1, num_vars, t-1)
state = block_matrix(F, [[new_state_0, rest_state]])
# Finally, we apply the Mix step.
state = state * m
round_keys_this_round = paper_c[(4+r_P)*t:(4+r_P+1)*t]
round_keys_matrix = matrix(
[round_keys_this_round] +
[[F(0) for __ in range(t)] for _ in range(num_vars - 1)])
state = state + round_keys_matrix
paper_final_state = state
# Now we do Brevis' implementation
state_0 = matrix(F, [0] + [0]*(t-1) + [1] + [0]*r_P).transpose()
state_positive = block_matrix(F,
[[zero_matrix(F, 1, t - 1)],
[identity_matrix(F, t - 1)],
[zero_matrix(F, 1 + r_P, t - 1)]])
state = block_matrix(F, [[state_0, state_positive]])
brevis_pre_sigma_state_0 = []
# state = Ark(api, state, c, nRoundsF/2*t)
round_keys_pre_loop = c[(4)*t:(4+1)*t]
round_keys_matrix = matrix([round_keys_pre_loop] +
[[F(0) for __ in range(t)]
for _ in range(num_vars - 1)])
state = state + round_keys_matrix
# state = Mix(api, state, p)
state = state * p
for r in range(r_P):
# Now we extract the pre-sigma state_0.
brevis_pre_sigma_state_0.append(state.column(0))
# We continue with the post-sigma state 0 replaced by a variable.
new_state_0 = matrix(F, [[0]*(1 + t - 1 + 1 + r) +
[1] + [0]*(r_P - 1 - r) ]).transpose()
rest_state = state.submatrix(0, 1, num_vars, t-1)
state = block_matrix(F, [[new_state_0, rest_state]])
# state[0] = api.Add(state[0], c[(nRoundsF/2+1)*t+r])
constants_to_add = [c[(4+1)*t + r]] + [0]*(t-1)
constants_to_add_matrix = matrix([constants_to_add] +
[[F(0) for __ in range(t)]
for _ in range(num_vars - 1)])
state = state + constants_to_add_matrix
# newState0 := frontend.Variable(0)
# for j := 0; j < len(state); j++ {
# mul := api.Mul(s[(t*2-1)*r+j], state[j])
# newState0 = api.Add(newState0, mul)
# }
new_states = []
new_state_0 = vector(F, [0]*num_vars)
for j in range(t):
new_state_0 = new_state_0 + F(s[(t*2 - 1)*r + j])*state.column(j)
new_states.append(new_state_0)
# for k := 1; k < t; k++ {
# state[k] = api.Add(state[k], api.Mul(state[0], s[(t*2-1)*r+t+k-1]))
# }
for k in range(1, t):
new_state_k = (state.column(k) +
F(s[(t*2 - 1)*r + t + k - 1])*state.column(0))
new_states.append(new_state_k)
# state[0] = newState0
# and also update all the other ones, as we haven't done that yet
state = matrix(F, new_states).transpose()
brevis_final_state = state
assert len(paper_pre_sigma_state_0) == len(brevis_pre_sigma_state_0)
assert len(paper_pre_sigma_state_0) == r_P
for i in range(r_P):
# Splitting the assert into lines is just to prevent line breaks
# in the report.
assert_check = (paper_pre_sigma_state_0[i] ==
brevis_pre_sigma_state_0[i])
assert_msg = f'pre-sigma state does not match for {i=}'
assert assert_check, assert_msg
assert paper_final_state == brevis_final_state
for i in range(1,4):
r = i - 1
# state = Ark(api, state, c, (nRoundsF/2+1)*t+nRoundsP+r*t)
assert_lhs = vector(F, c[t*(5 + r) + r_P:t*(5 + r + 1) + r_P])
assert_rhs = (vector(F, paper_c[t*(4+r_P+i):t*(4+r_P+i+1)]) *
m.inverse())
assert assert_lhs == assert_rhs
print(f'All correct for {t=}')
def check():
c,s,m,p = get_brevis_params()
paper_c, paper_m, paper_r_p = get_paper_params()
rounds = [None, None,
56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68]
for t in range(2, len(rounds)):
#assert rounds[t] >= paper_r_p[t]
#assert rounds[t] < paper_r_p[t] + t
#assert rounds[t] % t == 0
assert rounds[t] == paper_r_p[t]
check_for_t(
t, c[t], s[t], m[t], p[t],
paper_c[t], paper_m[t], rounds[t])
generate_params()
check()