Source file src/crypto/tls/handshake_messages_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/x509"
    10  	"encoding/hex"
    11  	"math"
    12  	"math/rand"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  	"testing/quick"
    17  	"time"
    18  )
    19  
    20  var tests = []handshakeMessage{
    21  	&clientHelloMsg{},
    22  	&serverHelloMsg{},
    23  	&finishedMsg{},
    24  
    25  	&certificateMsg{},
    26  	&certificateRequestMsg{},
    27  	&certificateVerifyMsg{
    28  		hasSignatureAlgorithm: true,
    29  	},
    30  	&certificateStatusMsg{},
    31  	&clientKeyExchangeMsg{},
    32  	&newSessionTicketMsg{},
    33  	&encryptedExtensionsMsg{},
    34  	&endOfEarlyDataMsg{},
    35  	&keyUpdateMsg{},
    36  	&newSessionTicketMsgTLS13{},
    37  	&certificateRequestMsgTLS13{},
    38  	&certificateMsgTLS13{},
    39  	&SessionState{},
    40  }
    41  
    42  func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
    43  	t.Helper()
    44  	b, err := msg.marshal()
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  	return b
    49  }
    50  
    51  func TestMarshalUnmarshal(t *testing.T) {
    52  	rand := rand.New(rand.NewSource(time.Now().UnixNano()))
    53  
    54  	for i, m := range tests {
    55  		ty := reflect.ValueOf(m).Type()
    56  		t.Run(ty.String(), func(t *testing.T) {
    57  			n := 100
    58  			if testing.Short() {
    59  				n = 5
    60  			}
    61  			for j := 0; j < n; j++ {
    62  				v, ok := quick.Value(ty, rand)
    63  				if !ok {
    64  					t.Errorf("#%d: failed to create value", i)
    65  					break
    66  				}
    67  
    68  				m1 := v.Interface().(handshakeMessage)
    69  				marshaled := mustMarshal(t, m1)
    70  				if !m.unmarshal(marshaled) {
    71  					t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
    72  					break
    73  				}
    74  
    75  				if ch, ok := m.(*clientHelloMsg); ok {
    76  					// extensions is special cased, as it is only populated by the
    77  					// server-side of a handshake and is not expected to roundtrip
    78  					// through marshal + unmarshal.  m ends up with the list of
    79  					// extensions necessary to serialize the other fields of
    80  					// clientHelloMsg, so check that it is non-empty, then clear it.
    81  					if len(ch.extensions) == 0 {
    82  						t.Errorf("expected ch.extensions to be populated on unmarshal")
    83  					}
    84  					ch.extensions = nil
    85  				}
    86  
    87  				// clientHelloMsg and serverHelloMsg, when unmarshalled, store
    88  				// their original representation, for later use in the handshake
    89  				// transcript. In order to prevent DeepEqual from failing since
    90  				// we didn't create the original message via unmarshalling, nil
    91  				// the field.
    92  				switch t := m.(type) {
    93  				case *clientHelloMsg:
    94  					t.original = nil
    95  				case *serverHelloMsg:
    96  					t.original = nil
    97  				}
    98  
    99  				if !reflect.DeepEqual(m1, m) {
   100  					t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
   101  					break
   102  				}
   103  
   104  				if i >= 3 {
   105  					// The first three message types (ClientHello,
   106  					// ServerHello and Finished) are allowed to
   107  					// have parsable prefixes because the extension
   108  					// data is optional and the length of the
   109  					// Finished varies across versions.
   110  					for j := 0; j < len(marshaled); j++ {
   111  						if m.unmarshal(marshaled[0:j]) {
   112  							t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
   113  							break
   114  						}
   115  					}
   116  				}
   117  			}
   118  		})
   119  	}
   120  }
   121  
   122  func TestFuzz(t *testing.T) {
   123  	rand := rand.New(rand.NewSource(0))
   124  	for _, m := range tests {
   125  		for j := 0; j < 1000; j++ {
   126  			len := rand.Intn(1000)
   127  			bytes := randomBytes(len, rand)
   128  			// This just looks for crashes due to bounds errors etc.
   129  			m.unmarshal(bytes)
   130  		}
   131  	}
   132  }
   133  
   134  func randomBytes(n int, rand *rand.Rand) []byte {
   135  	r := make([]byte, n)
   136  	if _, err := rand.Read(r); err != nil {
   137  		panic("rand.Read failed: " + err.Error())
   138  	}
   139  	return r
   140  }
   141  
   142  func randomString(n int, rand *rand.Rand) string {
   143  	b := randomBytes(n, rand)
   144  	return string(b)
   145  }
   146  
   147  func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   148  	m := &clientHelloMsg{}
   149  	m.vers = uint16(rand.Intn(65536))
   150  	m.random = randomBytes(32, rand)
   151  	m.sessionId = randomBytes(rand.Intn(32), rand)
   152  	m.cipherSuites = make([]uint16, rand.Intn(63)+1)
   153  	for i := 0; i < len(m.cipherSuites); i++ {
   154  		cs := uint16(rand.Int31())
   155  		if cs == scsvRenegotiation {
   156  			cs += 1
   157  		}
   158  		m.cipherSuites[i] = cs
   159  	}
   160  	m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
   161  	if rand.Intn(10) > 5 {
   162  		m.serverName = randomString(rand.Intn(255), rand)
   163  		for strings.HasSuffix(m.serverName, ".") {
   164  			m.serverName = m.serverName[:len(m.serverName)-1]
   165  		}
   166  	}
   167  	m.ocspStapling = rand.Intn(10) > 5
   168  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   169  	m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
   170  	for i := range m.supportedCurves {
   171  		m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
   172  	}
   173  	if rand.Intn(10) > 5 {
   174  		m.ticketSupported = true
   175  		if rand.Intn(10) > 5 {
   176  			m.sessionTicket = randomBytes(rand.Intn(300), rand)
   177  		} else {
   178  			m.sessionTicket = make([]byte, 0)
   179  		}
   180  	}
   181  	if rand.Intn(10) > 5 {
   182  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms(VersionTLS12)
   183  	}
   184  	if rand.Intn(10) > 5 {
   185  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms(VersionTLS12)
   186  	}
   187  	for i := 0; i < rand.Intn(5); i++ {
   188  		m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
   189  	}
   190  	if rand.Intn(10) > 5 {
   191  		m.scts = true
   192  	}
   193  	if rand.Intn(10) > 5 {
   194  		m.secureRenegotiationSupported = true
   195  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   196  	}
   197  	if rand.Intn(10) > 5 {
   198  		m.extendedMasterSecret = true
   199  	}
   200  	for i := 0; i < rand.Intn(5); i++ {
   201  		m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
   202  	}
   203  	if rand.Intn(10) > 5 {
   204  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   205  	}
   206  	for i := 0; i < rand.Intn(5); i++ {
   207  		var ks keyShare
   208  		ks.group = CurveID(rand.Intn(30000) + 1)
   209  		ks.data = randomBytes(rand.Intn(200)+1, rand)
   210  		m.keyShares = append(m.keyShares, ks)
   211  	}
   212  	switch rand.Intn(3) {
   213  	case 1:
   214  		m.pskModes = []uint8{pskModeDHE}
   215  	case 2:
   216  		m.pskModes = []uint8{pskModeDHE, pskModePlain}
   217  	}
   218  	for i := 0; i < rand.Intn(5); i++ {
   219  		var psk pskIdentity
   220  		psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
   221  		psk.label = randomBytes(rand.Intn(500)+1, rand)
   222  		m.pskIdentities = append(m.pskIdentities, psk)
   223  		m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
   224  	}
   225  	if rand.Intn(10) > 5 {
   226  		m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
   227  	}
   228  	if rand.Intn(10) > 5 {
   229  		m.earlyData = true
   230  	}
   231  	if rand.Intn(10) > 5 {
   232  		m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
   233  	}
   234  
   235  	return reflect.ValueOf(m)
   236  }
   237  
   238  func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   239  	m := &serverHelloMsg{}
   240  	m.vers = uint16(rand.Intn(65536))
   241  	m.random = randomBytes(32, rand)
   242  	m.sessionId = randomBytes(rand.Intn(32), rand)
   243  	m.cipherSuite = uint16(rand.Int31())
   244  	m.compressionMethod = uint8(rand.Intn(256))
   245  	m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
   246  
   247  	if rand.Intn(10) > 5 {
   248  		m.ocspStapling = true
   249  	}
   250  	if rand.Intn(10) > 5 {
   251  		m.ticketSupported = true
   252  	}
   253  	if rand.Intn(10) > 5 {
   254  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   255  	}
   256  
   257  	for i := 0; i < rand.Intn(4); i++ {
   258  		m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
   259  	}
   260  
   261  	if rand.Intn(10) > 5 {
   262  		m.secureRenegotiationSupported = true
   263  		m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
   264  	}
   265  	if rand.Intn(10) > 5 {
   266  		m.extendedMasterSecret = true
   267  	}
   268  	if rand.Intn(10) > 5 {
   269  		m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
   270  	}
   271  	if rand.Intn(10) > 5 {
   272  		m.cookie = randomBytes(rand.Intn(500)+1, rand)
   273  	}
   274  	if rand.Intn(10) > 5 {
   275  		for i := 0; i < rand.Intn(5); i++ {
   276  			m.serverShare.group = CurveID(rand.Intn(30000) + 1)
   277  			m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
   278  		}
   279  	} else if rand.Intn(10) > 5 {
   280  		m.selectedGroup = CurveID(rand.Intn(30000) + 1)
   281  	}
   282  	if rand.Intn(10) > 5 {
   283  		m.selectedIdentityPresent = true
   284  		m.selectedIdentity = uint16(rand.Intn(0xffff))
   285  	}
   286  	if rand.Intn(10) > 5 {
   287  		m.encryptedClientHello = randomBytes(rand.Intn(50)+1, rand)
   288  	}
   289  	if rand.Intn(10) > 5 {
   290  		m.serverNameAck = rand.Intn(2) == 1
   291  	}
   292  
   293  	return reflect.ValueOf(m)
   294  }
   295  
   296  func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   297  	m := &encryptedExtensionsMsg{}
   298  
   299  	if rand.Intn(10) > 5 {
   300  		m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
   301  	}
   302  	if rand.Intn(10) > 5 {
   303  		m.earlyData = true
   304  	}
   305  
   306  	return reflect.ValueOf(m)
   307  }
   308  
   309  func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   310  	m := &certificateMsg{}
   311  	numCerts := rand.Intn(20)
   312  	m.certificates = make([][]byte, numCerts)
   313  	for i := 0; i < numCerts; i++ {
   314  		m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
   315  	}
   316  	return reflect.ValueOf(m)
   317  }
   318  
   319  func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   320  	m := &certificateRequestMsg{}
   321  	m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
   322  	for i := 0; i < rand.Intn(100); i++ {
   323  		m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
   324  	}
   325  	return reflect.ValueOf(m)
   326  }
   327  
   328  func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   329  	m := &certificateVerifyMsg{}
   330  	m.hasSignatureAlgorithm = true
   331  	m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
   332  	m.signature = randomBytes(rand.Intn(15)+1, rand)
   333  	return reflect.ValueOf(m)
   334  }
   335  
   336  func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   337  	m := &certificateStatusMsg{}
   338  	m.response = randomBytes(rand.Intn(10)+1, rand)
   339  	return reflect.ValueOf(m)
   340  }
   341  
   342  func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   343  	m := &clientKeyExchangeMsg{}
   344  	m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
   345  	return reflect.ValueOf(m)
   346  }
   347  
   348  func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   349  	m := &finishedMsg{}
   350  	m.verifyData = randomBytes(12, rand)
   351  	return reflect.ValueOf(m)
   352  }
   353  
   354  func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   355  	m := &newSessionTicketMsg{}
   356  	m.ticket = randomBytes(rand.Intn(4), rand)
   357  	return reflect.ValueOf(m)
   358  }
   359  
   360  var sessionTestCerts []*x509.Certificate
   361  
   362  func init() {
   363  	cert, err := x509.ParseCertificate(testRSACertificate)
   364  	if err != nil {
   365  		panic(err)
   366  	}
   367  	sessionTestCerts = append(sessionTestCerts, cert)
   368  	cert, err = x509.ParseCertificate(testRSACertificateIssuer)
   369  	if err != nil {
   370  		panic(err)
   371  	}
   372  	sessionTestCerts = append(sessionTestCerts, cert)
   373  }
   374  
   375  func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
   376  	s := &SessionState{}
   377  	isTLS13 := rand.Intn(10) > 5
   378  	if isTLS13 {
   379  		s.version = VersionTLS13
   380  	} else {
   381  		s.version = uint16(rand.Intn(VersionTLS13))
   382  	}
   383  	s.isClient = rand.Intn(10) > 5
   384  	s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
   385  	s.createdAt = uint64(rand.Int63())
   386  	s.secret = randomBytes(rand.Intn(100)+1, rand)
   387  	for n, i := rand.Intn(3), 0; i < n; i++ {
   388  		s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
   389  	}
   390  	if rand.Intn(10) > 5 {
   391  		s.EarlyData = true
   392  	}
   393  	if rand.Intn(10) > 5 {
   394  		s.extMasterSecret = true
   395  	}
   396  	if s.isClient || rand.Intn(10) > 5 {
   397  		if rand.Intn(10) > 5 {
   398  			s.peerCertificates = sessionTestCerts
   399  		} else {
   400  			s.peerCertificates = sessionTestCerts[:1]
   401  		}
   402  	}
   403  	if rand.Intn(10) > 5 && s.peerCertificates != nil {
   404  		s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
   405  	}
   406  	if rand.Intn(10) > 5 && s.peerCertificates != nil {
   407  		for i := 0; i < rand.Intn(2)+1; i++ {
   408  			s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
   409  		}
   410  	}
   411  	if len(s.peerCertificates) > 0 {
   412  		for i := 0; i < rand.Intn(3); i++ {
   413  			if rand.Intn(10) > 5 {
   414  				s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
   415  			} else {
   416  				s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
   417  			}
   418  		}
   419  	}
   420  	if rand.Intn(10) > 5 && s.EarlyData {
   421  		s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
   422  	}
   423  	if isTLS13 {
   424  		if s.isClient {
   425  			s.useBy = uint64(rand.Int63())
   426  			s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
   427  		}
   428  	} else {
   429  		s.curveID = CurveID(rand.Intn(30000) + 1)
   430  	}
   431  	return reflect.ValueOf(s)
   432  }
   433  
   434  func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
   435  func (s *SessionState) unmarshal(b []byte) bool {
   436  	ss, err := ParseSessionState(b)
   437  	if err != nil {
   438  		return false
   439  	}
   440  	*s = *ss
   441  	return true
   442  }
   443  
   444  func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   445  	m := &endOfEarlyDataMsg{}
   446  	return reflect.ValueOf(m)
   447  }
   448  
   449  func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
   450  	m := &keyUpdateMsg{}
   451  	m.updateRequested = rand.Intn(10) > 5
   452  	return reflect.ValueOf(m)
   453  }
   454  
   455  func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   456  	m := &newSessionTicketMsgTLS13{}
   457  	m.lifetime = uint32(rand.Intn(500000))
   458  	m.ageAdd = uint32(rand.Intn(500000))
   459  	m.nonce = randomBytes(rand.Intn(100), rand)
   460  	m.label = randomBytes(rand.Intn(1000), rand)
   461  	if rand.Intn(10) > 5 {
   462  		m.maxEarlyData = uint32(rand.Intn(500000))
   463  	}
   464  	return reflect.ValueOf(m)
   465  }
   466  
   467  func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   468  	m := &certificateRequestMsgTLS13{}
   469  	if rand.Intn(10) > 5 {
   470  		m.ocspStapling = true
   471  	}
   472  	if rand.Intn(10) > 5 {
   473  		m.scts = true
   474  	}
   475  	if rand.Intn(10) > 5 {
   476  		m.supportedSignatureAlgorithms = supportedSignatureAlgorithms(VersionTLS12)
   477  	}
   478  	if rand.Intn(10) > 5 {
   479  		m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms(VersionTLS12)
   480  	}
   481  	if rand.Intn(10) > 5 {
   482  		m.certificateAuthorities = make([][]byte, 3)
   483  		for i := 0; i < 3; i++ {
   484  			m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
   485  		}
   486  	}
   487  	return reflect.ValueOf(m)
   488  }
   489  
   490  func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
   491  	m := &certificateMsgTLS13{}
   492  	for i := 0; i < rand.Intn(2)+1; i++ {
   493  		m.certificate.Certificate = append(
   494  			m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
   495  	}
   496  	if rand.Intn(10) > 5 {
   497  		m.ocspStapling = true
   498  		m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
   499  	}
   500  	if rand.Intn(10) > 5 {
   501  		m.scts = true
   502  		for i := 0; i < rand.Intn(2)+1; i++ {
   503  			m.certificate.SignedCertificateTimestamps = append(
   504  				m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
   505  		}
   506  	}
   507  	return reflect.ValueOf(m)
   508  }
   509  
   510  func TestRejectEmptySCTList(t *testing.T) {
   511  	// RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid.
   512  
   513  	var random [32]byte
   514  	sct := []byte{0x42, 0x42, 0x42, 0x42}
   515  	serverHello := &serverHelloMsg{
   516  		vers:   VersionTLS12,
   517  		random: random[:],
   518  		scts:   [][]byte{sct},
   519  	}
   520  	serverHelloBytes := mustMarshal(t, serverHello)
   521  
   522  	var serverHelloCopy serverHelloMsg
   523  	if !serverHelloCopy.unmarshal(serverHelloBytes) {
   524  		t.Fatal("Failed to unmarshal initial message")
   525  	}
   526  
   527  	// Change serverHelloBytes so that the SCT list is empty
   528  	i := bytes.Index(serverHelloBytes, sct)
   529  	if i < 0 {
   530  		t.Fatal("Cannot find SCT in ServerHello")
   531  	}
   532  
   533  	var serverHelloEmptySCT []byte
   534  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
   535  	// Append the extension length and SCT list length for an empty list.
   536  	serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
   537  	serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
   538  
   539  	// Update the handshake message length.
   540  	serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
   541  	serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
   542  	serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
   543  
   544  	// Update the extensions length
   545  	serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
   546  	serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
   547  
   548  	if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
   549  		t.Fatal("Unmarshaled ServerHello with empty SCT list")
   550  	}
   551  }
   552  
   553  func TestRejectEmptySCT(t *testing.T) {
   554  	// Not only must the SCT list be non-empty, but the SCT elements must
   555  	// not be zero length.
   556  
   557  	var random [32]byte
   558  	serverHello := &serverHelloMsg{
   559  		vers:   VersionTLS12,
   560  		random: random[:],
   561  		scts:   [][]byte{nil},
   562  	}
   563  	serverHelloBytes := mustMarshal(t, serverHello)
   564  
   565  	var serverHelloCopy serverHelloMsg
   566  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   567  		t.Fatal("Unmarshaled ServerHello with zero-length SCT")
   568  	}
   569  }
   570  
   571  func TestRejectDuplicateExtensions(t *testing.T) {
   572  	clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
   573  	if err != nil {
   574  		t.Fatalf("failed to decode test ClientHello: %s", err)
   575  	}
   576  	var clientHelloCopy clientHelloMsg
   577  	if clientHelloCopy.unmarshal(clientHelloBytes) {
   578  		t.Error("Unmarshaled ClientHello with duplicate extensions")
   579  	}
   580  
   581  	serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
   582  	if err != nil {
   583  		t.Fatalf("failed to decode test ServerHello: %s", err)
   584  	}
   585  	var serverHelloCopy serverHelloMsg
   586  	if serverHelloCopy.unmarshal(serverHelloBytes) {
   587  		t.Fatal("Unmarshaled ServerHello with duplicate extensions")
   588  	}
   589  }
   590  

View as plain text