Source file src/crypto/sha3/sha3_test.go

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sha3_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/cryptotest"
    10  	. "crypto/sha3"
    11  	"encoding/hex"
    12  	"hash"
    13  	"io"
    14  	"math/rand"
    15  	"strings"
    16  	"testing"
    17  )
    18  
    19  const testString = "brekeccakkeccak koax koax"
    20  
    21  // testDigests contains functions returning hash.Hash instances
    22  // with output-length equal to the KAT length for SHA-3, Keccak
    23  // and SHAKE instances.
    24  var testDigests = map[string]func() *SHA3{
    25  	"SHA3-224":  New224,
    26  	"SHA3-256":  New256,
    27  	"SHA3-384":  New384,
    28  	"SHA3-512":  New512,
    29  	"SHA3-Zero": func() *SHA3 { return &SHA3{} },
    30  }
    31  
    32  // testShakes contains functions that return *sha3.SHAKE instances for
    33  // with output-length equal to the KAT length.
    34  var testShakes = map[string]struct {
    35  	constructor  func(N []byte, S []byte) *SHAKE
    36  	defAlgoName  string
    37  	defCustomStr string
    38  }{
    39  	// NewCSHAKE without customization produces same result as SHAKE
    40  	"SHAKE128":   {NewCSHAKE128, "", ""},
    41  	"SHAKE256":   {NewCSHAKE256, "", ""},
    42  	"cSHAKE128":  {NewCSHAKE128, "CSHAKE128", "CustomString"},
    43  	"cSHAKE256":  {NewCSHAKE256, "CSHAKE256", "CustomString"},
    44  	"SHAKE-Zero": {func(N []byte, S []byte) *SHAKE { return &SHAKE{} }, "", ""},
    45  }
    46  
    47  func TestSHA3Hash(t *testing.T) {
    48  	cryptotest.TestAllImplementations(t, "sha3", func(t *testing.T) {
    49  		for name, f := range testDigests {
    50  			t.Run(name, func(t *testing.T) {
    51  				cryptotest.TestHash(t, func() hash.Hash { return f() })
    52  			})
    53  		}
    54  	})
    55  }
    56  
    57  // TestUnalignedWrite tests that writing data in an arbitrary pattern with
    58  // small input buffers.
    59  func TestUnalignedWrite(t *testing.T) {
    60  	cryptotest.TestAllImplementations(t, "sha3", testUnalignedWrite)
    61  }
    62  
    63  func testUnalignedWrite(t *testing.T) {
    64  	buf := sequentialBytes(0x10000)
    65  	for alg, df := range testDigests {
    66  		d := df()
    67  		d.Reset()
    68  		d.Write(buf)
    69  		want := d.Sum(nil)
    70  		d.Reset()
    71  		for i := 0; i < len(buf); {
    72  			// Cycle through offsets which make a 137 byte sequence.
    73  			// Because 137 is prime this sequence should exercise all corner cases.
    74  			offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
    75  			for _, j := range offsets {
    76  				if v := len(buf) - i; v < j {
    77  					j = v
    78  				}
    79  				d.Write(buf[i : i+j])
    80  				i += j
    81  			}
    82  		}
    83  		got := d.Sum(nil)
    84  		if !bytes.Equal(got, want) {
    85  			t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want)
    86  		}
    87  	}
    88  
    89  	// Same for SHAKE
    90  	for alg, df := range testShakes {
    91  		want := make([]byte, 16)
    92  		got := make([]byte, 16)
    93  		d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
    94  
    95  		d.Reset()
    96  		d.Write(buf)
    97  		d.Read(want)
    98  		d.Reset()
    99  		for i := 0; i < len(buf); {
   100  			// Cycle through offsets which make a 137 byte sequence.
   101  			// Because 137 is prime this sequence should exercise all corner cases.
   102  			offsets := [17]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1}
   103  			for _, j := range offsets {
   104  				if v := len(buf) - i; v < j {
   105  					j = v
   106  				}
   107  				d.Write(buf[i : i+j])
   108  				i += j
   109  			}
   110  		}
   111  		d.Read(got)
   112  		if !bytes.Equal(got, want) {
   113  			t.Errorf("Unaligned writes, alg=%s\ngot %q, want %q", alg, got, want)
   114  		}
   115  	}
   116  }
   117  
   118  // TestAppend checks that appending works when reallocation is necessary.
   119  func TestAppend(t *testing.T) {
   120  	cryptotest.TestAllImplementations(t, "sha3", testAppend)
   121  }
   122  
   123  func testAppend(t *testing.T) {
   124  	d := New224()
   125  
   126  	for capacity := 2; capacity <= 66; capacity += 64 {
   127  		// The first time around the loop, Sum will have to reallocate.
   128  		// The second time, it will not.
   129  		buf := make([]byte, 2, capacity)
   130  		d.Reset()
   131  		d.Write([]byte{0xcc})
   132  		buf = d.Sum(buf)
   133  		expected := "0000DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39"
   134  		if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
   135  			t.Errorf("got %s, want %s", got, expected)
   136  		}
   137  	}
   138  }
   139  
   140  // TestAppendNoRealloc tests that appending works when no reallocation is necessary.
   141  func TestAppendNoRealloc(t *testing.T) {
   142  	cryptotest.TestAllImplementations(t, "sha3", testAppendNoRealloc)
   143  }
   144  
   145  func testAppendNoRealloc(t *testing.T) {
   146  	buf := make([]byte, 1, 200)
   147  	d := New224()
   148  	d.Write([]byte{0xcc})
   149  	buf = d.Sum(buf)
   150  	expected := "00DF70ADC49B2E76EEE3A6931B93FA41841C3AF2CDF5B32A18B5478C39"
   151  	if got := strings.ToUpper(hex.EncodeToString(buf)); got != expected {
   152  		t.Errorf("got %s, want %s", got, expected)
   153  	}
   154  }
   155  
   156  // TestSqueezing checks that squeezing the full output a single time produces
   157  // the same output as repeatedly squeezing the instance.
   158  func TestSqueezing(t *testing.T) {
   159  	cryptotest.TestAllImplementations(t, "sha3", testSqueezing)
   160  }
   161  
   162  func testSqueezing(t *testing.T) {
   163  	for algo, v := range testShakes {
   164  		d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
   165  		d0.Write([]byte(testString))
   166  		ref := make([]byte, 32)
   167  		d0.Read(ref)
   168  
   169  		d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
   170  		d1.Write([]byte(testString))
   171  		var multiple []byte
   172  		for range ref {
   173  			d1.Read(make([]byte, 0))
   174  			one := make([]byte, 1)
   175  			d1.Read(one)
   176  			multiple = append(multiple, one...)
   177  		}
   178  		if !bytes.Equal(ref, multiple) {
   179  			t.Errorf("%s: squeezing %d bytes one at a time failed", algo, len(ref))
   180  		}
   181  	}
   182  }
   183  
   184  // sequentialBytes produces a buffer of size consecutive bytes 0x00, 0x01, ..., used for testing.
   185  //
   186  // The alignment of each slice is intentionally randomized to detect alignment
   187  // issues in the implementation. See https://golang.org/issue/37644.
   188  // Ideally, the compiler should fuzz the alignment itself.
   189  // (See https://golang.org/issue/35128.)
   190  func sequentialBytes(size int) []byte {
   191  	alignmentOffset := rand.Intn(8)
   192  	result := make([]byte, size+alignmentOffset)[alignmentOffset:]
   193  	for i := range result {
   194  		result[i] = byte(i)
   195  	}
   196  	return result
   197  }
   198  
   199  func TestReset(t *testing.T) {
   200  	cryptotest.TestAllImplementations(t, "sha3", testReset)
   201  }
   202  
   203  func testReset(t *testing.T) {
   204  	out1 := make([]byte, 32)
   205  	out2 := make([]byte, 32)
   206  
   207  	for _, v := range testShakes {
   208  		// Calculate hash for the first time
   209  		c := v.constructor(nil, []byte{0x99, 0x98})
   210  		c.Write(sequentialBytes(0x100))
   211  		c.Read(out1)
   212  
   213  		// Calculate hash again
   214  		c.Reset()
   215  		c.Write(sequentialBytes(0x100))
   216  		c.Read(out2)
   217  
   218  		if !bytes.Equal(out1, out2) {
   219  			t.Error("\nExpected:\n", out1, "\ngot:\n", out2)
   220  		}
   221  	}
   222  }
   223  
   224  var sinkSHA3 byte
   225  
   226  func TestAllocations(t *testing.T) {
   227  	cryptotest.SkipTestAllocations(t)
   228  	t.Run("New", func(t *testing.T) {
   229  		if allocs := testing.AllocsPerRun(10, func() {
   230  			h := New256()
   231  			b := []byte("ABC")
   232  			h.Write(b)
   233  			out := make([]byte, 0, 32)
   234  			out = h.Sum(out)
   235  			sinkSHA3 ^= out[0]
   236  		}); allocs > 0 {
   237  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   238  		}
   239  	})
   240  	t.Run("NewSHAKE", func(t *testing.T) {
   241  		if allocs := testing.AllocsPerRun(10, func() {
   242  			h := NewSHAKE128()
   243  			b := []byte("ABC")
   244  			h.Write(b)
   245  			out := make([]byte, 32)
   246  			h.Read(out)
   247  			sinkSHA3 ^= out[0]
   248  		}); allocs > 0 {
   249  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   250  		}
   251  	})
   252  	t.Run("Sum", func(t *testing.T) {
   253  		if allocs := testing.AllocsPerRun(10, func() {
   254  			b := []byte("ABC")
   255  			out := Sum256(b)
   256  			sinkSHA3 ^= out[0]
   257  		}); allocs > 0 {
   258  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   259  		}
   260  	})
   261  	t.Run("SumSHAKE", func(t *testing.T) {
   262  		if allocs := testing.AllocsPerRun(10, func() {
   263  			b := []byte("ABC")
   264  			out := SumSHAKE128(b, 10)
   265  			sinkSHA3 ^= out[0]
   266  		}); allocs > 0 {
   267  			t.Errorf("expected zero allocations, got %0.1f", allocs)
   268  		}
   269  	})
   270  }
   271  
   272  func TestCSHAKEAccumulated(t *testing.T) {
   273  	// Generated with pycryptodome@3.20.0
   274  	//
   275  	//    from Crypto.Hash import cSHAKE128
   276  	//    rng = cSHAKE128.new()
   277  	//    acc = cSHAKE128.new()
   278  	//    for n in range(200):
   279  	//        N = rng.read(n)
   280  	//        for s in range(200):
   281  	//            S = rng.read(s)
   282  	//            c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
   283  	//            c.update(rng.read(100))
   284  	//            acc.update(c.read(200))
   285  	//            c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
   286  	//            c.update(rng.read(168))
   287  	//            acc.update(c.read(200))
   288  	//            c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
   289  	//            c.update(rng.read(200))
   290  	//            acc.update(c.read(200))
   291  	//    print(acc.read(32).hex())
   292  	//
   293  	// and with @noble/hashes@v1.5.0
   294  	//
   295  	//    import { bytesToHex } from "@noble/hashes/utils";
   296  	//    import { cshake128 } from "@noble/hashes/sha3-addons";
   297  	//    const rng = cshake128.create();
   298  	//    const acc = cshake128.create();
   299  	//    for (let n = 0; n < 200; n++) {
   300  	//        const N = rng.xof(n);
   301  	//        for (let s = 0; s < 200; s++) {
   302  	//            const S = rng.xof(s);
   303  	//            let c = cshake128.create({ NISTfn: N, personalization: S });
   304  	//            c.update(rng.xof(100));
   305  	//            acc.update(c.xof(200));
   306  	//            c = cshake128.create({ NISTfn: N, personalization: S });
   307  	//            c.update(rng.xof(168));
   308  	//            acc.update(c.xof(200));
   309  	//            c = cshake128.create({ NISTfn: N, personalization: S });
   310  	//            c.update(rng.xof(200));
   311  	//            acc.update(c.xof(200));
   312  	//        }
   313  	//    }
   314  	//    console.log(bytesToHex(acc.xof(32)));
   315  	//
   316  	cryptotest.TestAllImplementations(t, "sha3", func(t *testing.T) {
   317  		t.Run("cSHAKE128", func(t *testing.T) {
   318  			testCSHAKEAccumulated(t, NewCSHAKE128, (1600-256)/8,
   319  				"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
   320  		})
   321  		t.Run("cSHAKE256", func(t *testing.T) {
   322  			testCSHAKEAccumulated(t, NewCSHAKE256, (1600-512)/8,
   323  				"0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef")
   324  		})
   325  	})
   326  }
   327  
   328  func testCSHAKEAccumulated(t *testing.T, newCSHAKE func(N, S []byte) *SHAKE, rate int64, exp string) {
   329  	rnd := newCSHAKE(nil, nil)
   330  	acc := newCSHAKE(nil, nil)
   331  	for n := 0; n < 200; n++ {
   332  		N := make([]byte, n)
   333  		rnd.Read(N)
   334  		for s := 0; s < 200; s++ {
   335  			S := make([]byte, s)
   336  			rnd.Read(S)
   337  
   338  			c := newCSHAKE(N, S)
   339  			io.CopyN(c, rnd, 100 /* < rate */)
   340  			io.CopyN(acc, c, 200)
   341  
   342  			c.Reset()
   343  			io.CopyN(c, rnd, rate)
   344  			io.CopyN(acc, c, 200)
   345  
   346  			c.Reset()
   347  			io.CopyN(c, rnd, 200 /* > rate */)
   348  			io.CopyN(acc, c, 200)
   349  		}
   350  	}
   351  	out := make([]byte, 32)
   352  	acc.Read(out)
   353  	if got := hex.EncodeToString(out); got != exp {
   354  		t.Errorf("got %s, want %s", got, exp)
   355  	}
   356  }
   357  
   358  func TestCSHAKELargeS(t *testing.T) {
   359  	cryptotest.TestAllImplementations(t, "sha3", testCSHAKELargeS)
   360  }
   361  
   362  func testCSHAKELargeS(t *testing.T) {
   363  	if testing.Short() {
   364  		t.Skip("skipping test in short mode.")
   365  	}
   366  
   367  	// See https://go.dev/issue/66232.
   368  	const s = (1<<32)/8 + 1000 // s * 8 > 2^32
   369  	S := make([]byte, s)
   370  	rnd := NewSHAKE128()
   371  	rnd.Read(S)
   372  	c := NewCSHAKE128(nil, S)
   373  	io.CopyN(c, rnd, 1000)
   374  	out := make([]byte, 32)
   375  	c.Read(out)
   376  
   377  	// Generated with pycryptodome@3.20.0
   378  	//
   379  	//    from Crypto.Hash import cSHAKE128
   380  	//    rng = cSHAKE128.new()
   381  	//    S = rng.read(536871912)
   382  	//    c = cSHAKE128.new(custom=S)
   383  	//    c.update(rng.read(1000))
   384  	//    print(c.read(32).hex())
   385  	//
   386  	exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0"
   387  	if got := hex.EncodeToString(out); got != exp {
   388  		t.Errorf("got %s, want %s", got, exp)
   389  	}
   390  }
   391  
   392  func TestMarshalUnmarshal(t *testing.T) {
   393  	cryptotest.TestAllImplementations(t, "sha3", func(t *testing.T) {
   394  		t.Run("SHA3-224", func(t *testing.T) { testMarshalUnmarshal(t, New224()) })
   395  		t.Run("SHA3-256", func(t *testing.T) { testMarshalUnmarshal(t, New256()) })
   396  		t.Run("SHA3-384", func(t *testing.T) { testMarshalUnmarshal(t, New384()) })
   397  		t.Run("SHA3-512", func(t *testing.T) { testMarshalUnmarshal(t, New512()) })
   398  		t.Run("SHAKE128", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewSHAKE128()) })
   399  		t.Run("SHAKE256", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewSHAKE256()) })
   400  		t.Run("cSHAKE128", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewCSHAKE128([]byte("N"), []byte("S"))) })
   401  		t.Run("cSHAKE256", func(t *testing.T) { testMarshalUnmarshalSHAKE(t, NewCSHAKE256([]byte("N"), []byte("S"))) })
   402  	})
   403  }
   404  
   405  // TODO(filippo): move this to crypto/internal/cryptotest.
   406  func testMarshalUnmarshal(t *testing.T, h *SHA3) {
   407  	buf := make([]byte, 200)
   408  	rand.Read(buf)
   409  	n := rand.Intn(200)
   410  	h.Write(buf)
   411  	want := h.Sum(nil)
   412  	h.Reset()
   413  	h.Write(buf[:n])
   414  	b, err := h.MarshalBinary()
   415  	if err != nil {
   416  		t.Errorf("MarshalBinary: %v", err)
   417  	}
   418  	h.Write(bytes.Repeat([]byte{0}, 200))
   419  	if err := h.UnmarshalBinary(b); err != nil {
   420  		t.Errorf("UnmarshalBinary: %v", err)
   421  	}
   422  	h.Write(buf[n:])
   423  	got := h.Sum(nil)
   424  	if !bytes.Equal(got, want) {
   425  		t.Errorf("got %x, want %x", got, want)
   426  	}
   427  }
   428  
   429  // TODO(filippo): move this to crypto/internal/cryptotest.
   430  func testMarshalUnmarshalSHAKE(t *testing.T, h *SHAKE) {
   431  	buf := make([]byte, 200)
   432  	rand.Read(buf)
   433  	n := rand.Intn(200)
   434  	h.Write(buf)
   435  	want := make([]byte, 32)
   436  	h.Read(want)
   437  	h.Reset()
   438  	h.Write(buf[:n])
   439  	b, err := h.MarshalBinary()
   440  	if err != nil {
   441  		t.Errorf("MarshalBinary: %v", err)
   442  	}
   443  	h.Write(bytes.Repeat([]byte{0}, 200))
   444  	if err := h.UnmarshalBinary(b); err != nil {
   445  		t.Errorf("UnmarshalBinary: %v", err)
   446  	}
   447  	h.Write(buf[n:])
   448  	got := make([]byte, 32)
   449  	h.Read(got)
   450  	if !bytes.Equal(got, want) {
   451  		t.Errorf("got %x, want %x", got, want)
   452  	}
   453  }
   454  
   455  // benchmarkHash tests the speed to hash num buffers of buflen each.
   456  func benchmarkHash(b *testing.B, h hash.Hash, size, num int) {
   457  	b.StopTimer()
   458  	h.Reset()
   459  	data := sequentialBytes(size)
   460  	b.SetBytes(int64(size * num))
   461  	b.StartTimer()
   462  
   463  	var state []byte
   464  	for i := 0; i < b.N; i++ {
   465  		for j := 0; j < num; j++ {
   466  			h.Write(data)
   467  		}
   468  		state = h.Sum(state[:0])
   469  	}
   470  	b.StopTimer()
   471  	h.Reset()
   472  }
   473  
   474  // benchmarkShake is specialized to the Shake instances, which don't
   475  // require a copy on reading output.
   476  func benchmarkShake(b *testing.B, h *SHAKE, size, num int) {
   477  	b.StopTimer()
   478  	h.Reset()
   479  	data := sequentialBytes(size)
   480  	d := make([]byte, 32)
   481  
   482  	b.SetBytes(int64(size * num))
   483  	b.StartTimer()
   484  
   485  	for i := 0; i < b.N; i++ {
   486  		h.Reset()
   487  		for j := 0; j < num; j++ {
   488  			h.Write(data)
   489  		}
   490  		h.Read(d)
   491  	}
   492  }
   493  
   494  func BenchmarkSha3_512_MTU(b *testing.B) { benchmarkHash(b, New512(), 1350, 1) }
   495  func BenchmarkSha3_384_MTU(b *testing.B) { benchmarkHash(b, New384(), 1350, 1) }
   496  func BenchmarkSha3_256_MTU(b *testing.B) { benchmarkHash(b, New256(), 1350, 1) }
   497  func BenchmarkSha3_224_MTU(b *testing.B) { benchmarkHash(b, New224(), 1350, 1) }
   498  
   499  func BenchmarkShake128_MTU(b *testing.B)  { benchmarkShake(b, NewSHAKE128(), 1350, 1) }
   500  func BenchmarkShake256_MTU(b *testing.B)  { benchmarkShake(b, NewSHAKE256(), 1350, 1) }
   501  func BenchmarkShake256_16x(b *testing.B)  { benchmarkShake(b, NewSHAKE256(), 16, 1024) }
   502  func BenchmarkShake256_1MiB(b *testing.B) { benchmarkShake(b, NewSHAKE256(), 1024, 1024) }
   503  
   504  func BenchmarkSha3_512_1MiB(b *testing.B) { benchmarkHash(b, New512(), 1024, 1024) }
   505  

View as plain text