1
2
3
4
5 package mldsa
6
7 import (
8 "bytes"
9 "crypto/internal/fips140"
10 "crypto/internal/fips140/drbg"
11 "crypto/internal/fips140/sha3"
12 "crypto/internal/fips140/subtle"
13 "crypto/internal/fips140deps/byteorder"
14 "errors"
15 )
16
17 type parameters struct {
18 k, l int
19 η int
20 γ1 int
21 γ2 int
22 λ int
23 τ int
24 ω int
25 }
26
27 var (
28 params44 = parameters{k: 4, l: 4, η: 2, γ1: 17, γ2: 88, λ: 128, τ: 39, ω: 80}
29 params65 = parameters{k: 6, l: 5, η: 4, γ1: 19, γ2: 32, λ: 192, τ: 49, ω: 55}
30 params87 = parameters{k: 8, l: 7, η: 2, γ1: 19, γ2: 32, λ: 256, τ: 60, ω: 75}
31 )
32
33 func pubKeySize(p parameters) int {
34
35 return 32 + p.k*n*10/8
36 }
37
38 func sigSize(p parameters) int {
39
40 return (p.λ / 4) + p.l*n*(p.γ1+1)/8 + p.ω + p.k
41 }
42
43 const (
44 PrivateKeySize = 32
45
46 PublicKeySize44 = 32 + 4*n*10/8
47 PublicKeySize65 = 32 + 6*n*10/8
48 PublicKeySize87 = 32 + 8*n*10/8
49
50 SignatureSize44 = 128/4 + 4*n*(17+1)/8 + 80 + 4
51 SignatureSize65 = 192/4 + 5*n*(19+1)/8 + 55 + 6
52 SignatureSize87 = 256/4 + 7*n*(19+1)/8 + 75 + 8
53 )
54
55 const maxK, maxL, maxλ, maxγ1 = 8, 7, 256, 19
56 const maxPubKeySize = PublicKeySize87
57
58 type PrivateKey struct {
59 seed [32]byte
60 pub PublicKey
61 s1 [maxL]nttElement
62 s2 [maxK]nttElement
63 t0 [maxK]nttElement
64 k [32]byte
65 }
66
67 func (priv *PrivateKey) Equal(x *PrivateKey) bool {
68 return priv.pub.p == x.pub.p && subtle.ConstantTimeCompare(priv.seed[:], x.seed[:]) == 1
69 }
70
71 func (priv *PrivateKey) Bytes() []byte {
72 seed := priv.seed
73 return seed[:]
74 }
75
76 func (priv *PrivateKey) PublicKey() *PublicKey {
77
78
79 return &priv.pub
80 }
81
82 type PublicKey struct {
83 raw [maxPubKeySize]byte
84 p parameters
85 a [maxK * maxL]nttElement
86 t1 [maxK]nttElement
87 tr [64]byte
88 }
89
90 func (pub *PublicKey) Equal(x *PublicKey) bool {
91 size := pubKeySize(pub.p)
92 return pub.p == x.p && subtle.ConstantTimeCompare(pub.raw[:size], x.raw[:size]) == 1
93 }
94
95 func (pub *PublicKey) Bytes() []byte {
96 size := pubKeySize(pub.p)
97 return bytes.Clone(pub.raw[:size])
98 }
99
100 func (pub *PublicKey) Parameters() string {
101 switch pub.p {
102 case params44:
103 return "ML-DSA-44"
104 case params65:
105 return "ML-DSA-65"
106 case params87:
107 return "ML-DSA-87"
108 default:
109 panic("mldsa: internal error: unknown parameters")
110 }
111 }
112
113 func GenerateKey44() *PrivateKey {
114 fipsSelfTest()
115 fips140.RecordApproved()
116 var seed [32]byte
117 drbg.Read(seed[:])
118 priv := newPrivateKey(&seed, params44)
119 fipsPCT(priv)
120 return priv
121 }
122
123 func GenerateKey65() *PrivateKey {
124 fipsSelfTest()
125 fips140.RecordApproved()
126 var seed [32]byte
127 drbg.Read(seed[:])
128 priv := newPrivateKey(&seed, params65)
129 fipsPCT(priv)
130 return priv
131 }
132
133 func GenerateKey87() *PrivateKey {
134 fipsSelfTest()
135 fips140.RecordApproved()
136 var seed [32]byte
137 drbg.Read(seed[:])
138 priv := newPrivateKey(&seed, params87)
139 fipsPCT(priv)
140 return priv
141 }
142
143 var errInvalidSeedLength = errors.New("mldsa: invalid seed length")
144
145 func NewPrivateKey44(seed []byte) (*PrivateKey, error) {
146 fipsSelfTest()
147 fips140.RecordApproved()
148 if len(seed) != 32 {
149 return nil, errInvalidSeedLength
150 }
151 return newPrivateKey((*[32]byte)(seed), params44), nil
152 }
153
154 func NewPrivateKey65(seed []byte) (*PrivateKey, error) {
155 fipsSelfTest()
156 fips140.RecordApproved()
157 if len(seed) != 32 {
158 return nil, errInvalidSeedLength
159 }
160 return newPrivateKey((*[32]byte)(seed), params65), nil
161 }
162
163 func NewPrivateKey87(seed []byte) (*PrivateKey, error) {
164 fipsSelfTest()
165 fips140.RecordApproved()
166 if len(seed) != 32 {
167 return nil, errInvalidSeedLength
168 }
169 return newPrivateKey((*[32]byte)(seed), params87), nil
170 }
171
172 func newPrivateKey(seed *[32]byte, p parameters) *PrivateKey {
173 k, l := p.k, p.l
174
175 priv := &PrivateKey{pub: PublicKey{p: p}}
176 priv.seed = *seed
177
178 ξ := sha3.NewShake256()
179 ξ.Write(seed[:])
180 ξ.Write([]byte{byte(k), byte(l)})
181 ρ, ρs := make([]byte, 32), make([]byte, 64)
182 ξ.Read(ρ)
183 ξ.Read(ρs)
184 ξ.Read(priv.k[:])
185
186 A := priv.pub.a[:k*l]
187 computeMatrixA(A, ρ, p)
188
189 s1 := priv.s1[:l]
190 for r := range l {
191 s1[r] = ntt(sampleBoundedPoly(ρs, byte(r), p))
192 }
193 s2 := priv.s2[:k]
194 for r := range k {
195 s2[r] = ntt(sampleBoundedPoly(ρs, byte(l+r), p))
196 }
197
198
199 tHat := make([]nttElement, k, maxK)
200 for i := range tHat {
201 tHat[i] = s2[i]
202 for j := range s1 {
203 tHat[i] = polyAdd(tHat[i], nttMul(A[i*l+j], s1[j]))
204 }
205 }
206
207 t := make([]ringElement, k, maxK)
208 for i := range tHat {
209 t[i] = inverseNTT(tHat[i])
210 }
211
212
213 t1, t0 := make([][n]uint16, k, maxK), priv.t0[:k]
214 for i := range t {
215 var w ringElement
216 for j := range t[i] {
217 t1[i][j], w[j] = power2Round(t[i][j])
218 }
219 t0[i] = ntt(w)
220 }
221
222
223
224
225
226 pk := pkEncode(priv.pub.raw[:0], ρ, t1, p)
227 priv.pub.tr = computePublicKeyHash(pk)
228 computeT1Hat(priv.pub.t1[:k], t1)
229
230 return priv
231 }
232
233 func computeMatrixA(A []nttElement, ρ []byte, p parameters) {
234 k, l := p.k, p.l
235 for r := range k {
236 for s := range l {
237 A[r*l+s] = sampleNTT(ρ, byte(s), byte(r))
238 }
239 }
240 }
241
242 func computePublicKeyHash(pk []byte) [64]byte {
243 H := sha3.NewShake256()
244 H.Write(pk)
245 var tr [64]byte
246 H.Read(tr[:])
247 return tr
248 }
249
250 func computeT1Hat(t1Hat []nttElement, t1 [][n]uint16) {
251 for i := range t1 {
252 var w ringElement
253 for j := range t1[i] {
254
255
256 z, _ := fieldToMontgomery(uint32(t1[i][j]) << 13)
257 w[j] = z
258 }
259 t1Hat[i] = ntt(w)
260 }
261 }
262
263 func pkEncode(buf []byte, ρ []byte, t1 [][n]uint16, p parameters) []byte {
264 pk := append(buf, ρ...)
265 for _, w := range t1[:p.k] {
266
267 for i := 0; i < n; i += 4 {
268 c0 := w[i]
269 c1 := w[i+1]
270 c2 := w[i+2]
271 c3 := w[i+3]
272 b0 := byte(c0 >> 0)
273 b1 := byte((c0 >> 8) | (c1 << 2))
274 b2 := byte((c1 >> 6) | (c2 << 4))
275 b3 := byte((c2 >> 4) | (c3 << 6))
276 b4 := byte(c3 >> 2)
277 pk = append(pk, b0, b1, b2, b3, b4)
278 }
279 }
280 return pk
281 }
282
283 func pkDecode(pk []byte, t1 [][n]uint16, p parameters) (ρ []byte, err error) {
284 if len(pk) != pubKeySize(p) {
285 return nil, errInvalidPublicKeyLength
286 }
287 ρ, pk = pk[:32], pk[32:]
288 for r := range t1 {
289
290 for i := 0; i < n; i += 4 {
291 b0, b1, b2, b3, b4 := pk[0], pk[1], pk[2], pk[3], pk[4]
292 t1[r][i+0] = uint16(b0>>0) | uint16(b1&0b0000_0011)<<8
293 t1[r][i+1] = uint16(b1>>2) | uint16(b2&0b0000_1111)<<6
294 t1[r][i+2] = uint16(b2>>4) | uint16(b3&0b0011_1111)<<4
295 t1[r][i+3] = uint16(b3>>6) | uint16(b4&0b1111_1111)<<2
296 pk = pk[5:]
297 }
298 }
299 return ρ, nil
300 }
301
302 var errInvalidPublicKeyLength = errors.New("mldsa: invalid public key length")
303
304 func NewPublicKey44(pk []byte) (*PublicKey, error) {
305 return newPublicKey(pk, params44)
306 }
307
308 func NewPublicKey65(pk []byte) (*PublicKey, error) {
309 return newPublicKey(pk, params65)
310 }
311
312 func NewPublicKey87(pk []byte) (*PublicKey, error) {
313 return newPublicKey(pk, params87)
314 }
315
316 func newPublicKey(pk []byte, p parameters) (*PublicKey, error) {
317 k, l := p.k, p.l
318
319 t1 := make([][n]uint16, k, maxK)
320 ρ, err := pkDecode(pk, t1, p)
321 if err != nil {
322 return nil, err
323 }
324
325 pub := &PublicKey{p: p}
326 copy(pub.raw[:], pk)
327 computeMatrixA(pub.a[:k*l], ρ, p)
328 pub.tr = computePublicKeyHash(pk)
329 computeT1Hat(pub.t1[:k], t1)
330
331 return pub, nil
332 }
333
334 var (
335 errContextTooLong = errors.New("mldsa: context too long")
336 errMessageHashLength = errors.New("mldsa: invalid message hash length")
337 errRandomLength = errors.New("mldsa: invalid random length")
338 )
339
340 func Sign(priv *PrivateKey, msg []byte, context string) ([]byte, error) {
341 fipsSelfTest()
342 fips140.RecordApproved()
343 var random [32]byte
344 drbg.Read(random[:])
345 μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
346 if err != nil {
347 return nil, err
348 }
349 return signInternal(priv, &μ, &random), nil
350 }
351
352 func SignDeterministic(priv *PrivateKey, msg []byte, context string) ([]byte, error) {
353 fipsSelfTest()
354 fips140.RecordApproved()
355 var random [32]byte
356 μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
357 if err != nil {
358 return nil, err
359 }
360 return signInternal(priv, &μ, &random), nil
361 }
362
363 func TestingOnlySignWithRandom(priv *PrivateKey, msg []byte, context string, random []byte) ([]byte, error) {
364 fipsSelfTest()
365 fips140.RecordApproved()
366 μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
367 if err != nil {
368 return nil, err
369 }
370 if len(random) != 32 {
371 return nil, errRandomLength
372 }
373 return signInternal(priv, &μ, (*[32]byte)(random)), nil
374 }
375
376 func SignExternalMu(priv *PrivateKey, μ []byte) ([]byte, error) {
377 fipsSelfTest()
378 fips140.RecordApproved()
379 var random [32]byte
380 drbg.Read(random[:])
381 if len(μ) != 64 {
382 return nil, errMessageHashLength
383 }
384 return signInternal(priv, (*[64]byte)(μ), &random), nil
385 }
386
387 func SignExternalMuDeterministic(priv *PrivateKey, μ []byte) ([]byte, error) {
388 fipsSelfTest()
389 fips140.RecordApproved()
390 var random [32]byte
391 if len(μ) != 64 {
392 return nil, errMessageHashLength
393 }
394 return signInternal(priv, (*[64]byte)(μ), &random), nil
395 }
396
397 func TestingOnlySignExternalMuWithRandom(priv *PrivateKey, μ []byte, random []byte) ([]byte, error) {
398 fipsSelfTest()
399 fips140.RecordApproved()
400 if len(μ) != 64 {
401 return nil, errMessageHashLength
402 }
403 if len(random) != 32 {
404 return nil, errRandomLength
405 }
406 return signInternal(priv, (*[64]byte)(μ), (*[32]byte)(random)), nil
407 }
408
409 func computeMessageHash(tr []byte, msg []byte, context string) ([64]byte, error) {
410 if len(context) > 255 {
411 return [64]byte{}, errContextTooLong
412 }
413 H := sha3.NewShake256()
414 H.Write(tr)
415 H.Write([]byte{0})
416 H.Write([]byte{byte(len(context))})
417 H.Write([]byte(context))
418 H.Write(msg)
419 var μ [64]byte
420 H.Read(μ[:])
421 return μ, nil
422 }
423
424 func signInternal(priv *PrivateKey, μ *[64]byte, random *[32]byte) []byte {
425 p, k, l := priv.pub.p, priv.pub.p.k, priv.pub.p.l
426 A, s1, s2, t0 := priv.pub.a[:k*l], priv.s1[:l], priv.s2[:k], priv.t0[:k]
427
428 β := p.τ * p.η
429 γ1 := uint32(1 << p.γ1)
430 γ1β := γ1 - uint32(β)
431 γ2 := (q - 1) / uint32(p.γ2)
432 γ2β := γ2 - uint32(β)
433
434 H := sha3.NewShake256()
435 H.Write(priv.k[:])
436 H.Write(random[:])
437 H.Write(μ[:])
438 nonce := make([]byte, 64)
439 H.Read(nonce)
440
441 κ := 0
442 sign:
443 for {
444
445
446
447
448
449
450
451 y := make([]ringElement, l, maxL)
452 for r := range y {
453 counter := make([]byte, 2)
454 byteorder.LEPutUint16(counter, uint16(κ))
455 κ++
456
457 H.Reset()
458 H.Write(nonce)
459 H.Write(counter)
460 v := make([]byte, (p.γ1+1)*n/8, (maxγ1+1)*n/8)
461 H.Read(v)
462
463 y[r] = bitUnpack(v, p)
464 }
465
466
467 yHat := make([]nttElement, l, maxL)
468 for i := range y {
469 yHat[i] = ntt(y[i])
470 }
471 w := make([]ringElement, k, maxK)
472 for i := range w {
473 var wHat nttElement
474 for j := range l {
475 wHat = polyAdd(wHat, nttMul(A[i*l+j], yHat[j]))
476 }
477 w[i] = inverseNTT(wHat)
478 }
479
480 H.Reset()
481 H.Write(μ[:])
482 for i := range w {
483 w1Encode(H, highBits(w[i], p), p)
484 }
485 ch := make([]byte, p.λ/4, maxλ/4)
486 H.Read(ch)
487
488
489
490 c := ntt(sampleInBall(ch, p))
491
492 cs1 := make([]ringElement, l, maxL)
493 for i := range cs1 {
494 cs1[i] = inverseNTT(nttMul(c, s1[i]))
495 }
496 cs2 := make([]ringElement, k, maxK)
497 for i := range cs2 {
498 cs2[i] = inverseNTT(nttMul(c, s2[i]))
499 }
500
501 z := make([]ringElement, l, maxL)
502 for i := range y {
503 z[i] = polyAdd(y[i], cs1[i])
504
505
506 if coefficientsExceedBound(z[i], γ1β) {
507 if testingOnlyRejectionReason != nil {
508 testingOnlyRejectionReason("z")
509 }
510 continue sign
511 }
512 }
513
514 for i := range w {
515 r0 := polySub(w[i], cs2[i])
516
517
518 if lowBitsExceedBound(r0, γ2β, p) {
519 if testingOnlyRejectionReason != nil {
520 testingOnlyRejectionReason("r0")
521 }
522 continue sign
523 }
524 }
525
526 ct0 := make([]ringElement, k, maxK)
527 for i := range ct0 {
528 ct0[i] = inverseNTT(nttMul(c, t0[i]))
529
530
531 if coefficientsExceedBound(ct0[i], γ2) {
532 if testingOnlyRejectionReason != nil {
533 testingOnlyRejectionReason("ct0")
534 }
535 continue sign
536 }
537 }
538
539 count1s := 0
540 h := make([][n]byte, k, maxK)
541 for i := range w {
542 var count int
543 h[i], count = makeHint(ct0[i], w[i], cs2[i], p)
544 count1s += count
545 }
546
547 if count1s > p.ω {
548 if testingOnlyRejectionReason != nil {
549 testingOnlyRejectionReason("h")
550 }
551 continue sign
552 }
553
554 return sigEncode(ch, z, h, p)
555 }
556 }
557
558
559
560
561 var testingOnlyRejectionReason func(reason string)
562
563
564 func w1Encode(H *sha3.SHAKE, w [n]byte, p parameters) {
565 switch p.γ2 {
566 case 32:
567
568 buf := make([]byte, 4*n/8)
569 for i := 0; i < n; i += 2 {
570 b0 := w[i]
571 b1 := w[i+1]
572 buf[i/2] = b0 | b1<<4
573 }
574 H.Write(buf)
575 case 88:
576
577 buf := make([]byte, 6*n/8)
578 for i := 0; i < n; i += 4 {
579 b0 := w[i]
580 b1 := w[i+1]
581 b2 := w[i+2]
582 b3 := w[i+3]
583 buf[3*i/4+0] = (b0 >> 0) | (b1 << 6)
584 buf[3*i/4+1] = (b1 >> 2) | (b2 << 4)
585 buf[3*i/4+2] = (b2 >> 4) | (b3 << 2)
586 }
587 H.Write(buf)
588 default:
589 panic("mldsa: internal error: unsupported γ2")
590 }
591 }
592
593 func coefficientsExceedBound(w ringElement, bound uint32) bool {
594
595
596
597 for i := range w {
598 if fieldInfinityNorm(w[i]) >= bound {
599 return true
600 }
601 }
602 return false
603 }
604
605 func lowBitsExceedBound(w ringElement, bound uint32, p parameters) bool {
606 switch p.γ2 {
607 case 32:
608 for i := range w {
609 _, r0 := decompose32(w[i])
610 if constantTimeAbs(r0) >= bound {
611 return true
612 }
613 }
614 case 88:
615 for i := range w {
616 _, r0 := decompose88(w[i])
617 if constantTimeAbs(r0) >= bound {
618 return true
619 }
620 }
621 default:
622 panic("mldsa: internal error: unsupported γ2")
623 }
624 return false
625 }
626
627 var (
628 errInvalidSignatureLength = errors.New("mldsa: invalid signature length")
629 errInvalidSignatureCoeffBounds = errors.New("mldsa: invalid signature")
630 errInvalidSignatureChallenge = errors.New("mldsa: invalid signature")
631 errInvalidSignatureHintLimits = errors.New("mldsa: invalid signature encoding")
632 errInvalidSignatureHintIndexOrder = errors.New("mldsa: invalid signature encoding")
633 errInvalidSignatureHintExtraIndices = errors.New("mldsa: invalid signature encoding")
634 )
635
636 func Verify(pub *PublicKey, msg, sig []byte, context string) error {
637 fipsSelfTest()
638 fips140.RecordApproved()
639 μ, err := computeMessageHash(pub.tr[:], msg, context)
640 if err != nil {
641 return err
642 }
643 return verifyInternal(pub, &μ, sig)
644 }
645
646 func VerifyExternalMu(pub *PublicKey, μ []byte, sig []byte) error {
647 fipsSelfTest()
648 fips140.RecordApproved()
649 if len(μ) != 64 {
650 return errMessageHashLength
651 }
652 return verifyInternal(pub, (*[64]byte)(μ), sig)
653 }
654
655 func verifyInternal(pub *PublicKey, μ *[64]byte, sig []byte) error {
656 p, k, l := pub.p, pub.p.k, pub.p.l
657 t1, A := pub.t1[:k], pub.a[:k*l]
658
659 β := p.τ * p.η
660 γ1 := uint32(1 << p.γ1)
661 γ1β := γ1 - uint32(β)
662
663 z := make([]ringElement, l, maxL)
664 h := make([][n]byte, k, maxK)
665 ch, err := sigDecode(sig, z, h, p)
666 if err != nil {
667 return err
668 }
669
670 c := ntt(sampleInBall(ch, p))
671
672
673 zHat := make([]nttElement, l, maxL)
674 for i := range zHat {
675 zHat[i] = ntt(z[i])
676 }
677 w := make([]ringElement, k, maxK)
678 for i := range w {
679 var wHat nttElement
680 for j := range l {
681 wHat = polyAdd(wHat, nttMul(A[i*l+j], zHat[j]))
682 }
683 wHat = polySub(wHat, nttMul(c, t1[i]))
684 w[i] = inverseNTT(wHat)
685 }
686
687
688 w1 := make([][n]byte, k, maxK)
689 for i := range w {
690 w1[i] = useHint(w[i], h[i], p)
691 }
692
693 H := sha3.NewShake256()
694 H.Write(μ[:])
695 for i := range w {
696 w1Encode(H, w1[i], p)
697 }
698 computedCH := make([]byte, p.λ/4, maxλ/4)
699 H.Read(computedCH)
700
701 for i := range z {
702 if coefficientsExceedBound(z[i], γ1β) {
703 return errInvalidSignatureCoeffBounds
704 }
705 }
706
707 if !bytes.Equal(ch, computedCH) {
708 return errInvalidSignatureChallenge
709 }
710
711 return nil
712 }
713
714 func sigEncode(ch []byte, z []ringElement, h [][n]byte, p parameters) []byte {
715 sig := make([]byte, 0, sigSize(p))
716 sig = append(sig, ch...)
717 for i := range z {
718 sig = bitPack(sig, z[i], p)
719 }
720 sig = hintEncode(sig, h, p)
721 return sig
722 }
723
724 func sigDecode(sig []byte, z []ringElement, h [][n]byte, p parameters) (ch []byte, err error) {
725 if len(sig) != sigSize(p) {
726 return nil, errInvalidSignatureLength
727 }
728 ch, sig = sig[:p.λ/4], sig[p.λ/4:]
729 for i := range z {
730 length := (p.γ1 + 1) * n / 8
731 z[i] = bitUnpack(sig[:length], p)
732 sig = sig[length:]
733 }
734 if err := hintDecode(sig, h, p); err != nil {
735 return nil, err
736 }
737 return ch, nil
738 }
739
740 func hintEncode(buf []byte, h [][n]byte, p parameters) []byte {
741 ω, k := p.ω, p.k
742 out, y := sliceForAppend(buf, ω+k)
743 var idx byte
744 for i := range k {
745 for j := range n {
746 if h[i][j] != 0 {
747 y[idx] = byte(j)
748 idx++
749 }
750 }
751 y[ω+i] = idx
752 }
753 return out
754 }
755
756 func hintDecode(y []byte, h [][n]byte, p parameters) error {
757 ω, k := p.ω, p.k
758 if len(y) != ω+k {
759 return errors.New("mldsa: internal error: invalid signature hint length")
760 }
761 var idx byte
762 for i := range k {
763 limit := y[ω+i]
764 if limit < idx || limit > byte(ω) {
765 return errInvalidSignatureHintLimits
766 }
767 first := idx
768 for idx < limit {
769 if idx > first && y[idx-1] >= y[idx] {
770 return errInvalidSignatureHintIndexOrder
771 }
772 h[i][y[idx]] = 1
773 idx++
774 }
775 }
776 for i := idx; i < byte(ω); i++ {
777 if y[i] != 0 {
778 return errInvalidSignatureHintExtraIndices
779 }
780 }
781 return nil
782 }
783
View as plain text