1
2
3
4
5 package mldsa
6
7 import (
8 "math/big"
9 "testing"
10 )
11
12 type interestingValue struct {
13 v uint32
14 m fieldElement
15 }
16
17
18
19
20 func interestingValues() []interestingValue {
21 if testing.Short() {
22 return []interestingValue{{v: q - 1, m: minusOne}}
23 }
24 var values []interestingValue
25 for _, v := range []uint32{
26 0,
27 1,
28 2,
29 3,
30 q - 3,
31 q - 2,
32 q - 1,
33 q / 2,
34 (q + 1) / 2,
35 } {
36 m, _ := fieldToMontgomery(v)
37 values = append(values, interestingValue{v: v, m: m})
38
39 values = append(values, interestingValue{
40 v: fieldFromMontgomery(fieldElement(v)), m: fieldElement(v)})
41 }
42 return values
43 }
44
45 func TestToFromMontgomery(t *testing.T) {
46 for a := range uint32(q) {
47 m, err := fieldToMontgomery(a)
48 if err != nil {
49 t.Fatalf("fieldToMontgomery(%d) returned error: %v", a, err)
50 }
51 exp := fieldElement((uint64(a) * R) % q)
52 if m != exp {
53 t.Fatalf("fieldToMontgomery(%d) = %d, expected %d", a, m, exp)
54 }
55 got := fieldFromMontgomery(m)
56 if got != a {
57 t.Fatalf("fieldFromMontgomery(fieldToMontgomery(%d)) = %d, expected %d", a, got, a)
58 }
59 }
60 }
61
62 func TestFieldAdd(t *testing.T) {
63 t.Parallel()
64 for _, a := range interestingValues() {
65 for b := range fieldElement(q) {
66 got := fieldAdd(a.m, b)
67 exp := (a.m + b) % q
68 if got != exp {
69 t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
70 }
71 }
72 }
73 }
74
75 func TestFieldSub(t *testing.T) {
76 t.Parallel()
77 for _, a := range interestingValues() {
78 for b := range fieldElement(q) {
79 got := fieldSub(a.m, b)
80 exp := (a.m + q - b) % q
81 if got != exp {
82 t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
83 }
84 }
85 }
86 }
87
88 func TestFieldSubToMontgomery(t *testing.T) {
89 t.Parallel()
90 for _, a := range interestingValues() {
91 for b := range uint32(q) {
92 got := fieldSubToMontgomery(a.v, b)
93 diff := (a.v + q - b) % q
94 exp := fieldElement((uint64(diff) * R) % q)
95 if got != exp {
96 t.Fatalf("fieldSubToMontgomery(%d, %d) = %d, expected %d", a.v, b, got, exp)
97 }
98 }
99 }
100 }
101
102 func TestFieldReduceOnce(t *testing.T) {
103 t.Parallel()
104 for a := range uint32(2 * q) {
105 got := fieldReduceOnce(a)
106 var exp uint32
107 if a < q {
108 exp = a
109 } else {
110 exp = a - q
111 }
112 if uint32(got) != exp {
113 t.Fatalf("fieldReduceOnce(%d) = %d, expected %d", a, got, exp)
114 }
115 }
116 }
117
118 func TestFieldMul(t *testing.T) {
119 t.Parallel()
120 for _, a := range interestingValues() {
121 for b := range fieldElement(q) {
122 got := fieldFromMontgomery(fieldMontgomeryMul(a.m, b))
123 exp := uint32((uint64(a.v) * uint64(fieldFromMontgomery(b))) % q)
124 if got != exp {
125 t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
126 }
127 }
128 }
129 }
130
131 func TestFieldToMontgomeryOverflow(t *testing.T) {
132
133 inputs := []uint32{
134 q,
135 q + 1,
136 q + 2,
137 1<<23 - 1,
138 1 << 23,
139 q + 1<<23,
140 q + 1<<31,
141 ^uint32(0),
142 }
143 for _, in := range inputs {
144 if _, err := fieldToMontgomery(in); err == nil {
145 t.Fatalf("fieldToMontgomery(%d) did not return an error", in)
146 }
147 }
148 }
149
150 func TestFieldMulSub(t *testing.T) {
151 for _, a := range interestingValues() {
152 for _, b := range interestingValues() {
153 for _, c := range interestingValues() {
154 got := fieldFromMontgomery(fieldMontgomeryMulSub(a.m, b.m, c.m))
155 exp := uint32((uint64(a.v) * (uint64(b.v) + q - uint64(c.v))) % q)
156 if got != exp {
157 t.Fatalf("%d * (%d - %d) = %d, expected %d", a.v, b.v, c.v, got, exp)
158 }
159 }
160 }
161 }
162 }
163
164 func TestFieldAddMul(t *testing.T) {
165 for _, a := range interestingValues() {
166 for _, b := range interestingValues() {
167 for _, c := range interestingValues() {
168 for _, d := range interestingValues() {
169 got := fieldFromMontgomery(fieldMontgomeryAddMul(a.m, b.m, c.m, d.m))
170 exp := uint32((uint64(a.v)*uint64(b.v) + uint64(c.v)*uint64(d.v)) % q)
171 if got != exp {
172 t.Fatalf("%d + %d * %d = %d, expected %d", a.v, b.v, c.v, got, exp)
173 }
174 }
175 }
176 }
177 }
178 }
179
180 func BitRev8(n uint8) uint8 {
181 var r uint8
182 r |= n >> 7 & 0b0000_0001
183 r |= n >> 5 & 0b0000_0010
184 r |= n >> 3 & 0b0000_0100
185 r |= n >> 1 & 0b0000_1000
186 r |= n << 1 & 0b0001_0000
187 r |= n << 3 & 0b0010_0000
188 r |= n << 5 & 0b0100_0000
189 r |= n << 7 & 0b1000_0000
190 return r
191 }
192
193 func CenteredMod(x, m uint32) int32 {
194 x = x % m
195 if x > m/2 {
196 return int32(x) - int32(m)
197 }
198 return int32(x)
199 }
200
201 func reduceModQ(x int32) uint32 {
202 x %= q
203 if x < 0 {
204 return uint32(x + q)
205 }
206 return uint32(x)
207 }
208
209 func TestCenteredMod(t *testing.T) {
210 for x := range uint32(q * 2) {
211 got := CenteredMod(uint32(x), q)
212 if reduceModQ(got) != (x % q) {
213 t.Fatalf("CenteredMod(%d) = %d, which is not congruent to %d mod %d", x, got, x, q)
214 }
215 }
216
217 for x := range uint32(q) {
218 r, _ := fieldToMontgomery(x)
219 got := fieldCenteredMod(r)
220 exp := CenteredMod(x, q)
221 if got != exp {
222 t.Fatalf("fieldCenteredMod(%d) = %d, expected %d", x, got, exp)
223 }
224 }
225 }
226
227 func TestInfinityNorm(t *testing.T) {
228 for x := range uint32(q) {
229 r, _ := fieldToMontgomery(x)
230 got := fieldInfinityNorm(r)
231 exp := CenteredMod(x, q)
232 if exp < 0 {
233 exp = -exp
234 }
235 if got != uint32(exp) {
236 t.Fatalf("fieldInfinityNorm(%d) = %d, expected %d", x, got, exp)
237 }
238 }
239 }
240
241 func TestConstants(t *testing.T) {
242 if fieldFromMontgomery(one) != 1 {
243 t.Errorf("one constant incorrect")
244 }
245 if fieldFromMontgomery(minusOne) != q-1 {
246 t.Errorf("minusOne constant incorrect")
247 }
248 if fieldInfinityNorm(one) != 1 {
249 t.Errorf("one infinity norm incorrect")
250 }
251 if fieldInfinityNorm(minusOne) != 1 {
252 t.Errorf("minusOne infinity norm incorrect")
253 }
254
255 if PublicKeySize44 != pubKeySize(params44) {
256 t.Errorf("PublicKeySize44 constant incorrect")
257 }
258 if PublicKeySize65 != pubKeySize(params65) {
259 t.Errorf("PublicKeySize65 constant incorrect")
260 }
261 if PublicKeySize87 != pubKeySize(params87) {
262 t.Errorf("PublicKeySize87 constant incorrect")
263 }
264 if SignatureSize44 != sigSize(params44) {
265 t.Errorf("SignatureSize44 constant incorrect")
266 }
267 if SignatureSize65 != sigSize(params65) {
268 t.Errorf("SignatureSize65 constant incorrect")
269 }
270 if SignatureSize87 != sigSize(params87) {
271 t.Errorf("SignatureSize87 constant incorrect")
272 }
273 }
274
275 func TestPower2Round(t *testing.T) {
276 t.Parallel()
277 for x := range uint32(q) {
278 rr, _ := fieldToMontgomery(x)
279 t1, t0 := power2Round(rr)
280
281 hi, err := fieldToMontgomery(uint32(t1) << 13)
282 if err != nil {
283 t.Fatalf("power2Round(%d): failed to convert high part to Montgomery: %v", x, err)
284 }
285 if r := fieldFromMontgomery(fieldAdd(hi, t0)); r != x {
286 t.Fatalf("power2Round(%d) = (%d, %d), which reconstructs to %d, expected %d", x, t1, t0, r, x)
287 }
288 }
289 }
290
291 func SpecDecompose(rr fieldElement, p parameters) (R1 uint32, R0 int32) {
292 r := fieldFromMontgomery(rr)
293 if (q-1)%p.γ2 != 0 {
294 panic("mldsa: internal error: unsupported denγ2")
295 }
296 γ2 := (q - 1) / uint32(p.γ2)
297 r0 := CenteredMod(r, 2*γ2)
298 diff := int32(r) - r0
299 if diff == q-1 {
300 r0 = r0 - 1
301 return 0, r0
302 } else {
303 if diff < 0 || uint32(diff)%γ2 != 0 {
304 panic("mldsa: internal error: invalid decomposition")
305 }
306 r1 := uint32(diff) / (2 * γ2)
307 return r1, r0
308 }
309 }
310
311 func TestDecompose(t *testing.T) {
312 t.Run("ML-DSA-44", func(t *testing.T) {
313 testDecompose(t, params44)
314 })
315 t.Run("ML-DSA-65,87", func(t *testing.T) {
316 testDecompose(t, params65)
317 })
318 }
319
320 func testDecompose(t *testing.T, p parameters) {
321 t.Parallel()
322 for x := range uint32(q) {
323 rr, _ := fieldToMontgomery(x)
324 r1, r0 := SpecDecompose(rr, p)
325
326
327
328 γ2 := (q - 1) / uint32(p.γ2)
329 reconstructed := reduceModQ(int32(r1*2*γ2) + r0)
330 if reconstructed != x {
331 t.Fatalf("SpecDecompose(%d) = (%d, %d), which reconstructs to %d, expected %d", x, r1, r0, reconstructed, x)
332 }
333
334 var gotR1 byte
335 var gotR0 int32
336 switch p.γ2 {
337 case 88:
338 gotR1, gotR0 = decompose88(rr)
339 if gotR1 > 43 {
340 t.Fatalf("decompose88(%d) returned r1 = %d, which is out of range", x, gotR1)
341 }
342 case 32:
343 gotR1, gotR0 = decompose32(rr)
344 if gotR1 > 15 {
345 t.Fatalf("decompose32(%d) returned r1 = %d, which is out of range", x, gotR1)
346 }
347 default:
348 t.Fatalf("unsupported denγ2: %d", p.γ2)
349 }
350 if uint32(gotR1) != r1 {
351 t.Fatalf("highBits(%d) = %d, expected %d", x, gotR1, r1)
352 }
353 if gotR0 != r0 {
354 t.Fatalf("lowBits(%d) = %d, expected %d", x, gotR0, r0)
355 }
356 }
357 }
358
359 func TestZetas(t *testing.T) {
360 ζ := big.NewInt(1753)
361 q := big.NewInt(q)
362 for k, zeta := range zetas {
363
364 exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev8(uint8(k)))), q)
365 got := fieldFromMontgomery(zeta)
366 if big.NewInt(int64(got)).Cmp(exp) != 0 {
367 t.Errorf("zetas[%d] = %v, expected %v", k, got, exp)
368 }
369 }
370 }
371
View as plain text