Source file src/crypto/internal/hpke/hpke.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"crypto"
     9  	"crypto/aes"
    10  	"crypto/cipher"
    11  	"crypto/ecdh"
    12  	"crypto/internal/fips140/hkdf"
    13  	"crypto/rand"
    14  	"errors"
    15  	"internal/byteorder"
    16  	"math/bits"
    17  
    18  	"golang.org/x/crypto/chacha20poly1305"
    19  )
    20  
    21  // testingOnlyGenerateKey is only used during testing, to provide
    22  // a fixed test key to use when checking the RFC 9180 vectors.
    23  var testingOnlyGenerateKey func() (*ecdh.PrivateKey, error)
    24  
    25  type hkdfKDF struct {
    26  	hash crypto.Hash
    27  }
    28  
    29  func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte {
    30  	labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey))
    31  	labeledIKM = append(labeledIKM, []byte("HPKE-v1")...)
    32  	labeledIKM = append(labeledIKM, sid...)
    33  	labeledIKM = append(labeledIKM, label...)
    34  	labeledIKM = append(labeledIKM, inputKey...)
    35  	return hkdf.Extract(kdf.hash.New, labeledIKM, salt)
    36  }
    37  
    38  func (kdf *hkdfKDF) LabeledExpand(suiteID []byte, randomKey []byte, label string, info []byte, length uint16) []byte {
    39  	labeledInfo := make([]byte, 0, 2+7+len(suiteID)+len(label)+len(info))
    40  	labeledInfo = byteorder.BEAppendUint16(labeledInfo, length)
    41  	labeledInfo = append(labeledInfo, []byte("HPKE-v1")...)
    42  	labeledInfo = append(labeledInfo, suiteID...)
    43  	labeledInfo = append(labeledInfo, label...)
    44  	labeledInfo = append(labeledInfo, info...)
    45  	return hkdf.Expand(kdf.hash.New, randomKey, string(labeledInfo), int(length))
    46  }
    47  
    48  // dhKEM implements the KEM specified in RFC 9180, Section 4.1.
    49  type dhKEM struct {
    50  	dh  ecdh.Curve
    51  	kdf hkdfKDF
    52  
    53  	suiteID []byte
    54  	nSecret uint16
    55  }
    56  
    57  type KemID uint16
    58  
    59  const DHKEM_X25519_HKDF_SHA256 = 0x0020
    60  
    61  var SupportedKEMs = map[uint16]struct {
    62  	curve   ecdh.Curve
    63  	hash    crypto.Hash
    64  	nSecret uint16
    65  }{
    66  	// RFC 9180 Section 7.1
    67  	DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32},
    68  }
    69  
    70  func newDHKem(kemID uint16) (*dhKEM, error) {
    71  	suite, ok := SupportedKEMs[kemID]
    72  	if !ok {
    73  		return nil, errors.New("unsupported suite ID")
    74  	}
    75  	return &dhKEM{
    76  		dh:      suite.curve,
    77  		kdf:     hkdfKDF{suite.hash},
    78  		suiteID: byteorder.BEAppendUint16([]byte("KEM"), kemID),
    79  		nSecret: suite.nSecret,
    80  	}, nil
    81  }
    82  
    83  func (dh *dhKEM) ExtractAndExpand(dhKey, kemContext []byte) []byte {
    84  	eaePRK := dh.kdf.LabeledExtract(dh.suiteID[:], nil, "eae_prk", dhKey)
    85  	return dh.kdf.LabeledExpand(dh.suiteID[:], eaePRK, "shared_secret", kemContext, dh.nSecret)
    86  }
    87  
    88  func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encapPub []byte, err error) {
    89  	var privEph *ecdh.PrivateKey
    90  	if testingOnlyGenerateKey != nil {
    91  		privEph, err = testingOnlyGenerateKey()
    92  	} else {
    93  		privEph, err = dh.dh.GenerateKey(rand.Reader)
    94  	}
    95  	if err != nil {
    96  		return nil, nil, err
    97  	}
    98  	dhVal, err := privEph.ECDH(pubRecipient)
    99  	if err != nil {
   100  		return nil, nil, err
   101  	}
   102  	encPubEph := privEph.PublicKey().Bytes()
   103  
   104  	encPubRecip := pubRecipient.Bytes()
   105  	kemContext := append(encPubEph, encPubRecip...)
   106  
   107  	return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil
   108  }
   109  
   110  func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) {
   111  	pubEph, err := dh.dh.NewPublicKey(encPubEph)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	dhVal, err := secRecipient.ECDH(pubEph)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...)
   120  
   121  	return dh.ExtractAndExpand(dhVal, kemContext), nil
   122  }
   123  
   124  type context struct {
   125  	aead cipher.AEAD
   126  
   127  	sharedSecret []byte
   128  
   129  	suiteID []byte
   130  
   131  	key            []byte
   132  	baseNonce      []byte
   133  	exporterSecret []byte
   134  
   135  	seqNum uint128
   136  }
   137  
   138  type Sender struct {
   139  	*context
   140  }
   141  
   142  type Receipient struct {
   143  	*context
   144  }
   145  
   146  var aesGCMNew = func(key []byte) (cipher.AEAD, error) {
   147  	block, err := aes.NewCipher(key)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	return cipher.NewGCM(block)
   152  }
   153  
   154  type AEADID uint16
   155  
   156  const (
   157  	AEAD_AES_128_GCM      = 0x0001
   158  	AEAD_AES_256_GCM      = 0x0002
   159  	AEAD_ChaCha20Poly1305 = 0x0003
   160  )
   161  
   162  var SupportedAEADs = map[uint16]struct {
   163  	keySize   int
   164  	nonceSize int
   165  	aead      func([]byte) (cipher.AEAD, error)
   166  }{
   167  	// RFC 9180, Section 7.3
   168  	AEAD_AES_128_GCM:      {keySize: 16, nonceSize: 12, aead: aesGCMNew},
   169  	AEAD_AES_256_GCM:      {keySize: 32, nonceSize: 12, aead: aesGCMNew},
   170  	AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New},
   171  }
   172  
   173  type KDFID uint16
   174  
   175  const KDF_HKDF_SHA256 = 0x0001
   176  
   177  var SupportedKDFs = map[uint16]func() *hkdfKDF{
   178  	// RFC 9180, Section 7.2
   179  	KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} },
   180  }
   181  
   182  func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) {
   183  	sid := suiteID(kemID, kdfID, aeadID)
   184  
   185  	kdfInit, ok := SupportedKDFs[kdfID]
   186  	if !ok {
   187  		return nil, errors.New("unsupported KDF id")
   188  	}
   189  	kdf := kdfInit()
   190  
   191  	aeadInfo, ok := SupportedAEADs[aeadID]
   192  	if !ok {
   193  		return nil, errors.New("unsupported AEAD id")
   194  	}
   195  
   196  	pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil)
   197  	infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info)
   198  	ksContext := append([]byte{0}, pskIDHash...)
   199  	ksContext = append(ksContext, infoHash...)
   200  
   201  	secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil)
   202  
   203  	key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */)
   204  	baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */)
   205  	exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/)
   206  
   207  	aead, err := aeadInfo.aead(key)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	return &context{
   213  		aead:           aead,
   214  		sharedSecret:   sharedSecret,
   215  		suiteID:        sid,
   216  		key:            key,
   217  		baseNonce:      baseNonce,
   218  		exporterSecret: exporterSecret,
   219  	}, nil
   220  }
   221  
   222  func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) {
   223  	kem, err := newDHKem(kemID)
   224  	if err != nil {
   225  		return nil, nil, err
   226  	}
   227  	sharedSecret, encapsulatedKey, err := kem.Encap(pub)
   228  	if err != nil {
   229  		return nil, nil, err
   230  	}
   231  
   232  	context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
   233  	if err != nil {
   234  		return nil, nil, err
   235  	}
   236  
   237  	return encapsulatedKey, &Sender{context}, nil
   238  }
   239  
   240  func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) {
   241  	kem, err := newDHKem(kemID)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	sharedSecret, err := kem.Decap(encPubEph, priv)
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	return &Receipient{context}, nil
   256  }
   257  
   258  func (ctx *context) nextNonce() []byte {
   259  	nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():]
   260  	for i := range ctx.baseNonce {
   261  		nonce[i] ^= ctx.baseNonce[i]
   262  	}
   263  	return nonce
   264  }
   265  
   266  func (ctx *context) incrementNonce() {
   267  	// Message limit is, according to the RFC, 2^95+1, which
   268  	// is somewhat confusing, but we do as we're told.
   269  	if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 {
   270  		panic("message limit reached")
   271  	}
   272  	ctx.seqNum = ctx.seqNum.addOne()
   273  }
   274  
   275  func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) {
   276  	ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad)
   277  	s.incrementNonce()
   278  	return ciphertext, nil
   279  }
   280  
   281  func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) {
   282  	plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	r.incrementNonce()
   287  	return plaintext, nil
   288  }
   289  
   290  func suiteID(kemID, kdfID, aeadID uint16) []byte {
   291  	suiteID := make([]byte, 0, 4+2+2+2)
   292  	suiteID = append(suiteID, []byte("HPKE")...)
   293  	suiteID = byteorder.BEAppendUint16(suiteID, kemID)
   294  	suiteID = byteorder.BEAppendUint16(suiteID, kdfID)
   295  	suiteID = byteorder.BEAppendUint16(suiteID, aeadID)
   296  	return suiteID
   297  }
   298  
   299  func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) {
   300  	kemInfo, ok := SupportedKEMs[kemID]
   301  	if !ok {
   302  		return nil, errors.New("unsupported KEM id")
   303  	}
   304  	return kemInfo.curve.NewPublicKey(bytes)
   305  }
   306  
   307  func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) {
   308  	kemInfo, ok := SupportedKEMs[kemID]
   309  	if !ok {
   310  		return nil, errors.New("unsupported KEM id")
   311  	}
   312  	return kemInfo.curve.NewPrivateKey(bytes)
   313  }
   314  
   315  type uint128 struct {
   316  	hi, lo uint64
   317  }
   318  
   319  func (u uint128) addOne() uint128 {
   320  	lo, carry := bits.Add64(u.lo, 1, 0)
   321  	return uint128{u.hi + carry, lo}
   322  }
   323  
   324  func (u uint128) bitLen() int {
   325  	return bits.Len64(u.hi) + bits.Len64(u.lo)
   326  }
   327  
   328  func (u uint128) bytes() []byte {
   329  	b := make([]byte, 16)
   330  	byteorder.BEPutUint64(b[0:], u.hi)
   331  	byteorder.BEPutUint64(b[8:], u.lo)
   332  	return b
   333  }
   334  

View as plain text