Source file
src/crypto/tls/handshake_messages_test.go
1
2
3
4
5 package tls
6
7 import (
8 "bytes"
9 "crypto/x509"
10 "encoding/hex"
11 "math"
12 "math/rand"
13 "reflect"
14 "strings"
15 "testing"
16 "testing/quick"
17 "time"
18 )
19
20 var tests = []handshakeMessage{
21 &clientHelloMsg{},
22 &serverHelloMsg{},
23 &finishedMsg{},
24
25 &certificateMsg{},
26 &certificateRequestMsg{},
27 &certificateVerifyMsg{
28 hasSignatureAlgorithm: true,
29 },
30 &certificateStatusMsg{},
31 &clientKeyExchangeMsg{},
32 &newSessionTicketMsg{},
33 &encryptedExtensionsMsg{},
34 &endOfEarlyDataMsg{},
35 &keyUpdateMsg{},
36 &newSessionTicketMsgTLS13{},
37 &certificateRequestMsgTLS13{},
38 &certificateMsgTLS13{},
39 &SessionState{},
40 }
41
42 func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
43 t.Helper()
44 b, err := msg.marshal()
45 if err != nil {
46 t.Fatal(err)
47 }
48 return b
49 }
50
51 func TestMarshalUnmarshal(t *testing.T) {
52 rand := rand.New(rand.NewSource(time.Now().UnixNano()))
53
54 for i, m := range tests {
55 ty := reflect.ValueOf(m).Type()
56 t.Run(ty.String(), func(t *testing.T) {
57 n := 100
58 if testing.Short() {
59 n = 5
60 }
61 for j := 0; j < n; j++ {
62 v, ok := quick.Value(ty, rand)
63 if !ok {
64 t.Errorf("#%d: failed to create value", i)
65 break
66 }
67
68 m1 := v.Interface().(handshakeMessage)
69 marshaled := mustMarshal(t, m1)
70 if !m.unmarshal(marshaled) {
71 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
72 break
73 }
74
75 if m, ok := m.(*SessionState); ok {
76 m.activeCertHandles = nil
77 }
78
79
80
81
82
83
84 switch t := m.(type) {
85 case *clientHelloMsg:
86 t.original = nil
87 case *serverHelloMsg:
88 t.original = nil
89 }
90
91 if !reflect.DeepEqual(m1, m) {
92 t.Errorf("#%d got:%#v want:%#v %x", i, m, m1, marshaled)
93 break
94 }
95
96 if i >= 3 {
97
98
99
100
101
102 for j := 0; j < len(marshaled); j++ {
103 if m.unmarshal(marshaled[0:j]) {
104 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
105 break
106 }
107 }
108 }
109 }
110 })
111 }
112 }
113
114 func TestFuzz(t *testing.T) {
115 rand := rand.New(rand.NewSource(0))
116 for _, m := range tests {
117 for j := 0; j < 1000; j++ {
118 len := rand.Intn(1000)
119 bytes := randomBytes(len, rand)
120
121 m.unmarshal(bytes)
122 }
123 }
124 }
125
126 func randomBytes(n int, rand *rand.Rand) []byte {
127 r := make([]byte, n)
128 if _, err := rand.Read(r); err != nil {
129 panic("rand.Read failed: " + err.Error())
130 }
131 return r
132 }
133
134 func randomString(n int, rand *rand.Rand) string {
135 b := randomBytes(n, rand)
136 return string(b)
137 }
138
139 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
140 m := &clientHelloMsg{}
141 m.vers = uint16(rand.Intn(65536))
142 m.random = randomBytes(32, rand)
143 m.sessionId = randomBytes(rand.Intn(32), rand)
144 m.cipherSuites = make([]uint16, rand.Intn(63)+1)
145 for i := 0; i < len(m.cipherSuites); i++ {
146 cs := uint16(rand.Int31())
147 if cs == scsvRenegotiation {
148 cs += 1
149 }
150 m.cipherSuites[i] = cs
151 }
152 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
153 if rand.Intn(10) > 5 {
154 m.serverName = randomString(rand.Intn(255), rand)
155 for strings.HasSuffix(m.serverName, ".") {
156 m.serverName = m.serverName[:len(m.serverName)-1]
157 }
158 }
159 m.ocspStapling = rand.Intn(10) > 5
160 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
161 m.supportedCurves = make([]CurveID, rand.Intn(5)+1)
162 for i := range m.supportedCurves {
163 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1)
164 }
165 if rand.Intn(10) > 5 {
166 m.ticketSupported = true
167 if rand.Intn(10) > 5 {
168 m.sessionTicket = randomBytes(rand.Intn(300), rand)
169 } else {
170 m.sessionTicket = make([]byte, 0)
171 }
172 }
173 if rand.Intn(10) > 5 {
174 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
175 }
176 if rand.Intn(10) > 5 {
177 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
178 }
179 for i := 0; i < rand.Intn(5); i++ {
180 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
181 }
182 if rand.Intn(10) > 5 {
183 m.scts = true
184 }
185 if rand.Intn(10) > 5 {
186 m.secureRenegotiationSupported = true
187 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
188 }
189 if rand.Intn(10) > 5 {
190 m.extendedMasterSecret = true
191 }
192 for i := 0; i < rand.Intn(5); i++ {
193 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1))
194 }
195 if rand.Intn(10) > 5 {
196 m.cookie = randomBytes(rand.Intn(500)+1, rand)
197 }
198 for i := 0; i < rand.Intn(5); i++ {
199 var ks keyShare
200 ks.group = CurveID(rand.Intn(30000) + 1)
201 ks.data = randomBytes(rand.Intn(200)+1, rand)
202 m.keyShares = append(m.keyShares, ks)
203 }
204 switch rand.Intn(3) {
205 case 1:
206 m.pskModes = []uint8{pskModeDHE}
207 case 2:
208 m.pskModes = []uint8{pskModeDHE, pskModePlain}
209 }
210 for i := 0; i < rand.Intn(5); i++ {
211 var psk pskIdentity
212 psk.obfuscatedTicketAge = uint32(rand.Intn(500000))
213 psk.label = randomBytes(rand.Intn(500)+1, rand)
214 m.pskIdentities = append(m.pskIdentities, psk)
215 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand))
216 }
217 if rand.Intn(10) > 5 {
218 m.quicTransportParameters = randomBytes(rand.Intn(500), rand)
219 }
220 if rand.Intn(10) > 5 {
221 m.earlyData = true
222 }
223
224 return reflect.ValueOf(m)
225 }
226
227 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
228 m := &serverHelloMsg{}
229 m.vers = uint16(rand.Intn(65536))
230 m.random = randomBytes(32, rand)
231 m.sessionId = randomBytes(rand.Intn(32), rand)
232 m.cipherSuite = uint16(rand.Int31())
233 m.compressionMethod = uint8(rand.Intn(256))
234 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand)
235
236 if rand.Intn(10) > 5 {
237 m.ocspStapling = true
238 }
239 if rand.Intn(10) > 5 {
240 m.ticketSupported = true
241 }
242 if rand.Intn(10) > 5 {
243 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
244 }
245
246 for i := 0; i < rand.Intn(4); i++ {
247 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
248 }
249
250 if rand.Intn(10) > 5 {
251 m.secureRenegotiationSupported = true
252 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand)
253 }
254 if rand.Intn(10) > 5 {
255 m.extendedMasterSecret = true
256 }
257 if rand.Intn(10) > 5 {
258 m.supportedVersion = uint16(rand.Intn(0xffff) + 1)
259 }
260 if rand.Intn(10) > 5 {
261 m.cookie = randomBytes(rand.Intn(500)+1, rand)
262 }
263 if rand.Intn(10) > 5 {
264 for i := 0; i < rand.Intn(5); i++ {
265 m.serverShare.group = CurveID(rand.Intn(30000) + 1)
266 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand)
267 }
268 } else if rand.Intn(10) > 5 {
269 m.selectedGroup = CurveID(rand.Intn(30000) + 1)
270 }
271 if rand.Intn(10) > 5 {
272 m.selectedIdentityPresent = true
273 m.selectedIdentity = uint16(rand.Intn(0xffff))
274 }
275
276 return reflect.ValueOf(m)
277 }
278
279 func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
280 m := &encryptedExtensionsMsg{}
281
282 if rand.Intn(10) > 5 {
283 m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
284 }
285 if rand.Intn(10) > 5 {
286 m.earlyData = true
287 }
288
289 return reflect.ValueOf(m)
290 }
291
292 func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
293 m := &certificateMsg{}
294 numCerts := rand.Intn(20)
295 m.certificates = make([][]byte, numCerts)
296 for i := 0; i < numCerts; i++ {
297 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
298 }
299 return reflect.ValueOf(m)
300 }
301
302 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
303 m := &certificateRequestMsg{}
304 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
305 for i := 0; i < rand.Intn(100); i++ {
306 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
307 }
308 return reflect.ValueOf(m)
309 }
310
311 func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value {
312 m := &certificateVerifyMsg{}
313 m.hasSignatureAlgorithm = true
314 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000))
315 m.signature = randomBytes(rand.Intn(15)+1, rand)
316 return reflect.ValueOf(m)
317 }
318
319 func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
320 m := &certificateStatusMsg{}
321 m.response = randomBytes(rand.Intn(10)+1, rand)
322 return reflect.ValueOf(m)
323 }
324
325 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
326 m := &clientKeyExchangeMsg{}
327 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
328 return reflect.ValueOf(m)
329 }
330
331 func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
332 m := &finishedMsg{}
333 m.verifyData = randomBytes(12, rand)
334 return reflect.ValueOf(m)
335 }
336
337 func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value {
338 m := &newSessionTicketMsg{}
339 m.ticket = randomBytes(rand.Intn(4), rand)
340 return reflect.ValueOf(m)
341 }
342
343 var sessionTestCerts []*x509.Certificate
344
345 func init() {
346 cert, err := x509.ParseCertificate(testRSACertificate)
347 if err != nil {
348 panic(err)
349 }
350 sessionTestCerts = append(sessionTestCerts, cert)
351 cert, err = x509.ParseCertificate(testRSACertificateIssuer)
352 if err != nil {
353 panic(err)
354 }
355 sessionTestCerts = append(sessionTestCerts, cert)
356 }
357
358 func (*SessionState) Generate(rand *rand.Rand, size int) reflect.Value {
359 s := &SessionState{}
360 isTLS13 := rand.Intn(10) > 5
361 if isTLS13 {
362 s.version = VersionTLS13
363 } else {
364 s.version = uint16(rand.Intn(VersionTLS13))
365 }
366 s.isClient = rand.Intn(10) > 5
367 s.cipherSuite = uint16(rand.Intn(math.MaxUint16))
368 s.createdAt = uint64(rand.Int63())
369 s.secret = randomBytes(rand.Intn(100)+1, rand)
370 for n, i := rand.Intn(3), 0; i < n; i++ {
371 s.Extra = append(s.Extra, randomBytes(rand.Intn(100), rand))
372 }
373 if rand.Intn(10) > 5 {
374 s.EarlyData = true
375 }
376 if rand.Intn(10) > 5 {
377 s.extMasterSecret = true
378 }
379 if s.isClient || rand.Intn(10) > 5 {
380 if rand.Intn(10) > 5 {
381 s.peerCertificates = sessionTestCerts
382 } else {
383 s.peerCertificates = sessionTestCerts[:1]
384 }
385 }
386 if rand.Intn(10) > 5 && s.peerCertificates != nil {
387 s.ocspResponse = randomBytes(rand.Intn(100)+1, rand)
388 }
389 if rand.Intn(10) > 5 && s.peerCertificates != nil {
390 for i := 0; i < rand.Intn(2)+1; i++ {
391 s.scts = append(s.scts, randomBytes(rand.Intn(500)+1, rand))
392 }
393 }
394 if len(s.peerCertificates) > 0 {
395 for i := 0; i < rand.Intn(3); i++ {
396 if rand.Intn(10) > 5 {
397 s.verifiedChains = append(s.verifiedChains, s.peerCertificates)
398 } else {
399 s.verifiedChains = append(s.verifiedChains, s.peerCertificates[:1])
400 }
401 }
402 }
403 if rand.Intn(10) > 5 && s.EarlyData {
404 s.alpnProtocol = string(randomBytes(rand.Intn(10), rand))
405 }
406 if s.isClient {
407 if isTLS13 {
408 s.useBy = uint64(rand.Int63())
409 s.ageAdd = uint32(rand.Int63() & math.MaxUint32)
410 }
411 }
412 return reflect.ValueOf(s)
413 }
414
415 func (s *SessionState) marshal() ([]byte, error) { return s.Bytes() }
416 func (s *SessionState) unmarshal(b []byte) bool {
417 ss, err := ParseSessionState(b)
418 if err != nil {
419 return false
420 }
421 *s = *ss
422 return true
423 }
424
425 func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value {
426 m := &endOfEarlyDataMsg{}
427 return reflect.ValueOf(m)
428 }
429
430 func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
431 m := &keyUpdateMsg{}
432 m.updateRequested = rand.Intn(10) > 5
433 return reflect.ValueOf(m)
434 }
435
436 func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
437 m := &newSessionTicketMsgTLS13{}
438 m.lifetime = uint32(rand.Intn(500000))
439 m.ageAdd = uint32(rand.Intn(500000))
440 m.nonce = randomBytes(rand.Intn(100), rand)
441 m.label = randomBytes(rand.Intn(1000), rand)
442 if rand.Intn(10) > 5 {
443 m.maxEarlyData = uint32(rand.Intn(500000))
444 }
445 return reflect.ValueOf(m)
446 }
447
448 func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
449 m := &certificateRequestMsgTLS13{}
450 if rand.Intn(10) > 5 {
451 m.ocspStapling = true
452 }
453 if rand.Intn(10) > 5 {
454 m.scts = true
455 }
456 if rand.Intn(10) > 5 {
457 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms()
458 }
459 if rand.Intn(10) > 5 {
460 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms()
461 }
462 if rand.Intn(10) > 5 {
463 m.certificateAuthorities = make([][]byte, 3)
464 for i := 0; i < 3; i++ {
465 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand)
466 }
467 }
468 return reflect.ValueOf(m)
469 }
470
471 func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value {
472 m := &certificateMsgTLS13{}
473 for i := 0; i < rand.Intn(2)+1; i++ {
474 m.certificate.Certificate = append(
475 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand))
476 }
477 if rand.Intn(10) > 5 {
478 m.ocspStapling = true
479 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand)
480 }
481 if rand.Intn(10) > 5 {
482 m.scts = true
483 for i := 0; i < rand.Intn(2)+1; i++ {
484 m.certificate.SignedCertificateTimestamps = append(
485 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand))
486 }
487 }
488 return reflect.ValueOf(m)
489 }
490
491 func TestRejectEmptySCTList(t *testing.T) {
492
493
494 var random [32]byte
495 sct := []byte{0x42, 0x42, 0x42, 0x42}
496 serverHello := &serverHelloMsg{
497 vers: VersionTLS12,
498 random: random[:],
499 scts: [][]byte{sct},
500 }
501 serverHelloBytes := mustMarshal(t, serverHello)
502
503 var serverHelloCopy serverHelloMsg
504 if !serverHelloCopy.unmarshal(serverHelloBytes) {
505 t.Fatal("Failed to unmarshal initial message")
506 }
507
508
509 i := bytes.Index(serverHelloBytes, sct)
510 if i < 0 {
511 t.Fatal("Cannot find SCT in ServerHello")
512 }
513
514 var serverHelloEmptySCT []byte
515 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...)
516
517 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...)
518 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...)
519
520
521 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16)
522 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8)
523 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4)
524
525
526 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
527 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))
528
529 if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
530 t.Fatal("Unmarshaled ServerHello with empty SCT list")
531 }
532 }
533
534 func TestRejectEmptySCT(t *testing.T) {
535
536
537
538 var random [32]byte
539 serverHello := &serverHelloMsg{
540 vers: VersionTLS12,
541 random: random[:],
542 scts: [][]byte{nil},
543 }
544 serverHelloBytes := mustMarshal(t, serverHello)
545
546 var serverHelloCopy serverHelloMsg
547 if serverHelloCopy.unmarshal(serverHelloBytes) {
548 t.Fatal("Unmarshaled ServerHello with zero-length SCT")
549 }
550 }
551
552 func TestRejectDuplicateExtensions(t *testing.T) {
553 clientHelloBytes, err := hex.DecodeString("010000440303000000000000000000000000000000000000000000000000000000000000000000000000001c0000000a000800000568656c6c6f0000000a000800000568656c6c6f")
554 if err != nil {
555 t.Fatalf("failed to decode test ClientHello: %s", err)
556 }
557 var clientHelloCopy clientHelloMsg
558 if clientHelloCopy.unmarshal(clientHelloBytes) {
559 t.Error("Unmarshaled ClientHello with duplicate extensions")
560 }
561
562 serverHelloBytes, err := hex.DecodeString("02000030030300000000000000000000000000000000000000000000000000000000000000000000000000080005000000050000")
563 if err != nil {
564 t.Fatalf("failed to decode test ServerHello: %s", err)
565 }
566 var serverHelloCopy serverHelloMsg
567 if serverHelloCopy.unmarshal(serverHelloBytes) {
568 t.Fatal("Unmarshaled ServerHello with duplicate extensions")
569 }
570 }
571
View as plain text