1
2
3
4
5 package mldsa
6
7 import (
8 "crypto/internal/constanttime"
9 "crypto/internal/fips140/sha3"
10 "errors"
11 "math/bits"
12 )
13
14 const (
15 q = 8380417
16 R = 4294967296
17 RR = 2365951
18 qNegInv = 4236238847
19 one = 4193792
20 minusOne = 4186625
21 )
22
23
24
25 type fieldElement uint32
26
27 var errUnreducedFieldElement = errors.New("mldsa: unreduced field element")
28
29
30
31 func fieldToMontgomery(a uint32) (fieldElement, error) {
32 if a >= q {
33 return 0, errUnreducedFieldElement
34 }
35
36 return fieldMontgomeryMul(fieldElement(a), RR), nil
37 }
38
39
40
41 func fieldSubToMontgomery(a, b uint32) fieldElement {
42 x := a - b + q
43 return fieldMontgomeryMul(fieldElement(x), RR)
44 }
45
46
47
48 func fieldFromMontgomery(a fieldElement) uint32 {
49
50 return uint32(fieldMontgomeryReduce(uint64(a)))
51 }
52
53
54
55 func fieldCenteredMod(r fieldElement) int32 {
56 x := int32(fieldFromMontgomery(r))
57
58 return constantTimeSelectLessOrEqual(x, q/2, x, x-q)
59 }
60
61
62
63 func fieldInfinityNorm(r fieldElement) uint32 {
64 x := int32(fieldFromMontgomery(r))
65
66
67 return uint32(constantTimeSelectLessOrEqual(x, q/2, x, q-x))
68 }
69
70
71 func fieldReduceOnce(a uint32) fieldElement {
72 x, b := bits.Sub64(uint64(a), uint64(q), 0)
73 return fieldElement(x + b*q)
74 }
75
76
77 func fieldAdd(a, b fieldElement) fieldElement {
78 x := uint32(a + b)
79 return fieldReduceOnce(x)
80 }
81
82
83 func fieldSub(a, b fieldElement) fieldElement {
84 x := uint32(a - b + q)
85 return fieldReduceOnce(x)
86 }
87
88
89 func fieldMontgomeryMul(a, b fieldElement) fieldElement {
90 x := uint64(a) * uint64(b)
91 return fieldMontgomeryReduce(x)
92 }
93
94
95 func fieldMontgomeryReduce(x uint64) fieldElement {
96 t := uint32(x) * qNegInv
97 u := (x + uint64(t)*q) >> 32
98 return fieldReduceOnce(uint32(u))
99 }
100
101
102
103 func fieldMontgomeryMulSub(a, b, c fieldElement) fieldElement {
104 x := uint64(a) * uint64(b-c+q)
105 return fieldMontgomeryReduce(x)
106 }
107
108
109
110 func fieldMontgomeryAddMul(a, b, c, d fieldElement) fieldElement {
111 x := uint64(a) * uint64(b)
112 x += uint64(c) * uint64(d)
113 return fieldMontgomeryReduce(x)
114 }
115
116 const n = 256
117
118
119 type ringElement [n]fieldElement
120
121
122 func polyAdd[T ~[n]fieldElement](a, b T) (s T) {
123 for i := range s {
124 s[i] = fieldAdd(a[i], b[i])
125 }
126 return s
127 }
128
129
130 func polySub[T ~[n]fieldElement](a, b T) (s T) {
131 for i := range s {
132 s[i] = fieldSub(a[i], b[i])
133 }
134 return s
135 }
136
137
138 type nttElement [n]fieldElement
139
140
141
142 var zetas = [256]fieldElement{4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468, 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103, 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868, 6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005, 2706023, 95776, 3077325, 3530437, 6718724, 4788269, 5842901, 3915439, 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118, 6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596, 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892, 5582638, 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196, 7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922, 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370, 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987, 5037034, 264944, 508951, 3097992, 44288, 7280319, 904516, 3958618, 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561, 189548, 4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330, 1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961, 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955, 266997, 2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039, 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917, 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579, 342297, 286988, 5942594, 4108315, 3437287, 5038140, 1735879, 203044, 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974, 4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447, 7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775, 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031, 7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136, 4603424, 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032, 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310, 5341501, 3523897, 3866901, 269760, 2213111, 7404533, 1717735, 472078, 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524, 5441381, 6144432, 7959518, 6094090, 183443, 7403526, 1612842, 4834730, 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782}
143
144
145
146
147 func ntt(f ringElement) nttElement {
148 var m uint8
149
150 for len := 128; len >= 8; len /= 2 {
151 for start := 0; start < 256; start += 2 * len {
152 m++
153 zeta := zetas[m]
154
155
156 f, flen := f[start:start+len], f[start+len:start+len+len]
157 for j := 0; j < len; j += 2 {
158 t := fieldMontgomeryMul(zeta, flen[j])
159 flen[j] = fieldSub(f[j], t)
160 f[j] = fieldAdd(f[j], t)
161
162
163 t = fieldMontgomeryMul(zeta, flen[j+1])
164 flen[j+1] = fieldSub(f[j+1], t)
165 f[j+1] = fieldAdd(f[j+1], t)
166 }
167 }
168 }
169
170
171 for start := 0; start < 256; start += 8 {
172 m++
173 zeta := zetas[m]
174
175 t := fieldMontgomeryMul(zeta, f[start+4])
176 f[start+4] = fieldSub(f[start], t)
177 f[start] = fieldAdd(f[start], t)
178
179 t = fieldMontgomeryMul(zeta, f[start+5])
180 f[start+5] = fieldSub(f[start+1], t)
181 f[start+1] = fieldAdd(f[start+1], t)
182
183 t = fieldMontgomeryMul(zeta, f[start+6])
184 f[start+6] = fieldSub(f[start+2], t)
185 f[start+2] = fieldAdd(f[start+2], t)
186
187 t = fieldMontgomeryMul(zeta, f[start+7])
188 f[start+7] = fieldSub(f[start+3], t)
189 f[start+3] = fieldAdd(f[start+3], t)
190 }
191 for start := 0; start < 256; start += 4 {
192 m++
193 zeta := zetas[m]
194
195 t := fieldMontgomeryMul(zeta, f[start+2])
196 f[start+2] = fieldSub(f[start], t)
197 f[start] = fieldAdd(f[start], t)
198
199 t = fieldMontgomeryMul(zeta, f[start+3])
200 f[start+3] = fieldSub(f[start+1], t)
201 f[start+1] = fieldAdd(f[start+1], t)
202 }
203 for start := 0; start < 256; start += 2 {
204 m++
205 zeta := zetas[m]
206
207 t := fieldMontgomeryMul(zeta, f[start+1])
208 f[start+1] = fieldSub(f[start], t)
209 f[start] = fieldAdd(f[start], t)
210 }
211
212 return nttElement(f)
213 }
214
215
216
217
218 func inverseNTT(f nttElement) ringElement {
219 var m uint8 = 255
220
221
222 for start := 0; start < 256; start += 2 {
223 zeta := zetas[m]
224 m--
225
226 t := f[start]
227 f[start] = fieldAdd(t, f[start+1])
228 f[start+1] = fieldMontgomeryMulSub(zeta, f[start+1], t)
229 }
230 for start := 0; start < 256; start += 4 {
231 zeta := zetas[m]
232 m--
233
234 t := f[start]
235 f[start] = fieldAdd(t, f[start+2])
236 f[start+2] = fieldMontgomeryMulSub(zeta, f[start+2], t)
237
238 t = f[start+1]
239 f[start+1] = fieldAdd(t, f[start+3])
240 f[start+3] = fieldMontgomeryMulSub(zeta, f[start+3], t)
241 }
242 for start := 0; start < 256; start += 8 {
243 zeta := zetas[m]
244 m--
245
246 t := f[start]
247 f[start] = fieldAdd(t, f[start+4])
248 f[start+4] = fieldMontgomeryMulSub(zeta, f[start+4], t)
249
250 t = f[start+1]
251 f[start+1] = fieldAdd(t, f[start+5])
252 f[start+5] = fieldMontgomeryMulSub(zeta, f[start+5], t)
253
254 t = f[start+2]
255 f[start+2] = fieldAdd(t, f[start+6])
256 f[start+6] = fieldMontgomeryMulSub(zeta, f[start+6], t)
257
258 t = f[start+3]
259 f[start+3] = fieldAdd(t, f[start+7])
260 f[start+7] = fieldMontgomeryMulSub(zeta, f[start+7], t)
261 }
262
263 for len := 8; len < 256; len *= 2 {
264 for start := 0; start < 256; start += 2 * len {
265 zeta := zetas[m]
266 m--
267
268
269 f, flen := f[start:start+len], f[start+len:start+len+len]
270 for j := 0; j < len; j += 2 {
271 t := f[j]
272 f[j] = fieldAdd(t, flen[j])
273
274 flen[j] = fieldMontgomeryMulSub(zeta, flen[j], t)
275
276
277 t = f[j+1]
278 f[j+1] = fieldAdd(t, flen[j+1])
279 flen[j+1] = fieldMontgomeryMulSub(zeta, flen[j+1], t)
280 }
281 }
282 }
283
284 for i := range f {
285 f[i] = fieldMontgomeryMul(f[i], 16382)
286 }
287 return ringElement(f)
288 }
289
290
291 func nttMul(a, b nttElement) (p nttElement) {
292 for i := range p {
293 p[i] = fieldMontgomeryMul(a[i], b[i])
294 }
295 return p
296 }
297
298
299
300
301 func sampleNTT(rho []byte, s, r byte) nttElement {
302 G := sha3.NewShake128()
303 G.Write(rho)
304 G.Write([]byte{s, r})
305
306 var a nttElement
307 var j int
308 var buf [168]byte
309 off := len(buf)
310 for j < n {
311 if off >= len(buf) {
312 G.Read(buf[:])
313 off = 0
314 }
315 v := uint32(buf[off]) | uint32(buf[off+1])<<8 | uint32(buf[off+2])<<16
316 off += 3
317 f, err := fieldToMontgomery(v & 0b01111111_11111111_11111111)
318 if err != nil {
319 continue
320 }
321 a[j] = f
322 j++
323 }
324 return a
325 }
326
327
328
329
330 func sampleBoundedPoly(rho []byte, r byte, p parameters) ringElement {
331 H := sha3.NewShake256()
332 H.Write(rho)
333 H.Write([]byte{r, 0})
334
335 var a ringElement
336 var j int
337 var buf [136]byte
338 off := len(buf)
339 for {
340 if off >= len(buf) {
341 H.Read(buf[:])
342 off = 0
343 }
344 z0 := buf[off] & 0x0F
345 z1 := buf[off] >> 4
346 off++
347 coeff, ok := coeffFromHalfByte(z0, p)
348 if ok {
349 a[j] = coeff
350 j++
351 }
352 if j >= len(a) {
353 break
354 }
355 coeff, ok = coeffFromHalfByte(z1, p)
356 if ok {
357 a[j] = coeff
358 j++
359 }
360 if j >= len(a) {
361 break
362 }
363 }
364 return a
365 }
366
367
368
369 func sampleInBall(rho []byte, p parameters) ringElement {
370 H := sha3.NewShake256()
371 H.Write(rho)
372 s := make([]byte, 8)
373 H.Read(s)
374
375 var c ringElement
376 for i := 256 - p.τ; i < 256; i++ {
377 j := make([]byte, 1)
378 H.Read(j)
379 for j[0] > byte(i) {
380 H.Read(j)
381 }
382 c[i] = c[j[0]]
383
384
385 bitIdx := i + p.τ - 256
386 bit := (s[bitIdx/8] >> (bitIdx % 8)) & 1
387 if bit == 0 {
388 c[j[0]] = one
389 } else {
390 c[j[0]] = minusOne
391 }
392 }
393
394 return c
395 }
396
397
398
399
400 func coeffFromHalfByte(b byte, p parameters) (fieldElement, bool) {
401 if b > 15 {
402 panic("internal error: half-byte out of range")
403 }
404 switch p.η {
405 case 2:
406
407
408
409
410
411
412
413
414
415
416
417
418 if b > 14 {
419 return 0, false
420 }
421
422
423 const barrettMultiplier = 0x3334
424 const barrettShift = 16
425 quotient := (uint32(b) * barrettMultiplier) >> barrettShift
426 remainder := uint32(b) - quotient*5
427 return fieldSubToMontgomery(2, remainder), true
428 case 4:
429
430
431
432
433
434
435
436
437 if b > 8 {
438 return 0, false
439 }
440 return fieldSubToMontgomery(4, uint32(b)), true
441 default:
442 panic("internal error: unsupported η")
443 }
444 }
445
446
447
448
449
450 func power2Round(r fieldElement) (hi uint16, lo fieldElement) {
451 rr := fieldFromMontgomery(r)
452
453
454 r1 := rr + 1<<12 - 1
455 r1 >>= 13
456
457
458 r0 := fieldSubToMontgomery(rr, r1<<13)
459 return uint16(r1), r0
460 }
461
462
463 func highBits(r ringElement, p parameters) [n]byte {
464 var w [n]byte
465 switch p.γ2 {
466 case 32:
467 for i := range n {
468 w[i] = highBits32(fieldFromMontgomery(r[i]))
469 }
470 case 88:
471 for i := range n {
472 w[i] = highBits88(fieldFromMontgomery(r[i]))
473 }
474 default:
475 panic("mldsa: internal error: unsupported γ2")
476 }
477 return w
478 }
479
480
481
482
483 func useHint(r ringElement, h [n]byte, p parameters) [n]byte {
484 var w [n]byte
485 switch p.γ2 {
486 case 32:
487 for i := range n {
488 w[i] = useHint32(r[i], h[i])
489 }
490 case 88:
491 for i := range n {
492 w[i] = useHint88(r[i], h[i])
493 }
494 default:
495 panic("mldsa: internal error: unsupported γ2")
496 }
497 return w
498 }
499
500
501 func makeHint(ct0, w, cs2 ringElement, p parameters) (h [n]byte, count1s int) {
502 switch p.γ2 {
503 case 32:
504 for i := range n {
505 h[i] = makeHint32(ct0[i], w[i], cs2[i])
506 count1s += int(h[i])
507 }
508 case 88:
509 for i := range n {
510 h[i] = makeHint88(ct0[i], w[i], cs2[i])
511 count1s += int(h[i])
512 }
513 default:
514 panic("mldsa: internal error: unsupported γ2")
515 }
516 return h, count1s
517 }
518
519
520 func highBits32(x uint32) byte {
521
522
523
524
525 r1 := (x + 127) >> 7
526 r1 = (r1*1025 + (1 << 21)) >> 22
527 r1 &= 0b1111
528 return byte(r1)
529 }
530
531
532
533
534 func decompose32(r fieldElement) (r1 byte, r0 int32) {
535 x := fieldFromMontgomery(r)
536 r1 = highBits32(x)
537
538
539 r0 = int32(x) - int32(r1)*2*(q-1)/32
540 r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0)
541
542 return r1, r0
543 }
544
545
546 func useHint32(r fieldElement, hint byte) byte {
547 const m = 16
548 r1, r0 := decompose32(r)
549 if hint == 1 {
550 if r0 > 0 {
551 r1 = (r1 + 1) % m
552 } else {
553
554
555 r1 = (r1 - 1) % m
556 }
557 }
558 return r1
559 }
560
561
562 func makeHint32(ct0, w, cs2 fieldElement) byte {
563
564 rPlusZ := fieldSub(w, cs2)
565 v1 := highBits32(fieldFromMontgomery(rPlusZ))
566
567 r1 := highBits32(fieldFromMontgomery(fieldAdd(rPlusZ, ct0)))
568
569 return byte(constanttime.ByteEq(v1, r1) ^ 1)
570 }
571
572
573 func highBits88(x uint32) byte {
574
575 r1 := (x + 127) >> 7
576 r1 = (r1*11275 + (1 << 23)) >> 24
577 r1 = constantTimeSelectEqual(r1, 44, 0, r1)
578 return byte(r1)
579 }
580
581
582
583
584 func decompose88(r fieldElement) (r1 byte, r0 int32) {
585 x := fieldFromMontgomery(r)
586 r1 = highBits88(x)
587
588
589 r0 = int32(x) - int32(r1)*2*(q-1)/88
590 r0 = constantTimeSelectLessOrEqual(q/2+1, r0, r0-q, r0)
591
592 return r1, r0
593 }
594
595
596 func useHint88(r fieldElement, hint byte) byte {
597 const m = 44
598 r1, r0 := decompose88(r)
599 if hint == 1 {
600 if r0 > 0 {
601
602 if r1 == m-1 {
603 r1 = 0
604 } else {
605 r1++
606 }
607 } else {
608
609 if r1 == 0 {
610 r1 = m - 1
611 } else {
612 r1--
613 }
614 }
615 }
616 return r1
617 }
618
619
620 func makeHint88(ct0, w, cs2 fieldElement) byte {
621
622 rPlusZ := fieldSub(w, cs2)
623 v1 := highBits88(fieldFromMontgomery(rPlusZ))
624 r1 := highBits88(fieldFromMontgomery(fieldAdd(rPlusZ, ct0)))
625 return byte(constanttime.ByteEq(v1, r1) ^ 1)
626 }
627
628
629
630
631
632
633 func bitPack(b []byte, r ringElement, p parameters) []byte {
634 switch p.γ1 {
635 case 17:
636 return bitPack18(b, r)
637 case 19:
638 return bitPack20(b, r)
639 default:
640 panic("mldsa: internal error: unsupported γ1")
641 }
642 }
643
644
645
646
647
648
649 func bitPack18(buf []byte, r ringElement) []byte {
650 out, v := sliceForAppend(buf, 18*n/8)
651 const b = 1 << 17
652 for i := 0; i < n; i += 4 {
653
654 w0 := b - fieldCenteredMod(r[i])
655 v[0] = byte(w0 << 0)
656 v[1] = byte(w0 >> 8)
657 v[2] = byte(w0 >> 16)
658 w1 := b - fieldCenteredMod(r[i+1])
659 v[2] |= byte(w1 << 2)
660 v[3] = byte(w1 >> 6)
661 v[4] = byte(w1 >> 14)
662 w2 := b - fieldCenteredMod(r[i+2])
663 v[4] |= byte(w2 << 4)
664 v[5] = byte(w2 >> 4)
665 v[6] = byte(w2 >> 12)
666 w3 := b - fieldCenteredMod(r[i+3])
667 v[6] |= byte(w3 << 6)
668 v[7] = byte(w3 >> 2)
669 v[8] = byte(w3 >> 10)
670 v = v[4*18/8:]
671 }
672 return out
673 }
674
675
676
677
678
679
680 func bitPack20(buf []byte, r ringElement) []byte {
681 out, v := sliceForAppend(buf, 20*n/8)
682 const b = 1 << 19
683 for i := 0; i < n; i += 2 {
684
685 w0 := b - fieldCenteredMod(r[i])
686 v[0] = byte(w0 << 0)
687 v[1] = byte(w0 >> 8)
688 v[2] = byte(w0 >> 16)
689 w1 := b - fieldCenteredMod(r[i+1])
690 v[2] |= byte(w1 << 4)
691 v[3] = byte(w1 >> 4)
692 v[4] = byte(w1 >> 12)
693 v = v[2*20/8:]
694 }
695 return out
696 }
697
698
699
700 func bitUnpack(v []byte, p parameters) ringElement {
701 switch p.γ1 {
702 case 17:
703 return bitUnpack18(v)
704 case 19:
705 return bitUnpack20(v)
706 default:
707 panic("mldsa: internal error: unsupported γ1")
708 }
709 }
710
711
712
713 func bitUnpack18(v []byte) ringElement {
714 if len(v) != 18*n/8 {
715 panic("mldsa: internal error: invalid bitUnpack18 input length")
716 }
717 const b = 1 << 17
718 const mask18 = 1<<18 - 1
719 var r ringElement
720 for i := 0; i < n; i += 4 {
721 w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16
722 r[i+0] = fieldSubToMontgomery(b, w0&mask18)
723 w1 := uint32(v[2])>>2 | uint32(v[3])<<6 | uint32(v[4])<<14
724 r[i+1] = fieldSubToMontgomery(b, w1&mask18)
725 w2 := uint32(v[4])>>4 | uint32(v[5])<<4 | uint32(v[6])<<12
726 r[i+2] = fieldSubToMontgomery(b, w2&mask18)
727 w3 := uint32(v[6])>>6 | uint32(v[7])<<2 | uint32(v[8])<<10
728 r[i+3] = fieldSubToMontgomery(b, w3&mask18)
729 v = v[4*18/8:]
730 }
731 return r
732 }
733
734
735
736 func bitUnpack20(v []byte) ringElement {
737 if len(v) != 20*n/8 {
738 panic("mldsa: internal error: invalid bitUnpack20 input length")
739 }
740 const b = 1 << 19
741 const mask20 = 1<<20 - 1
742 var r ringElement
743 for i := 0; i < n; i += 2 {
744 w0 := uint32(v[0]) | uint32(v[1])<<8 | uint32(v[2])<<16
745 r[i+0] = fieldSubToMontgomery(b, w0&mask20)
746 w1 := uint32(v[2])>>4 | uint32(v[3])<<4 | uint32(v[4])<<12
747 r[i+1] = fieldSubToMontgomery(b, w1&mask20)
748 v = v[2*20/8:]
749 }
750 return r
751 }
752
753
754
755
756
757 func sliceForAppend(in []byte, n int) (head, tail []byte) {
758 if total := len(in) + n; cap(in) >= total {
759 head = in[:total]
760 } else {
761 head = make([]byte, total)
762 copy(head, in)
763 }
764 tail = head[len(in):]
765 return
766 }
767
768
769 func constantTimeSelectLessOrEqual(a, b, yes, no int32) int32 {
770 return int32(constanttime.Select(constanttime.LessOrEq(int(a), int(b)), int(yes), int(no)))
771 }
772
773
774 func constantTimeSelectEqual(a, b, yes, no uint32) uint32 {
775 return uint32(constanttime.Select(constanttime.Eq(int32(a), int32(b)), int(yes), int(no)))
776 }
777
778
779 func constantTimeAbs(x int32) uint32 {
780 return uint32(constantTimeSelectLessOrEqual(0, x, x, -x))
781 }
782
View as plain text