Source file src/crypto/internal/hpke/hpke_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/hex"
    10  	"encoding/json"
    11  	"os"
    12  	"strconv"
    13  	"strings"
    14  	"testing"
    15  
    16  	"crypto/ecdh"
    17  	_ "crypto/sha256"
    18  	_ "crypto/sha512"
    19  )
    20  
    21  func mustDecodeHex(t *testing.T, in string) []byte {
    22  	t.Helper()
    23  	b, err := hex.DecodeString(in)
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  	return b
    28  }
    29  
    30  func parseVectorSetup(vector string) map[string]string {
    31  	vals := map[string]string{}
    32  	for _, l := range strings.Split(vector, "\n") {
    33  		fields := strings.Split(l, ": ")
    34  		vals[fields[0]] = fields[1]
    35  	}
    36  	return vals
    37  }
    38  
    39  func parseVectorEncryptions(vector string) []map[string]string {
    40  	vals := []map[string]string{}
    41  	for _, section := range strings.Split(vector, "\n\n") {
    42  		e := map[string]string{}
    43  		for _, l := range strings.Split(section, "\n") {
    44  			fields := strings.Split(l, ": ")
    45  			e[fields[0]] = fields[1]
    46  		}
    47  		vals = append(vals, e)
    48  	}
    49  	return vals
    50  }
    51  
    52  func TestRFC9180Vectors(t *testing.T) {
    53  	vectorsJSON, err := os.ReadFile("testdata/rfc9180-vectors.json")
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  
    58  	var vectors []struct {
    59  		Name        string
    60  		Setup       string
    61  		Encryptions string
    62  	}
    63  	if err := json.Unmarshal(vectorsJSON, &vectors); err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	for _, vector := range vectors {
    68  		t.Run(vector.Name, func(t *testing.T) {
    69  			setup := parseVectorSetup(vector.Setup)
    70  
    71  			kemID, err := strconv.Atoi(setup["kem_id"])
    72  			if err != nil {
    73  				t.Fatal(err)
    74  			}
    75  			if _, ok := SupportedKEMs[uint16(kemID)]; !ok {
    76  				t.Skip("unsupported KEM")
    77  			}
    78  			kdfID, err := strconv.Atoi(setup["kdf_id"])
    79  			if err != nil {
    80  				t.Fatal(err)
    81  			}
    82  			if _, ok := SupportedKDFs[uint16(kdfID)]; !ok {
    83  				t.Skip("unsupported KDF")
    84  			}
    85  			aeadID, err := strconv.Atoi(setup["aead_id"])
    86  			if err != nil {
    87  				t.Fatal(err)
    88  			}
    89  			if _, ok := SupportedAEADs[uint16(aeadID)]; !ok {
    90  				t.Skip("unsupported AEAD")
    91  			}
    92  
    93  			info := mustDecodeHex(t, setup["info"])
    94  			pubKeyBytes := mustDecodeHex(t, setup["pkRm"])
    95  			pub, err := ParseHPKEPublicKey(uint16(kemID), pubKeyBytes)
    96  			if err != nil {
    97  				t.Fatal(err)
    98  			}
    99  
   100  			ephemeralPrivKey := mustDecodeHex(t, setup["skEm"])
   101  
   102  			testingOnlyGenerateKey = func() (*ecdh.PrivateKey, error) {
   103  				return SupportedKEMs[uint16(kemID)].curve.NewPrivateKey(ephemeralPrivKey)
   104  			}
   105  			t.Cleanup(func() { testingOnlyGenerateKey = nil })
   106  
   107  			encap, sender, err := SetupSender(
   108  				uint16(kemID),
   109  				uint16(kdfID),
   110  				uint16(aeadID),
   111  				pub,
   112  				info,
   113  			)
   114  			if err != nil {
   115  				t.Fatal(err)
   116  			}
   117  
   118  			expectedEncap := mustDecodeHex(t, setup["enc"])
   119  			if !bytes.Equal(encap, expectedEncap) {
   120  				t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap)
   121  			}
   122  
   123  			privKeyBytes := mustDecodeHex(t, setup["skRm"])
   124  			priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes)
   125  			if err != nil {
   126  				t.Fatal(err)
   127  			}
   128  
   129  			receipient, err := SetupReceipient(
   130  				uint16(kemID),
   131  				uint16(kdfID),
   132  				uint16(aeadID),
   133  				priv,
   134  				info,
   135  				encap,
   136  			)
   137  			if err != nil {
   138  				t.Fatal(err)
   139  			}
   140  
   141  			for _, ctx := range []*context{sender.context, receipient.context} {
   142  				expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"])
   143  				if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) {
   144  					t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret)
   145  				}
   146  				expectedKey := mustDecodeHex(t, setup["key"])
   147  				if !bytes.Equal(ctx.key, expectedKey) {
   148  					t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey)
   149  				}
   150  				expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"])
   151  				if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) {
   152  					t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce)
   153  				}
   154  				expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"])
   155  				if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) {
   156  					t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret)
   157  				}
   158  			}
   159  
   160  			for _, enc := range parseVectorEncryptions(vector.Encryptions) {
   161  				t.Run("seq num "+enc["sequence number"], func(t *testing.T) {
   162  					seqNum, err := strconv.Atoi(enc["sequence number"])
   163  					if err != nil {
   164  						t.Fatal(err)
   165  					}
   166  					sender.seqNum = uint128{lo: uint64(seqNum)}
   167  					receipient.seqNum = uint128{lo: uint64(seqNum)}
   168  					expectedNonce := mustDecodeHex(t, enc["nonce"])
   169  					computedNonce := sender.nextNonce()
   170  					if !bytes.Equal(computedNonce, expectedNonce) {
   171  						t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce)
   172  					}
   173  
   174  					expectedCiphertext := mustDecodeHex(t, enc["ct"])
   175  					ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"]))
   176  					if err != nil {
   177  						t.Fatal(err)
   178  					}
   179  					if !bytes.Equal(ciphertext, expectedCiphertext) {
   180  						t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext)
   181  					}
   182  
   183  					expectedPlaintext := mustDecodeHex(t, enc["pt"])
   184  					plaintext, err := receipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"]))
   185  					if err != nil {
   186  						t.Fatal(err)
   187  					}
   188  					if !bytes.Equal(plaintext, expectedPlaintext) {
   189  						t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext)
   190  					}
   191  				})
   192  			}
   193  		})
   194  	}
   195  }
   196  

View as plain text