Source file src/crypto/hpke/hpke_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/ecdh"
    10  	"crypto/mlkem"
    11  	"crypto/mlkem/mlkemtest"
    12  	"crypto/sha3"
    13  	"encoding/hex"
    14  	"encoding/json"
    15  	"fmt"
    16  	"io"
    17  	"os"
    18  	"testing"
    19  )
    20  
    21  func Example() {
    22  	// In this example, we use MLKEM768-X25519 as the KEM, HKDF-SHA256 as the
    23  	// KDF, and AES-256-GCM as the AEAD to encrypt a single message from a
    24  	// sender to a recipient using the one-shot API.
    25  
    26  	kem, kdf, aead := MLKEM768X25519(), HKDFSHA256(), AES256GCM()
    27  
    28  	// Recipient side
    29  	var (
    30  		recipientPrivateKey PrivateKey
    31  		publicKeyBytes      []byte
    32  	)
    33  	{
    34  		k, err := kem.GenerateKey()
    35  		if err != nil {
    36  			panic(err)
    37  		}
    38  		recipientPrivateKey = k
    39  		publicKeyBytes = k.PublicKey().Bytes()
    40  	}
    41  
    42  	// Sender side
    43  	var ciphertext []byte
    44  	{
    45  		publicKey, err := kem.NewPublicKey(publicKeyBytes)
    46  		if err != nil {
    47  			panic(err)
    48  		}
    49  
    50  		message := []byte("|-()-|")
    51  		ct, err := Seal(publicKey, kdf, aead, []byte("example"), message)
    52  		if err != nil {
    53  			panic(err)
    54  		}
    55  
    56  		ciphertext = ct
    57  	}
    58  
    59  	// Recipient side
    60  	{
    61  		plaintext, err := Open(recipientPrivateKey, kdf, aead, []byte("example"), ciphertext)
    62  		if err != nil {
    63  			panic(err)
    64  		}
    65  		fmt.Printf("Decrypted message: %s\n", plaintext)
    66  	}
    67  
    68  	// Output:
    69  	// Decrypted message: |-()-|
    70  }
    71  
    72  func mustDecodeHex(t *testing.T, in string) []byte {
    73  	t.Helper()
    74  	b, err := hex.DecodeString(in)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	return b
    79  }
    80  
    81  func TestVectors(t *testing.T) {
    82  	t.Run("rfc9180", func(t *testing.T) {
    83  		testVectors(t, "rfc9180")
    84  	})
    85  	t.Run("hpke-pq", func(t *testing.T) {
    86  		testVectors(t, "hpke-pq")
    87  	})
    88  }
    89  
    90  func testVectors(t *testing.T, name string) {
    91  	vectorsJSON, err := os.ReadFile("testdata/" + name + ".json")
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	var vectors []struct {
    96  		Mode        uint16 `json:"mode"`
    97  		KEM         uint16 `json:"kem_id"`
    98  		KDF         uint16 `json:"kdf_id"`
    99  		AEAD        uint16 `json:"aead_id"`
   100  		Info        string `json:"info"`
   101  		IkmE        string `json:"ikmE"`
   102  		IkmR        string `json:"ikmR"`
   103  		SkRm        string `json:"skRm"`
   104  		PkRm        string `json:"pkRm"`
   105  		Enc         string `json:"enc"`
   106  		Encryptions []struct {
   107  			Aad   string `json:"aad"`
   108  			Ct    string `json:"ct"`
   109  			Nonce string `json:"nonce"`
   110  			Pt    string `json:"pt"`
   111  		} `json:"encryptions"`
   112  		Exports []struct {
   113  			Context string `json:"exporter_context"`
   114  			L       int    `json:"L"`
   115  			Value   string `json:"exported_value"`
   116  		} `json:"exports"`
   117  
   118  		// Instead of checking in a very large rfc9180.json, we computed
   119  		// alternative accumulated values.
   120  		AccEncryptions string `json:"encryptions_accumulated"`
   121  		AccExports     string `json:"exports_accumulated"`
   122  	}
   123  	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
   124  		t.Fatal(err)
   125  	}
   126  
   127  	for _, vector := range vectors {
   128  		name := fmt.Sprintf("mode %04x kem %04x kdf %04x aead %04x",
   129  			vector.Mode, vector.KEM, vector.KDF, vector.AEAD)
   130  		t.Run(name, func(t *testing.T) {
   131  			if vector.Mode != 0 {
   132  				t.Skip("only mode 0 (base) is supported")
   133  			}
   134  			if vector.KEM == 0x0021 {
   135  				t.Skip("KEM 0x0021 (DHKEM(X448)) not supported")
   136  			}
   137  			if vector.KEM == 0x0040 {
   138  				t.Skip("KEM 0x0040 (ML-KEM-512) not supported")
   139  			}
   140  			if vector.KDF == 0x0012 || vector.KDF == 0x0013 {
   141  				t.Skipf("TurboSHAKE KDF not supported")
   142  			}
   143  
   144  			kdf, err := NewKDF(vector.KDF)
   145  			if err != nil {
   146  				t.Fatal(err)
   147  			}
   148  			if kdf.ID() != vector.KDF {
   149  				t.Errorf("unexpected KDF ID: got %04x, want %04x", kdf.ID(), vector.KDF)
   150  			}
   151  
   152  			aead, err := NewAEAD(vector.AEAD)
   153  			if err != nil {
   154  				t.Fatal(err)
   155  			}
   156  			if aead.ID() != vector.AEAD {
   157  				t.Errorf("unexpected AEAD ID: got %04x, want %04x", aead.ID(), vector.AEAD)
   158  			}
   159  
   160  			kem, err := NewKEM(vector.KEM)
   161  			if err != nil {
   162  				t.Fatal(err)
   163  			}
   164  			if kem.ID() != vector.KEM {
   165  				t.Errorf("unexpected KEM ID: got %04x, want %04x", kem.ID(), vector.KEM)
   166  			}
   167  
   168  			pubKeyBytes := mustDecodeHex(t, vector.PkRm)
   169  			kemSender, err := kem.NewPublicKey(pubKeyBytes)
   170  			if err != nil {
   171  				t.Fatal(err)
   172  			}
   173  			if kemSender.KEM() != kem {
   174  				t.Errorf("unexpected KEM from sender: got %04x, want %04x", kemSender.KEM().ID(), kem.ID())
   175  			}
   176  			if !bytes.Equal(kemSender.Bytes(), pubKeyBytes) {
   177  				t.Errorf("unexpected KEM bytes: got %x, want %x", kemSender.Bytes(), pubKeyBytes)
   178  			}
   179  
   180  			ikmE := mustDecodeHex(t, vector.IkmE)
   181  			setupDerandomizedEncap(t, ikmE, kemSender)
   182  
   183  			info := mustDecodeHex(t, vector.Info)
   184  			encap, sender, err := NewSender(kemSender, kdf, aead, info)
   185  			if err != nil {
   186  				t.Fatal(err)
   187  			}
   188  			if len(encap) != kem.encSize() {
   189  				t.Errorf("unexpected encapsulated key size: got %d, want %d", len(encap), kem.encSize())
   190  			}
   191  
   192  			expectedEncap := mustDecodeHex(t, vector.Enc)
   193  			if !bytes.Equal(encap, expectedEncap) {
   194  				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
   195  			}
   196  
   197  			privKeyBytes := mustDecodeHex(t, vector.SkRm)
   198  			kemRecipient, err := kem.NewPrivateKey(privKeyBytes)
   199  			if err != nil {
   200  				t.Fatal(err)
   201  			}
   202  			if kemRecipient.KEM() != kem {
   203  				t.Errorf("unexpected KEM from recipient: got %04x, want %04x", kemRecipient.KEM().ID(), kem.ID())
   204  			}
   205  			kemRecipientBytes, err := kemRecipient.Bytes()
   206  			if err != nil {
   207  				t.Fatal(err)
   208  			}
   209  			// X25519 serialized keys must be clamped, so the bytes might not match.
   210  			if !bytes.Equal(kemRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
   211  				t.Errorf("unexpected KEM bytes: got %x, want %x", kemRecipientBytes, privKeyBytes)
   212  			}
   213  			if vector.KEM == DHKEM(ecdh.X25519()).ID() {
   214  				kem2, err := kem.NewPrivateKey(kemRecipientBytes)
   215  				if err != nil {
   216  					t.Fatal(err)
   217  				}
   218  				kemRecipientBytes2, err := kem2.Bytes()
   219  				if err != nil {
   220  					t.Fatal(err)
   221  				}
   222  				if !bytes.Equal(kemRecipientBytes2, kemRecipientBytes) {
   223  					t.Errorf("X25519 re-serialized key differs: got %x, want %x", kemRecipientBytes2, kemRecipientBytes)
   224  				}
   225  				if !bytes.Equal(kem2.PublicKey().Bytes(), pubKeyBytes) {
   226  					t.Errorf("X25519 re-derived public key differs: got %x, want %x", kem2.PublicKey().Bytes(), pubKeyBytes)
   227  				}
   228  			}
   229  			if !bytes.Equal(kemRecipient.PublicKey().Bytes(), pubKeyBytes) {
   230  				t.Errorf("unexpected KEM sender bytes: got %x, want %x", kemRecipient.PublicKey().Bytes(), pubKeyBytes)
   231  			}
   232  
   233  			ikm := mustDecodeHex(t, vector.IkmR)
   234  			derivRecipient, err := kem.DeriveKeyPair(ikm)
   235  			if err != nil {
   236  				t.Fatal(err)
   237  			}
   238  			derivRecipientBytes, err := derivRecipient.Bytes()
   239  			if err != nil {
   240  				t.Fatal(err)
   241  			}
   242  			if !bytes.Equal(derivRecipientBytes, privKeyBytes) && vector.KEM != DHKEM(ecdh.X25519()).ID() {
   243  				t.Errorf("unexpected KEM bytes from seed: got %x, want %x", derivRecipientBytes, privKeyBytes)
   244  			}
   245  			if !bytes.Equal(derivRecipient.PublicKey().Bytes(), pubKeyBytes) {
   246  				t.Errorf("unexpected KEM sender bytes from seed: got %x, want %x", derivRecipient.PublicKey().Bytes(), pubKeyBytes)
   247  			}
   248  
   249  			recipient, err := NewRecipient(encap, kemRecipient, kdf, aead, info)
   250  			if err != nil {
   251  				t.Fatal(err)
   252  			}
   253  
   254  			if aead != ExportOnly() && len(vector.AccEncryptions) != 0 {
   255  				source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
   256  				for range 1000 {
   257  					aad, plaintext := drawRandomInput(t, source), drawRandomInput(t, source)
   258  					ciphertext, err := sender.Seal(aad, plaintext)
   259  					if err != nil {
   260  						t.Fatal(err)
   261  					}
   262  					sink.Write(ciphertext)
   263  					got, err := recipient.Open(aad, ciphertext)
   264  					if err != nil {
   265  						t.Fatal(err)
   266  					}
   267  					if !bytes.Equal(got, plaintext) {
   268  						t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
   269  					}
   270  				}
   271  				encryptions := make([]byte, 16)
   272  				sink.Read(encryptions)
   273  				expectedEncryptions := mustDecodeHex(t, vector.AccEncryptions)
   274  				if !bytes.Equal(encryptions, expectedEncryptions) {
   275  					t.Errorf("unexpected accumulated encryptions, got: %x, want %x", encryptions, expectedEncryptions)
   276  				}
   277  			} else if aead != ExportOnly() {
   278  				for _, enc := range vector.Encryptions {
   279  					aad := mustDecodeHex(t, enc.Aad)
   280  					plaintext := mustDecodeHex(t, enc.Pt)
   281  					expectedCiphertext := mustDecodeHex(t, enc.Ct)
   282  
   283  					ciphertext, err := sender.Seal(aad, plaintext)
   284  					if err != nil {
   285  						t.Fatal(err)
   286  					}
   287  					if !bytes.Equal(ciphertext, expectedCiphertext) {
   288  						t.Errorf("unexpected ciphertext, got: %x, want %x", ciphertext, expectedCiphertext)
   289  					}
   290  
   291  					got, err := recipient.Open(aad, ciphertext)
   292  					if err != nil {
   293  						t.Fatal(err)
   294  					}
   295  					if !bytes.Equal(got, plaintext) {
   296  						t.Errorf("unexpected plaintext: got %x want %x", got, plaintext)
   297  					}
   298  				}
   299  			} else {
   300  				if _, err := sender.Seal(nil, nil); err == nil {
   301  					t.Error("expected error from Seal with export-only AEAD")
   302  				}
   303  				if _, err := recipient.Open(nil, nil); err == nil {
   304  					t.Error("expected error from Open with export-only AEAD")
   305  				}
   306  			}
   307  
   308  			if len(vector.AccExports) != 0 {
   309  				source, sink := sha3.NewSHAKE128(), sha3.NewSHAKE128()
   310  				for l := range 1000 {
   311  					context := string(drawRandomInput(t, source))
   312  					value, err := sender.Export(context, l)
   313  					if err != nil {
   314  						t.Fatal(err)
   315  					}
   316  					sink.Write(value)
   317  					got, err := recipient.Export(context, l)
   318  					if err != nil {
   319  						t.Fatal(err)
   320  					}
   321  					if !bytes.Equal(got, value) {
   322  						t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
   323  					}
   324  				}
   325  				exports := make([]byte, 16)
   326  				sink.Read(exports)
   327  				expectedExports := mustDecodeHex(t, vector.AccExports)
   328  				if !bytes.Equal(exports, expectedExports) {
   329  					t.Errorf("unexpected accumulated exports, got: %x, want %x", exports, expectedExports)
   330  				}
   331  			} else {
   332  				for _, exp := range vector.Exports {
   333  					context := string(mustDecodeHex(t, exp.Context))
   334  					expectedValue := mustDecodeHex(t, exp.Value)
   335  
   336  					value, err := sender.Export(context, exp.L)
   337  					if err != nil {
   338  						t.Fatal(err)
   339  					}
   340  					if !bytes.Equal(value, expectedValue) {
   341  						t.Errorf("unexpected exported value, got: %x, want %x", value, expectedValue)
   342  					}
   343  
   344  					got, err := recipient.Export(context, exp.L)
   345  					if err != nil {
   346  						t.Fatal(err)
   347  					}
   348  					if !bytes.Equal(got, value) {
   349  						t.Errorf("recipient: unexpected exported secret: got %x want %x", got, value)
   350  					}
   351  				}
   352  			}
   353  		})
   354  	}
   355  }
   356  
   357  func drawRandomInput(t *testing.T, r io.Reader) []byte {
   358  	t.Helper()
   359  	l := make([]byte, 1)
   360  	if _, err := r.Read(l); err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	n := int(l[0])
   364  	b := make([]byte, n)
   365  	if _, err := r.Read(b); err != nil {
   366  		t.Fatal(err)
   367  	}
   368  	return b
   369  }
   370  
   371  func setupDerandomizedEncap(t *testing.T, randBytes []byte, pk PublicKey) {
   372  	t.Cleanup(func() {
   373  		testingOnlyGenerateKey = nil
   374  		testingOnlyEncapsulate = nil
   375  	})
   376  	switch pk.KEM() {
   377  	case DHKEM(ecdh.P256()), DHKEM(ecdh.P384()), DHKEM(ecdh.P521()), DHKEM(ecdh.X25519()):
   378  		r, err := pk.KEM().DeriveKeyPair(randBytes)
   379  		if err != nil {
   380  			t.Fatal(err)
   381  		}
   382  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   383  			return r.(*dhKEMPrivateKey).priv.(*ecdh.PrivateKey)
   384  		}
   385  	case mlkem768:
   386  		pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey768)
   387  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   388  			ss, ct, err := mlkemtest.Encapsulate768(pq, randBytes)
   389  			if err != nil {
   390  				t.Fatal(err)
   391  			}
   392  			return ss, ct
   393  		}
   394  	case mlkem1024:
   395  		pq := pk.(*mlkemPublicKey).pq.(*mlkem.EncapsulationKey1024)
   396  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   397  			ss, ct, err := mlkemtest.Encapsulate1024(pq, randBytes)
   398  			if err != nil {
   399  				t.Fatal(err)
   400  			}
   401  			return ss, ct
   402  		}
   403  	case mlkem768X25519:
   404  		pqRand, tRand := randBytes[:32], randBytes[32:]
   405  		pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
   406  		k, err := ecdh.X25519().NewPrivateKey(tRand)
   407  		if err != nil {
   408  			t.Fatal(err)
   409  		}
   410  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   411  			return k
   412  		}
   413  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   414  			ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
   415  			if err != nil {
   416  				t.Fatal(err)
   417  			}
   418  			return ss, ct
   419  		}
   420  	case mlkem768P256:
   421  		// The rest of randBytes are the following candidates for rejection
   422  		// sampling, but they are never reached.
   423  		pqRand, tRand := randBytes[:32], randBytes[32:64]
   424  		pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey768)
   425  		k, err := ecdh.P256().NewPrivateKey(tRand)
   426  		if err != nil {
   427  			t.Fatal(err)
   428  		}
   429  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   430  			return k
   431  		}
   432  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   433  			ss, ct, err := mlkemtest.Encapsulate768(pq, pqRand)
   434  			if err != nil {
   435  				t.Fatal(err)
   436  			}
   437  			return ss, ct
   438  		}
   439  	case mlkem1024P384:
   440  		pqRand, tRand := randBytes[:32], randBytes[32:]
   441  		pq := pk.(*hybridPublicKey).pq.(*mlkem.EncapsulationKey1024)
   442  		k, err := ecdh.P384().NewPrivateKey(tRand)
   443  		if err != nil {
   444  			t.Fatal(err)
   445  		}
   446  		testingOnlyGenerateKey = func() *ecdh.PrivateKey {
   447  			return k
   448  		}
   449  		testingOnlyEncapsulate = func() ([]byte, []byte) {
   450  			ss, ct, err := mlkemtest.Encapsulate1024(pq, pqRand)
   451  			if err != nil {
   452  				t.Fatal(err)
   453  			}
   454  			return ss, ct
   455  		}
   456  	default:
   457  		t.Fatalf("unsupported KEM %04x", pk.KEM().ID())
   458  	}
   459  }
   460  
   461  func TestSingletons(t *testing.T) {
   462  	if HKDFSHA256() != HKDFSHA256() {
   463  		t.Error("HKDFSHA256() != HKDFSHA256()")
   464  	}
   465  	if HKDFSHA384() != HKDFSHA384() {
   466  		t.Error("HKDFSHA384() != HKDFSHA384()")
   467  	}
   468  	if HKDFSHA512() != HKDFSHA512() {
   469  		t.Error("HKDFSHA512() != HKDFSHA512()")
   470  	}
   471  	if AES128GCM() != AES128GCM() {
   472  		t.Error("AES128GCM() != AES128GCM()")
   473  	}
   474  	if AES256GCM() != AES256GCM() {
   475  		t.Error("AES256GCM() != AES256GCM()")
   476  	}
   477  	if ChaCha20Poly1305() != ChaCha20Poly1305() {
   478  		t.Error("ChaCha20Poly1305() != ChaCha20Poly1305()")
   479  	}
   480  	if ExportOnly() != ExportOnly() {
   481  		t.Error("ExportOnly() != ExportOnly()")
   482  	}
   483  	if DHKEM(ecdh.P256()) != DHKEM(ecdh.P256()) {
   484  		t.Error("DHKEM(P-256) != DHKEM(P-256)")
   485  	}
   486  	if DHKEM(ecdh.P384()) != DHKEM(ecdh.P384()) {
   487  		t.Error("DHKEM(P-384) != DHKEM(P-384)")
   488  	}
   489  	if DHKEM(ecdh.P521()) != DHKEM(ecdh.P521()) {
   490  		t.Error("DHKEM(P-521) != DHKEM(P-521)")
   491  	}
   492  	if DHKEM(ecdh.X25519()) != DHKEM(ecdh.X25519()) {
   493  		t.Error("DHKEM(X25519) != DHKEM(X25519)")
   494  	}
   495  	if MLKEM768() != MLKEM768() {
   496  		t.Error("MLKEM768() != MLKEM768()")
   497  	}
   498  	if MLKEM1024() != MLKEM1024() {
   499  		t.Error("MLKEM1024() != MLKEM1024()")
   500  	}
   501  	if MLKEM768X25519() != MLKEM768X25519() {
   502  		t.Error("MLKEM768X25519() != MLKEM768X25519()")
   503  	}
   504  	if MLKEM768P256() != MLKEM768P256() {
   505  		t.Error("MLKEM768P256() != MLKEM768P256()")
   506  	}
   507  	if MLKEM1024P384() != MLKEM1024P384() {
   508  		t.Error("MLKEM1024P384() != MLKEM1024P384()")
   509  	}
   510  }
   511  

View as plain text