Source file src/crypto/sha3/sha3_test.go

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sha3_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/cryptotest"
    10  	. "crypto/sha3"
    11  	"encoding/hex"
    12  	"hash"
    13  	"io"
    14  	"math/rand"
    15  	"strings"
    16  	"testing"
    17  )
    18  
    19  const testString = "brekeccakkeccak koax koax"
    20  
    21  // testDigests contains functions returning hash.Hash instances
    22  // with output-length equal to the KAT length for SHA-3, Keccak
    23  // and SHAKE instances.
    24  var testDigests = map[string]func() *SHA3{
    25  	"SHA3-224": New224,
    26  	"SHA3-256": New256,
    27  	"SHA3-384": New384,
    28  	"SHA3-512": New512,
    29  }
    30  
    31  // testShakes contains functions that return *sha3.SHAKE instances for
    32  // with output-length equal to the KAT length.
    33  var testShakes = map[string]struct {
    34  	constructor  func(N []byte, S []byte) *SHAKE
    35  	defAlgoName  string
    36  	defCustomStr string
    37  }{
    38  	// NewCSHAKE without customization produces same result as SHAKE
    39  	"SHAKE128":  {NewCSHAKE128, "", ""},
    40  	"SHAKE256":  {NewCSHAKE256, "", ""},
    41  	"cSHAKE128": {NewCSHAKE128, "CSHAKE128", "CustomString"},
    42  	"cSHAKE256": {NewCSHAKE256, "CSHAKE256", "CustomString"},
    43  }
    44  
    45  func TestSHA3Hash(t *testing.T) {
    46  	cryptotest.TestAllImplementations(t, "sha3", func(t *testing.T) {
    47  		for name, f := range testDigests {
    48  			t.Run(name, func(t *testing.T) {
    49  				cryptotest.TestHash(t, func() hash.Hash { return f() })
    50  			})
    51  		}
    52  	})
    53  }
    54  
    55  // TestUnalignedWrite tests that writing data in an arbitrary pattern with
    56  // small input buffers.
    57  func TestUnalignedWrite(t *testing.T) {
    58  	cryptotest.TestAllImplementations(t, "sha3", testUnalignedWrite)
    59  }
    60  
    61  func testUnalignedWrite(t *testing.T) {
    62  	buf := sequentialBytes(0x10000)
    63  	for alg, df := range testDigests {
    64  		d := df()
    65  		d.Reset()
    66  		d.Write(buf)
    67  		want := d.Sum(nil)
    68  		d.Reset()
    69  		for i := 0; i < len(buf); {
    70  			// Cycle through offsets which make a 137 byte sequence.
    71  			// Because 137 is prime this sequence should exercise all corner cases.
    72  			offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
    73  			for _, j := range offsets {
    74  				if v := len(buf) - i; v < j {
    75  					j = v
    76  				}
    77  				d.Write(buf[i : i+j])
    78  				i += j
    79  			}
    80  		}
    81  		got := d.Sum(nil)
    82  		if !bytes.Equal(got, want) {
    83  			t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want)
    84  		}
    85  	}
    86  
    87  	// Same for SHAKE
    88  	for alg, df := range testShakes {
    89  		want := make([]byte, 16)
    90  		got := make([]byte, 16)
    91  		d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
    92  
    93  		d.Reset()
    94  		d.Write(buf)
    95  		d.Read(want)
    96  		d.Reset()
    97  		for i := 0; i < len(buf); {
    98  			// Cycle through offsets which make a 137 byte sequence.
    99  			// Because 137 is prime this sequence should exercise all corner cases.
   100  			offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
   101  			for _, j := range offsets {
   102  				if v := len(buf) - i; v < j {
   103  					j = v
   104  				}
   105  				d.Write(buf[i : i+j])
   106  				i += j
   107  			}
   108  		}
   109  		d.Read(got)
   110  		if !bytes.Equal(got, want) {
   111  			t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want)
   112  		}
   113  	}
   114  }
   115  
   116  // TestAppend checks that appending works when reallocation is necessary.
   117  func TestAppend(t *testing.T) {
   118  	cryptotest.TestAllImplementations(t, "sha3", testAppend)
   119  }
   120  
   121  func testAppend(t *testing.T) {
   122  	d := New224()
   123  
   124  	for capacity := 2; capacity <= 66; capacity += 64 {
   125  		// The first time around the loop, Sum will have to reallocate.
   126  		// The second time, it will not.
   127  		buf := make([]byte, 2, capacity)
   128  		d.Reset()
   129  		d.Write([]byte{0xcc})
   130  		buf = d.Sum(buf)
   131  		expected := "0000DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39"
   132  		if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
   133  			t.Errorf("got %s, want %s", got, expected)
   134  		}
   135  	}
   136  }
   137  
   138  // TestAppendNoRealloc tests that appending works when no reallocation is necessary.
   139  func TestAppendNoRealloc(t *testing.T) {
   140  	cryptotest.TestAllImplementations(t, "sha3", testAppendNoRealloc)
   141  }
   142  
   143  func testAppendNoRealloc(t *testing.T) {
   144  	buf := make([]byte, 1, 200)
   145  	d := New224()
   146  	d.Write([]byte{0xcc})
   147  	buf = d.Sum(buf)
   148  	expected := "00DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39"
   149  	if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
   150  		t.Errorf("got %s, want %s", got, expected)
   151  	}
   152  }
   153  
   154  // TestSqueezing checks that squeezing the full output a single time produces
   155  // the same output as repeatedly squeezing the instance.
   156  func TestSqueezing(t *testing.T) {
   157  	cryptotest.TestAllImplementations(t, "sha3", testSqueezing)
   158  }
   159  
   160  func testSqueezing(t *testing.T) {
   161  	for algo, v := range testShakes {
   162  		d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
   163  		d0.Write([]byte(testString))
   164  		ref := make([]byte, 32)
   165  		d0.Read(ref)
   166  
   167  		d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
   168  		d1.Write([]byte(testString))
   169  		var multiple []byte
   170  		for range ref {
   171  			d1.Read(make([]byte, 0))
   172  			one := make([]byte, 1)
   173  			d1.Read(one)
   174  			multiple = append(multiple, one...)
   175  		}
   176  		if !bytes.Equal(ref, multiple) {
   177  			t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref))
   178  		}
   179  	}
   180  }
   181  
   182  // sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing.
   183  //
   184  // The alignment of each slice is intentionally randomized to detect alignment
   185  // issues in the implementation. See https://golang.org/issue/37644.
   186  // Ideally, the compiler should fuzz the alignment itself.
   187  // (See https://golang.org/issue/35128.)
   188  func sequentialBytes(size int) []byte {
   189  	alignmentOffset := rand.Intn(8)
   190  	result := make([]byte, size+alignmentOffset)[alignmentOffset:]
   191  	for i := range result {
   192  		result[i] = byte(i)
   193  	}
   194  	return result
   195  }
   196  
   197  func TestReset(t *testing.T) {
   198  	cryptotest.TestAllImplementations(t, "sha3", testReset)
   199  }
   200  
   201  func testReset(t *testing.T) {
   202  	out1 := make([]byte, 32)
   203  	out2 := make([]byte, 32)
   204  
   205  	for _, v := range testShakes {
   206  		// Calculate hash for the first time
   207  		c := v.constructor(nil, []byte{0x99, 0x98})
   208  		c.Write(sequentialBytes(0x100))
   209  		c.Read(out1)
   210  
   211  		// Calculate hash again
   212  		c.Reset()
   213  		c.Write(sequentialBytes(0x100))
   214  		c.Read(out2)
   215  
   216  		if !bytes.Equal(out1, out2) {
   217  			t.Error("\nExpected:\n", out1, "\ngot:\n", out2)
   218  		}
   219  	}
   220  }
   221  
   222  var sinkSHA3 byte
   223  
   224  func TestAllocations(t *testing.T) {
   225  	cryptotest.SkipTestAllocations(t)
   226  	t.Run("New", func(t *testing.T) {
   227  		if allocs := testing.AllocsPerRun(10, func() {
   228  			h := New256()
   229  			b := []byte("ABC")
   230  			h.Write(b)
   231  			out := make([]byte, 0, 32)
   232  			out = h.Sum(out)
   233  			sinkSHA3 ^= out[0]
   234  		}); allocs > 0 {
   235  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   236  		}
   237  	})
   238  	t.Run("NewSHAKE", func(t *testing.T) {
   239  		if allocs := testing.AllocsPerRun(10, func() {
   240  			h := NewSHAKE128()
   241  			b := []byte("ABC")
   242  			h.Write(b)
   243  			out := make([]byte, 32)
   244  			h.Read(out)
   245  			sinkSHA3 ^= out[0]
   246  		}); allocs > 0 {
   247  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   248  		}
   249  	})
   250  	t.Run("Sum", func(t *testing.T) {
   251  		if allocs := testing.AllocsPerRun(10, func() {
   252  			b := []byte("ABC")
   253  			out := Sum256(b)
   254  			sinkSHA3 ^= out[0]
   255  		}); allocs > 0 {
   256  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   257  		}
   258  	})
   259  	t.Run("SumSHAKE", func(t *testing.T) {
   260  		if allocs := testing.AllocsPerRun(10, func() {
   261  			b := []byte("ABC")
   262  			out := SumSHAKE128(b, 10)
   263  			sinkSHA3 ^= out[0]
   264  		}); allocs > 0 {
   265  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   266  		}
   267  	})
   268  }
   269  
   270  func TestCSHAKEAccumulated(t *testing.T) {
   271  	// Generated with pycryptodome@3.20.0
   272  	//
   273  	//    from Crypto.Hash import cSHAKE128
   274  	//    rng = cSHAKE128.new()
   275  	//    acc = cSHAKE128.new()
   276  	//    for n in range(200):
   277  	//        N = rng.read(n)
   278  	//        for s in range(200):
   279  	//            S = rng.read(s)
   280  	//            c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
   281  	//            c.update(rng.read(100))
   282  	//            acc.update(c.read(200))
   283  	//            c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
   284  	//            c.update(rng.read(168))
   285  	//            acc.update(c.read(200))
   286  	//            c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
   287  	//            c.update(rng.read(200))
   288  	//            acc.update(c.read(200))
   289  	//    print(acc.read(32).hex())
   290  	//
   291  	// and with @noble/hashes@v1.5.0
   292  	//
   293  	//    import { bytesToHex } from "@noble/hashes/utils";
   294  	//    import { cshake128 } from "@noble/hashes/sha3-addons";
   295  	//    const rng = cshake128.create();
   296  	//    const acc = cshake128.create();
   297  	//    for (let n = 0; n < 200; n++) {
   298  	//        const N = rng.xof(n);
   299  	//        for (let s = 0; s < 200; s++) {
   300  	//            const S = rng.xof(s);
   301  	//            let c = cshake128.create({ NISTfn: N, personalization: S });
   302  	//            c.update(rng.xof(100));
   303  	//            acc.update(c.xof(200));
   304  	//            c = cshake128.create({ NISTfn: N, personalization: S });
   305  	//            c.update(rng.xof(168));
   306  	//            acc.update(c.xof(200));
   307  	//            c = cshake128.create({ NISTfn: N, personalization: S });
   308  	//            c.update(rng.xof(200));
   309  	//            acc.update(c.xof(200));
   310  	//        }
   311  	//    }
   312  	//    console.log(bytesToHex(acc.xof(32)));
   313  	//
   314  	cryptotest.TestAllImplementations(t, "sha3", func(t *testing.T) {
   315  		t.Run("cSHAKE128", func(t *testing.T) {
   316  			testCSHAKEAccumulated(t, NewCSHAKE128, (1600-256)/8,
   317  				"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
   318  		})
   319  		t.Run("cSHAKE256", func(t *testing.T) {
   320  			testCSHAKEAccumulated(t, NewCSHAKE256, (1600-512)/8,
   321  				"0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef")
   322  		})
   323  	})
   324  }
   325  
   326  func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *SHAKE, rate int64, exp string) {
   327  	rnd := newCSHAKE(nil, nil)
   328  	acc := newCSHAKE(nil, nil)
   329  	for n := 0; n < 200; n++ {
   330  		N := make([]byte, n)
   331  		rnd.Read(N)
   332  		for s := 0; s < 200; s++ {
   333  			S := make([]byte, s)
   334  			rnd.Read(S)
   335  
   336  			c := newCSHAKE(N, S)
   337  			io.CopyN(c, rnd, 100 /* < rate */)
   338  			io.CopyN(acc, c, 200)
   339  
   340  			c.Reset()
   341  			io.CopyN(c, rnd, rate)
   342  			io.CopyN(acc, c, 200)
   343  
   344  			c.Reset()
   345  			io.CopyN(c, rnd, 200 /* > rate */)
   346  			io.CopyN(acc, c, 200)
   347  		}
   348  	}
   349  	out := make([]byte, 32)
   350  	acc.Read(out)
   351  	if got := hex.EncodeToString(out); got != exp {
   352  		t.Errorf("got %s, want %s", got, exp)
   353  	}
   354  }
   355  
   356  func TestCSHAKELargeS(t *testing.T) {
   357  	cryptotest.TestAllImplementations(t, "sha3", testCSHAKELargeS)
   358  }
   359  
   360  func testCSHAKELargeS(t *testing.T) {
   361  	if testing.Short() {
   362  		t.Skip("skipping test in short mode.")
   363  	}
   364  
   365  	// See https://go.dev/issue/66232.
   366  	const s = (1<<32)/8 + 1000 // s * 8 > 2^32
   367  	S := make([]byte, s)
   368  	rnd := NewSHAKE128()
   369  	rnd.Read(S)
   370  	c := NewCSHAKE128(nil, S)
   371  	io.CopyN(c, rnd, 1000)
   372  	out := make([]byte, 32)
   373  	c.Read(out)
   374  
   375  	// Generated with pycryptodome@3.20.0
   376  	//
   377  	//    from Crypto.Hash import cSHAKE128
   378  	//    rng = cSHAKE128.new()
   379  	//    S = rng.read(536871912)
   380  	//    c = cSHAKE128.new(custom=S)
   381  	//    c.update(rng.read(1000))
   382  	//    print(c.read(32).hex())
   383  	//
   384  	exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0"
   385  	if got := hex.EncodeToString(out); got != exp {
   386  		t.Errorf("got %s, want %s", got, exp)
   387  	}
   388  }
   389  
   390  func TestMarshalUnmarshal(t *testing.T) {
   391  	cryptotest.TestAllImplementations(t, "sha3", func(t *testing.T) {
   392  		t.Run("SHA3-224", func(t *testing.T) { testMarshalUnmarshal(t, New224()) })
   393  		t.Run("SHA3-256", func(t *testing.T) { testMarshalUnmarshal(t, New256()) })
   394  		t.Run("SHA3-384", func(t *testing.T) { testMarshalUnmarshal(t, New384()) })
   395  		t.Run("SHA3-512", func(t *testing.T) { testMarshalUnmarshal(t, New512()) })
   396  		t.Run("SHAKE128", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewSHAKE128()) })
   397  		t.Run("SHAKE256", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewSHAKE256()) })
   398  		t.Run("cSHAKE128", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewCSHAKE128([]byte("N"), []byte("S"))) })
   399  		t.Run("cSHAKE256", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewCSHAKE256([]byte("N"), []byte("S"))) })
   400  	})
   401  }
   402  
   403  // TODO(filippo): move this to crypto/internal/cryptotest.
   404  func testMarshalUnmarshal(t *testing.T, h *SHA3) {
   405  	buf := make([]byte, 200)
   406  	rand.Read(buf)
   407  	n := rand.Intn(200)
   408  	h.Write(buf)
   409  	want := h.Sum(nil)
   410  	h.Reset()
   411  	h.Write(buf[:n])
   412  	b, err := h.MarshalBinary()
   413  	if err != nil {
   414  		t.Errorf("MarshalBinary: %v", err)
   415  	}
   416  	h.Write(bytes.Repeat([]byte{0}, 200))
   417  	if err := h.UnmarshalBinary(b); err != nil {
   418  		t.Errorf("UnmarshalBinary: %v", err)
   419  	}
   420  	h.Write(buf[n:])
   421  	got := h.Sum(nil)
   422  	if !bytes.Equal(got, want) {
   423  		t.Errorf("got %x, want %x", got, want)
   424  	}
   425  }
   426  
   427  // TODO(filippo): move this to crypto/internal/cryptotest.
   428  func testMarshalUnmarshalSHAKE(t *testing.T, h *SHAKE) {
   429  	buf := make([]byte, 200)
   430  	rand.Read(buf)
   431  	n := rand.Intn(200)
   432  	h.Write(buf)
   433  	want := make([]byte, 32)
   434  	h.Read(want)
   435  	h.Reset()
   436  	h.Write(buf[:n])
   437  	b, err := h.MarshalBinary()
   438  	if err != nil {
   439  		t.Errorf("MarshalBinary: %v", err)
   440  	}
   441  	h.Write(bytes.Repeat([]byte{0}, 200))
   442  	if err := h.UnmarshalBinary(b); err != nil {
   443  		t.Errorf("UnmarshalBinary: %v", err)
   444  	}
   445  	h.Write(buf[n:])
   446  	got := make([]byte, 32)
   447  	h.Read(got)
   448  	if !bytes.Equal(got, want) {
   449  		t.Errorf("got %x, want %x", got, want)
   450  	}
   451  }
   452  
   453  // benchmarkHash tests the speed to hash num buffers of buflen each.
   454  func benchmarkHash(b *testing.B, h hash.Hash, size, num int) {
   455  	b.StopTimer()
   456  	h.Reset()
   457  	data := sequentialBytes(size)
   458  	b.SetBytes(int64(size * num))
   459  	b.StartTimer()
   460  
   461  	var state []byte
   462  	for i := 0; i < b.N; i++ {
   463  		for j := 0; j < num; j++ {
   464  			h.Write(data)
   465  		}
   466  		state = h.Sum(state[:0])
   467  	}
   468  	b.StopTimer()
   469  	h.Reset()
   470  }
   471  
   472  // benchmarkShake is specialized to the Shake instances, which don't
   473  // require a copy on reading output.
   474  func benchmarkShake(b *testing.B, h *SHAKE, size, num int) {
   475  	b.StopTimer()
   476  	h.Reset()
   477  	data := sequentialBytes(size)
   478  	d := make([]byte, 32)
   479  
   480  	b.SetBytes(int64(size * num))
   481  	b.StartTimer()
   482  
   483  	for i := 0; i < b.N; i++ {
   484  		h.Reset()
   485  		for j := 0; j < num; j++ {
   486  			h.Write(data)
   487  		}
   488  		h.Read(d)
   489  	}
   490  }
   491  
   492  func BenchmarkSha3_512_MTU(b *testing.B) { benchmarkHash(b, New512(), 1350, 1) }
   493  func BenchmarkSha3_384_MTU(b *testing.B) { benchmarkHash(b, New384(), 1350, 1) }
   494  func BenchmarkSha3_256_MTU(b *testing.B) { benchmarkHash(b, New256(), 1350, 1) }
   495  func BenchmarkSha3_224_MTU(b *testing.B) { benchmarkHash(b, New224(), 1350, 1) }
   496  
   497  func BenchmarkShake128_MTU(b *testing.B)  { benchmarkShake(b, NewSHAKE128(), 1350, 1) }
   498  func BenchmarkShake256_MTU(b *testing.B)  { benchmarkShake(b, NewSHAKE256(), 1350, 1) }
   499  func BenchmarkShake256_16x(b *testing.B)  { benchmarkShake(b, NewSHAKE256(), 16, 1024) }
   500  func BenchmarkShake256_1MiB(b *testing.B) { benchmarkShake(b, NewSHAKE256(), 1024, 1024) }
   501  
   502  func BenchmarkSha3_512_1MiB(b *testing.B) { benchmarkHash(b, New512(), 1024, 1024) }
   503  

View as plain text