Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "internal/synctest"
26 "io"
27 "log"
28 mrand "math/rand"
29 "net"
30 . "net/http"
31 "net/http/httptest"
32 "net/http/httptrace"
33 "net/http/httputil"
34 "net/http/internal/testcert"
35 "net/textproto"
36 "net/url"
37 "os"
38 "reflect"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "testing"
46 "testing/iotest"
47 "time"
48
49 "golang.org/x/net/http/httpguts"
50 )
51
52
53
54
55
56 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
57 if r.FormValue("close") == "true" {
58 w.Header().Set("Connection", "close")
59 }
60 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
61 w.Write([]byte(r.RemoteAddr))
62
63
64
65 if c, ok := ResponseWriterConnForTesting(w); ok {
66 fmt.Fprintf(w, ", %T %p", c, c)
67 }
68 })
69
70
71 type testCloseConn struct {
72 net.Conn
73 set *testConnSet
74 }
75
76 func (c *testCloseConn) Close() error {
77 c.set.remove(c)
78 return c.Conn.Close()
79 }
80
81
82
83 type testConnSet struct {
84 t *testing.T
85 mu sync.Mutex
86 closed map[net.Conn]bool
87 list []net.Conn
88 }
89
90 func (tcs *testConnSet) insert(c net.Conn) {
91 tcs.mu.Lock()
92 defer tcs.mu.Unlock()
93 tcs.closed[c] = false
94 tcs.list = append(tcs.list, c)
95 }
96
97 func (tcs *testConnSet) remove(c net.Conn) {
98 tcs.mu.Lock()
99 defer tcs.mu.Unlock()
100 tcs.closed[c] = true
101 }
102
103
104 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
105 connSet := &testConnSet{
106 t: t,
107 closed: make(map[net.Conn]bool),
108 }
109 dial := func(n, addr string) (net.Conn, error) {
110 c, err := net.Dial(n, addr)
111 if err != nil {
112 return nil, err
113 }
114 tc := &testCloseConn{c, connSet}
115 connSet.insert(tc)
116 return tc, nil
117 }
118 return connSet, dial
119 }
120
121 func (tcs *testConnSet) check(t *testing.T) {
122 tcs.mu.Lock()
123 defer tcs.mu.Unlock()
124 for i := 4; i >= 0; i-- {
125 for i, c := range tcs.list {
126 if tcs.closed[c] {
127 continue
128 }
129 if i != 0 {
130
131
132 tcs.mu.Unlock()
133 time.Sleep(50 * time.Millisecond)
134 tcs.mu.Lock()
135 continue
136 }
137 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
138 }
139 }
140 }
141
142 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
143 func testReuseRequest(t *testing.T, mode testMode) {
144 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
145 w.Write([]byte("{}"))
146 })).ts
147
148 c := ts.Client()
149 req, _ := NewRequest("GET", ts.URL, nil)
150 res, err := c.Do(req)
151 if err != nil {
152 t.Fatal(err)
153 }
154 err = res.Body.Close()
155 if err != nil {
156 t.Fatal(err)
157 }
158
159 res, err = c.Do(req)
160 if err != nil {
161 t.Fatal(err)
162 }
163 err = res.Body.Close()
164 if err != nil {
165 t.Fatal(err)
166 }
167 }
168
169
170
171 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
172 func testTransportKeepAlives(t *testing.T, mode testMode) {
173 ts := newClientServerTest(t, mode, hostPortHandler).ts
174
175 c := ts.Client()
176 for _, disableKeepAlive := range []bool{false, true} {
177 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
178 fetch := func(n int) string {
179 res, err := c.Get(ts.URL)
180 if err != nil {
181 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
182 }
183 body, err := io.ReadAll(res.Body)
184 if err != nil {
185 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
186 }
187 return string(body)
188 }
189
190 body1 := fetch(1)
191 body2 := fetch(2)
192
193 bodiesDiffer := body1 != body2
194 if bodiesDiffer != disableKeepAlive {
195 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
196 disableKeepAlive, bodiesDiffer, body1, body2)
197 }
198 }
199 }
200
201 func TestTransportConnectionCloseOnResponse(t *testing.T) {
202 run(t, testTransportConnectionCloseOnResponse)
203 }
204 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
205 ts := newClientServerTest(t, mode, hostPortHandler).ts
206
207 connSet, testDial := makeTestDial(t)
208
209 c := ts.Client()
210 tr := c.Transport.(*Transport)
211 tr.Dial = testDial
212
213 for _, connectionClose := range []bool{false, true} {
214 fetch := func(n int) string {
215 req := new(Request)
216 var err error
217 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
218 if err != nil {
219 t.Fatalf("URL parse error: %v", err)
220 }
221 req.Method = "GET"
222 req.Proto = "HTTP/1.1"
223 req.ProtoMajor = 1
224 req.ProtoMinor = 1
225
226 res, err := c.Do(req)
227 if err != nil {
228 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
229 }
230 defer res.Body.Close()
231 body, err := io.ReadAll(res.Body)
232 if err != nil {
233 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
234 }
235 return string(body)
236 }
237
238 body1 := fetch(1)
239 body2 := fetch(2)
240 bodiesDiffer := body1 != body2
241 if bodiesDiffer != connectionClose {
242 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
243 connectionClose, bodiesDiffer, body1, body2)
244 }
245
246 tr.CloseIdleConnections()
247 }
248
249 connSet.check(t)
250 }
251
252
253
254
255
256
257
258 func TestTransportConnectionCloseOnRequest(t *testing.T) {
259 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
260 }
261 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
262 ts := newClientServerTest(t, mode, hostPortHandler).ts
263
264 connSet, testDial := makeTestDial(t)
265
266 c := ts.Client()
267 tr := c.Transport.(*Transport)
268 tr.Dial = testDial
269 for _, reqClose := range []bool{false, true} {
270 fetch := func(n int) string {
271 req := new(Request)
272 var err error
273 req.URL, err = url.Parse(ts.URL)
274 if err != nil {
275 t.Fatalf("URL parse error: %v", err)
276 }
277 req.Method = "GET"
278 req.Proto = "HTTP/1.1"
279 req.ProtoMajor = 1
280 req.ProtoMinor = 1
281 req.Close = reqClose
282
283 res, err := c.Do(req)
284 if err != nil {
285 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
286 }
287 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
288 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
289 reqClose, got, !reqClose)
290 }
291 body, err := io.ReadAll(res.Body)
292 if err != nil {
293 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
294 }
295 return string(body)
296 }
297
298 body1 := fetch(1)
299 body2 := fetch(2)
300
301 got := 1
302 if body1 != body2 {
303 got++
304 }
305 want := 1
306 if reqClose {
307 want = 2
308 }
309 if got != want {
310 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
311 reqClose, got, want, body1, body2)
312 }
313
314 tr.CloseIdleConnections()
315 }
316
317 connSet.check(t)
318 }
319
320
321
322
323 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
324 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
325 }
326 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
327 ts := newClientServerTest(t, mode, hostPortHandler).ts
328
329 c := ts.Client()
330 c.Transport.(*Transport).DisableKeepAlives = true
331
332 res, err := c.Get(ts.URL)
333 if err != nil {
334 t.Fatal(err)
335 }
336 res.Body.Close()
337 if res.Header.Get("X-Saw-Close") != "true" {
338 t.Errorf("handler didn't see Connection: close ")
339 }
340 }
341
342
343
344 func TestTransportRespectRequestWantsClose(t *testing.T) {
345 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
346 }
347 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
348 tests := []struct {
349 disableKeepAlives bool
350 close bool
351 }{
352 {disableKeepAlives: false, close: false},
353 {disableKeepAlives: false, close: true},
354 {disableKeepAlives: true, close: false},
355 {disableKeepAlives: true, close: true},
356 }
357
358 for _, tc := range tests {
359 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
360 func(t *testing.T) {
361 ts := newClientServerTest(t, mode, hostPortHandler).ts
362
363 c := ts.Client()
364 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
365 req, err := NewRequest("GET", ts.URL, nil)
366 if err != nil {
367 t.Fatal(err)
368 }
369 count := 0
370 trace := &httptrace.ClientTrace{
371 WroteHeaderField: func(key string, field []string) {
372 if key != "Connection" {
373 return
374 }
375 if httpguts.HeaderValuesContainsToken(field, "close") {
376 count += 1
377 }
378 },
379 }
380 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
381 req.Close = tc.close
382 res, err := c.Do(req)
383 if err != nil {
384 t.Fatal(err)
385 }
386 defer res.Body.Close()
387 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
388 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
389 }
390 })
391 }
392
393 }
394
395 func TestTransportIdleCacheKeys(t *testing.T) {
396 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
397 }
398 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
399 ts := newClientServerTest(t, mode, hostPortHandler).ts
400 c := ts.Client()
401 tr := c.Transport.(*Transport)
402
403 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
404 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
405 }
406
407 resp, err := c.Get(ts.URL)
408 if err != nil {
409 t.Error(err)
410 }
411 io.ReadAll(resp.Body)
412
413 keys := tr.IdleConnKeysForTesting()
414 if e, g := 1, len(keys); e != g {
415 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
416 }
417
418 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
419 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
420 }
421
422 tr.CloseIdleConnections()
423 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
424 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
425 }
426 }
427
428
429
430 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
431 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
432 const msg = "foobar"
433
434 var addrSeen map[string]int
435 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
436 addrSeen[r.RemoteAddr]++
437 if r.URL.Path == "/chunked/" {
438 w.WriteHeader(200)
439 w.(Flusher).Flush()
440 } else {
441 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
442 w.WriteHeader(200)
443 }
444 w.Write([]byte(msg))
445 })).ts
446
447 for pi, path := range []string{"/content-length/", "/chunked/"} {
448 wantLen := []int{len(msg), -1}[pi]
449 addrSeen = make(map[string]int)
450 for i := 0; i < 3; i++ {
451 res, err := ts.Client().Get(ts.URL + path)
452 if err != nil {
453 t.Errorf("Get %s: %v", path, err)
454 continue
455 }
456
457
458
459
460
461 defer res.Body.Close()
462
463 if res.ContentLength != int64(wantLen) {
464 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
465 }
466 got, err := io.ReadAll(res.Body)
467 if string(got) != msg || err != nil {
468 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
469 }
470 }
471 if len(addrSeen) != 1 {
472 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
473 }
474 }
475 }
476
477 func TestTransportMaxPerHostIdleConns(t *testing.T) {
478 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
479 }
480 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
481 stop := make(chan struct{})
482 defer close(stop)
483
484 resch := make(chan string)
485 gotReq := make(chan bool)
486 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
487 gotReq <- true
488 var msg string
489 select {
490 case <-stop:
491 return
492 case msg = <-resch:
493 }
494 _, err := w.Write([]byte(msg))
495 if err != nil {
496 t.Errorf("Write: %v", err)
497 return
498 }
499 })).ts
500
501 c := ts.Client()
502 tr := c.Transport.(*Transport)
503 maxIdleConnsPerHost := 2
504 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
505
506
507
508 donech := make(chan bool)
509 doReq := func() {
510 defer func() {
511 select {
512 case <-stop:
513 return
514 case donech <- t.Failed():
515 }
516 }()
517 resp, err := c.Get(ts.URL)
518 if err != nil {
519 t.Error(err)
520 return
521 }
522 if _, err := io.ReadAll(resp.Body); err != nil {
523 t.Errorf("ReadAll: %v", err)
524 return
525 }
526 }
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531 go doReq()
532 <-gotReq
533
534 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
535 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
536 }
537
538 resch <- "res1"
539 <-donech
540 keys := tr.IdleConnKeysForTesting()
541 if e, g := 1, len(keys); e != g {
542 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
543 }
544 addr := ts.Listener.Addr().String()
545 cacheKey := "|http|" + addr
546 if keys[0] != cacheKey {
547 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
548 }
549 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
550 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
551 }
552
553 resch <- "res2"
554 <-donech
555 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
556 t.Errorf("after second response, idle conns = %d; want %d", g, w)
557 }
558
559 resch <- "res3"
560 <-donech
561 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
562 t.Errorf("after third response, idle conns = %d; want %d", g, w)
563 }
564 }
565
566 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
567 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
568 }
569 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
570 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
571 _, err := w.Write([]byte("foo"))
572 if err != nil {
573 t.Fatalf("Write: %v", err)
574 }
575 })).ts
576 c := ts.Client()
577 tr := c.Transport.(*Transport)
578 dialStarted := make(chan struct{})
579 stallDial := make(chan struct{})
580 tr.Dial = func(network, addr string) (net.Conn, error) {
581 dialStarted <- struct{}{}
582 <-stallDial
583 return net.Dial(network, addr)
584 }
585
586 tr.DisableKeepAlives = true
587 tr.MaxConnsPerHost = 1
588
589 preDial := make(chan struct{})
590 reqComplete := make(chan struct{})
591 doReq := func(reqId string) {
592 req, _ := NewRequest("GET", ts.URL, nil)
593 trace := &httptrace.ClientTrace{
594 GetConn: func(hostPort string) {
595 preDial <- struct{}{}
596 },
597 }
598 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
599 resp, err := tr.RoundTrip(req)
600 if err != nil {
601 t.Errorf("unexpected error for request %s: %v", reqId, err)
602 }
603 _, err = io.ReadAll(resp.Body)
604 if err != nil {
605 t.Errorf("unexpected error for request %s: %v", reqId, err)
606 }
607 reqComplete <- struct{}{}
608 }
609
610 go doReq("req1")
611 <-preDial
612 <-dialStarted
613
614
615 go doReq("req2")
616 <-preDial
617 select {
618 case <-dialStarted:
619 t.Error("req2 dial started while req1 dial in progress")
620 return
621 default:
622 }
623
624
625 stallDial <- struct{}{}
626 <-reqComplete
627
628
629 <-dialStarted
630 stallDial <- struct{}{}
631 <-reqComplete
632 }
633
634 func TestTransportMaxConnsPerHost(t *testing.T) {
635 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
636 }
637 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
638 CondSkipHTTP2(t)
639
640 h := HandlerFunc(func(w ResponseWriter, r *Request) {
641 _, err := w.Write([]byte("foo"))
642 if err != nil {
643 t.Fatalf("Write: %v", err)
644 }
645 })
646
647 ts := newClientServerTest(t, mode, h).ts
648 c := ts.Client()
649 tr := c.Transport.(*Transport)
650 tr.MaxConnsPerHost = 1
651
652 mu := sync.Mutex{}
653 var conns []net.Conn
654 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
655 tr.Dial = func(network, addr string) (net.Conn, error) {
656 atomic.AddInt32(&dialCnt, 1)
657 c, err := net.Dial(network, addr)
658 mu.Lock()
659 defer mu.Unlock()
660 conns = append(conns, c)
661 return c, err
662 }
663
664 doReq := func() {
665 trace := &httptrace.ClientTrace{
666 GotConn: func(connInfo httptrace.GotConnInfo) {
667 if !connInfo.Reused {
668 atomic.AddInt32(&gotConnCnt, 1)
669 }
670 },
671 TLSHandshakeStart: func() {
672 atomic.AddInt32(&tlsHandshakeCnt, 1)
673 },
674 }
675 req, _ := NewRequest("GET", ts.URL, nil)
676 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
677
678 resp, err := c.Do(req)
679 if err != nil {
680 t.Fatalf("request failed: %v", err)
681 }
682 defer resp.Body.Close()
683 _, err = io.ReadAll(resp.Body)
684 if err != nil {
685 t.Fatalf("read body failed: %v", err)
686 }
687 }
688
689 wg := sync.WaitGroup{}
690 for i := 0; i < 10; i++ {
691 wg.Add(1)
692 go func() {
693 defer wg.Done()
694 doReq()
695 }()
696 }
697 wg.Wait()
698
699 expected := int32(tr.MaxConnsPerHost)
700 if dialCnt != expected {
701 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
702 }
703 if gotConnCnt != expected {
704 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
705 }
706 if ts.TLS != nil && tlsHandshakeCnt != expected {
707 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
708 }
709
710 if t.Failed() {
711 t.FailNow()
712 }
713
714 mu.Lock()
715 for _, c := range conns {
716 c.Close()
717 }
718 conns = nil
719 mu.Unlock()
720 tr.CloseIdleConnections()
721
722 doReq()
723 expected++
724 if dialCnt != expected {
725 t.Errorf("round 2: too many dials: %d", dialCnt)
726 }
727 if gotConnCnt != expected {
728 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
729 }
730 if ts.TLS != nil && tlsHandshakeCnt != expected {
731 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
732 }
733 }
734
735 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
736 run(t, testTransportMaxConnsPerHostDialCancellation,
737 testNotParallel,
738 []testMode{http1Mode, https1Mode, http2Mode},
739 )
740 }
741
742 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
743 CondSkipHTTP2(t)
744
745 h := HandlerFunc(func(w ResponseWriter, r *Request) {
746 _, err := w.Write([]byte("foo"))
747 if err != nil {
748 t.Fatalf("Write: %v", err)
749 }
750 })
751
752 cst := newClientServerTest(t, mode, h)
753 defer cst.close()
754 ts := cst.ts
755 c := ts.Client()
756 tr := c.Transport.(*Transport)
757 tr.MaxConnsPerHost = 1
758
759
760 ctx, cancel := context.WithCancel(context.Background())
761 defer cancel()
762 SetPendingDialHooks(cancel, nil)
763 defer SetPendingDialHooks(nil, nil)
764
765 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
766 _, err := c.Do(req)
767 if !errors.Is(err, context.Canceled) {
768 t.Errorf("expected error %v, got %v", context.Canceled, err)
769 }
770
771
772 SetPendingDialHooks(nil, nil)
773 req, _ = NewRequest("GET", ts.URL, nil)
774 resp, err := c.Do(req)
775 if err != nil {
776 t.Fatalf("request failed: %v", err)
777 }
778 defer resp.Body.Close()
779 _, err = io.ReadAll(resp.Body)
780 if err != nil {
781 t.Fatalf("read body failed: %v", err)
782 }
783 }
784
785 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
786 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
787 }
788 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
789 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
790 io.WriteString(w, r.RemoteAddr)
791 })).ts
792
793 c := ts.Client()
794 tr := c.Transport.(*Transport)
795
796 doReq := func(name string) {
797
798
799 res, err := c.Post(ts.URL, "", nil)
800 if err != nil {
801 t.Fatalf("%s: %v", name, err)
802 }
803 if res.StatusCode != 200 {
804 t.Fatalf("%s: %v", name, res.Status)
805 }
806 defer res.Body.Close()
807 slurp, err := io.ReadAll(res.Body)
808 if err != nil {
809 t.Fatalf("%s: %v", name, err)
810 }
811 t.Logf("%s: ok (%q)", name, slurp)
812 }
813
814 doReq("first")
815 keys1 := tr.IdleConnKeysForTesting()
816
817 ts.CloseClientConnections()
818
819 var keys2 []string
820 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
821 keys2 = tr.IdleConnKeysForTesting()
822 if len(keys2) != 0 {
823 if d > 0 {
824 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
825 }
826 return false
827 }
828 return true
829 })
830
831 doReq("second")
832 }
833
834
835
836 func TestTransportServerClosingUnexpectedly(t *testing.T) {
837 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
838 }
839 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
840 ts := newClientServerTest(t, mode, hostPortHandler).ts
841 c := ts.Client()
842
843 fetch := func(n, retries int) string {
844 condFatalf := func(format string, arg ...any) {
845 if retries <= 0 {
846 t.Fatalf(format, arg...)
847 }
848 t.Logf("retrying shortly after expected error: "+format, arg...)
849 time.Sleep(time.Second / time.Duration(retries))
850 }
851 for retries >= 0 {
852 retries--
853 res, err := c.Get(ts.URL)
854 if err != nil {
855 condFatalf("error in req #%d, GET: %v", n, err)
856 continue
857 }
858 body, err := io.ReadAll(res.Body)
859 if err != nil {
860 condFatalf("error in req #%d, ReadAll: %v", n, err)
861 continue
862 }
863 res.Body.Close()
864 return string(body)
865 }
866 panic("unreachable")
867 }
868
869 body1 := fetch(1, 0)
870 body2 := fetch(2, 0)
871
872
873
874
875
876
877
878
879 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
880
881 body3 := fetch(3, 5)
882
883 if body1 != body2 {
884 t.Errorf("expected body1 and body2 to be equal")
885 }
886 if body2 == body3 {
887 t.Errorf("expected body2 and body3 to be different")
888 }
889 }
890
891
892
893 func TestStressSurpriseServerCloses(t *testing.T) {
894 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
895 }
896 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
897 if testing.Short() {
898 t.Skip("skipping test in short mode")
899 }
900 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
901 w.Header().Set("Content-Length", "5")
902 w.Header().Set("Content-Type", "text/plain")
903 w.Write([]byte("Hello"))
904 w.(Flusher).Flush()
905 conn, buf, _ := w.(Hijacker).Hijack()
906 buf.Flush()
907 conn.Close()
908 })).ts
909 c := ts.Client()
910
911
912
913
914
915
916
917 const (
918 numClients = 20
919 reqsPerClient = 25
920 )
921 var wg sync.WaitGroup
922 wg.Add(numClients * reqsPerClient)
923 for i := 0; i < numClients; i++ {
924 go func() {
925 for i := 0; i < reqsPerClient; i++ {
926 res, err := c.Get(ts.URL)
927 if err == nil {
928
929
930
931
932
933
934 res.Body.Close()
935 }
936 wg.Done()
937 }
938 }()
939 }
940
941
942 wg.Wait()
943 }
944
945
946
947 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
948 func testTransportHeadResponses(t *testing.T, mode testMode) {
949 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
950 if r.Method != "HEAD" {
951 panic("expected HEAD; got " + r.Method)
952 }
953 w.Header().Set("Content-Length", "123")
954 w.WriteHeader(200)
955 })).ts
956 c := ts.Client()
957
958 for i := 0; i < 2; i++ {
959 res, err := c.Head(ts.URL)
960 if err != nil {
961 t.Errorf("error on loop %d: %v", i, err)
962 continue
963 }
964 if e, g := "123", res.Header.Get("Content-Length"); e != g {
965 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
966 }
967 if e, g := int64(123), res.ContentLength; e != g {
968 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
969 }
970 if all, err := io.ReadAll(res.Body); err != nil {
971 t.Errorf("loop %d: Body ReadAll: %v", i, err)
972 } else if len(all) != 0 {
973 t.Errorf("Bogus body %q", all)
974 }
975 }
976 }
977
978
979
980 func TestTransportHeadChunkedResponse(t *testing.T) {
981 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
982 }
983 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
984 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
985 if r.Method != "HEAD" {
986 panic("expected HEAD; got " + r.Method)
987 }
988 w.Header().Set("Transfer-Encoding", "chunked")
989 w.Header().Set("x-client-ipport", r.RemoteAddr)
990 w.WriteHeader(200)
991 })).ts
992 c := ts.Client()
993
994
995
996 didRead := make(chan bool)
997 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
998 defer SetReadLoopBeforeNextReadHook(nil)
999
1000 res1, err := c.Head(ts.URL)
1001 <-didRead
1002
1003 if err != nil {
1004 t.Fatalf("request 1 error: %v", err)
1005 }
1006
1007 res2, err := c.Head(ts.URL)
1008 <-didRead
1009
1010 if err != nil {
1011 t.Fatalf("request 2 error: %v", err)
1012 }
1013 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1014 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1015 }
1016 }
1017
1018 var roundTripTests = []struct {
1019 accept string
1020 expectAccept string
1021 compressed bool
1022 }{
1023
1024 {"", "gzip", false},
1025
1026 {"foo", "foo", false},
1027
1028 {"gzip", "gzip", true},
1029 }
1030
1031
1032 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1033 func testRoundTripGzip(t *testing.T, mode testMode) {
1034 const responseBody = "test response body"
1035 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1036 accept := req.Header.Get("Accept-Encoding")
1037 if expect := req.FormValue("expect_accept"); accept != expect {
1038 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1039 req.FormValue("testnum"), accept, expect)
1040 }
1041 if accept == "gzip" {
1042 rw.Header().Set("Content-Encoding", "gzip")
1043 gz := gzip.NewWriter(rw)
1044 gz.Write([]byte(responseBody))
1045 gz.Close()
1046 } else {
1047 rw.Header().Set("Content-Encoding", accept)
1048 rw.Write([]byte(responseBody))
1049 }
1050 })).ts
1051 tr := ts.Client().Transport.(*Transport)
1052
1053 for i, test := range roundTripTests {
1054
1055 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1056 if test.accept != "" {
1057 req.Header.Set("Accept-Encoding", test.accept)
1058 }
1059 res, err := tr.RoundTrip(req)
1060 if err != nil {
1061 t.Errorf("%d. RoundTrip: %v", i, err)
1062 continue
1063 }
1064 var body []byte
1065 if test.compressed {
1066 var r *gzip.Reader
1067 r, err = gzip.NewReader(res.Body)
1068 if err != nil {
1069 t.Errorf("%d. gzip NewReader: %v", i, err)
1070 continue
1071 }
1072 body, err = io.ReadAll(r)
1073 res.Body.Close()
1074 } else {
1075 body, err = io.ReadAll(res.Body)
1076 }
1077 if err != nil {
1078 t.Errorf("%d. Error: %q", i, err)
1079 continue
1080 }
1081 if g, e := string(body), responseBody; g != e {
1082 t.Errorf("%d. body = %q; want %q", i, g, e)
1083 }
1084 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1085 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1086 }
1087 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1088 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1089 }
1090 }
1091
1092 }
1093
1094 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1095 func testTransportGzip(t *testing.T, mode testMode) {
1096 if mode == http2Mode {
1097 t.Skip("https://go.dev/issue/56020")
1098 }
1099 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1100 const nRandBytes = 1024 * 1024
1101 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1102 if req.Method == "HEAD" {
1103 if g := req.Header.Get("Accept-Encoding"); g != "" {
1104 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1105 }
1106 return
1107 }
1108 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1109 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1110 }
1111 rw.Header().Set("Content-Encoding", "gzip")
1112
1113 var w io.Writer = rw
1114 var buf bytes.Buffer
1115 if req.FormValue("chunked") == "0" {
1116 w = &buf
1117 defer io.Copy(rw, &buf)
1118 defer func() {
1119 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1120 }()
1121 }
1122 gz := gzip.NewWriter(w)
1123 gz.Write([]byte(testString))
1124 if req.FormValue("body") == "large" {
1125 io.CopyN(gz, rand.Reader, nRandBytes)
1126 }
1127 gz.Close()
1128 })).ts
1129 c := ts.Client()
1130
1131 for _, chunked := range []string{"1", "0"} {
1132
1133 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1134 if err != nil {
1135 t.Fatalf("large get: %v", err)
1136 }
1137 buf := make([]byte, len(testString))
1138 n, err := io.ReadFull(res.Body, buf)
1139 if err != nil {
1140 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1141 }
1142 if e, g := testString, string(buf); e != g {
1143 t.Errorf("partial read got %q, expected %q", g, e)
1144 }
1145 res.Body.Close()
1146
1147 n, err = res.Body.Read(buf)
1148 if n != 0 || err == nil {
1149 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1150 }
1151
1152
1153 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1154 if err != nil {
1155 t.Fatal(err)
1156 }
1157 body, err := io.ReadAll(res.Body)
1158 if err != nil {
1159 t.Fatal(err)
1160 }
1161 if g, e := string(body), testString; g != e {
1162 t.Fatalf("body = %q; want %q", g, e)
1163 }
1164 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1165 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1166 }
1167
1168
1169 n, err = res.Body.Read(buf)
1170 if n != 0 || err == nil {
1171 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1172 }
1173 res.Body.Close()
1174 n, err = res.Body.Read(buf)
1175 if n != 0 || err == nil {
1176 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1177 }
1178 }
1179
1180
1181 res, err := c.Head(ts.URL)
1182 if err != nil {
1183 t.Fatalf("Head: %v", err)
1184 }
1185 if res.StatusCode != 200 {
1186 t.Errorf("Head status=%d; want=200", res.StatusCode)
1187 }
1188 }
1189
1190
1191
1192 type transport100ContinueTest struct {
1193 t *testing.T
1194
1195 reqdone chan struct{}
1196 resp *Response
1197 respErr error
1198
1199 conn net.Conn
1200 reader *bufio.Reader
1201 }
1202
1203 const transport100ContinueTestBody = "request body"
1204
1205
1206
1207 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1208 ln := newLocalListener(t)
1209 defer ln.Close()
1210
1211 test := &transport100ContinueTest{
1212 t: t,
1213 reqdone: make(chan struct{}),
1214 }
1215
1216 tr := &Transport{
1217 ExpectContinueTimeout: timeout,
1218 }
1219 go func() {
1220 defer close(test.reqdone)
1221 body := strings.NewReader(transport100ContinueTestBody)
1222 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1223 req.Header.Set("Expect", "100-continue")
1224 req.ContentLength = int64(len(transport100ContinueTestBody))
1225 test.resp, test.respErr = tr.RoundTrip(req)
1226 test.resp.Body.Close()
1227 }()
1228
1229 c, err := ln.Accept()
1230 if err != nil {
1231 t.Fatalf("Accept: %v", err)
1232 }
1233 t.Cleanup(func() {
1234 c.Close()
1235 })
1236 br := bufio.NewReader(c)
1237 _, err = ReadRequest(br)
1238 if err != nil {
1239 t.Fatalf("ReadRequest: %v", err)
1240 }
1241 test.conn = c
1242 test.reader = br
1243 t.Cleanup(func() {
1244 <-test.reqdone
1245 tr.CloseIdleConnections()
1246 got, _ := io.ReadAll(test.reader)
1247 if len(got) > 0 {
1248 t.Fatalf("Transport sent unexpected bytes: %q", got)
1249 }
1250 })
1251
1252 return test
1253 }
1254
1255
1256 func (test *transport100ContinueTest) respond(lines ...string) {
1257 for _, line := range lines {
1258 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1259 test.t.Fatalf("Write: %v", err)
1260 }
1261 }
1262 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1263 test.t.Fatalf("Write: %v", err)
1264 }
1265 }
1266
1267
1268 func (test *transport100ContinueTest) wantBodySent() {
1269 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1270 if err != nil {
1271 test.t.Fatalf("unexpected error reading body: %v", err)
1272 }
1273 if got, want := string(got), transport100ContinueTestBody; got != want {
1274 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1275 }
1276 }
1277
1278
1279 func (test *transport100ContinueTest) wantRequestDone(want int) {
1280 <-test.reqdone
1281 if test.respErr != nil {
1282 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1283 }
1284 if got := test.resp.StatusCode; got != want {
1285 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1286 }
1287 }
1288
1289 func TestTransportExpect100ContinueSent(t *testing.T) {
1290 test := newTransport100ContinueTest(t, 1*time.Hour)
1291
1292 test.respond("HTTP/1.1 100 Continue")
1293 test.wantBodySent()
1294 test.respond("HTTP/1.1 200", "Content-Length: 0")
1295 test.wantRequestDone(200)
1296 }
1297
1298 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1299 test := newTransport100ContinueTest(t, 1*time.Hour)
1300
1301 test.respond("HTTP/1.1 200", "Content-Length: 0")
1302 test.wantBodySent()
1303 test.wantRequestDone(200)
1304 }
1305
1306 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1307 test := newTransport100ContinueTest(t, 1*time.Hour)
1308
1309 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1310 test.wantRequestDone(200)
1311 }
1312
1313 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1314 test := newTransport100ContinueTest(t, 1*time.Hour)
1315
1316 test.respond("HTTP/1.1 500", "Content-Length: 0")
1317 test.wantBodySent()
1318 test.wantRequestDone(500)
1319 }
1320
1321 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1322 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1323 test.wantBodySent()
1324 test.respond("HTTP/1.1 200", "Content-Length: 0")
1325 test.wantRequestDone(200)
1326 }
1327
1328 func TestSOCKS5Proxy(t *testing.T) {
1329 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1330 }
1331 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1332 ch := make(chan string, 1)
1333 l := newLocalListener(t)
1334 defer l.Close()
1335 defer close(ch)
1336 proxy := func(t *testing.T) {
1337 s, err := l.Accept()
1338 if err != nil {
1339 t.Errorf("socks5 proxy Accept(): %v", err)
1340 return
1341 }
1342 defer s.Close()
1343 var buf [22]byte
1344 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1345 t.Errorf("socks5 proxy initial read: %v", err)
1346 return
1347 }
1348 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1349 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1350 return
1351 }
1352 if _, err := s.Write([]byte{5, 0}); err != nil {
1353 t.Errorf("socks5 proxy initial write: %v", err)
1354 return
1355 }
1356 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1357 t.Errorf("socks5 proxy second read: %v", err)
1358 return
1359 }
1360 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1361 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1362 return
1363 }
1364 var ipLen int
1365 switch buf[3] {
1366 case 1:
1367 ipLen = net.IPv4len
1368 case 4:
1369 ipLen = net.IPv6len
1370 default:
1371 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1372 return
1373 }
1374 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1375 t.Errorf("socks5 proxy address read: %v", err)
1376 return
1377 }
1378 ip := net.IP(buf[4 : ipLen+4])
1379 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1380 copy(buf[:3], []byte{5, 0, 0})
1381 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1382 t.Errorf("socks5 proxy connect write: %v", err)
1383 return
1384 }
1385 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1386
1387
1388 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1389 targetConn, err := net.Dial("tcp", targetHost)
1390 if err != nil {
1391 t.Errorf("net.Dial failed")
1392 return
1393 }
1394 go io.Copy(targetConn, s)
1395 io.Copy(s, targetConn)
1396 targetConn.Close()
1397 }
1398
1399 pu, err := url.Parse("socks5://" + l.Addr().String())
1400 if err != nil {
1401 t.Fatal(err)
1402 }
1403
1404 sentinelHeader := "X-Sentinel"
1405 sentinelValue := "12345"
1406 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1407 w.Header().Set(sentinelHeader, sentinelValue)
1408 })
1409 for _, useTLS := range []bool{false, true} {
1410 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1411 ts := newClientServerTest(t, mode, h).ts
1412 go proxy(t)
1413 c := ts.Client()
1414 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1415 r, err := c.Head(ts.URL)
1416 if err != nil {
1417 t.Fatal(err)
1418 }
1419 if r.Header.Get(sentinelHeader) != sentinelValue {
1420 t.Errorf("Failed to retrieve sentinel value")
1421 }
1422 got := <-ch
1423 ts.Close()
1424 tsu, err := url.Parse(ts.URL)
1425 if err != nil {
1426 t.Fatal(err)
1427 }
1428 want := "proxy for " + tsu.Host
1429 if got != want {
1430 t.Errorf("got %q, want %q", got, want)
1431 }
1432 })
1433 }
1434 }
1435
1436 func TestTransportProxy(t *testing.T) {
1437 defer afterTest(t)
1438 testCases := []struct{ siteMode, proxyMode testMode }{
1439 {http1Mode, http1Mode},
1440 {http1Mode, https1Mode},
1441 {https1Mode, http1Mode},
1442 {https1Mode, https1Mode},
1443 }
1444 for _, testCase := range testCases {
1445 siteMode := testCase.siteMode
1446 proxyMode := testCase.proxyMode
1447 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1448 siteCh := make(chan *Request, 1)
1449 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1450 siteCh <- r
1451 })
1452 proxyCh := make(chan *Request, 1)
1453 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1454 proxyCh <- r
1455
1456 if r.Method == "CONNECT" {
1457 hijacker, ok := w.(Hijacker)
1458 if !ok {
1459 t.Errorf("hijack not allowed")
1460 return
1461 }
1462 clientConn, _, err := hijacker.Hijack()
1463 if err != nil {
1464 t.Errorf("hijacking failed")
1465 return
1466 }
1467 res := &Response{
1468 StatusCode: StatusOK,
1469 Proto: "HTTP/1.1",
1470 ProtoMajor: 1,
1471 ProtoMinor: 1,
1472 Header: make(Header),
1473 }
1474
1475 targetConn, err := net.Dial("tcp", r.URL.Host)
1476 if err != nil {
1477 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1478 return
1479 }
1480
1481 if err := res.Write(clientConn); err != nil {
1482 t.Errorf("Writing 200 OK failed: %v", err)
1483 return
1484 }
1485
1486 go io.Copy(targetConn, clientConn)
1487 go func() {
1488 io.Copy(clientConn, targetConn)
1489 targetConn.Close()
1490 }()
1491 }
1492 })
1493 ts := newClientServerTest(t, siteMode, h1).ts
1494 proxy := newClientServerTest(t, proxyMode, h2).ts
1495
1496 pu, err := url.Parse(proxy.URL)
1497 if err != nil {
1498 t.Fatal(err)
1499 }
1500
1501
1502
1503
1504 c := proxy.Client()
1505 if siteMode == https1Mode {
1506 c = ts.Client()
1507 }
1508
1509 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1510 if _, err := c.Head(ts.URL); err != nil {
1511 t.Error(err)
1512 }
1513 got := <-proxyCh
1514 c.Transport.(*Transport).CloseIdleConnections()
1515 ts.Close()
1516 proxy.Close()
1517 if siteMode == https1Mode {
1518
1519 if got.Method != "CONNECT" {
1520 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1521 }
1522 gotHost := got.URL.Host
1523 pu, err := url.Parse(ts.URL)
1524 if err != nil {
1525 t.Fatal("Invalid site URL")
1526 }
1527 if wantHost := pu.Host; gotHost != wantHost {
1528 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1529 }
1530
1531
1532 next := <-siteCh
1533 if next.Method != "HEAD" {
1534 t.Errorf("Wrong method at destination: %s", next.Method)
1535 }
1536 if nextURL := next.URL.String(); nextURL != "/" {
1537 t.Errorf("Wrong URL at destination: %s", nextURL)
1538 }
1539 } else {
1540 if got.Method != "HEAD" {
1541 t.Errorf("Wrong method for destination: %q", got.Method)
1542 }
1543 gotURL := got.URL.String()
1544 wantURL := ts.URL + "/"
1545 if gotURL != wantURL {
1546 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1547 }
1548 }
1549 })
1550 }
1551 }
1552
1553
1554
1555
1556 func TestProxyWithInfiniteHeader(t *testing.T) {
1557 defer afterTest(t)
1558
1559 ln := newLocalListener(t)
1560 defer ln.Close()
1561 cancelc := make(chan struct{})
1562 defer close(cancelc)
1563
1564
1565
1566 go func() {
1567 c, err := ln.Accept()
1568 if err != nil {
1569 t.Errorf("Accept: %v", err)
1570 return
1571 }
1572 defer c.Close()
1573
1574 br := bufio.NewReader(c)
1575 cr, err := ReadRequest(br)
1576 if err != nil {
1577 t.Errorf("proxy server failed to read CONNECT request")
1578 return
1579 }
1580 if cr.Method != "CONNECT" {
1581 t.Errorf("unexpected method %q", cr.Method)
1582 return
1583 }
1584
1585
1586 for {
1587
1588
1589
1590 runtime.Gosched()
1591 select {
1592 case <-cancelc:
1593 return
1594 default:
1595 c.Write([]byte("infinite stream of bytes"))
1596 }
1597 }
1598 }()
1599
1600 c := &Client{
1601 Transport: &Transport{
1602 Proxy: func(*Request) (*url.URL, error) {
1603 return url.Parse("http://" + ln.Addr().String())
1604 },
1605
1606 MaxResponseHeaderBytes: 1024,
1607 },
1608 }
1609 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1610 if err != nil {
1611 t.Fatal(err)
1612 }
1613 _, err = c.Do(req)
1614 if err == nil {
1615 t.Errorf("unexpected Get success")
1616 }
1617 }
1618
1619 func TestOnProxyConnectResponse(t *testing.T) {
1620
1621 var tcases = []struct {
1622 proxyStatusCode int
1623 err error
1624 }{
1625 {
1626 StatusOK,
1627 nil,
1628 },
1629 {
1630 StatusForbidden,
1631 errors.New("403"),
1632 },
1633 }
1634 for _, tcase := range tcases {
1635 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1636
1637 })
1638
1639 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1640
1641 if r.Method == "CONNECT" {
1642 if tcase.proxyStatusCode != StatusOK {
1643 w.WriteHeader(tcase.proxyStatusCode)
1644 return
1645 }
1646 hijacker, ok := w.(Hijacker)
1647 if !ok {
1648 t.Errorf("hijack not allowed")
1649 return
1650 }
1651 clientConn, _, err := hijacker.Hijack()
1652 if err != nil {
1653 t.Errorf("hijacking failed")
1654 return
1655 }
1656 res := &Response{
1657 StatusCode: StatusOK,
1658 Proto: "HTTP/1.1",
1659 ProtoMajor: 1,
1660 ProtoMinor: 1,
1661 Header: make(Header),
1662 }
1663
1664 targetConn, err := net.Dial("tcp", r.URL.Host)
1665 if err != nil {
1666 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1667 return
1668 }
1669
1670 if err := res.Write(clientConn); err != nil {
1671 t.Errorf("Writing 200 OK failed: %v", err)
1672 return
1673 }
1674
1675 go io.Copy(targetConn, clientConn)
1676 go func() {
1677 io.Copy(clientConn, targetConn)
1678 targetConn.Close()
1679 }()
1680 }
1681 })
1682 ts := newClientServerTest(t, https1Mode, h1).ts
1683 proxy := newClientServerTest(t, https1Mode, h2).ts
1684
1685 pu, err := url.Parse(proxy.URL)
1686 if err != nil {
1687 t.Fatal(err)
1688 }
1689
1690 c := proxy.Client()
1691
1692 var (
1693 dials atomic.Int32
1694 closes atomic.Int32
1695 )
1696 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1697 conn, err := net.Dial(network, addr)
1698 if err != nil {
1699 return nil, err
1700 }
1701 dials.Add(1)
1702 return noteCloseConn{
1703 Conn: conn,
1704 closeFunc: func() {
1705 closes.Add(1)
1706 },
1707 }, nil
1708 }
1709
1710 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1711 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1712 if proxyURL.String() != pu.String() {
1713 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1714 }
1715
1716 if "https://"+connectReq.URL.String() != ts.URL {
1717 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1718 }
1719 return tcase.err
1720 }
1721 wantCloses := int32(0)
1722 if _, err := c.Head(ts.URL); err != nil {
1723 wantCloses = 1
1724 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1725 t.Errorf("got %v, want %v", err, tcase.err)
1726 }
1727 } else {
1728 if tcase.err != nil {
1729 t.Errorf("got %v, want nil", err)
1730 }
1731 }
1732 if got, want := dials.Load(), int32(1); got != want {
1733 t.Errorf("got %v dials, want %v", got, want)
1734 }
1735
1736 if got, want := closes.Load(), wantCloses; got != want {
1737 t.Errorf("got %v closes, want %v", got, want)
1738 }
1739 }
1740 }
1741
1742
1743
1744 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1745 cancelc := make(chan struct{})
1746 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1747 ctx, cancel := context.WithCancel(ctx)
1748 go func() {
1749 select {
1750 case <-cancelc:
1751 case <-ctx.Done():
1752 }
1753 cancel()
1754 }()
1755 return ctx, cancel
1756 })
1757
1758 defer afterTest(t)
1759
1760 ln := newLocalListener(t)
1761 defer ln.Close()
1762 listenerDone := make(chan struct{})
1763 go func() {
1764 defer close(listenerDone)
1765 c, err := ln.Accept()
1766 if err != nil {
1767 t.Errorf("Accept: %v", err)
1768 return
1769 }
1770 defer c.Close()
1771
1772 br := bufio.NewReader(c)
1773 cr, err := ReadRequest(br)
1774 if err != nil {
1775 t.Errorf("proxy server failed to read CONNECT request")
1776 return
1777 }
1778 if cr.Method != "CONNECT" {
1779 t.Errorf("unexpected method %q", cr.Method)
1780 return
1781 }
1782
1783
1784
1785
1786 close(cancelc)
1787 var buf [1]byte
1788 _, err = br.Read(buf[:])
1789 if err != io.EOF {
1790 t.Errorf("proxy server Read err = %v; want EOF", err)
1791 }
1792 return
1793 }()
1794
1795 c := &Client{
1796 Transport: &Transport{
1797 Proxy: func(*Request) (*url.URL, error) {
1798 return url.Parse("http://" + ln.Addr().String())
1799 },
1800 },
1801 }
1802 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1803 if err != nil {
1804 t.Fatal(err)
1805 }
1806 _, err = c.Do(req)
1807 if err == nil {
1808 t.Errorf("unexpected Get success")
1809 }
1810
1811
1812
1813
1814 <-listenerDone
1815 }
1816
1817
1818 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1819 defer afterTest(t)
1820
1821 var errDial = errors.New("some dial error")
1822
1823 tr := &Transport{
1824 Proxy: func(*Request) (*url.URL, error) {
1825 return url.Parse("http://proxy.fake.tld/")
1826 },
1827 Dial: func(string, string) (net.Conn, error) {
1828 return nil, errDial
1829 },
1830 }
1831 defer tr.CloseIdleConnections()
1832
1833 c := &Client{Transport: tr}
1834 req, _ := NewRequest("GET", "http://fake.tld", nil)
1835 res, err := c.Do(req)
1836 if err == nil {
1837 res.Body.Close()
1838 t.Fatal("wanted a non-nil error")
1839 }
1840
1841 uerr, ok := err.(*url.Error)
1842 if !ok {
1843 t.Fatalf("got %T, want *url.Error", err)
1844 }
1845 oe, ok := uerr.Err.(*net.OpError)
1846 if !ok {
1847 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1848 }
1849 want := &net.OpError{
1850 Op: "proxyconnect",
1851 Net: "tcp",
1852 Err: errDial,
1853 }
1854 if !reflect.DeepEqual(oe, want) {
1855 t.Errorf("Got error %#v; want %#v", oe, want)
1856 }
1857 }
1858
1859
1860
1861
1862
1863 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1864 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1865 }
1866 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1867 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1868 defer proxy.Close()
1869 c := proxy.Client()
1870
1871 tr := c.Transport.(*Transport)
1872 tr.Proxy = func(*Request) (*url.URL, error) {
1873 u, _ := url.Parse(proxy.URL)
1874 u.User = url.UserPassword("aladdin", "opensesame")
1875 return u, nil
1876 }
1877 h := tr.ProxyConnectHeader
1878 if h == nil {
1879 h = make(Header)
1880 }
1881 tr.ProxyConnectHeader = h.Clone()
1882
1883 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1884 if err != nil {
1885 t.Fatal(err)
1886 }
1887 _, err = c.Do(req)
1888 if err == nil {
1889 t.Errorf("unexpected Get success")
1890 }
1891
1892 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1893 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1894 }
1895 }
1896
1897
1898
1899
1900
1901 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1902 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1903 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1904 w.Header().Set("Content-Encoding", "gzip")
1905 w.Write(rgz)
1906 })).ts
1907
1908 c := ts.Client()
1909 res, err := c.Get(ts.URL)
1910 if err != nil {
1911 t.Fatal(err)
1912 }
1913 body, err := io.ReadAll(res.Body)
1914 if err != nil {
1915 t.Fatal(err)
1916 }
1917 if !bytes.Equal(body, rgz) {
1918 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1919 body, rgz)
1920 }
1921 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1922 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1923 }
1924 }
1925
1926
1927
1928 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1929 func testTransportGzipShort(t *testing.T, mode testMode) {
1930 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1931 w.Header().Set("Content-Encoding", "gzip")
1932 w.Write([]byte{0x1f, 0x8b})
1933 })).ts
1934
1935 c := ts.Client()
1936 res, err := c.Get(ts.URL)
1937 if err != nil {
1938 t.Fatal(err)
1939 }
1940 defer res.Body.Close()
1941 _, err = io.ReadAll(res.Body)
1942 if err == nil {
1943 t.Fatal("Expect an error from reading a body.")
1944 }
1945 if err != io.ErrUnexpectedEOF {
1946 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1947 }
1948 }
1949
1950
1951 func waitNumGoroutine(nmax int) int {
1952 nfinal := runtime.NumGoroutine()
1953 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1954 time.Sleep(50 * time.Millisecond)
1955 runtime.GC()
1956 nfinal = runtime.NumGoroutine()
1957 }
1958 return nfinal
1959 }
1960
1961
1962 func TestTransportPersistConnLeak(t *testing.T) {
1963 run(t, testTransportPersistConnLeak, testNotParallel)
1964 }
1965 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1966 if mode == http2Mode {
1967 t.Skip("flaky in HTTP/2")
1968 }
1969
1970
1971 const numReq = 25
1972 gotReqCh := make(chan bool, numReq)
1973 unblockCh := make(chan bool, numReq)
1974 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1975 gotReqCh <- true
1976 <-unblockCh
1977 w.Header().Set("Content-Length", "0")
1978 w.WriteHeader(204)
1979 })).ts
1980 c := ts.Client()
1981 tr := c.Transport.(*Transport)
1982
1983 n0 := runtime.NumGoroutine()
1984
1985 didReqCh := make(chan bool, numReq)
1986 failed := make(chan bool, numReq)
1987 for i := 0; i < numReq; i++ {
1988 go func() {
1989 res, err := c.Get(ts.URL)
1990 didReqCh <- true
1991 if err != nil {
1992 t.Logf("client fetch error: %v", err)
1993 failed <- true
1994 return
1995 }
1996 res.Body.Close()
1997 }()
1998 }
1999
2000
2001 for i := 0; i < numReq; i++ {
2002 select {
2003 case <-gotReqCh:
2004
2005 case <-failed:
2006
2007
2008 }
2009 }
2010
2011 nhigh := runtime.NumGoroutine()
2012
2013
2014 close(unblockCh)
2015
2016
2017 for i := 0; i < numReq; i++ {
2018 <-didReqCh
2019 }
2020
2021 tr.CloseIdleConnections()
2022 nfinal := waitNumGoroutine(n0 + 5)
2023
2024 growth := nfinal - n0
2025
2026
2027
2028 if int(growth) > 5 {
2029 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2030 t.Error("too many new goroutines")
2031 }
2032 }
2033
2034
2035
2036 func TestTransportPersistConnLeakShortBody(t *testing.T) {
2037 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
2038 }
2039 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
2040 if mode == http2Mode {
2041 t.Skip("flaky in HTTP/2")
2042 }
2043
2044
2045 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2046 })).ts
2047 c := ts.Client()
2048 tr := c.Transport.(*Transport)
2049
2050 n0 := runtime.NumGoroutine()
2051 body := []byte("Hello")
2052 for i := 0; i < 20; i++ {
2053 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2054 if err != nil {
2055 t.Fatal(err)
2056 }
2057 req.ContentLength = int64(len(body) - 2)
2058 _, err = c.Do(req)
2059 if err == nil {
2060 t.Fatal("Expect an error from writing too long of a body.")
2061 }
2062 }
2063 nhigh := runtime.NumGoroutine()
2064 tr.CloseIdleConnections()
2065 nfinal := waitNumGoroutine(n0 + 5)
2066
2067 growth := nfinal - n0
2068
2069
2070
2071 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2072 if int(growth) > 5 {
2073 t.Error("too many new goroutines")
2074 }
2075 }
2076
2077
2078 type countedConn struct {
2079 net.Conn
2080 }
2081
2082
2083 type countingDialer struct {
2084 dialer net.Dialer
2085 mu sync.Mutex
2086 total, live int64
2087 }
2088
2089 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2090 conn, err := d.dialer.DialContext(ctx, network, address)
2091 if err != nil {
2092 return nil, err
2093 }
2094
2095 counted := new(countedConn)
2096 counted.Conn = conn
2097
2098 d.mu.Lock()
2099 defer d.mu.Unlock()
2100 d.total++
2101 d.live++
2102
2103 runtime.AddCleanup(counted, func(dd *countingDialer) { dd.decrement(nil) }, d)
2104 return counted, nil
2105 }
2106
2107 func (d *countingDialer) decrement(*countedConn) {
2108 d.mu.Lock()
2109 defer d.mu.Unlock()
2110 d.live--
2111 }
2112
2113 func (d *countingDialer) Read() (total, live int64) {
2114 d.mu.Lock()
2115 defer d.mu.Unlock()
2116 return d.total, d.live
2117 }
2118
2119 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2120 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2121 }
2122 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2123 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2124
2125 conn, _, err := w.(Hijacker).Hijack()
2126 if err != nil {
2127 t.Errorf("Hijack failed unexpectedly: %v", err)
2128 return
2129 }
2130 conn.Close()
2131 })).ts
2132
2133 var d countingDialer
2134 c := ts.Client()
2135 c.Transport.(*Transport).DialContext = d.DialContext
2136
2137 body := []byte("Hello")
2138 for i := 0; ; i++ {
2139 total, live := d.Read()
2140 if live < total {
2141 break
2142 }
2143 if i >= 1<<12 {
2144 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2145 }
2146
2147 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2148 if err != nil {
2149 t.Fatal(err)
2150 }
2151 _, err = c.Do(req)
2152 if err == nil {
2153 t.Fatal("expected broken connection")
2154 }
2155
2156 runtime.GC()
2157 }
2158 }
2159
2160 type countedContext struct {
2161 context.Context
2162 }
2163
2164 type contextCounter struct {
2165 mu sync.Mutex
2166 live int64
2167 }
2168
2169 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2170 counted := new(countedContext)
2171 counted.Context = ctx
2172 cc.mu.Lock()
2173 defer cc.mu.Unlock()
2174 cc.live++
2175 runtime.AddCleanup(counted, func(c *contextCounter) { cc.decrement(nil) }, cc)
2176 return counted
2177 }
2178
2179 func (cc *contextCounter) decrement(*countedContext) {
2180 cc.mu.Lock()
2181 defer cc.mu.Unlock()
2182 cc.live--
2183 }
2184
2185 func (cc *contextCounter) Read() (live int64) {
2186 cc.mu.Lock()
2187 defer cc.mu.Unlock()
2188 return cc.live
2189 }
2190
2191 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2192 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2193 }
2194 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2195 if mode == http2Mode {
2196 t.Skip("https://go.dev/issue/56021")
2197 }
2198
2199 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2200 runtime.Gosched()
2201 w.WriteHeader(StatusOK)
2202 })).ts
2203
2204 c := ts.Client()
2205 c.Transport.(*Transport).MaxConnsPerHost = 1
2206
2207 ctx := context.Background()
2208 body := []byte("Hello")
2209 doPosts := func(cc *contextCounter) {
2210 var wg sync.WaitGroup
2211 for n := 64; n > 0; n-- {
2212 wg.Add(1)
2213 go func() {
2214 defer wg.Done()
2215
2216 ctx := cc.Track(ctx)
2217 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2218 if err != nil {
2219 t.Error(err)
2220 }
2221
2222 _, err = c.Do(req.WithContext(ctx))
2223 if err != nil {
2224 t.Errorf("Do failed with error: %v", err)
2225 }
2226 }()
2227 }
2228 wg.Wait()
2229 }
2230
2231 var initialCC contextCounter
2232 doPosts(&initialCC)
2233
2234
2235
2236
2237 var flushCC contextCounter
2238 for i := 0; ; i++ {
2239 live := initialCC.Read()
2240 if live == 0 {
2241 break
2242 }
2243 if i >= 100 {
2244 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2245 }
2246 doPosts(&flushCC)
2247 runtime.GC()
2248 }
2249 }
2250
2251
2252 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2253 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2254 var tr *Transport
2255
2256 unblockCh := make(chan bool, 1)
2257 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2258 <-unblockCh
2259 tr.CloseIdleConnections()
2260 })).ts
2261 c := ts.Client()
2262 tr = c.Transport.(*Transport)
2263
2264 didreq := make(chan bool)
2265 go func() {
2266 res, err := c.Get(ts.URL)
2267 if err != nil {
2268 t.Error(err)
2269 } else {
2270 res.Body.Close()
2271 }
2272 didreq <- true
2273 }()
2274 unblockCh <- true
2275 <-didreq
2276 }
2277
2278
2279
2280
2281
2282 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2283 func testIssue3644(t *testing.T, mode testMode) {
2284 const numFoos = 5000
2285 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2286 w.Header().Set("Connection", "close")
2287 for i := 0; i < numFoos; i++ {
2288 w.Write([]byte("foo "))
2289 }
2290 })).ts
2291 c := ts.Client()
2292 res, err := c.Get(ts.URL)
2293 if err != nil {
2294 t.Fatal(err)
2295 }
2296 defer res.Body.Close()
2297 bs, err := io.ReadAll(res.Body)
2298 if err != nil {
2299 t.Fatal(err)
2300 }
2301 if len(bs) != numFoos*len("foo ") {
2302 t.Errorf("unexpected response length")
2303 }
2304 }
2305
2306
2307
2308 func TestIssue3595(t *testing.T) {
2309
2310 run(t, testIssue3595, testNotParallel)
2311 }
2312 func testIssue3595(t *testing.T, mode testMode) {
2313 runTimeSensitiveTest(t, []time.Duration{
2314 1 * time.Millisecond,
2315 5 * time.Millisecond,
2316 10 * time.Millisecond,
2317 50 * time.Millisecond,
2318 100 * time.Millisecond,
2319 500 * time.Millisecond,
2320 time.Second,
2321 5 * time.Second,
2322 }, func(t *testing.T, timeout time.Duration) error {
2323 SetRSTAvoidanceDelay(t, timeout)
2324 t.Logf("set RST avoidance delay to %v", timeout)
2325
2326 const deniedMsg = "sorry, denied."
2327 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2328 Error(w, deniedMsg, StatusUnauthorized)
2329 }))
2330
2331
2332 defer cst.close()
2333 ts := cst.ts
2334 c := ts.Client()
2335
2336 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2337 if err != nil {
2338 return fmt.Errorf("Post: %v", err)
2339 }
2340 got, err := io.ReadAll(res.Body)
2341 if err != nil {
2342 return fmt.Errorf("Body ReadAll: %v", err)
2343 }
2344 t.Logf("server response:\n%s", got)
2345 if !strings.Contains(string(got), deniedMsg) {
2346
2347
2348 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2349 }
2350 return nil
2351 })
2352 }
2353
2354
2355
2356 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2357 func testChunkedNoContent(t *testing.T, mode testMode) {
2358 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2359 w.WriteHeader(StatusNoContent)
2360 })).ts
2361
2362 c := ts.Client()
2363 for _, closeBody := range []bool{true, false} {
2364 const n = 4
2365 for i := 1; i <= n; i++ {
2366 res, err := c.Get(ts.URL)
2367 if err != nil {
2368 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2369 } else {
2370 if closeBody {
2371 res.Body.Close()
2372 }
2373 }
2374 }
2375 }
2376 }
2377
2378 func TestTransportConcurrency(t *testing.T) {
2379 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2380 }
2381 func testTransportConcurrency(t *testing.T, mode testMode) {
2382
2383 maxProcs, numReqs := 16, 500
2384 if testing.Short() {
2385 maxProcs, numReqs = 4, 50
2386 }
2387 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2388 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2389 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2390 })).ts
2391
2392 var wg sync.WaitGroup
2393 wg.Add(numReqs)
2394
2395
2396
2397
2398
2399
2400
2401 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2402 defer SetPendingDialHooks(nil, nil)
2403
2404 c := ts.Client()
2405 reqs := make(chan string)
2406 defer close(reqs)
2407
2408 for i := 0; i < maxProcs*2; i++ {
2409 go func() {
2410 for req := range reqs {
2411 res, err := c.Get(ts.URL + "/?echo=" + req)
2412 if err != nil {
2413 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2414
2415
2416 t.Logf("error on req %s: %v", req, err)
2417 t.Logf("(see https://go.dev/issue/52168)")
2418 } else {
2419 t.Errorf("error on req %s: %v", req, err)
2420 }
2421 wg.Done()
2422 continue
2423 }
2424 all, err := io.ReadAll(res.Body)
2425 if err != nil {
2426 t.Errorf("read error on req %s: %v", req, err)
2427 } else if string(all) != req {
2428 t.Errorf("body of req %s = %q; want %q", req, all, req)
2429 }
2430 res.Body.Close()
2431 wg.Done()
2432 }
2433 }()
2434 }
2435 for i := 0; i < numReqs; i++ {
2436 reqs <- fmt.Sprintf("request-%d", i)
2437 }
2438 wg.Wait()
2439 }
2440
2441 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2442 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2443 mux := NewServeMux()
2444 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2445 io.Copy(w, neverEnding('a'))
2446 })
2447 ts := newClientServerTest(t, mode, mux).ts
2448
2449 connc := make(chan net.Conn, 1)
2450 c := ts.Client()
2451 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2452 conn, err := net.Dial(n, addr)
2453 if err != nil {
2454 return nil, err
2455 }
2456 select {
2457 case connc <- conn:
2458 default:
2459 }
2460 return conn, nil
2461 }
2462
2463 res, err := c.Get(ts.URL + "/get")
2464 if err != nil {
2465 t.Fatalf("Error issuing GET: %v", err)
2466 }
2467 defer res.Body.Close()
2468
2469 conn := <-connc
2470 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2471 _, err = io.Copy(io.Discard, res.Body)
2472 if err == nil {
2473 t.Errorf("Unexpected successful copy")
2474 }
2475 }
2476
2477 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2478 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2479 }
2480 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2481 const debug = false
2482 mux := NewServeMux()
2483 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2484 io.Copy(w, neverEnding('a'))
2485 })
2486 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2487 defer r.Body.Close()
2488 io.Copy(io.Discard, r.Body)
2489 })
2490 ts := newClientServerTest(t, mode, mux).ts
2491 timeout := 100 * time.Millisecond
2492
2493 c := ts.Client()
2494 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2495 conn, err := net.Dial(n, addr)
2496 if err != nil {
2497 return nil, err
2498 }
2499 conn.SetDeadline(time.Now().Add(timeout))
2500 if debug {
2501 conn = NewLoggingConn("client", conn)
2502 }
2503 return conn, nil
2504 }
2505
2506 getFailed := false
2507 nRuns := 5
2508 if testing.Short() {
2509 nRuns = 1
2510 }
2511 for i := 0; i < nRuns; i++ {
2512 if debug {
2513 println("run", i+1, "of", nRuns)
2514 }
2515 sres, err := c.Get(ts.URL + "/get")
2516 if err != nil {
2517 if !getFailed {
2518
2519 getFailed = true
2520 t.Logf("increasing timeout")
2521 i--
2522 timeout *= 10
2523 continue
2524 }
2525 t.Errorf("Error issuing GET: %v", err)
2526 break
2527 }
2528 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2529 _, err = c.Do(req)
2530 if err == nil {
2531 sres.Body.Close()
2532 t.Errorf("Unexpected successful PUT")
2533 break
2534 }
2535 sres.Body.Close()
2536 }
2537 if debug {
2538 println("tests complete; waiting for handlers to finish")
2539 }
2540 ts.Close()
2541 }
2542
2543 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2544 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2545 if testing.Short() {
2546 t.Skip("skipping timeout test in -short mode")
2547 }
2548
2549 timeout := 2 * time.Millisecond
2550 retry := true
2551 for retry && !t.Failed() {
2552 var srvWG sync.WaitGroup
2553 inHandler := make(chan bool, 1)
2554 mux := NewServeMux()
2555 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2556 inHandler <- true
2557 srvWG.Done()
2558 })
2559 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2560 inHandler <- true
2561 <-r.Context().Done()
2562 srvWG.Done()
2563 })
2564 ts := newClientServerTest(t, mode, mux).ts
2565
2566 c := ts.Client()
2567 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2568
2569 retry = false
2570 srvWG.Add(3)
2571 tests := []struct {
2572 path string
2573 wantTimeout bool
2574 }{
2575 {path: "/fast"},
2576 {path: "/slow", wantTimeout: true},
2577 {path: "/fast"},
2578 }
2579 for i, tt := range tests {
2580 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2581 req = req.WithT(t)
2582 res, err := c.Do(req)
2583 <-inHandler
2584 if err != nil {
2585 uerr, ok := err.(*url.Error)
2586 if !ok {
2587 t.Errorf("error is not a url.Error; got: %#v", err)
2588 continue
2589 }
2590 nerr, ok := uerr.Err.(net.Error)
2591 if !ok {
2592 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2593 continue
2594 }
2595 if !nerr.Timeout() {
2596 t.Errorf("want timeout error; got: %q", nerr)
2597 continue
2598 }
2599 if !tt.wantTimeout {
2600 if !retry {
2601
2602 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2603 timeout *= 2
2604 retry = true
2605 }
2606 }
2607 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2608 t.Errorf("%d. unexpected error: %v", i, err)
2609 }
2610 continue
2611 }
2612 if tt.wantTimeout {
2613 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2614 continue
2615 }
2616 if res.StatusCode != 200 {
2617 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2618 }
2619 }
2620
2621 srvWG.Wait()
2622 ts.Close()
2623 }
2624 }
2625
2626
2627 type cancelTest struct {
2628 mode testMode
2629 newReq func(req *Request) *Request
2630 cancel func(tr *Transport, req *Request)
2631 checkErr func(when string, err error)
2632 }
2633
2634
2635 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2636 t.Run("TransportCancel", func(t *testing.T) {
2637 f(t, cancelTest{
2638 mode: mode,
2639 newReq: func(req *Request) *Request {
2640 return req
2641 },
2642 cancel: func(tr *Transport, req *Request) {
2643 tr.CancelRequest(req)
2644 },
2645 checkErr: func(when string, err error) {
2646 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2647 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2648 }
2649 },
2650 })
2651 })
2652 }
2653
2654
2655 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2656 cancelc := make(chan struct{})
2657 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2658 f(t, cancelTest{
2659 mode: mode,
2660 newReq: func(req *Request) *Request {
2661 req.Cancel = cancelc
2662 return req
2663 },
2664 cancel: func(tr *Transport, req *Request) {
2665 cancelOnce()
2666 },
2667 checkErr: func(when string, err error) {
2668 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2669 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2670 }
2671 },
2672 })
2673 }
2674
2675
2676 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2677 ctx, cancel := context.WithCancel(context.Background())
2678 f(t, cancelTest{
2679 mode: mode,
2680 newReq: func(req *Request) *Request {
2681 return req.WithContext(ctx)
2682 },
2683 cancel: func(tr *Transport, req *Request) {
2684 cancel()
2685 },
2686 checkErr: func(when string, err error) {
2687 if !errors.Is(err, context.Canceled) {
2688 t.Errorf("%v error = %v, want context.Canceled", when, err)
2689 }
2690 },
2691 })
2692 }
2693
2694 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2695 run(t, func(t *testing.T, mode testMode) {
2696 if mode == http1Mode {
2697 t.Run("TransportCancel", func(t *testing.T) {
2698 runCancelTestTransport(t, mode, f)
2699 })
2700 }
2701 t.Run("RequestCancel", func(t *testing.T) {
2702 runCancelTestChannel(t, mode, f)
2703 })
2704 t.Run("ContextCancel", func(t *testing.T) {
2705 runCancelTestContext(t, mode, f)
2706 })
2707 }, opts...)
2708 }
2709
2710 func TestTransportCancelRequest(t *testing.T) {
2711 runCancelTest(t, testTransportCancelRequest)
2712 }
2713 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2714 if testing.Short() {
2715 t.Skip("skipping test in -short mode")
2716 }
2717
2718 const msg = "Hello"
2719 unblockc := make(chan bool)
2720 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2721 io.WriteString(w, msg)
2722 w.(Flusher).Flush()
2723 <-unblockc
2724 })).ts
2725 defer close(unblockc)
2726
2727 c := ts.Client()
2728 tr := c.Transport.(*Transport)
2729
2730 req, _ := NewRequest("GET", ts.URL, nil)
2731 req = test.newReq(req)
2732 res, err := c.Do(req)
2733 if err != nil {
2734 t.Fatal(err)
2735 }
2736 body := make([]byte, len(msg))
2737 n, _ := io.ReadFull(res.Body, body)
2738 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2739 t.Errorf("Body = %q; want %q", body[:n], msg)
2740 }
2741 test.cancel(tr, req)
2742
2743 tail, err := io.ReadAll(res.Body)
2744 res.Body.Close()
2745 test.checkErr("Body.Read", err)
2746 if len(tail) > 0 {
2747 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2748 }
2749
2750
2751
2752 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2753 n := tr.NumPendingRequestsForTesting()
2754 if n > 0 {
2755 if d > 0 {
2756 t.Logf("pending requests = %d after %v (want 0)", n, d)
2757 }
2758 return false
2759 }
2760 return true
2761 })
2762 }
2763
2764 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2765 if testing.Short() {
2766 t.Skip("skipping test in -short mode")
2767 }
2768 unblockc := make(chan bool)
2769 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2770 <-unblockc
2771 })).ts
2772 defer close(unblockc)
2773
2774 c := ts.Client()
2775 tr := c.Transport.(*Transport)
2776
2777 donec := make(chan bool)
2778 req, _ := NewRequest("GET", ts.URL, body)
2779 req = test.newReq(req)
2780 go func() {
2781 defer close(donec)
2782 c.Do(req)
2783 }()
2784
2785 unblockc <- true
2786 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2787 test.cancel(tr, req)
2788 select {
2789 case <-donec:
2790 return true
2791 default:
2792 if d > 0 {
2793 t.Logf("Do of canceled request has not returned after %v", d)
2794 }
2795 return false
2796 }
2797 })
2798 }
2799
2800 func TestTransportCancelRequestInDo(t *testing.T) {
2801 runCancelTest(t, func(t *testing.T, test cancelTest) {
2802 testTransportCancelRequestInDo(t, test, nil)
2803 })
2804 }
2805
2806 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2807 runCancelTest(t, func(t *testing.T, test cancelTest) {
2808 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2809 })
2810 }
2811
2812 func TestTransportCancelRequestInDial(t *testing.T) {
2813 runCancelTest(t, testTransportCancelRequestInDial)
2814 }
2815 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2816 defer afterTest(t)
2817 if testing.Short() {
2818 t.Skip("skipping test in -short mode")
2819 }
2820 var logbuf strings.Builder
2821 eventLog := log.New(&logbuf, "", 0)
2822
2823 unblockDial := make(chan bool)
2824 defer close(unblockDial)
2825
2826 inDial := make(chan bool)
2827 tr := &Transport{
2828 Dial: func(network, addr string) (net.Conn, error) {
2829 eventLog.Println("dial: blocking")
2830 if !<-inDial {
2831 return nil, errors.New("main Test goroutine exited")
2832 }
2833 <-unblockDial
2834 return nil, errors.New("nope")
2835 },
2836 }
2837 cl := &Client{Transport: tr}
2838 gotres := make(chan bool)
2839 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2840 req = test.newReq(req)
2841 go func() {
2842 _, err := cl.Do(req)
2843 eventLog.Printf("Get error = %v", err != nil)
2844 test.checkErr("Get", err)
2845 gotres <- true
2846 }()
2847
2848 inDial <- true
2849
2850 eventLog.Printf("canceling")
2851 test.cancel(tr, req)
2852 test.cancel(tr, req)
2853
2854 if d, ok := t.Deadline(); ok {
2855
2856
2857 timeout := time.Until(d) * 19 / 20
2858 timer := time.AfterFunc(timeout, func() {
2859 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2860 })
2861 defer timer.Stop()
2862 }
2863 <-gotres
2864
2865 got := logbuf.String()
2866 want := `dial: blocking
2867 canceling
2868 Get error = true
2869 `
2870 if got != want {
2871 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2872 }
2873 }
2874
2875
2876 func TestTransportCancelRequestWithBody(t *testing.T) {
2877 runCancelTest(t, testTransportCancelRequestWithBody)
2878 }
2879 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2880 if testing.Short() {
2881 t.Skip("skipping test in -short mode")
2882 }
2883
2884 const msg = "Hello"
2885 unblockc := make(chan struct{})
2886 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2887 io.WriteString(w, msg)
2888 w.(Flusher).Flush()
2889 <-unblockc
2890 })).ts
2891 defer close(unblockc)
2892
2893 c := ts.Client()
2894 tr := c.Transport.(*Transport)
2895
2896 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2897 req = test.newReq(req)
2898
2899 res, err := c.Do(req)
2900 if err != nil {
2901 t.Fatal(err)
2902 }
2903 body := make([]byte, len(msg))
2904 n, _ := io.ReadFull(res.Body, body)
2905 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2906 t.Errorf("Body = %q; want %q", body[:n], msg)
2907 }
2908 test.cancel(tr, req)
2909
2910 tail, err := io.ReadAll(res.Body)
2911 res.Body.Close()
2912 test.checkErr("Body.Read", err)
2913 if len(tail) > 0 {
2914 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2915 }
2916
2917
2918
2919 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2920 n := tr.NumPendingRequestsForTesting()
2921 if n > 0 {
2922 if d > 0 {
2923 t.Logf("pending requests = %d after %v (want 0)", n, d)
2924 }
2925 return false
2926 }
2927 return true
2928 })
2929 }
2930
2931 func TestTransportCancelRequestBeforeDo(t *testing.T) {
2932
2933 run(t, func(t *testing.T, mode testMode) {
2934 t.Run("RequestCancel", func(t *testing.T) {
2935 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
2936 })
2937 t.Run("ContextCancel", func(t *testing.T) {
2938 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
2939 })
2940 })
2941 }
2942 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
2943 unblockc := make(chan bool)
2944 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2945 <-unblockc
2946 }))
2947 defer close(unblockc)
2948
2949 c := cst.ts.Client()
2950
2951 req, _ := NewRequest("GET", cst.ts.URL, nil)
2952 req = test.newReq(req)
2953 test.cancel(cst.tr, req)
2954
2955 _, err := c.Do(req)
2956 test.checkErr("Do", err)
2957 }
2958
2959
2960 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
2961 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
2962 }
2963 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
2964 defer afterTest(t)
2965
2966 serverConnCh := make(chan net.Conn, 1)
2967 tr := &Transport{
2968 Dial: func(network, addr string) (net.Conn, error) {
2969 cc, sc := net.Pipe()
2970 serverConnCh <- sc
2971 return cc, nil
2972 },
2973 }
2974 defer tr.CloseIdleConnections()
2975 errc := make(chan error, 1)
2976 req, _ := NewRequest("GET", "http://example.com/", nil)
2977 req = test.newReq(req)
2978 go func() {
2979 _, err := tr.RoundTrip(req)
2980 errc <- err
2981 }()
2982
2983 sc := <-serverConnCh
2984 verb := make([]byte, 3)
2985 if _, err := io.ReadFull(sc, verb); err != nil {
2986 t.Errorf("Error reading HTTP verb from server: %v", err)
2987 }
2988 if string(verb) != "GET" {
2989 t.Errorf("server received %q; want GET", verb)
2990 }
2991 defer sc.Close()
2992
2993 test.cancel(tr, req)
2994
2995 err := <-errc
2996 if err == nil {
2997 t.Fatalf("unexpected success from RoundTrip")
2998 }
2999 test.checkErr("RoundTrip", err)
3000 }
3001
3002
3003
3004
3005 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
3006 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
3007 writeErr := make(chan error, 1)
3008 msg := []byte("young\n")
3009 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3010 for {
3011 _, err := w.Write(msg)
3012 if err != nil {
3013 writeErr <- err
3014 return
3015 }
3016 w.(Flusher).Flush()
3017 }
3018 })).ts
3019
3020 c := ts.Client()
3021 tr := c.Transport.(*Transport)
3022
3023 req, _ := NewRequest("GET", ts.URL, nil)
3024 defer tr.CancelRequest(req)
3025
3026 res, err := c.Do(req)
3027 if err != nil {
3028 t.Fatal(err)
3029 }
3030
3031 const repeats = 3
3032 buf := make([]byte, len(msg)*repeats)
3033 want := bytes.Repeat(msg, repeats)
3034
3035 _, err = io.ReadFull(res.Body, buf)
3036 if err != nil {
3037 t.Fatal(err)
3038 }
3039 if !bytes.Equal(buf, want) {
3040 t.Fatalf("read %q; want %q", buf, want)
3041 }
3042
3043 if err := res.Body.Close(); err != nil {
3044 t.Errorf("Close = %v", err)
3045 }
3046
3047 if err := <-writeErr; err == nil {
3048 t.Errorf("expected non-nil write error")
3049 }
3050 }
3051
3052 type fooProto struct{}
3053
3054 func (fooProto) RoundTrip(req *Request) (*Response, error) {
3055 res := &Response{
3056 Status: "200 OK",
3057 StatusCode: 200,
3058 Header: make(Header),
3059 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
3060 }
3061 return res, nil
3062 }
3063
3064 func TestTransportAltProto(t *testing.T) {
3065 defer afterTest(t)
3066 tr := &Transport{}
3067 c := &Client{Transport: tr}
3068 tr.RegisterProtocol("foo", fooProto{})
3069 res, err := c.Get("foo://bar.com/path")
3070 if err != nil {
3071 t.Fatal(err)
3072 }
3073 bodyb, err := io.ReadAll(res.Body)
3074 if err != nil {
3075 t.Fatal(err)
3076 }
3077 body := string(bodyb)
3078 if e := "You wanted foo://bar.com/path"; body != e {
3079 t.Errorf("got response %q, want %q", body, e)
3080 }
3081 }
3082
3083 func TestTransportNoHost(t *testing.T) {
3084 defer afterTest(t)
3085 tr := &Transport{}
3086 _, err := tr.RoundTrip(&Request{
3087 Header: make(Header),
3088 URL: &url.URL{
3089 Scheme: "http",
3090 },
3091 })
3092 want := "http: no Host in request URL"
3093 if got := fmt.Sprint(err); got != want {
3094 t.Errorf("error = %v; want %q", err, want)
3095 }
3096 }
3097
3098
3099 func TestTransportEmptyMethod(t *testing.T) {
3100 req, _ := NewRequest("GET", "http://foo.com/", nil)
3101 req.Method = ""
3102 got, err := httputil.DumpRequestOut(req, false)
3103 if err != nil {
3104 t.Fatal(err)
3105 }
3106 if !strings.Contains(string(got), "GET ") {
3107 t.Fatalf("expected substring 'GET '; got: %s", got)
3108 }
3109 }
3110
3111 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3112 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3113 mux := NewServeMux()
3114 fooGate := make(chan bool, 1)
3115 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3116 w.Header().Set("foo-ipport", r.RemoteAddr)
3117 w.(Flusher).Flush()
3118 <-fooGate
3119 })
3120 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3121 w.Header().Set("bar-ipport", r.RemoteAddr)
3122 })
3123 ts := newClientServerTest(t, mode, mux).ts
3124
3125 dialGate := make(chan bool, 1)
3126 dialing := make(chan bool)
3127 c := ts.Client()
3128 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3129 for {
3130 select {
3131 case ok := <-dialGate:
3132 if !ok {
3133 return nil, errors.New("manually closed")
3134 }
3135 return net.Dial(n, addr)
3136 case dialing <- true:
3137 }
3138 }
3139 }
3140 defer close(dialGate)
3141
3142 dialGate <- true
3143 fooRes, err := c.Get(ts.URL + "/foo")
3144 if err != nil {
3145 t.Fatal(err)
3146 }
3147 fooAddr := fooRes.Header.Get("foo-ipport")
3148 if fooAddr == "" {
3149 t.Fatal("No addr on /foo request")
3150 }
3151
3152 fooDone := make(chan struct{})
3153 go func() {
3154
3155
3156
3157
3158 if mode == http2Mode {
3159
3160
3161
3162
3163 select {
3164 case <-dialing:
3165 t.Errorf("unexpected second Dial in HTTP/2 mode")
3166 case <-time.After(10 * time.Millisecond):
3167 }
3168 } else {
3169 <-dialing
3170 }
3171 fooGate <- true
3172 io.Copy(io.Discard, fooRes.Body)
3173 fooRes.Body.Close()
3174 close(fooDone)
3175 }()
3176 defer func() {
3177 <-fooDone
3178 }()
3179
3180 barRes, err := c.Get(ts.URL + "/bar")
3181 if err != nil {
3182 t.Fatal(err)
3183 }
3184 barAddr := barRes.Header.Get("bar-ipport")
3185 if barAddr != fooAddr {
3186 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3187 }
3188 barRes.Body.Close()
3189 }
3190
3191
3192 func TestTransportReading100Continue(t *testing.T) {
3193 defer afterTest(t)
3194
3195 const numReqs = 5
3196 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3197 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3198
3199 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3200 defer w.Close()
3201 defer r.Close()
3202 br := bufio.NewReader(r)
3203 n := 0
3204 for {
3205 n++
3206 req, err := ReadRequest(br)
3207 if err == io.EOF {
3208 return
3209 }
3210 if err != nil {
3211 t.Error(err)
3212 return
3213 }
3214 slurp, err := io.ReadAll(req.Body)
3215 if err != nil {
3216 t.Errorf("Server request body slurp: %v", err)
3217 return
3218 }
3219 id := req.Header.Get("Request-Id")
3220 resCode := req.Header.Get("X-Want-Response-Code")
3221 if resCode == "" {
3222 resCode = "100 Continue"
3223 if string(slurp) != reqBody(n) {
3224 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3225 }
3226 }
3227 body := fmt.Sprintf("Response number %d", n)
3228 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3229 Date: Thu, 28 Feb 2013 17:55:41 GMT
3230
3231 HTTP/1.1 200 OK
3232 Content-Type: text/html
3233 Echo-Request-Id: %s
3234 Content-Length: %d
3235
3236 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3237 w.Write(v)
3238 if id == reqID(numReqs) {
3239 return
3240 }
3241 }
3242
3243 }
3244
3245 tr := &Transport{
3246 Dial: func(n, addr string) (net.Conn, error) {
3247 sr, sw := io.Pipe()
3248 cr, cw := io.Pipe()
3249 conn := &rwTestConn{
3250 Reader: cr,
3251 Writer: sw,
3252 closeFunc: func() error {
3253 sw.Close()
3254 cw.Close()
3255 return nil
3256 },
3257 }
3258 go send100Response(cw, sr)
3259 return conn, nil
3260 },
3261 DisableKeepAlives: false,
3262 }
3263 defer tr.CloseIdleConnections()
3264 c := &Client{Transport: tr}
3265
3266 testResponse := func(req *Request, name string, wantCode int) {
3267 t.Helper()
3268 res, err := c.Do(req)
3269 if err != nil {
3270 t.Fatalf("%s: Do: %v", name, err)
3271 }
3272 if res.StatusCode != wantCode {
3273 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3274 }
3275 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3276 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3277 }
3278 _, err = io.ReadAll(res.Body)
3279 if err != nil {
3280 t.Fatalf("%s: Slurp error: %v", name, err)
3281 }
3282 }
3283
3284
3285 for i := 1; i <= numReqs; i++ {
3286 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3287 req.Header.Set("Request-Id", reqID(i))
3288 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3289 }
3290 }
3291
3292
3293
3294 func TestTransportIgnore1xxResponses(t *testing.T) {
3295 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3296 }
3297 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3298 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3299 conn, buf, _ := w.(Hijacker).Hijack()
3300 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3301 buf.Flush()
3302 conn.Close()
3303 }))
3304 cst.tr.DisableKeepAlives = true
3305
3306 var got strings.Builder
3307
3308 req, _ := NewRequest("GET", cst.ts.URL, nil)
3309 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3310 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3311 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3312 return nil
3313 },
3314 }))
3315 res, err := cst.c.Do(req)
3316 if err != nil {
3317 t.Fatal(err)
3318 }
3319 defer res.Body.Close()
3320
3321 res.Write(&got)
3322 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3323 if got.String() != want {
3324 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3325 }
3326 }
3327
3328 func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
3329 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3330 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3331 w.Header().Add("X-Header", strings.Repeat("a", 100))
3332 for i := 0; i < 10; i++ {
3333 w.WriteHeader(123)
3334 }
3335 w.WriteHeader(204)
3336 }))
3337 cst.tr.DisableKeepAlives = true
3338 cst.tr.MaxResponseHeaderBytes = 1000
3339
3340 res, err := cst.c.Get(cst.ts.URL)
3341 if err == nil {
3342 res.Body.Close()
3343 t.Fatalf("RoundTrip succeeded; want error")
3344 }
3345 for _, want := range []string{
3346 "response headers exceeded",
3347 "too many 1xx",
3348 "header list too large",
3349 } {
3350 if strings.Contains(err.Error(), want) {
3351 return
3352 }
3353 }
3354 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3355 }
3356
3357 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3358 run(t, testTransportDoesNotLimitDelivered1xxResponses)
3359 }
3360 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3361 if mode == http2Mode {
3362 t.Skip("skip until x/net/http2 updated")
3363 }
3364 const num1xx = 10
3365 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3366 w.Header().Add("X-Header", strings.Repeat("a", 100))
3367 for i := 0; i < 10; i++ {
3368 w.WriteHeader(123)
3369 }
3370 w.WriteHeader(204)
3371 }))
3372 cst.tr.DisableKeepAlives = true
3373 cst.tr.MaxResponseHeaderBytes = 1000
3374
3375 got1xx := 0
3376 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3377 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3378 got1xx++
3379 return nil
3380 },
3381 })
3382 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3383 res, err := cst.c.Do(req)
3384 if err != nil {
3385 t.Fatal(err)
3386 }
3387 res.Body.Close()
3388 if got1xx != num1xx {
3389 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3390 }
3391 }
3392
3393
3394
3395 func TestTransportTreat101Terminal(t *testing.T) {
3396 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3397 }
3398 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3399 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3400 conn, buf, _ := w.(Hijacker).Hijack()
3401 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3402 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3403 buf.Flush()
3404 conn.Close()
3405 }))
3406 res, err := cst.c.Get(cst.ts.URL)
3407 if err != nil {
3408 t.Fatal(err)
3409 }
3410 defer res.Body.Close()
3411 if res.StatusCode != StatusSwitchingProtocols {
3412 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3413 }
3414 }
3415
3416 type proxyFromEnvTest struct {
3417 req string
3418
3419 env string
3420 httpsenv string
3421 noenv string
3422 reqmeth string
3423
3424 want string
3425 wanterr error
3426 }
3427
3428 func (t proxyFromEnvTest) String() string {
3429 var buf strings.Builder
3430 space := func() {
3431 if buf.Len() > 0 {
3432 buf.WriteByte(' ')
3433 }
3434 }
3435 if t.env != "" {
3436 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3437 }
3438 if t.httpsenv != "" {
3439 space()
3440 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3441 }
3442 if t.noenv != "" {
3443 space()
3444 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3445 }
3446 if t.reqmeth != "" {
3447 space()
3448 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3449 }
3450 req := "http://example.com"
3451 if t.req != "" {
3452 req = t.req
3453 }
3454 space()
3455 fmt.Fprintf(&buf, "req=%q", req)
3456 return strings.TrimSpace(buf.String())
3457 }
3458
3459 var proxyFromEnvTests = []proxyFromEnvTest{
3460 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3461 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3462 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3463 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3464 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3465 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3466 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3467 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3468
3469
3470 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3471
3472 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3473 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3474
3475
3476
3477 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3478 want: "<nil>",
3479 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3480
3481 {want: "<nil>"},
3482
3483 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3484 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3485 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3486 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3487 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3488 }
3489
3490 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3491 t.Helper()
3492 reqURL := tt.req
3493 if reqURL == "" {
3494 reqURL = "http://example.com"
3495 }
3496 req, _ := NewRequest("GET", reqURL, nil)
3497 url, err := proxyForRequest(req)
3498 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3499 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3500 return
3501 }
3502 if got := fmt.Sprintf("%s", url); got != tt.want {
3503 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3504 }
3505 }
3506
3507 func TestProxyFromEnvironment(t *testing.T) {
3508 ResetProxyEnv()
3509 defer ResetProxyEnv()
3510 for _, tt := range proxyFromEnvTests {
3511 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3512 os.Setenv("HTTP_PROXY", tt.env)
3513 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3514 os.Setenv("NO_PROXY", tt.noenv)
3515 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3516 ResetCachedEnvironment()
3517 return ProxyFromEnvironment(req)
3518 })
3519 }
3520 }
3521
3522 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3523 ResetProxyEnv()
3524 defer ResetProxyEnv()
3525 for _, tt := range proxyFromEnvTests {
3526 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3527 os.Setenv("http_proxy", tt.env)
3528 os.Setenv("https_proxy", tt.httpsenv)
3529 os.Setenv("no_proxy", tt.noenv)
3530 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3531 ResetCachedEnvironment()
3532 return ProxyFromEnvironment(req)
3533 })
3534 }
3535 }
3536
3537 func TestIdleConnChannelLeak(t *testing.T) {
3538 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3539 }
3540 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3541
3542 var mu sync.Mutex
3543 var n int
3544
3545 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3546 mu.Lock()
3547 n++
3548 mu.Unlock()
3549 })).ts
3550
3551 const nReqs = 5
3552 didRead := make(chan bool, nReqs)
3553 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3554 defer SetReadLoopBeforeNextReadHook(nil)
3555
3556 c := ts.Client()
3557 tr := c.Transport.(*Transport)
3558 tr.Dial = func(netw, addr string) (net.Conn, error) {
3559 return net.Dial(netw, ts.Listener.Addr().String())
3560 }
3561
3562
3563 for _, disableKeep := range []bool{true, false} {
3564 tr.DisableKeepAlives = disableKeep
3565 for i := 0; i < nReqs; i++ {
3566 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3567 if err != nil {
3568 t.Fatal(err)
3569 }
3570
3571
3572
3573
3574
3575 }
3576
3577
3578
3579
3580
3581
3582
3583 for i := 0; i < nReqs; i++ {
3584 <-didRead
3585 }
3586
3587 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3588 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3589 }
3590 }
3591 }
3592
3593
3594
3595
3596 func TestTransportClosesRequestBody(t *testing.T) {
3597 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3598 }
3599 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3601 io.Copy(io.Discard, r.Body)
3602 })).ts
3603
3604 c := ts.Client()
3605
3606 closes := 0
3607
3608 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3609 if err != nil {
3610 t.Fatal(err)
3611 }
3612 res.Body.Close()
3613 if closes != 1 {
3614 t.Errorf("closes = %d; want 1", closes)
3615 }
3616 }
3617
3618 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3619 defer afterTest(t)
3620 if testing.Short() {
3621 t.Skip("skipping in short mode")
3622 }
3623 ln := newLocalListener(t)
3624 defer ln.Close()
3625 testdonec := make(chan struct{})
3626 defer close(testdonec)
3627
3628 go func() {
3629 c, err := ln.Accept()
3630 if err != nil {
3631 t.Error(err)
3632 return
3633 }
3634 <-testdonec
3635 c.Close()
3636 }()
3637
3638 tr := &Transport{
3639 Dial: func(_, _ string) (net.Conn, error) {
3640 return net.Dial("tcp", ln.Addr().String())
3641 },
3642 TLSHandshakeTimeout: 250 * time.Millisecond,
3643 }
3644 cl := &Client{Transport: tr}
3645 _, err := cl.Get("https://dummy.tld/")
3646 if err == nil {
3647 t.Error("expected error")
3648 return
3649 }
3650 ue, ok := err.(*url.Error)
3651 if !ok {
3652 t.Errorf("expected url.Error; got %#v", err)
3653 return
3654 }
3655 ne, ok := ue.Err.(net.Error)
3656 if !ok {
3657 t.Errorf("expected net.Error; got %#v", err)
3658 return
3659 }
3660 if !ne.Timeout() {
3661 t.Errorf("expected timeout error; got %v", err)
3662 }
3663 if !strings.Contains(err.Error(), "handshake timeout") {
3664 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3665 }
3666 }
3667
3668
3669 func TestTLSServerClosesConnection(t *testing.T) {
3670 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3671 }
3672 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3673 closedc := make(chan bool, 1)
3674 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3675 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3676 conn, _, _ := w.(Hijacker).Hijack()
3677 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3678 conn.Close()
3679 closedc <- true
3680 return
3681 }
3682 fmt.Fprintf(w, "hello")
3683 })).ts
3684
3685 c := ts.Client()
3686 tr := c.Transport.(*Transport)
3687
3688 var nSuccess = 0
3689 var errs []error
3690 const trials = 20
3691 for i := 0; i < trials; i++ {
3692 tr.CloseIdleConnections()
3693 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3694 if err != nil {
3695 t.Fatal(err)
3696 }
3697 <-closedc
3698 slurp, err := io.ReadAll(res.Body)
3699 if err != nil {
3700 t.Fatal(err)
3701 }
3702 if string(slurp) != "foo" {
3703 t.Errorf("Got %q, want foo", slurp)
3704 }
3705
3706
3707
3708 res, err = c.Get(ts.URL + "/")
3709 if err != nil {
3710 errs = append(errs, err)
3711 continue
3712 }
3713 slurp, err = io.ReadAll(res.Body)
3714 if err != nil {
3715 errs = append(errs, err)
3716 continue
3717 }
3718 nSuccess++
3719 }
3720 if nSuccess > 0 {
3721 t.Logf("successes = %d of %d", nSuccess, trials)
3722 } else {
3723 t.Errorf("All runs failed:")
3724 }
3725 for _, err := range errs {
3726 t.Logf(" err: %v", err)
3727 }
3728 }
3729
3730
3731
3732
3733 type byteFromChanReader chan byte
3734
3735 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3736 if len(p) == 0 {
3737 return
3738 }
3739 b, ok := <-c
3740 if !ok {
3741 return 0, io.EOF
3742 }
3743 p[0] = b
3744 return 1, nil
3745 }
3746
3747
3748
3749
3750
3751
3752
3753 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3754 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3755 }
3756 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3757 defer func(d time.Duration) {
3758 *MaxWriteWaitBeforeConnReuse = d
3759 }(*MaxWriteWaitBeforeConnReuse)
3760 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3761 var sconn struct {
3762 sync.Mutex
3763 c net.Conn
3764 }
3765 var getOkay bool
3766 var copying sync.WaitGroup
3767 closeConn := func() {
3768 sconn.Lock()
3769 defer sconn.Unlock()
3770 if sconn.c != nil {
3771 sconn.c.Close()
3772 sconn.c = nil
3773 if !getOkay {
3774 t.Logf("Closed server connection")
3775 }
3776 }
3777 }
3778 defer func() {
3779 closeConn()
3780 copying.Wait()
3781 }()
3782
3783 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3784 if r.Method == "GET" {
3785 io.WriteString(w, "bar")
3786 return
3787 }
3788 conn, _, _ := w.(Hijacker).Hijack()
3789 sconn.Lock()
3790 sconn.c = conn
3791 sconn.Unlock()
3792 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3793
3794 copying.Add(1)
3795 go func() {
3796 io.Copy(io.Discard, conn)
3797 copying.Done()
3798 }()
3799 })).ts
3800 c := ts.Client()
3801
3802 const bodySize = 256 << 10
3803 finalBit := make(byteFromChanReader, 1)
3804 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3805 req.ContentLength = bodySize
3806 res, err := c.Do(req)
3807 if err := wantBody(res, err, "foo"); err != nil {
3808 t.Errorf("POST response: %v", err)
3809 }
3810
3811 res, err = c.Get(ts.URL)
3812 if err := wantBody(res, err, "bar"); err != nil {
3813 t.Errorf("GET response: %v", err)
3814 return
3815 }
3816 getOkay = true
3817 finalBit <- 'x'
3818 close(finalBit)
3819 }
3820
3821
3822
3823 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3824 func testTransportIssue10457(t *testing.T, mode testMode) {
3825 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3826
3827
3828
3829
3830
3831 conn, _, _ := w.(Hijacker).Hijack()
3832 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3833 conn.Close()
3834 })).ts
3835 c := ts.Client()
3836
3837 res, err := c.Get(ts.URL)
3838 if err != nil {
3839 t.Fatalf("Get: %v", err)
3840 }
3841 defer res.Body.Close()
3842
3843
3844
3845
3846 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3847 t.Errorf("Foo header = %q; want %q", got, want)
3848 }
3849 }
3850
3851 type closerFunc func() error
3852
3853 func (f closerFunc) Close() error { return f() }
3854
3855 type writerFuncConn struct {
3856 net.Conn
3857 write func(p []byte) (n int, err error)
3858 }
3859
3860 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874 func TestRetryRequestsOnError(t *testing.T) {
3875 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3876 }
3877 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3878 newRequest := func(method, urlStr string, body io.Reader) *Request {
3879 req, err := NewRequest(method, urlStr, body)
3880 if err != nil {
3881 t.Fatal(err)
3882 }
3883 return req
3884 }
3885
3886 testCases := []struct {
3887 name string
3888 failureN int
3889 failureErr error
3890
3891
3892
3893 req func() *Request
3894 reqString string
3895 }{
3896 {
3897 name: "IdempotentNoBodySomeWritten",
3898
3899
3900 failureN: 1,
3901
3902 failureErr: ExportErrServerClosedIdle,
3903 req: func() *Request {
3904 return newRequest("GET", "http://fake.golang", nil)
3905 },
3906 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3907 },
3908 {
3909 name: "IdempotentGetBodySomeWritten",
3910
3911
3912 failureN: 1,
3913
3914 failureErr: ExportErrServerClosedIdle,
3915 req: func() *Request {
3916 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3917 },
3918 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3919 },
3920 {
3921 name: "NothingWrittenNoBody",
3922
3923
3924 failureN: 0,
3925 failureErr: errors.New("second write fails"),
3926 req: func() *Request {
3927 return newRequest("DELETE", "http://fake.golang", nil)
3928 },
3929 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3930 },
3931 {
3932 name: "NothingWrittenGetBody",
3933
3934
3935 failureN: 0,
3936 failureErr: errors.New("second write fails"),
3937
3938
3939 req: func() *Request {
3940 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3941 },
3942 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3943 },
3944 }
3945
3946 for _, tc := range testCases {
3947 t.Run(tc.name, func(t *testing.T) {
3948 var (
3949 mu sync.Mutex
3950 logbuf strings.Builder
3951 )
3952 logf := func(format string, args ...any) {
3953 mu.Lock()
3954 defer mu.Unlock()
3955 fmt.Fprintf(&logbuf, format, args...)
3956 logbuf.WriteByte('\n')
3957 }
3958
3959 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3960 logf("Handler")
3961 w.Header().Set("X-Status", "ok")
3962 })).ts
3963
3964 var writeNumAtomic int32
3965 c := ts.Client()
3966 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3967 logf("Dial")
3968 c, err := net.Dial(network, ts.Listener.Addr().String())
3969 if err != nil {
3970 logf("Dial error: %v", err)
3971 return nil, err
3972 }
3973 return &writerFuncConn{
3974 Conn: c,
3975 write: func(p []byte) (n int, err error) {
3976 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3977 logf("intentional write failure")
3978 return tc.failureN, tc.failureErr
3979 }
3980 logf("Write(%q)", p)
3981 return c.Write(p)
3982 },
3983 }, nil
3984 }
3985
3986 SetRoundTripRetried(func() {
3987 logf("Retried.")
3988 })
3989 defer SetRoundTripRetried(nil)
3990
3991 for i := 0; i < 3; i++ {
3992 t0 := time.Now()
3993 req := tc.req()
3994 res, err := c.Do(req)
3995 if err != nil {
3996 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3997 mu.Lock()
3998 got := logbuf.String()
3999 mu.Unlock()
4000 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
4001 }
4002 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
4003 }
4004 res.Body.Close()
4005 if res.Request != req {
4006 t.Errorf("Response.Request != original request; want identical Request")
4007 }
4008 }
4009
4010 mu.Lock()
4011 got := logbuf.String()
4012 mu.Unlock()
4013 want := fmt.Sprintf(`Dial
4014 Write("%s")
4015 Handler
4016 intentional write failure
4017 Retried.
4018 Dial
4019 Write("%s")
4020 Handler
4021 Write("%s")
4022 Handler
4023 `, tc.reqString, tc.reqString, tc.reqString)
4024 if got != want {
4025 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
4026 }
4027 })
4028 }
4029 }
4030
4031
4032 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
4033 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
4034 readBody := make(chan error, 1)
4035 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4036 _, err := io.ReadAll(r.Body)
4037 readBody <- err
4038 })).ts
4039 c := ts.Client()
4040 fakeErr := errors.New("fake error")
4041 didClose := make(chan bool, 1)
4042 req, _ := NewRequest("POST", ts.URL, struct {
4043 io.Reader
4044 io.Closer
4045 }{
4046 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
4047 closerFunc(func() error {
4048 select {
4049 case didClose <- true:
4050 default:
4051 }
4052 return nil
4053 }),
4054 })
4055 res, err := c.Do(req)
4056 if res != nil {
4057 defer res.Body.Close()
4058 }
4059 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
4060 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
4061 }
4062 if err := <-readBody; err == nil {
4063 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
4064 }
4065 select {
4066 case <-didClose:
4067 default:
4068 t.Errorf("didn't see Body.Close")
4069 }
4070 }
4071
4072 func TestTransportDialTLS(t *testing.T) {
4073 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4074 }
4075 func testTransportDialTLS(t *testing.T, mode testMode) {
4076 var mu sync.Mutex
4077 var gotReq, didDial bool
4078
4079 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4080 mu.Lock()
4081 gotReq = true
4082 mu.Unlock()
4083 })).ts
4084 c := ts.Client()
4085 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4086 mu.Lock()
4087 didDial = true
4088 mu.Unlock()
4089 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4090 if err != nil {
4091 return nil, err
4092 }
4093 return c, c.Handshake()
4094 }
4095
4096 res, err := c.Get(ts.URL)
4097 if err != nil {
4098 t.Fatal(err)
4099 }
4100 res.Body.Close()
4101 mu.Lock()
4102 if !gotReq {
4103 t.Error("didn't get request")
4104 }
4105 if !didDial {
4106 t.Error("didn't use dial hook")
4107 }
4108 }
4109
4110 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4111 func testTransportDialContext(t *testing.T, mode testMode) {
4112 ctxKey := "some-key"
4113 ctxValue := "some-value"
4114 var (
4115 mu sync.Mutex
4116 gotReq bool
4117 gotCtxValue any
4118 )
4119
4120 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4121 mu.Lock()
4122 gotReq = true
4123 mu.Unlock()
4124 })).ts
4125 c := ts.Client()
4126 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4127 mu.Lock()
4128 gotCtxValue = ctx.Value(ctxKey)
4129 mu.Unlock()
4130 return net.Dial(netw, addr)
4131 }
4132
4133 req, err := NewRequest("GET", ts.URL, nil)
4134 if err != nil {
4135 t.Fatal(err)
4136 }
4137 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4138 res, err := c.Do(req.WithContext(ctx))
4139 if err != nil {
4140 t.Fatal(err)
4141 }
4142 res.Body.Close()
4143 mu.Lock()
4144 if !gotReq {
4145 t.Error("didn't get request")
4146 }
4147 if got, want := gotCtxValue, ctxValue; got != want {
4148 t.Errorf("got context with value %v, want %v", got, want)
4149 }
4150 }
4151
4152 func TestTransportDialTLSContext(t *testing.T) {
4153 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4154 }
4155 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4156 ctxKey := "some-key"
4157 ctxValue := "some-value"
4158 var (
4159 mu sync.Mutex
4160 gotReq bool
4161 gotCtxValue any
4162 )
4163
4164 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4165 mu.Lock()
4166 gotReq = true
4167 mu.Unlock()
4168 })).ts
4169 c := ts.Client()
4170 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4171 mu.Lock()
4172 gotCtxValue = ctx.Value(ctxKey)
4173 mu.Unlock()
4174 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4175 if err != nil {
4176 return nil, err
4177 }
4178 return c, c.HandshakeContext(ctx)
4179 }
4180
4181 req, err := NewRequest("GET", ts.URL, nil)
4182 if err != nil {
4183 t.Fatal(err)
4184 }
4185 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4186 res, err := c.Do(req.WithContext(ctx))
4187 if err != nil {
4188 t.Fatal(err)
4189 }
4190 res.Body.Close()
4191 mu.Lock()
4192 if !gotReq {
4193 t.Error("didn't get request")
4194 }
4195 if got, want := gotCtxValue, ctxValue; got != want {
4196 t.Errorf("got context with value %v, want %v", got, want)
4197 }
4198 }
4199
4200
4201
4202 func TestRoundTripReturnsProxyError(t *testing.T) {
4203 badProxy := func(*Request) (*url.URL, error) {
4204 return nil, errors.New("errorMessage")
4205 }
4206
4207 tr := &Transport{Proxy: badProxy}
4208
4209 req, _ := NewRequest("GET", "http://example.com", nil)
4210
4211 _, err := tr.RoundTrip(req)
4212
4213 if err == nil {
4214 t.Error("Expected proxy error to be returned by RoundTrip")
4215 }
4216 }
4217
4218
4219 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4220 tr := &Transport{}
4221 wantIdle := func(when string, n int) bool {
4222 got := tr.IdleConnCountForTesting("http", "example.com")
4223 if got == n {
4224 return true
4225 }
4226 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4227 return false
4228 }
4229 wantIdle("start", 0)
4230 if !tr.PutIdleTestConn("http", "example.com") {
4231 t.Fatal("put failed")
4232 }
4233 if !tr.PutIdleTestConn("http", "example.com") {
4234 t.Fatal("second put failed")
4235 }
4236 wantIdle("after put", 2)
4237 tr.CloseIdleConnections()
4238 if !tr.IsIdleForTesting() {
4239 t.Error("should be idle after CloseIdleConnections")
4240 }
4241 wantIdle("after close idle", 0)
4242 if tr.PutIdleTestConn("http", "example.com") {
4243 t.Fatal("put didn't fail")
4244 }
4245 wantIdle("after second put", 0)
4246
4247 tr.QueueForIdleConnForTesting()
4248 if tr.IsIdleForTesting() {
4249 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4250 }
4251 if !tr.PutIdleTestConn("http", "example.com") {
4252 t.Fatal("after re-activation")
4253 }
4254 wantIdle("after final put", 1)
4255 }
4256
4257
4258
4259 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4260 tr := &Transport{}
4261 wantIdle := func(when string, n int) bool {
4262 got := tr.IdleConnCountForTesting("https", "example.com:443")
4263 if got == n {
4264 return true
4265 }
4266 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4267 return false
4268 }
4269 wantIdle("start", 0)
4270 alt := funcRoundTripper(func() {})
4271 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4272 t.Fatal("put failed")
4273 }
4274 wantIdle("after put", 1)
4275 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4276 GotConn: func(httptrace.GotConnInfo) {
4277
4278 t.Error("GotConn called")
4279 },
4280 })
4281 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4282 _, err := tr.RoundTrip(req)
4283 if err != errFakeRoundTrip {
4284 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4285 }
4286 wantIdle("after round trip", 1)
4287 }
4288
4289
4290
4291
4292
4293
4294 func TestTransportIdleConnRacesRequest(t *testing.T) {
4295
4296
4297 runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
4298 }
4299 func testTransportIdleConnRacesRequest(t *testing.T, mode testMode) {
4300 if mode == http2UnencryptedMode {
4301 t.Skip("remove skip when #70515 is fixed")
4302 }
4303 timeout := 1 * time.Millisecond
4304 trFunc := func(tr *Transport) {
4305 tr.IdleConnTimeout = timeout
4306 }
4307 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4308 }), trFunc, optFakeNet)
4309 cst.li.trackConns = true
4310
4311
4312
4313
4314
4315 dialc := make(chan struct{})
4316 cst.li.onDial = func() {
4317 <-dialc
4318 }
4319 closec := make(chan struct{})
4320 cst.li.onClose = func(*fakeNetConn) {
4321 <-closec
4322 }
4323 ctx, cancel := context.WithCancel(context.Background())
4324 req1c := make(chan error)
4325 go func() {
4326 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
4327 resp, err := cst.c.Do(req)
4328 if err == nil {
4329 resp.Body.Close()
4330 }
4331 req1c <- err
4332 }()
4333
4334 synctest.Wait()
4335
4336 cancel()
4337 synctest.Wait()
4338 if err := <-req1c; err == nil {
4339 t.Fatal("expected request to fail, but it succeeded")
4340 }
4341
4342 close(dialc)
4343
4344
4345
4346
4347
4348
4349
4350
4351 synctest.Wait()
4352 time.Sleep(timeout)
4353 synctest.Wait()
4354
4355 req2c := make(chan error)
4356 go func() {
4357 resp, err := cst.c.Get(cst.ts.URL)
4358 if err == nil {
4359 resp.Body.Close()
4360 }
4361 req2c <- err
4362 }()
4363
4364
4365 close(closec)
4366 if err := <-req2c; err != nil {
4367 t.Fatalf("Get: %v", err)
4368 }
4369 }
4370
4371 func TestTransportRemovesConnsAfterIdle(t *testing.T) {
4372 runSynctest(t, testTransportRemovesConnsAfterIdle)
4373 }
4374 func testTransportRemovesConnsAfterIdle(t *testing.T, mode testMode) {
4375 if testing.Short() {
4376 t.Skip("skipping in short mode")
4377 }
4378
4379 timeout := 1 * time.Second
4380 trFunc := func(tr *Transport) {
4381 tr.MaxConnsPerHost = 1
4382 tr.MaxIdleConnsPerHost = 1
4383 tr.IdleConnTimeout = timeout
4384 }
4385 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4386 w.Header().Set("X-Addr", r.RemoteAddr)
4387 }), trFunc, optFakeNet)
4388
4389
4390
4391 makeRequest := func() string {
4392 resp, err := cst.c.Get(cst.ts.URL)
4393 if err != nil {
4394 t.Fatalf("got error: %s", err)
4395 }
4396 resp.Body.Close()
4397 return resp.Header.Get("X-Addr")
4398 }
4399
4400 addr1 := makeRequest()
4401
4402 time.Sleep(timeout / 2)
4403 synctest.Wait()
4404 addr2 := makeRequest()
4405 if addr1 != addr2 {
4406 t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
4407 }
4408
4409 time.Sleep(timeout)
4410 synctest.Wait()
4411 addr3 := makeRequest()
4412 if addr1 == addr3 {
4413 t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
4414 }
4415 }
4416
4417 func TestTransportRemovesConnsAfterBroken(t *testing.T) {
4418 runSynctest(t, testTransportRemovesConnsAfterBroken)
4419 }
4420 func testTransportRemovesConnsAfterBroken(t *testing.T, mode testMode) {
4421 if testing.Short() {
4422 t.Skip("skipping in short mode")
4423 }
4424
4425 trFunc := func(tr *Transport) {
4426 tr.MaxConnsPerHost = 1
4427 tr.MaxIdleConnsPerHost = 1
4428 }
4429 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4430 w.Header().Set("X-Addr", r.RemoteAddr)
4431 }), trFunc, optFakeNet)
4432 cst.li.trackConns = true
4433
4434
4435
4436 makeRequest := func() string {
4437 resp, err := cst.c.Get(cst.ts.URL)
4438 if err != nil {
4439 t.Fatalf("got error: %s", err)
4440 }
4441 resp.Body.Close()
4442 return resp.Header.Get("X-Addr")
4443 }
4444
4445 addr1 := makeRequest()
4446 addr2 := makeRequest()
4447 if addr1 != addr2 {
4448 t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
4449 }
4450
4451
4452 synctest.Wait()
4453 cst.li.conns[0].peer.Close()
4454 synctest.Wait()
4455 addr3 := makeRequest()
4456 if addr1 == addr3 {
4457 t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
4458 }
4459 }
4460
4461
4462
4463
4464
4465 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4466 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4467 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4468 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4469 t.Error("Transport advertised gzip support in the Accept header")
4470 }
4471 if r.Header.Get("Range") == "" {
4472 t.Error("no Range in request")
4473 }
4474 })).ts
4475 c := ts.Client()
4476
4477 req, _ := NewRequest("GET", ts.URL, nil)
4478 req.Header.Set("Range", "bytes=7-11")
4479 res, err := c.Do(req)
4480 if err != nil {
4481 t.Fatal(err)
4482 }
4483 res.Body.Close()
4484 }
4485
4486
4487 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4488 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4489 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4490
4491 var b [1024]byte
4492 w.Write(b[:])
4493 })).ts
4494 tr := ts.Client().Transport.(*Transport)
4495
4496 req, err := NewRequest("GET", ts.URL, nil)
4497 if err != nil {
4498 t.Fatal(err)
4499 }
4500 res, err := tr.RoundTrip(req)
4501 if err != nil {
4502 t.Fatal(err)
4503 }
4504
4505
4506
4507 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4508 t.Fatal(err)
4509 }
4510
4511 req2, err := NewRequest("GET", ts.URL, nil)
4512 if err != nil {
4513 t.Fatal(err)
4514 }
4515 tr.CancelRequest(req)
4516 res, err = tr.RoundTrip(req2)
4517 if err != nil {
4518 t.Fatal(err)
4519 }
4520 res.Body.Close()
4521 }
4522
4523
4524 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4525 run(t, testTransportContentEncodingCaseInsensitive)
4526 }
4527 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4528 for _, ce := range []string{"gzip", "GZIP"} {
4529 ce := ce
4530 t.Run(ce, func(t *testing.T) {
4531 const encodedString = "Hello Gopher"
4532 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4533 w.Header().Set("Content-Encoding", ce)
4534 gz := gzip.NewWriter(w)
4535 gz.Write([]byte(encodedString))
4536 gz.Close()
4537 })).ts
4538
4539 res, err := ts.Client().Get(ts.URL)
4540 if err != nil {
4541 t.Fatal(err)
4542 }
4543
4544 body, err := io.ReadAll(res.Body)
4545 res.Body.Close()
4546 if err != nil {
4547 t.Fatal(err)
4548 }
4549
4550 if string(body) != encodedString {
4551 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4552 }
4553 })
4554 }
4555 }
4556
4557
4558 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4559 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4560 }
4561 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4562 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4563 func(tr *Transport) {
4564 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4565
4566 return &funcConn{
4567 read: func([]byte) (int, error) {
4568 return 0, errors.New("error")
4569 },
4570 write: func([]byte) (int, error) {
4571 return 0, errors.New("error")
4572 },
4573 }, nil
4574 }
4575 },
4576 ).ts
4577
4578
4579
4580
4581
4582 SetEnterRoundTripHook(func() {
4583 time.Sleep(1 * time.Millisecond)
4584 })
4585 defer SetEnterRoundTripHook(nil)
4586 var closes int
4587 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4588 if err == nil {
4589 t.Fatalf("expected request to fail, but it did not")
4590 }
4591 if closes != 1 {
4592 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4593 }
4594 }
4595
4596
4597
4598
4599 type logWritesConn struct {
4600 net.Conn
4601
4602 w io.Writer
4603
4604 rch <-chan io.Reader
4605 r io.Reader
4606
4607 mu sync.Mutex
4608 writes []string
4609 }
4610
4611 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4612 c.mu.Lock()
4613 defer c.mu.Unlock()
4614 c.writes = append(c.writes, string(p))
4615 return c.w.Write(p)
4616 }
4617
4618 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4619 if c.r == nil {
4620 c.r = <-c.rch
4621 }
4622 return c.r.Read(p)
4623 }
4624
4625 func (c *logWritesConn) Close() error { return nil }
4626
4627
4628 func TestTransportFlushesBodyChunks(t *testing.T) {
4629 defer afterTest(t)
4630 resBody := make(chan io.Reader, 1)
4631 connr, connw := io.Pipe()
4632 lw := &logWritesConn{
4633 rch: resBody,
4634 w: connw,
4635 }
4636 tr := &Transport{
4637 Dial: func(network, addr string) (net.Conn, error) {
4638 return lw, nil
4639 },
4640 }
4641 bodyr, bodyw := io.Pipe()
4642 go func() {
4643 defer bodyw.Close()
4644 for i := 0; i < 3; i++ {
4645 fmt.Fprintf(bodyw, "num%d\n", i)
4646 }
4647 }()
4648 resc := make(chan *Response)
4649 go func() {
4650 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4651 req.Header.Set("User-Agent", "x")
4652 res, err := tr.RoundTrip(req)
4653 if err != nil {
4654 t.Errorf("RoundTrip: %v", err)
4655 close(resc)
4656 return
4657 }
4658 resc <- res
4659
4660 }()
4661
4662 req, err := ReadRequest(bufio.NewReader(connr))
4663 if err != nil {
4664 t.Fatal(err)
4665 }
4666 io.Copy(io.Discard, req.Body)
4667
4668
4669 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4670 res, ok := <-resc
4671 if !ok {
4672 return
4673 }
4674 defer res.Body.Close()
4675
4676 want := []string{
4677 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4678 "5\r\nnum0\n\r\n",
4679 "5\r\nnum1\n\r\n",
4680 "5\r\nnum2\n\r\n",
4681 "0\r\n\r\n",
4682 }
4683 if !slices.Equal(lw.writes, want) {
4684 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4685 }
4686 }
4687
4688
4689 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4690 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4691 gotReq := make(chan struct{})
4692 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4693 close(gotReq)
4694 }))
4695
4696 pr, pw := io.Pipe()
4697 req, err := NewRequest("POST", cst.ts.URL, pr)
4698 if err != nil {
4699 t.Fatal(err)
4700 }
4701 gotRes := make(chan struct{})
4702 go func() {
4703 defer close(gotRes)
4704 res, err := cst.tr.RoundTrip(req)
4705 if err != nil {
4706 t.Error(err)
4707 return
4708 }
4709 res.Body.Close()
4710 }()
4711
4712 <-gotReq
4713 pw.Close()
4714 <-gotRes
4715 }
4716
4717 type wgReadCloser struct {
4718 io.Reader
4719 wg *sync.WaitGroup
4720 closed bool
4721 }
4722
4723 func (c *wgReadCloser) Close() error {
4724 if c.closed {
4725 return net.ErrClosed
4726 }
4727 c.closed = true
4728 c.wg.Done()
4729 return nil
4730 }
4731
4732
4733 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4734
4735 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4736 }
4737 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4738 if testing.Short() {
4739 t.Skip("skipping in short mode")
4740 }
4741
4742 runTimeSensitiveTest(t, []time.Duration{
4743 1 * time.Millisecond,
4744 5 * time.Millisecond,
4745 10 * time.Millisecond,
4746 50 * time.Millisecond,
4747 100 * time.Millisecond,
4748 500 * time.Millisecond,
4749 time.Second,
4750 5 * time.Second,
4751 }, func(t *testing.T, timeout time.Duration) error {
4752 SetRSTAvoidanceDelay(t, timeout)
4753 t.Logf("set RST avoidance delay to %v", timeout)
4754
4755 const contentLengthLimit = 1024 * 1024
4756 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4757 if r.ContentLength >= contentLengthLimit {
4758 w.WriteHeader(StatusBadRequest)
4759 r.Body.Close()
4760 return
4761 }
4762 w.WriteHeader(StatusOK)
4763 }))
4764
4765
4766 defer cst.close()
4767 ts := cst.ts
4768 c := ts.Client()
4769
4770 count := 100
4771
4772 bigBody := strings.Repeat("a", contentLengthLimit*2)
4773 var wg sync.WaitGroup
4774 defer wg.Wait()
4775 getBody := func() (io.ReadCloser, error) {
4776 wg.Add(1)
4777 body := &wgReadCloser{
4778 Reader: strings.NewReader(bigBody),
4779 wg: &wg,
4780 }
4781 return body, nil
4782 }
4783
4784 for i := 0; i < count; i++ {
4785 reqBody, _ := getBody()
4786 req, err := NewRequest("PUT", ts.URL, reqBody)
4787 if err != nil {
4788 reqBody.Close()
4789 t.Fatal(err)
4790 }
4791 req.ContentLength = int64(len(bigBody))
4792 req.GetBody = getBody
4793
4794 resp, err := c.Do(req)
4795 if err != nil {
4796 return fmt.Errorf("Do %d: %v", i, err)
4797 } else {
4798 resp.Body.Close()
4799 if resp.StatusCode != 400 {
4800 t.Errorf("Expected status code 400, got %v", resp.Status)
4801 }
4802 }
4803 }
4804 return nil
4805 })
4806 }
4807
4808 func TestTransportAutomaticHTTP2(t *testing.T) {
4809 testTransportAutoHTTP(t, &Transport{}, true)
4810 }
4811
4812 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4813 testTransportAutoHTTP(t, &Transport{
4814 ForceAttemptHTTP2: true,
4815 TLSClientConfig: new(tls.Config),
4816 }, true)
4817 }
4818
4819
4820 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4821 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4822 }
4823
4824 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4825 testTransportAutoHTTP(t, &Transport{
4826 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4827 }, false)
4828 }
4829
4830 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4831 testTransportAutoHTTP(t, &Transport{
4832 TLSClientConfig: new(tls.Config),
4833 }, false)
4834 }
4835
4836 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4837 testTransportAutoHTTP(t, &Transport{
4838 ExpectContinueTimeout: 1 * time.Second,
4839 }, true)
4840 }
4841
4842 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4843 var d net.Dialer
4844 testTransportAutoHTTP(t, &Transport{
4845 Dial: d.Dial,
4846 }, false)
4847 }
4848
4849 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4850 var d net.Dialer
4851 testTransportAutoHTTP(t, &Transport{
4852 DialContext: d.DialContext,
4853 }, false)
4854 }
4855
4856 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4857 testTransportAutoHTTP(t, &Transport{
4858 DialTLS: func(network, addr string) (net.Conn, error) {
4859 panic("unused")
4860 },
4861 }, false)
4862 }
4863
4864 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4865 CondSkipHTTP2(t)
4866 _, err := tr.RoundTrip(new(Request))
4867 if err == nil {
4868 t.Error("expected error from RoundTrip")
4869 }
4870 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4871 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4872 }
4873 }
4874
4875
4876
4877
4878
4879
4880
4881
4882 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4883 run(t, testTransportReuseConnEmptyResponseBody)
4884 }
4885 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4886 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4887 w.Header().Set("X-Addr", r.RemoteAddr)
4888
4889 }))
4890 n := 100
4891 if testing.Short() {
4892 n = 10
4893 }
4894 var firstAddr string
4895 for i := 0; i < n; i++ {
4896 res, err := cst.c.Get(cst.ts.URL)
4897 if err != nil {
4898 log.Fatal(err)
4899 }
4900 addr := res.Header.Get("X-Addr")
4901 if i == 0 {
4902 firstAddr = addr
4903 } else if addr != firstAddr {
4904 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4905 }
4906 res.Body.Close()
4907 }
4908 }
4909
4910
4911 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4912 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4913 if err != nil {
4914 t.Fatal(err)
4915 }
4916 ln := newLocalListener(t)
4917 defer ln.Close()
4918
4919 var wg sync.WaitGroup
4920 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4921 defer SetPendingDialHooks(nil, nil)
4922
4923 testDone := make(chan struct{})
4924 defer close(testDone)
4925 go func() {
4926 tln := tls.NewListener(ln, &tls.Config{
4927 NextProtos: []string{"foo"},
4928 Certificates: []tls.Certificate{cert},
4929 })
4930 sc, err := tln.Accept()
4931 if err != nil {
4932 t.Error(err)
4933 return
4934 }
4935 if err := sc.(*tls.Conn).Handshake(); err != nil {
4936 t.Error(err)
4937 return
4938 }
4939 <-testDone
4940 sc.Close()
4941 }()
4942
4943 addr := ln.Addr().String()
4944
4945 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4946 cancel := make(chan struct{})
4947 req.Cancel = cancel
4948
4949 doReturned := make(chan bool, 1)
4950 madeRoundTripper := make(chan bool, 1)
4951
4952 tr := &Transport{
4953 DisableKeepAlives: true,
4954 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4955 "foo": func(authority string, c *tls.Conn) RoundTripper {
4956 madeRoundTripper <- true
4957 return funcRoundTripper(func() {
4958 t.Error("foo RoundTripper should not be called")
4959 })
4960 },
4961 },
4962 Dial: func(_, _ string) (net.Conn, error) {
4963 panic("shouldn't be called")
4964 },
4965 DialTLS: func(_, _ string) (net.Conn, error) {
4966 tc, err := tls.Dial("tcp", addr, &tls.Config{
4967 InsecureSkipVerify: true,
4968 NextProtos: []string{"foo"},
4969 })
4970 if err != nil {
4971 return nil, err
4972 }
4973 if err := tc.Handshake(); err != nil {
4974 return nil, err
4975 }
4976 close(cancel)
4977 <-doReturned
4978 return tc, nil
4979 },
4980 }
4981 c := &Client{Transport: tr}
4982
4983 _, err = c.Do(req)
4984 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4985 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4986 }
4987
4988 doReturned <- true
4989 <-madeRoundTripper
4990 wg.Wait()
4991 }
4992
4993 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4994 run(t, func(t *testing.T, mode testMode) {
4995 testTransportReuseConnection_Gzip(t, mode, true)
4996 })
4997 }
4998
4999 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
5000 run(t, func(t *testing.T, mode testMode) {
5001 testTransportReuseConnection_Gzip(t, mode, false)
5002 })
5003 }
5004
5005
5006 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
5007 addr := make(chan string, 2)
5008 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5009 addr <- r.RemoteAddr
5010 w.Header().Set("Content-Encoding", "gzip")
5011 if chunked {
5012 w.(Flusher).Flush()
5013 }
5014 w.Write(rgz)
5015 })).ts
5016 c := ts.Client()
5017
5018 trace := &httptrace.ClientTrace{
5019 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
5020 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
5021 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
5022 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
5023 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
5024 }
5025 ctx := httptrace.WithClientTrace(context.Background(), trace)
5026
5027 for i := 0; i < 2; i++ {
5028 req, _ := NewRequest("GET", ts.URL, nil)
5029 req = req.WithContext(ctx)
5030 res, err := c.Do(req)
5031 if err != nil {
5032 t.Fatal(err)
5033 }
5034 buf := make([]byte, len(rgz))
5035 if n, err := io.ReadFull(res.Body, buf); err != nil {
5036 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
5037 }
5038
5039
5040
5041 }
5042 a1, a2 := <-addr, <-addr
5043 if a1 != a2 {
5044 t.Fatalf("didn't reuse connection")
5045 }
5046 }
5047
5048 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
5049 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
5050 if mode == http2Mode {
5051 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
5052 }
5053 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5054 if r.URL.Path == "/long" {
5055 w.Header().Set("Long", strings.Repeat("a", 1<<20))
5056 }
5057 })).ts
5058 c := ts.Client()
5059 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
5060
5061 if res, err := c.Get(ts.URL); err != nil {
5062 t.Fatal(err)
5063 } else {
5064 res.Body.Close()
5065 }
5066
5067 res, err := c.Get(ts.URL + "/long")
5068 if err == nil {
5069 defer res.Body.Close()
5070 var n int64
5071 for k, vv := range res.Header {
5072 for _, v := range vv {
5073 n += int64(len(k)) + int64(len(v))
5074 }
5075 }
5076 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
5077 }
5078 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
5079 t.Errorf("got error: %v; want %q", err, want)
5080 }
5081 }
5082
5083 func TestTransportEventTrace(t *testing.T) {
5084 run(t, func(t *testing.T, mode testMode) {
5085 testTransportEventTrace(t, mode, false)
5086 }, testNotParallel)
5087 }
5088
5089
5090 func TestTransportEventTrace_NoHooks(t *testing.T) {
5091 run(t, func(t *testing.T, mode testMode) {
5092 testTransportEventTrace(t, mode, true)
5093 }, testNotParallel)
5094 }
5095
5096 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
5097 const resBody = "some body"
5098 gotWroteReqEvent := make(chan struct{}, 500)
5099 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5100 if r.Method == "GET" {
5101
5102 return
5103 }
5104 if _, err := io.ReadAll(r.Body); err != nil {
5105 t.Error(err)
5106 }
5107 if !noHooks {
5108 <-gotWroteReqEvent
5109 }
5110 io.WriteString(w, resBody)
5111 }), func(tr *Transport) {
5112 if tr.TLSClientConfig != nil {
5113 tr.TLSClientConfig.InsecureSkipVerify = true
5114 }
5115 })
5116 defer cst.close()
5117
5118 cst.tr.ExpectContinueTimeout = 1 * time.Second
5119
5120 var mu sync.Mutex
5121 var buf strings.Builder
5122 logf := func(format string, args ...any) {
5123 mu.Lock()
5124 defer mu.Unlock()
5125 fmt.Fprintf(&buf, format, args...)
5126 buf.WriteByte('\n')
5127 }
5128
5129 addrStr := cst.ts.Listener.Addr().String()
5130 ip, port, err := net.SplitHostPort(addrStr)
5131 if err != nil {
5132 t.Fatal(err)
5133 }
5134
5135
5136 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5137 if host != "dns-is-faked.golang" {
5138 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
5139 return nil, nil
5140 }
5141 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5142 })
5143
5144 body := "some body"
5145 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
5146 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
5147 trace := &httptrace.ClientTrace{
5148 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
5149 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
5150 GotFirstResponseByte: func() { logf("first response byte") },
5151 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
5152 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
5153 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
5154 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
5155 ConnectDone: func(network, addr string, err error) {
5156 if err != nil {
5157 t.Errorf("ConnectDone: %v", err)
5158 }
5159 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
5160 },
5161 WroteHeaderField: func(key string, value []string) {
5162 logf("WroteHeaderField: %s: %v", key, value)
5163 },
5164 WroteHeaders: func() {
5165 logf("WroteHeaders")
5166 },
5167 Wait100Continue: func() { logf("Wait100Continue") },
5168 Got100Continue: func() { logf("Got100Continue") },
5169 WroteRequest: func(e httptrace.WroteRequestInfo) {
5170 logf("WroteRequest: %+v", e)
5171 gotWroteReqEvent <- struct{}{}
5172 },
5173 }
5174 if mode == http2Mode {
5175 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
5176 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
5177 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
5178 }
5179 }
5180 if noHooks {
5181
5182 *trace = httptrace.ClientTrace{}
5183 }
5184 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5185
5186 req.Header.Set("Expect", "100-continue")
5187 res, err := cst.c.Do(req)
5188 if err != nil {
5189 t.Fatal(err)
5190 }
5191 logf("got roundtrip.response")
5192 slurp, err := io.ReadAll(res.Body)
5193 if err != nil {
5194 t.Fatal(err)
5195 }
5196 logf("consumed body")
5197 if string(slurp) != resBody || res.StatusCode != 200 {
5198 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5199 }
5200 res.Body.Close()
5201
5202 if noHooks {
5203
5204
5205
5206 return
5207 }
5208
5209 mu.Lock()
5210 got := buf.String()
5211 mu.Unlock()
5212
5213 wantOnce := func(sub string) {
5214 if strings.Count(got, sub) != 1 {
5215 t.Errorf("expected substring %q exactly once in output.", sub)
5216 }
5217 }
5218 wantOnceOrMore := func(sub string) {
5219 if strings.Count(got, sub) == 0 {
5220 t.Errorf("expected substring %q at least once in output.", sub)
5221 }
5222 }
5223 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5224 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5225 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5226 wantOnce("got conn: {")
5227 wantOnceOrMore("Connecting to tcp " + addrStr)
5228 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5229 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5230 wantOnce("first response byte")
5231 if mode == http2Mode {
5232 wantOnce("tls handshake start")
5233 wantOnce("tls handshake done")
5234 } else {
5235 wantOnce("PutIdleConn = <nil>")
5236 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5237
5238
5239 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5240 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5241 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5242 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5243 }
5244 wantOnce("WroteHeaders")
5245 wantOnce("Wait100Continue")
5246 wantOnce("Got100Continue")
5247 wantOnce("WroteRequest: {Err:<nil>}")
5248 if strings.Contains(got, " to udp ") {
5249 t.Errorf("should not see UDP (DNS) connections")
5250 }
5251 if t.Failed() {
5252 t.Errorf("Output:\n%s", got)
5253 }
5254
5255
5256 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5257 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5258 res, err = cst.c.Do(req)
5259 if err != nil {
5260 t.Fatal(err)
5261 }
5262 if res.StatusCode != 200 {
5263 t.Fatal(res.Status)
5264 }
5265 res.Body.Close()
5266
5267 mu.Lock()
5268 got = buf.String()
5269 mu.Unlock()
5270
5271 sub := "Getting conn for dns-is-faked.golang:"
5272 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5273 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5274 }
5275
5276 }
5277
5278 func TestTransportEventTraceTLSVerify(t *testing.T) {
5279 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5280 }
5281 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5282 var mu sync.Mutex
5283 var buf strings.Builder
5284 logf := func(format string, args ...any) {
5285 mu.Lock()
5286 defer mu.Unlock()
5287 fmt.Fprintf(&buf, format, args...)
5288 buf.WriteByte('\n')
5289 }
5290
5291 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5292 t.Error("Unexpected request")
5293 }), func(ts *httptest.Server) {
5294 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5295 logf("%s", p)
5296 return len(p), nil
5297 }), "", 0)
5298 }).ts
5299
5300 certpool := x509.NewCertPool()
5301 certpool.AddCert(ts.Certificate())
5302
5303 c := &Client{Transport: &Transport{
5304 TLSClientConfig: &tls.Config{
5305 ServerName: "dns-is-faked.golang",
5306 RootCAs: certpool,
5307 },
5308 }}
5309
5310 trace := &httptrace.ClientTrace{
5311 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5312 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5313 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5314 },
5315 }
5316
5317 req, _ := NewRequest("GET", ts.URL, nil)
5318 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5319 _, err := c.Do(req)
5320 if err == nil {
5321 t.Error("Expected request to fail TLS verification")
5322 }
5323
5324 mu.Lock()
5325 got := buf.String()
5326 mu.Unlock()
5327
5328 wantOnce := func(sub string) {
5329 if strings.Count(got, sub) != 1 {
5330 t.Errorf("expected substring %q exactly once in output.", sub)
5331 }
5332 }
5333
5334 wantOnce("TLSHandshakeStart")
5335 wantOnce("TLSHandshakeDone")
5336 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5337
5338 if t.Failed() {
5339 t.Errorf("Output:\n%s", got)
5340 }
5341 }
5342
5343 var isDNSHijacked = sync.OnceValue(func() bool {
5344 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5345 return len(addrs) != 0
5346 })
5347
5348 func skipIfDNSHijacked(t *testing.T) {
5349
5350
5351
5352 if isDNSHijacked() {
5353 t.Skip("skipping; test requires non-hijacking DNS server")
5354 }
5355 }
5356
5357 func TestTransportEventTraceRealDNS(t *testing.T) {
5358 skipIfDNSHijacked(t)
5359 defer afterTest(t)
5360 tr := &Transport{}
5361 defer tr.CloseIdleConnections()
5362 c := &Client{Transport: tr}
5363
5364 var mu sync.Mutex
5365 var buf strings.Builder
5366 logf := func(format string, args ...any) {
5367 mu.Lock()
5368 defer mu.Unlock()
5369 fmt.Fprintf(&buf, format, args...)
5370 buf.WriteByte('\n')
5371 }
5372
5373 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5374 trace := &httptrace.ClientTrace{
5375 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5376 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5377 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5378 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5379 }
5380 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5381
5382 resp, err := c.Do(req)
5383 if err == nil {
5384 resp.Body.Close()
5385 t.Fatal("expected error during DNS lookup")
5386 }
5387
5388 mu.Lock()
5389 got := buf.String()
5390 mu.Unlock()
5391
5392 wantSub := func(sub string) {
5393 if !strings.Contains(got, sub) {
5394 t.Errorf("expected substring %q in output.", sub)
5395 }
5396 }
5397 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5398 wantSub("DNSDone: {Addrs:[] Err:")
5399 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5400 t.Errorf("should not see Connect events")
5401 }
5402 if t.Failed() {
5403 t.Errorf("Output:\n%s", got)
5404 }
5405 }
5406
5407
5408 func TestTransportRejectsAlphaPort(t *testing.T) {
5409 res, err := Get("http://dummy.tld:123foo/bar")
5410 if err == nil {
5411 res.Body.Close()
5412 t.Fatal("unexpected success")
5413 }
5414 ue, ok := err.(*url.Error)
5415 if !ok {
5416 t.Fatalf("got %#v; want *url.Error", err)
5417 }
5418 got := ue.Err.Error()
5419 want := `invalid port ":123foo" after host`
5420 if got != want {
5421 t.Errorf("got error %q; want %q", got, want)
5422 }
5423 }
5424
5425
5426
5427 func TestTLSHandshakeTrace(t *testing.T) {
5428 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5429 }
5430 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5431 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5432
5433 var mu sync.Mutex
5434 var start, done bool
5435 trace := &httptrace.ClientTrace{
5436 TLSHandshakeStart: func() {
5437 mu.Lock()
5438 defer mu.Unlock()
5439 start = true
5440 },
5441 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5442 mu.Lock()
5443 defer mu.Unlock()
5444 done = true
5445 if err != nil {
5446 t.Fatal("Expected error to be nil but was:", err)
5447 }
5448 },
5449 }
5450
5451 c := ts.Client()
5452 req, err := NewRequest("GET", ts.URL, nil)
5453 if err != nil {
5454 t.Fatal("Unable to construct test request:", err)
5455 }
5456 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5457
5458 r, err := c.Do(req)
5459 if err != nil {
5460 t.Fatal("Unexpected error making request:", err)
5461 }
5462 r.Body.Close()
5463 mu.Lock()
5464 defer mu.Unlock()
5465 if !start {
5466 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5467 }
5468 if !done {
5469 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5470 }
5471 }
5472
5473 func TestTransportMaxIdleConns(t *testing.T) {
5474 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5475 }
5476 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5477 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5478
5479 })).ts
5480 c := ts.Client()
5481 tr := c.Transport.(*Transport)
5482 tr.MaxIdleConns = 4
5483
5484 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5485 if err != nil {
5486 t.Fatal(err)
5487 }
5488 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5489 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5490 })
5491
5492 hitHost := func(n int) {
5493 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5494 req = req.WithContext(ctx)
5495 res, err := c.Do(req)
5496 if err != nil {
5497 t.Fatal(err)
5498 }
5499 res.Body.Close()
5500 }
5501 for i := 0; i < 4; i++ {
5502 hitHost(i)
5503 }
5504 want := []string{
5505 "|http|host-0.dns-is-faked.golang:" + port,
5506 "|http|host-1.dns-is-faked.golang:" + port,
5507 "|http|host-2.dns-is-faked.golang:" + port,
5508 "|http|host-3.dns-is-faked.golang:" + port,
5509 }
5510 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5511 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5512 }
5513
5514
5515 hitHost(4)
5516 want = []string{
5517 "|http|host-1.dns-is-faked.golang:" + port,
5518 "|http|host-2.dns-is-faked.golang:" + port,
5519 "|http|host-3.dns-is-faked.golang:" + port,
5520 "|http|host-4.dns-is-faked.golang:" + port,
5521 }
5522 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5523 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5524 }
5525 }
5526
5527 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5528 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5529 if testing.Short() {
5530 t.Skip("skipping in short mode")
5531 }
5532
5533 timeout := 1 * time.Millisecond
5534 timeoutLoop:
5535 for {
5536 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5537
5538 }))
5539 tr := cst.tr
5540 tr.IdleConnTimeout = timeout
5541 defer tr.CloseIdleConnections()
5542 c := &Client{Transport: tr}
5543
5544 idleConns := func() []string {
5545 if mode == http2Mode {
5546 return tr.IdleConnStrsForTesting_h2()
5547 } else {
5548 return tr.IdleConnStrsForTesting()
5549 }
5550 }
5551
5552 var conn string
5553 doReq := func(n int) (timeoutOk bool) {
5554 req, _ := NewRequest("GET", cst.ts.URL, nil)
5555 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5556 PutIdleConn: func(err error) {
5557 if err != nil {
5558 t.Errorf("failed to keep idle conn: %v", err)
5559 }
5560 },
5561 }))
5562 res, err := c.Do(req)
5563 if err != nil {
5564 if strings.Contains(err.Error(), "use of closed network connection") {
5565 t.Logf("req %v: connection closed prematurely", n)
5566 return false
5567 }
5568 }
5569 if err == nil {
5570 res.Body.Close()
5571 }
5572 conns := idleConns()
5573 if len(conns) != 1 {
5574 if len(conns) == 0 {
5575 t.Logf("req %v: no idle conns", n)
5576 return false
5577 }
5578 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5579 }
5580 if conn == "" {
5581 conn = conns[0]
5582 }
5583 if conn != conns[0] {
5584 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5585 return false
5586 }
5587 return true
5588 }
5589 for i := 0; i < 3; i++ {
5590 if !doReq(i) {
5591 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5592 timeout *= 2
5593 cst.close()
5594 continue timeoutLoop
5595 }
5596 time.Sleep(timeout / 2)
5597 }
5598
5599 waitCondition(t, timeout/2, func(d time.Duration) bool {
5600 if got := idleConns(); len(got) != 0 {
5601 if d >= timeout*3/2 {
5602 t.Logf("after %v, idle conns = %q", d, got)
5603 }
5604 return false
5605 }
5606 return true
5607 })
5608 break
5609 }
5610 }
5611
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
5622
5623 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5624 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5625 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5626
5627 }))
5628
5629 ctx, cancel := context.WithCancel(context.Background())
5630 defer cancel()
5631
5632 sawDoErr := make(chan bool, 1)
5633 testDone := make(chan struct{})
5634 defer close(testDone)
5635
5636 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5637 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5638 c, err := tls.Dial(network, addr, &tls.Config{
5639 InsecureSkipVerify: true,
5640 NextProtos: []string{"h2"},
5641 })
5642 if err != nil {
5643 t.Error(err)
5644 return nil, err
5645 }
5646 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5647 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5648 c.Close()
5649 return nil, errors.New("bogus")
5650 }
5651
5652 cancel()
5653
5654 select {
5655 case <-sawDoErr:
5656 case <-testDone:
5657 }
5658 return c, nil
5659 }
5660
5661 req, _ := NewRequest("GET", cst.ts.URL, nil)
5662 req = req.WithContext(ctx)
5663 res, err := cst.c.Do(req)
5664 if err == nil {
5665 res.Body.Close()
5666 t.Fatal("unexpected success")
5667 }
5668 sawDoErr <- true
5669
5670
5671 time.Sleep(cst.tr.IdleConnTimeout * 10)
5672 }
5673
5674 type funcConn struct {
5675 net.Conn
5676 read func([]byte) (int, error)
5677 write func([]byte) (int, error)
5678 }
5679
5680 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5681 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5682 func (c funcConn) Close() error { return nil }
5683
5684
5685
5686 func TestTransportReturnsPeekError(t *testing.T) {
5687 errValue := errors.New("specific error value")
5688
5689 wrote := make(chan struct{})
5690 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5691
5692 tr := &Transport{
5693 Dial: func(network, addr string) (net.Conn, error) {
5694 c := funcConn{
5695 read: func([]byte) (int, error) {
5696 <-wrote
5697 return 0, errValue
5698 },
5699 write: func(p []byte) (int, error) {
5700 wroteOnce()
5701 return len(p), nil
5702 },
5703 }
5704 return c, nil
5705 },
5706 }
5707 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5708 if err != errValue {
5709 t.Errorf("error = %#v; want %v", err, errValue)
5710 }
5711 }
5712
5713
5714 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5715 func testTransportIDNA(t *testing.T, mode testMode) {
5716 const uniDomain = "гофер.го"
5717 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5718
5719 var port string
5720 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5721 want := punyDomain + ":" + port
5722 if r.Host != want {
5723 t.Errorf("Host header = %q; want %q", r.Host, want)
5724 }
5725 if mode == http2Mode {
5726 if r.TLS == nil {
5727 t.Errorf("r.TLS == nil")
5728 } else if r.TLS.ServerName != punyDomain {
5729 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5730 }
5731 }
5732 w.Header().Set("Hit-Handler", "1")
5733 }), func(tr *Transport) {
5734 if tr.TLSClientConfig != nil {
5735 tr.TLSClientConfig.InsecureSkipVerify = true
5736 }
5737 })
5738
5739 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5740 if err != nil {
5741 t.Fatal(err)
5742 }
5743
5744
5745 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5746 if host != punyDomain {
5747 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5748 return nil, nil
5749 }
5750 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5751 })
5752
5753 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5754 trace := &httptrace.ClientTrace{
5755 GetConn: func(hostPort string) {
5756 want := net.JoinHostPort(punyDomain, port)
5757 if hostPort != want {
5758 t.Errorf("getting conn for %q; want %q", hostPort, want)
5759 }
5760 },
5761 DNSStart: func(e httptrace.DNSStartInfo) {
5762 if e.Host != punyDomain {
5763 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5764 }
5765 },
5766 }
5767 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5768
5769 res, err := cst.tr.RoundTrip(req)
5770 if err != nil {
5771 t.Fatal(err)
5772 }
5773 defer res.Body.Close()
5774 if res.Header.Get("Hit-Handler") != "1" {
5775 out, err := httputil.DumpResponse(res, true)
5776 if err != nil {
5777 t.Fatal(err)
5778 }
5779 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5780 }
5781 }
5782
5783
5784 func TestTransportProxyConnectHeader(t *testing.T) {
5785 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5786 }
5787 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5788 reqc := make(chan *Request, 1)
5789 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5790 if r.Method != "CONNECT" {
5791 t.Errorf("method = %q; want CONNECT", r.Method)
5792 }
5793 reqc <- r
5794 c, _, err := w.(Hijacker).Hijack()
5795 if err != nil {
5796 t.Errorf("Hijack: %v", err)
5797 return
5798 }
5799 c.Close()
5800 })).ts
5801
5802 c := ts.Client()
5803 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5804 return url.Parse(ts.URL)
5805 }
5806 c.Transport.(*Transport).ProxyConnectHeader = Header{
5807 "User-Agent": {"foo"},
5808 "Other": {"bar"},
5809 }
5810
5811 res, err := c.Get("https://dummy.tld/")
5812 if err == nil {
5813 res.Body.Close()
5814 t.Errorf("unexpected success")
5815 }
5816
5817 r := <-reqc
5818 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5819 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5820 }
5821 if got, want := r.Header.Get("Other"), "bar"; got != want {
5822 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5823 }
5824 }
5825
5826 func TestTransportProxyGetConnectHeader(t *testing.T) {
5827 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5828 }
5829 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5830 reqc := make(chan *Request, 1)
5831 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5832 if r.Method != "CONNECT" {
5833 t.Errorf("method = %q; want CONNECT", r.Method)
5834 }
5835 reqc <- r
5836 c, _, err := w.(Hijacker).Hijack()
5837 if err != nil {
5838 t.Errorf("Hijack: %v", err)
5839 return
5840 }
5841 c.Close()
5842 })).ts
5843
5844 c := ts.Client()
5845 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5846 return url.Parse(ts.URL)
5847 }
5848
5849 c.Transport.(*Transport).ProxyConnectHeader = Header{
5850 "User-Agent": {"foo"},
5851 "Other": {"bar"},
5852 }
5853 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5854 return Header{
5855 "User-Agent": {"foo2"},
5856 "Other": {"bar2"},
5857 }, nil
5858 }
5859
5860 res, err := c.Get("https://dummy.tld/")
5861 if err == nil {
5862 res.Body.Close()
5863 t.Errorf("unexpected success")
5864 }
5865
5866 r := <-reqc
5867 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5868 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5869 }
5870 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5871 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5872 }
5873 }
5874
5875 var errFakeRoundTrip = errors.New("fake roundtrip")
5876
5877 type funcRoundTripper func()
5878
5879 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5880 fn()
5881 return nil, errFakeRoundTrip
5882 }
5883
5884 func wantBody(res *Response, err error, want string) error {
5885 if err != nil {
5886 return err
5887 }
5888 slurp, err := io.ReadAll(res.Body)
5889 if err != nil {
5890 return fmt.Errorf("error reading body: %v", err)
5891 }
5892 if string(slurp) != want {
5893 return fmt.Errorf("body = %q; want %q", slurp, want)
5894 }
5895 if err := res.Body.Close(); err != nil {
5896 return fmt.Errorf("body Close = %v", err)
5897 }
5898 return nil
5899 }
5900
5901 func newLocalListener(t *testing.T) net.Listener {
5902 ln, err := net.Listen("tcp", "127.0.0.1:0")
5903 if err != nil {
5904 ln, err = net.Listen("tcp6", "[::1]:0")
5905 }
5906 if err != nil {
5907 t.Fatal(err)
5908 }
5909 return ln
5910 }
5911
5912 type countCloseReader struct {
5913 n *int
5914 io.Reader
5915 }
5916
5917 func (cr countCloseReader) Close() error {
5918 (*cr.n)++
5919 return nil
5920 }
5921
5922
5923 var rgz = []byte{
5924 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5925 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5926 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5927 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5928 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5929 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5930 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5931 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5932 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5933 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5934 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5935 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5936 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5937 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5938 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5939 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5940 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5941 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5942 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5943 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5944 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5945 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5946 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5947 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5948 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5949 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5950 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5951 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5952 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5953 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5954 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5955 0x00, 0x00,
5956 }
5957
5958
5959
5960 func TestMissingStatusNoPanic(t *testing.T) {
5961 t.Parallel()
5962
5963 const want = "unknown status code"
5964
5965 ln := newLocalListener(t)
5966 addr := ln.Addr().String()
5967 done := make(chan bool)
5968 fullAddrURL := fmt.Sprintf("http://%s", addr)
5969 raw := "HTTP/1.1 400\r\n" +
5970 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5971 "Content-Type: text/html; charset=utf-8\r\n" +
5972 "Content-Length: 10\r\n" +
5973 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5974 "Vary: Accept-Encoding\r\n\r\n" +
5975 "Aloha Olaa"
5976
5977 go func() {
5978 defer close(done)
5979
5980 conn, _ := ln.Accept()
5981 if conn != nil {
5982 io.WriteString(conn, raw)
5983 io.ReadAll(conn)
5984 conn.Close()
5985 }
5986 }()
5987
5988 proxyURL, err := url.Parse(fullAddrURL)
5989 if err != nil {
5990 t.Fatalf("proxyURL: %v", err)
5991 }
5992
5993 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5994
5995 req, _ := NewRequest("GET", "https://golang.org/", nil)
5996 res, err, panicked := doFetchCheckPanic(tr, req)
5997 if panicked {
5998 t.Error("panicked, expecting an error")
5999 }
6000 if res != nil && res.Body != nil {
6001 io.Copy(io.Discard, res.Body)
6002 res.Body.Close()
6003 }
6004
6005 if err == nil || !strings.Contains(err.Error(), want) {
6006 t.Errorf("got=%v want=%q", err, want)
6007 }
6008
6009 ln.Close()
6010 <-done
6011 }
6012
6013 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
6014 defer func() {
6015 if r := recover(); r != nil {
6016 panicked = true
6017 }
6018 }()
6019 res, err = tr.RoundTrip(req)
6020 return
6021 }
6022
6023
6024
6025 func TestNoBodyOnChunked304Response(t *testing.T) {
6026 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
6027 }
6028 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
6029 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6030 conn, buf, _ := w.(Hijacker).Hijack()
6031 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
6032 buf.Flush()
6033 conn.Close()
6034 }))
6035
6036
6037
6038
6039
6040 cst.tr.DisableKeepAlives = true
6041
6042 res, err := cst.c.Get(cst.ts.URL)
6043 if err != nil {
6044 t.Fatal(err)
6045 }
6046
6047 if res.Body != NoBody {
6048 t.Errorf("Unexpected body on 304 response")
6049 }
6050 }
6051
6052 type funcWriter func([]byte) (int, error)
6053
6054 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
6055
6056 type doneContext struct {
6057 context.Context
6058 err error
6059 }
6060
6061 func (doneContext) Done() <-chan struct{} {
6062 c := make(chan struct{})
6063 close(c)
6064 return c
6065 }
6066
6067 func (d doneContext) Err() error { return d.err }
6068
6069
6070 func TestTransportCheckContextDoneEarly(t *testing.T) {
6071 tr := &Transport{}
6072 req, _ := NewRequest("GET", "http://fake.example/", nil)
6073 wantErr := errors.New("some error")
6074 req = req.WithContext(doneContext{context.Background(), wantErr})
6075 _, err := tr.RoundTrip(req)
6076 if err != wantErr {
6077 t.Errorf("error = %v; want %v", err, wantErr)
6078 }
6079 }
6080
6081
6082
6083
6084
6085
6086 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
6087 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
6088 }
6089 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
6090 timeout := 1 * time.Millisecond
6091 for {
6092 inHandler := make(chan bool)
6093 cancelHandler := make(chan struct{})
6094 handlerDone := make(chan bool)
6095 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6096 <-r.Context().Done()
6097
6098 select {
6099 case <-cancelHandler:
6100 return
6101 case inHandler <- true:
6102 }
6103 defer func() { handlerDone <- true }()
6104
6105
6106 conn, _, err := w.(Hijacker).Hijack()
6107 if err != nil {
6108 t.Error(err)
6109 return
6110 }
6111 n, err := conn.Read([]byte{0})
6112 if n != 0 || err != io.EOF {
6113 t.Errorf("unexpected Read result: %v, %v", n, err)
6114 }
6115 conn.Close()
6116 }))
6117
6118 cst.c.Timeout = timeout
6119
6120 _, err := cst.c.Get(cst.ts.URL)
6121 if err == nil {
6122 close(cancelHandler)
6123 t.Fatal("unexpected Get success")
6124 }
6125
6126 tooSlow := time.NewTimer(timeout * 10)
6127 select {
6128 case <-tooSlow.C:
6129
6130
6131
6132 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
6133 close(cancelHandler)
6134 cst.close()
6135 timeout *= 2
6136 continue
6137 case <-inHandler:
6138 tooSlow.Stop()
6139 <-handlerDone
6140 }
6141 break
6142 }
6143 }
6144
6145
6146
6147
6148
6149
6150 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
6151 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
6152 }
6153 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
6154 inHandler := make(chan bool)
6155 cancelHandler := make(chan struct{})
6156 handlerDone := make(chan bool)
6157 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6158 w.Header().Set("Content-Length", "100")
6159 w.(Flusher).Flush()
6160
6161 select {
6162 case <-cancelHandler:
6163 return
6164 case inHandler <- true:
6165 }
6166 defer func() { handlerDone <- true }()
6167
6168 conn, _, err := w.(Hijacker).Hijack()
6169 if err != nil {
6170 t.Error(err)
6171 return
6172 }
6173 conn.Write([]byte("foo"))
6174
6175 n, err := conn.Read([]byte{0})
6176
6177
6178
6179
6180
6181 if n != 0 || err == nil {
6182 t.Errorf("unexpected Read result: %v, %v", n, err)
6183 }
6184 conn.Close()
6185 }))
6186
6187
6188
6189
6190
6191 cst.c.Timeout = 24 * time.Hour
6192 req, _ := NewRequest("GET", cst.ts.URL, nil)
6193 cancelReq := make(chan struct{})
6194 req.Cancel = cancelReq
6195
6196 res, err := cst.c.Do(req)
6197 if err != nil {
6198 close(cancelHandler)
6199 t.Fatalf("Get error: %v", err)
6200 }
6201
6202
6203
6204
6205 close(cancelReq)
6206 got, err := io.ReadAll(res.Body)
6207 if err == nil {
6208 t.Errorf("unexpected success; read %q, nil", got)
6209 }
6210
6211
6212 <-inHandler
6213 <-handlerDone
6214 }
6215
6216 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6217 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6218 }
6219 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6220 done := make(chan struct{})
6221 defer close(done)
6222 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6223 conn, _, err := w.(Hijacker).Hijack()
6224 if err != nil {
6225 t.Error(err)
6226 return
6227 }
6228 defer conn.Close()
6229 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6230 bs := bufio.NewScanner(conn)
6231 bs.Scan()
6232 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6233 <-done
6234 }))
6235
6236 req, _ := NewRequest("GET", cst.ts.URL, nil)
6237 req.Header.Set("Upgrade", "foo")
6238 req.Header.Set("Connection", "upgrade")
6239 res, err := cst.c.Do(req)
6240 if err != nil {
6241 t.Fatal(err)
6242 }
6243 if res.StatusCode != 101 {
6244 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6245 }
6246 rwc, ok := res.Body.(io.ReadWriteCloser)
6247 if !ok {
6248 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6249 }
6250 defer rwc.Close()
6251 bs := bufio.NewScanner(rwc)
6252 if !bs.Scan() {
6253 t.Fatalf("expected readable input")
6254 }
6255 if got, want := bs.Text(), "Some buffered data"; got != want {
6256 t.Errorf("read %q; want %q", got, want)
6257 }
6258 io.WriteString(rwc, "echo\n")
6259 if !bs.Scan() {
6260 t.Fatalf("expected another line")
6261 }
6262 if got, want := bs.Text(), "ECHO"; got != want {
6263 t.Errorf("read %q; want %q", got, want)
6264 }
6265 }
6266
6267 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6268 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6269 const target = "backend:443"
6270 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6271 if r.Method != "CONNECT" {
6272 t.Errorf("unexpected method %q", r.Method)
6273 w.WriteHeader(500)
6274 return
6275 }
6276 if r.RequestURI != target {
6277 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6278 w.WriteHeader(500)
6279 return
6280 }
6281 nc, brw, err := w.(Hijacker).Hijack()
6282 if err != nil {
6283 t.Error(err)
6284 return
6285 }
6286 defer nc.Close()
6287 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6288
6289 for {
6290 line, err := brw.ReadString('\n')
6291 if err != nil {
6292 if err != io.EOF {
6293 t.Error(err)
6294 }
6295 return
6296 }
6297 io.WriteString(brw, strings.ToUpper(line))
6298 brw.Flush()
6299 }
6300 }))
6301 pr, pw := io.Pipe()
6302 defer pw.Close()
6303 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6304 if err != nil {
6305 t.Fatal(err)
6306 }
6307 req.URL.Opaque = target
6308 res, err := cst.c.Do(req)
6309 if err != nil {
6310 t.Fatal(err)
6311 }
6312 defer res.Body.Close()
6313 if res.StatusCode != 200 {
6314 t.Fatalf("status code = %d; want 200", res.StatusCode)
6315 }
6316 br := bufio.NewReader(res.Body)
6317 for _, str := range []string{"foo", "bar", "baz"} {
6318 fmt.Fprintf(pw, "%s\n", str)
6319 got, err := br.ReadString('\n')
6320 if err != nil {
6321 t.Fatal(err)
6322 }
6323 got = strings.TrimSpace(got)
6324 want := strings.ToUpper(str)
6325 if got != want {
6326 t.Fatalf("got %q; want %q", got, want)
6327 }
6328 }
6329 }
6330
6331 func TestTransportRequestReplayable(t *testing.T) {
6332 someBody := io.NopCloser(strings.NewReader(""))
6333 tests := []struct {
6334 name string
6335 req *Request
6336 want bool
6337 }{
6338 {
6339 name: "GET",
6340 req: &Request{Method: "GET"},
6341 want: true,
6342 },
6343 {
6344 name: "GET_http.NoBody",
6345 req: &Request{Method: "GET", Body: NoBody},
6346 want: true,
6347 },
6348 {
6349 name: "GET_body",
6350 req: &Request{Method: "GET", Body: someBody},
6351 want: false,
6352 },
6353 {
6354 name: "POST",
6355 req: &Request{Method: "POST"},
6356 want: false,
6357 },
6358 {
6359 name: "POST_idempotency-key",
6360 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6361 want: true,
6362 },
6363 {
6364 name: "POST_x-idempotency-key",
6365 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6366 want: true,
6367 },
6368 {
6369 name: "POST_body",
6370 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6371 want: false,
6372 },
6373 }
6374 for _, tt := range tests {
6375 t.Run(tt.name, func(t *testing.T) {
6376 got := tt.req.ExportIsReplayable()
6377 if got != tt.want {
6378 t.Errorf("replyable = %v; want %v", got, tt.want)
6379 }
6380 })
6381 }
6382 }
6383
6384
6385
6386 type testMockTCPConn struct {
6387 *net.TCPConn
6388
6389 ReadFromCalled bool
6390 }
6391
6392 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6393 c.ReadFromCalled = true
6394 return c.TCPConn.ReadFrom(r)
6395 }
6396
6397 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6398 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6399 nBytes := int64(1 << 10)
6400 newFileFunc := func() (r io.Reader, done func(), err error) {
6401 f, err := os.CreateTemp("", "net-http-newfilefunc")
6402 if err != nil {
6403 return nil, nil, err
6404 }
6405
6406
6407 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6408 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6409 }
6410 if _, err := f.Seek(0, 0); err != nil {
6411 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6412 }
6413
6414 done = func() {
6415 f.Close()
6416 os.Remove(f.Name())
6417 }
6418
6419 return f, done, nil
6420 }
6421
6422 newBufferFunc := func() (io.Reader, func(), error) {
6423 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6424 }
6425
6426 cases := []struct {
6427 name string
6428 readerFunc func() (io.Reader, func(), error)
6429 contentLength int64
6430 expectedReadFrom bool
6431 }{
6432 {
6433 name: "file, length",
6434 readerFunc: newFileFunc,
6435 contentLength: nBytes,
6436 expectedReadFrom: true,
6437 },
6438 {
6439 name: "file, no length",
6440 readerFunc: newFileFunc,
6441 },
6442 {
6443 name: "file, negative length",
6444 readerFunc: newFileFunc,
6445 contentLength: -1,
6446 },
6447 {
6448 name: "buffer",
6449 contentLength: nBytes,
6450 readerFunc: newBufferFunc,
6451 },
6452 {
6453 name: "buffer, no length",
6454 readerFunc: newBufferFunc,
6455 },
6456 {
6457 name: "buffer, length -1",
6458 contentLength: -1,
6459 readerFunc: newBufferFunc,
6460 },
6461 }
6462
6463 for _, tc := range cases {
6464 t.Run(tc.name, func(t *testing.T) {
6465 r, cleanup, err := tc.readerFunc()
6466 if err != nil {
6467 t.Fatal(err)
6468 }
6469 defer cleanup()
6470
6471 tConn := &testMockTCPConn{}
6472 trFunc := func(tr *Transport) {
6473 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6474 var d net.Dialer
6475 conn, err := d.DialContext(ctx, network, addr)
6476 if err != nil {
6477 return nil, err
6478 }
6479
6480 tcpConn, ok := conn.(*net.TCPConn)
6481 if !ok {
6482 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6483 }
6484
6485 tConn.TCPConn = tcpConn
6486 return tConn, nil
6487 }
6488 }
6489
6490 cst := newClientServerTest(
6491 t,
6492 mode,
6493 HandlerFunc(func(w ResponseWriter, r *Request) {
6494 io.Copy(io.Discard, r.Body)
6495 r.Body.Close()
6496 w.WriteHeader(200)
6497 }),
6498 trFunc,
6499 )
6500
6501 req, err := NewRequest("PUT", cst.ts.URL, r)
6502 if err != nil {
6503 t.Fatal(err)
6504 }
6505 req.ContentLength = tc.contentLength
6506 req.Header.Set("Content-Type", "application/octet-stream")
6507 resp, err := cst.c.Do(req)
6508 if err != nil {
6509 t.Fatal(err)
6510 }
6511 defer resp.Body.Close()
6512 if resp.StatusCode != 200 {
6513 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6514 }
6515
6516 expectedReadFrom := tc.expectedReadFrom
6517 if mode != http1Mode {
6518 expectedReadFrom = false
6519 }
6520 if !tConn.ReadFromCalled && expectedReadFrom {
6521 t.Fatalf("did not call ReadFrom")
6522 }
6523
6524 if tConn.ReadFromCalled && !expectedReadFrom {
6525 t.Fatalf("ReadFrom was unexpectedly invoked")
6526 }
6527 })
6528 }
6529 }
6530
6531 func TestTransportClone(t *testing.T) {
6532 tr := &Transport{
6533 Proxy: func(*Request) (*url.URL, error) { panic("") },
6534 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6535 return nil
6536 },
6537 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6538 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6539 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6540 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6541 TLSClientConfig: new(tls.Config),
6542 TLSHandshakeTimeout: time.Second,
6543 DisableKeepAlives: true,
6544 DisableCompression: true,
6545 MaxIdleConns: 1,
6546 MaxIdleConnsPerHost: 1,
6547 MaxConnsPerHost: 1,
6548 IdleConnTimeout: time.Second,
6549 ResponseHeaderTimeout: time.Second,
6550 ExpectContinueTimeout: time.Second,
6551 ProxyConnectHeader: Header{},
6552 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6553 MaxResponseHeaderBytes: 1,
6554 ForceAttemptHTTP2: true,
6555 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6556 Protocols: &Protocols{},
6557 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6558 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6559 },
6560 ReadBufferSize: 1,
6561 WriteBufferSize: 1,
6562 }
6563 tr.Protocols.SetHTTP1(true)
6564 tr.Protocols.SetHTTP2(true)
6565 tr2 := tr.Clone()
6566 rv := reflect.ValueOf(tr2).Elem()
6567 rt := rv.Type()
6568 for i := 0; i < rt.NumField(); i++ {
6569 sf := rt.Field(i)
6570 if !token.IsExported(sf.Name) {
6571 continue
6572 }
6573 if rv.Field(i).IsZero() {
6574 t.Errorf("cloned field t2.%s is zero", sf.Name)
6575 }
6576 }
6577
6578 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6579 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6580 }
6581
6582
6583 tr = new(Transport)
6584 tr2 = tr.Clone()
6585 if tr2.TLSNextProto != nil {
6586 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6587 }
6588 }
6589
6590 func TestIs408(t *testing.T) {
6591 tests := []struct {
6592 in string
6593 want bool
6594 }{
6595 {"HTTP/1.0 408", true},
6596 {"HTTP/1.1 408", true},
6597 {"HTTP/1.8 408", true},
6598 {"HTTP/2.0 408", false},
6599 {"HTTP/1.1 408 ", true},
6600 {"HTTP/1.1 40", false},
6601 {"http/1.0 408", false},
6602 {"HTTP/1-1 408", false},
6603 }
6604 for _, tt := range tests {
6605 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6606 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6607 }
6608 }
6609 }
6610
6611 func TestTransportIgnores408(t *testing.T) {
6612 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6613 }
6614 func testTransportIgnores408(t *testing.T, mode testMode) {
6615
6616 defer log.SetOutput(log.Writer())
6617
6618 var logout strings.Builder
6619 log.SetOutput(&logout)
6620
6621 const target = "backend:443"
6622
6623 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6624 nc, _, err := w.(Hijacker).Hijack()
6625 if err != nil {
6626 t.Error(err)
6627 return
6628 }
6629 defer nc.Close()
6630 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6631 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6632 }))
6633 req, err := NewRequest("GET", cst.ts.URL, nil)
6634 if err != nil {
6635 t.Fatal(err)
6636 }
6637 res, err := cst.c.Do(req)
6638 if err != nil {
6639 t.Fatal(err)
6640 }
6641 slurp, err := io.ReadAll(res.Body)
6642 if err != nil {
6643 t.Fatal(err)
6644 }
6645 if err != nil {
6646 t.Fatal(err)
6647 }
6648 if string(slurp) != "ok" {
6649 t.Fatalf("got %q; want ok", slurp)
6650 }
6651
6652 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6653 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6654 if d > 0 {
6655 t.Logf("%v idle conns still present after %v", n, d)
6656 }
6657 return false
6658 }
6659 return true
6660 })
6661 if got := logout.String(); got != "" {
6662 t.Fatalf("expected no log output; got: %s", got)
6663 }
6664 }
6665
6666 func TestInvalidHeaderResponse(t *testing.T) {
6667 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6668 }
6669 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6670 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6671 conn, buf, _ := w.(Hijacker).Hijack()
6672 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6673 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6674 "Content-Type: text/html; charset=utf-8\r\n" +
6675 "Content-Length: 0\r\n" +
6676 "Foo : bar\r\n\r\n"))
6677 buf.Flush()
6678 conn.Close()
6679 }))
6680 res, err := cst.c.Get(cst.ts.URL)
6681 if err != nil {
6682 t.Fatal(err)
6683 }
6684 defer res.Body.Close()
6685 if v := res.Header.Get("Foo"); v != "" {
6686 t.Errorf(`unexpected "Foo" header: %q`, v)
6687 }
6688 if v := res.Header.Get("Foo "); v != "bar" {
6689 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6690 }
6691 }
6692
6693 type bodyCloser bool
6694
6695 func (bc *bodyCloser) Close() error {
6696 *bc = true
6697 return nil
6698 }
6699 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6700 return 0, io.EOF
6701 }
6702
6703
6704
6705 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6706 run(t, testTransportClosesBodyOnInvalidRequests)
6707 }
6708 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6709 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6710 t.Errorf("Should not have been invoked")
6711 })).ts
6712
6713 u, _ := url.Parse(cst.URL)
6714
6715 tests := []struct {
6716 name string
6717 req *Request
6718 wantErr string
6719 }{
6720 {
6721 name: "invalid method",
6722 req: &Request{
6723 Method: " ",
6724 URL: u,
6725 },
6726 wantErr: `invalid method " "`,
6727 },
6728 {
6729 name: "nil URL",
6730 req: &Request{
6731 Method: "GET",
6732 },
6733 wantErr: `nil Request.URL`,
6734 },
6735 {
6736 name: "invalid header key",
6737 req: &Request{
6738 Method: "GET",
6739 Header: Header{"💡": {"emoji"}},
6740 URL: u,
6741 },
6742 wantErr: `invalid header field name "💡"`,
6743 },
6744 {
6745 name: "invalid header value",
6746 req: &Request{
6747 Method: "POST",
6748 Header: Header{"key": {"\x19"}},
6749 URL: u,
6750 },
6751 wantErr: `invalid header field value for "key"`,
6752 },
6753 {
6754 name: "non HTTP(s) scheme",
6755 req: &Request{
6756 Method: "POST",
6757 URL: &url.URL{Scheme: "faux"},
6758 },
6759 wantErr: `unsupported protocol scheme "faux"`,
6760 },
6761 {
6762 name: "no Host in URL",
6763 req: &Request{
6764 Method: "POST",
6765 URL: &url.URL{Scheme: "http"},
6766 },
6767 wantErr: `no Host in request URL`,
6768 },
6769 }
6770
6771 for _, tt := range tests {
6772 t.Run(tt.name, func(t *testing.T) {
6773 var bc bodyCloser
6774 req := tt.req
6775 req.Body = &bc
6776 _, err := cst.Client().Do(tt.req)
6777 if err == nil {
6778 t.Fatal("Expected an error")
6779 }
6780 if !bc {
6781 t.Fatal("Expected body to have been closed")
6782 }
6783 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6784 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6785 }
6786 })
6787 }
6788 }
6789
6790
6791
6792 type breakableConn struct {
6793 net.Conn
6794 *brokenState
6795 }
6796
6797 type brokenState struct {
6798 sync.Mutex
6799 broken bool
6800 }
6801
6802 func (w *breakableConn) Write(b []byte) (n int, err error) {
6803 w.Lock()
6804 defer w.Unlock()
6805 if w.broken {
6806 return 0, errors.New("some write error")
6807 }
6808 return w.Conn.Write(b)
6809 }
6810
6811
6812 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6813 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6814 }
6815 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6816 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6817
6818 var brokenState brokenState
6819
6820 const numReqs = 5
6821 var numDials, gotConns uint32
6822
6823 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6824 atomic.AddUint32(&numDials, 1)
6825 c, err := net.Dial(netw, addr)
6826 if err != nil {
6827 t.Errorf("unexpected Dial error: %v", err)
6828 return nil, err
6829 }
6830 return &breakableConn{c, &brokenState}, err
6831 }
6832
6833 for i := 1; i <= numReqs; i++ {
6834 brokenState.Lock()
6835 brokenState.broken = false
6836 brokenState.Unlock()
6837
6838
6839
6840
6841 doBreak := i != numReqs
6842
6843 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6844 GotConn: func(info httptrace.GotConnInfo) {
6845 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6846 atomic.AddUint32(&gotConns, 1)
6847 },
6848 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6849 brokenState.Lock()
6850 defer brokenState.Unlock()
6851 if doBreak {
6852 brokenState.broken = true
6853 }
6854 },
6855 })
6856 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6857 if err != nil {
6858 t.Fatal(err)
6859 }
6860 _, err = cst.c.Do(req)
6861 if doBreak != (err != nil) {
6862 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6863 }
6864 }
6865 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6866 t.Errorf("GotConn calls = %v; want %v", got, want)
6867 }
6868 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6869 t.Errorf("Dials = %v; want %v", got, want)
6870 }
6871 }
6872
6873
6874
6875
6876
6877 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6878 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6879 }
6880 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6881 CondSkipHTTP2(t)
6882
6883 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6884 _, err := w.Write([]byte("foo"))
6885 if err != nil {
6886 t.Fatalf("Write: %v", err)
6887 }
6888 })
6889
6890 ts := newClientServerTest(t, mode, h).ts
6891
6892 c := ts.Client()
6893 tr := c.Transport.(*Transport)
6894 tr.MaxConnsPerHost = 1
6895
6896 errCh := make(chan error, 300)
6897 doReq := func() {
6898 resp, err := c.Get(ts.URL)
6899 if err != nil {
6900 errCh <- fmt.Errorf("request failed: %v", err)
6901 return
6902 }
6903 defer resp.Body.Close()
6904 _, err = io.ReadAll(resp.Body)
6905 if err != nil {
6906 errCh <- fmt.Errorf("read body failed: %v", err)
6907 }
6908 }
6909
6910 var wg sync.WaitGroup
6911 for i := 0; i < 300; i++ {
6912 wg.Add(1)
6913 go func() {
6914 defer wg.Done()
6915 doReq()
6916 }()
6917 }
6918 wg.Wait()
6919 close(errCh)
6920
6921 for err := range errCh {
6922 t.Errorf("error occurred: %v", err)
6923 }
6924 }
6925
6926
6927
6928
6929 func TestAltProtoCancellation(t *testing.T) {
6930 defer afterTest(t)
6931 tr := &Transport{}
6932 c := &Client{
6933 Transport: tr,
6934 Timeout: time.Millisecond,
6935 }
6936 tr.RegisterProtocol("cancel", cancelProto{})
6937 _, err := c.Get("cancel://bar.com/path")
6938 if err == nil {
6939 t.Error("request unexpectedly succeeded")
6940 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6941 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6942 }
6943 }
6944
6945 var errCancelProto = errors.New("canceled as expected")
6946
6947 type cancelProto struct{}
6948
6949 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6950 <-req.Cancel
6951 return nil, errCancelProto
6952 }
6953
6954 type roundTripFunc func(r *Request) (*Response, error)
6955
6956 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6957
6958
6959 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6960 func testIssue32441(t *testing.T, mode testMode) {
6961 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6962 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6963 t.Error("body length is zero")
6964 }
6965 })).ts
6966 c := ts.Client()
6967 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6968
6969 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6970 t.Error("body length is zero during round trip")
6971 }
6972 return nil, ErrSkipAltProtocol
6973 }))
6974 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6975 t.Error(err)
6976 }
6977 }
6978
6979
6980
6981 func TestTransportRejectsSignInContentLength(t *testing.T) {
6982 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6983 }
6984 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6985 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6986 w.Header().Set("Content-Length", "+3")
6987 w.Write([]byte("abc"))
6988 })).ts
6989
6990 c := cst.Client()
6991 res, err := c.Get(cst.URL)
6992 if err == nil || res != nil {
6993 t.Fatal("Expected a non-nil error and a nil http.Response")
6994 }
6995 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6996 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6997 }
6998 }
6999
7000
7001 type dumpConn struct {
7002 io.Writer
7003 io.Reader
7004 }
7005
7006 func (c *dumpConn) Close() error { return nil }
7007 func (c *dumpConn) LocalAddr() net.Addr { return nil }
7008 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
7009 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
7010 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
7011 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
7012
7013
7014
7015 type delegateReader struct {
7016 c chan io.Reader
7017 r io.Reader
7018 }
7019
7020 func (r *delegateReader) Read(p []byte) (int, error) {
7021 if r.r == nil {
7022 var ok bool
7023 if r.r, ok = <-r.c; !ok {
7024 return 0, errors.New("delegate closed")
7025 }
7026 }
7027 return r.r.Read(p)
7028 }
7029
7030 func testTransportRace(req *Request) {
7031 save := req.Body
7032 pr, pw := io.Pipe()
7033 defer pr.Close()
7034 defer pw.Close()
7035 dr := &delegateReader{c: make(chan io.Reader)}
7036
7037 t := &Transport{
7038 Dial: func(net, addr string) (net.Conn, error) {
7039 return &dumpConn{pw, dr}, nil
7040 },
7041 }
7042 defer t.CloseIdleConnections()
7043
7044 quitReadCh := make(chan struct{})
7045
7046 go func() {
7047 defer close(quitReadCh)
7048
7049 req, err := ReadRequest(bufio.NewReader(pr))
7050 if err == nil {
7051
7052
7053 io.Copy(io.Discard, req.Body)
7054 req.Body.Close()
7055 }
7056 select {
7057 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
7058 case quitReadCh <- struct{}{}:
7059
7060 close(dr.c)
7061 }
7062 }()
7063
7064 t.RoundTrip(req)
7065
7066
7067
7068 pw.Close()
7069 <-quitReadCh
7070
7071 req.Body = save
7072 }
7073
7074
7075
7076
7077
7078 func TestErrorWriteLoopRace(t *testing.T) {
7079 if testing.Short() {
7080 return
7081 }
7082 t.Parallel()
7083 for i := 0; i < 1000; i++ {
7084 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
7085 ctx, cancel := context.WithTimeout(context.Background(), delay)
7086 defer cancel()
7087
7088 r := bytes.NewBuffer(make([]byte, 10000))
7089 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
7090 if err != nil {
7091 t.Fatal(err)
7092 }
7093
7094 testTransportRace(req)
7095 }
7096 }
7097
7098
7099
7100
7101 func TestCancelRequestWhenSharingConnection(t *testing.T) {
7102 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
7103 }
7104 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
7105 reqc := make(chan chan struct{}, 2)
7106 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
7107 ch := make(chan struct{}, 1)
7108 reqc <- ch
7109 <-ch
7110 w.Header().Add("Content-Length", "0")
7111 })).ts
7112
7113 client := ts.Client()
7114 transport := client.Transport.(*Transport)
7115 transport.MaxIdleConns = 1
7116 transport.MaxConnsPerHost = 1
7117
7118 var wg sync.WaitGroup
7119
7120 wg.Add(1)
7121 putidlec := make(chan chan struct{}, 1)
7122 reqerrc := make(chan error, 1)
7123 go func() {
7124 defer wg.Done()
7125 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7126 PutIdleConn: func(error) {
7127
7128
7129 ch := make(chan struct{})
7130 putidlec <- ch
7131 close(putidlec)
7132 <-ch
7133 },
7134 })
7135 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
7136 res, err := client.Do(req)
7137 if err != nil {
7138 reqerrc <- err
7139 } else {
7140 res.Body.Close()
7141 }
7142 }()
7143
7144
7145
7146 select {
7147 case err := <-reqerrc:
7148 t.Fatalf("request 1: got err %v, want nil", err)
7149 case r1c := <-reqc:
7150 close(r1c)
7151 }
7152 var idlec chan struct{}
7153 select {
7154 case err := <-reqerrc:
7155 t.Fatalf("request 1: got err %v, want nil", err)
7156 case idlec = <-putidlec:
7157 }
7158
7159 wg.Add(1)
7160 cancelctx, cancel := context.WithCancel(context.Background())
7161 go func() {
7162 defer wg.Done()
7163 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
7164 res, err := client.Do(req)
7165 if err == nil {
7166 res.Body.Close()
7167 }
7168 if !errors.Is(err, context.Canceled) {
7169 t.Errorf("request 2: got err %v, want Canceled", err)
7170 }
7171
7172
7173 close(idlec)
7174 }()
7175
7176
7177
7178 r2c := <-reqc
7179 cancel()
7180
7181 <-idlec
7182
7183 close(r2c)
7184 wg.Wait()
7185 }
7186
7187 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
7188 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
7189 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7190 go io.Copy(io.Discard, req.Body)
7191 panic(ErrAbortHandler)
7192 })).ts
7193
7194 var wg sync.WaitGroup
7195 for i := 0; i < 2; i++ {
7196 wg.Add(1)
7197 go func() {
7198 defer wg.Done()
7199 for j := 0; j < 10; j++ {
7200 const reqLen = 6 * 1024 * 1024
7201 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7202 req.ContentLength = reqLen
7203 resp, _ := ts.Client().Transport.RoundTrip(req)
7204 if resp != nil {
7205 resp.Body.Close()
7206 }
7207 }
7208 }()
7209 }
7210 wg.Wait()
7211 }
7212
7213 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7214 func testRequestSanitization(t *testing.T, mode testMode) {
7215 if mode == http2Mode {
7216
7217 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7218 }
7219 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7220 if h, ok := req.Header["X-Evil"]; ok {
7221 t.Errorf("request has X-Evil header: %q", h)
7222 }
7223 })).ts
7224 req, _ := NewRequest("GET", ts.URL, nil)
7225 req.Host = "go.dev\r\nX-Evil:evil"
7226 resp, _ := ts.Client().Do(req)
7227 if resp != nil {
7228 resp.Body.Close()
7229 }
7230 }
7231
7232 func TestProxyAuthHeader(t *testing.T) {
7233
7234 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7235 }
7236 func testProxyAuthHeader(t *testing.T, mode testMode) {
7237 const username = "u"
7238 const password = "@/?!"
7239 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7240
7241
7242 var r2 Request
7243 r2.Header = Header{
7244 "Authorization": req.Header["Proxy-Authorization"],
7245 }
7246 gotuser, gotpass, ok := r2.BasicAuth()
7247 if !ok || gotuser != username || gotpass != password {
7248 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7249 }
7250 }))
7251 u, err := url.Parse(cst.ts.URL)
7252 if err != nil {
7253 t.Fatal(err)
7254 }
7255 u.User = url.UserPassword(username, password)
7256 t.Setenv("HTTP_PROXY", u.String())
7257 cst.tr.Proxy = ProxyURL(u)
7258 resp, err := cst.c.Get("http://_/")
7259 if err != nil {
7260 t.Fatal(err)
7261 }
7262 resp.Body.Close()
7263 }
7264
7265
7266 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7267 ln := newLocalListener(t)
7268 addr := ln.Addr().String()
7269
7270 done := make(chan struct{})
7271 go func() {
7272 conn, err := ln.Accept()
7273 if err != nil {
7274 t.Errorf("ln.Accept: %v", err)
7275 return
7276 }
7277
7278
7279 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7280 t.Errorf("conn.Read: %v", err)
7281 return
7282 }
7283 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7284 <-done
7285 conn.Close()
7286 }()
7287
7288 didRead := make(chan bool)
7289 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7290 defer SetReadLoopBeforeNextReadHook(nil)
7291
7292 tr := &Transport{}
7293
7294
7295 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7296 if err != nil {
7297 t.Fatalf("NewRequest: %v", err)
7298 }
7299
7300 resp, err := tr.RoundTrip(req)
7301 if err != nil {
7302 t.Fatalf("tr.RoundTrip: %v", err)
7303 }
7304
7305 close(done)
7306
7307
7308
7309 <-didRead
7310
7311 resp.Body.Close()
7312
7313
7314
7315 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7316 n := tr.NumPendingRequestsForTesting()
7317 if n > 0 {
7318 if d > 0 {
7319 t.Logf("pending requests = %d after %v (want 0)", n, d)
7320 }
7321 return false
7322 }
7323 return true
7324 })
7325 }
7326
7327 func TestValidateClientRequestTrailers(t *testing.T) {
7328 run(t, testValidateClientRequestTrailers)
7329 }
7330
7331 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7332 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7333 rw.Write([]byte("Hello"))
7334 })).ts
7335
7336 cases := []struct {
7337 trailer Header
7338 wantErr string
7339 }{
7340 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7341 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7342 }
7343
7344 for i, tt := range cases {
7345 testName := fmt.Sprintf("%s%d", mode, i)
7346 t.Run(testName, func(t *testing.T) {
7347 req, err := NewRequest("GET", cst.URL, nil)
7348 if err != nil {
7349 t.Fatal(err)
7350 }
7351 req.Trailer = tt.trailer
7352 res, err := cst.Client().Do(req)
7353 if err == nil {
7354 t.Fatal("Expected an error")
7355 }
7356 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7357 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7358 }
7359 if res != nil {
7360 t.Fatal("Unexpected non-nil response")
7361 }
7362 })
7363 }
7364 }
7365
7366 func TestTransportServerProtocols(t *testing.T) {
7367 CondSkipHTTP2(t)
7368 DefaultTransport.(*Transport).CloseIdleConnections()
7369
7370 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7371 if err != nil {
7372 t.Fatal(err)
7373 }
7374 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7375 if err != nil {
7376 t.Fatal(err)
7377 }
7378 certpool := x509.NewCertPool()
7379 certpool.AddCert(leafCert)
7380
7381 for _, test := range []struct {
7382 name string
7383 scheme string
7384 setup func(t *testing.T)
7385 transport func(*Transport)
7386 server func(*Server)
7387 want string
7388 }{{
7389 name: "http default",
7390 scheme: "http",
7391 want: "HTTP/1.1",
7392 }, {
7393 name: "https default",
7394 scheme: "https",
7395 transport: func(tr *Transport) {
7396
7397 },
7398 want: "HTTP/1.1",
7399 }, {
7400 name: "https transport protocols include HTTP2",
7401 scheme: "https",
7402 transport: func(tr *Transport) {
7403
7404
7405 tr.Protocols = &Protocols{}
7406 tr.Protocols.SetHTTP1(true)
7407 tr.Protocols.SetHTTP2(true)
7408 },
7409 want: "HTTP/2.0",
7410 }, {
7411 name: "https transport protocols only include HTTP1",
7412 scheme: "https",
7413 transport: func(tr *Transport) {
7414
7415 tr.Protocols = &Protocols{}
7416 tr.Protocols.SetHTTP1(true)
7417 },
7418 want: "HTTP/1.1",
7419 }, {
7420 name: "https transport ForceAttemptHTTP2",
7421 scheme: "https",
7422 transport: func(tr *Transport) {
7423
7424 tr.ForceAttemptHTTP2 = true
7425 },
7426 want: "HTTP/2.0",
7427 }, {
7428 name: "https transport protocols override TLSNextProto",
7429 scheme: "https",
7430 transport: func(tr *Transport) {
7431
7432
7433
7434 tr.Protocols = &Protocols{}
7435 tr.Protocols.SetHTTP1(true)
7436 tr.Protocols.SetHTTP2(true)
7437 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7438 },
7439 want: "HTTP/2.0",
7440 }, {
7441 name: "https server disables HTTP2 with TLSNextProto",
7442 scheme: "https",
7443 server: func(srv *Server) {
7444
7445
7446 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7447 },
7448 want: "HTTP/1.1",
7449 }, {
7450 name: "https server Protocols overrides empty TLSNextProto",
7451 scheme: "https",
7452 server: func(srv *Server) {
7453
7454
7455 srv.Protocols = &Protocols{}
7456 srv.Protocols.SetHTTP1(true)
7457 srv.Protocols.SetHTTP2(true)
7458 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7459 },
7460 want: "HTTP/2.0",
7461 }, {
7462 name: "https server protocols only include HTTP1",
7463 scheme: "https",
7464 server: func(srv *Server) {
7465 srv.Protocols = &Protocols{}
7466 srv.Protocols.SetHTTP1(true)
7467 },
7468 want: "HTTP/1.1",
7469 }, {
7470 name: "https server protocols include HTTP2",
7471 scheme: "https",
7472 server: func(srv *Server) {
7473 srv.Protocols = &Protocols{}
7474 srv.Protocols.SetHTTP1(true)
7475 srv.Protocols.SetHTTP2(true)
7476 },
7477 want: "HTTP/2.0",
7478 }, {
7479 name: "GODEBUG disables HTTP2 client",
7480 scheme: "https",
7481 setup: func(t *testing.T) {
7482 t.Setenv("GODEBUG", "http2client=0")
7483 },
7484 transport: func(tr *Transport) {
7485
7486
7487 tr.Protocols = &Protocols{}
7488 tr.Protocols.SetHTTP1(true)
7489 tr.Protocols.SetHTTP2(true)
7490 },
7491 want: "HTTP/1.1",
7492 }, {
7493 name: "GODEBUG disables HTTP2 server",
7494 scheme: "https",
7495 setup: func(t *testing.T) {
7496 t.Setenv("GODEBUG", "http2server=0")
7497 },
7498 transport: func(tr *Transport) {
7499
7500
7501 tr.Protocols = &Protocols{}
7502 tr.Protocols.SetHTTP1(true)
7503 tr.Protocols.SetHTTP2(true)
7504 },
7505 want: "HTTP/1.1",
7506 }, {
7507 name: "unencrypted HTTP2 with prior knowledge",
7508 scheme: "http",
7509 transport: func(tr *Transport) {
7510 tr.Protocols = &Protocols{}
7511 tr.Protocols.SetUnencryptedHTTP2(true)
7512 },
7513 server: func(srv *Server) {
7514 srv.Protocols = &Protocols{}
7515 srv.Protocols.SetHTTP1(true)
7516 srv.Protocols.SetUnencryptedHTTP2(true)
7517 },
7518 want: "HTTP/2.0",
7519 }, {
7520 name: "unencrypted HTTP2 only on server",
7521 scheme: "http",
7522 transport: func(tr *Transport) {
7523 tr.Protocols = &Protocols{}
7524 tr.Protocols.SetUnencryptedHTTP2(true)
7525 },
7526 server: func(srv *Server) {
7527 srv.Protocols = &Protocols{}
7528 srv.Protocols.SetUnencryptedHTTP2(true)
7529 },
7530 want: "HTTP/2.0",
7531 }, {
7532 name: "unencrypted HTTP2 with no server support",
7533 scheme: "http",
7534 transport: func(tr *Transport) {
7535 tr.Protocols = &Protocols{}
7536 tr.Protocols.SetUnencryptedHTTP2(true)
7537 },
7538 server: func(srv *Server) {
7539 srv.Protocols = &Protocols{}
7540 srv.Protocols.SetHTTP1(true)
7541 },
7542 want: "error",
7543 }, {
7544 name: "HTTP1 with no server support",
7545 scheme: "http",
7546 transport: func(tr *Transport) {
7547 tr.Protocols = &Protocols{}
7548 tr.Protocols.SetHTTP1(true)
7549 },
7550 server: func(srv *Server) {
7551 srv.Protocols = &Protocols{}
7552 srv.Protocols.SetUnencryptedHTTP2(true)
7553 },
7554 want: "error",
7555 }, {
7556 name: "HTTPS1 with no server support",
7557 scheme: "https",
7558 transport: func(tr *Transport) {
7559 tr.Protocols = &Protocols{}
7560 tr.Protocols.SetHTTP1(true)
7561 },
7562 server: func(srv *Server) {
7563 srv.Protocols = &Protocols{}
7564 srv.Protocols.SetHTTP2(true)
7565 },
7566 want: "error",
7567 }} {
7568 t.Run(test.name, func(t *testing.T) {
7569
7570
7571 srv := &Server{
7572 TLSConfig: &tls.Config{
7573 Certificates: []tls.Certificate{cert},
7574 },
7575 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7576 w.Header().Set("X-Proto", req.Proto)
7577 }),
7578 }
7579 tr := &Transport{
7580 TLSClientConfig: &tls.Config{
7581 RootCAs: certpool,
7582 },
7583 }
7584
7585 if test.setup != nil {
7586 test.setup(t)
7587 }
7588 if test.server != nil {
7589 test.server(srv)
7590 }
7591 if test.transport != nil {
7592 test.transport(tr)
7593 } else {
7594 tr.Protocols = &Protocols{}
7595 tr.Protocols.SetHTTP1(true)
7596 tr.Protocols.SetHTTP2(true)
7597 }
7598
7599 listener := newLocalListener(t)
7600 srvc := make(chan error, 1)
7601 go func() {
7602 switch test.scheme {
7603 case "http":
7604 srvc <- srv.Serve(listener)
7605 case "https":
7606 srvc <- srv.ServeTLS(listener, "", "")
7607 }
7608 }()
7609 t.Cleanup(func() {
7610 srv.Close()
7611 <-srvc
7612 })
7613
7614 client := &Client{Transport: tr}
7615 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7616 if err != nil {
7617 if test.want == "error" {
7618 return
7619 }
7620 t.Fatal(err)
7621 }
7622 if got := resp.Header.Get("X-Proto"); got != test.want {
7623 t.Fatalf("request proto %q, want %q", got, test.want)
7624 }
7625 })
7626 }
7627 }
7628
View as plain text