Source file src/crypto/internal/fips140/mldsa/mldsa.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mldsa
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/fips140"
    10  	"crypto/internal/fips140/drbg"
    11  	"crypto/internal/fips140/sha3"
    12  	"crypto/internal/fips140/subtle"
    13  	"crypto/internal/fips140deps/byteorder"
    14  	"errors"
    15  )
    16  
    17  type parameters struct {
    18  	k, l int // dimensions of A
    19  	η    int // bound for secret coefficients
    20  	γ1   int // log₂(γ₁), where [-γ₁+1, γ₁] is the bound of y
    21  	γ2   int // denominator of γ₂ = (q - 1) / γ2
    22  	λ    int // collison strength
    23  	τ    int // number of non-zero coefficients in challenge
    24  	ω    int // max number of hints in MakeHint
    25  }
    26  
    27  var (
    28  	params44 = parameters{k: 4, l: 4, η: 2, γ1: 17, γ2: 88, λ: 128, τ: 39, ω: 80}
    29  	params65 = parameters{k: 6, l: 5, η: 4, γ1: 19, γ2: 32, λ: 192, τ: 49, ω: 55}
    30  	params87 = parameters{k: 8, l: 7, η: 2, γ1: 19, γ2: 32, λ: 256, τ: 60, ω: 75}
    31  )
    32  
    33  func pubKeySize(p parameters) int {
    34  	// ρ + k × n × 10-bit coefficients of t₁
    35  	return 32 + p.k*n*10/8
    36  }
    37  
    38  func sigSize(p parameters) int {
    39  	// challenge + l × n × (γ₁+1)-bit coefficients of z + hint
    40  	return (p.λ / 4) + p.l*n*(p.γ1+1)/8 + p.ω + p.k
    41  }
    42  
    43  const (
    44  	PrivateKeySize = 32
    45  
    46  	PublicKeySize44 = 32 + 4*n*10/8
    47  	PublicKeySize65 = 32 + 6*n*10/8
    48  	PublicKeySize87 = 32 + 8*n*10/8
    49  
    50  	SignatureSize44 = 128/4 + 4*n*(17+1)/8 + 80 + 4
    51  	SignatureSize65 = 192/4 + 5*n*(19+1)/8 + 55 + 6
    52  	SignatureSize87 = 256/4 + 7*n*(19+1)/8 + 75 + 8
    53  )
    54  
    55  const maxK, maxL, maxλ, maxγ1 = 8, 7, 256, 19
    56  const maxPubKeySize = PublicKeySize87
    57  
    58  type PrivateKey struct {
    59  	seed [32]byte
    60  	pub  PublicKey
    61  	s1   [maxL]nttElement
    62  	s2   [maxK]nttElement
    63  	t0   [maxK]nttElement
    64  	k    [32]byte
    65  }
    66  
    67  func (priv *PrivateKey) Equal(x *PrivateKey) bool {
    68  	return priv.pub.p == x.pub.p && subtle.ConstantTimeCompare(priv.seed[:], x.seed[:]) == 1
    69  }
    70  
    71  func (priv *PrivateKey) Bytes() []byte {
    72  	seed := priv.seed
    73  	return seed[:]
    74  }
    75  
    76  func (priv *PrivateKey) PublicKey() *PublicKey {
    77  	// Note that this is likely to keep the entire PrivateKey reachable for
    78  	// the lifetime of the PublicKey, which may be undesirable.
    79  	return &priv.pub
    80  }
    81  
    82  type PublicKey struct {
    83  	raw [maxPubKeySize]byte
    84  	p   parameters
    85  	a   [maxK * maxL]nttElement
    86  	t1  [maxK]nttElement // NTT(t₁ ⋅ 2ᵈ)
    87  	tr  [64]byte         // public key hash
    88  }
    89  
    90  func (pub *PublicKey) Equal(x *PublicKey) bool {
    91  	size := pubKeySize(pub.p)
    92  	return pub.p == x.p && subtle.ConstantTimeCompare(pub.raw[:size], x.raw[:size]) == 1
    93  }
    94  
    95  func (pub *PublicKey) Bytes() []byte {
    96  	size := pubKeySize(pub.p)
    97  	return bytes.Clone(pub.raw[:size])
    98  }
    99  
   100  func (pub *PublicKey) Parameters() string {
   101  	switch pub.p {
   102  	case params44:
   103  		return "ML-DSA-44"
   104  	case params65:
   105  		return "ML-DSA-65"
   106  	case params87:
   107  		return "ML-DSA-87"
   108  	default:
   109  		panic("mldsa: internal error: unknown parameters")
   110  	}
   111  }
   112  
   113  func GenerateKey44() *PrivateKey {
   114  	fipsSelfTest()
   115  	fips140.RecordApproved()
   116  	var seed [32]byte
   117  	drbg.Read(seed[:])
   118  	priv := newPrivateKey(&seed, params44)
   119  	fipsPCT(priv)
   120  	return priv
   121  }
   122  
   123  func GenerateKey65() *PrivateKey {
   124  	fipsSelfTest()
   125  	fips140.RecordApproved()
   126  	var seed [32]byte
   127  	drbg.Read(seed[:])
   128  	priv := newPrivateKey(&seed, params65)
   129  	fipsPCT(priv)
   130  	return priv
   131  }
   132  
   133  func GenerateKey87() *PrivateKey {
   134  	fipsSelfTest()
   135  	fips140.RecordApproved()
   136  	var seed [32]byte
   137  	drbg.Read(seed[:])
   138  	priv := newPrivateKey(&seed, params87)
   139  	fipsPCT(priv)
   140  	return priv
   141  }
   142  
   143  var errInvalidSeedLength = errors.New("mldsa: invalid seed length")
   144  
   145  func NewPrivateKey44(seed []byte) (*PrivateKey, error) {
   146  	fipsSelfTest()
   147  	fips140.RecordApproved()
   148  	if len(seed) != 32 {
   149  		return nil, errInvalidSeedLength
   150  	}
   151  	return newPrivateKey((*[32]byte)(seed), params44), nil
   152  }
   153  
   154  func NewPrivateKey65(seed []byte) (*PrivateKey, error) {
   155  	fipsSelfTest()
   156  	fips140.RecordApproved()
   157  	if len(seed) != 32 {
   158  		return nil, errInvalidSeedLength
   159  	}
   160  	return newPrivateKey((*[32]byte)(seed), params65), nil
   161  }
   162  
   163  func NewPrivateKey87(seed []byte) (*PrivateKey, error) {
   164  	fipsSelfTest()
   165  	fips140.RecordApproved()
   166  	if len(seed) != 32 {
   167  		return nil, errInvalidSeedLength
   168  	}
   169  	return newPrivateKey((*[32]byte)(seed), params87), nil
   170  }
   171  
   172  func newPrivateKey(seed *[32]byte, p parameters) *PrivateKey {
   173  	k, l := p.k, p.l
   174  
   175  	priv := &PrivateKey{pub: PublicKey{p: p}}
   176  	priv.seed = *seed
   177  
   178  	ξ := sha3.NewShake256()
   179  	ξ.Write(seed[:])
   180  	ξ.Write([]byte{byte(k), byte(l)})
   181  	ρ, ρs := make([]byte, 32), make([]byte, 64)
   182  	ξ.Read(ρ)
   183  	ξ.Read(ρs)
   184  	ξ.Read(priv.k[:])
   185  
   186  	A := priv.pub.a[:k*l]
   187  	computeMatrixA(A, ρ, p)
   188  
   189  	s1 := priv.s1[:l]
   190  	for r := range l {
   191  		s1[r] = ntt(sampleBoundedPoly(ρs, byte(r), p))
   192  	}
   193  	s2 := priv.s2[:k]
   194  	for r := range k {
   195  		s2[r] = ntt(sampleBoundedPoly(ρs, byte(l+r), p))
   196  	}
   197  
   198  	// ˆt = Â ∘ ŝ₁ + ŝ₂
   199  	tHat := make([]nttElement, k, maxK)
   200  	for i := range tHat {
   201  		tHat[i] = s2[i]
   202  		for j := range s1 {
   203  			tHat[i] = polyAdd(tHat[i], nttMul(A[i*l+j], s1[j]))
   204  		}
   205  	}
   206  	// t = NTT⁻¹(ˆt)
   207  	t := make([]ringElement, k, maxK)
   208  	for i := range tHat {
   209  		t[i] = inverseNTT(tHat[i])
   210  	}
   211  	// (t₁, _) = Power2Round(t)
   212  	// (_, ˆt₀) = NTT(Power2Round(t))
   213  	t1, t0 := make([][n]uint16, k, maxK), priv.t0[:k]
   214  	for i := range t {
   215  		var w ringElement
   216  		for j := range t[i] {
   217  			t1[i][j], w[j] = power2Round(t[i][j])
   218  		}
   219  		t0[i] = ntt(w)
   220  	}
   221  
   222  	// The computations below (and their storage in the PrivateKey struct) are
   223  	// not strictly necessary and could be deferred to PrivateKey.PublicKey().
   224  	// That would require keeping or re-deriving ρ and t/t1, though.
   225  
   226  	pk := pkEncode(priv.pub.raw[:0], ρ, t1, p)
   227  	priv.pub.tr = computePublicKeyHash(pk)
   228  	computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
   229  
   230  	return priv
   231  }
   232  
   233  func computeMatrixA(A []nttElement, ρ []byte, p parameters) {
   234  	k, l := p.k, p.l
   235  	for r := range k {
   236  		for s := range l {
   237  			A[r*l+s] = sampleNTT(ρ, byte(s), byte(r))
   238  		}
   239  	}
   240  }
   241  
   242  func computePublicKeyHash(pk []byte) [64]byte {
   243  	H := sha3.NewShake256()
   244  	H.Write(pk)
   245  	var tr [64]byte
   246  	H.Read(tr[:])
   247  	return tr
   248  }
   249  
   250  func computeT1Hat(t1Hat []nttElement, t1 [][n]uint16) {
   251  	for i := range t1 {
   252  		var w ringElement
   253  		for j := range t1[i] {
   254  			// t₁ <= 2¹⁰ - 1
   255  			// t₁ ⋅ 2ᵈ <= 2ᵈ(2¹⁰ - 1) = 2²³ - 2¹³ < q = 2²³ - 2¹³ + 1
   256  			z, _ := fieldToMontgomery(uint32(t1[i][j]) << 13)
   257  			w[j] = z
   258  		}
   259  		t1Hat[i] = ntt(w)
   260  	}
   261  }
   262  
   263  func pkEncode(buf []byte, ρ []byte, t1 [][n]uint16, p parameters) []byte {
   264  	pk := append(buf, ρ...)
   265  	for _, w := range t1[:p.k] {
   266  		// Encode four at a time into 4 * 10 bits = 5 bytes.
   267  		for i := 0; i < n; i += 4 {
   268  			c0 := w[i]
   269  			c1 := w[i+1]
   270  			c2 := w[i+2]
   271  			c3 := w[i+3]
   272  			b0 := byte(c0 >> 0)
   273  			b1 := byte((c0 >> 8) | (c1 << 2))
   274  			b2 := byte((c1 >> 6) | (c2 << 4))
   275  			b3 := byte((c2 >> 4) | (c3 << 6))
   276  			b4 := byte(c3 >> 2)
   277  			pk = append(pk, b0, b1, b2, b3, b4)
   278  		}
   279  	}
   280  	return pk
   281  }
   282  
   283  func pkDecode(pk []byte, t1 [][n]uint16, p parameters) (ρ []byte, err error) {
   284  	if len(pk) != pubKeySize(p) {
   285  		return nil, errInvalidPublicKeyLength
   286  	}
   287  	ρ, pk = pk[:32], pk[32:]
   288  	for r := range t1 {
   289  		// Decode four at a time from 4 * 10 bits = 5 bytes.
   290  		for i := 0; i < n; i += 4 {
   291  			b0, b1, b2, b3, b4 := pk[0], pk[1], pk[2], pk[3], pk[4]
   292  			t1[r][i+0] = uint16(b0>>0) | uint16(b1&0b0000_0011)<<8
   293  			t1[r][i+1] = uint16(b1>>2) | uint16(b2&0b0000_1111)<<6
   294  			t1[r][i+2] = uint16(b2>>4) | uint16(b3&0b0011_1111)<<4
   295  			t1[r][i+3] = uint16(b3>>6) | uint16(b4&0b1111_1111)<<2
   296  			pk = pk[5:]
   297  		}
   298  	}
   299  	return ρ, nil
   300  }
   301  
   302  var errInvalidPublicKeyLength = errors.New("mldsa: invalid public key length")
   303  
   304  func NewPublicKey44(pk []byte) (*PublicKey, error) {
   305  	return newPublicKey(pk, params44)
   306  }
   307  
   308  func NewPublicKey65(pk []byte) (*PublicKey, error) {
   309  	return newPublicKey(pk, params65)
   310  }
   311  
   312  func NewPublicKey87(pk []byte) (*PublicKey, error) {
   313  	return newPublicKey(pk, params87)
   314  }
   315  
   316  func newPublicKey(pk []byte, p parameters) (*PublicKey, error) {
   317  	k, l := p.k, p.l
   318  
   319  	t1 := make([][n]uint16, k, maxK)
   320  	ρ, err := pkDecode(pk, t1, p)
   321  	if err != nil {
   322  		return nil, err
   323  	}
   324  
   325  	pub := &PublicKey{p: p}
   326  	copy(pub.raw[:], pk)
   327  	computeMatrixA(pub.a[:k*l], ρ, p)
   328  	pub.tr = computePublicKeyHash(pk)
   329  	computeT1Hat(pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
   330  
   331  	return pub, nil
   332  }
   333  
   334  var (
   335  	errContextTooLong    = errors.New("mldsa: context too long")
   336  	errMessageHashLength = errors.New("mldsa: invalid message hash length")
   337  	errRandomLength      = errors.New("mldsa: invalid random length")
   338  )
   339  
   340  func Sign(priv *PrivateKey, msg []byte, context string) ([]byte, error) {
   341  	fipsSelfTest()
   342  	fips140.RecordApproved()
   343  	var random [32]byte
   344  	drbg.Read(random[:])
   345  	μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  	return signInternal(priv, &μ, &random), nil
   350  }
   351  
   352  func SignDeterministic(priv *PrivateKey, msg []byte, context string) ([]byte, error) {
   353  	fipsSelfTest()
   354  	fips140.RecordApproved()
   355  	var random [32]byte
   356  	μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  	return signInternal(priv, &μ, &random), nil
   361  }
   362  
   363  func TestingOnlySignWithRandom(priv *PrivateKey, msg []byte, context string, random []byte) ([]byte, error) {
   364  	fipsSelfTest()
   365  	fips140.RecordApproved()
   366  	μ, err := computeMessageHash(priv.pub.tr[:], msg, context)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  	if len(random) != 32 {
   371  		return nil, errRandomLength
   372  	}
   373  	return signInternal(priv, &μ, (*[32]byte)(random)), nil
   374  }
   375  
   376  func SignExternalMu(priv *PrivateKey, μ []byte) ([]byte, error) {
   377  	fipsSelfTest()
   378  	fips140.RecordApproved()
   379  	var random [32]byte
   380  	drbg.Read(random[:])
   381  	if len(μ) != 64 {
   382  		return nil, errMessageHashLength
   383  	}
   384  	return signInternal(priv, (*[64]byte)(μ), &random), nil
   385  }
   386  
   387  func SignExternalMuDeterministic(priv *PrivateKey, μ []byte) ([]byte, error) {
   388  	fipsSelfTest()
   389  	fips140.RecordApproved()
   390  	var random [32]byte
   391  	if len(μ) != 64 {
   392  		return nil, errMessageHashLength
   393  	}
   394  	return signInternal(priv, (*[64]byte)(μ), &random), nil
   395  }
   396  
   397  func TestingOnlySignExternalMuWithRandom(priv *PrivateKey, μ []byte, random []byte) ([]byte, error) {
   398  	fipsSelfTest()
   399  	fips140.RecordApproved()
   400  	if len(μ) != 64 {
   401  		return nil, errMessageHashLength
   402  	}
   403  	if len(random) != 32 {
   404  		return nil, errRandomLength
   405  	}
   406  	return signInternal(priv, (*[64]byte)(μ), (*[32]byte)(random)), nil
   407  }
   408  
   409  func computeMessageHash(tr []byte, msg []byte, context string) ([64]byte, error) {
   410  	if len(context) > 255 {
   411  		return [64]byte{}, errContextTooLong
   412  	}
   413  	H := sha3.NewShake256()
   414  	H.Write(tr)
   415  	H.Write([]byte{0}) // ML-DSA / HashML-DSA domain separator
   416  	H.Write([]byte{byte(len(context))})
   417  	H.Write([]byte(context))
   418  	H.Write(msg)
   419  	var μ [64]byte
   420  	H.Read(μ[:])
   421  	return μ, nil
   422  }
   423  
   424  func signInternal(priv *PrivateKey, μ *[64]byte, random *[32]byte) []byte {
   425  	p, k, l := priv.pub.p, priv.pub.p.k, priv.pub.p.l
   426  	A, s1, s2, t0 := priv.pub.a[:k*l], priv.s1[:l], priv.s2[:k], priv.t0[:k]
   427  
   428  	β := p.τ * p.η
   429  	γ1 := uint32(1 << p.γ1)
   430  	γ1β := γ1 - uint32(β)
   431  	γ2 := (q - 1) / uint32(p.γ2)
   432  	γ2β := γ2 - uint32(β)
   433  
   434  	H := sha3.NewShake256()
   435  	H.Write(priv.k[:])
   436  	H.Write(random[:])
   437  	H.Write(μ[:])
   438  	nonce := make([]byte, 64)
   439  	H.Read(nonce)
   440  
   441  	κ := 0
   442  sign:
   443  	for {
   444  		// Main rejection sampling loop. Note that leaking rejected signatures
   445  		// leaks information about the private key. However, as explained in
   446  		// https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
   447  		// Section 5.5, we are free to leak rejected ch values, as well as which
   448  		// check causes the rejection and which coefficient failed the check
   449  		// (but not the value or sign of the coefficient).
   450  
   451  		y := make([]ringElement, l, maxL)
   452  		for r := range y {
   453  			counter := make([]byte, 2)
   454  			byteorder.LEPutUint16(counter, uint16(κ))
   455  			κ++
   456  
   457  			H.Reset()
   458  			H.Write(nonce)
   459  			H.Write(counter)
   460  			v := make([]byte, (p.γ1+1)*n/8, (maxγ1+1)*n/8)
   461  			H.Read(v)
   462  
   463  			y[r] = bitUnpack(v, p)
   464  		}
   465  
   466  		// w = NTT⁻¹(Â ∘ NTT(y))
   467  		yHat := make([]nttElement, l, maxL)
   468  		for i := range y {
   469  			yHat[i] = ntt(y[i])
   470  		}
   471  		w := make([]ringElement, k, maxK)
   472  		for i := range w {
   473  			var wHat nttElement
   474  			for j := range l {
   475  				wHat = polyAdd(wHat, nttMul(A[i*l+j], yHat[j]))
   476  			}
   477  			w[i] = inverseNTT(wHat)
   478  		}
   479  
   480  		H.Reset()
   481  		H.Write(μ[:])
   482  		for i := range w {
   483  			w1Encode(H, highBits(w[i], p), p)
   484  		}
   485  		ch := make([]byte, p.λ/4, maxλ/4)
   486  		H.Read(ch)
   487  
   488  		// sampleInBall is not constant time, but see comment above about
   489  		// leaking rejected ch values being acceptable.
   490  		c := ntt(sampleInBall(ch, p))
   491  
   492  		cs1 := make([]ringElement, l, maxL)
   493  		for i := range cs1 {
   494  			cs1[i] = inverseNTT(nttMul(c, s1[i]))
   495  		}
   496  		cs2 := make([]ringElement, k, maxK)
   497  		for i := range cs2 {
   498  			cs2[i] = inverseNTT(nttMul(c, s2[i]))
   499  		}
   500  
   501  		z := make([]ringElement, l, maxL)
   502  		for i := range y {
   503  			z[i] = polyAdd(y[i], cs1[i])
   504  
   505  			// Reject if ||z||∞ ≥ γ1 − β
   506  			if coefficientsExceedBound(z[i], γ1β) {
   507  				if testingOnlyRejectionReason != nil {
   508  					testingOnlyRejectionReason("z")
   509  				}
   510  				continue sign
   511  			}
   512  		}
   513  
   514  		for i := range w {
   515  			r0 := polySub(w[i], cs2[i])
   516  
   517  			// Reject if ||LowBits(r0)||∞ ≥ γ2 − β
   518  			if lowBitsExceedBound(r0, γ2β, p) {
   519  				if testingOnlyRejectionReason != nil {
   520  					testingOnlyRejectionReason("r0")
   521  				}
   522  				continue sign
   523  			}
   524  		}
   525  
   526  		ct0 := make([]ringElement, k, maxK)
   527  		for i := range ct0 {
   528  			ct0[i] = inverseNTT(nttMul(c, t0[i]))
   529  
   530  			// Reject if ||ct0||∞ ≥ γ2
   531  			if coefficientsExceedBound(ct0[i], γ2) {
   532  				if testingOnlyRejectionReason != nil {
   533  					testingOnlyRejectionReason("ct0")
   534  				}
   535  				continue sign
   536  			}
   537  		}
   538  
   539  		count1s := 0
   540  		h := make([][n]byte, k, maxK)
   541  		for i := range w {
   542  			var count int
   543  			h[i], count = makeHint(ct0[i], w[i], cs2[i], p)
   544  			count1s += count
   545  		}
   546  		// Reject if number of hints > ω
   547  		if count1s > p.ω {
   548  			if testingOnlyRejectionReason != nil {
   549  				testingOnlyRejectionReason("h")
   550  			}
   551  			continue sign
   552  		}
   553  
   554  		return sigEncode(ch, z, h, p)
   555  	}
   556  }
   557  
   558  // testingOnlyRejectionReason is set in tests, to ensure that all rejection
   559  // paths are covered. If not nil, it is called with a string describing the
   560  // reason for rejection: "z", "r0", "ct0", or "h".
   561  var testingOnlyRejectionReason func(reason string)
   562  
   563  // w1Encode implements w1Encode from FIPS 204, writing directly into H.
   564  func w1Encode(H *sha3.SHAKE, w [n]byte, p parameters) {
   565  	switch p.γ2 {
   566  	case 32:
   567  		// Coefficients are <= (q − 1)/(2γ2) − 1 = 15, four bits each.
   568  		buf := make([]byte, 4*n/8)
   569  		for i := 0; i < n; i += 2 {
   570  			b0 := w[i]
   571  			b1 := w[i+1]
   572  			buf[i/2] = b0 | b1<<4
   573  		}
   574  		H.Write(buf)
   575  	case 88:
   576  		// Coefficients are <= (q − 1)/(2γ2) − 1 = 43, six bits each.
   577  		buf := make([]byte, 6*n/8)
   578  		for i := 0; i < n; i += 4 {
   579  			b0 := w[i]
   580  			b1 := w[i+1]
   581  			b2 := w[i+2]
   582  			b3 := w[i+3]
   583  			buf[3*i/4+0] = (b0 >> 0) | (b1 << 6)
   584  			buf[3*i/4+1] = (b1 >> 2) | (b2 << 4)
   585  			buf[3*i/4+2] = (b2 >> 4) | (b3 << 2)
   586  		}
   587  		H.Write(buf)
   588  	default:
   589  		panic("mldsa: internal error: unsupported γ2")
   590  	}
   591  }
   592  
   593  func coefficientsExceedBound(w ringElement, bound uint32) bool {
   594  	// If this function appears in profiles, it might be possible to deduplicate
   595  	// the work of fieldFromMontgomery inside fieldInfinityNorm with the
   596  	// subsequent encoding of w.
   597  	for i := range w {
   598  		if fieldInfinityNorm(w[i]) >= bound {
   599  			return true
   600  		}
   601  	}
   602  	return false
   603  }
   604  
   605  func lowBitsExceedBound(w ringElement, bound uint32, p parameters) bool {
   606  	switch p.γ2 {
   607  	case 32:
   608  		for i := range w {
   609  			_, r0 := decompose32(w[i])
   610  			if constantTimeAbs(r0) >= bound {
   611  				return true
   612  			}
   613  		}
   614  	case 88:
   615  		for i := range w {
   616  			_, r0 := decompose88(w[i])
   617  			if constantTimeAbs(r0) >= bound {
   618  				return true
   619  			}
   620  		}
   621  	default:
   622  		panic("mldsa: internal error: unsupported γ2")
   623  	}
   624  	return false
   625  }
   626  
   627  var (
   628  	errInvalidSignatureLength           = errors.New("mldsa: invalid signature length")
   629  	errInvalidSignatureCoeffBounds      = errors.New("mldsa: invalid signature")
   630  	errInvalidSignatureChallenge        = errors.New("mldsa: invalid signature")
   631  	errInvalidSignatureHintLimits       = errors.New("mldsa: invalid signature encoding")
   632  	errInvalidSignatureHintIndexOrder   = errors.New("mldsa: invalid signature encoding")
   633  	errInvalidSignatureHintExtraIndices = errors.New("mldsa: invalid signature encoding")
   634  )
   635  
   636  func Verify(pub *PublicKey, msg, sig []byte, context string) error {
   637  	fipsSelfTest()
   638  	fips140.RecordApproved()
   639  	μ, err := computeMessageHash(pub.tr[:], msg, context)
   640  	if err != nil {
   641  		return err
   642  	}
   643  	return verifyInternal(pub, &μ, sig)
   644  }
   645  
   646  func VerifyExternalMu(pub *PublicKey, μ []byte, sig []byte) error {
   647  	fipsSelfTest()
   648  	fips140.RecordApproved()
   649  	if len(μ) != 64 {
   650  		return errMessageHashLength
   651  	}
   652  	return verifyInternal(pub, (*[64]byte)(μ), sig)
   653  }
   654  
   655  func verifyInternal(pub *PublicKey, μ *[64]byte, sig []byte) error {
   656  	p, k, l := pub.p, pub.p.k, pub.p.l
   657  	t1, A := pub.t1[:k], pub.a[:k*l]
   658  
   659  	β := p.τ * p.η
   660  	γ1 := uint32(1 << p.γ1)
   661  	γ1β := γ1 - uint32(β)
   662  
   663  	z := make([]ringElement, l, maxL)
   664  	h := make([][n]byte, k, maxK)
   665  	ch, err := sigDecode(sig, z, h, p)
   666  	if err != nil {
   667  		return err
   668  	}
   669  
   670  	c := ntt(sampleInBall(ch, p))
   671  
   672  	// w = Â ∘ NTT(z) − NTT(c) ∘ NTT(t₁ ⋅ 2ᵈ)
   673  	zHat := make([]nttElement, l, maxL)
   674  	for i := range zHat {
   675  		zHat[i] = ntt(z[i])
   676  	}
   677  	w := make([]ringElement, k, maxK)
   678  	for i := range w {
   679  		var wHat nttElement
   680  		for j := range l {
   681  			wHat = polyAdd(wHat, nttMul(A[i*l+j], zHat[j]))
   682  		}
   683  		wHat = polySub(wHat, nttMul(c, t1[i]))
   684  		w[i] = inverseNTT(wHat)
   685  	}
   686  
   687  	// Use hints h to compute w₁ from w(approx).
   688  	w1 := make([][n]byte, k, maxK)
   689  	for i := range w {
   690  		w1[i] = useHint(w[i], h[i], p)
   691  	}
   692  
   693  	H := sha3.NewShake256()
   694  	H.Write(μ[:])
   695  	for i := range w {
   696  		w1Encode(H, w1[i], p)
   697  	}
   698  	computedCH := make([]byte, p.λ/4, maxλ/4)
   699  	H.Read(computedCH)
   700  
   701  	for i := range z {
   702  		if coefficientsExceedBound(z[i], γ1β) {
   703  			return errInvalidSignatureCoeffBounds
   704  		}
   705  	}
   706  
   707  	if !bytes.Equal(ch, computedCH) {
   708  		return errInvalidSignatureChallenge
   709  	}
   710  
   711  	return nil
   712  }
   713  
   714  func sigEncode(ch []byte, z []ringElement, h [][n]byte, p parameters) []byte {
   715  	sig := make([]byte, 0, sigSize(p))
   716  	sig = append(sig, ch...)
   717  	for i := range z {
   718  		sig = bitPack(sig, z[i], p)
   719  	}
   720  	sig = hintEncode(sig, h, p)
   721  	return sig
   722  }
   723  
   724  func sigDecode(sig []byte, z []ringElement, h [][n]byte, p parameters) (ch []byte, err error) {
   725  	if len(sig) != sigSize(p) {
   726  		return nil, errInvalidSignatureLength
   727  	}
   728  	ch, sig = sig[:p.λ/4], sig[p.λ/4:]
   729  	for i := range z {
   730  		length := (p.γ1 + 1) * n / 8
   731  		z[i] = bitUnpack(sig[:length], p)
   732  		sig = sig[length:]
   733  	}
   734  	if err := hintDecode(sig, h, p); err != nil {
   735  		return nil, err
   736  	}
   737  	return ch, nil
   738  }
   739  
   740  func hintEncode(buf []byte, h [][n]byte, p parameters) []byte {
   741  	ω, k := p.ω, p.k
   742  	out, y := sliceForAppend(buf, ω+k)
   743  	var idx byte
   744  	for i := range k {
   745  		for j := range n {
   746  			if h[i][j] != 0 {
   747  				y[idx] = byte(j)
   748  				idx++
   749  			}
   750  		}
   751  		y[ω+i] = idx
   752  	}
   753  	return out
   754  }
   755  
   756  func hintDecode(y []byte, h [][n]byte, p parameters) error {
   757  	ω, k := p.ω, p.k
   758  	if len(y) != ω+k {
   759  		return errors.New("mldsa: internal error: invalid signature hint length")
   760  	}
   761  	var idx byte
   762  	for i := range k {
   763  		limit := y[ω+i]
   764  		if limit < idx || limit > byte(ω) {
   765  			return errInvalidSignatureHintLimits
   766  		}
   767  		first := idx
   768  		for idx < limit {
   769  			if idx > first && y[idx-1] >= y[idx] {
   770  				return errInvalidSignatureHintIndexOrder
   771  			}
   772  			h[i][y[idx]] = 1
   773  			idx++
   774  		}
   775  	}
   776  	for i := idx; i < byte(ω); i++ {
   777  		if y[i] != 0 {
   778  			return errInvalidSignatureHintExtraIndices
   779  		}
   780  	}
   781  	return nil
   782  }
   783  

View as plain text