Source file src/crypto/internal/fips140/mldsa/field_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mldsa
     6  
     7  import (
     8  	"math/big"
     9  	"testing"
    10  )
    11  
    12  type interestingValue struct {
    13  	v uint32
    14  	m fieldElement
    15  }
    16  
    17  // q is large enough that we can't exhaustively test all q × q inputs, so when
    18  // we have two inputs  we test [0, q) on one side and a set of interesting
    19  // values on the other side.
    20  func interestingValues() []interestingValue {
    21  	if testing.Short() {
    22  		return []interestingValue{{v: q - 1, m: minusOne}}
    23  	}
    24  	var values []interestingValue
    25  	for _, v := range []uint32{
    26  		0,
    27  		1,
    28  		2,
    29  		3,
    30  		q - 3,
    31  		q - 2,
    32  		q - 1,
    33  		q / 2,
    34  		(q + 1) / 2,
    35  	} {
    36  		m, _ := fieldToMontgomery(v)
    37  		values = append(values, interestingValue{v: v, m: m})
    38  		// Also test values that have an interesting Montgomery representation.
    39  		values = append(values, interestingValue{
    40  			v: fieldFromMontgomery(fieldElement(v)), m: fieldElement(v)})
    41  	}
    42  	return values
    43  }
    44  
    45  func TestToFromMontgomery(t *testing.T) {
    46  	for a := range uint32(q) {
    47  		m, err := fieldToMontgomery(a)
    48  		if err != nil {
    49  			t.Fatalf("fieldToMontgomery(%d) returned error: %v", a, err)
    50  		}
    51  		exp := fieldElement((uint64(a) * R) % q)
    52  		if m != exp {
    53  			t.Fatalf("fieldToMontgomery(%d) = %d, expected %d", a, m, exp)
    54  		}
    55  		got := fieldFromMontgomery(m)
    56  		if got != a {
    57  			t.Fatalf("fieldFromMontgomery(fieldToMontgomery(%d)) = %d, expected %d", a, got, a)
    58  		}
    59  	}
    60  }
    61  
    62  func TestFieldAdd(t *testing.T) {
    63  	t.Parallel()
    64  	for _, a := range interestingValues() {
    65  		for b := range fieldElement(q) {
    66  			got := fieldAdd(a.m, b)
    67  			exp := (a.m + b) % q
    68  			if got != exp {
    69  				t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
    70  			}
    71  		}
    72  	}
    73  }
    74  
    75  func TestFieldSub(t *testing.T) {
    76  	t.Parallel()
    77  	for _, a := range interestingValues() {
    78  		for b := range fieldElement(q) {
    79  			got := fieldSub(a.m, b)
    80  			exp := (a.m + q - b) % q
    81  			if got != exp {
    82  				t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
    83  			}
    84  		}
    85  	}
    86  }
    87  
    88  func TestFieldSubToMontgomery(t *testing.T) {
    89  	t.Parallel()
    90  	for _, a := range interestingValues() {
    91  		for b := range uint32(q) {
    92  			got := fieldSubToMontgomery(a.v, b)
    93  			diff := (a.v + q - b) % q
    94  			exp := fieldElement((uint64(diff) * R) % q)
    95  			if got != exp {
    96  				t.Fatalf("fieldSubToMontgomery(%d, %d) = %d, expected %d", a.v, b, got, exp)
    97  			}
    98  		}
    99  	}
   100  }
   101  
   102  func TestFieldReduceOnce(t *testing.T) {
   103  	t.Parallel()
   104  	for a := range uint32(2 * q) {
   105  		got := fieldReduceOnce(a)
   106  		var exp uint32
   107  		if a < q {
   108  			exp = a
   109  		} else {
   110  			exp = a - q
   111  		}
   112  		if uint32(got) != exp {
   113  			t.Fatalf("fieldReduceOnce(%d) = %d, expected %d", a, got, exp)
   114  		}
   115  	}
   116  }
   117  
   118  func TestFieldMul(t *testing.T) {
   119  	t.Parallel()
   120  	for _, a := range interestingValues() {
   121  		for b := range fieldElement(q) {
   122  			got := fieldFromMontgomery(fieldMontgomeryMul(a.m, b))
   123  			exp := uint32((uint64(a.v) * uint64(fieldFromMontgomery(b))) % q)
   124  			if got != exp {
   125  				t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
   126  			}
   127  		}
   128  	}
   129  }
   130  
   131  func TestFieldToMontgomeryOverflow(t *testing.T) {
   132  	// fieldToMontgomery should reject inputs ≥ q.
   133  	inputs := []uint32{
   134  		q,
   135  		q + 1,
   136  		q + 2,
   137  		1<<23 - 1,
   138  		1 << 23,
   139  		q + 1<<23,
   140  		q + 1<<31,
   141  		^uint32(0),
   142  	}
   143  	for _, in := range inputs {
   144  		if _, err := fieldToMontgomery(in); err == nil {
   145  			t.Fatalf("fieldToMontgomery(%d) did not return an error", in)
   146  		}
   147  	}
   148  }
   149  
   150  func TestFieldMulSub(t *testing.T) {
   151  	for _, a := range interestingValues() {
   152  		for _, b := range interestingValues() {
   153  			for _, c := range interestingValues() {
   154  				got := fieldFromMontgomery(fieldMontgomeryMulSub(a.m, b.m, c.m))
   155  				exp := uint32((uint64(a.v) * (uint64(b.v) + q - uint64(c.v))) % q)
   156  				if got != exp {
   157  					t.Fatalf("%d * (%d - %d) = %d, expected %d", a.v, b.v, c.v, got, exp)
   158  				}
   159  			}
   160  		}
   161  	}
   162  }
   163  
   164  func TestFieldAddMul(t *testing.T) {
   165  	for _, a := range interestingValues() {
   166  		for _, b := range interestingValues() {
   167  			for _, c := range interestingValues() {
   168  				for _, d := range interestingValues() {
   169  					got := fieldFromMontgomery(fieldMontgomeryAddMul(a.m, b.m, c.m, d.m))
   170  					exp := uint32((uint64(a.v)*uint64(b.v) + uint64(c.v)*uint64(d.v)) % q)
   171  					if got != exp {
   172  						t.Fatalf("%d + %d * %d = %d, expected %d", a.v, b.v, c.v, got, exp)
   173  					}
   174  				}
   175  			}
   176  		}
   177  	}
   178  }
   179  
   180  func BitRev8(n uint8) uint8 {
   181  	var r uint8
   182  	r |= n >> 7 & 0b0000_0001
   183  	r |= n >> 5 & 0b0000_0010
   184  	r |= n >> 3 & 0b0000_0100
   185  	r |= n >> 1 & 0b0000_1000
   186  	r |= n << 1 & 0b0001_0000
   187  	r |= n << 3 & 0b0010_0000
   188  	r |= n << 5 & 0b0100_0000
   189  	r |= n << 7 & 0b1000_0000
   190  	return r
   191  }
   192  
   193  func CenteredMod(x, m uint32) int32 {
   194  	x = x % m
   195  	if x > m/2 {
   196  		return int32(x) - int32(m)
   197  	}
   198  	return int32(x)
   199  }
   200  
   201  func reduceModQ(x int32) uint32 {
   202  	x %= q
   203  	if x < 0 {
   204  		return uint32(x + q)
   205  	}
   206  	return uint32(x)
   207  }
   208  
   209  func TestCenteredMod(t *testing.T) {
   210  	for x := range uint32(q * 2) {
   211  		got := CenteredMod(uint32(x), q)
   212  		if reduceModQ(got) != (x % q) {
   213  			t.Fatalf("CenteredMod(%d) = %d, which is not congruent to %d mod %d", x, got, x, q)
   214  		}
   215  	}
   216  
   217  	for x := range uint32(q) {
   218  		r, _ := fieldToMontgomery(x)
   219  		got := fieldCenteredMod(r)
   220  		exp := CenteredMod(x, q)
   221  		if got != exp {
   222  			t.Fatalf("fieldCenteredMod(%d) = %d, expected %d", x, got, exp)
   223  		}
   224  	}
   225  }
   226  
   227  func TestInfinityNorm(t *testing.T) {
   228  	for x := range uint32(q) {
   229  		r, _ := fieldToMontgomery(x)
   230  		got := fieldInfinityNorm(r)
   231  		exp := CenteredMod(x, q)
   232  		if exp < 0 {
   233  			exp = -exp
   234  		}
   235  		if got != uint32(exp) {
   236  			t.Fatalf("fieldInfinityNorm(%d) = %d, expected %d", x, got, exp)
   237  		}
   238  	}
   239  }
   240  
   241  func TestConstants(t *testing.T) {
   242  	if fieldFromMontgomery(one) != 1 {
   243  		t.Errorf("one constant incorrect")
   244  	}
   245  	if fieldFromMontgomery(minusOne) != q-1 {
   246  		t.Errorf("minusOne constant incorrect")
   247  	}
   248  	if fieldInfinityNorm(one) != 1 {
   249  		t.Errorf("one infinity norm incorrect")
   250  	}
   251  	if fieldInfinityNorm(minusOne) != 1 {
   252  		t.Errorf("minusOne infinity norm incorrect")
   253  	}
   254  
   255  	if PublicKeySize44 != pubKeySize(params44) {
   256  		t.Errorf("PublicKeySize44 constant incorrect")
   257  	}
   258  	if PublicKeySize65 != pubKeySize(params65) {
   259  		t.Errorf("PublicKeySize65 constant incorrect")
   260  	}
   261  	if PublicKeySize87 != pubKeySize(params87) {
   262  		t.Errorf("PublicKeySize87 constant incorrect")
   263  	}
   264  	if SignatureSize44 != sigSize(params44) {
   265  		t.Errorf("SignatureSize44 constant incorrect")
   266  	}
   267  	if SignatureSize65 != sigSize(params65) {
   268  		t.Errorf("SignatureSize65 constant incorrect")
   269  	}
   270  	if SignatureSize87 != sigSize(params87) {
   271  		t.Errorf("SignatureSize87 constant incorrect")
   272  	}
   273  }
   274  
   275  func TestPower2Round(t *testing.T) {
   276  	t.Parallel()
   277  	for x := range uint32(q) {
   278  		rr, _ := fieldToMontgomery(x)
   279  		t1, t0 := power2Round(rr)
   280  
   281  		hi, err := fieldToMontgomery(uint32(t1) << 13)
   282  		if err != nil {
   283  			t.Fatalf("power2Round(%d): failed to convert high part to Montgomery: %v", x, err)
   284  		}
   285  		if r := fieldFromMontgomery(fieldAdd(hi, t0)); r != x {
   286  			t.Fatalf("power2Round(%d) = (%d, %d), which reconstructs to %d, expected %d", x, t1, t0, r, x)
   287  		}
   288  	}
   289  }
   290  
   291  func SpecDecompose(rr fieldElement, p parameters) (R1 uint32, R0 int32) {
   292  	r := fieldFromMontgomery(rr)
   293  	if (q-1)%p.γ2 != 0 {
   294  		panic("mldsa: internal error: unsupported denγ2")
   295  	}
   296  	γ2 := (q - 1) / uint32(p.γ2)
   297  	r0 := CenteredMod(r, 2*γ2)
   298  	diff := int32(r) - r0
   299  	if diff == q-1 {
   300  		r0 = r0 - 1
   301  		return 0, r0
   302  	} else {
   303  		if diff < 0 || uint32(diff)%γ2 != 0 {
   304  			panic("mldsa: internal error: invalid decomposition")
   305  		}
   306  		r1 := uint32(diff) / (2 * γ2)
   307  		return r1, r0
   308  	}
   309  }
   310  
   311  func TestDecompose(t *testing.T) {
   312  	t.Run("ML-DSA-44", func(t *testing.T) {
   313  		testDecompose(t, params44)
   314  	})
   315  	t.Run("ML-DSA-65,87", func(t *testing.T) {
   316  		testDecompose(t, params65)
   317  	})
   318  }
   319  
   320  func testDecompose(t *testing.T, p parameters) {
   321  	t.Parallel()
   322  	for x := range uint32(q) {
   323  		rr, _ := fieldToMontgomery(x)
   324  		r1, r0 := SpecDecompose(rr, p)
   325  
   326  		// Check that SpecDecompose is correct.
   327  		// r ≡ r1 * (2 * γ2) + r0 mod q
   328  		γ2 := (q - 1) / uint32(p.γ2)
   329  		reconstructed := reduceModQ(int32(r1*2*γ2) + r0)
   330  		if reconstructed != x {
   331  			t.Fatalf("SpecDecompose(%d) = (%d, %d), which reconstructs to %d, expected %d", x, r1, r0, reconstructed, x)
   332  		}
   333  
   334  		var gotR1 byte
   335  		var gotR0 int32
   336  		switch p.γ2 {
   337  		case 88:
   338  			gotR1, gotR0 = decompose88(rr)
   339  			if gotR1 > 43 {
   340  				t.Fatalf("decompose88(%d) returned r1 = %d, which is out of range", x, gotR1)
   341  			}
   342  		case 32:
   343  			gotR1, gotR0 = decompose32(rr)
   344  			if gotR1 > 15 {
   345  				t.Fatalf("decompose32(%d) returned r1 = %d, which is out of range", x, gotR1)
   346  			}
   347  		default:
   348  			t.Fatalf("unsupported denγ2: %d", p.γ2)
   349  		}
   350  		if uint32(gotR1) != r1 {
   351  			t.Fatalf("highBits(%d) = %d, expected %d", x, gotR1, r1)
   352  		}
   353  		if gotR0 != r0 {
   354  			t.Fatalf("lowBits(%d) = %d, expected %d", x, gotR0, r0)
   355  		}
   356  	}
   357  }
   358  
   359  func TestZetas(t *testing.T) {
   360  	ζ := big.NewInt(1753)
   361  	q := big.NewInt(q)
   362  	for k, zeta := range zetas {
   363  		// ζ^BitRev₈(k) mod q
   364  		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev8(uint8(k)))), q)
   365  		got := fieldFromMontgomery(zeta)
   366  		if big.NewInt(int64(got)).Cmp(exp) != 0 {
   367  			t.Errorf("zetas[%d] = %v, expected %v", k, got, exp)
   368  		}
   369  	}
   370  }
   371  

View as plain text