Source file src/crypto/tls/handshake_messages.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"strings"
    11  
    12  	"golang.org/x/crypto/cryptobyte"
    13  )
    14  
    15  // The marshalingFunction type is an adapter to allow the use of ordinary
    16  // functions as cryptobyte.MarshalingValue.
    17  type marshalingFunction func(b *cryptobyte.Builder) error
    18  
    19  func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
    20  	return f(b)
    21  }
    22  
    23  // addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
    24  // the length of the sequence is not the value specified, it produces an error.
    25  func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
    26  	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
    27  		if len(v) != n {
    28  			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
    29  		}
    30  		b.AddBytes(v)
    31  		return nil
    32  	}))
    33  }
    34  
    35  // addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
    36  func addUint64(b *cryptobyte.Builder, v uint64) {
    37  	b.AddUint32(uint32(v >> 32))
    38  	b.AddUint32(uint32(v))
    39  }
    40  
    41  // readUint64 decodes a big-endian, 64-bit value into out and advances over it.
    42  // It reports whether the read was successful.
    43  func readUint64(s *cryptobyte.String, out *uint64) bool {
    44  	var hi, lo uint32
    45  	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
    46  		return false
    47  	}
    48  	*out = uint64(hi)<<32 | uint64(lo)
    49  	return true
    50  }
    51  
    52  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
    53  // []byte instead of a cryptobyte.String.
    54  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    55  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
    56  }
    57  
    58  // readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
    59  // []byte instead of a cryptobyte.String.
    60  func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    61  	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
    62  }
    63  
    64  // readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
    65  // []byte instead of a cryptobyte.String.
    66  func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
    67  	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
    68  }
    69  
    70  type clientHelloMsg struct {
    71  	original                         []byte
    72  	vers                             uint16
    73  	random                           []byte
    74  	sessionId                        []byte
    75  	cipherSuites                     []uint16
    76  	compressionMethods               []uint8
    77  	serverName                       string
    78  	ocspStapling                     bool
    79  	supportedCurves                  []CurveID
    80  	supportedPoints                  []uint8
    81  	ticketSupported                  bool
    82  	sessionTicket                    []uint8
    83  	supportedSignatureAlgorithms     []SignatureScheme
    84  	supportedSignatureAlgorithmsCert []SignatureScheme
    85  	secureRenegotiationSupported     bool
    86  	secureRenegotiation              []byte
    87  	extendedMasterSecret             bool
    88  	alpnProtocols                    []string
    89  	scts                             bool
    90  	supportedVersions                []uint16
    91  	cookie                           []byte
    92  	keyShares                        []keyShare
    93  	earlyData                        bool
    94  	pskModes                         []uint8
    95  	pskIdentities                    []pskIdentity
    96  	pskBinders                       [][]byte
    97  	quicTransportParameters          []byte
    98  }
    99  
   100  func (m *clientHelloMsg) marshal() ([]byte, error) {
   101  	var exts cryptobyte.Builder
   102  	if len(m.serverName) > 0 {
   103  		// RFC 6066, Section 3
   104  		exts.AddUint16(extensionServerName)
   105  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   106  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   107  				exts.AddUint8(0) // name_type = host_name
   108  				exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   109  					exts.AddBytes([]byte(m.serverName))
   110  				})
   111  			})
   112  		})
   113  	}
   114  	if m.ocspStapling {
   115  		// RFC 4366, Section 3.6
   116  		exts.AddUint16(extensionStatusRequest)
   117  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   118  			exts.AddUint8(1)  // status_type = ocsp
   119  			exts.AddUint16(0) // empty responder_id_list
   120  			exts.AddUint16(0) // empty request_extensions
   121  		})
   122  	}
   123  	if len(m.supportedCurves) > 0 {
   124  		// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   125  		exts.AddUint16(extensionSupportedCurves)
   126  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   127  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   128  				for _, curve := range m.supportedCurves {
   129  					exts.AddUint16(uint16(curve))
   130  				}
   131  			})
   132  		})
   133  	}
   134  	if len(m.supportedPoints) > 0 {
   135  		// RFC 4492, Section 5.1.2
   136  		exts.AddUint16(extensionSupportedPoints)
   137  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   138  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   139  				exts.AddBytes(m.supportedPoints)
   140  			})
   141  		})
   142  	}
   143  	if m.ticketSupported {
   144  		// RFC 5077, Section 3.2
   145  		exts.AddUint16(extensionSessionTicket)
   146  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   147  			exts.AddBytes(m.sessionTicket)
   148  		})
   149  	}
   150  	if len(m.supportedSignatureAlgorithms) > 0 {
   151  		// RFC 5246, Section 7.4.1.4.1
   152  		exts.AddUint16(extensionSignatureAlgorithms)
   153  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   154  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   155  				for _, sigAlgo := range m.supportedSignatureAlgorithms {
   156  					exts.AddUint16(uint16(sigAlgo))
   157  				}
   158  			})
   159  		})
   160  	}
   161  	if len(m.supportedSignatureAlgorithmsCert) > 0 {
   162  		// RFC 8446, Section 4.2.3
   163  		exts.AddUint16(extensionSignatureAlgorithmsCert)
   164  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   165  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   166  				for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
   167  					exts.AddUint16(uint16(sigAlgo))
   168  				}
   169  			})
   170  		})
   171  	}
   172  	if m.secureRenegotiationSupported {
   173  		// RFC 5746, Section 3.2
   174  		exts.AddUint16(extensionRenegotiationInfo)
   175  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   176  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   177  				exts.AddBytes(m.secureRenegotiation)
   178  			})
   179  		})
   180  	}
   181  	if m.extendedMasterSecret {
   182  		// RFC 7627
   183  		exts.AddUint16(extensionExtendedMasterSecret)
   184  		exts.AddUint16(0) // empty extension_data
   185  	}
   186  	if len(m.alpnProtocols) > 0 {
   187  		// RFC 7301, Section 3.1
   188  		exts.AddUint16(extensionALPN)
   189  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   190  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   191  				for _, proto := range m.alpnProtocols {
   192  					exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   193  						exts.AddBytes([]byte(proto))
   194  					})
   195  				}
   196  			})
   197  		})
   198  	}
   199  	if m.scts {
   200  		// RFC 6962, Section 3.3.1
   201  		exts.AddUint16(extensionSCT)
   202  		exts.AddUint16(0) // empty extension_data
   203  	}
   204  	if len(m.supportedVersions) > 0 {
   205  		// RFC 8446, Section 4.2.1
   206  		exts.AddUint16(extensionSupportedVersions)
   207  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   208  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   209  				for _, vers := range m.supportedVersions {
   210  					exts.AddUint16(vers)
   211  				}
   212  			})
   213  		})
   214  	}
   215  	if len(m.cookie) > 0 {
   216  		// RFC 8446, Section 4.2.2
   217  		exts.AddUint16(extensionCookie)
   218  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   219  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   220  				exts.AddBytes(m.cookie)
   221  			})
   222  		})
   223  	}
   224  	if len(m.keyShares) > 0 {
   225  		// RFC 8446, Section 4.2.8
   226  		exts.AddUint16(extensionKeyShare)
   227  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   228  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   229  				for _, ks := range m.keyShares {
   230  					exts.AddUint16(uint16(ks.group))
   231  					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   232  						exts.AddBytes(ks.data)
   233  					})
   234  				}
   235  			})
   236  		})
   237  	}
   238  	if m.earlyData {
   239  		// RFC 8446, Section 4.2.10
   240  		exts.AddUint16(extensionEarlyData)
   241  		exts.AddUint16(0) // empty extension_data
   242  	}
   243  	if len(m.pskModes) > 0 {
   244  		// RFC 8446, Section 4.2.9
   245  		exts.AddUint16(extensionPSKModes)
   246  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   247  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   248  				exts.AddBytes(m.pskModes)
   249  			})
   250  		})
   251  	}
   252  	if m.quicTransportParameters != nil { // marshal zero-length parameters when present
   253  		// RFC 9001, Section 8.2
   254  		exts.AddUint16(extensionQUICTransportParameters)
   255  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   256  			exts.AddBytes(m.quicTransportParameters)
   257  		})
   258  	}
   259  	if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
   260  		// RFC 8446, Section 4.2.11
   261  		exts.AddUint16(extensionPreSharedKey)
   262  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   263  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   264  				for _, psk := range m.pskIdentities {
   265  					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   266  						exts.AddBytes(psk.label)
   267  					})
   268  					exts.AddUint32(psk.obfuscatedTicketAge)
   269  				}
   270  			})
   271  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   272  				for _, binder := range m.pskBinders {
   273  					exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   274  						exts.AddBytes(binder)
   275  					})
   276  				}
   277  			})
   278  		})
   279  	}
   280  	extBytes, err := exts.Bytes()
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  
   285  	var b cryptobyte.Builder
   286  	b.AddUint8(typeClientHello)
   287  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   288  		b.AddUint16(m.vers)
   289  		addBytesWithLength(b, m.random, 32)
   290  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   291  			b.AddBytes(m.sessionId)
   292  		})
   293  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   294  			for _, suite := range m.cipherSuites {
   295  				b.AddUint16(suite)
   296  			}
   297  		})
   298  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   299  			b.AddBytes(m.compressionMethods)
   300  		})
   301  
   302  		if len(extBytes) > 0 {
   303  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   304  				b.AddBytes(extBytes)
   305  			})
   306  		}
   307  	})
   308  
   309  	return b.Bytes()
   310  }
   311  
   312  // marshalWithoutBinders returns the ClientHello through the
   313  // PreSharedKeyExtension.identities field, according to RFC 8446, Section
   314  // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
   315  func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
   316  	bindersLen := 2 // uint16 length prefix
   317  	for _, binder := range m.pskBinders {
   318  		bindersLen += 1 // uint8 length prefix
   319  		bindersLen += len(binder)
   320  	}
   321  
   322  	var fullMessage []byte
   323  	if m.original != nil {
   324  		fullMessage = m.original
   325  	} else {
   326  		var err error
   327  		fullMessage, err = m.marshal()
   328  		if err != nil {
   329  			return nil, err
   330  		}
   331  	}
   332  	return fullMessage[:len(fullMessage)-bindersLen], nil
   333  }
   334  
   335  // updateBinders updates the m.pskBinders field. The supplied binders must have
   336  // the same length as the current m.pskBinders.
   337  func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
   338  	if len(pskBinders) != len(m.pskBinders) {
   339  		return errors.New("tls: internal error: pskBinders length mismatch")
   340  	}
   341  	for i := range m.pskBinders {
   342  		if len(pskBinders[i]) != len(m.pskBinders[i]) {
   343  			return errors.New("tls: internal error: pskBinders length mismatch")
   344  		}
   345  	}
   346  	m.pskBinders = pskBinders
   347  
   348  	return nil
   349  }
   350  
   351  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   352  	*m = clientHelloMsg{original: data}
   353  	s := cryptobyte.String(data)
   354  
   355  	if !s.Skip(4) || // message type and uint24 length field
   356  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   357  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   358  		return false
   359  	}
   360  
   361  	var cipherSuites cryptobyte.String
   362  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   363  		return false
   364  	}
   365  	m.cipherSuites = []uint16{}
   366  	m.secureRenegotiationSupported = false
   367  	for !cipherSuites.Empty() {
   368  		var suite uint16
   369  		if !cipherSuites.ReadUint16(&suite) {
   370  			return false
   371  		}
   372  		if suite == scsvRenegotiation {
   373  			m.secureRenegotiationSupported = true
   374  		}
   375  		m.cipherSuites = append(m.cipherSuites, suite)
   376  	}
   377  
   378  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   379  		return false
   380  	}
   381  
   382  	if s.Empty() {
   383  		// ClientHello is optionally followed by extension data
   384  		return true
   385  	}
   386  
   387  	var extensions cryptobyte.String
   388  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   389  		return false
   390  	}
   391  
   392  	seenExts := make(map[uint16]bool)
   393  	for !extensions.Empty() {
   394  		var extension uint16
   395  		var extData cryptobyte.String
   396  		if !extensions.ReadUint16(&extension) ||
   397  			!extensions.ReadUint16LengthPrefixed(&extData) {
   398  			return false
   399  		}
   400  
   401  		if seenExts[extension] {
   402  			return false
   403  		}
   404  		seenExts[extension] = true
   405  
   406  		switch extension {
   407  		case extensionServerName:
   408  			// RFC 6066, Section 3
   409  			var nameList cryptobyte.String
   410  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   411  				return false
   412  			}
   413  			for !nameList.Empty() {
   414  				var nameType uint8
   415  				var serverName cryptobyte.String
   416  				if !nameList.ReadUint8(&nameType) ||
   417  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   418  					serverName.Empty() {
   419  					return false
   420  				}
   421  				if nameType != 0 {
   422  					continue
   423  				}
   424  				if len(m.serverName) != 0 {
   425  					// Multiple names of the same name_type are prohibited.
   426  					return false
   427  				}
   428  				m.serverName = string(serverName)
   429  				// An SNI value may not include a trailing dot.
   430  				if strings.HasSuffix(m.serverName, ".") {
   431  					return false
   432  				}
   433  			}
   434  		case extensionStatusRequest:
   435  			// RFC 4366, Section 3.6
   436  			var statusType uint8
   437  			var ignored cryptobyte.String
   438  			if !extData.ReadUint8(&statusType) ||
   439  				!extData.ReadUint16LengthPrefixed(&ignored) ||
   440  				!extData.ReadUint16LengthPrefixed(&ignored) {
   441  				return false
   442  			}
   443  			m.ocspStapling = statusType == statusTypeOCSP
   444  		case extensionSupportedCurves:
   445  			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
   446  			var curves cryptobyte.String
   447  			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
   448  				return false
   449  			}
   450  			for !curves.Empty() {
   451  				var curve uint16
   452  				if !curves.ReadUint16(&curve) {
   453  					return false
   454  				}
   455  				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
   456  			}
   457  		case extensionSupportedPoints:
   458  			// RFC 4492, Section 5.1.2
   459  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   460  				len(m.supportedPoints) == 0 {
   461  				return false
   462  			}
   463  		case extensionSessionTicket:
   464  			// RFC 5077, Section 3.2
   465  			m.ticketSupported = true
   466  			extData.ReadBytes(&m.sessionTicket, len(extData))
   467  		case extensionSignatureAlgorithms:
   468  			// RFC 5246, Section 7.4.1.4.1
   469  			var sigAndAlgs cryptobyte.String
   470  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   471  				return false
   472  			}
   473  			for !sigAndAlgs.Empty() {
   474  				var sigAndAlg uint16
   475  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   476  					return false
   477  				}
   478  				m.supportedSignatureAlgorithms = append(
   479  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
   480  			}
   481  		case extensionSignatureAlgorithmsCert:
   482  			// RFC 8446, Section 4.2.3
   483  			var sigAndAlgs cryptobyte.String
   484  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
   485  				return false
   486  			}
   487  			for !sigAndAlgs.Empty() {
   488  				var sigAndAlg uint16
   489  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
   490  					return false
   491  				}
   492  				m.supportedSignatureAlgorithmsCert = append(
   493  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
   494  			}
   495  		case extensionRenegotiationInfo:
   496  			// RFC 5746, Section 3.2
   497  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   498  				return false
   499  			}
   500  			m.secureRenegotiationSupported = true
   501  		case extensionExtendedMasterSecret:
   502  			// RFC 7627
   503  			m.extendedMasterSecret = true
   504  		case extensionALPN:
   505  			// RFC 7301, Section 3.1
   506  			var protoList cryptobyte.String
   507  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   508  				return false
   509  			}
   510  			for !protoList.Empty() {
   511  				var proto cryptobyte.String
   512  				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
   513  					return false
   514  				}
   515  				m.alpnProtocols = append(m.alpnProtocols, string(proto))
   516  			}
   517  		case extensionSCT:
   518  			// RFC 6962, Section 3.3.1
   519  			m.scts = true
   520  		case extensionSupportedVersions:
   521  			// RFC 8446, Section 4.2.1
   522  			var versList cryptobyte.String
   523  			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
   524  				return false
   525  			}
   526  			for !versList.Empty() {
   527  				var vers uint16
   528  				if !versList.ReadUint16(&vers) {
   529  					return false
   530  				}
   531  				m.supportedVersions = append(m.supportedVersions, vers)
   532  			}
   533  		case extensionCookie:
   534  			// RFC 8446, Section 4.2.2
   535  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   536  				len(m.cookie) == 0 {
   537  				return false
   538  			}
   539  		case extensionKeyShare:
   540  			// RFC 8446, Section 4.2.8
   541  			var clientShares cryptobyte.String
   542  			if !extData.ReadUint16LengthPrefixed(&clientShares) {
   543  				return false
   544  			}
   545  			for !clientShares.Empty() {
   546  				var ks keyShare
   547  				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
   548  					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
   549  					len(ks.data) == 0 {
   550  					return false
   551  				}
   552  				m.keyShares = append(m.keyShares, ks)
   553  			}
   554  		case extensionEarlyData:
   555  			// RFC 8446, Section 4.2.10
   556  			m.earlyData = true
   557  		case extensionPSKModes:
   558  			// RFC 8446, Section 4.2.9
   559  			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
   560  				return false
   561  			}
   562  		case extensionQUICTransportParameters:
   563  			m.quicTransportParameters = make([]byte, len(extData))
   564  			if !extData.CopyBytes(m.quicTransportParameters) {
   565  				return false
   566  			}
   567  		case extensionPreSharedKey:
   568  			// RFC 8446, Section 4.2.11
   569  			if !extensions.Empty() {
   570  				return false // pre_shared_key must be the last extension
   571  			}
   572  			var identities cryptobyte.String
   573  			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
   574  				return false
   575  			}
   576  			for !identities.Empty() {
   577  				var psk pskIdentity
   578  				if !readUint16LengthPrefixed(&identities, &psk.label) ||
   579  					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
   580  					len(psk.label) == 0 {
   581  					return false
   582  				}
   583  				m.pskIdentities = append(m.pskIdentities, psk)
   584  			}
   585  			var binders cryptobyte.String
   586  			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
   587  				return false
   588  			}
   589  			for !binders.Empty() {
   590  				var binder []byte
   591  				if !readUint8LengthPrefixed(&binders, &binder) ||
   592  					len(binder) == 0 {
   593  					return false
   594  				}
   595  				m.pskBinders = append(m.pskBinders, binder)
   596  			}
   597  		default:
   598  			// Ignore unknown extensions.
   599  			continue
   600  		}
   601  
   602  		if !extData.Empty() {
   603  			return false
   604  		}
   605  	}
   606  
   607  	return true
   608  }
   609  
   610  func (m *clientHelloMsg) originalBytes() []byte {
   611  	return m.original
   612  }
   613  
   614  type serverHelloMsg struct {
   615  	original                     []byte
   616  	vers                         uint16
   617  	random                       []byte
   618  	sessionId                    []byte
   619  	cipherSuite                  uint16
   620  	compressionMethod            uint8
   621  	ocspStapling                 bool
   622  	ticketSupported              bool
   623  	secureRenegotiationSupported bool
   624  	secureRenegotiation          []byte
   625  	extendedMasterSecret         bool
   626  	alpnProtocol                 string
   627  	scts                         [][]byte
   628  	supportedVersion             uint16
   629  	serverShare                  keyShare
   630  	selectedIdentityPresent      bool
   631  	selectedIdentity             uint16
   632  	supportedPoints              []uint8
   633  
   634  	// HelloRetryRequest extensions
   635  	cookie        []byte
   636  	selectedGroup CurveID
   637  }
   638  
   639  func (m *serverHelloMsg) marshal() ([]byte, error) {
   640  	var exts cryptobyte.Builder
   641  	if m.ocspStapling {
   642  		exts.AddUint16(extensionStatusRequest)
   643  		exts.AddUint16(0) // empty extension_data
   644  	}
   645  	if m.ticketSupported {
   646  		exts.AddUint16(extensionSessionTicket)
   647  		exts.AddUint16(0) // empty extension_data
   648  	}
   649  	if m.secureRenegotiationSupported {
   650  		exts.AddUint16(extensionRenegotiationInfo)
   651  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   652  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   653  				exts.AddBytes(m.secureRenegotiation)
   654  			})
   655  		})
   656  	}
   657  	if m.extendedMasterSecret {
   658  		exts.AddUint16(extensionExtendedMasterSecret)
   659  		exts.AddUint16(0) // empty extension_data
   660  	}
   661  	if len(m.alpnProtocol) > 0 {
   662  		exts.AddUint16(extensionALPN)
   663  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   664  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   665  				exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   666  					exts.AddBytes([]byte(m.alpnProtocol))
   667  				})
   668  			})
   669  		})
   670  	}
   671  	if len(m.scts) > 0 {
   672  		exts.AddUint16(extensionSCT)
   673  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   674  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   675  				for _, sct := range m.scts {
   676  					exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   677  						exts.AddBytes(sct)
   678  					})
   679  				}
   680  			})
   681  		})
   682  	}
   683  	if m.supportedVersion != 0 {
   684  		exts.AddUint16(extensionSupportedVersions)
   685  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   686  			exts.AddUint16(m.supportedVersion)
   687  		})
   688  	}
   689  	if m.serverShare.group != 0 {
   690  		exts.AddUint16(extensionKeyShare)
   691  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   692  			exts.AddUint16(uint16(m.serverShare.group))
   693  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   694  				exts.AddBytes(m.serverShare.data)
   695  			})
   696  		})
   697  	}
   698  	if m.selectedIdentityPresent {
   699  		exts.AddUint16(extensionPreSharedKey)
   700  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   701  			exts.AddUint16(m.selectedIdentity)
   702  		})
   703  	}
   704  
   705  	if len(m.cookie) > 0 {
   706  		exts.AddUint16(extensionCookie)
   707  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   708  			exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   709  				exts.AddBytes(m.cookie)
   710  			})
   711  		})
   712  	}
   713  	if m.selectedGroup != 0 {
   714  		exts.AddUint16(extensionKeyShare)
   715  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   716  			exts.AddUint16(uint16(m.selectedGroup))
   717  		})
   718  	}
   719  	if len(m.supportedPoints) > 0 {
   720  		exts.AddUint16(extensionSupportedPoints)
   721  		exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
   722  			exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
   723  				exts.AddBytes(m.supportedPoints)
   724  			})
   725  		})
   726  	}
   727  
   728  	extBytes, err := exts.Bytes()
   729  	if err != nil {
   730  		return nil, err
   731  	}
   732  
   733  	var b cryptobyte.Builder
   734  	b.AddUint8(typeServerHello)
   735  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   736  		b.AddUint16(m.vers)
   737  		addBytesWithLength(b, m.random, 32)
   738  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   739  			b.AddBytes(m.sessionId)
   740  		})
   741  		b.AddUint16(m.cipherSuite)
   742  		b.AddUint8(m.compressionMethod)
   743  
   744  		if len(extBytes) > 0 {
   745  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   746  				b.AddBytes(extBytes)
   747  			})
   748  		}
   749  	})
   750  
   751  	return b.Bytes()
   752  }
   753  
   754  func (m *serverHelloMsg) unmarshal(data []byte) bool {
   755  	*m = serverHelloMsg{original: data}
   756  	s := cryptobyte.String(data)
   757  
   758  	if !s.Skip(4) || // message type and uint24 length field
   759  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   760  		!readUint8LengthPrefixed(&s, &m.sessionId) ||
   761  		!s.ReadUint16(&m.cipherSuite) ||
   762  		!s.ReadUint8(&m.compressionMethod) {
   763  		return false
   764  	}
   765  
   766  	if s.Empty() {
   767  		// ServerHello is optionally followed by extension data
   768  		return true
   769  	}
   770  
   771  	var extensions cryptobyte.String
   772  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   773  		return false
   774  	}
   775  
   776  	seenExts := make(map[uint16]bool)
   777  	for !extensions.Empty() {
   778  		var extension uint16
   779  		var extData cryptobyte.String
   780  		if !extensions.ReadUint16(&extension) ||
   781  			!extensions.ReadUint16LengthPrefixed(&extData) {
   782  			return false
   783  		}
   784  
   785  		if seenExts[extension] {
   786  			return false
   787  		}
   788  		seenExts[extension] = true
   789  
   790  		switch extension {
   791  		case extensionStatusRequest:
   792  			m.ocspStapling = true
   793  		case extensionSessionTicket:
   794  			m.ticketSupported = true
   795  		case extensionRenegotiationInfo:
   796  			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
   797  				return false
   798  			}
   799  			m.secureRenegotiationSupported = true
   800  		case extensionExtendedMasterSecret:
   801  			m.extendedMasterSecret = true
   802  		case extensionALPN:
   803  			var protoList cryptobyte.String
   804  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   805  				return false
   806  			}
   807  			var proto cryptobyte.String
   808  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   809  				proto.Empty() || !protoList.Empty() {
   810  				return false
   811  			}
   812  			m.alpnProtocol = string(proto)
   813  		case extensionSCT:
   814  			var sctList cryptobyte.String
   815  			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
   816  				return false
   817  			}
   818  			for !sctList.Empty() {
   819  				var sct []byte
   820  				if !readUint16LengthPrefixed(&sctList, &sct) ||
   821  					len(sct) == 0 {
   822  					return false
   823  				}
   824  				m.scts = append(m.scts, sct)
   825  			}
   826  		case extensionSupportedVersions:
   827  			if !extData.ReadUint16(&m.supportedVersion) {
   828  				return false
   829  			}
   830  		case extensionCookie:
   831  			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
   832  				len(m.cookie) == 0 {
   833  				return false
   834  			}
   835  		case extensionKeyShare:
   836  			// This extension has different formats in SH and HRR, accept either
   837  			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
   838  			if len(extData) == 2 {
   839  				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
   840  					return false
   841  				}
   842  			} else {
   843  				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
   844  					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
   845  					return false
   846  				}
   847  			}
   848  		case extensionPreSharedKey:
   849  			m.selectedIdentityPresent = true
   850  			if !extData.ReadUint16(&m.selectedIdentity) {
   851  				return false
   852  			}
   853  		case extensionSupportedPoints:
   854  			// RFC 4492, Section 5.1.2
   855  			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
   856  				len(m.supportedPoints) == 0 {
   857  				return false
   858  			}
   859  		default:
   860  			// Ignore unknown extensions.
   861  			continue
   862  		}
   863  
   864  		if !extData.Empty() {
   865  			return false
   866  		}
   867  	}
   868  
   869  	return true
   870  }
   871  
   872  func (m *serverHelloMsg) originalBytes() []byte {
   873  	return m.original
   874  }
   875  
   876  type encryptedExtensionsMsg struct {
   877  	alpnProtocol            string
   878  	quicTransportParameters []byte
   879  	earlyData               bool
   880  }
   881  
   882  func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
   883  	var b cryptobyte.Builder
   884  	b.AddUint8(typeEncryptedExtensions)
   885  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   886  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   887  			if len(m.alpnProtocol) > 0 {
   888  				b.AddUint16(extensionALPN)
   889  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   890  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   891  						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
   892  							b.AddBytes([]byte(m.alpnProtocol))
   893  						})
   894  					})
   895  				})
   896  			}
   897  			if m.quicTransportParameters != nil { // marshal zero-length parameters when present
   898  				// draft-ietf-quic-tls-32, Section 8.2
   899  				b.AddUint16(extensionQUICTransportParameters)
   900  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
   901  					b.AddBytes(m.quicTransportParameters)
   902  				})
   903  			}
   904  			if m.earlyData {
   905  				// RFC 8446, Section 4.2.10
   906  				b.AddUint16(extensionEarlyData)
   907  				b.AddUint16(0) // empty extension_data
   908  			}
   909  		})
   910  	})
   911  
   912  	return b.Bytes()
   913  }
   914  
   915  func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
   916  	*m = encryptedExtensionsMsg{}
   917  	s := cryptobyte.String(data)
   918  
   919  	var extensions cryptobyte.String
   920  	if !s.Skip(4) || // message type and uint24 length field
   921  		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   922  		return false
   923  	}
   924  
   925  	for !extensions.Empty() {
   926  		var extension uint16
   927  		var extData cryptobyte.String
   928  		if !extensions.ReadUint16(&extension) ||
   929  			!extensions.ReadUint16LengthPrefixed(&extData) {
   930  			return false
   931  		}
   932  
   933  		switch extension {
   934  		case extensionALPN:
   935  			var protoList cryptobyte.String
   936  			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
   937  				return false
   938  			}
   939  			var proto cryptobyte.String
   940  			if !protoList.ReadUint8LengthPrefixed(&proto) ||
   941  				proto.Empty() || !protoList.Empty() {
   942  				return false
   943  			}
   944  			m.alpnProtocol = string(proto)
   945  		case extensionQUICTransportParameters:
   946  			m.quicTransportParameters = make([]byte, len(extData))
   947  			if !extData.CopyBytes(m.quicTransportParameters) {
   948  				return false
   949  			}
   950  		case extensionEarlyData:
   951  			// RFC 8446, Section 4.2.10
   952  			m.earlyData = true
   953  		default:
   954  			// Ignore unknown extensions.
   955  			continue
   956  		}
   957  
   958  		if !extData.Empty() {
   959  			return false
   960  		}
   961  	}
   962  
   963  	return true
   964  }
   965  
   966  type endOfEarlyDataMsg struct{}
   967  
   968  func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
   969  	x := make([]byte, 4)
   970  	x[0] = typeEndOfEarlyData
   971  	return x, nil
   972  }
   973  
   974  func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
   975  	return len(data) == 4
   976  }
   977  
   978  type keyUpdateMsg struct {
   979  	updateRequested bool
   980  }
   981  
   982  func (m *keyUpdateMsg) marshal() ([]byte, error) {
   983  	var b cryptobyte.Builder
   984  	b.AddUint8(typeKeyUpdate)
   985  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
   986  		if m.updateRequested {
   987  			b.AddUint8(1)
   988  		} else {
   989  			b.AddUint8(0)
   990  		}
   991  	})
   992  
   993  	return b.Bytes()
   994  }
   995  
   996  func (m *keyUpdateMsg) unmarshal(data []byte) bool {
   997  	s := cryptobyte.String(data)
   998  
   999  	var updateRequested uint8
  1000  	if !s.Skip(4) || // message type and uint24 length field
  1001  		!s.ReadUint8(&updateRequested) || !s.Empty() {
  1002  		return false
  1003  	}
  1004  	switch updateRequested {
  1005  	case 0:
  1006  		m.updateRequested = false
  1007  	case 1:
  1008  		m.updateRequested = true
  1009  	default:
  1010  		return false
  1011  	}
  1012  	return true
  1013  }
  1014  
  1015  type newSessionTicketMsgTLS13 struct {
  1016  	lifetime     uint32
  1017  	ageAdd       uint32
  1018  	nonce        []byte
  1019  	label        []byte
  1020  	maxEarlyData uint32
  1021  }
  1022  
  1023  func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
  1024  	var b cryptobyte.Builder
  1025  	b.AddUint8(typeNewSessionTicket)
  1026  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1027  		b.AddUint32(m.lifetime)
  1028  		b.AddUint32(m.ageAdd)
  1029  		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
  1030  			b.AddBytes(m.nonce)
  1031  		})
  1032  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1033  			b.AddBytes(m.label)
  1034  		})
  1035  
  1036  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1037  			if m.maxEarlyData > 0 {
  1038  				b.AddUint16(extensionEarlyData)
  1039  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1040  					b.AddUint32(m.maxEarlyData)
  1041  				})
  1042  			}
  1043  		})
  1044  	})
  1045  
  1046  	return b.Bytes()
  1047  }
  1048  
  1049  func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
  1050  	*m = newSessionTicketMsgTLS13{}
  1051  	s := cryptobyte.String(data)
  1052  
  1053  	var extensions cryptobyte.String
  1054  	if !s.Skip(4) || // message type and uint24 length field
  1055  		!s.ReadUint32(&m.lifetime) ||
  1056  		!s.ReadUint32(&m.ageAdd) ||
  1057  		!readUint8LengthPrefixed(&s, &m.nonce) ||
  1058  		!readUint16LengthPrefixed(&s, &m.label) ||
  1059  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1060  		!s.Empty() {
  1061  		return false
  1062  	}
  1063  
  1064  	for !extensions.Empty() {
  1065  		var extension uint16
  1066  		var extData cryptobyte.String
  1067  		if !extensions.ReadUint16(&extension) ||
  1068  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1069  			return false
  1070  		}
  1071  
  1072  		switch extension {
  1073  		case extensionEarlyData:
  1074  			if !extData.ReadUint32(&m.maxEarlyData) {
  1075  				return false
  1076  			}
  1077  		default:
  1078  			// Ignore unknown extensions.
  1079  			continue
  1080  		}
  1081  
  1082  		if !extData.Empty() {
  1083  			return false
  1084  		}
  1085  	}
  1086  
  1087  	return true
  1088  }
  1089  
  1090  type certificateRequestMsgTLS13 struct {
  1091  	ocspStapling                     bool
  1092  	scts                             bool
  1093  	supportedSignatureAlgorithms     []SignatureScheme
  1094  	supportedSignatureAlgorithmsCert []SignatureScheme
  1095  	certificateAuthorities           [][]byte
  1096  }
  1097  
  1098  func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
  1099  	var b cryptobyte.Builder
  1100  	b.AddUint8(typeCertificateRequest)
  1101  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1102  		// certificate_request_context (SHALL be zero length unless used for
  1103  		// post-handshake authentication)
  1104  		b.AddUint8(0)
  1105  
  1106  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1107  			if m.ocspStapling {
  1108  				b.AddUint16(extensionStatusRequest)
  1109  				b.AddUint16(0) // empty extension_data
  1110  			}
  1111  			if m.scts {
  1112  				// RFC 8446, Section 4.4.2.1 makes no mention of
  1113  				// signed_certificate_timestamp in CertificateRequest, but
  1114  				// "Extensions in the Certificate message from the client MUST
  1115  				// correspond to extensions in the CertificateRequest message
  1116  				// from the server." and it appears in the table in Section 4.2.
  1117  				b.AddUint16(extensionSCT)
  1118  				b.AddUint16(0) // empty extension_data
  1119  			}
  1120  			if len(m.supportedSignatureAlgorithms) > 0 {
  1121  				b.AddUint16(extensionSignatureAlgorithms)
  1122  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1123  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1124  						for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1125  							b.AddUint16(uint16(sigAlgo))
  1126  						}
  1127  					})
  1128  				})
  1129  			}
  1130  			if len(m.supportedSignatureAlgorithmsCert) > 0 {
  1131  				b.AddUint16(extensionSignatureAlgorithmsCert)
  1132  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1133  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1134  						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
  1135  							b.AddUint16(uint16(sigAlgo))
  1136  						}
  1137  					})
  1138  				})
  1139  			}
  1140  			if len(m.certificateAuthorities) > 0 {
  1141  				b.AddUint16(extensionCertificateAuthorities)
  1142  				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1143  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1144  						for _, ca := range m.certificateAuthorities {
  1145  							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1146  								b.AddBytes(ca)
  1147  							})
  1148  						}
  1149  					})
  1150  				})
  1151  			}
  1152  		})
  1153  	})
  1154  
  1155  	return b.Bytes()
  1156  }
  1157  
  1158  func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
  1159  	*m = certificateRequestMsgTLS13{}
  1160  	s := cryptobyte.String(data)
  1161  
  1162  	var context, extensions cryptobyte.String
  1163  	if !s.Skip(4) || // message type and uint24 length field
  1164  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1165  		!s.ReadUint16LengthPrefixed(&extensions) ||
  1166  		!s.Empty() {
  1167  		return false
  1168  	}
  1169  
  1170  	for !extensions.Empty() {
  1171  		var extension uint16
  1172  		var extData cryptobyte.String
  1173  		if !extensions.ReadUint16(&extension) ||
  1174  			!extensions.ReadUint16LengthPrefixed(&extData) {
  1175  			return false
  1176  		}
  1177  
  1178  		switch extension {
  1179  		case extensionStatusRequest:
  1180  			m.ocspStapling = true
  1181  		case extensionSCT:
  1182  			m.scts = true
  1183  		case extensionSignatureAlgorithms:
  1184  			var sigAndAlgs cryptobyte.String
  1185  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1186  				return false
  1187  			}
  1188  			for !sigAndAlgs.Empty() {
  1189  				var sigAndAlg uint16
  1190  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1191  					return false
  1192  				}
  1193  				m.supportedSignatureAlgorithms = append(
  1194  					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
  1195  			}
  1196  		case extensionSignatureAlgorithmsCert:
  1197  			var sigAndAlgs cryptobyte.String
  1198  			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
  1199  				return false
  1200  			}
  1201  			for !sigAndAlgs.Empty() {
  1202  				var sigAndAlg uint16
  1203  				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
  1204  					return false
  1205  				}
  1206  				m.supportedSignatureAlgorithmsCert = append(
  1207  					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
  1208  			}
  1209  		case extensionCertificateAuthorities:
  1210  			var auths cryptobyte.String
  1211  			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
  1212  				return false
  1213  			}
  1214  			for !auths.Empty() {
  1215  				var ca []byte
  1216  				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
  1217  					return false
  1218  				}
  1219  				m.certificateAuthorities = append(m.certificateAuthorities, ca)
  1220  			}
  1221  		default:
  1222  			// Ignore unknown extensions.
  1223  			continue
  1224  		}
  1225  
  1226  		if !extData.Empty() {
  1227  			return false
  1228  		}
  1229  	}
  1230  
  1231  	return true
  1232  }
  1233  
  1234  type certificateMsg struct {
  1235  	certificates [][]byte
  1236  }
  1237  
  1238  func (m *certificateMsg) marshal() ([]byte, error) {
  1239  	var i int
  1240  	for _, slice := range m.certificates {
  1241  		i += len(slice)
  1242  	}
  1243  
  1244  	length := 3 + 3*len(m.certificates) + i
  1245  	x := make([]byte, 4+length)
  1246  	x[0] = typeCertificate
  1247  	x[1] = uint8(length >> 16)
  1248  	x[2] = uint8(length >> 8)
  1249  	x[3] = uint8(length)
  1250  
  1251  	certificateOctets := length - 3
  1252  	x[4] = uint8(certificateOctets >> 16)
  1253  	x[5] = uint8(certificateOctets >> 8)
  1254  	x[6] = uint8(certificateOctets)
  1255  
  1256  	y := x[7:]
  1257  	for _, slice := range m.certificates {
  1258  		y[0] = uint8(len(slice) >> 16)
  1259  		y[1] = uint8(len(slice) >> 8)
  1260  		y[2] = uint8(len(slice))
  1261  		copy(y[3:], slice)
  1262  		y = y[3+len(slice):]
  1263  	}
  1264  
  1265  	return x, nil
  1266  }
  1267  
  1268  func (m *certificateMsg) unmarshal(data []byte) bool {
  1269  	if len(data) < 7 {
  1270  		return false
  1271  	}
  1272  
  1273  	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
  1274  	if uint32(len(data)) != certsLen+7 {
  1275  		return false
  1276  	}
  1277  
  1278  	numCerts := 0
  1279  	d := data[7:]
  1280  	for certsLen > 0 {
  1281  		if len(d) < 4 {
  1282  			return false
  1283  		}
  1284  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1285  		if uint32(len(d)) < 3+certLen {
  1286  			return false
  1287  		}
  1288  		d = d[3+certLen:]
  1289  		certsLen -= 3 + certLen
  1290  		numCerts++
  1291  	}
  1292  
  1293  	m.certificates = make([][]byte, numCerts)
  1294  	d = data[7:]
  1295  	for i := 0; i < numCerts; i++ {
  1296  		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
  1297  		m.certificates[i] = d[3 : 3+certLen]
  1298  		d = d[3+certLen:]
  1299  	}
  1300  
  1301  	return true
  1302  }
  1303  
  1304  type certificateMsgTLS13 struct {
  1305  	certificate  Certificate
  1306  	ocspStapling bool
  1307  	scts         bool
  1308  }
  1309  
  1310  func (m *certificateMsgTLS13) marshal() ([]byte, error) {
  1311  	var b cryptobyte.Builder
  1312  	b.AddUint8(typeCertificate)
  1313  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1314  		b.AddUint8(0) // certificate_request_context
  1315  
  1316  		certificate := m.certificate
  1317  		if !m.ocspStapling {
  1318  			certificate.OCSPStaple = nil
  1319  		}
  1320  		if !m.scts {
  1321  			certificate.SignedCertificateTimestamps = nil
  1322  		}
  1323  		marshalCertificate(b, certificate)
  1324  	})
  1325  
  1326  	return b.Bytes()
  1327  }
  1328  
  1329  func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
  1330  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1331  		for i, cert := range certificate.Certificate {
  1332  			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1333  				b.AddBytes(cert)
  1334  			})
  1335  			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1336  				if i > 0 {
  1337  					// This library only supports OCSP and SCT for leaf certificates.
  1338  					return
  1339  				}
  1340  				if certificate.OCSPStaple != nil {
  1341  					b.AddUint16(extensionStatusRequest)
  1342  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1343  						b.AddUint8(statusTypeOCSP)
  1344  						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1345  							b.AddBytes(certificate.OCSPStaple)
  1346  						})
  1347  					})
  1348  				}
  1349  				if certificate.SignedCertificateTimestamps != nil {
  1350  					b.AddUint16(extensionSCT)
  1351  					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1352  						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1353  							for _, sct := range certificate.SignedCertificateTimestamps {
  1354  								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1355  									b.AddBytes(sct)
  1356  								})
  1357  							}
  1358  						})
  1359  					})
  1360  				}
  1361  			})
  1362  		}
  1363  	})
  1364  }
  1365  
  1366  func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
  1367  	*m = certificateMsgTLS13{}
  1368  	s := cryptobyte.String(data)
  1369  
  1370  	var context cryptobyte.String
  1371  	if !s.Skip(4) || // message type and uint24 length field
  1372  		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
  1373  		!unmarshalCertificate(&s, &m.certificate) ||
  1374  		!s.Empty() {
  1375  		return false
  1376  	}
  1377  
  1378  	m.scts = m.certificate.SignedCertificateTimestamps != nil
  1379  	m.ocspStapling = m.certificate.OCSPStaple != nil
  1380  
  1381  	return true
  1382  }
  1383  
  1384  func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
  1385  	var certList cryptobyte.String
  1386  	if !s.ReadUint24LengthPrefixed(&certList) {
  1387  		return false
  1388  	}
  1389  	for !certList.Empty() {
  1390  		var cert []byte
  1391  		var extensions cryptobyte.String
  1392  		if !readUint24LengthPrefixed(&certList, &cert) ||
  1393  			!certList.ReadUint16LengthPrefixed(&extensions) {
  1394  			return false
  1395  		}
  1396  		certificate.Certificate = append(certificate.Certificate, cert)
  1397  		for !extensions.Empty() {
  1398  			var extension uint16
  1399  			var extData cryptobyte.String
  1400  			if !extensions.ReadUint16(&extension) ||
  1401  				!extensions.ReadUint16LengthPrefixed(&extData) {
  1402  				return false
  1403  			}
  1404  			if len(certificate.Certificate) > 1 {
  1405  				// This library only supports OCSP and SCT for leaf certificates.
  1406  				continue
  1407  			}
  1408  
  1409  			switch extension {
  1410  			case extensionStatusRequest:
  1411  				var statusType uint8
  1412  				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1413  					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
  1414  					len(certificate.OCSPStaple) == 0 {
  1415  					return false
  1416  				}
  1417  			case extensionSCT:
  1418  				var sctList cryptobyte.String
  1419  				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
  1420  					return false
  1421  				}
  1422  				for !sctList.Empty() {
  1423  					var sct []byte
  1424  					if !readUint16LengthPrefixed(&sctList, &sct) ||
  1425  						len(sct) == 0 {
  1426  						return false
  1427  					}
  1428  					certificate.SignedCertificateTimestamps = append(
  1429  						certificate.SignedCertificateTimestamps, sct)
  1430  				}
  1431  			default:
  1432  				// Ignore unknown extensions.
  1433  				continue
  1434  			}
  1435  
  1436  			if !extData.Empty() {
  1437  				return false
  1438  			}
  1439  		}
  1440  	}
  1441  	return true
  1442  }
  1443  
  1444  type serverKeyExchangeMsg struct {
  1445  	key []byte
  1446  }
  1447  
  1448  func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
  1449  	length := len(m.key)
  1450  	x := make([]byte, length+4)
  1451  	x[0] = typeServerKeyExchange
  1452  	x[1] = uint8(length >> 16)
  1453  	x[2] = uint8(length >> 8)
  1454  	x[3] = uint8(length)
  1455  	copy(x[4:], m.key)
  1456  
  1457  	return x, nil
  1458  }
  1459  
  1460  func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
  1461  	if len(data) < 4 {
  1462  		return false
  1463  	}
  1464  	m.key = data[4:]
  1465  	return true
  1466  }
  1467  
  1468  type certificateStatusMsg struct {
  1469  	response []byte
  1470  }
  1471  
  1472  func (m *certificateStatusMsg) marshal() ([]byte, error) {
  1473  	var b cryptobyte.Builder
  1474  	b.AddUint8(typeCertificateStatus)
  1475  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1476  		b.AddUint8(statusTypeOCSP)
  1477  		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1478  			b.AddBytes(m.response)
  1479  		})
  1480  	})
  1481  
  1482  	return b.Bytes()
  1483  }
  1484  
  1485  func (m *certificateStatusMsg) unmarshal(data []byte) bool {
  1486  	s := cryptobyte.String(data)
  1487  
  1488  	var statusType uint8
  1489  	if !s.Skip(4) || // message type and uint24 length field
  1490  		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
  1491  		!readUint24LengthPrefixed(&s, &m.response) ||
  1492  		len(m.response) == 0 || !s.Empty() {
  1493  		return false
  1494  	}
  1495  	return true
  1496  }
  1497  
  1498  type serverHelloDoneMsg struct{}
  1499  
  1500  func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
  1501  	x := make([]byte, 4)
  1502  	x[0] = typeServerHelloDone
  1503  	return x, nil
  1504  }
  1505  
  1506  func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
  1507  	return len(data) == 4
  1508  }
  1509  
  1510  type clientKeyExchangeMsg struct {
  1511  	ciphertext []byte
  1512  }
  1513  
  1514  func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
  1515  	length := len(m.ciphertext)
  1516  	x := make([]byte, length+4)
  1517  	x[0] = typeClientKeyExchange
  1518  	x[1] = uint8(length >> 16)
  1519  	x[2] = uint8(length >> 8)
  1520  	x[3] = uint8(length)
  1521  	copy(x[4:], m.ciphertext)
  1522  
  1523  	return x, nil
  1524  }
  1525  
  1526  func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
  1527  	if len(data) < 4 {
  1528  		return false
  1529  	}
  1530  	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
  1531  	if l != len(data)-4 {
  1532  		return false
  1533  	}
  1534  	m.ciphertext = data[4:]
  1535  	return true
  1536  }
  1537  
  1538  type finishedMsg struct {
  1539  	verifyData []byte
  1540  }
  1541  
  1542  func (m *finishedMsg) marshal() ([]byte, error) {
  1543  	var b cryptobyte.Builder
  1544  	b.AddUint8(typeFinished)
  1545  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1546  		b.AddBytes(m.verifyData)
  1547  	})
  1548  
  1549  	return b.Bytes()
  1550  }
  1551  
  1552  func (m *finishedMsg) unmarshal(data []byte) bool {
  1553  	s := cryptobyte.String(data)
  1554  	return s.Skip(1) &&
  1555  		readUint24LengthPrefixed(&s, &m.verifyData) &&
  1556  		s.Empty()
  1557  }
  1558  
  1559  type certificateRequestMsg struct {
  1560  	// hasSignatureAlgorithm indicates whether this message includes a list of
  1561  	// supported signature algorithms. This change was introduced with TLS 1.2.
  1562  	hasSignatureAlgorithm bool
  1563  
  1564  	certificateTypes             []byte
  1565  	supportedSignatureAlgorithms []SignatureScheme
  1566  	certificateAuthorities       [][]byte
  1567  }
  1568  
  1569  func (m *certificateRequestMsg) marshal() ([]byte, error) {
  1570  	// See RFC 4346, Section 7.4.4.
  1571  	length := 1 + len(m.certificateTypes) + 2
  1572  	casLength := 0
  1573  	for _, ca := range m.certificateAuthorities {
  1574  		casLength += 2 + len(ca)
  1575  	}
  1576  	length += casLength
  1577  
  1578  	if m.hasSignatureAlgorithm {
  1579  		length += 2 + 2*len(m.supportedSignatureAlgorithms)
  1580  	}
  1581  
  1582  	x := make([]byte, 4+length)
  1583  	x[0] = typeCertificateRequest
  1584  	x[1] = uint8(length >> 16)
  1585  	x[2] = uint8(length >> 8)
  1586  	x[3] = uint8(length)
  1587  
  1588  	x[4] = uint8(len(m.certificateTypes))
  1589  
  1590  	copy(x[5:], m.certificateTypes)
  1591  	y := x[5+len(m.certificateTypes):]
  1592  
  1593  	if m.hasSignatureAlgorithm {
  1594  		n := len(m.supportedSignatureAlgorithms) * 2
  1595  		y[0] = uint8(n >> 8)
  1596  		y[1] = uint8(n)
  1597  		y = y[2:]
  1598  		for _, sigAlgo := range m.supportedSignatureAlgorithms {
  1599  			y[0] = uint8(sigAlgo >> 8)
  1600  			y[1] = uint8(sigAlgo)
  1601  			y = y[2:]
  1602  		}
  1603  	}
  1604  
  1605  	y[0] = uint8(casLength >> 8)
  1606  	y[1] = uint8(casLength)
  1607  	y = y[2:]
  1608  	for _, ca := range m.certificateAuthorities {
  1609  		y[0] = uint8(len(ca) >> 8)
  1610  		y[1] = uint8(len(ca))
  1611  		y = y[2:]
  1612  		copy(y, ca)
  1613  		y = y[len(ca):]
  1614  	}
  1615  
  1616  	return x, nil
  1617  }
  1618  
  1619  func (m *certificateRequestMsg) unmarshal(data []byte) bool {
  1620  	if len(data) < 5 {
  1621  		return false
  1622  	}
  1623  
  1624  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1625  	if uint32(len(data))-4 != length {
  1626  		return false
  1627  	}
  1628  
  1629  	numCertTypes := int(data[4])
  1630  	data = data[5:]
  1631  	if numCertTypes == 0 || len(data) <= numCertTypes {
  1632  		return false
  1633  	}
  1634  
  1635  	m.certificateTypes = make([]byte, numCertTypes)
  1636  	if copy(m.certificateTypes, data) != numCertTypes {
  1637  		return false
  1638  	}
  1639  
  1640  	data = data[numCertTypes:]
  1641  
  1642  	if m.hasSignatureAlgorithm {
  1643  		if len(data) < 2 {
  1644  			return false
  1645  		}
  1646  		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
  1647  		data = data[2:]
  1648  		if sigAndHashLen&1 != 0 {
  1649  			return false
  1650  		}
  1651  		if len(data) < int(sigAndHashLen) {
  1652  			return false
  1653  		}
  1654  		numSigAlgos := sigAndHashLen / 2
  1655  		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
  1656  		for i := range m.supportedSignatureAlgorithms {
  1657  			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
  1658  			data = data[2:]
  1659  		}
  1660  	}
  1661  
  1662  	if len(data) < 2 {
  1663  		return false
  1664  	}
  1665  	casLength := uint16(data[0])<<8 | uint16(data[1])
  1666  	data = data[2:]
  1667  	if len(data) < int(casLength) {
  1668  		return false
  1669  	}
  1670  	cas := make([]byte, casLength)
  1671  	copy(cas, data)
  1672  	data = data[casLength:]
  1673  
  1674  	m.certificateAuthorities = nil
  1675  	for len(cas) > 0 {
  1676  		if len(cas) < 2 {
  1677  			return false
  1678  		}
  1679  		caLen := uint16(cas[0])<<8 | uint16(cas[1])
  1680  		cas = cas[2:]
  1681  
  1682  		if len(cas) < int(caLen) {
  1683  			return false
  1684  		}
  1685  
  1686  		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
  1687  		cas = cas[caLen:]
  1688  	}
  1689  
  1690  	return len(data) == 0
  1691  }
  1692  
  1693  type certificateVerifyMsg struct {
  1694  	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
  1695  	signatureAlgorithm    SignatureScheme
  1696  	signature             []byte
  1697  }
  1698  
  1699  func (m *certificateVerifyMsg) marshal() ([]byte, error) {
  1700  	var b cryptobyte.Builder
  1701  	b.AddUint8(typeCertificateVerify)
  1702  	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
  1703  		if m.hasSignatureAlgorithm {
  1704  			b.AddUint16(uint16(m.signatureAlgorithm))
  1705  		}
  1706  		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
  1707  			b.AddBytes(m.signature)
  1708  		})
  1709  	})
  1710  
  1711  	return b.Bytes()
  1712  }
  1713  
  1714  func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
  1715  	s := cryptobyte.String(data)
  1716  
  1717  	if !s.Skip(4) { // message type and uint24 length field
  1718  		return false
  1719  	}
  1720  	if m.hasSignatureAlgorithm {
  1721  		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
  1722  			return false
  1723  		}
  1724  	}
  1725  	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
  1726  }
  1727  
  1728  type newSessionTicketMsg struct {
  1729  	ticket []byte
  1730  }
  1731  
  1732  func (m *newSessionTicketMsg) marshal() ([]byte, error) {
  1733  	// See RFC 5077, Section 3.3.
  1734  	ticketLen := len(m.ticket)
  1735  	length := 2 + 4 + ticketLen
  1736  	x := make([]byte, 4+length)
  1737  	x[0] = typeNewSessionTicket
  1738  	x[1] = uint8(length >> 16)
  1739  	x[2] = uint8(length >> 8)
  1740  	x[3] = uint8(length)
  1741  	x[8] = uint8(ticketLen >> 8)
  1742  	x[9] = uint8(ticketLen)
  1743  	copy(x[10:], m.ticket)
  1744  
  1745  	return x, nil
  1746  }
  1747  
  1748  func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
  1749  	if len(data) < 10 {
  1750  		return false
  1751  	}
  1752  
  1753  	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
  1754  	if uint32(len(data))-4 != length {
  1755  		return false
  1756  	}
  1757  
  1758  	ticketLen := int(data[8])<<8 + int(data[9])
  1759  	if len(data)-10 != ticketLen {
  1760  		return false
  1761  	}
  1762  
  1763  	m.ticket = data[10:]
  1764  
  1765  	return true
  1766  }
  1767  
  1768  type helloRequestMsg struct {
  1769  }
  1770  
  1771  func (*helloRequestMsg) marshal() ([]byte, error) {
  1772  	return []byte{typeHelloRequest, 0, 0, 0}, nil
  1773  }
  1774  
  1775  func (*helloRequestMsg) unmarshal(data []byte) bool {
  1776  	return len(data) == 4
  1777  }
  1778  
  1779  type transcriptHash interface {
  1780  	Write([]byte) (int, error)
  1781  }
  1782  
  1783  // transcriptMsg is a helper used to hash messages which are not hashed when
  1784  // they are read from, or written to, the wire. This is typically the case for
  1785  // messages which are either not sent, or need to be hashed out of order from
  1786  // when they are read/written.
  1787  //
  1788  // For most messages, the message is marshalled using their marshal method,
  1789  // since their wire representation is idempotent. For clientHelloMsg and
  1790  // serverHelloMsg, we store the original wire representation of the message and
  1791  // use that for hashing, since unmarshal/marshal are not idempotent due to
  1792  // extension ordering and other malleable fields, which may cause differences
  1793  // between what was received and what we marshal.
  1794  func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
  1795  	if msgWithOrig, ok := msg.(handshakeMessageWithOriginalBytes); ok {
  1796  		if orig := msgWithOrig.originalBytes(); orig != nil {
  1797  			h.Write(msgWithOrig.originalBytes())
  1798  			return nil
  1799  		}
  1800  	}
  1801  
  1802  	data, err := msg.marshal()
  1803  	if err != nil {
  1804  		return err
  1805  	}
  1806  	h.Write(data)
  1807  	return nil
  1808  }
  1809  

View as plain text