Source file src/crypto/internal/fips140/mlkem/field_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mlkem
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"math/big"
    11  	mathrand "math/rand/v2"
    12  	"strconv"
    13  	"testing"
    14  )
    15  
    16  func TestFieldReduce(t *testing.T) {
    17  	for a := uint32(0); a < 2*q*q; a++ {
    18  		got := fieldReduce(a)
    19  		exp := fieldElement(a % q)
    20  		if got != exp {
    21  			t.Fatalf("reduce(%d) = %d, expected %d", a, got, exp)
    22  		}
    23  	}
    24  }
    25  
    26  func TestFieldAdd(t *testing.T) {
    27  	for a := fieldElement(0); a < q; a++ {
    28  		for b := fieldElement(0); b < q; b++ {
    29  			got := fieldAdd(a, b)
    30  			exp := (a + b) % q
    31  			if got != exp {
    32  				t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
    33  			}
    34  		}
    35  	}
    36  }
    37  
    38  func TestFieldSub(t *testing.T) {
    39  	for a := fieldElement(0); a < q; a++ {
    40  		for b := fieldElement(0); b < q; b++ {
    41  			got := fieldSub(a, b)
    42  			exp := (a - b + q) % q
    43  			if got != exp {
    44  				t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
    45  			}
    46  		}
    47  	}
    48  }
    49  
    50  func TestFieldMul(t *testing.T) {
    51  	for a := fieldElement(0); a < q; a++ {
    52  		for b := fieldElement(0); b < q; b++ {
    53  			got := fieldMul(a, b)
    54  			exp := fieldElement((uint32(a) * uint32(b)) % q)
    55  			if got != exp {
    56  				t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
    57  			}
    58  		}
    59  	}
    60  }
    61  
    62  func TestDecompressCompress(t *testing.T) {
    63  	for _, bits := range []uint8{1, 4, 10} {
    64  		for a := uint16(0); a < 1<<bits; a++ {
    65  			f := decompress(a, bits)
    66  			if f >= q {
    67  				t.Fatalf("decompress(%d, %d) = %d >= q", a, bits, f)
    68  			}
    69  			got := compress(f, bits)
    70  			if got != a {
    71  				t.Fatalf("compress(decompress(%d, %d), %d) = %d", a, bits, bits, got)
    72  			}
    73  		}
    74  
    75  		for a := fieldElement(0); a < q; a++ {
    76  			c := compress(a, bits)
    77  			if c >= 1<<bits {
    78  				t.Fatalf("compress(%d, %d) = %d >= 2^bits", a, bits, c)
    79  			}
    80  			got := decompress(c, bits)
    81  			diff := min(a-got, got-a, a-got+q, got-a+q)
    82  			ceil := q / (1 << bits)
    83  			if diff > fieldElement(ceil) {
    84  				t.Fatalf("decompress(compress(%d, %d), %d) = %d (diff %d, max diff %d)",
    85  					a, bits, bits, got, diff, ceil)
    86  			}
    87  		}
    88  	}
    89  }
    90  
    91  func CompressRat(x fieldElement, d uint8) uint16 {
    92  	if x >= q {
    93  		panic("x out of range")
    94  	}
    95  	if d <= 0 || d >= 12 {
    96  		panic("d out of range")
    97  	}
    98  
    99  	precise := big.NewRat((1<<d)*int64(x), q) // (2ᵈ / q) * x == (2ᵈ * x) / q
   100  
   101  	// FloatString rounds halves away from 0, and our result should always be positive,
   102  	// so it should work as we expect. (There's no direct way to round a Rat.)
   103  	rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
   104  	if err != nil {
   105  		panic(err)
   106  	}
   107  
   108  	// If we rounded up, `rounded` may be equal to 2ᵈ, so we perform a final reduction.
   109  	return uint16(rounded % (1 << d))
   110  }
   111  
   112  func TestCompress(t *testing.T) {
   113  	for d := 1; d < 12; d++ {
   114  		for n := 0; n < q; n++ {
   115  			expected := CompressRat(fieldElement(n), uint8(d))
   116  			result := compress(fieldElement(n), uint8(d))
   117  			if result != expected {
   118  				t.Errorf("compress(%d, %d): got %d, expected %d", n, d, result, expected)
   119  			}
   120  		}
   121  	}
   122  }
   123  
   124  func DecompressRat(y uint16, d uint8) fieldElement {
   125  	if y >= 1<<d {
   126  		panic("y out of range")
   127  	}
   128  	if d <= 0 || d >= 12 {
   129  		panic("d out of range")
   130  	}
   131  
   132  	precise := big.NewRat(q*int64(y), 1<<d) // (q / 2ᵈ) * y  ==  (q * y) / 2ᵈ
   133  
   134  	// FloatString rounds halves away from 0, and our result should always be positive,
   135  	// so it should work as we expect. (There's no direct way to round a Rat.)
   136  	rounded, err := strconv.ParseInt(precise.FloatString(0), 10, 64)
   137  	if err != nil {
   138  		panic(err)
   139  	}
   140  
   141  	// If we rounded up, `rounded` may be equal to q, so we perform a final reduction.
   142  	return fieldElement(rounded % q)
   143  }
   144  
   145  func TestDecompress(t *testing.T) {
   146  	for d := 1; d < 12; d++ {
   147  		for n := 0; n < (1 << d); n++ {
   148  			expected := DecompressRat(uint16(n), uint8(d))
   149  			result := decompress(uint16(n), uint8(d))
   150  			if result != expected {
   151  				t.Errorf("decompress(%d, %d): got %d, expected %d", n, d, result, expected)
   152  			}
   153  		}
   154  	}
   155  }
   156  
   157  func randomRingElement() ringElement {
   158  	var r ringElement
   159  	for i := range r {
   160  		r[i] = fieldElement(mathrand.IntN(q))
   161  	}
   162  	return r
   163  }
   164  
   165  func TestEncodeDecode(t *testing.T) {
   166  	f := randomRingElement()
   167  	b := make([]byte, 12*n/8)
   168  	rand.Read(b)
   169  
   170  	// Compare ringCompressAndEncode to ringCompressAndEncodeN.
   171  	e1 := ringCompressAndEncode(nil, f, 10)
   172  	e2 := ringCompressAndEncode10(nil, f)
   173  	if !bytes.Equal(e1, e2) {
   174  		t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode10 = %x", e1, e2)
   175  	}
   176  	e1 = ringCompressAndEncode(nil, f, 4)
   177  	e2 = ringCompressAndEncode4(nil, f)
   178  	if !bytes.Equal(e1, e2) {
   179  		t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode4 = %x", e1, e2)
   180  	}
   181  	e1 = ringCompressAndEncode(nil, f, 1)
   182  	e2 = ringCompressAndEncode1(nil, f)
   183  	if !bytes.Equal(e1, e2) {
   184  		t.Errorf("ringCompressAndEncode = %x, ringCompressAndEncode1 = %x", e1, e2)
   185  	}
   186  
   187  	// Compare ringDecodeAndDecompress to ringDecodeAndDecompressN.
   188  	g1 := ringDecodeAndDecompress(b[:encodingSize10], 10)
   189  	g2 := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
   190  	if g1 != g2 {
   191  		t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress10 = %v", g1, g2)
   192  	}
   193  	g1 = ringDecodeAndDecompress(b[:encodingSize4], 4)
   194  	g2 = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
   195  	if g1 != g2 {
   196  		t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress4 = %v", g1, g2)
   197  	}
   198  	g1 = ringDecodeAndDecompress(b[:encodingSize1], 1)
   199  	g2 = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
   200  	if g1 != g2 {
   201  		t.Errorf("ringDecodeAndDecompress = %v, ringDecodeAndDecompress1 = %v", g1, g2)
   202  	}
   203  
   204  	// Round-trip ringCompressAndEncode and ringDecodeAndDecompress.
   205  	for d := 1; d < 12; d++ {
   206  		encodingSize := d * n / 8
   207  		g := ringDecodeAndDecompress(b[:encodingSize], uint8(d))
   208  		out := ringCompressAndEncode(nil, g, uint8(d))
   209  		if !bytes.Equal(out, b[:encodingSize]) {
   210  			t.Errorf("roundtrip failed for d = %d", d)
   211  		}
   212  	}
   213  
   214  	// Round-trip ringCompressAndEncodeN and ringDecodeAndDecompressN.
   215  	g := ringDecodeAndDecompress10((*[encodingSize10]byte)(b))
   216  	out := ringCompressAndEncode10(nil, g)
   217  	if !bytes.Equal(out, b[:encodingSize10]) {
   218  		t.Errorf("roundtrip failed for specialized 10")
   219  	}
   220  	g = ringDecodeAndDecompress4((*[encodingSize4]byte)(b))
   221  	out = ringCompressAndEncode4(nil, g)
   222  	if !bytes.Equal(out, b[:encodingSize4]) {
   223  		t.Errorf("roundtrip failed for specialized 4")
   224  	}
   225  	g = ringDecodeAndDecompress1((*[encodingSize1]byte)(b))
   226  	out = ringCompressAndEncode1(nil, g)
   227  	if !bytes.Equal(out, b[:encodingSize1]) {
   228  		t.Errorf("roundtrip failed for specialized 1")
   229  	}
   230  }
   231  
   232  func BitRev7(n uint8) uint8 {
   233  	if n>>7 != 0 {
   234  		panic("not 7 bits")
   235  	}
   236  	var r uint8
   237  	r |= n >> 6 & 0b0000_0001
   238  	r |= n >> 4 & 0b0000_0010
   239  	r |= n >> 2 & 0b0000_0100
   240  	r |= n /**/ & 0b0000_1000
   241  	r |= n << 2 & 0b0001_0000
   242  	r |= n << 4 & 0b0010_0000
   243  	r |= n << 6 & 0b0100_0000
   244  	return r
   245  }
   246  
   247  func TestZetas(t *testing.T) {
   248  	ζ := big.NewInt(17)
   249  	q := big.NewInt(q)
   250  	for k, zeta := range zetas {
   251  		// ζ^BitRev7(k) mod q
   252  		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))), q)
   253  		if big.NewInt(int64(zeta)).Cmp(exp) != 0 {
   254  			t.Errorf("zetas[%d] = %v, expected %v", k, zeta, exp)
   255  		}
   256  	}
   257  }
   258  
   259  func TestGammas(t *testing.T) {
   260  	ζ := big.NewInt(17)
   261  	q := big.NewInt(q)
   262  	for k, gamma := range gammas {
   263  		// ζ^2BitRev7(i)+1
   264  		exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev7(uint8(k)))*2+1), q)
   265  		if big.NewInt(int64(gamma)).Cmp(exp) != 0 {
   266  			t.Errorf("gammas[%d] = %v, expected %v", k, gamma, exp)
   267  		}
   268  	}
   269  }
   270  

View as plain text