1
2
3
4
5 package tls
6
7 import (
8 "context"
9 "errors"
10 "fmt"
11 "net"
12 )
13
14
15
16 type QUICEncryptionLevel int
17
18 const (
19 QUICEncryptionLevelInitial = QUICEncryptionLevel(iota)
20 QUICEncryptionLevelEarly
21 QUICEncryptionLevelHandshake
22 QUICEncryptionLevelApplication
23 )
24
25 func (l QUICEncryptionLevel) String() string {
26 switch l {
27 case QUICEncryptionLevelInitial:
28 return "Initial"
29 case QUICEncryptionLevelEarly:
30 return "Early"
31 case QUICEncryptionLevelHandshake:
32 return "Handshake"
33 case QUICEncryptionLevelApplication:
34 return "Application"
35 default:
36 return fmt.Sprintf("QUICEncryptionLevel(%v)", int(l))
37 }
38 }
39
40
41
42
43
44 type QUICConn struct {
45 conn *Conn
46
47 sessionTicketSent bool
48 }
49
50
51 type QUICConfig struct {
52 TLSConfig *Config
53
54
55
56
57
58
59 EnableSessionEvents bool
60
61
62 ClientHelloInfoConn net.Conn
63 }
64
65
66 type QUICEventKind int
67
68 const (
69
70 QUICNoEvent QUICEventKind = iota
71
72
73
74
75
76
77
78 QUICSetReadSecret
79 QUICSetWriteSecret
80
81
82
83 QUICWriteData
84
85
86
87 QUICTransportParameters
88
89
90
91
92
93
94
95
96 QUICTransportParametersRequired
97
98
99
100
101
102 QUICRejectedEarlyData
103
104
105 QUICHandshakeDone
106
107
108
109
110
111
112
113
114
115 QUICResumeSession
116
117
118
119
120
121
122
123 QUICStoreSession
124
125
126
127
128 QUICErrorEvent
129 )
130
131
132
133
134
135 type QUICEvent struct {
136 Kind QUICEventKind
137
138
139 Level QUICEncryptionLevel
140
141
142
143 Data []byte
144
145
146 Suite uint16
147
148
149 SessionState *SessionState
150
151
152
153 Err error
154 }
155
156 type quicState struct {
157 events []QUICEvent
158 nextEvent int
159
160
161
162
163
164 eventArr [8]QUICEvent
165
166 started bool
167 signalc chan struct{}
168 blockedc chan struct{}
169 ctx context.Context
170 cancel context.CancelFunc
171
172 waitingForDrain bool
173 errorReturned bool
174
175
176
177
178 readbuf []byte
179
180 transportParams []byte
181
182 enableSessionEvents bool
183 clientHelloInfoConn net.Conn
184 }
185
186
187
188 func QUICClient(config *QUICConfig) *QUICConn {
189 return newQUICConn(Client(nil, config.TLSConfig), config)
190 }
191
192
193
194 func QUICServer(config *QUICConfig) *QUICConn {
195 return newQUICConn(Server(nil, config.TLSConfig), config)
196 }
197
198 func newQUICConn(conn *Conn, config *QUICConfig) *QUICConn {
199 conn.quic = &quicState{
200 signalc: make(chan struct{}),
201 blockedc: make(chan struct{}),
202 enableSessionEvents: config.EnableSessionEvents,
203 clientHelloInfoConn: config.ClientHelloInfoConn,
204 }
205 conn.quic.events = conn.quic.eventArr[:0]
206 return &QUICConn{
207 conn: conn,
208 }
209 }
210
211
212
213
214
215 func (q *QUICConn) Start(ctx context.Context) error {
216 if q.conn.quic.started {
217 return quicError(errors.New("tls: Start called more than once"))
218 }
219 q.conn.quic.started = true
220 go q.conn.HandshakeContext(ctx)
221 if _, ok := <-q.conn.quic.blockedc; !ok {
222 return q.conn.handshakeErr
223 }
224 return nil
225 }
226
227
228
229 func (q *QUICConn) NextEvent() QUICEvent {
230 qs := q.conn.quic
231 if last := qs.nextEvent - 1; last >= 0 && len(qs.events[last].Data) > 0 {
232
233
234 qs.events[last].Data[0] = 0
235 }
236 if qs.nextEvent >= len(qs.events) && qs.waitingForDrain {
237 qs.waitingForDrain = false
238 <-qs.signalc
239 <-qs.blockedc
240 }
241 if err := q.conn.handshakeErr; err != nil {
242 if qs.errorReturned {
243 return QUICEvent{Kind: QUICNoEvent}
244 }
245 qs.errorReturned = true
246 qs.events = nil
247 qs.nextEvent = 0
248 return QUICEvent{Kind: QUICErrorEvent, Err: q.conn.handshakeErr}
249 }
250 if qs.nextEvent >= len(qs.events) {
251 qs.events = qs.events[:0]
252 qs.nextEvent = 0
253 return QUICEvent{Kind: QUICNoEvent}
254 }
255 e := qs.events[qs.nextEvent]
256 qs.events[qs.nextEvent] = QUICEvent{}
257 qs.nextEvent++
258 return e
259 }
260
261
262 func (q *QUICConn) Close() error {
263 if q.conn.quic.ctx == nil {
264 return nil
265 }
266 q.conn.quic.cancel()
267 <-q.conn.quic.signalc
268 for range q.conn.quic.blockedc {
269
270 }
271 return q.conn.handshakeErr
272 }
273
274
275
276 func (q *QUICConn) HandleData(level QUICEncryptionLevel, data []byte) error {
277 c := q.conn
278 if c.in.level != level {
279 return quicError(c.in.setErrorLocked(errors.New("tls: handshake data received at wrong level")))
280 }
281 c.quic.readbuf = data
282 <-c.quic.signalc
283 _, ok := <-c.quic.blockedc
284 if ok {
285
286 return nil
287 }
288
289 c.handshakeMutex.Lock()
290 defer c.handshakeMutex.Unlock()
291 c.hand.Write(c.quic.readbuf)
292 c.quic.readbuf = nil
293 for q.conn.hand.Len() >= 4 && q.conn.handshakeErr == nil {
294 b := q.conn.hand.Bytes()
295 n := int(b[1])<<16 | int(b[2])<<8 | int(b[3])
296 if n > maxHandshake {
297 q.conn.handshakeErr = fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake)
298 break
299 }
300 if len(b) < 4+n {
301 return nil
302 }
303 if err := q.conn.handlePostHandshakeMessage(); err != nil {
304 q.conn.handshakeErr = err
305 }
306 }
307 if q.conn.handshakeErr != nil {
308 return quicError(q.conn.handshakeErr)
309 }
310 return nil
311 }
312
313 type QUICSessionTicketOptions struct {
314
315 EarlyData bool
316 Extra [][]byte
317 }
318
319
320
321
322 func (q *QUICConn) SendSessionTicket(opts QUICSessionTicketOptions) error {
323 c := q.conn
324 if c.config.SessionTicketsDisabled {
325 return nil
326 }
327 if !c.isHandshakeComplete.Load() {
328 return quicError(errors.New("tls: SendSessionTicket called before handshake completed"))
329 }
330 if c.isClient {
331 return quicError(errors.New("tls: SendSessionTicket called on the client"))
332 }
333 if q.sessionTicketSent {
334 return quicError(errors.New("tls: SendSessionTicket called multiple times"))
335 }
336 q.sessionTicketSent = true
337 return quicError(c.sendSessionTicket(opts.EarlyData, opts.Extra))
338 }
339
340
341
342
343
344 func (q *QUICConn) StoreSession(session *SessionState) error {
345 c := q.conn
346 if !c.isClient {
347 return quicError(errors.New("tls: StoreSessionTicket called on the server"))
348 }
349 cacheKey := c.clientSessionCacheKey()
350 if cacheKey == "" {
351 return nil
352 }
353 cs := &ClientSessionState{session: session}
354 c.config.ClientSessionCache.Put(cacheKey, cs)
355 return nil
356 }
357
358
359 func (q *QUICConn) ConnectionState() ConnectionState {
360 return q.conn.ConnectionState()
361 }
362
363
364
365
366
367 func (q *QUICConn) SetTransportParameters(params []byte) {
368 if params == nil {
369 params = []byte{}
370 }
371 q.conn.quic.transportParams = params
372 if q.conn.quic.started {
373 <-q.conn.quic.signalc
374 <-q.conn.quic.blockedc
375 }
376 }
377
378
379
380 func quicError(err error) error {
381 if err == nil {
382 return nil
383 }
384 if _, ok := errors.AsType[AlertError](err); ok {
385 return err
386 }
387 a, ok := errors.AsType[alert](err)
388 if !ok {
389 a = alertInternalError
390 }
391
392
393 return fmt.Errorf("%w%.0w", err, AlertError(a))
394 }
395
396 func (c *Conn) quicReadHandshakeBytes(n int) error {
397 for c.hand.Len() < n {
398 if err := c.quicWaitForSignal(); err != nil {
399 return err
400 }
401 }
402 return nil
403 }
404
405 func (c *Conn) quicSetReadSecret(level QUICEncryptionLevel, suite uint16, secret []byte) error {
406
407
408
409
410 if c.hand.Len() != 0 {
411 c.sendAlert(alertUnexpectedMessage)
412 return errors.New("tls: handshake buffer not empty before setting read traffic secret")
413 }
414 c.quic.events = append(c.quic.events, QUICEvent{
415 Kind: QUICSetReadSecret,
416 Level: level,
417 Suite: suite,
418 Data: secret,
419 })
420 return nil
421 }
422
423 func (c *Conn) quicSetWriteSecret(level QUICEncryptionLevel, suite uint16, secret []byte) {
424 c.quic.events = append(c.quic.events, QUICEvent{
425 Kind: QUICSetWriteSecret,
426 Level: level,
427 Suite: suite,
428 Data: secret,
429 })
430 }
431
432 func (c *Conn) quicWriteCryptoData(level QUICEncryptionLevel, data []byte) {
433 var last *QUICEvent
434 if len(c.quic.events) > 0 {
435 last = &c.quic.events[len(c.quic.events)-1]
436 }
437 if last == nil || last.Kind != QUICWriteData || last.Level != level {
438 c.quic.events = append(c.quic.events, QUICEvent{
439 Kind: QUICWriteData,
440 Level: level,
441 })
442 last = &c.quic.events[len(c.quic.events)-1]
443 }
444 last.Data = append(last.Data, data...)
445 }
446
447 func (c *Conn) quicResumeSession(session *SessionState) error {
448 c.quic.events = append(c.quic.events, QUICEvent{
449 Kind: QUICResumeSession,
450 SessionState: session,
451 })
452 c.quic.waitingForDrain = true
453 for c.quic.waitingForDrain {
454 if err := c.quicWaitForSignal(); err != nil {
455 return err
456 }
457 }
458 return nil
459 }
460
461 func (c *Conn) quicStoreSession(session *SessionState) {
462 c.quic.events = append(c.quic.events, QUICEvent{
463 Kind: QUICStoreSession,
464 SessionState: session,
465 })
466 }
467
468 func (c *Conn) quicSetTransportParameters(params []byte) {
469 c.quic.events = append(c.quic.events, QUICEvent{
470 Kind: QUICTransportParameters,
471 Data: params,
472 })
473 }
474
475 func (c *Conn) quicGetTransportParameters() ([]byte, error) {
476 if c.quic.transportParams == nil {
477 c.quic.events = append(c.quic.events, QUICEvent{
478 Kind: QUICTransportParametersRequired,
479 })
480 }
481 for c.quic.transportParams == nil {
482 if err := c.quicWaitForSignal(); err != nil {
483 return nil, err
484 }
485 }
486 return c.quic.transportParams, nil
487 }
488
489 func (c *Conn) quicHandshakeComplete() {
490 c.quic.events = append(c.quic.events, QUICEvent{
491 Kind: QUICHandshakeDone,
492 })
493 }
494
495 func (c *Conn) quicRejectedEarlyData() {
496 c.quic.events = append(c.quic.events, QUICEvent{
497 Kind: QUICRejectedEarlyData,
498 })
499 }
500
501
502
503
504
505
506 func (c *Conn) quicWaitForSignal() error {
507
508
509 c.handshakeMutex.Unlock()
510 defer c.handshakeMutex.Lock()
511
512
513
514 c.quic.blockedc <- struct{}{}
515
516
517
518 c.quic.signalc <- struct{}{}
519 if c.quic.ctx.Err() != nil {
520
521 return c.sendAlertLocked(alertCloseNotify)
522 }
523 c.hand.Write(c.quic.readbuf)
524 c.quic.readbuf = nil
525 return nil
526 }
527
View as plain text