Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 "net/http"
30 . "net/http"
31 "net/http/httptest"
32 "net/http/httptrace"
33 "net/http/httputil"
34 "net/http/internal/testcert"
35 "net/textproto"
36 "net/url"
37 "os"
38 "reflect"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "testing"
46 "testing/iotest"
47 "testing/synctest"
48 "time"
49
50 "golang.org/x/net/http/httpguts"
51 )
52
53
54
55
56
57 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
58 if r.FormValue("close") == "true" {
59 w.Header().Set("Connection", "close")
60 }
61 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
62 w.Write([]byte(r.RemoteAddr))
63
64
65
66 if c, ok := ResponseWriterConnForTesting(w); ok {
67 fmt.Fprintf(w, ", %T %p", c, c)
68 }
69 })
70
71
72 type testCloseConn struct {
73 net.Conn
74 set *testConnSet
75 }
76
77 func (c *testCloseConn) Close() error {
78 c.set.remove(c)
79 return c.Conn.Close()
80 }
81
82
83
84 type testConnSet struct {
85 t *testing.T
86 mu sync.Mutex
87 closed map[net.Conn]bool
88 list []net.Conn
89 }
90
91 func (tcs *testConnSet) insert(c net.Conn) {
92 tcs.mu.Lock()
93 defer tcs.mu.Unlock()
94 tcs.closed[c] = false
95 tcs.list = append(tcs.list, c)
96 }
97
98 func (tcs *testConnSet) remove(c net.Conn) {
99 tcs.mu.Lock()
100 defer tcs.mu.Unlock()
101 tcs.closed[c] = true
102 }
103
104
105 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
106 connSet := &testConnSet{
107 t: t,
108 closed: make(map[net.Conn]bool),
109 }
110 dial := func(n, addr string) (net.Conn, error) {
111 c, err := net.Dial(n, addr)
112 if err != nil {
113 return nil, err
114 }
115 tc := &testCloseConn{c, connSet}
116 connSet.insert(tc)
117 return tc, nil
118 }
119 return connSet, dial
120 }
121
122 func (tcs *testConnSet) check(t *testing.T) {
123 tcs.mu.Lock()
124 defer tcs.mu.Unlock()
125 for i := 4; i >= 0; i-- {
126 for i, c := range tcs.list {
127 if tcs.closed[c] {
128 continue
129 }
130 if i != 0 {
131
132
133 tcs.mu.Unlock()
134 time.Sleep(50 * time.Millisecond)
135 tcs.mu.Lock()
136 continue
137 }
138 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
139 }
140 }
141 }
142
143 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
144 func testReuseRequest(t *testing.T, mode testMode) {
145 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
146 w.Write([]byte("{}"))
147 })).ts
148
149 c := ts.Client()
150 req, _ := NewRequest("GET", ts.URL, nil)
151 res, err := c.Do(req)
152 if err != nil {
153 t.Fatal(err)
154 }
155 err = res.Body.Close()
156 if err != nil {
157 t.Fatal(err)
158 }
159
160 res, err = c.Do(req)
161 if err != nil {
162 t.Fatal(err)
163 }
164 err = res.Body.Close()
165 if err != nil {
166 t.Fatal(err)
167 }
168 }
169
170
171
172 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
173 func testTransportKeepAlives(t *testing.T, mode testMode) {
174 ts := newClientServerTest(t, mode, hostPortHandler).ts
175
176 c := ts.Client()
177 for _, disableKeepAlive := range []bool{false, true} {
178 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
179 fetch := func(n int) string {
180 res, err := c.Get(ts.URL)
181 if err != nil {
182 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
183 }
184 body, err := io.ReadAll(res.Body)
185 if err != nil {
186 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
187 }
188 return string(body)
189 }
190
191 body1 := fetch(1)
192 body2 := fetch(2)
193
194 bodiesDiffer := body1 != body2
195 if bodiesDiffer != disableKeepAlive {
196 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
197 disableKeepAlive, bodiesDiffer, body1, body2)
198 }
199 }
200 }
201
202 func TestTransportConnectionCloseOnResponse(t *testing.T) {
203 run(t, testTransportConnectionCloseOnResponse, http3SkippedMode)
204 }
205 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
206 ts := newClientServerTest(t, mode, hostPortHandler).ts
207
208 connSet, testDial := makeTestDial(t)
209
210 c := ts.Client()
211 tr := c.Transport.(*Transport)
212 tr.Dial = testDial
213
214 for _, connectionClose := range []bool{false, true} {
215 fetch := func(n int) string {
216 req := new(Request)
217 var err error
218 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
219 if err != nil {
220 t.Fatalf("URL parse error: %v", err)
221 }
222 req.Method = "GET"
223 req.Proto = "HTTP/1.1"
224 req.ProtoMajor = 1
225 req.ProtoMinor = 1
226
227 res, err := c.Do(req)
228 if err != nil {
229 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
230 }
231 defer res.Body.Close()
232 body, err := io.ReadAll(res.Body)
233 if err != nil {
234 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
235 }
236 return string(body)
237 }
238
239 body1 := fetch(1)
240 body2 := fetch(2)
241 bodiesDiffer := body1 != body2
242 if bodiesDiffer != connectionClose {
243 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
244 connectionClose, bodiesDiffer, body1, body2)
245 }
246
247 tr.CloseIdleConnections()
248 }
249
250 connSet.check(t)
251 }
252
253
254
255
256
257
258
259 func TestTransportConnectionCloseOnRequest(t *testing.T) {
260 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
261 }
262 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
263 ts := newClientServerTest(t, mode, hostPortHandler).ts
264
265 connSet, testDial := makeTestDial(t)
266
267 c := ts.Client()
268 tr := c.Transport.(*Transport)
269 tr.Dial = testDial
270 for _, reqClose := range []bool{false, true} {
271 fetch := func(n int) string {
272 req := new(Request)
273 var err error
274 req.URL, err = url.Parse(ts.URL)
275 if err != nil {
276 t.Fatalf("URL parse error: %v", err)
277 }
278 req.Method = "GET"
279 req.Proto = "HTTP/1.1"
280 req.ProtoMajor = 1
281 req.ProtoMinor = 1
282 req.Close = reqClose
283
284 res, err := c.Do(req)
285 if err != nil {
286 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
287 }
288 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
289 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
290 reqClose, got, !reqClose)
291 }
292 body, err := io.ReadAll(res.Body)
293 if err != nil {
294 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
295 }
296 return string(body)
297 }
298
299 body1 := fetch(1)
300 body2 := fetch(2)
301
302 got := 1
303 if body1 != body2 {
304 got++
305 }
306 want := 1
307 if reqClose {
308 want = 2
309 }
310 if got != want {
311 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
312 reqClose, got, want, body1, body2)
313 }
314
315 tr.CloseIdleConnections()
316 }
317
318 connSet.check(t)
319 }
320
321
322
323
324 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
325 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
326 }
327 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
328 ts := newClientServerTest(t, mode, hostPortHandler).ts
329
330 c := ts.Client()
331 c.Transport.(*Transport).DisableKeepAlives = true
332
333 res, err := c.Get(ts.URL)
334 if err != nil {
335 t.Fatal(err)
336 }
337 res.Body.Close()
338 if res.Header.Get("X-Saw-Close") != "true" {
339 t.Errorf("handler didn't see Connection: close ")
340 }
341 }
342
343
344
345 func TestTransportRespectRequestWantsClose(t *testing.T) {
346 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
347 }
348 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
349 tests := []struct {
350 disableKeepAlives bool
351 close bool
352 }{
353 {disableKeepAlives: false, close: false},
354 {disableKeepAlives: false, close: true},
355 {disableKeepAlives: true, close: false},
356 {disableKeepAlives: true, close: true},
357 }
358
359 for _, tc := range tests {
360 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
361 func(t *testing.T) {
362 ts := newClientServerTest(t, mode, hostPortHandler).ts
363
364 c := ts.Client()
365 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
366 req, err := NewRequest("GET", ts.URL, nil)
367 if err != nil {
368 t.Fatal(err)
369 }
370 count := 0
371 trace := &httptrace.ClientTrace{
372 WroteHeaderField: func(key string, field []string) {
373 if key != "Connection" {
374 return
375 }
376 if httpguts.HeaderValuesContainsToken(field, "close") {
377 count += 1
378 }
379 },
380 }
381 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
382 req.Close = tc.close
383 res, err := c.Do(req)
384 if err != nil {
385 t.Fatal(err)
386 }
387 defer res.Body.Close()
388 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
389 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
390 }
391 })
392 }
393
394 }
395
396 func TestTransportIdleCacheKeys(t *testing.T) {
397 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
398 }
399 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
400 ts := newClientServerTest(t, mode, hostPortHandler).ts
401 c := ts.Client()
402 tr := c.Transport.(*Transport)
403
404 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
405 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
406 }
407
408 resp, err := c.Get(ts.URL)
409 if err != nil {
410 t.Error(err)
411 }
412 io.ReadAll(resp.Body)
413
414 keys := tr.IdleConnKeysForTesting()
415 if e, g := 1, len(keys); e != g {
416 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
417 }
418
419 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
420 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
421 }
422
423 tr.CloseIdleConnections()
424 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
425 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
426 }
427 }
428
429
430
431 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
432 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
433 const msg = "foobar"
434 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
435 w.Header().Set("Remote-Addr", r.RemoteAddr)
436 if r.URL.Path == "/chunked/" {
437 w.WriteHeader(200)
438 w.(Flusher).Flush()
439 } else {
440 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
441 w.WriteHeader(200)
442 }
443 w.Write([]byte(msg))
444 })).ts
445
446 for pi, path := range []string{"/content-length/", "/chunked/"} {
447 wantLen := []int{len(msg), -1}[pi]
448 addrSeen := make(map[string]int)
449 for range 3 {
450 res, err := ts.Client().Get(ts.URL + path)
451 if err != nil {
452 t.Errorf("Get %s: %v", path, err)
453 continue
454 }
455 defer res.Body.Close()
456
457 if res.ContentLength != int64(wantLen) {
458 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
459 }
460 got, err := io.ReadAll(res.Body)
461 if string(got) != msg || err != nil {
462 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
463 }
464 addrSeen[res.Header.Get("Remote-Addr")]++
465 }
466 if len(addrSeen) != 1 {
467 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
468 }
469 }
470 }
471
472
473
474
475
476 func TestTransportNotReadToEndConnectionReuse(t *testing.T) {
477 run(t, testTransportNotReadToEndConnectionReuse, []testMode{http1Mode, https1Mode})
478 }
479 func testTransportNotReadToEndConnectionReuse(t *testing.T, mode testMode) {
480 tests := []struct {
481 name string
482 bodyLen int
483 contentLenKnown bool
484 headRequest bool
485 timeBetweenReqs time.Duration
486 responseTime time.Duration
487 wantReuse bool
488 }{
489 {
490 name: "unconsumed body within drain limit",
491 bodyLen: 200 * 1024,
492 timeBetweenReqs: http.MaxPostCloseReadTime,
493 wantReuse: true,
494 },
495 {
496 name: "unconsumed body within drain limit with known length",
497 bodyLen: 200 * 1024,
498 contentLenKnown: true,
499 timeBetweenReqs: http.MaxPostCloseReadTime,
500 wantReuse: true,
501 },
502 {
503 name: "unconsumed body larger than drain limit",
504 bodyLen: 500 * 1024,
505 timeBetweenReqs: http.MaxPostCloseReadTime,
506 wantReuse: false,
507 },
508 {
509 name: "unconsumed body larger than drain limit with known length",
510 bodyLen: 500 * 1024,
511 contentLenKnown: true,
512 timeBetweenReqs: http.MaxPostCloseReadTime,
513 wantReuse: false,
514 },
515 {
516 name: "new requests start before drain for old requests are finished",
517 bodyLen: 200 * 1024,
518 timeBetweenReqs: 0,
519 responseTime: time.Minute,
520 wantReuse: false,
521 },
522 {
523
524
525 name: "unconsumed body larger than drain limit for HEAD request",
526 bodyLen: 500 * 1024,
527 headRequest: true,
528 wantReuse: true,
529 },
530 }
531
532 for _, tc := range tests {
533 subtest := func(t *testing.T) {
534 addrSeen := make(map[string]int)
535 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
536 addrSeen[r.RemoteAddr]++
537 time.Sleep(tc.responseTime)
538 if tc.contentLenKnown {
539 w.Header().Add("Content-Length", strconv.Itoa(tc.bodyLen))
540 }
541 w.Write(slices.Repeat([]byte("a"), tc.bodyLen))
542 }), optFakeNet).ts
543
544 var wg sync.WaitGroup
545 for range 10 {
546 wg.Go(func() {
547 method := http.MethodGet
548 if tc.headRequest {
549 method = http.MethodHead
550 }
551 ctx, cancel := context.WithCancel(context.Background())
552 req, err := http.NewRequestWithContext(ctx, method, ts.URL, nil)
553 if err != nil {
554 log.Fatal(err)
555 }
556 resp, err := ts.Client().Do(req)
557 if err != nil {
558 t.Fatal(err)
559 }
560 if resp.StatusCode != http.StatusOK {
561 t.Errorf("expected HTTP 200, got: %v", resp.StatusCode)
562 }
563 resp.Body.Close()
564
565
566 cancel()
567 if n, err := resp.Body.Read([]byte{}); n != 0 || err == nil {
568 t.Errorf("read after body has been closed should not succeed, but read %v byte with %v error", n, err)
569 }
570 })
571 time.Sleep(tc.timeBetweenReqs)
572 synctest.Wait()
573 }
574 wg.Wait()
575 if (len(addrSeen) == 1) != tc.wantReuse {
576 t.Errorf("want connection reuse to be %v, but %v connections were created", tc.wantReuse, len(addrSeen))
577 }
578 }
579 t.Run(tc.name, func(t *testing.T) {
580 synctest.Test(t, subtest)
581 })
582 }
583 }
584
585 func TestTransportMaxPerHostIdleConns(t *testing.T) {
586 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
587 }
588 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
589 stop := make(chan struct{})
590 defer close(stop)
591
592 resch := make(chan string)
593 gotReq := make(chan bool)
594 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
595 gotReq <- true
596 var msg string
597 select {
598 case <-stop:
599 return
600 case msg = <-resch:
601 }
602 _, err := w.Write([]byte(msg))
603 if err != nil {
604 t.Errorf("Write: %v", err)
605 return
606 }
607 })).ts
608
609 c := ts.Client()
610 tr := c.Transport.(*Transport)
611 maxIdleConnsPerHost := 2
612 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
613
614
615
616 donech := make(chan bool)
617 doReq := func() {
618 defer func() {
619 select {
620 case <-stop:
621 return
622 case donech <- t.Failed():
623 }
624 }()
625 resp, err := c.Get(ts.URL)
626 if err != nil {
627 t.Error(err)
628 return
629 }
630 if _, err := io.ReadAll(resp.Body); err != nil {
631 t.Errorf("ReadAll: %v", err)
632 return
633 }
634 }
635 go doReq()
636 <-gotReq
637 go doReq()
638 <-gotReq
639 go doReq()
640 <-gotReq
641
642 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
643 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
644 }
645
646 resch <- "res1"
647 <-donech
648 keys := tr.IdleConnKeysForTesting()
649 if e, g := 1, len(keys); e != g {
650 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
651 }
652 addr := ts.Listener.Addr().String()
653 cacheKey := "|http|" + addr
654 if keys[0] != cacheKey {
655 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
656 }
657 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
658 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
659 }
660
661 resch <- "res2"
662 <-donech
663 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
664 t.Errorf("after second response, idle conns = %d; want %d", g, w)
665 }
666
667 resch <- "res3"
668 <-donech
669 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
670 t.Errorf("after third response, idle conns = %d; want %d", g, w)
671 }
672 }
673
674 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
675 run(t, testTransportMaxConnsPerHostIncludeDialInProgress, http3SkippedMode)
676 }
677 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
678 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
679 _, err := w.Write([]byte("foo"))
680 if err != nil {
681 t.Fatalf("Write: %v", err)
682 }
683 })).ts
684 c := ts.Client()
685 tr := c.Transport.(*Transport)
686 dialStarted := make(chan struct{})
687 stallDial := make(chan struct{})
688 tr.Dial = func(network, addr string) (net.Conn, error) {
689 dialStarted <- struct{}{}
690 <-stallDial
691 return net.Dial(network, addr)
692 }
693
694 tr.DisableKeepAlives = true
695 tr.MaxConnsPerHost = 1
696
697 preDial := make(chan struct{})
698 reqComplete := make(chan struct{})
699 doReq := func(reqId string) {
700 req, _ := NewRequest("GET", ts.URL, nil)
701 trace := &httptrace.ClientTrace{
702 GetConn: func(hostPort string) {
703 preDial <- struct{}{}
704 },
705 }
706 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
707 resp, err := tr.RoundTrip(req)
708 if err != nil {
709 t.Errorf("unexpected error for request %s: %v", reqId, err)
710 }
711 _, err = io.ReadAll(resp.Body)
712 if err != nil {
713 t.Errorf("unexpected error for request %s: %v", reqId, err)
714 }
715 reqComplete <- struct{}{}
716 }
717
718 go doReq("req1")
719 <-preDial
720 <-dialStarted
721
722
723 go doReq("req2")
724 <-preDial
725 select {
726 case <-dialStarted:
727 t.Error("req2 dial started while req1 dial in progress")
728 return
729 default:
730 }
731
732
733 stallDial <- struct{}{}
734 <-reqComplete
735
736
737 <-dialStarted
738 stallDial <- struct{}{}
739 <-reqComplete
740 }
741
742 func TestTransportMaxConnsPerHost(t *testing.T) {
743 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
744 }
745 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
746 CondSkipHTTP2(t)
747
748 h := HandlerFunc(func(w ResponseWriter, r *Request) {
749 _, err := w.Write([]byte("foo"))
750 if err != nil {
751 t.Fatalf("Write: %v", err)
752 }
753 })
754
755 ts := newClientServerTest(t, mode, h).ts
756 c := ts.Client()
757 tr := c.Transport.(*Transport)
758 tr.MaxConnsPerHost = 1
759
760 mu := sync.Mutex{}
761 var conns []net.Conn
762 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
763 tr.Dial = func(network, addr string) (net.Conn, error) {
764 atomic.AddInt32(&dialCnt, 1)
765 c, err := net.Dial(network, addr)
766 mu.Lock()
767 defer mu.Unlock()
768 conns = append(conns, c)
769 return c, err
770 }
771
772 doReq := func() {
773 trace := &httptrace.ClientTrace{
774 GotConn: func(connInfo httptrace.GotConnInfo) {
775 if !connInfo.Reused {
776 atomic.AddInt32(&gotConnCnt, 1)
777 }
778 },
779 TLSHandshakeStart: func() {
780 atomic.AddInt32(&tlsHandshakeCnt, 1)
781 },
782 }
783 req, _ := NewRequest("GET", ts.URL, nil)
784 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
785
786 resp, err := c.Do(req)
787 if err != nil {
788 t.Fatalf("request failed: %v", err)
789 }
790 defer resp.Body.Close()
791 _, err = io.ReadAll(resp.Body)
792 if err != nil {
793 t.Fatalf("read body failed: %v", err)
794 }
795 }
796
797 wg := sync.WaitGroup{}
798 for i := 0; i < 10; i++ {
799 wg.Add(1)
800 go func() {
801 defer wg.Done()
802 doReq()
803 }()
804 }
805 wg.Wait()
806
807 expected := int32(tr.MaxConnsPerHost)
808 if dialCnt != expected {
809 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
810 }
811 if gotConnCnt != expected {
812 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
813 }
814 if ts.TLS != nil && tlsHandshakeCnt != expected {
815 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
816 }
817
818 if t.Failed() {
819 t.FailNow()
820 }
821
822 mu.Lock()
823 for _, c := range conns {
824 c.Close()
825 }
826 conns = nil
827 mu.Unlock()
828 tr.CloseIdleConnections()
829
830 doReq()
831 expected++
832 if dialCnt != expected {
833 t.Errorf("round 2: too many dials: %d", dialCnt)
834 }
835 if gotConnCnt != expected {
836 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
837 }
838 if ts.TLS != nil && tlsHandshakeCnt != expected {
839 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
840 }
841 }
842
843 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
844 run(t, testTransportMaxConnsPerHostDialCancellation,
845 testNotParallel,
846 []testMode{http1Mode, https1Mode, http2Mode},
847 )
848 }
849
850 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
851 CondSkipHTTP2(t)
852
853 h := HandlerFunc(func(w ResponseWriter, r *Request) {
854 _, err := w.Write([]byte("foo"))
855 if err != nil {
856 t.Fatalf("Write: %v", err)
857 }
858 })
859
860 cst := newClientServerTest(t, mode, h)
861 defer cst.close()
862 ts := cst.ts
863 c := ts.Client()
864 tr := c.Transport.(*Transport)
865 tr.MaxConnsPerHost = 1
866
867
868 ctx, cancel := context.WithCancel(context.Background())
869 defer cancel()
870 SetPendingDialHooks(cancel, nil)
871 defer SetPendingDialHooks(nil, nil)
872
873 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
874 _, err := c.Do(req)
875 if !errors.Is(err, context.Canceled) {
876 t.Errorf("expected error %v, got %v", context.Canceled, err)
877 }
878
879
880 SetPendingDialHooks(nil, nil)
881 req, _ = NewRequest("GET", ts.URL, nil)
882 resp, err := c.Do(req)
883 if err != nil {
884 t.Fatalf("request failed: %v", err)
885 }
886 defer resp.Body.Close()
887 _, err = io.ReadAll(resp.Body)
888 if err != nil {
889 t.Fatalf("read body failed: %v", err)
890 }
891 }
892
893 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
894 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
895 }
896 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
897 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
898 io.WriteString(w, r.RemoteAddr)
899 })).ts
900
901 c := ts.Client()
902 tr := c.Transport.(*Transport)
903
904 doReq := func(name string) {
905
906
907 res, err := c.Post(ts.URL, "", nil)
908 if err != nil {
909 t.Fatalf("%s: %v", name, err)
910 }
911 if res.StatusCode != 200 {
912 t.Fatalf("%s: %v", name, res.Status)
913 }
914 defer res.Body.Close()
915 slurp, err := io.ReadAll(res.Body)
916 if err != nil {
917 t.Fatalf("%s: %v", name, err)
918 }
919 t.Logf("%s: ok (%q)", name, slurp)
920 }
921
922 doReq("first")
923 keys1 := tr.IdleConnKeysForTesting()
924
925 ts.CloseClientConnections()
926
927 var keys2 []string
928 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
929 keys2 = tr.IdleConnKeysForTesting()
930 if len(keys2) != 0 {
931 if d > 0 {
932 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
933 }
934 return false
935 }
936 return true
937 })
938
939 doReq("second")
940 }
941
942
943
944 func TestTransportServerClosingUnexpectedly(t *testing.T) {
945 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
946 }
947 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
948 ts := newClientServerTest(t, mode, hostPortHandler).ts
949 c := ts.Client()
950
951 fetch := func(n, retries int) string {
952 condFatalf := func(format string, arg ...any) {
953 if retries <= 0 {
954 t.Fatalf(format, arg...)
955 }
956 t.Logf("retrying shortly after expected error: "+format, arg...)
957 time.Sleep(time.Second / time.Duration(retries))
958 }
959 for retries >= 0 {
960 retries--
961 res, err := c.Get(ts.URL)
962 if err != nil {
963 condFatalf("error in req #%d, GET: %v", n, err)
964 continue
965 }
966 body, err := io.ReadAll(res.Body)
967 if err != nil {
968 condFatalf("error in req #%d, ReadAll: %v", n, err)
969 continue
970 }
971 res.Body.Close()
972 return string(body)
973 }
974 panic("unreachable")
975 }
976
977 body1 := fetch(1, 0)
978 body2 := fetch(2, 0)
979
980
981
982
983
984
985
986
987 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
988
989 body3 := fetch(3, 5)
990
991 if body1 != body2 {
992 t.Errorf("expected body1 and body2 to be equal")
993 }
994 if body2 == body3 {
995 t.Errorf("expected body2 and body3 to be different")
996 }
997 }
998
999
1000
1001 func TestStressSurpriseServerCloses(t *testing.T) {
1002 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
1003 }
1004 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
1005 if testing.Short() {
1006 t.Skip("skipping test in short mode")
1007 }
1008 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1009 w.Header().Set("Content-Length", "5")
1010 w.Header().Set("Content-Type", "text/plain")
1011 w.Write([]byte("Hello"))
1012 w.(Flusher).Flush()
1013 conn, buf, _ := w.(Hijacker).Hijack()
1014 buf.Flush()
1015 conn.Close()
1016 })).ts
1017 c := ts.Client()
1018
1019
1020
1021
1022
1023
1024
1025 const (
1026 numClients = 20
1027 reqsPerClient = 25
1028 )
1029 var wg sync.WaitGroup
1030 wg.Add(numClients * reqsPerClient)
1031 for i := 0; i < numClients; i++ {
1032 go func() {
1033 for i := 0; i < reqsPerClient; i++ {
1034 res, err := c.Get(ts.URL)
1035 if err == nil {
1036
1037
1038
1039
1040
1041
1042 res.Body.Close()
1043 }
1044 wg.Done()
1045 }
1046 }()
1047 }
1048
1049
1050 wg.Wait()
1051 }
1052
1053
1054
1055 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
1056 func testTransportHeadResponses(t *testing.T, mode testMode) {
1057 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1058 if r.Method != "HEAD" {
1059 panic("expected HEAD; got " + r.Method)
1060 }
1061 w.Header().Set("Content-Length", "123")
1062 w.WriteHeader(200)
1063 })).ts
1064 c := ts.Client()
1065
1066 for i := 0; i < 2; i++ {
1067 res, err := c.Head(ts.URL)
1068 if err != nil {
1069 t.Errorf("error on loop %d: %v", i, err)
1070 continue
1071 }
1072 if e, g := "123", res.Header.Get("Content-Length"); e != g {
1073 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
1074 }
1075 if e, g := int64(123), res.ContentLength; e != g {
1076 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
1077 }
1078 if all, err := io.ReadAll(res.Body); err != nil {
1079 t.Errorf("loop %d: Body ReadAll: %v", i, err)
1080 } else if len(all) != 0 {
1081 t.Errorf("Bogus body %q", all)
1082 }
1083 }
1084 }
1085
1086
1087
1088 func TestTransportHeadChunkedResponse(t *testing.T) {
1089 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
1090 }
1091 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
1092 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1093 if r.Method != "HEAD" {
1094 panic("expected HEAD; got " + r.Method)
1095 }
1096 w.Header().Set("Transfer-Encoding", "chunked")
1097 w.Header().Set("x-client-ipport", r.RemoteAddr)
1098 w.WriteHeader(200)
1099 })).ts
1100 c := ts.Client()
1101
1102
1103
1104 didRead := make(chan bool)
1105 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
1106 defer SetReadLoopBeforeNextReadHook(nil)
1107
1108 res1, err := c.Head(ts.URL)
1109 <-didRead
1110
1111 if err != nil {
1112 t.Fatalf("request 1 error: %v", err)
1113 }
1114
1115 res2, err := c.Head(ts.URL)
1116 <-didRead
1117
1118 if err != nil {
1119 t.Fatalf("request 2 error: %v", err)
1120 }
1121 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1122 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1123 }
1124 }
1125
1126 var roundTripTests = []struct {
1127 accept string
1128 expectAccept string
1129 compressed bool
1130 }{
1131
1132 {"", "gzip", false},
1133
1134 {"foo", "foo", false},
1135
1136 {"gzip", "gzip", true},
1137 }
1138
1139
1140 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip, http3SkippedMode) }
1141 func testRoundTripGzip(t *testing.T, mode testMode) {
1142 const responseBody = "test response body"
1143 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1144 accept := req.Header.Get("Accept-Encoding")
1145 if expect := req.FormValue("expect_accept"); accept != expect {
1146 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1147 req.FormValue("testnum"), accept, expect)
1148 }
1149 if accept == "gzip" {
1150 rw.Header().Set("Content-Encoding", "gzip")
1151 gz := gzip.NewWriter(rw)
1152 gz.Write([]byte(responseBody))
1153 gz.Close()
1154 } else {
1155 rw.Header().Set("Content-Encoding", accept)
1156 rw.Write([]byte(responseBody))
1157 }
1158 })).ts
1159 tr := ts.Client().Transport.(*Transport)
1160
1161 for i, test := range roundTripTests {
1162
1163 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1164 if test.accept != "" {
1165 req.Header.Set("Accept-Encoding", test.accept)
1166 }
1167 res, err := tr.RoundTrip(req)
1168 if err != nil {
1169 t.Errorf("%d. RoundTrip: %v", i, err)
1170 continue
1171 }
1172 var body []byte
1173 if test.compressed {
1174 var r *gzip.Reader
1175 r, err = gzip.NewReader(res.Body)
1176 if err != nil {
1177 t.Errorf("%d. gzip NewReader: %v", i, err)
1178 continue
1179 }
1180 body, err = io.ReadAll(r)
1181 res.Body.Close()
1182 } else {
1183 body, err = io.ReadAll(res.Body)
1184 }
1185 if err != nil {
1186 t.Errorf("%d. Error: %q", i, err)
1187 continue
1188 }
1189 if g, e := string(body), responseBody; g != e {
1190 t.Errorf("%d. body = %q; want %q", i, g, e)
1191 }
1192 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1193 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1194 }
1195 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1196 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1197 }
1198 }
1199
1200 }
1201
1202 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip, http3SkippedMode) }
1203 func testTransportGzip(t *testing.T, mode testMode) {
1204 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1205 const nRandBytes = 1024 * 1024
1206 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1207 if req.Method == "HEAD" {
1208 if g := req.Header.Get("Accept-Encoding"); g != "" {
1209 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1210 }
1211 return
1212 }
1213 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1214 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1215 }
1216 rw.Header().Set("Content-Encoding", "gzip")
1217
1218 var w io.Writer = rw
1219 var buf bytes.Buffer
1220 if req.FormValue("chunked") == "0" {
1221 w = &buf
1222 defer io.Copy(rw, &buf)
1223 defer func() {
1224 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1225 }()
1226 }
1227 gz := gzip.NewWriter(w)
1228 gz.Write([]byte(testString))
1229 if req.FormValue("body") == "large" {
1230 io.CopyN(gz, rand.Reader, nRandBytes)
1231 }
1232 gz.Close()
1233 })).ts
1234 c := ts.Client()
1235
1236 for _, chunked := range []string{"1", "0"} {
1237
1238 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1239 if err != nil {
1240 t.Fatalf("large get: %v", err)
1241 }
1242 buf := make([]byte, len(testString))
1243 n, err := io.ReadFull(res.Body, buf)
1244 if err != nil {
1245 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1246 }
1247 if e, g := testString, string(buf); e != g {
1248 t.Errorf("partial read got %q, expected %q", g, e)
1249 }
1250 res.Body.Close()
1251
1252 n, err = res.Body.Read(buf)
1253 if n != 0 || err == nil {
1254 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1255 }
1256
1257
1258 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1259 if err != nil {
1260 t.Fatal(err)
1261 }
1262 body, err := io.ReadAll(res.Body)
1263 if err != nil {
1264 t.Fatal(err)
1265 }
1266 if g, e := string(body), testString; g != e {
1267 t.Fatalf("body = %q; want %q", g, e)
1268 }
1269 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1270 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1271 }
1272
1273
1274 n, err = res.Body.Read(buf)
1275 if n != 0 || err == nil {
1276 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1277 }
1278 res.Body.Close()
1279 n, err = res.Body.Read(buf)
1280 if n != 0 || err == nil {
1281 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1282 }
1283 }
1284
1285
1286 res, err := c.Head(ts.URL)
1287 if err != nil {
1288 t.Fatalf("Head: %v", err)
1289 }
1290 if res.StatusCode != 200 {
1291 t.Errorf("Head status=%d; want=200", res.StatusCode)
1292 }
1293 }
1294
1295
1296
1297 type transport100ContinueTest struct {
1298 t *testing.T
1299
1300 reqdone chan struct{}
1301 resp *Response
1302 respErr error
1303
1304 conn net.Conn
1305 reader *bufio.Reader
1306 }
1307
1308 const transport100ContinueTestBody = "request body"
1309
1310
1311
1312 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1313 ln := newLocalListener(t)
1314 defer ln.Close()
1315
1316 test := &transport100ContinueTest{
1317 t: t,
1318 reqdone: make(chan struct{}),
1319 }
1320
1321 tr := &Transport{
1322 ExpectContinueTimeout: timeout,
1323 }
1324 go func() {
1325 defer close(test.reqdone)
1326 body := strings.NewReader(transport100ContinueTestBody)
1327 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1328 req.Header.Set("Expect", "100-continue")
1329 req.ContentLength = int64(len(transport100ContinueTestBody))
1330 test.resp, test.respErr = tr.RoundTrip(req)
1331 test.resp.Body.Close()
1332 }()
1333
1334 c, err := ln.Accept()
1335 if err != nil {
1336 t.Fatalf("Accept: %v", err)
1337 }
1338 t.Cleanup(func() {
1339 c.Close()
1340 })
1341 br := bufio.NewReader(c)
1342 _, err = ReadRequest(br)
1343 if err != nil {
1344 t.Fatalf("ReadRequest: %v", err)
1345 }
1346 test.conn = c
1347 test.reader = br
1348 t.Cleanup(func() {
1349 <-test.reqdone
1350 tr.CloseIdleConnections()
1351 got, _ := io.ReadAll(test.reader)
1352 if len(got) > 0 {
1353 t.Fatalf("Transport sent unexpected bytes: %q", got)
1354 }
1355 })
1356
1357 return test
1358 }
1359
1360
1361 func (test *transport100ContinueTest) respond(lines ...string) {
1362 for _, line := range lines {
1363 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1364 test.t.Fatalf("Write: %v", err)
1365 }
1366 }
1367 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1368 test.t.Fatalf("Write: %v", err)
1369 }
1370 }
1371
1372
1373 func (test *transport100ContinueTest) wantBodySent() {
1374 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1375 if err != nil {
1376 test.t.Fatalf("unexpected error reading body: %v", err)
1377 }
1378 if got, want := string(got), transport100ContinueTestBody; got != want {
1379 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1380 }
1381 }
1382
1383
1384 func (test *transport100ContinueTest) wantRequestDone(want int) {
1385 <-test.reqdone
1386 if test.respErr != nil {
1387 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1388 }
1389 if got := test.resp.StatusCode; got != want {
1390 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1391 }
1392 }
1393
1394 func TestTransportExpect100ContinueSent(t *testing.T) {
1395 test := newTransport100ContinueTest(t, 1*time.Hour)
1396
1397 test.respond("HTTP/1.1 100 Continue")
1398 test.wantBodySent()
1399 test.respond("HTTP/1.1 200", "Content-Length: 0")
1400 test.wantRequestDone(200)
1401 }
1402
1403 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1404 test := newTransport100ContinueTest(t, 1*time.Hour)
1405
1406 test.respond("HTTP/1.1 200", "Content-Length: 0")
1407 test.wantBodySent()
1408 test.wantRequestDone(200)
1409 }
1410
1411 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1412 test := newTransport100ContinueTest(t, 1*time.Hour)
1413
1414 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1415 test.wantRequestDone(200)
1416 }
1417
1418 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1419 test := newTransport100ContinueTest(t, 1*time.Hour)
1420
1421 test.respond("HTTP/1.1 500", "Content-Length: 0")
1422 test.wantBodySent()
1423 test.wantRequestDone(500)
1424 }
1425
1426 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1427 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1428 test.wantBodySent()
1429 test.respond("HTTP/1.1 200", "Content-Length: 0")
1430 test.wantRequestDone(200)
1431 }
1432
1433 func TestSOCKS5Proxy(t *testing.T) {
1434 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1435 }
1436 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1437 ch := make(chan string, 1)
1438 l := newLocalListener(t)
1439 defer l.Close()
1440 defer close(ch)
1441 proxy := func(t *testing.T) {
1442 s, err := l.Accept()
1443 if err != nil {
1444 t.Errorf("socks5 proxy Accept(): %v", err)
1445 return
1446 }
1447 defer s.Close()
1448 var buf [22]byte
1449 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1450 t.Errorf("socks5 proxy initial read: %v", err)
1451 return
1452 }
1453 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1454 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1455 return
1456 }
1457 if _, err := s.Write([]byte{5, 0}); err != nil {
1458 t.Errorf("socks5 proxy initial write: %v", err)
1459 return
1460 }
1461 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1462 t.Errorf("socks5 proxy second read: %v", err)
1463 return
1464 }
1465 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1466 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1467 return
1468 }
1469 var ipLen int
1470 switch buf[3] {
1471 case 1:
1472 ipLen = net.IPv4len
1473 case 4:
1474 ipLen = net.IPv6len
1475 default:
1476 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1477 return
1478 }
1479 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1480 t.Errorf("socks5 proxy address read: %v", err)
1481 return
1482 }
1483 ip := net.IP(buf[4 : ipLen+4])
1484 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1485 copy(buf[:3], []byte{5, 0, 0})
1486 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1487 t.Errorf("socks5 proxy connect write: %v", err)
1488 return
1489 }
1490 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1491
1492
1493 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1494 targetConn, err := net.Dial("tcp", targetHost)
1495 if err != nil {
1496 t.Errorf("net.Dial failed")
1497 return
1498 }
1499 go io.Copy(targetConn, s)
1500 io.Copy(s, targetConn)
1501 targetConn.Close()
1502 }
1503
1504 pu, err := url.Parse("socks5://" + l.Addr().String())
1505 if err != nil {
1506 t.Fatal(err)
1507 }
1508
1509 sentinelHeader := "X-Sentinel"
1510 sentinelValue := "12345"
1511 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1512 w.Header().Set(sentinelHeader, sentinelValue)
1513 })
1514 for _, useTLS := range []bool{false, true} {
1515 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1516 ts := newClientServerTest(t, mode, h).ts
1517 go proxy(t)
1518 c := ts.Client()
1519 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1520 r, err := c.Head(ts.URL)
1521 if err != nil {
1522 t.Fatal(err)
1523 }
1524 if r.Header.Get(sentinelHeader) != sentinelValue {
1525 t.Errorf("Failed to retrieve sentinel value")
1526 }
1527 got := <-ch
1528 ts.Close()
1529 tsu, err := url.Parse(ts.URL)
1530 if err != nil {
1531 t.Fatal(err)
1532 }
1533 want := "proxy for " + tsu.Host
1534 if got != want {
1535 t.Errorf("got %q, want %q", got, want)
1536 }
1537 })
1538 }
1539 }
1540
1541 func TestTransportProxy(t *testing.T) {
1542 defer afterTest(t)
1543 testCases := []struct{ siteMode, proxyMode testMode }{
1544 {http1Mode, http1Mode},
1545 {http1Mode, https1Mode},
1546 {https1Mode, http1Mode},
1547 {https1Mode, https1Mode},
1548 }
1549 for _, testCase := range testCases {
1550 siteMode := testCase.siteMode
1551 proxyMode := testCase.proxyMode
1552 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1553 siteCh := make(chan *Request, 1)
1554 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1555 siteCh <- r
1556 })
1557 proxyCh := make(chan *Request, 1)
1558 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1559 proxyCh <- r
1560
1561 if r.Method == "CONNECT" {
1562 hijacker, ok := w.(Hijacker)
1563 if !ok {
1564 t.Errorf("hijack not allowed")
1565 return
1566 }
1567 clientConn, _, err := hijacker.Hijack()
1568 if err != nil {
1569 t.Errorf("hijacking failed")
1570 return
1571 }
1572 res := &Response{
1573 StatusCode: StatusOK,
1574 Proto: "HTTP/1.1",
1575 ProtoMajor: 1,
1576 ProtoMinor: 1,
1577 Header: make(Header),
1578 }
1579
1580 targetConn, err := net.Dial("tcp", r.URL.Host)
1581 if err != nil {
1582 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1583 return
1584 }
1585
1586 if err := res.Write(clientConn); err != nil {
1587 t.Errorf("Writing 200 OK failed: %v", err)
1588 return
1589 }
1590
1591 go io.Copy(targetConn, clientConn)
1592 go func() {
1593 io.Copy(clientConn, targetConn)
1594 targetConn.Close()
1595 }()
1596 }
1597 })
1598 ts := newClientServerTest(t, siteMode, h1).ts
1599 proxy := newClientServerTest(t, proxyMode, h2).ts
1600
1601 pu, err := url.Parse(proxy.URL)
1602 if err != nil {
1603 t.Fatal(err)
1604 }
1605
1606
1607
1608
1609 c := proxy.Client()
1610 if siteMode == https1Mode {
1611 c = ts.Client()
1612 }
1613
1614 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1615 if _, err := c.Head(ts.URL); err != nil {
1616 t.Error(err)
1617 }
1618 got := <-proxyCh
1619 c.Transport.(*Transport).CloseIdleConnections()
1620 ts.Close()
1621 proxy.Close()
1622 if siteMode == https1Mode {
1623
1624 if got.Method != "CONNECT" {
1625 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1626 }
1627 gotHost := got.URL.Host
1628 pu, err := url.Parse(ts.URL)
1629 if err != nil {
1630 t.Fatal("Invalid site URL")
1631 }
1632 if wantHost := pu.Host; gotHost != wantHost {
1633 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1634 }
1635
1636
1637 next := <-siteCh
1638 if next.Method != "HEAD" {
1639 t.Errorf("Wrong method at destination: %s", next.Method)
1640 }
1641 if nextURL := next.URL.String(); nextURL != "/" {
1642 t.Errorf("Wrong URL at destination: %s", nextURL)
1643 }
1644 } else {
1645 if got.Method != "HEAD" {
1646 t.Errorf("Wrong method for destination: %q", got.Method)
1647 }
1648 gotURL := got.URL.String()
1649 wantURL := ts.URL + "/"
1650 if gotURL != wantURL {
1651 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1652 }
1653 }
1654 })
1655 }
1656 }
1657
1658
1659
1660
1661 func TestProxyWithInfiniteHeader(t *testing.T) {
1662 defer afterTest(t)
1663
1664 ln := newLocalListener(t)
1665 defer ln.Close()
1666 cancelc := make(chan struct{})
1667 defer close(cancelc)
1668
1669
1670
1671 go func() {
1672 c, err := ln.Accept()
1673 if err != nil {
1674 t.Errorf("Accept: %v", err)
1675 return
1676 }
1677 defer c.Close()
1678
1679 br := bufio.NewReader(c)
1680 cr, err := ReadRequest(br)
1681 if err != nil {
1682 t.Errorf("proxy server failed to read CONNECT request")
1683 return
1684 }
1685 if cr.Method != "CONNECT" {
1686 t.Errorf("unexpected method %q", cr.Method)
1687 return
1688 }
1689
1690
1691 for {
1692
1693
1694
1695 runtime.Gosched()
1696 select {
1697 case <-cancelc:
1698 return
1699 default:
1700 c.Write([]byte("infinite stream of bytes"))
1701 }
1702 }
1703 }()
1704
1705 c := &Client{
1706 Transport: &Transport{
1707 Proxy: func(*Request) (*url.URL, error) {
1708 return url.Parse("http://" + ln.Addr().String())
1709 },
1710
1711 MaxResponseHeaderBytes: 1024,
1712 },
1713 }
1714 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1715 if err != nil {
1716 t.Fatal(err)
1717 }
1718 _, err = c.Do(req)
1719 if err == nil {
1720 t.Errorf("unexpected Get success")
1721 }
1722 }
1723
1724 func TestOnProxyConnectResponse(t *testing.T) {
1725
1726 var tcases = []struct {
1727 proxyStatusCode int
1728 err error
1729 }{
1730 {
1731 StatusOK,
1732 nil,
1733 },
1734 {
1735 StatusForbidden,
1736 errors.New("403"),
1737 },
1738 }
1739 for _, tcase := range tcases {
1740 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1741
1742 })
1743
1744 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1745
1746 if r.Method == "CONNECT" {
1747 if tcase.proxyStatusCode != StatusOK {
1748 w.WriteHeader(tcase.proxyStatusCode)
1749 return
1750 }
1751 hijacker, ok := w.(Hijacker)
1752 if !ok {
1753 t.Errorf("hijack not allowed")
1754 return
1755 }
1756 clientConn, _, err := hijacker.Hijack()
1757 if err != nil {
1758 t.Errorf("hijacking failed")
1759 return
1760 }
1761 res := &Response{
1762 StatusCode: StatusOK,
1763 Proto: "HTTP/1.1",
1764 ProtoMajor: 1,
1765 ProtoMinor: 1,
1766 Header: make(Header),
1767 }
1768
1769 targetConn, err := net.Dial("tcp", r.URL.Host)
1770 if err != nil {
1771 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1772 return
1773 }
1774
1775 if err := res.Write(clientConn); err != nil {
1776 t.Errorf("Writing 200 OK failed: %v", err)
1777 return
1778 }
1779
1780 go io.Copy(targetConn, clientConn)
1781 go func() {
1782 io.Copy(clientConn, targetConn)
1783 targetConn.Close()
1784 }()
1785 }
1786 })
1787 ts := newClientServerTest(t, https1Mode, h1).ts
1788 proxy := newClientServerTest(t, https1Mode, h2).ts
1789
1790 pu, err := url.Parse(proxy.URL)
1791 if err != nil {
1792 t.Fatal(err)
1793 }
1794
1795 c := proxy.Client()
1796
1797 var (
1798 dials atomic.Int32
1799 closes atomic.Int32
1800 )
1801 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1802 conn, err := net.Dial(network, addr)
1803 if err != nil {
1804 return nil, err
1805 }
1806 dials.Add(1)
1807 return noteCloseConn{
1808 Conn: conn,
1809 closeFunc: func() {
1810 closes.Add(1)
1811 },
1812 }, nil
1813 }
1814
1815 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1816 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1817 if proxyURL.String() != pu.String() {
1818 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1819 }
1820
1821 if "https://"+connectReq.URL.String() != ts.URL {
1822 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1823 }
1824 return tcase.err
1825 }
1826 wantCloses := int32(0)
1827 if _, err := c.Head(ts.URL); err != nil {
1828 wantCloses = 1
1829 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1830 t.Errorf("got %v, want %v", err, tcase.err)
1831 }
1832 } else {
1833 if tcase.err != nil {
1834 t.Errorf("got %v, want nil", err)
1835 }
1836 }
1837 if got, want := dials.Load(), int32(1); got != want {
1838 t.Errorf("got %v dials, want %v", got, want)
1839 }
1840
1841 if got, want := closes.Load(), wantCloses; got != want {
1842 t.Errorf("got %v closes, want %v", got, want)
1843 }
1844 }
1845 }
1846
1847
1848
1849 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1850 cancelc := make(chan struct{})
1851 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1852 ctx, cancel := context.WithCancel(ctx)
1853 go func() {
1854 select {
1855 case <-cancelc:
1856 case <-ctx.Done():
1857 }
1858 cancel()
1859 }()
1860 return ctx, cancel
1861 })
1862
1863 defer afterTest(t)
1864
1865 ln := newLocalListener(t)
1866 defer ln.Close()
1867 listenerDone := make(chan struct{})
1868 go func() {
1869 defer close(listenerDone)
1870 c, err := ln.Accept()
1871 if err != nil {
1872 t.Errorf("Accept: %v", err)
1873 return
1874 }
1875 defer c.Close()
1876
1877 br := bufio.NewReader(c)
1878 cr, err := ReadRequest(br)
1879 if err != nil {
1880 t.Errorf("proxy server failed to read CONNECT request")
1881 return
1882 }
1883 if cr.Method != "CONNECT" {
1884 t.Errorf("unexpected method %q", cr.Method)
1885 return
1886 }
1887
1888
1889
1890
1891 close(cancelc)
1892 var buf [1]byte
1893 _, err = br.Read(buf[:])
1894 if err != io.EOF {
1895 t.Errorf("proxy server Read err = %v; want EOF", err)
1896 }
1897 return
1898 }()
1899
1900 c := &Client{
1901 Transport: &Transport{
1902 Proxy: func(*Request) (*url.URL, error) {
1903 return url.Parse("http://" + ln.Addr().String())
1904 },
1905 },
1906 }
1907 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1908 if err != nil {
1909 t.Fatal(err)
1910 }
1911 _, err = c.Do(req)
1912 if err == nil {
1913 t.Errorf("unexpected Get success")
1914 }
1915
1916
1917
1918
1919 <-listenerDone
1920 }
1921
1922
1923 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1924 defer afterTest(t)
1925
1926 var errDial = errors.New("some dial error")
1927
1928 tr := &Transport{
1929 Proxy: func(*Request) (*url.URL, error) {
1930 return url.Parse("http://proxy.fake.tld/")
1931 },
1932 Dial: func(string, string) (net.Conn, error) {
1933 return nil, errDial
1934 },
1935 }
1936 defer tr.CloseIdleConnections()
1937
1938 c := &Client{Transport: tr}
1939 req, _ := NewRequest("GET", "http://fake.tld", nil)
1940 res, err := c.Do(req)
1941 if err == nil {
1942 res.Body.Close()
1943 t.Fatal("wanted a non-nil error")
1944 }
1945
1946 uerr, ok := err.(*url.Error)
1947 if !ok {
1948 t.Fatalf("got %T, want *url.Error", err)
1949 }
1950 oe, ok := uerr.Err.(*net.OpError)
1951 if !ok {
1952 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1953 }
1954 want := &net.OpError{
1955 Op: "proxyconnect",
1956 Net: "tcp",
1957 Err: errDial,
1958 }
1959 if !reflect.DeepEqual(oe, want) {
1960 t.Errorf("Got error %#v; want %#v", oe, want)
1961 }
1962 }
1963
1964
1965
1966
1967
1968 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1969 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader, http3SkippedMode)
1970 }
1971 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1972 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1973 defer proxy.Close()
1974 c := proxy.Client()
1975
1976 tr := c.Transport.(*Transport)
1977 tr.Proxy = func(*Request) (*url.URL, error) {
1978 u, _ := url.Parse(proxy.URL)
1979 u.User = url.UserPassword("aladdin", "opensesame")
1980 return u, nil
1981 }
1982 h := tr.ProxyConnectHeader
1983 if h == nil {
1984 h = make(Header)
1985 }
1986 tr.ProxyConnectHeader = h.Clone()
1987
1988 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1989 if err != nil {
1990 t.Fatal(err)
1991 }
1992 _, err = c.Do(req)
1993 if err == nil {
1994 t.Errorf("unexpected Get success")
1995 }
1996
1997 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1998 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1999 }
2000 }
2001
2002
2003
2004
2005
2006 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive, http3SkippedMode) }
2007 func testTransportGzipRecursive(t *testing.T, mode testMode) {
2008 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2009 w.Header().Set("Content-Encoding", "gzip")
2010 w.Write(rgz)
2011 })).ts
2012
2013 c := ts.Client()
2014 res, err := c.Get(ts.URL)
2015 if err != nil {
2016 t.Fatal(err)
2017 }
2018 body, err := io.ReadAll(res.Body)
2019 if err != nil {
2020 t.Fatal(err)
2021 }
2022 if !bytes.Equal(body, rgz) {
2023 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
2024 body, rgz)
2025 }
2026 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
2027 t.Fatalf("Content-Encoding = %q; want %q", g, e)
2028 }
2029 }
2030
2031
2032
2033 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort, http3SkippedMode) }
2034 func testTransportGzipShort(t *testing.T, mode testMode) {
2035 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2036 w.Header().Set("Content-Encoding", "gzip")
2037 w.Write([]byte{0x1f, 0x8b})
2038 })).ts
2039
2040 c := ts.Client()
2041 res, err := c.Get(ts.URL)
2042 if err != nil {
2043 t.Fatal(err)
2044 }
2045 defer res.Body.Close()
2046 _, err = io.ReadAll(res.Body)
2047 if err == nil {
2048 t.Fatal("Expect an error from reading a body.")
2049 }
2050 if err != io.ErrUnexpectedEOF {
2051 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
2052 }
2053 }
2054
2055 func TestTransportGzipConcurrentCloseAndRead(t *testing.T) {
2056 runSynctest(t, testTransportGzipConcurrentCloseAndRead, http3SkippedMode)
2057 }
2058 func testTransportGzipConcurrentCloseAndRead(t *testing.T, mode testMode) {
2059 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2060
2061 w.Header().Set("Content-Encoding", "gzip")
2062 w.WriteHeader(200)
2063 w.(Flusher).Flush()
2064 }), optFakeNet)
2065
2066 if mode == http1Mode {
2067 cst.tr.DisableKeepAlives = true
2068 }
2069
2070 res, err := cst.c.Get(cst.ts.URL)
2071 if err != nil {
2072 t.Fatal(err)
2073 }
2074
2075
2076
2077 readErrCh := make(chan error, 1)
2078 go func() {
2079 var buf [10]byte
2080 _, err := res.Body.Read(buf[:])
2081 readErrCh <- err
2082 }()
2083 synctest.Wait()
2084
2085 if err := res.Body.Close(); err != nil {
2086 t.Fatal(err)
2087 }
2088 synctest.Wait()
2089
2090 select {
2091 case err := <-readErrCh:
2092 if err == nil {
2093 t.Error("Read returned nil error, want error")
2094 }
2095 default:
2096 t.Fatal("Read did not unblock on Close")
2097 }
2098 }
2099
2100
2101 func waitNumGoroutine(nmax int) int {
2102 nfinal := runtime.NumGoroutine()
2103 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
2104 time.Sleep(50 * time.Millisecond)
2105 runtime.GC()
2106 nfinal = runtime.NumGoroutine()
2107 }
2108 return nfinal
2109 }
2110
2111
2112 func TestTransportPersistConnLeak(t *testing.T) {
2113 run(t, testTransportPersistConnLeak, testNotParallel)
2114 }
2115 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
2116 if mode == http2Mode || mode == http3Mode {
2117 t.Skip("flaky in HTTP/2 and HTTP/3")
2118 }
2119
2120
2121 const numReq = 25
2122 gotReqCh := make(chan bool, numReq)
2123 unblockCh := make(chan bool, numReq)
2124 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2125 gotReqCh <- true
2126 <-unblockCh
2127 w.Header().Set("Content-Length", "0")
2128 w.WriteHeader(204)
2129 })).ts
2130 c := ts.Client()
2131 tr := c.Transport.(*Transport)
2132
2133 n0 := runtime.NumGoroutine()
2134
2135 didReqCh := make(chan bool, numReq)
2136 failed := make(chan bool, numReq)
2137 for i := 0; i < numReq; i++ {
2138 go func() {
2139 res, err := c.Get(ts.URL)
2140 didReqCh <- true
2141 if err != nil {
2142 t.Logf("client fetch error: %v", err)
2143 failed <- true
2144 return
2145 }
2146 res.Body.Close()
2147 }()
2148 }
2149
2150
2151 for i := 0; i < numReq; i++ {
2152 select {
2153 case <-gotReqCh:
2154
2155 case <-failed:
2156
2157
2158 }
2159 }
2160
2161 nhigh := runtime.NumGoroutine()
2162
2163
2164 close(unblockCh)
2165
2166
2167 for i := 0; i < numReq; i++ {
2168 <-didReqCh
2169 }
2170
2171 tr.CloseIdleConnections()
2172 nfinal := waitNumGoroutine(n0 + 5)
2173
2174 growth := nfinal - n0
2175
2176
2177
2178 if int(growth) > 5 {
2179 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2180 t.Error("too many new goroutines")
2181 }
2182 }
2183
2184
2185
2186 func TestTransportPersistConnLeakShortBody(t *testing.T) {
2187 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
2188 }
2189 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
2190 if mode == http2Mode || mode == http3Mode {
2191 t.Skip("flaky in HTTP/2 and HTTP/3")
2192 }
2193
2194
2195 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2196 })).ts
2197 c := ts.Client()
2198 tr := c.Transport.(*Transport)
2199
2200 n0 := runtime.NumGoroutine()
2201 body := []byte("Hello")
2202 for i := 0; i < 20; i++ {
2203 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2204 if err != nil {
2205 t.Fatal(err)
2206 }
2207 req.ContentLength = int64(len(body) - 2)
2208 _, err = c.Do(req)
2209 if err == nil {
2210 t.Fatal("Expect an error from writing too long of a body.")
2211 }
2212 }
2213 nhigh := runtime.NumGoroutine()
2214 tr.CloseIdleConnections()
2215 nfinal := waitNumGoroutine(n0 + 5)
2216
2217 growth := nfinal - n0
2218
2219
2220
2221 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2222 if int(growth) > 5 {
2223 t.Error("too many new goroutines")
2224 }
2225 }
2226
2227
2228 type countedConn struct {
2229 net.Conn
2230 }
2231
2232
2233 type countingDialer struct {
2234 dialer net.Dialer
2235 mu sync.Mutex
2236 total, live int64
2237 }
2238
2239 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2240 conn, err := d.dialer.DialContext(ctx, network, address)
2241 if err != nil {
2242 return nil, err
2243 }
2244
2245 counted := new(countedConn)
2246 counted.Conn = conn
2247
2248 d.mu.Lock()
2249 defer d.mu.Unlock()
2250 d.total++
2251 d.live++
2252
2253 runtime.AddCleanup(counted, func(dd *countingDialer) { dd.decrement(nil) }, d)
2254 return counted, nil
2255 }
2256
2257 func (d *countingDialer) decrement(*countedConn) {
2258 d.mu.Lock()
2259 defer d.mu.Unlock()
2260 d.live--
2261 }
2262
2263 func (d *countingDialer) Read() (total, live int64) {
2264 d.mu.Lock()
2265 defer d.mu.Unlock()
2266 return d.total, d.live
2267 }
2268
2269 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2270 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2271 }
2272 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2273 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2274
2275 conn, _, err := w.(Hijacker).Hijack()
2276 if err != nil {
2277 t.Errorf("Hijack failed unexpectedly: %v", err)
2278 return
2279 }
2280 conn.Close()
2281 })).ts
2282
2283 var d countingDialer
2284 c := ts.Client()
2285 c.Transport.(*Transport).DialContext = d.DialContext
2286
2287 body := []byte("Hello")
2288 for i := 0; ; i++ {
2289 total, live := d.Read()
2290 if live < total {
2291 break
2292 }
2293 if i >= 1<<12 {
2294 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2295 }
2296
2297 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2298 if err != nil {
2299 t.Fatal(err)
2300 }
2301 _, err = c.Do(req)
2302 if err == nil {
2303 t.Fatal("expected broken connection")
2304 }
2305
2306 runtime.GC()
2307 }
2308 }
2309
2310 type countedContext struct {
2311 context.Context
2312 }
2313
2314 type contextCounter struct {
2315 mu sync.Mutex
2316 live int64
2317 }
2318
2319 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2320 counted := new(countedContext)
2321 counted.Context = ctx
2322 cc.mu.Lock()
2323 defer cc.mu.Unlock()
2324 cc.live++
2325 runtime.AddCleanup(counted, func(c *contextCounter) { cc.decrement(nil) }, cc)
2326 return counted
2327 }
2328
2329 func (cc *contextCounter) decrement(*countedContext) {
2330 cc.mu.Lock()
2331 defer cc.mu.Unlock()
2332 cc.live--
2333 }
2334
2335 func (cc *contextCounter) Read() (live int64) {
2336 cc.mu.Lock()
2337 defer cc.mu.Unlock()
2338 return cc.live
2339 }
2340
2341 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2342 run(t, testTransportPersistConnContextLeakMaxConnsPerHost, http3SkippedMode)
2343 }
2344 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2345 if mode == http2Mode {
2346 t.Skip("https://go.dev/issue/56021")
2347 }
2348
2349 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2350 runtime.Gosched()
2351 w.WriteHeader(StatusOK)
2352 })).ts
2353
2354 c := ts.Client()
2355 c.Transport.(*Transport).MaxConnsPerHost = 1
2356
2357 ctx := context.Background()
2358 body := []byte("Hello")
2359 doPosts := func(cc *contextCounter) {
2360 var wg sync.WaitGroup
2361 for n := 64; n > 0; n-- {
2362 wg.Add(1)
2363 go func() {
2364 defer wg.Done()
2365
2366 ctx := cc.Track(ctx)
2367 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2368 if err != nil {
2369 t.Error(err)
2370 }
2371
2372 _, err = c.Do(req.WithContext(ctx))
2373 if err != nil {
2374 t.Errorf("Do failed with error: %v", err)
2375 }
2376 }()
2377 }
2378 wg.Wait()
2379 }
2380
2381 var initialCC contextCounter
2382 doPosts(&initialCC)
2383
2384
2385
2386
2387 var flushCC contextCounter
2388 for i := 0; ; i++ {
2389 live := initialCC.Read()
2390 if live == 0 {
2391 break
2392 }
2393 if i >= 100 {
2394 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2395 }
2396 doPosts(&flushCC)
2397 runtime.GC()
2398 }
2399 }
2400
2401
2402 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash, http3SkippedMode) }
2403 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2404 var tr *Transport
2405
2406 unblockCh := make(chan bool, 1)
2407 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2408 <-unblockCh
2409 tr.CloseIdleConnections()
2410 })).ts
2411 c := ts.Client()
2412 tr = c.Transport.(*Transport)
2413
2414 didreq := make(chan bool)
2415 go func() {
2416 res, err := c.Get(ts.URL)
2417 if err != nil {
2418 t.Error(err)
2419 } else {
2420 res.Body.Close()
2421 }
2422 didreq <- true
2423 }()
2424 unblockCh <- true
2425 <-didreq
2426 }
2427
2428
2429
2430
2431
2432 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2433 func testIssue3644(t *testing.T, mode testMode) {
2434 const numFoos = 5000
2435 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2436 w.Header().Set("Connection", "close")
2437 for i := 0; i < numFoos; i++ {
2438 w.Write([]byte("foo "))
2439 }
2440 })).ts
2441 c := ts.Client()
2442 res, err := c.Get(ts.URL)
2443 if err != nil {
2444 t.Fatal(err)
2445 }
2446 defer res.Body.Close()
2447 bs, err := io.ReadAll(res.Body)
2448 if err != nil {
2449 t.Fatal(err)
2450 }
2451 if len(bs) != numFoos*len("foo ") {
2452 t.Errorf("unexpected response length")
2453 }
2454 }
2455
2456
2457
2458 func TestIssue3595(t *testing.T) {
2459
2460
2461 run(t, testIssue3595, testNotParallel, http3SkippedMode)
2462 }
2463 func testIssue3595(t *testing.T, mode testMode) {
2464 runTimeSensitiveTest(t, []time.Duration{
2465 1 * time.Millisecond,
2466 5 * time.Millisecond,
2467 10 * time.Millisecond,
2468 50 * time.Millisecond,
2469 100 * time.Millisecond,
2470 500 * time.Millisecond,
2471 time.Second,
2472 5 * time.Second,
2473 }, func(t *testing.T, timeout time.Duration) error {
2474 SetRSTAvoidanceDelay(t, timeout)
2475 t.Logf("set RST avoidance delay to %v", timeout)
2476
2477 const deniedMsg = "sorry, denied."
2478 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2479 Error(w, deniedMsg, StatusUnauthorized)
2480 }))
2481
2482
2483 defer cst.close()
2484 ts := cst.ts
2485 c := ts.Client()
2486
2487 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2488 if err != nil {
2489 return fmt.Errorf("Post: %v", err)
2490 }
2491 got, err := io.ReadAll(res.Body)
2492 if err != nil {
2493 return fmt.Errorf("Body ReadAll: %v", err)
2494 }
2495 t.Logf("server response:\n%s", got)
2496 if !strings.Contains(string(got), deniedMsg) {
2497
2498
2499 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2500 }
2501 return nil
2502 })
2503 }
2504
2505
2506
2507 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2508 func testChunkedNoContent(t *testing.T, mode testMode) {
2509 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2510 w.WriteHeader(StatusNoContent)
2511 })).ts
2512
2513 c := ts.Client()
2514 for _, closeBody := range []bool{true, false} {
2515 const n = 4
2516 for i := 1; i <= n; i++ {
2517 res, err := c.Get(ts.URL)
2518 if err != nil {
2519 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2520 } else {
2521 if closeBody {
2522 res.Body.Close()
2523 }
2524 }
2525 }
2526 }
2527 }
2528
2529 func TestTransportConcurrency(t *testing.T) {
2530 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2531 }
2532 func testTransportConcurrency(t *testing.T, mode testMode) {
2533
2534 maxProcs, numReqs := 16, 500
2535 if testing.Short() {
2536 maxProcs, numReqs = 4, 50
2537 }
2538 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2539 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2540 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2541 })).ts
2542
2543 var wg sync.WaitGroup
2544 wg.Add(numReqs)
2545
2546
2547
2548
2549
2550
2551
2552 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2553 defer SetPendingDialHooks(nil, nil)
2554
2555 c := ts.Client()
2556 reqs := make(chan string)
2557 defer close(reqs)
2558
2559 for i := 0; i < maxProcs*2; i++ {
2560 go func() {
2561 for req := range reqs {
2562 res, err := c.Get(ts.URL + "/?echo=" + req)
2563 if err != nil {
2564 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2565
2566
2567 t.Logf("error on req %s: %v", req, err)
2568 t.Logf("(see https://go.dev/issue/52168)")
2569 } else {
2570 t.Errorf("error on req %s: %v", req, err)
2571 }
2572 wg.Done()
2573 continue
2574 }
2575 all, err := io.ReadAll(res.Body)
2576 if err != nil {
2577 t.Errorf("read error on req %s: %v", req, err)
2578 } else if string(all) != req {
2579 t.Errorf("body of req %s = %q; want %q", req, all, req)
2580 }
2581 res.Body.Close()
2582 wg.Done()
2583 }
2584 }()
2585 }
2586 for i := 0; i < numReqs; i++ {
2587 reqs <- fmt.Sprintf("request-%d", i)
2588 }
2589 wg.Wait()
2590 }
2591
2592 func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
2593 run(t, testIssue4191_InfiniteGetTimeout, http3SkippedMode)
2594 }
2595 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2596 mux := NewServeMux()
2597 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2598 io.Copy(w, neverEnding('a'))
2599 })
2600 ts := newClientServerTest(t, mode, mux).ts
2601
2602 connc := make(chan net.Conn, 1)
2603 c := ts.Client()
2604 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2605 conn, err := net.Dial(n, addr)
2606 if err != nil {
2607 return nil, err
2608 }
2609 select {
2610 case connc <- conn:
2611 default:
2612 }
2613 return conn, nil
2614 }
2615
2616 res, err := c.Get(ts.URL + "/get")
2617 if err != nil {
2618 t.Fatalf("Error issuing GET: %v", err)
2619 }
2620 defer res.Body.Close()
2621
2622 conn := <-connc
2623 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2624 _, err = io.Copy(io.Discard, res.Body)
2625 if err == nil {
2626 t.Errorf("Unexpected successful copy")
2627 }
2628 }
2629
2630 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2631 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2632 }
2633 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2634 const debug = false
2635 mux := NewServeMux()
2636 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2637 io.Copy(w, neverEnding('a'))
2638 })
2639 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2640 defer r.Body.Close()
2641 io.Copy(io.Discard, r.Body)
2642 })
2643 ts := newClientServerTest(t, mode, mux).ts
2644 timeout := 100 * time.Millisecond
2645
2646 c := ts.Client()
2647 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2648 conn, err := net.Dial(n, addr)
2649 if err != nil {
2650 return nil, err
2651 }
2652 conn.SetDeadline(time.Now().Add(timeout))
2653 if debug {
2654 conn = NewLoggingConn("client", conn)
2655 }
2656 return conn, nil
2657 }
2658
2659 getFailed := false
2660 nRuns := 5
2661 if testing.Short() {
2662 nRuns = 1
2663 }
2664 for i := 0; i < nRuns; i++ {
2665 if debug {
2666 println("run", i+1, "of", nRuns)
2667 }
2668 sres, err := c.Get(ts.URL + "/get")
2669 if err != nil {
2670 if !getFailed {
2671
2672 getFailed = true
2673 t.Logf("increasing timeout")
2674 i--
2675 timeout *= 10
2676 continue
2677 }
2678 t.Errorf("Error issuing GET: %v", err)
2679 break
2680 }
2681 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2682 _, err = c.Do(req)
2683 if err == nil {
2684 sres.Body.Close()
2685 t.Errorf("Unexpected successful PUT")
2686 break
2687 }
2688 sres.Body.Close()
2689 }
2690 if debug {
2691 println("tests complete; waiting for handlers to finish")
2692 }
2693 ts.Close()
2694 }
2695
2696 func TestTransportResponseHeaderTimeout(t *testing.T) {
2697 run(t, testTransportResponseHeaderTimeout, http3SkippedMode)
2698 }
2699 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2700 if testing.Short() {
2701 t.Skip("skipping timeout test in -short mode")
2702 }
2703
2704 timeout := 2 * time.Millisecond
2705 retry := true
2706 for retry && !t.Failed() {
2707 var srvWG sync.WaitGroup
2708 inHandler := make(chan bool, 1)
2709 mux := NewServeMux()
2710 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2711 inHandler <- true
2712 srvWG.Done()
2713 })
2714 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2715 inHandler <- true
2716 <-r.Context().Done()
2717 srvWG.Done()
2718 })
2719 ts := newClientServerTest(t, mode, mux).ts
2720
2721 c := ts.Client()
2722 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2723
2724 retry = false
2725 srvWG.Add(3)
2726 tests := []struct {
2727 path string
2728 wantTimeout bool
2729 }{
2730 {path: "/fast"},
2731 {path: "/slow", wantTimeout: true},
2732 {path: "/fast"},
2733 }
2734 for i, tt := range tests {
2735 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2736 req = req.WithT(t)
2737 res, err := c.Do(req)
2738 <-inHandler
2739 if err != nil {
2740 uerr, ok := err.(*url.Error)
2741 if !ok {
2742 t.Errorf("error is not a url.Error; got: %#v", err)
2743 continue
2744 }
2745 nerr, ok := uerr.Err.(net.Error)
2746 if !ok {
2747 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2748 continue
2749 }
2750 if !nerr.Timeout() {
2751 t.Errorf("want timeout error; got: %q", nerr)
2752 continue
2753 }
2754 if !tt.wantTimeout {
2755 if !retry {
2756
2757 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2758 timeout *= 2
2759 retry = true
2760 }
2761 }
2762 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2763 t.Errorf("%d. unexpected error: %v", i, err)
2764 }
2765 continue
2766 }
2767 if tt.wantTimeout {
2768 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2769 continue
2770 }
2771 if res.StatusCode != 200 {
2772 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2773 }
2774 }
2775
2776 srvWG.Wait()
2777 ts.Close()
2778 }
2779 }
2780
2781
2782 type cancelTest struct {
2783 mode testMode
2784 newReq func(req *Request) *Request
2785 cancel func(tr *Transport, req *Request)
2786 checkErr func(when string, err error)
2787 }
2788
2789
2790 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2791 t.Run("TransportCancel", func(t *testing.T) {
2792 f(t, cancelTest{
2793 mode: mode,
2794 newReq: func(req *Request) *Request {
2795 return req
2796 },
2797 cancel: func(tr *Transport, req *Request) {
2798 tr.CancelRequest(req)
2799 },
2800 checkErr: func(when string, err error) {
2801 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2802 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2803 }
2804 },
2805 })
2806 })
2807 }
2808
2809
2810 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2811 cancelc := make(chan struct{})
2812 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2813 f(t, cancelTest{
2814 mode: mode,
2815 newReq: func(req *Request) *Request {
2816 req.Cancel = cancelc
2817 return req
2818 },
2819 cancel: func(tr *Transport, req *Request) {
2820 cancelOnce()
2821 },
2822 checkErr: func(when string, err error) {
2823 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2824 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2825 }
2826 },
2827 })
2828 }
2829
2830
2831 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2832 ctx, cancel := context.WithCancel(context.Background())
2833 f(t, cancelTest{
2834 mode: mode,
2835 newReq: func(req *Request) *Request {
2836 return req.WithContext(ctx)
2837 },
2838 cancel: func(tr *Transport, req *Request) {
2839 cancel()
2840 },
2841 checkErr: func(when string, err error) {
2842 if !errors.Is(err, context.Canceled) {
2843 t.Errorf("%v error = %v, want context.Canceled", when, err)
2844 }
2845 },
2846 })
2847 }
2848
2849 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2850 run(t, func(t *testing.T, mode testMode) {
2851 if mode == http1Mode {
2852 t.Run("TransportCancel", func(t *testing.T) {
2853 runCancelTestTransport(t, mode, f)
2854 })
2855 }
2856 t.Run("RequestCancel", func(t *testing.T) {
2857 runCancelTestChannel(t, mode, f)
2858 })
2859 t.Run("ContextCancel", func(t *testing.T) {
2860 runCancelTestContext(t, mode, f)
2861 })
2862 }, opts...)
2863 }
2864
2865 func TestTransportCancelRequest(t *testing.T) {
2866 runCancelTest(t, testTransportCancelRequest, http3SkippedMode)
2867 }
2868 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2869 if testing.Short() {
2870 t.Skip("skipping test in -short mode")
2871 }
2872
2873 const msg = "Hello"
2874 unblockc := make(chan bool)
2875 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2876 io.WriteString(w, msg)
2877 w.(Flusher).Flush()
2878 <-unblockc
2879 })).ts
2880 defer close(unblockc)
2881
2882 c := ts.Client()
2883 tr := c.Transport.(*Transport)
2884
2885 req, _ := NewRequest("GET", ts.URL, nil)
2886 req = test.newReq(req)
2887 res, err := c.Do(req)
2888 if err != nil {
2889 t.Fatal(err)
2890 }
2891 body := make([]byte, len(msg))
2892 n, _ := io.ReadFull(res.Body, body)
2893 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2894 t.Errorf("Body = %q; want %q", body[:n], msg)
2895 }
2896 test.cancel(tr, req)
2897
2898 tail, err := io.ReadAll(res.Body)
2899 res.Body.Close()
2900 test.checkErr("Body.Read", err)
2901 if len(tail) > 0 {
2902 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2903 }
2904
2905
2906
2907 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2908 n := tr.NumPendingRequestsForTesting()
2909 if n > 0 {
2910 if d > 0 {
2911 t.Logf("pending requests = %d after %v (want 0)", n, d)
2912 }
2913 return false
2914 }
2915 return true
2916 })
2917 }
2918
2919 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2920 if testing.Short() {
2921 t.Skip("skipping test in -short mode")
2922 }
2923 unblockc := make(chan bool)
2924 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2925 <-unblockc
2926 })).ts
2927 defer close(unblockc)
2928
2929 c := ts.Client()
2930 tr := c.Transport.(*Transport)
2931
2932 donec := make(chan bool)
2933 req, _ := NewRequest("GET", ts.URL, body)
2934 req = test.newReq(req)
2935 go func() {
2936 defer close(donec)
2937 c.Do(req)
2938 }()
2939
2940 unblockc <- true
2941 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2942 test.cancel(tr, req)
2943 select {
2944 case <-donec:
2945 return true
2946 default:
2947 if d > 0 {
2948 t.Logf("Do of canceled request has not returned after %v", d)
2949 }
2950 return false
2951 }
2952 })
2953 }
2954
2955 func TestTransportCancelRequestInDo(t *testing.T) {
2956 runCancelTest(t, func(t *testing.T, test cancelTest) {
2957 testTransportCancelRequestInDo(t, test, nil)
2958 })
2959 }
2960
2961 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2962 runCancelTest(t, func(t *testing.T, test cancelTest) {
2963 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2964 })
2965 }
2966
2967 func TestTransportCancelRequestInDial(t *testing.T) {
2968 runCancelTest(t, testTransportCancelRequestInDial)
2969 }
2970 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2971 defer afterTest(t)
2972 if testing.Short() {
2973 t.Skip("skipping test in -short mode")
2974 }
2975 var logbuf strings.Builder
2976 eventLog := log.New(&logbuf, "", 0)
2977
2978 unblockDial := make(chan bool)
2979 defer close(unblockDial)
2980
2981 inDial := make(chan bool)
2982 tr := &Transport{
2983 Dial: func(network, addr string) (net.Conn, error) {
2984 eventLog.Println("dial: blocking")
2985 if !<-inDial {
2986 return nil, errors.New("main Test goroutine exited")
2987 }
2988 <-unblockDial
2989 return nil, errors.New("nope")
2990 },
2991 }
2992 cl := &Client{Transport: tr}
2993 gotres := make(chan bool)
2994 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2995 req = test.newReq(req)
2996 go func() {
2997 _, err := cl.Do(req)
2998 eventLog.Printf("Get error = %v", err != nil)
2999 test.checkErr("Get", err)
3000 gotres <- true
3001 }()
3002
3003 inDial <- true
3004
3005 eventLog.Printf("canceling")
3006 test.cancel(tr, req)
3007 test.cancel(tr, req)
3008
3009 if d, ok := t.Deadline(); ok {
3010
3011
3012 timeout := time.Until(d) * 19 / 20
3013 timer := time.AfterFunc(timeout, func() {
3014 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
3015 })
3016 defer timer.Stop()
3017 }
3018 <-gotres
3019
3020 got := logbuf.String()
3021 want := `dial: blocking
3022 canceling
3023 Get error = true
3024 `
3025 if got != want {
3026 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
3027 }
3028 }
3029
3030
3031 func TestTransportCancelRequestWithBody(t *testing.T) {
3032 runCancelTest(t, testTransportCancelRequestWithBody, http3SkippedMode)
3033 }
3034 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
3035 if testing.Short() {
3036 t.Skip("skipping test in -short mode")
3037 }
3038
3039 const msg = "Hello"
3040 unblockc := make(chan struct{})
3041 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3042 io.WriteString(w, msg)
3043 w.(Flusher).Flush()
3044 <-unblockc
3045 })).ts
3046 defer close(unblockc)
3047
3048 c := ts.Client()
3049 tr := c.Transport.(*Transport)
3050
3051 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
3052 req = test.newReq(req)
3053
3054 res, err := c.Do(req)
3055 if err != nil {
3056 t.Fatal(err)
3057 }
3058 body := make([]byte, len(msg))
3059 n, _ := io.ReadFull(res.Body, body)
3060 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
3061 t.Errorf("Body = %q; want %q", body[:n], msg)
3062 }
3063 test.cancel(tr, req)
3064
3065 tail, err := io.ReadAll(res.Body)
3066 res.Body.Close()
3067 test.checkErr("Body.Read", err)
3068 if len(tail) > 0 {
3069 t.Errorf("Spurious bytes from Body.Read: %q", tail)
3070 }
3071
3072
3073
3074 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
3075 n := tr.NumPendingRequestsForTesting()
3076 if n > 0 {
3077 if d > 0 {
3078 t.Logf("pending requests = %d after %v (want 0)", n, d)
3079 }
3080 return false
3081 }
3082 return true
3083 })
3084 }
3085
3086 func TestTransportCancelRequestBeforeDo(t *testing.T) {
3087
3088 run(t, func(t *testing.T, mode testMode) {
3089 t.Run("RequestCancel", func(t *testing.T) {
3090 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
3091 })
3092 t.Run("ContextCancel", func(t *testing.T) {
3093 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
3094 })
3095 }, http3SkippedMode)
3096 }
3097 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
3098 unblockc := make(chan bool)
3099 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3100 <-unblockc
3101 }))
3102 defer close(unblockc)
3103
3104 c := cst.ts.Client()
3105
3106 req, _ := NewRequest("GET", cst.ts.URL, nil)
3107 req = test.newReq(req)
3108 test.cancel(cst.tr, req)
3109
3110 _, err := c.Do(req)
3111 test.checkErr("Do", err)
3112 }
3113
3114
3115 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
3116 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
3117 }
3118 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
3119 defer afterTest(t)
3120
3121 serverConnCh := make(chan net.Conn, 1)
3122 tr := &Transport{
3123 Dial: func(network, addr string) (net.Conn, error) {
3124 cc, sc := net.Pipe()
3125 serverConnCh <- sc
3126 return cc, nil
3127 },
3128 }
3129 defer tr.CloseIdleConnections()
3130 errc := make(chan error, 1)
3131 req, _ := NewRequest("GET", "http://example.com/", nil)
3132 req = test.newReq(req)
3133 go func() {
3134 _, err := tr.RoundTrip(req)
3135 errc <- err
3136 }()
3137
3138 sc := <-serverConnCh
3139 verb := make([]byte, 3)
3140 if _, err := io.ReadFull(sc, verb); err != nil {
3141 t.Errorf("Error reading HTTP verb from server: %v", err)
3142 }
3143 if string(verb) != "GET" {
3144 t.Errorf("server received %q; want GET", verb)
3145 }
3146 defer sc.Close()
3147
3148 test.cancel(tr, req)
3149
3150 err := <-errc
3151 if err == nil {
3152 t.Fatalf("unexpected success from RoundTrip")
3153 }
3154 test.checkErr("RoundTrip", err)
3155 }
3156
3157
3158
3159
3160 func TestTransportCloseResponseBody(t *testing.T) {
3161 run(t, testTransportCloseResponseBody, http3SkippedMode)
3162 }
3163 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
3164 writeErr := make(chan error, 1)
3165 msg := []byte("young\n")
3166 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3167 for {
3168 _, err := w.Write(msg)
3169 if err != nil {
3170 writeErr <- err
3171 return
3172 }
3173 w.(Flusher).Flush()
3174 }
3175 })).ts
3176
3177 c := ts.Client()
3178 tr := c.Transport.(*Transport)
3179
3180 req, _ := NewRequest("GET", ts.URL, nil)
3181 defer tr.CancelRequest(req)
3182
3183 res, err := c.Do(req)
3184 if err != nil {
3185 t.Fatal(err)
3186 }
3187
3188 const repeats = 3
3189 buf := make([]byte, len(msg)*repeats)
3190 want := bytes.Repeat(msg, repeats)
3191
3192 _, err = io.ReadFull(res.Body, buf)
3193 if err != nil {
3194 t.Fatal(err)
3195 }
3196 if !bytes.Equal(buf, want) {
3197 t.Fatalf("read %q; want %q", buf, want)
3198 }
3199
3200 if err := res.Body.Close(); err != nil {
3201 t.Errorf("Close = %v", err)
3202 }
3203
3204 if err := <-writeErr; err == nil {
3205 t.Errorf("expected non-nil write error")
3206 }
3207 }
3208
3209 type fooProto struct{}
3210
3211 func (fooProto) RoundTrip(req *Request) (*Response, error) {
3212 res := &Response{
3213 Status: "200 OK",
3214 StatusCode: 200,
3215 Header: make(Header),
3216 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
3217 }
3218 return res, nil
3219 }
3220
3221 func TestTransportAltProto(t *testing.T) {
3222 defer afterTest(t)
3223 tr := &Transport{}
3224 c := &Client{Transport: tr}
3225 tr.RegisterProtocol("foo", fooProto{})
3226 res, err := c.Get("foo://bar.com/path")
3227 if err != nil {
3228 t.Fatal(err)
3229 }
3230 bodyb, err := io.ReadAll(res.Body)
3231 if err != nil {
3232 t.Fatal(err)
3233 }
3234 body := string(bodyb)
3235 if e := "You wanted foo://bar.com/path"; body != e {
3236 t.Errorf("got response %q, want %q", body, e)
3237 }
3238 }
3239
3240 func TestTransportNoHost(t *testing.T) {
3241 defer afterTest(t)
3242 tr := &Transport{}
3243 _, err := tr.RoundTrip(&Request{
3244 Header: make(Header),
3245 URL: &url.URL{
3246 Scheme: "http",
3247 },
3248 })
3249 want := "http: no Host in request URL"
3250 if got := fmt.Sprint(err); got != want {
3251 t.Errorf("error = %v; want %q", err, want)
3252 }
3253 }
3254
3255
3256 func TestTransportEmptyMethod(t *testing.T) {
3257 req, _ := NewRequest("GET", "http://foo.com/", nil)
3258 req.Method = ""
3259 got, err := httputil.DumpRequestOut(req, false)
3260 if err != nil {
3261 t.Fatal(err)
3262 }
3263 if !strings.Contains(string(got), "GET ") {
3264 t.Fatalf("expected substring 'GET '; got: %s", got)
3265 }
3266 }
3267
3268 func TestTransportSocketLateBinding(t *testing.T) {
3269 run(t, testTransportSocketLateBinding, http3SkippedMode)
3270 }
3271 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3272 mux := NewServeMux()
3273 fooGate := make(chan bool, 1)
3274 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3275 w.Header().Set("foo-ipport", r.RemoteAddr)
3276 w.(Flusher).Flush()
3277 <-fooGate
3278 })
3279 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3280 w.Header().Set("bar-ipport", r.RemoteAddr)
3281 })
3282 ts := newClientServerTest(t, mode, mux).ts
3283
3284 dialGate := make(chan bool, 1)
3285 dialing := make(chan bool)
3286 c := ts.Client()
3287 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3288 for {
3289 select {
3290 case ok := <-dialGate:
3291 if !ok {
3292 return nil, errors.New("manually closed")
3293 }
3294 return net.Dial(n, addr)
3295 case dialing <- true:
3296 }
3297 }
3298 }
3299 defer close(dialGate)
3300
3301 dialGate <- true
3302 fooRes, err := c.Get(ts.URL + "/foo")
3303 if err != nil {
3304 t.Fatal(err)
3305 }
3306 fooAddr := fooRes.Header.Get("foo-ipport")
3307 if fooAddr == "" {
3308 t.Fatal("No addr on /foo request")
3309 }
3310
3311 fooDone := make(chan struct{})
3312 go func() {
3313
3314
3315
3316
3317 if mode == http2Mode {
3318
3319
3320
3321
3322 select {
3323 case <-dialing:
3324 t.Errorf("unexpected second Dial in HTTP/2 mode")
3325 case <-time.After(10 * time.Millisecond):
3326 }
3327 } else {
3328 <-dialing
3329 }
3330 fooGate <- true
3331 io.Copy(io.Discard, fooRes.Body)
3332 fooRes.Body.Close()
3333 close(fooDone)
3334 }()
3335 defer func() {
3336 <-fooDone
3337 }()
3338
3339 barRes, err := c.Get(ts.URL + "/bar")
3340 if err != nil {
3341 t.Fatal(err)
3342 }
3343 barAddr := barRes.Header.Get("bar-ipport")
3344 if barAddr != fooAddr {
3345 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3346 }
3347 barRes.Body.Close()
3348 }
3349
3350
3351 func TestTransportReading100Continue(t *testing.T) {
3352 defer afterTest(t)
3353
3354 const numReqs = 5
3355 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3356 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3357
3358 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3359 defer w.Close()
3360 defer r.Close()
3361 br := bufio.NewReader(r)
3362 n := 0
3363 for {
3364 n++
3365 req, err := ReadRequest(br)
3366 if err == io.EOF {
3367 return
3368 }
3369 if err != nil {
3370 t.Error(err)
3371 return
3372 }
3373 slurp, err := io.ReadAll(req.Body)
3374 if err != nil {
3375 t.Errorf("Server request body slurp: %v", err)
3376 return
3377 }
3378 id := req.Header.Get("Request-Id")
3379 resCode := req.Header.Get("X-Want-Response-Code")
3380 if resCode == "" {
3381 resCode = "100 Continue"
3382 if string(slurp) != reqBody(n) {
3383 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3384 }
3385 }
3386 body := fmt.Sprintf("Response number %d", n)
3387 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3388 Date: Thu, 28 Feb 2013 17:55:41 GMT
3389
3390 HTTP/1.1 200 OK
3391 Content-Type: text/html
3392 Echo-Request-Id: %s
3393 Content-Length: %d
3394
3395 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3396 w.Write(v)
3397 if id == reqID(numReqs) {
3398 return
3399 }
3400 }
3401
3402 }
3403
3404 tr := &Transport{
3405 Dial: func(n, addr string) (net.Conn, error) {
3406 sr, sw := io.Pipe()
3407 cr, cw := io.Pipe()
3408 conn := &rwTestConn{
3409 Reader: cr,
3410 Writer: sw,
3411 closeFunc: func() error {
3412 sw.Close()
3413 cw.Close()
3414 return nil
3415 },
3416 }
3417 go send100Response(cw, sr)
3418 return conn, nil
3419 },
3420 DisableKeepAlives: false,
3421 }
3422 defer tr.CloseIdleConnections()
3423 c := &Client{Transport: tr}
3424
3425 testResponse := func(req *Request, name string, wantCode int) {
3426 t.Helper()
3427 res, err := c.Do(req)
3428 if err != nil {
3429 t.Fatalf("%s: Do: %v", name, err)
3430 }
3431 if res.StatusCode != wantCode {
3432 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3433 }
3434 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3435 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3436 }
3437 _, err = io.ReadAll(res.Body)
3438 if err != nil {
3439 t.Fatalf("%s: Slurp error: %v", name, err)
3440 }
3441 }
3442
3443
3444 for i := 1; i <= numReqs; i++ {
3445 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3446 req.Header.Set("Request-Id", reqID(i))
3447 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3448 }
3449 }
3450
3451
3452
3453 func TestTransportIgnore1xxResponses(t *testing.T) {
3454 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3455 }
3456 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3457 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3458 conn, buf, _ := w.(Hijacker).Hijack()
3459 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3460 buf.Flush()
3461 conn.Close()
3462 }))
3463 cst.tr.DisableKeepAlives = true
3464
3465 var got strings.Builder
3466
3467 req, _ := NewRequest("GET", cst.ts.URL, nil)
3468 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3469 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3470 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3471 return nil
3472 },
3473 }))
3474 res, err := cst.c.Do(req)
3475 if err != nil {
3476 t.Fatal(err)
3477 }
3478 defer res.Body.Close()
3479
3480 res.Write(&got)
3481 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3482 if got.String() != want {
3483 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3484 }
3485 }
3486
3487 func TestTransportLimits1xxResponses(t *testing.T) {
3488 run(t, testTransportLimits1xxResponses, http3SkippedMode)
3489 }
3490 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3491 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3492 w.Header().Add("X-Header", strings.Repeat("a", 100))
3493 for i := 0; i < 10; i++ {
3494 w.WriteHeader(123)
3495 }
3496 w.WriteHeader(204)
3497 }))
3498 cst.tr.DisableKeepAlives = true
3499 cst.tr.MaxResponseHeaderBytes = 1000
3500
3501 res, err := cst.c.Get(cst.ts.URL)
3502 if err == nil {
3503 res.Body.Close()
3504 t.Fatalf("RoundTrip succeeded; want error")
3505 }
3506 for _, want := range []string{
3507 "response headers exceeded",
3508 "too many 1xx",
3509 "header list too large",
3510 } {
3511 if strings.Contains(err.Error(), want) {
3512 return
3513 }
3514 }
3515 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3516 }
3517
3518 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3519 run(t, testTransportDoesNotLimitDelivered1xxResponses, http3SkippedMode)
3520 }
3521 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3522 if mode == http2Mode {
3523 t.Skip("skip until x/net/http2 updated")
3524 }
3525 const num1xx = 10
3526 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3527 w.Header().Add("X-Header", strings.Repeat("a", 100))
3528 for i := 0; i < 10; i++ {
3529 w.WriteHeader(123)
3530 }
3531 w.WriteHeader(204)
3532 }))
3533 cst.tr.DisableKeepAlives = true
3534 cst.tr.MaxResponseHeaderBytes = 1000
3535
3536 got1xx := 0
3537 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3538 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3539 got1xx++
3540 return nil
3541 },
3542 })
3543 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3544 res, err := cst.c.Do(req)
3545 if err != nil {
3546 t.Fatal(err)
3547 }
3548 res.Body.Close()
3549 if got1xx != num1xx {
3550 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3551 }
3552 }
3553
3554
3555
3556 func TestTransportTreat101Terminal(t *testing.T) {
3557 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3558 }
3559 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3560 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3561 conn, buf, _ := w.(Hijacker).Hijack()
3562 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3563 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3564 buf.Flush()
3565 conn.Close()
3566 }))
3567 res, err := cst.c.Get(cst.ts.URL)
3568 if err != nil {
3569 t.Fatal(err)
3570 }
3571 defer res.Body.Close()
3572 if res.StatusCode != StatusSwitchingProtocols {
3573 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3574 }
3575 }
3576
3577 type proxyFromEnvTest struct {
3578 req string
3579
3580 env string
3581 httpsenv string
3582 noenv string
3583 reqmeth string
3584
3585 want string
3586 wanterr error
3587 }
3588
3589 func (t proxyFromEnvTest) String() string {
3590 var buf strings.Builder
3591 space := func() {
3592 if buf.Len() > 0 {
3593 buf.WriteByte(' ')
3594 }
3595 }
3596 if t.env != "" {
3597 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3598 }
3599 if t.httpsenv != "" {
3600 space()
3601 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3602 }
3603 if t.noenv != "" {
3604 space()
3605 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3606 }
3607 if t.reqmeth != "" {
3608 space()
3609 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3610 }
3611 req := "http://example.com"
3612 if t.req != "" {
3613 req = t.req
3614 }
3615 space()
3616 fmt.Fprintf(&buf, "req=%q", req)
3617 return strings.TrimSpace(buf.String())
3618 }
3619
3620 var proxyFromEnvTests = []proxyFromEnvTest{
3621 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3622 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3623 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3624 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3625 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3626 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3627 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3628 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3629
3630
3631 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3632
3633 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3634 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3635
3636
3637
3638 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3639 want: "<nil>",
3640 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3641
3642 {want: "<nil>"},
3643
3644 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3645 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3646 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3647 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3648 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3649 }
3650
3651 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3652 t.Helper()
3653 reqURL := tt.req
3654 if reqURL == "" {
3655 reqURL = "http://example.com"
3656 }
3657 req, _ := NewRequest("GET", reqURL, nil)
3658 url, err := proxyForRequest(req)
3659 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3660 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3661 return
3662 }
3663 if got := fmt.Sprintf("%s", url); got != tt.want {
3664 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3665 }
3666 }
3667
3668 func TestProxyFromEnvironment(t *testing.T) {
3669 ResetProxyEnv()
3670 defer ResetProxyEnv()
3671 for _, tt := range proxyFromEnvTests {
3672 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3673 os.Setenv("HTTP_PROXY", tt.env)
3674 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3675 os.Setenv("NO_PROXY", tt.noenv)
3676 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3677 ResetCachedEnvironment()
3678 return ProxyFromEnvironment(req)
3679 })
3680 }
3681 }
3682
3683 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3684 ResetProxyEnv()
3685 defer ResetProxyEnv()
3686 for _, tt := range proxyFromEnvTests {
3687 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3688 os.Setenv("http_proxy", tt.env)
3689 os.Setenv("https_proxy", tt.httpsenv)
3690 os.Setenv("no_proxy", tt.noenv)
3691 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3692 ResetCachedEnvironment()
3693 return ProxyFromEnvironment(req)
3694 })
3695 }
3696 }
3697
3698 func TestIdleConnChannelLeak(t *testing.T) {
3699 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3700 }
3701 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3702
3703 var mu sync.Mutex
3704 var n int
3705
3706 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3707 mu.Lock()
3708 n++
3709 mu.Unlock()
3710 })).ts
3711
3712 const nReqs = 5
3713 didRead := make(chan bool, nReqs)
3714 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3715 defer SetReadLoopBeforeNextReadHook(nil)
3716
3717 c := ts.Client()
3718 tr := c.Transport.(*Transport)
3719 tr.Dial = func(netw, addr string) (net.Conn, error) {
3720 return net.Dial(netw, ts.Listener.Addr().String())
3721 }
3722
3723
3724 for _, disableKeep := range []bool{true, false} {
3725 tr.DisableKeepAlives = disableKeep
3726 for i := 0; i < nReqs; i++ {
3727 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3728 if err != nil {
3729 t.Fatal(err)
3730 }
3731
3732
3733
3734
3735
3736 }
3737
3738
3739
3740
3741
3742
3743
3744 for i := 0; i < nReqs; i++ {
3745 <-didRead
3746 }
3747
3748 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3749 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3750 }
3751 }
3752 }
3753
3754
3755
3756
3757 func TestTransportClosesRequestBody(t *testing.T) {
3758 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3759 }
3760 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3761 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3762 io.Copy(io.Discard, r.Body)
3763 })).ts
3764
3765 c := ts.Client()
3766
3767 closes := 0
3768
3769 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3770 if err != nil {
3771 t.Fatal(err)
3772 }
3773 res.Body.Close()
3774 if closes != 1 {
3775 t.Errorf("closes = %d; want 1", closes)
3776 }
3777 }
3778
3779 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3780 defer afterTest(t)
3781 if testing.Short() {
3782 t.Skip("skipping in short mode")
3783 }
3784 ln := newLocalListener(t)
3785 defer ln.Close()
3786 testdonec := make(chan struct{})
3787 defer close(testdonec)
3788
3789 go func() {
3790 c, err := ln.Accept()
3791 if err != nil {
3792 t.Error(err)
3793 return
3794 }
3795 <-testdonec
3796 c.Close()
3797 }()
3798
3799 tr := &Transport{
3800 Dial: func(_, _ string) (net.Conn, error) {
3801 return net.Dial("tcp", ln.Addr().String())
3802 },
3803 TLSHandshakeTimeout: 250 * time.Millisecond,
3804 }
3805 cl := &Client{Transport: tr}
3806 _, err := cl.Get("https://dummy.tld/")
3807 if err == nil {
3808 t.Error("expected error")
3809 return
3810 }
3811 ue, ok := err.(*url.Error)
3812 if !ok {
3813 t.Errorf("expected url.Error; got %#v", err)
3814 return
3815 }
3816 ne, ok := ue.Err.(net.Error)
3817 if !ok {
3818 t.Errorf("expected net.Error; got %#v", err)
3819 return
3820 }
3821 if !ne.Timeout() {
3822 t.Errorf("expected timeout error; got %v", err)
3823 }
3824 if !strings.Contains(err.Error(), "handshake timeout") {
3825 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3826 }
3827 }
3828
3829
3830 func TestTLSServerClosesConnection(t *testing.T) {
3831 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3832 }
3833 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3834 closedc := make(chan bool, 1)
3835 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3836 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3837 conn, _, _ := w.(Hijacker).Hijack()
3838 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3839 conn.Close()
3840 closedc <- true
3841 return
3842 }
3843 fmt.Fprintf(w, "hello")
3844 })).ts
3845
3846 c := ts.Client()
3847 tr := c.Transport.(*Transport)
3848
3849 var nSuccess = 0
3850 var errs []error
3851 const trials = 20
3852 for i := 0; i < trials; i++ {
3853 tr.CloseIdleConnections()
3854 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3855 if err != nil {
3856 t.Fatal(err)
3857 }
3858 <-closedc
3859 slurp, err := io.ReadAll(res.Body)
3860 if err != nil {
3861 t.Fatal(err)
3862 }
3863 if string(slurp) != "foo" {
3864 t.Errorf("Got %q, want foo", slurp)
3865 }
3866
3867
3868
3869 res, err = c.Get(ts.URL + "/")
3870 if err != nil {
3871 errs = append(errs, err)
3872 continue
3873 }
3874 slurp, err = io.ReadAll(res.Body)
3875 if err != nil {
3876 errs = append(errs, err)
3877 continue
3878 }
3879 nSuccess++
3880 }
3881 if nSuccess > 0 {
3882 t.Logf("successes = %d of %d", nSuccess, trials)
3883 } else {
3884 t.Errorf("All runs failed:")
3885 }
3886 for _, err := range errs {
3887 t.Logf(" err: %v", err)
3888 }
3889 }
3890
3891
3892
3893
3894 type byteFromChanReader chan byte
3895
3896 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3897 if len(p) == 0 {
3898 return
3899 }
3900 b, ok := <-c
3901 if !ok {
3902 return 0, io.EOF
3903 }
3904 p[0] = b
3905 return 1, nil
3906 }
3907
3908
3909
3910
3911
3912
3913
3914 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3915 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3916 }
3917 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3918 defer func(d time.Duration) {
3919 *MaxWriteWaitBeforeConnReuse = d
3920 }(*MaxWriteWaitBeforeConnReuse)
3921 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3922 var sconn struct {
3923 sync.Mutex
3924 c net.Conn
3925 }
3926 var getOkay bool
3927 var willCopy sync.WaitGroup
3928 closeConn := func() {
3929 sconn.Lock()
3930 defer sconn.Unlock()
3931 if sconn.c != nil {
3932 sconn.c.Close()
3933 sconn.c = nil
3934 if !getOkay {
3935 t.Logf("Closed server connection")
3936 }
3937 }
3938 }
3939 defer func() {
3940 closeConn()
3941 willCopy.Wait()
3942 }()
3943
3944 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3945 if r.Method == "GET" {
3946 io.WriteString(w, "bar")
3947 return
3948 }
3949 conn, _, _ := w.(Hijacker).Hijack()
3950 sconn.Lock()
3951 sconn.c = conn
3952 sconn.Unlock()
3953
3954 willCopy.Add(1)
3955 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3956 go func() {
3957 io.Copy(io.Discard, conn)
3958 willCopy.Done()
3959 }()
3960 })).ts
3961 c := ts.Client()
3962
3963 const bodySize = 256 << 10
3964 finalBit := make(byteFromChanReader, 1)
3965 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3966 req.ContentLength = bodySize
3967 res, err := c.Do(req)
3968 if err := wantBody(res, err, "foo"); err != nil {
3969 t.Errorf("POST response: %v", err)
3970 }
3971
3972 res, err = c.Get(ts.URL)
3973 if err := wantBody(res, err, "bar"); err != nil {
3974 t.Errorf("GET response: %v", err)
3975 return
3976 }
3977 getOkay = true
3978 finalBit <- 'x'
3979 close(finalBit)
3980 }
3981
3982
3983
3984 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3985 func testTransportIssue10457(t *testing.T, mode testMode) {
3986 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3987
3988
3989
3990
3991
3992 conn, _, _ := w.(Hijacker).Hijack()
3993 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3994 conn.Close()
3995 })).ts
3996 c := ts.Client()
3997
3998 res, err := c.Get(ts.URL)
3999 if err != nil {
4000 t.Fatalf("Get: %v", err)
4001 }
4002 defer res.Body.Close()
4003
4004
4005
4006
4007 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
4008 t.Errorf("Foo header = %q; want %q", got, want)
4009 }
4010 }
4011
4012 type closerFunc func() error
4013
4014 func (f closerFunc) Close() error { return f() }
4015
4016 type writerFuncConn struct {
4017 net.Conn
4018 write func(p []byte) (n int, err error)
4019 }
4020
4021 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
4022
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035 func TestRetryRequestsOnError(t *testing.T) {
4036 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
4037 }
4038 func testRetryRequestsOnError(t *testing.T, mode testMode) {
4039 newRequest := func(method, urlStr string, body io.Reader) *Request {
4040 req, err := NewRequest(method, urlStr, body)
4041 if err != nil {
4042 t.Fatal(err)
4043 }
4044 return req
4045 }
4046
4047 testCases := []struct {
4048 name string
4049 failureN int
4050 failureErr error
4051
4052
4053
4054 req func() *Request
4055 reqString string
4056 }{
4057 {
4058 name: "IdempotentNoBodySomeWritten",
4059
4060
4061 failureN: 1,
4062
4063 failureErr: ExportErrServerClosedIdle,
4064 req: func() *Request {
4065 return newRequest("GET", "http://fake.golang", nil)
4066 },
4067 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
4068 },
4069 {
4070 name: "IdempotentGetBodySomeWritten",
4071
4072
4073 failureN: 1,
4074
4075 failureErr: ExportErrServerClosedIdle,
4076 req: func() *Request {
4077 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
4078 },
4079 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
4080 },
4081 {
4082 name: "NothingWrittenNoBody",
4083
4084
4085 failureN: 0,
4086 failureErr: errors.New("second write fails"),
4087 req: func() *Request {
4088 return newRequest("DELETE", "http://fake.golang", nil)
4089 },
4090 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
4091 },
4092 {
4093 name: "NothingWrittenGetBody",
4094
4095
4096 failureN: 0,
4097 failureErr: errors.New("second write fails"),
4098
4099
4100 req: func() *Request {
4101 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
4102 },
4103 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
4104 },
4105 }
4106
4107 for _, tc := range testCases {
4108 t.Run(tc.name, func(t *testing.T) {
4109 var (
4110 mu sync.Mutex
4111 logbuf strings.Builder
4112 )
4113 logf := func(format string, args ...any) {
4114 mu.Lock()
4115 defer mu.Unlock()
4116 fmt.Fprintf(&logbuf, format, args...)
4117 logbuf.WriteByte('\n')
4118 }
4119
4120 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4121 logf("Handler")
4122 w.Header().Set("X-Status", "ok")
4123 })).ts
4124
4125 var writeNumAtomic int32
4126 c := ts.Client()
4127 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
4128 logf("Dial")
4129 c, err := net.Dial(network, ts.Listener.Addr().String())
4130 if err != nil {
4131 logf("Dial error: %v", err)
4132 return nil, err
4133 }
4134 return &writerFuncConn{
4135 Conn: c,
4136 write: func(p []byte) (n int, err error) {
4137 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
4138 logf("intentional write failure")
4139 return tc.failureN, tc.failureErr
4140 }
4141 logf("Write(%q)", p)
4142 return c.Write(p)
4143 },
4144 }, nil
4145 }
4146
4147 SetRoundTripRetried(func() {
4148 logf("Retried.")
4149 })
4150 defer SetRoundTripRetried(nil)
4151
4152 for i := 0; i < 3; i++ {
4153 t0 := time.Now()
4154 req := tc.req()
4155 res, err := c.Do(req)
4156 if err != nil {
4157 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
4158 mu.Lock()
4159 got := logbuf.String()
4160 mu.Unlock()
4161 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
4162 }
4163 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
4164 }
4165 res.Body.Close()
4166 if res.Request != req {
4167 t.Errorf("Response.Request != original request; want identical Request")
4168 }
4169 }
4170
4171 mu.Lock()
4172 got := logbuf.String()
4173 mu.Unlock()
4174 want := fmt.Sprintf(`Dial
4175 Write("%s")
4176 Handler
4177 intentional write failure
4178 Retried.
4179 Dial
4180 Write("%s")
4181 Handler
4182 Write("%s")
4183 Handler
4184 `, tc.reqString, tc.reqString, tc.reqString)
4185 if got != want {
4186 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
4187 }
4188 })
4189 }
4190 }
4191
4192
4193 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
4194 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
4195 readBody := make(chan error, 1)
4196 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4197 _, err := io.ReadAll(r.Body)
4198 readBody <- err
4199 })).ts
4200 c := ts.Client()
4201 fakeErr := errors.New("fake error")
4202 didClose := make(chan bool, 1)
4203 req, _ := NewRequest("POST", ts.URL, struct {
4204 io.Reader
4205 io.Closer
4206 }{
4207 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
4208 closerFunc(func() error {
4209 select {
4210 case didClose <- true:
4211 default:
4212 }
4213 return nil
4214 }),
4215 })
4216
4217
4218
4219
4220
4221
4222
4223 req.Header.Set("Expect", "100-continue")
4224 res, err := c.Do(req)
4225 if res != nil {
4226 defer res.Body.Close()
4227 }
4228 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
4229 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
4230 }
4231 if err := <-readBody; err == nil {
4232 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
4233 }
4234 select {
4235 case <-didClose:
4236 default:
4237 t.Errorf("didn't see Body.Close")
4238 }
4239 }
4240
4241 func TestTransportDialTLS(t *testing.T) {
4242 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4243 }
4244 func testTransportDialTLS(t *testing.T, mode testMode) {
4245 var mu sync.Mutex
4246 var gotReq, didDial bool
4247
4248 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4249 mu.Lock()
4250 gotReq = true
4251 mu.Unlock()
4252 })).ts
4253 c := ts.Client()
4254 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4255 mu.Lock()
4256 didDial = true
4257 mu.Unlock()
4258 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4259 if err != nil {
4260 return nil, err
4261 }
4262 return c, c.Handshake()
4263 }
4264
4265 res, err := c.Get(ts.URL)
4266 if err != nil {
4267 t.Fatal(err)
4268 }
4269 res.Body.Close()
4270 mu.Lock()
4271 if !gotReq {
4272 t.Error("didn't get request")
4273 }
4274 if !didDial {
4275 t.Error("didn't use dial hook")
4276 }
4277 }
4278
4279 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext, http3SkippedMode) }
4280 func testTransportDialContext(t *testing.T, mode testMode) {
4281 ctxKey := "some-key"
4282 ctxValue := "some-value"
4283 var (
4284 mu sync.Mutex
4285 gotReq bool
4286 gotCtxValue any
4287 )
4288
4289 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4290 mu.Lock()
4291 gotReq = true
4292 mu.Unlock()
4293 })).ts
4294 c := ts.Client()
4295 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4296 mu.Lock()
4297 gotCtxValue = ctx.Value(ctxKey)
4298 mu.Unlock()
4299 return net.Dial(netw, addr)
4300 }
4301
4302 req, err := NewRequest("GET", ts.URL, nil)
4303 if err != nil {
4304 t.Fatal(err)
4305 }
4306 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4307 res, err := c.Do(req.WithContext(ctx))
4308 if err != nil {
4309 t.Fatal(err)
4310 }
4311 res.Body.Close()
4312 mu.Lock()
4313 if !gotReq {
4314 t.Error("didn't get request")
4315 }
4316 if got, want := gotCtxValue, ctxValue; got != want {
4317 t.Errorf("got context with value %v, want %v", got, want)
4318 }
4319 }
4320
4321 func TestTransportDialTLSContext(t *testing.T) {
4322 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4323 }
4324 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4325 ctxKey := "some-key"
4326 ctxValue := "some-value"
4327 var (
4328 mu sync.Mutex
4329 gotReq bool
4330 gotCtxValue any
4331 )
4332
4333 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4334 mu.Lock()
4335 gotReq = true
4336 mu.Unlock()
4337 })).ts
4338 c := ts.Client()
4339 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4340 mu.Lock()
4341 gotCtxValue = ctx.Value(ctxKey)
4342 mu.Unlock()
4343 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4344 if err != nil {
4345 return nil, err
4346 }
4347 return c, c.HandshakeContext(ctx)
4348 }
4349
4350 req, err := NewRequest("GET", ts.URL, nil)
4351 if err != nil {
4352 t.Fatal(err)
4353 }
4354 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4355 res, err := c.Do(req.WithContext(ctx))
4356 if err != nil {
4357 t.Fatal(err)
4358 }
4359 res.Body.Close()
4360 mu.Lock()
4361 if !gotReq {
4362 t.Error("didn't get request")
4363 }
4364 if got, want := gotCtxValue, ctxValue; got != want {
4365 t.Errorf("got context with value %v, want %v", got, want)
4366 }
4367 }
4368
4369
4370
4371 func TestRoundTripReturnsProxyError(t *testing.T) {
4372 badProxy := func(*Request) (*url.URL, error) {
4373 return nil, errors.New("errorMessage")
4374 }
4375
4376 tr := &Transport{Proxy: badProxy}
4377
4378 req, _ := NewRequest("GET", "http://example.com", nil)
4379
4380 _, err := tr.RoundTrip(req)
4381
4382 if err == nil {
4383 t.Error("Expected proxy error to be returned by RoundTrip")
4384 }
4385 }
4386
4387
4388 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4389 tr := &Transport{}
4390 wantIdle := func(when string, n int) bool {
4391 got := tr.IdleConnCountForTesting("http", "example.com")
4392 if got == n {
4393 return true
4394 }
4395 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4396 return false
4397 }
4398 wantIdle("start", 0)
4399 if !tr.PutIdleTestConn("http", "example.com") {
4400 t.Fatal("put failed")
4401 }
4402 if !tr.PutIdleTestConn("http", "example.com") {
4403 t.Fatal("second put failed")
4404 }
4405 wantIdle("after put", 2)
4406 tr.CloseIdleConnections()
4407 if !tr.IsIdleForTesting() {
4408 t.Error("should be idle after CloseIdleConnections")
4409 }
4410 wantIdle("after close idle", 0)
4411 if tr.PutIdleTestConn("http", "example.com") {
4412 t.Fatal("put didn't fail")
4413 }
4414 wantIdle("after second put", 0)
4415
4416 tr.QueueForIdleConnForTesting()
4417 if tr.IsIdleForTesting() {
4418 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4419 }
4420 if !tr.PutIdleTestConn("http", "example.com") {
4421 t.Fatal("after re-activation")
4422 }
4423 wantIdle("after final put", 1)
4424 }
4425
4426
4427
4428 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4429 tr := &Transport{}
4430 wantIdle := func(when string, n int) bool {
4431 got := tr.IdleConnCountForTesting("https", "example.com:443")
4432 if got == n {
4433 return true
4434 }
4435 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4436 return false
4437 }
4438 wantIdle("start", 0)
4439 alt := funcRoundTripper(func() {})
4440 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4441 t.Fatal("put failed")
4442 }
4443 wantIdle("after put", 1)
4444 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4445 GotConn: func(httptrace.GotConnInfo) {
4446
4447 t.Error("GotConn called")
4448 },
4449 })
4450 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4451 _, err := tr.RoundTrip(req)
4452 if err != errFakeRoundTrip {
4453 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4454 }
4455 wantIdle("after round trip", 1)
4456 }
4457
4458
4459
4460
4461
4462
4463 func TestTransportIdleConnRacesRequest(t *testing.T) {
4464
4465
4466 runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
4467 }
4468 func testTransportIdleConnRacesRequest(t *testing.T, mode testMode) {
4469 if mode == http2UnencryptedMode {
4470 t.Skip("remove skip when #70515 is fixed")
4471 }
4472 timeout := 1 * time.Millisecond
4473 trFunc := func(tr *Transport) {
4474 tr.IdleConnTimeout = timeout
4475 }
4476 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4477 }), trFunc, optFakeNet)
4478 cst.li.trackConns = true
4479
4480
4481
4482
4483
4484 dialc := make(chan struct{})
4485 cst.li.onDial = func() {
4486 <-dialc
4487 }
4488 closec := make(chan struct{})
4489 cst.li.onClose = func(*fakeNetConn) {
4490 <-closec
4491 }
4492 ctx, cancel := context.WithCancel(context.Background())
4493 req1c := make(chan error)
4494 go func() {
4495 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
4496 resp, err := cst.c.Do(req)
4497 if err == nil {
4498 resp.Body.Close()
4499 }
4500 req1c <- err
4501 }()
4502
4503 synctest.Wait()
4504
4505 cancel()
4506 synctest.Wait()
4507 if err := <-req1c; err == nil {
4508 t.Fatal("expected request to fail, but it succeeded")
4509 }
4510
4511 close(dialc)
4512
4513
4514
4515
4516
4517
4518
4519
4520 synctest.Wait()
4521 time.Sleep(timeout)
4522 synctest.Wait()
4523
4524 req2c := make(chan error)
4525 go func() {
4526 resp, err := cst.c.Get(cst.ts.URL)
4527 if err == nil {
4528 resp.Body.Close()
4529 }
4530 req2c <- err
4531 }()
4532
4533
4534 close(closec)
4535 if err := <-req2c; err != nil {
4536 t.Fatalf("Get: %v", err)
4537 }
4538 }
4539
4540 func TestTransportRemovesConnsAfterIdle(t *testing.T) {
4541 runSynctest(t, testTransportRemovesConnsAfterIdle, http3SkippedMode)
4542 }
4543 func testTransportRemovesConnsAfterIdle(t *testing.T, mode testMode) {
4544 if testing.Short() {
4545 t.Skip("skipping in short mode")
4546 }
4547
4548 timeout := 1 * time.Second
4549 trFunc := func(tr *Transport) {
4550 tr.MaxConnsPerHost = 1
4551 tr.MaxIdleConnsPerHost = 1
4552 tr.IdleConnTimeout = timeout
4553 }
4554 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4555 w.Header().Set("X-Addr", r.RemoteAddr)
4556 }), trFunc, optFakeNet)
4557
4558
4559
4560 makeRequest := func() string {
4561 resp, err := cst.c.Get(cst.ts.URL)
4562 if err != nil {
4563 t.Fatalf("got error: %s", err)
4564 }
4565 resp.Body.Close()
4566 return resp.Header.Get("X-Addr")
4567 }
4568
4569 addr1 := makeRequest()
4570
4571 time.Sleep(timeout / 2)
4572 synctest.Wait()
4573 addr2 := makeRequest()
4574 if addr1 != addr2 {
4575 t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
4576 }
4577
4578 time.Sleep(timeout)
4579 synctest.Wait()
4580 addr3 := makeRequest()
4581 if addr1 == addr3 {
4582 t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
4583 }
4584 }
4585
4586 func TestTransportRemovesConnsAfterBroken(t *testing.T) {
4587 runSynctest(t, testTransportRemovesConnsAfterBroken, http3SkippedMode)
4588 }
4589 func testTransportRemovesConnsAfterBroken(t *testing.T, mode testMode) {
4590 if testing.Short() {
4591 t.Skip("skipping in short mode")
4592 }
4593
4594 trFunc := func(tr *Transport) {
4595 tr.MaxConnsPerHost = 1
4596 tr.MaxIdleConnsPerHost = 1
4597 }
4598 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4599 w.Header().Set("X-Addr", r.RemoteAddr)
4600 }), trFunc, optFakeNet)
4601 cst.li.trackConns = true
4602
4603
4604
4605 makeRequest := func() string {
4606 resp, err := cst.c.Get(cst.ts.URL)
4607 if err != nil {
4608 t.Fatalf("got error: %s", err)
4609 }
4610 resp.Body.Close()
4611 return resp.Header.Get("X-Addr")
4612 }
4613
4614 addr1 := makeRequest()
4615 addr2 := makeRequest()
4616 if addr1 != addr2 {
4617 t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
4618 }
4619
4620
4621 synctest.Wait()
4622 cst.li.conns[0].peer.Close()
4623 synctest.Wait()
4624 addr3 := makeRequest()
4625 if addr1 == addr3 {
4626 t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
4627 }
4628 }
4629
4630
4631
4632
4633
4634 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4635 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4636 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4637 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4638 t.Error("Transport advertised gzip support in the Accept header")
4639 }
4640 if r.Header.Get("Range") == "" {
4641 t.Error("no Range in request")
4642 }
4643 })).ts
4644 c := ts.Client()
4645
4646 req, _ := NewRequest("GET", ts.URL, nil)
4647 req.Header.Set("Range", "bytes=7-11")
4648 res, err := c.Do(req)
4649 if err != nil {
4650 t.Fatal(err)
4651 }
4652 res.Body.Close()
4653 }
4654
4655
4656 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4657 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4658 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4659
4660 var b [1024]byte
4661 w.Write(b[:])
4662 })).ts
4663 tr := ts.Client().Transport.(*Transport)
4664
4665 req, err := NewRequest("GET", ts.URL, nil)
4666 if err != nil {
4667 t.Fatal(err)
4668 }
4669 res, err := tr.RoundTrip(req)
4670 if err != nil {
4671 t.Fatal(err)
4672 }
4673
4674
4675
4676 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4677 t.Fatal(err)
4678 }
4679
4680 req2, err := NewRequest("GET", ts.URL, nil)
4681 if err != nil {
4682 t.Fatal(err)
4683 }
4684 tr.CancelRequest(req)
4685 res, err = tr.RoundTrip(req2)
4686 if err != nil {
4687 t.Fatal(err)
4688 }
4689 res.Body.Close()
4690 }
4691
4692
4693 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4694 run(t, testTransportContentEncodingCaseInsensitive, http3SkippedMode)
4695 }
4696 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4697 for _, ce := range []string{"gzip", "GZIP"} {
4698 t.Run(ce, func(t *testing.T) {
4699 const encodedString = "Hello Gopher"
4700 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4701 w.Header().Set("Content-Encoding", ce)
4702 gz := gzip.NewWriter(w)
4703 gz.Write([]byte(encodedString))
4704 gz.Close()
4705 })).ts
4706
4707 res, err := ts.Client().Get(ts.URL)
4708 if err != nil {
4709 t.Fatal(err)
4710 }
4711
4712 body, err := io.ReadAll(res.Body)
4713 res.Body.Close()
4714 if err != nil {
4715 t.Fatal(err)
4716 }
4717
4718 if string(body) != encodedString {
4719 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4720 }
4721 })
4722 }
4723 }
4724
4725
4726 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4727 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4728 }
4729 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4730 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4731 func(tr *Transport) {
4732 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4733
4734 return &funcConn{
4735 read: func([]byte) (int, error) {
4736 return 0, errors.New("error")
4737 },
4738 write: func([]byte) (int, error) {
4739 return 0, errors.New("error")
4740 },
4741 }, nil
4742 }
4743 },
4744 ).ts
4745
4746
4747
4748
4749
4750 SetEnterRoundTripHook(func() {
4751 time.Sleep(1 * time.Millisecond)
4752 })
4753 defer SetEnterRoundTripHook(nil)
4754 var closes int
4755 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4756 if err == nil {
4757 t.Fatalf("expected request to fail, but it did not")
4758 }
4759 if closes != 1 {
4760 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4761 }
4762 }
4763
4764
4765
4766
4767 type logWritesConn struct {
4768 net.Conn
4769
4770 w io.Writer
4771
4772 rch <-chan io.Reader
4773 r io.Reader
4774
4775 mu sync.Mutex
4776 writes []string
4777 }
4778
4779 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4780 c.mu.Lock()
4781 defer c.mu.Unlock()
4782 c.writes = append(c.writes, string(p))
4783 return c.w.Write(p)
4784 }
4785
4786 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4787 if c.r == nil {
4788 c.r = <-c.rch
4789 }
4790 return c.r.Read(p)
4791 }
4792
4793 func (c *logWritesConn) Close() error { return nil }
4794
4795
4796 func TestTransportFlushesBodyChunks(t *testing.T) {
4797 defer afterTest(t)
4798 resBody := make(chan io.Reader, 1)
4799 connr, connw := io.Pipe()
4800 lw := &logWritesConn{
4801 rch: resBody,
4802 w: connw,
4803 }
4804 tr := &Transport{
4805 Dial: func(network, addr string) (net.Conn, error) {
4806 return lw, nil
4807 },
4808 }
4809 bodyr, bodyw := io.Pipe()
4810 go func() {
4811 defer bodyw.Close()
4812 for i := 0; i < 3; i++ {
4813 fmt.Fprintf(bodyw, "num%d\n", i)
4814 }
4815 }()
4816 resc := make(chan *Response)
4817 go func() {
4818 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4819 req.Header.Set("User-Agent", "x")
4820 res, err := tr.RoundTrip(req)
4821 if err != nil {
4822 t.Errorf("RoundTrip: %v", err)
4823 close(resc)
4824 return
4825 }
4826 resc <- res
4827
4828 }()
4829
4830 req, err := ReadRequest(bufio.NewReader(connr))
4831 if err != nil {
4832 t.Fatal(err)
4833 }
4834 io.Copy(io.Discard, req.Body)
4835
4836
4837 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4838 res, ok := <-resc
4839 if !ok {
4840 return
4841 }
4842 defer res.Body.Close()
4843
4844 want := []string{
4845 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4846 "5\r\nnum0\n\r\n",
4847 "5\r\nnum1\n\r\n",
4848 "5\r\nnum2\n\r\n",
4849 "0\r\n\r\n",
4850 }
4851 if !slices.Equal(lw.writes, want) {
4852 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4853 }
4854 }
4855
4856
4857 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4858 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4859 gotReq := make(chan struct{})
4860 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4861 close(gotReq)
4862 }))
4863
4864 pr, pw := io.Pipe()
4865 req, err := NewRequest("POST", cst.ts.URL, pr)
4866 if err != nil {
4867 t.Fatal(err)
4868 }
4869 gotRes := make(chan struct{})
4870 go func() {
4871 defer close(gotRes)
4872 res, err := cst.tr.RoundTrip(req)
4873 if err != nil {
4874 t.Error(err)
4875 return
4876 }
4877 res.Body.Close()
4878 }()
4879
4880 <-gotReq
4881 pw.Close()
4882 <-gotRes
4883 }
4884
4885 type wgReadCloser struct {
4886 io.Reader
4887 wg *sync.WaitGroup
4888 closed bool
4889 }
4890
4891 func (c *wgReadCloser) Close() error {
4892 if c.closed {
4893 return net.ErrClosed
4894 }
4895 c.closed = true
4896 c.wg.Done()
4897 return nil
4898 }
4899
4900
4901 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4902
4903 run(t, testTransportPrefersResponseOverWriteError, testNotParallel, http3SkippedMode)
4904 }
4905 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4906 if testing.Short() {
4907 t.Skip("skipping in short mode")
4908 }
4909
4910 runTimeSensitiveTest(t, []time.Duration{
4911 1 * time.Millisecond,
4912 5 * time.Millisecond,
4913 10 * time.Millisecond,
4914 50 * time.Millisecond,
4915 100 * time.Millisecond,
4916 500 * time.Millisecond,
4917 time.Second,
4918 5 * time.Second,
4919 }, func(t *testing.T, timeout time.Duration) error {
4920 SetRSTAvoidanceDelay(t, timeout)
4921 t.Logf("set RST avoidance delay to %v", timeout)
4922
4923 const contentLengthLimit = 1024 * 1024
4924 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4925 if r.ContentLength >= contentLengthLimit {
4926 w.WriteHeader(StatusBadRequest)
4927 r.Body.Close()
4928 return
4929 }
4930 w.WriteHeader(StatusOK)
4931 }))
4932
4933
4934 defer cst.close()
4935 ts := cst.ts
4936 c := ts.Client()
4937
4938 count := 100
4939
4940 bigBody := strings.Repeat("a", contentLengthLimit*2)
4941 var wg sync.WaitGroup
4942 defer wg.Wait()
4943 getBody := func() (io.ReadCloser, error) {
4944 wg.Add(1)
4945 body := &wgReadCloser{
4946 Reader: strings.NewReader(bigBody),
4947 wg: &wg,
4948 }
4949 return body, nil
4950 }
4951
4952 for i := 0; i < count; i++ {
4953 reqBody, _ := getBody()
4954 req, err := NewRequest("PUT", ts.URL, reqBody)
4955 if err != nil {
4956 reqBody.Close()
4957 t.Fatal(err)
4958 }
4959 req.ContentLength = int64(len(bigBody))
4960 req.GetBody = getBody
4961
4962 resp, err := c.Do(req)
4963 if err != nil {
4964 return fmt.Errorf("Do %d: %v", i, err)
4965 } else {
4966 resp.Body.Close()
4967 if resp.StatusCode != 400 {
4968 t.Errorf("Expected status code 400, got %v", resp.Status)
4969 }
4970 }
4971 }
4972 return nil
4973 })
4974 }
4975
4976 func TestTransportAutomaticHTTP2(t *testing.T) {
4977 testTransportAutoHTTP(t, &Transport{}, true)
4978 }
4979
4980 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4981 testTransportAutoHTTP(t, &Transport{
4982 ForceAttemptHTTP2: true,
4983 TLSClientConfig: new(tls.Config),
4984 }, true)
4985 }
4986
4987
4988 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4989 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4990 }
4991
4992 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4993 testTransportAutoHTTP(t, &Transport{
4994 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4995 }, false)
4996 }
4997
4998 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4999 testTransportAutoHTTP(t, &Transport{
5000 TLSClientConfig: new(tls.Config),
5001 }, false)
5002 }
5003
5004 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
5005 testTransportAutoHTTP(t, &Transport{
5006 ExpectContinueTimeout: 1 * time.Second,
5007 }, true)
5008 }
5009
5010 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
5011 var d net.Dialer
5012 testTransportAutoHTTP(t, &Transport{
5013 Dial: d.Dial,
5014 }, false)
5015 }
5016
5017 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
5018 var d net.Dialer
5019 testTransportAutoHTTP(t, &Transport{
5020 DialContext: d.DialContext,
5021 }, false)
5022 }
5023
5024 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
5025 testTransportAutoHTTP(t, &Transport{
5026 DialTLS: func(network, addr string) (net.Conn, error) {
5027 panic("unused")
5028 },
5029 }, false)
5030 }
5031
5032 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
5033 CondSkipHTTP2(t)
5034 _, err := tr.RoundTrip(new(Request))
5035 if err == nil {
5036 t.Error("expected error from RoundTrip")
5037 }
5038 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
5039 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
5040 }
5041 }
5042
5043
5044
5045
5046
5047
5048
5049
5050 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
5051 run(t, testTransportReuseConnEmptyResponseBody)
5052 }
5053 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
5054 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5055 w.Header().Set("X-Addr", r.RemoteAddr)
5056
5057 }))
5058 n := 100
5059 if testing.Short() {
5060 n = 10
5061 }
5062 var firstAddr string
5063 for i := 0; i < n; i++ {
5064 res, err := cst.c.Get(cst.ts.URL)
5065 if err != nil {
5066 log.Fatal(err)
5067 }
5068 addr := res.Header.Get("X-Addr")
5069 if i == 0 {
5070 firstAddr = addr
5071 } else if addr != firstAddr {
5072 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
5073 }
5074 res.Body.Close()
5075 }
5076 }
5077
5078
5079 func TestNoCrashReturningTransportAltConn(t *testing.T) {
5080 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
5081 if err != nil {
5082 t.Fatal(err)
5083 }
5084 ln := newLocalListener(t)
5085 defer ln.Close()
5086
5087 var wg sync.WaitGroup
5088 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
5089 defer SetPendingDialHooks(nil, nil)
5090
5091 testDone := make(chan struct{})
5092 defer close(testDone)
5093 go func() {
5094 tln := tls.NewListener(ln, &tls.Config{
5095 NextProtos: []string{"foo"},
5096 Certificates: []tls.Certificate{cert},
5097 })
5098 sc, err := tln.Accept()
5099 if err != nil {
5100 t.Error(err)
5101 return
5102 }
5103 if err := sc.(*tls.Conn).Handshake(); err != nil {
5104 t.Error(err)
5105 return
5106 }
5107 <-testDone
5108 sc.Close()
5109 }()
5110
5111 addr := ln.Addr().String()
5112
5113 req, _ := NewRequest("GET", "https://fake.tld/", nil)
5114 cancel := make(chan struct{})
5115 req.Cancel = cancel
5116
5117 doReturned := make(chan bool, 1)
5118 madeRoundTripper := make(chan bool, 1)
5119
5120 tr := &Transport{
5121 DisableKeepAlives: true,
5122 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
5123 "foo": func(authority string, c *tls.Conn) RoundTripper {
5124 madeRoundTripper <- true
5125 return funcRoundTripper(func() {
5126 t.Error("foo RoundTripper should not be called")
5127 })
5128 },
5129 },
5130 Dial: func(_, _ string) (net.Conn, error) {
5131 panic("shouldn't be called")
5132 },
5133 DialTLS: func(_, _ string) (net.Conn, error) {
5134 tc, err := tls.Dial("tcp", addr, &tls.Config{
5135 InsecureSkipVerify: true,
5136 NextProtos: []string{"foo"},
5137 })
5138 if err != nil {
5139 return nil, err
5140 }
5141 if err := tc.Handshake(); err != nil {
5142 return nil, err
5143 }
5144 close(cancel)
5145 <-doReturned
5146 return tc, nil
5147 },
5148 }
5149 c := &Client{Transport: tr}
5150
5151 _, err = c.Do(req)
5152 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
5153 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
5154 }
5155
5156 doReturned <- true
5157 <-madeRoundTripper
5158 wg.Wait()
5159 }
5160
5161 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
5162 run(t, func(t *testing.T, mode testMode) {
5163 testTransportReuseConnection_Gzip(t, mode, true)
5164 })
5165 }
5166
5167 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
5168 run(t, func(t *testing.T, mode testMode) {
5169 testTransportReuseConnection_Gzip(t, mode, false)
5170 })
5171 }
5172
5173
5174 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
5175 addr := make(chan string, 2)
5176 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5177 addr <- r.RemoteAddr
5178 w.Header().Set("Content-Encoding", "gzip")
5179 if chunked {
5180 w.(Flusher).Flush()
5181 }
5182 w.Write(rgz)
5183 })).ts
5184 c := ts.Client()
5185
5186 trace := &httptrace.ClientTrace{
5187 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
5188 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
5189 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
5190 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
5191 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
5192 }
5193 ctx := httptrace.WithClientTrace(context.Background(), trace)
5194
5195 for i := 0; i < 2; i++ {
5196 req, _ := NewRequest("GET", ts.URL, nil)
5197 req = req.WithContext(ctx)
5198 res, err := c.Do(req)
5199 if err != nil {
5200 t.Fatal(err)
5201 }
5202 buf := make([]byte, len(rgz))
5203 if n, err := io.ReadFull(res.Body, buf); err != nil {
5204 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
5205 }
5206
5207
5208
5209 }
5210 a1, a2 := <-addr, <-addr
5211 if a1 != a2 {
5212 t.Fatalf("didn't reuse connection")
5213 }
5214 }
5215
5216 func TestTransportResponseHeaderLength(t *testing.T) {
5217 run(t, testTransportResponseHeaderLength, http3SkippedMode)
5218 }
5219 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
5220 if mode == http2Mode {
5221 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
5222 }
5223 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5224 if r.URL.Path == "/long" {
5225 w.Header().Set("Long", strings.Repeat("a", 1<<20))
5226 }
5227 })).ts
5228 c := ts.Client()
5229 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
5230
5231 if res, err := c.Get(ts.URL); err != nil {
5232 t.Fatal(err)
5233 } else {
5234 res.Body.Close()
5235 }
5236
5237 res, err := c.Get(ts.URL + "/long")
5238 if err == nil {
5239 defer res.Body.Close()
5240 var n int64
5241 for k, vv := range res.Header {
5242 for _, v := range vv {
5243 n += int64(len(k)) + int64(len(v))
5244 }
5245 }
5246 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
5247 }
5248 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
5249 t.Errorf("got error: %v; want %q", err, want)
5250 }
5251 }
5252
5253 func TestTransportEventTrace(t *testing.T) {
5254 run(t, func(t *testing.T, mode testMode) {
5255 testTransportEventTrace(t, mode, false)
5256 }, testNotParallel, http3SkippedMode)
5257 }
5258
5259
5260 func TestTransportEventTrace_NoHooks(t *testing.T) {
5261 run(t, func(t *testing.T, mode testMode) {
5262 testTransportEventTrace(t, mode, true)
5263 }, testNotParallel, http3SkippedMode)
5264 }
5265
5266 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
5267 const resBody = "some body"
5268 gotWroteReqEvent := make(chan struct{}, 500)
5269 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5270 if r.Method == "GET" {
5271
5272 return
5273 }
5274 if _, err := io.ReadAll(r.Body); err != nil {
5275 t.Error(err)
5276 }
5277 if !noHooks {
5278 <-gotWroteReqEvent
5279 }
5280 io.WriteString(w, resBody)
5281 }), func(tr *Transport) {
5282 if tr.TLSClientConfig != nil {
5283 tr.TLSClientConfig.InsecureSkipVerify = true
5284 }
5285 })
5286 defer cst.close()
5287
5288 cst.tr.ExpectContinueTimeout = 1 * time.Second
5289
5290 var mu sync.Mutex
5291 var buf strings.Builder
5292 logf := func(format string, args ...any) {
5293 mu.Lock()
5294 defer mu.Unlock()
5295 fmt.Fprintf(&buf, format, args...)
5296 buf.WriteByte('\n')
5297 }
5298
5299 addrStr := cst.ts.Listener.Addr().String()
5300 ip, port, err := net.SplitHostPort(addrStr)
5301 if err != nil {
5302 t.Fatal(err)
5303 }
5304
5305
5306 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5307 if host != "dns-is-faked.golang" {
5308 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
5309 return nil, nil
5310 }
5311 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5312 })
5313
5314 body := "some body"
5315 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
5316 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
5317 trace := &httptrace.ClientTrace{
5318 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
5319 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
5320 GotFirstResponseByte: func() { logf("first response byte") },
5321 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
5322 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
5323 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
5324 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
5325 ConnectDone: func(network, addr string, err error) {
5326 if err != nil {
5327 t.Errorf("ConnectDone: %v", err)
5328 }
5329 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
5330 },
5331 WroteHeaderField: func(key string, value []string) {
5332 logf("WroteHeaderField: %s: %v", key, value)
5333 },
5334 WroteHeaders: func() {
5335 logf("WroteHeaders")
5336 },
5337 Wait100Continue: func() { logf("Wait100Continue") },
5338 Got100Continue: func() { logf("Got100Continue") },
5339 WroteRequest: func(e httptrace.WroteRequestInfo) {
5340 logf("WroteRequest: %+v", e)
5341 gotWroteReqEvent <- struct{}{}
5342 },
5343 }
5344 if mode == http2Mode {
5345 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
5346 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
5347 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
5348 }
5349 }
5350 if noHooks {
5351
5352 *trace = httptrace.ClientTrace{}
5353 }
5354 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5355
5356 req.Header.Set("Expect", "100-continue")
5357 res, err := cst.c.Do(req)
5358 if err != nil {
5359 t.Fatal(err)
5360 }
5361 logf("got roundtrip.response")
5362 slurp, err := io.ReadAll(res.Body)
5363 if err != nil {
5364 t.Fatal(err)
5365 }
5366 logf("consumed body")
5367 if string(slurp) != resBody || res.StatusCode != 200 {
5368 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5369 }
5370 res.Body.Close()
5371
5372 if noHooks {
5373
5374
5375
5376 return
5377 }
5378
5379 mu.Lock()
5380 got := buf.String()
5381 mu.Unlock()
5382
5383 wantOnce := func(sub string) {
5384 if strings.Count(got, sub) != 1 {
5385 t.Errorf("expected substring %q exactly once in output.", sub)
5386 }
5387 }
5388 wantOnceOrMore := func(sub string) {
5389 if strings.Count(got, sub) == 0 {
5390 t.Errorf("expected substring %q at least once in output.", sub)
5391 }
5392 }
5393 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5394 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5395 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5396 wantOnce("got conn: {")
5397 wantOnceOrMore("Connecting to tcp " + addrStr)
5398 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5399 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5400 wantOnce("first response byte")
5401 if mode == http2Mode {
5402 wantOnce("tls handshake start")
5403 wantOnce("tls handshake done")
5404 } else {
5405 wantOnce("PutIdleConn = <nil>")
5406 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5407
5408
5409 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5410 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5411 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5412 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5413 }
5414 wantOnce("WroteHeaders")
5415 wantOnce("Wait100Continue")
5416 wantOnce("Got100Continue")
5417 wantOnce("WroteRequest: {Err:<nil>}")
5418 if strings.Contains(got, " to udp ") {
5419 t.Errorf("should not see UDP (DNS) connections")
5420 }
5421 if t.Failed() {
5422 t.Errorf("Output:\n%s", got)
5423 }
5424
5425
5426 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5427 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5428 res, err = cst.c.Do(req)
5429 if err != nil {
5430 t.Fatal(err)
5431 }
5432 if res.StatusCode != 200 {
5433 t.Fatal(res.Status)
5434 }
5435 res.Body.Close()
5436
5437 mu.Lock()
5438 got = buf.String()
5439 mu.Unlock()
5440
5441 sub := "Getting conn for dns-is-faked.golang:"
5442 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5443 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5444 }
5445
5446 }
5447
5448 func TestTransportEventTraceTLSVerify(t *testing.T) {
5449 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5450 }
5451 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5452 var mu sync.Mutex
5453 var buf strings.Builder
5454 logf := func(format string, args ...any) {
5455 mu.Lock()
5456 defer mu.Unlock()
5457 fmt.Fprintf(&buf, format, args...)
5458 buf.WriteByte('\n')
5459 }
5460
5461 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5462 t.Error("Unexpected request")
5463 }), func(ts *httptest.Server) {
5464 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5465 logf("%s", p)
5466 return len(p), nil
5467 }), "", 0)
5468 }).ts
5469
5470 certpool := x509.NewCertPool()
5471 certpool.AddCert(ts.Certificate())
5472
5473 c := &Client{Transport: &Transport{
5474 TLSClientConfig: &tls.Config{
5475 ServerName: "dns-is-faked.golang",
5476 RootCAs: certpool,
5477 },
5478 }}
5479
5480 trace := &httptrace.ClientTrace{
5481 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5482 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5483 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5484 },
5485 }
5486
5487 req, _ := NewRequest("GET", ts.URL, nil)
5488 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5489 _, err := c.Do(req)
5490 if err == nil {
5491 t.Error("Expected request to fail TLS verification")
5492 }
5493
5494 mu.Lock()
5495 got := buf.String()
5496 mu.Unlock()
5497
5498 wantOnce := func(sub string) {
5499 if strings.Count(got, sub) != 1 {
5500 t.Errorf("expected substring %q exactly once in output.", sub)
5501 }
5502 }
5503
5504 wantOnce("TLSHandshakeStart")
5505 wantOnce("TLSHandshakeDone")
5506 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5507
5508 if t.Failed() {
5509 t.Errorf("Output:\n%s", got)
5510 }
5511 }
5512
5513 var isDNSHijacked = sync.OnceValue(func() bool {
5514 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5515 return len(addrs) != 0
5516 })
5517
5518 func skipIfDNSHijacked(t *testing.T) {
5519
5520
5521
5522 if isDNSHijacked() {
5523 t.Skip("skipping; test requires non-hijacking DNS server")
5524 }
5525 }
5526
5527 func TestTransportEventTraceRealDNS(t *testing.T) {
5528 skipIfDNSHijacked(t)
5529 defer afterTest(t)
5530 tr := &Transport{}
5531 defer tr.CloseIdleConnections()
5532 c := &Client{Transport: tr}
5533
5534 var mu sync.Mutex
5535 var buf strings.Builder
5536 logf := func(format string, args ...any) {
5537 mu.Lock()
5538 defer mu.Unlock()
5539 fmt.Fprintf(&buf, format, args...)
5540 buf.WriteByte('\n')
5541 }
5542
5543 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5544 trace := &httptrace.ClientTrace{
5545 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5546 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5547 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5548 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5549 }
5550 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5551
5552 resp, err := c.Do(req)
5553 if err == nil {
5554 resp.Body.Close()
5555 t.Fatal("expected error during DNS lookup")
5556 }
5557
5558 mu.Lock()
5559 got := buf.String()
5560 mu.Unlock()
5561
5562 wantSub := func(sub string) {
5563 if !strings.Contains(got, sub) {
5564 t.Errorf("expected substring %q in output.", sub)
5565 }
5566 }
5567 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5568 wantSub("DNSDone: {Addrs:[] Err:")
5569 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5570 t.Errorf("should not see Connect events")
5571 }
5572 if t.Failed() {
5573 t.Errorf("Output:\n%s", got)
5574 }
5575 }
5576
5577
5578 func TestTransportRejectsAlphaPort(t *testing.T) {
5579 res, err := Get("http://dummy.tld:123foo/bar")
5580 if err == nil {
5581 res.Body.Close()
5582 t.Fatal("unexpected success")
5583 }
5584 ue, ok := err.(*url.Error)
5585 if !ok {
5586 t.Fatalf("got %#v; want *url.Error", err)
5587 }
5588 got := ue.Err.Error()
5589 want := `invalid port ":123foo" after host`
5590 if got != want {
5591 t.Errorf("got error %q; want %q", got, want)
5592 }
5593 }
5594
5595
5596
5597 func TestTLSHandshakeTrace(t *testing.T) {
5598 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5599 }
5600 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5601 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5602
5603 var mu sync.Mutex
5604 var start, done bool
5605 trace := &httptrace.ClientTrace{
5606 TLSHandshakeStart: func() {
5607 mu.Lock()
5608 defer mu.Unlock()
5609 start = true
5610 },
5611 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5612 mu.Lock()
5613 defer mu.Unlock()
5614 done = true
5615 if err != nil {
5616 t.Fatal("Expected error to be nil but was:", err)
5617 }
5618 },
5619 }
5620
5621 c := ts.Client()
5622 req, err := NewRequest("GET", ts.URL, nil)
5623 if err != nil {
5624 t.Fatal("Unable to construct test request:", err)
5625 }
5626 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5627
5628 r, err := c.Do(req)
5629 if err != nil {
5630 t.Fatal("Unexpected error making request:", err)
5631 }
5632 r.Body.Close()
5633 mu.Lock()
5634 defer mu.Unlock()
5635 if !start {
5636 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5637 }
5638 if !done {
5639 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5640 }
5641 }
5642
5643 func TestTransportMaxIdleConns(t *testing.T) {
5644 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5645 }
5646 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5647 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5648
5649 })).ts
5650 c := ts.Client()
5651 tr := c.Transport.(*Transport)
5652 tr.MaxIdleConns = 4
5653
5654 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5655 if err != nil {
5656 t.Fatal(err)
5657 }
5658 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5659 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5660 })
5661
5662 hitHost := func(n int) {
5663 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5664 req = req.WithContext(ctx)
5665 res, err := c.Do(req)
5666 if err != nil {
5667 t.Fatal(err)
5668 }
5669 res.Body.Close()
5670 }
5671 for i := 0; i < 4; i++ {
5672 hitHost(i)
5673 }
5674 want := []string{
5675 "|http|host-0.dns-is-faked.golang:" + port,
5676 "|http|host-1.dns-is-faked.golang:" + port,
5677 "|http|host-2.dns-is-faked.golang:" + port,
5678 "|http|host-3.dns-is-faked.golang:" + port,
5679 }
5680 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5681 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5682 }
5683
5684
5685 hitHost(4)
5686 want = []string{
5687 "|http|host-1.dns-is-faked.golang:" + port,
5688 "|http|host-2.dns-is-faked.golang:" + port,
5689 "|http|host-3.dns-is-faked.golang:" + port,
5690 "|http|host-4.dns-is-faked.golang:" + port,
5691 }
5692 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5693 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5694 }
5695 }
5696
5697 func TestTransportIdleConnTimeout(t *testing.T) {
5698 run(t, testTransportIdleConnTimeout, http3SkippedMode)
5699 }
5700 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5701 if testing.Short() {
5702 t.Skip("skipping in short mode")
5703 }
5704
5705 timeout := 1 * time.Millisecond
5706 timeoutLoop:
5707 for {
5708 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5709
5710 }))
5711 tr := cst.tr
5712 tr.IdleConnTimeout = timeout
5713 defer tr.CloseIdleConnections()
5714 c := &Client{Transport: tr}
5715
5716 idleConns := func() []string {
5717 return tr.IdleConnStrsForTesting()
5718 }
5719
5720 var conn string
5721 doReq := func(n int) (timeoutOk bool) {
5722 req, _ := NewRequest("GET", cst.ts.URL, nil)
5723 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5724 PutIdleConn: func(err error) {
5725 if err != nil {
5726 t.Errorf("failed to keep idle conn: %v", err)
5727 }
5728 },
5729 }))
5730 res, err := c.Do(req)
5731 if err != nil {
5732 if strings.Contains(err.Error(), "use of closed network connection") {
5733 t.Logf("req %v: connection closed prematurely", n)
5734 return false
5735 }
5736 }
5737 if err == nil {
5738 res.Body.Close()
5739 }
5740 conns := idleConns()
5741 if len(conns) != 1 {
5742 if len(conns) == 0 {
5743 t.Logf("req %v: no idle conns", n)
5744 return false
5745 }
5746 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5747 }
5748 if conn == "" {
5749 conn = conns[0]
5750 }
5751 if conn != conns[0] {
5752 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5753 return false
5754 }
5755 return true
5756 }
5757 for i := 0; i < 3; i++ {
5758 if !doReq(i) {
5759 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5760 timeout *= 2
5761 cst.close()
5762 continue timeoutLoop
5763 }
5764 time.Sleep(timeout / 2)
5765 }
5766
5767 waitCondition(t, timeout/2, func(d time.Duration) bool {
5768 if got := idleConns(); len(got) != 0 {
5769 if d >= timeout*3/2 {
5770 t.Logf("after %v, idle conns = %q", d, got)
5771 }
5772 return false
5773 }
5774 return true
5775 })
5776 break
5777 }
5778 }
5779
5780
5781
5782
5783
5784
5785
5786
5787
5788
5789
5790
5791 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5792 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5793 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5794
5795 }))
5796
5797 ctx, cancel := context.WithCancel(context.Background())
5798 defer cancel()
5799
5800 sawDoErr := make(chan bool, 1)
5801 testDone := make(chan struct{})
5802 defer close(testDone)
5803
5804 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5805 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5806 c, err := tls.Dial(network, addr, &tls.Config{
5807 InsecureSkipVerify: true,
5808 NextProtos: []string{"h2"},
5809 })
5810 if err != nil {
5811 t.Error(err)
5812 return nil, err
5813 }
5814 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5815 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5816 c.Close()
5817 return nil, errors.New("bogus")
5818 }
5819
5820 cancel()
5821
5822 select {
5823 case <-sawDoErr:
5824 case <-testDone:
5825 }
5826 return c, nil
5827 }
5828
5829 req, _ := NewRequest("GET", cst.ts.URL, nil)
5830 req = req.WithContext(ctx)
5831 res, err := cst.c.Do(req)
5832 if err == nil {
5833 res.Body.Close()
5834 t.Fatal("unexpected success")
5835 }
5836 sawDoErr <- true
5837
5838
5839 time.Sleep(cst.tr.IdleConnTimeout * 10)
5840 }
5841
5842 type funcConn struct {
5843 net.Conn
5844 read func([]byte) (int, error)
5845 write func([]byte) (int, error)
5846 }
5847
5848 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5849 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5850 func (c funcConn) Close() error { return nil }
5851
5852
5853
5854 func TestTransportReturnsPeekError(t *testing.T) {
5855 errValue := errors.New("specific error value")
5856
5857 wrote := make(chan struct{})
5858 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5859
5860 tr := &Transport{
5861 Dial: func(network, addr string) (net.Conn, error) {
5862 c := funcConn{
5863 read: func([]byte) (int, error) {
5864 <-wrote
5865 return 0, errValue
5866 },
5867 write: func(p []byte) (int, error) {
5868 wroteOnce()
5869 return len(p), nil
5870 },
5871 }
5872 return c, nil
5873 },
5874 }
5875 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5876 if err != errValue {
5877 t.Errorf("error = %#v; want %v", err, errValue)
5878 }
5879 }
5880
5881
5882 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA, http3SkippedMode) }
5883 func testTransportIDNA(t *testing.T, mode testMode) {
5884 const uniDomain = "гофер.го"
5885 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5886
5887 var port string
5888 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5889 want := punyDomain + ":" + port
5890 if r.Host != want {
5891 t.Errorf("Host header = %q; want %q", r.Host, want)
5892 }
5893 if mode == http2Mode {
5894 if r.TLS == nil {
5895 t.Errorf("r.TLS == nil")
5896 } else if r.TLS.ServerName != punyDomain {
5897 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5898 }
5899 }
5900 w.Header().Set("Hit-Handler", "1")
5901 }), func(tr *Transport) {
5902 if tr.TLSClientConfig != nil {
5903 tr.TLSClientConfig.InsecureSkipVerify = true
5904 }
5905 })
5906
5907 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5908 if err != nil {
5909 t.Fatal(err)
5910 }
5911
5912
5913 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5914 if host != punyDomain {
5915 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5916 return nil, nil
5917 }
5918 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5919 })
5920
5921 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5922 trace := &httptrace.ClientTrace{
5923 GetConn: func(hostPort string) {
5924 want := net.JoinHostPort(punyDomain, port)
5925 if hostPort != want {
5926 t.Errorf("getting conn for %q; want %q", hostPort, want)
5927 }
5928 },
5929 DNSStart: func(e httptrace.DNSStartInfo) {
5930 if e.Host != punyDomain {
5931 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5932 }
5933 },
5934 }
5935 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5936
5937 res, err := cst.tr.RoundTrip(req)
5938 if err != nil {
5939 t.Fatal(err)
5940 }
5941 defer res.Body.Close()
5942 if res.Header.Get("Hit-Handler") != "1" {
5943 out, err := httputil.DumpResponse(res, true)
5944 if err != nil {
5945 t.Fatal(err)
5946 }
5947 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5948 }
5949 }
5950
5951
5952 func TestTransportProxyConnectHeader(t *testing.T) {
5953 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5954 }
5955 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5956 reqc := make(chan *Request, 1)
5957 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5958 if r.Method != "CONNECT" {
5959 t.Errorf("method = %q; want CONNECT", r.Method)
5960 }
5961 reqc <- r
5962 c, _, err := w.(Hijacker).Hijack()
5963 if err != nil {
5964 t.Errorf("Hijack: %v", err)
5965 return
5966 }
5967 c.Close()
5968 })).ts
5969
5970 c := ts.Client()
5971 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5972 return url.Parse(ts.URL)
5973 }
5974 c.Transport.(*Transport).ProxyConnectHeader = Header{
5975 "User-Agent": {"foo"},
5976 "Other": {"bar"},
5977 }
5978
5979 res, err := c.Get("https://dummy.tld/")
5980 if err == nil {
5981 res.Body.Close()
5982 t.Errorf("unexpected success")
5983 }
5984
5985 r := <-reqc
5986 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5987 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5988 }
5989 if got, want := r.Header.Get("Other"), "bar"; got != want {
5990 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5991 }
5992 }
5993
5994 func TestTransportProxyGetConnectHeader(t *testing.T) {
5995 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5996 }
5997 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5998 reqc := make(chan *Request, 1)
5999 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6000 if r.Method != "CONNECT" {
6001 t.Errorf("method = %q; want CONNECT", r.Method)
6002 }
6003 reqc <- r
6004 c, _, err := w.(Hijacker).Hijack()
6005 if err != nil {
6006 t.Errorf("Hijack: %v", err)
6007 return
6008 }
6009 c.Close()
6010 })).ts
6011
6012 c := ts.Client()
6013 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
6014 return url.Parse(ts.URL)
6015 }
6016
6017 c.Transport.(*Transport).ProxyConnectHeader = Header{
6018 "User-Agent": {"foo"},
6019 "Other": {"bar"},
6020 }
6021 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
6022 return Header{
6023 "User-Agent": {"foo2"},
6024 "Other": {"bar2"},
6025 }, nil
6026 }
6027
6028 res, err := c.Get("https://dummy.tld/")
6029 if err == nil {
6030 res.Body.Close()
6031 t.Errorf("unexpected success")
6032 }
6033
6034 r := <-reqc
6035 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
6036 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
6037 }
6038 if got, want := r.Header.Get("Other"), "bar2"; got != want {
6039 t.Errorf("CONNECT request Other = %q; want %q", got, want)
6040 }
6041 }
6042
6043 var errFakeRoundTrip = errors.New("fake roundtrip")
6044
6045 type funcRoundTripper func()
6046
6047 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
6048 fn()
6049 return nil, errFakeRoundTrip
6050 }
6051
6052 func wantBody(res *Response, err error, want string) error {
6053 if err != nil {
6054 return err
6055 }
6056 slurp, err := io.ReadAll(res.Body)
6057 if err != nil {
6058 return fmt.Errorf("error reading body: %v", err)
6059 }
6060 if string(slurp) != want {
6061 return fmt.Errorf("body = %q; want %q", slurp, want)
6062 }
6063 if err := res.Body.Close(); err != nil {
6064 return fmt.Errorf("body Close = %v", err)
6065 }
6066 return nil
6067 }
6068
6069 func newLocalListener(t *testing.T) net.Listener {
6070 ln, err := net.Listen("tcp", "127.0.0.1:0")
6071 if err != nil {
6072 ln, err = net.Listen("tcp6", "[::1]:0")
6073 }
6074 if err != nil {
6075 t.Fatal(err)
6076 }
6077 return ln
6078 }
6079
6080 type countCloseReader struct {
6081 n *int
6082 io.Reader
6083 }
6084
6085 func (cr countCloseReader) Close() error {
6086 (*cr.n)++
6087 return nil
6088 }
6089
6090
6091 var rgz = []byte{
6092 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
6093 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
6094 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
6095 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
6096 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
6097 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
6098 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
6099 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
6100 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
6101 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
6102 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
6103 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
6104 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
6105 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
6106 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
6107 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
6108 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
6109 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
6110 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
6111 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
6112 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
6113 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
6114 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
6115 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
6116 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
6117 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
6118 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
6119 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
6120 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
6121 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
6122 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
6123 0x00, 0x00,
6124 }
6125
6126
6127
6128 func TestMissingStatusNoPanic(t *testing.T) {
6129 t.Parallel()
6130
6131 const want = "unknown status code"
6132
6133 ln := newLocalListener(t)
6134 addr := ln.Addr().String()
6135 done := make(chan bool)
6136 fullAddrURL := fmt.Sprintf("http://%s", addr)
6137 raw := "HTTP/1.1 400\r\n" +
6138 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6139 "Content-Type: text/html; charset=utf-8\r\n" +
6140 "Content-Length: 10\r\n" +
6141 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
6142 "Vary: Accept-Encoding\r\n\r\n" +
6143 "Aloha Olaa"
6144
6145 go func() {
6146 defer close(done)
6147
6148 conn, _ := ln.Accept()
6149 if conn != nil {
6150 io.WriteString(conn, raw)
6151 io.ReadAll(conn)
6152 conn.Close()
6153 }
6154 }()
6155
6156 proxyURL, err := url.Parse(fullAddrURL)
6157 if err != nil {
6158 t.Fatalf("proxyURL: %v", err)
6159 }
6160
6161 tr := &Transport{Proxy: ProxyURL(proxyURL)}
6162
6163 req, _ := NewRequest("GET", "https://golang.org/", nil)
6164 res, err, panicked := doFetchCheckPanic(tr, req)
6165 if panicked {
6166 t.Error("panicked, expecting an error")
6167 }
6168 if res != nil && res.Body != nil {
6169 io.Copy(io.Discard, res.Body)
6170 res.Body.Close()
6171 }
6172
6173 if err == nil || !strings.Contains(err.Error(), want) {
6174 t.Errorf("got=%v want=%q", err, want)
6175 }
6176
6177 ln.Close()
6178 <-done
6179 }
6180
6181 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
6182 defer func() {
6183 if r := recover(); r != nil {
6184 panicked = true
6185 }
6186 }()
6187 res, err = tr.RoundTrip(req)
6188 return
6189 }
6190
6191
6192
6193 func TestNoBodyOnChunked304Response(t *testing.T) {
6194 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
6195 }
6196 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
6197 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6198 conn, buf, _ := w.(Hijacker).Hijack()
6199 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
6200 buf.Flush()
6201 conn.Close()
6202 }))
6203
6204
6205
6206
6207
6208 cst.tr.DisableKeepAlives = true
6209
6210 res, err := cst.c.Get(cst.ts.URL)
6211 if err != nil {
6212 t.Fatal(err)
6213 }
6214
6215 if res.Body != NoBody {
6216 t.Errorf("Unexpected body on 304 response")
6217 }
6218 }
6219
6220 type funcWriter func([]byte) (int, error)
6221
6222 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
6223
6224 type doneContext struct {
6225 context.Context
6226 err error
6227 }
6228
6229 func (doneContext) Done() <-chan struct{} {
6230 c := make(chan struct{})
6231 close(c)
6232 return c
6233 }
6234
6235 func (d doneContext) Err() error { return d.err }
6236
6237
6238 func TestTransportCheckContextDoneEarly(t *testing.T) {
6239 tr := &Transport{}
6240 req, _ := NewRequest("GET", "http://fake.example/", nil)
6241 wantErr := errors.New("some error")
6242 req = req.WithContext(doneContext{context.Background(), wantErr})
6243 _, err := tr.RoundTrip(req)
6244 if err != wantErr {
6245 t.Errorf("error = %v; want %v", err, wantErr)
6246 }
6247 }
6248
6249
6250
6251
6252
6253
6254 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
6255 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
6256 }
6257 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
6258 timeout := 1 * time.Millisecond
6259 for {
6260 inHandler := make(chan bool)
6261 cancelHandler := make(chan struct{})
6262 handlerDone := make(chan bool)
6263 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6264 <-r.Context().Done()
6265
6266 select {
6267 case <-cancelHandler:
6268 return
6269 case inHandler <- true:
6270 }
6271 defer func() { handlerDone <- true }()
6272
6273
6274 conn, _, err := w.(Hijacker).Hijack()
6275 if err != nil {
6276 t.Error(err)
6277 return
6278 }
6279 n, err := conn.Read([]byte{0})
6280 if n != 0 || err != io.EOF {
6281 t.Errorf("unexpected Read result: %v, %v", n, err)
6282 }
6283 conn.Close()
6284 }))
6285
6286 cst.c.Timeout = timeout
6287
6288 _, err := cst.c.Get(cst.ts.URL)
6289 if err == nil {
6290 close(cancelHandler)
6291 t.Fatal("unexpected Get success")
6292 }
6293
6294 tooSlow := time.NewTimer(timeout * 10)
6295 select {
6296 case <-tooSlow.C:
6297
6298
6299
6300 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
6301 close(cancelHandler)
6302 cst.close()
6303 timeout *= 2
6304 continue
6305 case <-inHandler:
6306 tooSlow.Stop()
6307 <-handlerDone
6308 }
6309 break
6310 }
6311 }
6312
6313
6314
6315
6316
6317
6318 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
6319 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
6320 }
6321 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
6322 inHandler := make(chan bool)
6323 cancelHandler := make(chan struct{})
6324 handlerDone := make(chan bool)
6325 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6326 w.Header().Set("Content-Length", "100")
6327 w.(Flusher).Flush()
6328
6329 select {
6330 case <-cancelHandler:
6331 return
6332 case inHandler <- true:
6333 }
6334 defer func() { handlerDone <- true }()
6335
6336 conn, _, err := w.(Hijacker).Hijack()
6337 if err != nil {
6338 t.Error(err)
6339 return
6340 }
6341 conn.Write([]byte("foo"))
6342
6343 n, err := conn.Read([]byte{0})
6344
6345
6346
6347
6348
6349 if n != 0 || err == nil {
6350 t.Errorf("unexpected Read result: %v, %v", n, err)
6351 }
6352 conn.Close()
6353 }))
6354
6355
6356
6357
6358
6359 cst.c.Timeout = 24 * time.Hour
6360 req, _ := NewRequest("GET", cst.ts.URL, nil)
6361 cancelReq := make(chan struct{})
6362 req.Cancel = cancelReq
6363
6364 res, err := cst.c.Do(req)
6365 if err != nil {
6366 close(cancelHandler)
6367 t.Fatalf("Get error: %v", err)
6368 }
6369
6370
6371
6372
6373 close(cancelReq)
6374 got, err := io.ReadAll(res.Body)
6375 if err == nil {
6376 t.Errorf("unexpected success; read %q, nil", got)
6377 }
6378
6379
6380 <-inHandler
6381 <-handlerDone
6382 }
6383
6384 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6385 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6386 }
6387 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6388 done := make(chan struct{})
6389 defer close(done)
6390 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6391 conn, _, err := w.(Hijacker).Hijack()
6392 if err != nil {
6393 t.Error(err)
6394 return
6395 }
6396 defer conn.Close()
6397 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6398 bs := bufio.NewScanner(conn)
6399 bs.Scan()
6400 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6401 <-done
6402 }))
6403
6404 req, _ := NewRequest("GET", cst.ts.URL, nil)
6405 req.Header.Set("Upgrade", "foo")
6406 req.Header.Set("Connection", "upgrade")
6407 res, err := cst.c.Do(req)
6408 if err != nil {
6409 t.Fatal(err)
6410 }
6411 if res.StatusCode != 101 {
6412 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6413 }
6414 rwc, ok := res.Body.(io.ReadWriteCloser)
6415 if !ok {
6416 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6417 }
6418 defer rwc.Close()
6419 bs := bufio.NewScanner(rwc)
6420 if !bs.Scan() {
6421 t.Fatalf("expected readable input")
6422 }
6423 if got, want := bs.Text(), "Some buffered data"; got != want {
6424 t.Errorf("read %q; want %q", got, want)
6425 }
6426 io.WriteString(rwc, "echo\n")
6427 if !bs.Scan() {
6428 t.Fatalf("expected another line")
6429 }
6430 if got, want := bs.Text(), "ECHO"; got != want {
6431 t.Errorf("read %q; want %q", got, want)
6432 }
6433 }
6434
6435 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6436 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6437 const target = "backend:443"
6438 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6439 if r.Method != "CONNECT" {
6440 t.Errorf("unexpected method %q", r.Method)
6441 w.WriteHeader(500)
6442 return
6443 }
6444 if r.RequestURI != target {
6445 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6446 w.WriteHeader(500)
6447 return
6448 }
6449 nc, brw, err := w.(Hijacker).Hijack()
6450 if err != nil {
6451 t.Error(err)
6452 return
6453 }
6454 defer nc.Close()
6455 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6456
6457 for {
6458 line, err := brw.ReadString('\n')
6459 if err != nil {
6460 if err != io.EOF {
6461 t.Error(err)
6462 }
6463 return
6464 }
6465 io.WriteString(brw, strings.ToUpper(line))
6466 brw.Flush()
6467 }
6468 }))
6469 pr, pw := io.Pipe()
6470 defer pw.Close()
6471 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6472 if err != nil {
6473 t.Fatal(err)
6474 }
6475 req.URL.Opaque = target
6476 res, err := cst.c.Do(req)
6477 if err != nil {
6478 t.Fatal(err)
6479 }
6480 defer res.Body.Close()
6481 if res.StatusCode != 200 {
6482 t.Fatalf("status code = %d; want 200", res.StatusCode)
6483 }
6484 br := bufio.NewReader(res.Body)
6485 for _, str := range []string{"foo", "bar", "baz"} {
6486 fmt.Fprintf(pw, "%s\n", str)
6487 got, err := br.ReadString('\n')
6488 if err != nil {
6489 t.Fatal(err)
6490 }
6491 got = strings.TrimSpace(got)
6492 want := strings.ToUpper(str)
6493 if got != want {
6494 t.Fatalf("got %q; want %q", got, want)
6495 }
6496 }
6497 }
6498
6499 func TestTransportRequestReplayable(t *testing.T) {
6500 someBody := io.NopCloser(strings.NewReader(""))
6501 tests := []struct {
6502 name string
6503 req *Request
6504 want bool
6505 }{
6506 {
6507 name: "GET",
6508 req: &Request{Method: "GET"},
6509 want: true,
6510 },
6511 {
6512 name: "GET_http.NoBody",
6513 req: &Request{Method: "GET", Body: NoBody},
6514 want: true,
6515 },
6516 {
6517 name: "GET_body",
6518 req: &Request{Method: "GET", Body: someBody},
6519 want: false,
6520 },
6521 {
6522 name: "POST",
6523 req: &Request{Method: "POST"},
6524 want: false,
6525 },
6526 {
6527 name: "POST_idempotency-key",
6528 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6529 want: true,
6530 },
6531 {
6532 name: "POST_x-idempotency-key",
6533 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6534 want: true,
6535 },
6536 {
6537 name: "POST_body",
6538 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6539 want: false,
6540 },
6541 }
6542 for _, tt := range tests {
6543 t.Run(tt.name, func(t *testing.T) {
6544 got := tt.req.ExportIsReplayable()
6545 if got != tt.want {
6546 t.Errorf("replyable = %v; want %v", got, tt.want)
6547 }
6548 })
6549 }
6550 }
6551
6552
6553
6554 type testMockTCPConn struct {
6555 *net.TCPConn
6556
6557 ReadFromCalled bool
6558 }
6559
6560 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6561 c.ReadFromCalled = true
6562 return c.TCPConn.ReadFrom(r)
6563 }
6564
6565 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6566 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6567 nBytes := int64(1 << 10)
6568 newFileFunc := func() (r io.Reader, done func(), err error) {
6569 f, err := os.CreateTemp("", "net-http-newfilefunc")
6570 if err != nil {
6571 return nil, nil, err
6572 }
6573
6574
6575 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6576 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6577 }
6578 if _, err := f.Seek(0, 0); err != nil {
6579 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6580 }
6581
6582 done = func() {
6583 f.Close()
6584 os.Remove(f.Name())
6585 }
6586
6587 return f, done, nil
6588 }
6589
6590 newBufferFunc := func() (io.Reader, func(), error) {
6591 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6592 }
6593
6594 cases := []struct {
6595 name string
6596 readerFunc func() (io.Reader, func(), error)
6597 contentLength int64
6598 expectedReadFrom bool
6599 }{
6600 {
6601 name: "file, length",
6602 readerFunc: newFileFunc,
6603 contentLength: nBytes,
6604 expectedReadFrom: true,
6605 },
6606 {
6607 name: "file, no length",
6608 readerFunc: newFileFunc,
6609 },
6610 {
6611 name: "file, negative length",
6612 readerFunc: newFileFunc,
6613 contentLength: -1,
6614 },
6615 {
6616 name: "buffer",
6617 contentLength: nBytes,
6618 readerFunc: newBufferFunc,
6619 },
6620 {
6621 name: "buffer, no length",
6622 readerFunc: newBufferFunc,
6623 },
6624 {
6625 name: "buffer, length -1",
6626 contentLength: -1,
6627 readerFunc: newBufferFunc,
6628 },
6629 }
6630
6631 for _, tc := range cases {
6632 t.Run(tc.name, func(t *testing.T) {
6633 r, cleanup, err := tc.readerFunc()
6634 if err != nil {
6635 t.Fatal(err)
6636 }
6637 defer cleanup()
6638
6639 tConn := &testMockTCPConn{}
6640 trFunc := func(tr *Transport) {
6641 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6642 var d net.Dialer
6643 conn, err := d.DialContext(ctx, network, addr)
6644 if err != nil {
6645 return nil, err
6646 }
6647
6648 tcpConn, ok := conn.(*net.TCPConn)
6649 if !ok {
6650 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6651 }
6652
6653 tConn.TCPConn = tcpConn
6654 return tConn, nil
6655 }
6656 }
6657
6658 cst := newClientServerTest(
6659 t,
6660 mode,
6661 HandlerFunc(func(w ResponseWriter, r *Request) {
6662 io.Copy(io.Discard, r.Body)
6663 r.Body.Close()
6664 w.WriteHeader(200)
6665 }),
6666 trFunc,
6667 )
6668
6669 req, err := NewRequest("PUT", cst.ts.URL, r)
6670 if err != nil {
6671 t.Fatal(err)
6672 }
6673 req.ContentLength = tc.contentLength
6674 req.Header.Set("Content-Type", "application/octet-stream")
6675 resp, err := cst.c.Do(req)
6676 if err != nil {
6677 t.Fatal(err)
6678 }
6679 defer resp.Body.Close()
6680 if resp.StatusCode != 200 {
6681 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6682 }
6683
6684 expectedReadFrom := tc.expectedReadFrom
6685 if mode != http1Mode {
6686 expectedReadFrom = false
6687 }
6688 if !tConn.ReadFromCalled && expectedReadFrom {
6689 t.Fatalf("did not call ReadFrom")
6690 }
6691
6692 if tConn.ReadFromCalled && !expectedReadFrom {
6693 t.Fatalf("ReadFrom was unexpectedly invoked")
6694 }
6695 })
6696 }
6697 }
6698
6699 func TestTransportClone(t *testing.T) {
6700 tr := &Transport{
6701 Proxy: func(*Request) (*url.URL, error) { panic("") },
6702 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6703 return nil
6704 },
6705 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6706 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6707 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6708 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6709 TLSClientConfig: new(tls.Config),
6710 TLSHandshakeTimeout: time.Second,
6711 DisableKeepAlives: true,
6712 DisableCompression: true,
6713 MaxIdleConns: 1,
6714 MaxIdleConnsPerHost: 1,
6715 MaxConnsPerHost: 1,
6716 IdleConnTimeout: time.Second,
6717 ResponseHeaderTimeout: time.Second,
6718 ExpectContinueTimeout: time.Second,
6719 ProxyConnectHeader: Header{},
6720 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6721 MaxResponseHeaderBytes: 1,
6722 ForceAttemptHTTP2: true,
6723 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6724 Protocols: &Protocols{},
6725 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6726 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6727 },
6728 ReadBufferSize: 1,
6729 WriteBufferSize: 1,
6730 }
6731 tr.Protocols.SetHTTP1(true)
6732 tr.Protocols.SetHTTP2(true)
6733 tr2 := tr.Clone()
6734 rv := reflect.ValueOf(tr2).Elem()
6735 rt := rv.Type()
6736 for i := 0; i < rt.NumField(); i++ {
6737 sf := rt.Field(i)
6738 if !token.IsExported(sf.Name) {
6739 continue
6740 }
6741 if rv.Field(i).IsZero() {
6742 t.Errorf("cloned field t2.%s is zero", sf.Name)
6743 }
6744 }
6745
6746 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6747 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6748 }
6749
6750
6751 tr = new(Transport)
6752 tr2 = tr.Clone()
6753 if tr2.TLSNextProto != nil {
6754 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6755 }
6756 }
6757
6758 func TestIs408(t *testing.T) {
6759 tests := []struct {
6760 in string
6761 want bool
6762 }{
6763 {"HTTP/1.0 408", true},
6764 {"HTTP/1.1 408", true},
6765 {"HTTP/1.8 408", true},
6766 {"HTTP/2.0 408", false},
6767 {"HTTP/1.1 408 ", true},
6768 {"HTTP/1.1 40", false},
6769 {"http/1.0 408", false},
6770 {"HTTP/1-1 408", false},
6771 }
6772 for _, tt := range tests {
6773 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6774 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6775 }
6776 }
6777 }
6778
6779 func TestTransportIgnores408(t *testing.T) {
6780 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6781 }
6782 func testTransportIgnores408(t *testing.T, mode testMode) {
6783
6784 defer log.SetOutput(log.Writer())
6785
6786 var logout strings.Builder
6787 log.SetOutput(&logout)
6788
6789 const target = "backend:443"
6790
6791 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6792 nc, _, err := w.(Hijacker).Hijack()
6793 if err != nil {
6794 t.Error(err)
6795 return
6796 }
6797 defer nc.Close()
6798 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6799 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6800 }))
6801 req, err := NewRequest("GET", cst.ts.URL, nil)
6802 if err != nil {
6803 t.Fatal(err)
6804 }
6805 res, err := cst.c.Do(req)
6806 if err != nil {
6807 t.Fatal(err)
6808 }
6809 slurp, err := io.ReadAll(res.Body)
6810 if err != nil {
6811 t.Fatal(err)
6812 }
6813 if err != nil {
6814 t.Fatal(err)
6815 }
6816 if string(slurp) != "ok" {
6817 t.Fatalf("got %q; want ok", slurp)
6818 }
6819
6820 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6821 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6822 if d > 0 {
6823 t.Logf("%v idle conns still present after %v", n, d)
6824 }
6825 return false
6826 }
6827 return true
6828 })
6829 if got := logout.String(); got != "" {
6830 t.Fatalf("expected no log output; got: %s", got)
6831 }
6832 }
6833
6834 func TestInvalidHeaderResponse(t *testing.T) {
6835 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6836 }
6837 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6838 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6839 conn, buf, _ := w.(Hijacker).Hijack()
6840 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6841 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6842 "Content-Type: text/html; charset=utf-8\r\n" +
6843 "Content-Length: 0\r\n" +
6844 "Foo : bar\r\n\r\n"))
6845 buf.Flush()
6846 conn.Close()
6847 }))
6848 res, err := cst.c.Get(cst.ts.URL)
6849 if err != nil {
6850 t.Fatal(err)
6851 }
6852 defer res.Body.Close()
6853 if v := res.Header.Get("Foo"); v != "" {
6854 t.Errorf(`unexpected "Foo" header: %q`, v)
6855 }
6856 if v := res.Header.Get("Foo "); v != "bar" {
6857 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6858 }
6859 }
6860
6861 type bodyCloser bool
6862
6863 func (bc *bodyCloser) Close() error {
6864 *bc = true
6865 return nil
6866 }
6867 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6868 return 0, io.EOF
6869 }
6870
6871
6872
6873 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6874 run(t, testTransportClosesBodyOnInvalidRequests)
6875 }
6876 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6877 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6878 t.Errorf("Should not have been invoked")
6879 })).ts
6880
6881 u, _ := url.Parse(cst.URL)
6882
6883 tests := []struct {
6884 name string
6885 req *Request
6886 wantErr string
6887 }{
6888 {
6889 name: "invalid method",
6890 req: &Request{
6891 Method: " ",
6892 URL: u,
6893 },
6894 wantErr: `invalid method " "`,
6895 },
6896 {
6897 name: "nil URL",
6898 req: &Request{
6899 Method: "GET",
6900 },
6901 wantErr: `nil Request.URL`,
6902 },
6903 {
6904 name: "invalid header key",
6905 req: &Request{
6906 Method: "GET",
6907 Header: Header{"💡": {"emoji"}},
6908 URL: u,
6909 },
6910 wantErr: `invalid header field name "💡"`,
6911 },
6912 {
6913 name: "invalid header value",
6914 req: &Request{
6915 Method: "POST",
6916 Header: Header{"key": {"\x19"}},
6917 URL: u,
6918 },
6919 wantErr: `invalid header field value for "key"`,
6920 },
6921 {
6922 name: "non HTTP(s) scheme",
6923 req: &Request{
6924 Method: "POST",
6925 URL: &url.URL{Scheme: "faux"},
6926 },
6927 wantErr: `unsupported protocol scheme "faux"`,
6928 },
6929 {
6930 name: "no Host in URL",
6931 req: &Request{
6932 Method: "POST",
6933 URL: &url.URL{Scheme: "http"},
6934 },
6935 wantErr: `no Host in request URL`,
6936 },
6937 }
6938
6939 for _, tt := range tests {
6940 t.Run(tt.name, func(t *testing.T) {
6941 var bc bodyCloser
6942 req := tt.req
6943 req.Body = &bc
6944 _, err := cst.Client().Do(tt.req)
6945 if err == nil {
6946 t.Fatal("Expected an error")
6947 }
6948 if !bc {
6949 t.Fatal("Expected body to have been closed")
6950 }
6951 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6952 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6953 }
6954 })
6955 }
6956 }
6957
6958
6959
6960 type breakableConn struct {
6961 net.Conn
6962 *brokenState
6963 }
6964
6965 type brokenState struct {
6966 sync.Mutex
6967 broken bool
6968 }
6969
6970 func (w *breakableConn) Write(b []byte) (n int, err error) {
6971 w.Lock()
6972 defer w.Unlock()
6973 if w.broken {
6974 return 0, errors.New("some write error")
6975 }
6976 return w.Conn.Write(b)
6977 }
6978
6979
6980 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6981 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6982 }
6983 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6984 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6985
6986 var brokenState brokenState
6987
6988 const numReqs = 5
6989 var numDials, gotConns uint32
6990
6991 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6992 atomic.AddUint32(&numDials, 1)
6993 c, err := net.Dial(netw, addr)
6994 if err != nil {
6995 t.Errorf("unexpected Dial error: %v", err)
6996 return nil, err
6997 }
6998 return &breakableConn{c, &brokenState}, err
6999 }
7000
7001 for i := 1; i <= numReqs; i++ {
7002 brokenState.Lock()
7003 brokenState.broken = false
7004 brokenState.Unlock()
7005
7006
7007
7008
7009 doBreak := i != numReqs
7010
7011 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7012 GotConn: func(info httptrace.GotConnInfo) {
7013 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
7014 atomic.AddUint32(&gotConns, 1)
7015 },
7016 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
7017 brokenState.Lock()
7018 defer brokenState.Unlock()
7019 if doBreak {
7020 brokenState.broken = true
7021 }
7022 },
7023 })
7024 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
7025 if err != nil {
7026 t.Fatal(err)
7027 }
7028 _, err = cst.c.Do(req)
7029 if doBreak != (err != nil) {
7030 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
7031 }
7032 }
7033 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
7034 t.Errorf("GotConn calls = %v; want %v", got, want)
7035 }
7036 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
7037 t.Errorf("Dials = %v; want %v", got, want)
7038 }
7039 }
7040
7041
7042
7043
7044
7045 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
7046 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
7047 }
7048 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
7049 CondSkipHTTP2(t)
7050
7051 h := HandlerFunc(func(w ResponseWriter, r *Request) {
7052 _, err := w.Write([]byte("foo"))
7053 if err != nil {
7054 t.Fatalf("Write: %v", err)
7055 }
7056 })
7057
7058 ts := newClientServerTest(t, mode, h).ts
7059
7060 c := ts.Client()
7061 tr := c.Transport.(*Transport)
7062 tr.MaxConnsPerHost = 1
7063
7064 errCh := make(chan error, 300)
7065 doReq := func() {
7066 resp, err := c.Get(ts.URL)
7067 if err != nil {
7068 errCh <- fmt.Errorf("request failed: %v", err)
7069 return
7070 }
7071 defer resp.Body.Close()
7072 _, err = io.ReadAll(resp.Body)
7073 if err != nil {
7074 errCh <- fmt.Errorf("read body failed: %v", err)
7075 }
7076 }
7077
7078 var wg sync.WaitGroup
7079 for i := 0; i < 300; i++ {
7080 wg.Add(1)
7081 go func() {
7082 defer wg.Done()
7083 doReq()
7084 }()
7085 }
7086 wg.Wait()
7087 close(errCh)
7088
7089 for err := range errCh {
7090 t.Errorf("error occurred: %v", err)
7091 }
7092 }
7093
7094
7095
7096
7097 func TestAltProtoCancellation(t *testing.T) {
7098 defer afterTest(t)
7099 tr := &Transport{}
7100 c := &Client{
7101 Transport: tr,
7102 Timeout: time.Millisecond,
7103 }
7104 tr.RegisterProtocol("cancel", cancelProto{})
7105 _, err := c.Get("cancel://bar.com/path")
7106 if err == nil {
7107 t.Error("request unexpectedly succeeded")
7108 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
7109 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
7110 }
7111 }
7112
7113 var errCancelProto = errors.New("canceled as expected")
7114
7115 type cancelProto struct{}
7116
7117 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
7118 <-req.Cancel
7119 return nil, errCancelProto
7120 }
7121
7122 type roundTripFunc func(r *Request) (*Response, error)
7123
7124 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
7125
7126
7127 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
7128 func testIssue32441(t *testing.T, mode testMode) {
7129 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7130 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
7131 t.Error("body length is zero")
7132 }
7133 })).ts
7134 c := ts.Client()
7135 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
7136
7137 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
7138 t.Error("body length is zero during round trip")
7139 }
7140 return nil, ErrSkipAltProtocol
7141 }))
7142 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
7143 t.Error(err)
7144 }
7145 }
7146
7147
7148
7149 func TestTransportRejectsSignInContentLength(t *testing.T) {
7150 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
7151 }
7152 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
7153 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7154 w.Header().Set("Content-Length", "+3")
7155 w.Write([]byte("abc"))
7156 })).ts
7157
7158 c := cst.Client()
7159 res, err := c.Get(cst.URL)
7160 if err == nil || res != nil {
7161 t.Fatal("Expected a non-nil error and a nil http.Response")
7162 }
7163 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
7164 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
7165 }
7166 }
7167
7168
7169 type dumpConn struct {
7170 io.Writer
7171 io.Reader
7172 }
7173
7174 func (c *dumpConn) Close() error { return nil }
7175 func (c *dumpConn) LocalAddr() net.Addr { return nil }
7176 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
7177 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
7178 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
7179 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
7180
7181
7182
7183 type delegateReader struct {
7184 c chan io.Reader
7185 r io.Reader
7186 }
7187
7188 func (r *delegateReader) Read(p []byte) (int, error) {
7189 if r.r == nil {
7190 var ok bool
7191 if r.r, ok = <-r.c; !ok {
7192 return 0, errors.New("delegate closed")
7193 }
7194 }
7195 return r.r.Read(p)
7196 }
7197
7198 func testTransportRace(req *Request) {
7199 save := req.Body
7200 pr, pw := io.Pipe()
7201 defer pr.Close()
7202 defer pw.Close()
7203 dr := &delegateReader{c: make(chan io.Reader)}
7204
7205 t := &Transport{
7206 Dial: func(net, addr string) (net.Conn, error) {
7207 return &dumpConn{pw, dr}, nil
7208 },
7209 }
7210 defer t.CloseIdleConnections()
7211
7212 quitReadCh := make(chan struct{})
7213
7214 go func() {
7215 defer close(quitReadCh)
7216
7217 req, err := ReadRequest(bufio.NewReader(pr))
7218 if err == nil {
7219
7220
7221 io.Copy(io.Discard, req.Body)
7222 req.Body.Close()
7223 }
7224 select {
7225 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
7226 case quitReadCh <- struct{}{}:
7227
7228 close(dr.c)
7229 }
7230 }()
7231
7232 t.RoundTrip(req)
7233
7234
7235
7236 pw.Close()
7237 <-quitReadCh
7238
7239 req.Body = save
7240 }
7241
7242
7243
7244
7245
7246 func TestErrorWriteLoopRace(t *testing.T) {
7247 if testing.Short() {
7248 return
7249 }
7250 t.Parallel()
7251 for i := 0; i < 1000; i++ {
7252 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
7253 ctx, cancel := context.WithTimeout(context.Background(), delay)
7254 defer cancel()
7255
7256 r := bytes.NewBuffer(make([]byte, 10000))
7257 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
7258 if err != nil {
7259 t.Fatal(err)
7260 }
7261
7262 testTransportRace(req)
7263 }
7264 }
7265
7266
7267
7268
7269 func TestCancelRequestWhenSharingConnection(t *testing.T) {
7270 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
7271 }
7272 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
7273 reqc := make(chan chan struct{}, 2)
7274 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
7275 ch := make(chan struct{}, 1)
7276 reqc <- ch
7277 <-ch
7278 w.Header().Add("Content-Length", "0")
7279 })).ts
7280
7281 client := ts.Client()
7282 transport := client.Transport.(*Transport)
7283 transport.MaxIdleConns = 1
7284 transport.MaxConnsPerHost = 1
7285
7286 var wg sync.WaitGroup
7287
7288 wg.Add(1)
7289 putidlec := make(chan chan struct{}, 1)
7290 reqerrc := make(chan error, 1)
7291 go func() {
7292 defer wg.Done()
7293 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7294 PutIdleConn: func(error) {
7295
7296
7297 ch := make(chan struct{})
7298 putidlec <- ch
7299 close(putidlec)
7300 <-ch
7301 },
7302 })
7303 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
7304 res, err := client.Do(req)
7305 if err != nil {
7306 reqerrc <- err
7307 } else {
7308 res.Body.Close()
7309 }
7310 }()
7311
7312
7313
7314 select {
7315 case err := <-reqerrc:
7316 t.Fatalf("request 1: got err %v, want nil", err)
7317 case r1c := <-reqc:
7318 close(r1c)
7319 }
7320 var idlec chan struct{}
7321 select {
7322 case err := <-reqerrc:
7323 t.Fatalf("request 1: got err %v, want nil", err)
7324 case idlec = <-putidlec:
7325 }
7326
7327 wg.Add(1)
7328 cancelctx, cancel := context.WithCancel(context.Background())
7329 go func() {
7330 defer wg.Done()
7331 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
7332 res, err := client.Do(req)
7333 if err == nil {
7334 res.Body.Close()
7335 }
7336 if !errors.Is(err, context.Canceled) {
7337 t.Errorf("request 2: got err %v, want Canceled", err)
7338 }
7339
7340
7341 close(idlec)
7342 }()
7343
7344
7345
7346 r2c := <-reqc
7347 cancel()
7348
7349 <-idlec
7350
7351 close(r2c)
7352 wg.Wait()
7353 }
7354
7355 func TestHandlerAbortRacesBodyRead(t *testing.T) {
7356 run(t, testHandlerAbortRacesBodyRead, http3SkippedMode)
7357 }
7358 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
7359 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7360 go io.Copy(io.Discard, req.Body)
7361 panic(ErrAbortHandler)
7362 })).ts
7363
7364 var wg sync.WaitGroup
7365 for i := 0; i < 2; i++ {
7366 wg.Add(1)
7367 go func() {
7368 defer wg.Done()
7369 for j := 0; j < 10; j++ {
7370 const reqLen = 6 * 1024 * 1024
7371 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7372 req.ContentLength = reqLen
7373 resp, _ := ts.Client().Transport.RoundTrip(req)
7374 if resp != nil {
7375 resp.Body.Close()
7376 }
7377 }
7378 }()
7379 }
7380 wg.Wait()
7381 }
7382
7383 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7384 func testRequestSanitization(t *testing.T, mode testMode) {
7385 if mode == http2Mode {
7386
7387 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7388 }
7389 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7390 if h, ok := req.Header["X-Evil"]; ok {
7391 t.Errorf("request has X-Evil header: %q", h)
7392 }
7393 })).ts
7394 req, _ := NewRequest("GET", ts.URL, nil)
7395 req.Host = "go.dev\r\nX-Evil:evil"
7396 resp, _ := ts.Client().Do(req)
7397 if resp != nil {
7398 resp.Body.Close()
7399 }
7400 }
7401
7402 func TestProxyAuthHeader(t *testing.T) {
7403
7404 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7405 }
7406 func testProxyAuthHeader(t *testing.T, mode testMode) {
7407 const username = "u"
7408 const password = "@/?!"
7409 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7410
7411
7412 var r2 Request
7413 r2.Header = Header{
7414 "Authorization": req.Header["Proxy-Authorization"],
7415 }
7416 gotuser, gotpass, ok := r2.BasicAuth()
7417 if !ok || gotuser != username || gotpass != password {
7418 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7419 }
7420 }))
7421 u, err := url.Parse(cst.ts.URL)
7422 if err != nil {
7423 t.Fatal(err)
7424 }
7425 u.User = url.UserPassword(username, password)
7426 t.Setenv("HTTP_PROXY", u.String())
7427 cst.tr.Proxy = ProxyURL(u)
7428 resp, err := cst.c.Get("http://_/")
7429 if err != nil {
7430 t.Fatal(err)
7431 }
7432 resp.Body.Close()
7433 }
7434
7435
7436 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7437 ln := newLocalListener(t)
7438 addr := ln.Addr().String()
7439
7440 done := make(chan struct{})
7441 go func() {
7442 conn, err := ln.Accept()
7443 if err != nil {
7444 t.Errorf("ln.Accept: %v", err)
7445 return
7446 }
7447
7448
7449 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7450 t.Errorf("conn.Read: %v", err)
7451 return
7452 }
7453 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7454 <-done
7455 conn.Close()
7456 }()
7457
7458 didRead := make(chan bool)
7459 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7460 defer SetReadLoopBeforeNextReadHook(nil)
7461
7462 tr := &Transport{}
7463
7464
7465 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7466 if err != nil {
7467 t.Fatalf("NewRequest: %v", err)
7468 }
7469
7470 resp, err := tr.RoundTrip(req)
7471 if err != nil {
7472 t.Fatalf("tr.RoundTrip: %v", err)
7473 }
7474
7475 close(done)
7476
7477
7478
7479 <-didRead
7480
7481 resp.Body.Close()
7482
7483
7484
7485 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7486 n := tr.NumPendingRequestsForTesting()
7487 if n > 0 {
7488 if d > 0 {
7489 t.Logf("pending requests = %d after %v (want 0)", n, d)
7490 }
7491 return false
7492 }
7493 return true
7494 })
7495 }
7496
7497 func TestValidateClientRequestTrailers(t *testing.T) {
7498 run(t, testValidateClientRequestTrailers)
7499 }
7500
7501 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7502 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7503 rw.Write([]byte("Hello"))
7504 })).ts
7505
7506 cases := []struct {
7507 trailer Header
7508 wantErr string
7509 }{
7510 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7511 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7512 }
7513
7514 for i, tt := range cases {
7515 testName := fmt.Sprintf("%s%d", mode, i)
7516 t.Run(testName, func(t *testing.T) {
7517 req, err := NewRequest("GET", cst.URL, nil)
7518 if err != nil {
7519 t.Fatal(err)
7520 }
7521 req.Trailer = tt.trailer
7522 res, err := cst.Client().Do(req)
7523 if err == nil {
7524 t.Fatal("Expected an error")
7525 }
7526 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7527 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7528 }
7529 if res != nil {
7530 t.Fatal("Unexpected non-nil response")
7531 }
7532 })
7533 }
7534 }
7535
7536 func TestTransportServerProtocols(t *testing.T) {
7537 CondSkipHTTP2(t)
7538 DefaultTransport.(*Transport).CloseIdleConnections()
7539
7540 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7541 if err != nil {
7542 t.Fatal(err)
7543 }
7544 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7545 if err != nil {
7546 t.Fatal(err)
7547 }
7548 certpool := x509.NewCertPool()
7549 certpool.AddCert(leafCert)
7550
7551 for _, test := range []struct {
7552 name string
7553 scheme string
7554 setup func(t *testing.T)
7555 transport func(*Transport)
7556 server func(*Server)
7557 want string
7558 }{{
7559 name: "http default",
7560 scheme: "http",
7561 want: "HTTP/1.1",
7562 }, {
7563 name: "https default",
7564 scheme: "https",
7565 transport: func(tr *Transport) {
7566
7567 },
7568 want: "HTTP/1.1",
7569 }, {
7570 name: "https transport protocols include HTTP2",
7571 scheme: "https",
7572 transport: func(tr *Transport) {
7573
7574
7575 tr.Protocols = &Protocols{}
7576 tr.Protocols.SetHTTP1(true)
7577 tr.Protocols.SetHTTP2(true)
7578 },
7579 want: "HTTP/2.0",
7580 }, {
7581 name: "https transport protocols only include HTTP1",
7582 scheme: "https",
7583 transport: func(tr *Transport) {
7584
7585 tr.Protocols = &Protocols{}
7586 tr.Protocols.SetHTTP1(true)
7587 },
7588 want: "HTTP/1.1",
7589 }, {
7590 name: "https transport ForceAttemptHTTP2",
7591 scheme: "https",
7592 transport: func(tr *Transport) {
7593
7594 tr.ForceAttemptHTTP2 = true
7595 },
7596 want: "HTTP/2.0",
7597 }, {
7598 name: "https transport protocols override TLSNextProto",
7599 scheme: "https",
7600 transport: func(tr *Transport) {
7601
7602
7603
7604 tr.Protocols = &Protocols{}
7605 tr.Protocols.SetHTTP1(true)
7606 tr.Protocols.SetHTTP2(true)
7607 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7608 },
7609 want: "HTTP/2.0",
7610 }, {
7611 name: "https server disables HTTP2 with TLSNextProto",
7612 scheme: "https",
7613 server: func(srv *Server) {
7614
7615
7616 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7617 },
7618 want: "HTTP/1.1",
7619 }, {
7620 name: "https server Protocols overrides empty TLSNextProto",
7621 scheme: "https",
7622 server: func(srv *Server) {
7623
7624
7625 srv.Protocols = &Protocols{}
7626 srv.Protocols.SetHTTP1(true)
7627 srv.Protocols.SetHTTP2(true)
7628 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7629 },
7630 want: "HTTP/2.0",
7631 }, {
7632 name: "https server protocols only include HTTP1",
7633 scheme: "https",
7634 server: func(srv *Server) {
7635 srv.Protocols = &Protocols{}
7636 srv.Protocols.SetHTTP1(true)
7637 },
7638 want: "HTTP/1.1",
7639 }, {
7640 name: "https server protocols include HTTP2",
7641 scheme: "https",
7642 server: func(srv *Server) {
7643 srv.Protocols = &Protocols{}
7644 srv.Protocols.SetHTTP1(true)
7645 srv.Protocols.SetHTTP2(true)
7646 },
7647 want: "HTTP/2.0",
7648 }, {
7649 name: "GODEBUG disables HTTP2 client",
7650 scheme: "https",
7651 setup: func(t *testing.T) {
7652 t.Setenv("GODEBUG", "http2client=0")
7653 },
7654 transport: func(tr *Transport) {
7655
7656
7657 tr.Protocols = &Protocols{}
7658 tr.Protocols.SetHTTP1(true)
7659 tr.Protocols.SetHTTP2(true)
7660 },
7661 want: "HTTP/1.1",
7662 }, {
7663 name: "GODEBUG disables HTTP2 server",
7664 scheme: "https",
7665 setup: func(t *testing.T) {
7666 t.Setenv("GODEBUG", "http2server=0")
7667 },
7668 transport: func(tr *Transport) {
7669
7670
7671 tr.Protocols = &Protocols{}
7672 tr.Protocols.SetHTTP1(true)
7673 tr.Protocols.SetHTTP2(true)
7674 },
7675 want: "HTTP/1.1",
7676 }, {
7677 name: "unencrypted HTTP2 with prior knowledge",
7678 scheme: "http",
7679 transport: func(tr *Transport) {
7680 tr.Protocols = &Protocols{}
7681 tr.Protocols.SetUnencryptedHTTP2(true)
7682 },
7683 server: func(srv *Server) {
7684 srv.Protocols = &Protocols{}
7685 srv.Protocols.SetHTTP1(true)
7686 srv.Protocols.SetUnencryptedHTTP2(true)
7687 },
7688 want: "HTTP/2.0",
7689 }, {
7690 name: "unencrypted HTTP2 only on server",
7691 scheme: "http",
7692 transport: func(tr *Transport) {
7693 tr.Protocols = &Protocols{}
7694 tr.Protocols.SetUnencryptedHTTP2(true)
7695 },
7696 server: func(srv *Server) {
7697 srv.Protocols = &Protocols{}
7698 srv.Protocols.SetUnencryptedHTTP2(true)
7699 },
7700 want: "HTTP/2.0",
7701 }, {
7702 name: "unencrypted HTTP2 with no server support",
7703 scheme: "http",
7704 transport: func(tr *Transport) {
7705 tr.Protocols = &Protocols{}
7706 tr.Protocols.SetUnencryptedHTTP2(true)
7707 },
7708 server: func(srv *Server) {
7709 srv.Protocols = &Protocols{}
7710 srv.Protocols.SetHTTP1(true)
7711 },
7712 want: "error",
7713 }, {
7714 name: "HTTP1 with no server support",
7715 scheme: "http",
7716 transport: func(tr *Transport) {
7717 tr.Protocols = &Protocols{}
7718 tr.Protocols.SetHTTP1(true)
7719 },
7720 server: func(srv *Server) {
7721 srv.Protocols = &Protocols{}
7722 srv.Protocols.SetUnencryptedHTTP2(true)
7723 },
7724 want: "error",
7725 }, {
7726 name: "HTTPS1 with no server support",
7727 scheme: "https",
7728 transport: func(tr *Transport) {
7729 tr.Protocols = &Protocols{}
7730 tr.Protocols.SetHTTP1(true)
7731 },
7732 server: func(srv *Server) {
7733 srv.Protocols = &Protocols{}
7734 srv.Protocols.SetHTTP2(true)
7735 },
7736 want: "error",
7737 }} {
7738 t.Run(test.name, func(t *testing.T) {
7739
7740
7741 srv := &Server{
7742 TLSConfig: &tls.Config{
7743 Certificates: []tls.Certificate{cert},
7744 },
7745 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7746 w.Header().Set("X-Proto", req.Proto)
7747 }),
7748 }
7749 tr := &Transport{
7750 TLSClientConfig: &tls.Config{
7751 RootCAs: certpool,
7752 },
7753 }
7754
7755 if test.setup != nil {
7756 test.setup(t)
7757 }
7758 if test.server != nil {
7759 test.server(srv)
7760 }
7761 if test.transport != nil {
7762 test.transport(tr)
7763 } else {
7764 tr.Protocols = &Protocols{}
7765 tr.Protocols.SetHTTP1(true)
7766 tr.Protocols.SetHTTP2(true)
7767 }
7768
7769 listener := newLocalListener(t)
7770 srvc := make(chan error, 1)
7771 go func() {
7772 switch test.scheme {
7773 case "http":
7774 srvc <- srv.Serve(listener)
7775 case "https":
7776 srvc <- srv.ServeTLS(listener, "", "")
7777 }
7778 }()
7779 t.Cleanup(func() {
7780 srv.Close()
7781 <-srvc
7782 })
7783
7784 client := &Client{Transport: tr}
7785 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7786 if err != nil {
7787 if test.want == "error" {
7788 return
7789 }
7790 t.Fatal(err)
7791 }
7792 if got := resp.Header.Get("X-Proto"); got != test.want {
7793 t.Fatalf("request proto %q, want %q", got, test.want)
7794 }
7795 })
7796 }
7797 }
7798
7799 func TestIssue61474(t *testing.T) {
7800 run(t, testIssue61474, []testMode{http2Mode})
7801 }
7802 func testIssue61474(t *testing.T, mode testMode) {
7803 if testing.Short() {
7804 return
7805 }
7806
7807
7808
7809
7810 t.Skip("test is too large")
7811
7812 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7813 }), func(tr *Transport) {
7814 tr.MaxConnsPerHost = 1
7815 })
7816 var wg sync.WaitGroup
7817 defer wg.Wait()
7818 for range 100000 {
7819 wg.Go(func() {
7820 ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond)
7821 defer cancel()
7822 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
7823 resp, err := cst.c.Do(req)
7824 if err == nil {
7825 resp.Body.Close()
7826 }
7827 })
7828 }
7829 }
7830
View as plain text