Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 "net/http"
30 . "net/http"
31 "net/http/httptest"
32 "net/http/httptrace"
33 "net/http/httputil"
34 "net/http/internal/testcert"
35 "net/textproto"
36 "net/url"
37 "os"
38 "reflect"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "testing"
46 "testing/iotest"
47 "testing/synctest"
48 "time"
49
50 "golang.org/x/net/http/httpguts"
51 )
52
53
54
55
56
57 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
58 if r.FormValue("close") == "true" {
59 w.Header().Set("Connection", "close")
60 }
61 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
62 w.Write([]byte(r.RemoteAddr))
63
64
65
66 if c, ok := ResponseWriterConnForTesting(w); ok {
67 fmt.Fprintf(w, ", %T %p", c, c)
68 }
69 })
70
71
72 type testCloseConn struct {
73 net.Conn
74 set *testConnSet
75 }
76
77 func (c *testCloseConn) Close() error {
78 c.set.remove(c)
79 return c.Conn.Close()
80 }
81
82
83
84 type testConnSet struct {
85 t *testing.T
86 mu sync.Mutex
87 closed map[net.Conn]bool
88 list []net.Conn
89 }
90
91 func (tcs *testConnSet) insert(c net.Conn) {
92 tcs.mu.Lock()
93 defer tcs.mu.Unlock()
94 tcs.closed[c] = false
95 tcs.list = append(tcs.list, c)
96 }
97
98 func (tcs *testConnSet) remove(c net.Conn) {
99 tcs.mu.Lock()
100 defer tcs.mu.Unlock()
101 tcs.closed[c] = true
102 }
103
104
105 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
106 connSet := &testConnSet{
107 t: t,
108 closed: make(map[net.Conn]bool),
109 }
110 dial := func(n, addr string) (net.Conn, error) {
111 c, err := net.Dial(n, addr)
112 if err != nil {
113 return nil, err
114 }
115 tc := &testCloseConn{c, connSet}
116 connSet.insert(tc)
117 return tc, nil
118 }
119 return connSet, dial
120 }
121
122 func (tcs *testConnSet) check(t *testing.T) {
123 tcs.mu.Lock()
124 defer tcs.mu.Unlock()
125 for i := 4; i >= 0; i-- {
126 for i, c := range tcs.list {
127 if tcs.closed[c] {
128 continue
129 }
130 if i != 0 {
131
132
133 tcs.mu.Unlock()
134 time.Sleep(50 * time.Millisecond)
135 tcs.mu.Lock()
136 continue
137 }
138 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
139 }
140 }
141 }
142
143 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
144 func testReuseRequest(t *testing.T, mode testMode) {
145 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
146 w.Write([]byte("{}"))
147 })).ts
148
149 c := ts.Client()
150 req, _ := NewRequest("GET", ts.URL, nil)
151 res, err := c.Do(req)
152 if err != nil {
153 t.Fatal(err)
154 }
155 err = res.Body.Close()
156 if err != nil {
157 t.Fatal(err)
158 }
159
160 res, err = c.Do(req)
161 if err != nil {
162 t.Fatal(err)
163 }
164 err = res.Body.Close()
165 if err != nil {
166 t.Fatal(err)
167 }
168 }
169
170
171
172 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
173 func testTransportKeepAlives(t *testing.T, mode testMode) {
174 ts := newClientServerTest(t, mode, hostPortHandler).ts
175
176 c := ts.Client()
177 for _, disableKeepAlive := range []bool{false, true} {
178 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
179 fetch := func(n int) string {
180 res, err := c.Get(ts.URL)
181 if err != nil {
182 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
183 }
184 body, err := io.ReadAll(res.Body)
185 if err != nil {
186 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
187 }
188 return string(body)
189 }
190
191 body1 := fetch(1)
192 body2 := fetch(2)
193
194 bodiesDiffer := body1 != body2
195 if bodiesDiffer != disableKeepAlive {
196 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
197 disableKeepAlive, bodiesDiffer, body1, body2)
198 }
199 }
200 }
201
202 func TestTransportConnectionCloseOnResponse(t *testing.T) {
203 run(t, testTransportConnectionCloseOnResponse)
204 }
205 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
206 ts := newClientServerTest(t, mode, hostPortHandler).ts
207
208 connSet, testDial := makeTestDial(t)
209
210 c := ts.Client()
211 tr := c.Transport.(*Transport)
212 tr.Dial = testDial
213
214 for _, connectionClose := range []bool{false, true} {
215 fetch := func(n int) string {
216 req := new(Request)
217 var err error
218 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
219 if err != nil {
220 t.Fatalf("URL parse error: %v", err)
221 }
222 req.Method = "GET"
223 req.Proto = "HTTP/1.1"
224 req.ProtoMajor = 1
225 req.ProtoMinor = 1
226
227 res, err := c.Do(req)
228 if err != nil {
229 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
230 }
231 defer res.Body.Close()
232 body, err := io.ReadAll(res.Body)
233 if err != nil {
234 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
235 }
236 return string(body)
237 }
238
239 body1 := fetch(1)
240 body2 := fetch(2)
241 bodiesDiffer := body1 != body2
242 if bodiesDiffer != connectionClose {
243 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
244 connectionClose, bodiesDiffer, body1, body2)
245 }
246
247 tr.CloseIdleConnections()
248 }
249
250 connSet.check(t)
251 }
252
253
254
255
256
257
258
259 func TestTransportConnectionCloseOnRequest(t *testing.T) {
260 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
261 }
262 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
263 ts := newClientServerTest(t, mode, hostPortHandler).ts
264
265 connSet, testDial := makeTestDial(t)
266
267 c := ts.Client()
268 tr := c.Transport.(*Transport)
269 tr.Dial = testDial
270 for _, reqClose := range []bool{false, true} {
271 fetch := func(n int) string {
272 req := new(Request)
273 var err error
274 req.URL, err = url.Parse(ts.URL)
275 if err != nil {
276 t.Fatalf("URL parse error: %v", err)
277 }
278 req.Method = "GET"
279 req.Proto = "HTTP/1.1"
280 req.ProtoMajor = 1
281 req.ProtoMinor = 1
282 req.Close = reqClose
283
284 res, err := c.Do(req)
285 if err != nil {
286 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
287 }
288 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
289 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
290 reqClose, got, !reqClose)
291 }
292 body, err := io.ReadAll(res.Body)
293 if err != nil {
294 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
295 }
296 return string(body)
297 }
298
299 body1 := fetch(1)
300 body2 := fetch(2)
301
302 got := 1
303 if body1 != body2 {
304 got++
305 }
306 want := 1
307 if reqClose {
308 want = 2
309 }
310 if got != want {
311 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
312 reqClose, got, want, body1, body2)
313 }
314
315 tr.CloseIdleConnections()
316 }
317
318 connSet.check(t)
319 }
320
321
322
323
324 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
325 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
326 }
327 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
328 ts := newClientServerTest(t, mode, hostPortHandler).ts
329
330 c := ts.Client()
331 c.Transport.(*Transport).DisableKeepAlives = true
332
333 res, err := c.Get(ts.URL)
334 if err != nil {
335 t.Fatal(err)
336 }
337 res.Body.Close()
338 if res.Header.Get("X-Saw-Close") != "true" {
339 t.Errorf("handler didn't see Connection: close ")
340 }
341 }
342
343
344
345 func TestTransportRespectRequestWantsClose(t *testing.T) {
346 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
347 }
348 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
349 tests := []struct {
350 disableKeepAlives bool
351 close bool
352 }{
353 {disableKeepAlives: false, close: false},
354 {disableKeepAlives: false, close: true},
355 {disableKeepAlives: true, close: false},
356 {disableKeepAlives: true, close: true},
357 }
358
359 for _, tc := range tests {
360 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
361 func(t *testing.T) {
362 ts := newClientServerTest(t, mode, hostPortHandler).ts
363
364 c := ts.Client()
365 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
366 req, err := NewRequest("GET", ts.URL, nil)
367 if err != nil {
368 t.Fatal(err)
369 }
370 count := 0
371 trace := &httptrace.ClientTrace{
372 WroteHeaderField: func(key string, field []string) {
373 if key != "Connection" {
374 return
375 }
376 if httpguts.HeaderValuesContainsToken(field, "close") {
377 count += 1
378 }
379 },
380 }
381 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
382 req.Close = tc.close
383 res, err := c.Do(req)
384 if err != nil {
385 t.Fatal(err)
386 }
387 defer res.Body.Close()
388 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
389 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
390 }
391 })
392 }
393
394 }
395
396 func TestTransportIdleCacheKeys(t *testing.T) {
397 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
398 }
399 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
400 ts := newClientServerTest(t, mode, hostPortHandler).ts
401 c := ts.Client()
402 tr := c.Transport.(*Transport)
403
404 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
405 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
406 }
407
408 resp, err := c.Get(ts.URL)
409 if err != nil {
410 t.Error(err)
411 }
412 io.ReadAll(resp.Body)
413
414 keys := tr.IdleConnKeysForTesting()
415 if e, g := 1, len(keys); e != g {
416 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
417 }
418
419 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
420 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
421 }
422
423 tr.CloseIdleConnections()
424 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
425 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
426 }
427 }
428
429
430
431 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
432 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
433 const msg = "foobar"
434
435 var addrSeen map[string]int
436 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
437 addrSeen[r.RemoteAddr]++
438 if r.URL.Path == "/chunked/" {
439 w.WriteHeader(200)
440 w.(Flusher).Flush()
441 } else {
442 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
443 w.WriteHeader(200)
444 }
445 w.Write([]byte(msg))
446 })).ts
447
448 for pi, path := range []string{"/content-length/", "/chunked/"} {
449 wantLen := []int{len(msg), -1}[pi]
450 addrSeen = make(map[string]int)
451 for i := 0; i < 3; i++ {
452 res, err := ts.Client().Get(ts.URL + path)
453 if err != nil {
454 t.Errorf("Get %s: %v", path, err)
455 continue
456 }
457
458
459
460
461
462 defer res.Body.Close()
463
464 if res.ContentLength != int64(wantLen) {
465 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
466 }
467 got, err := io.ReadAll(res.Body)
468 if string(got) != msg || err != nil {
469 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
470 }
471 }
472 if len(addrSeen) != 1 {
473 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
474 }
475 }
476 }
477
478
479
480
481
482 func TestTransportNotReadToEndConnectionReuse(t *testing.T) {
483 run(t, testTransportNotReadToEndConnectionReuse, []testMode{http1Mode, https1Mode})
484 }
485 func testTransportNotReadToEndConnectionReuse(t *testing.T, mode testMode) {
486 tests := []struct {
487 name string
488 bodyLen int
489 contentLenKnown bool
490 headRequest bool
491 timeBetweenReqs time.Duration
492 responseTime time.Duration
493 wantReuse bool
494 }{
495 {
496 name: "unconsumed body within drain limit",
497 bodyLen: 200 * 1024,
498 timeBetweenReqs: http.MaxPostCloseReadTime,
499 wantReuse: true,
500 },
501 {
502 name: "unconsumed body within drain limit with known length",
503 bodyLen: 200 * 1024,
504 contentLenKnown: true,
505 timeBetweenReqs: http.MaxPostCloseReadTime,
506 wantReuse: true,
507 },
508 {
509 name: "unconsumed body larger than drain limit",
510 bodyLen: 500 * 1024,
511 timeBetweenReqs: http.MaxPostCloseReadTime,
512 wantReuse: false,
513 },
514 {
515 name: "unconsumed body larger than drain limit with known length",
516 bodyLen: 500 * 1024,
517 contentLenKnown: true,
518 timeBetweenReqs: http.MaxPostCloseReadTime,
519 wantReuse: false,
520 },
521 {
522 name: "new requests start before drain for old requests are finished",
523 bodyLen: 200 * 1024,
524 timeBetweenReqs: 0,
525 responseTime: time.Minute,
526 wantReuse: false,
527 },
528 {
529
530
531 name: "unconsumed body larger than drain limit for HEAD request",
532 bodyLen: 500 * 1024,
533 headRequest: true,
534 wantReuse: true,
535 },
536 }
537
538 for _, tc := range tests {
539 subtest := func(t *testing.T) {
540 addrSeen := make(map[string]int)
541 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
542 addrSeen[r.RemoteAddr]++
543 time.Sleep(tc.responseTime)
544 if tc.contentLenKnown {
545 w.Header().Add("Content-Length", strconv.Itoa(tc.bodyLen))
546 }
547 w.Write(slices.Repeat([]byte("a"), tc.bodyLen))
548 }), optFakeNet).ts
549
550 var wg sync.WaitGroup
551 for range 10 {
552 wg.Go(func() {
553 method := http.MethodGet
554 if tc.headRequest {
555 method = http.MethodHead
556 }
557 ctx, cancel := context.WithCancel(context.Background())
558 req, err := http.NewRequestWithContext(ctx, method, ts.URL, nil)
559 if err != nil {
560 log.Fatal(err)
561 }
562 resp, err := ts.Client().Do(req)
563 if err != nil {
564 t.Fatal(err)
565 }
566 if resp.StatusCode != http.StatusOK {
567 t.Errorf("expected HTTP 200, got: %v", resp.StatusCode)
568 }
569 resp.Body.Close()
570
571
572 cancel()
573 if n, err := resp.Body.Read([]byte{}); n != 0 || err == nil {
574 t.Errorf("read after body has been closed should not succeed, but read %v byte with %v error", n, err)
575 }
576 })
577 time.Sleep(tc.timeBetweenReqs)
578 synctest.Wait()
579 }
580 wg.Wait()
581 if (len(addrSeen) == 1) != tc.wantReuse {
582 t.Errorf("want connection reuse to be %v, but %v connections were created", tc.wantReuse, len(addrSeen))
583 }
584 }
585 t.Run(tc.name, func(t *testing.T) {
586 synctest.Test(t, subtest)
587 })
588 }
589 }
590
591 func TestTransportMaxPerHostIdleConns(t *testing.T) {
592 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
593 }
594 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
595 stop := make(chan struct{})
596 defer close(stop)
597
598 resch := make(chan string)
599 gotReq := make(chan bool)
600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
601 gotReq <- true
602 var msg string
603 select {
604 case <-stop:
605 return
606 case msg = <-resch:
607 }
608 _, err := w.Write([]byte(msg))
609 if err != nil {
610 t.Errorf("Write: %v", err)
611 return
612 }
613 })).ts
614
615 c := ts.Client()
616 tr := c.Transport.(*Transport)
617 maxIdleConnsPerHost := 2
618 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
619
620
621
622 donech := make(chan bool)
623 doReq := func() {
624 defer func() {
625 select {
626 case <-stop:
627 return
628 case donech <- t.Failed():
629 }
630 }()
631 resp, err := c.Get(ts.URL)
632 if err != nil {
633 t.Error(err)
634 return
635 }
636 if _, err := io.ReadAll(resp.Body); err != nil {
637 t.Errorf("ReadAll: %v", err)
638 return
639 }
640 }
641 go doReq()
642 <-gotReq
643 go doReq()
644 <-gotReq
645 go doReq()
646 <-gotReq
647
648 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
649 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
650 }
651
652 resch <- "res1"
653 <-donech
654 keys := tr.IdleConnKeysForTesting()
655 if e, g := 1, len(keys); e != g {
656 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
657 }
658 addr := ts.Listener.Addr().String()
659 cacheKey := "|http|" + addr
660 if keys[0] != cacheKey {
661 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
662 }
663 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
664 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
665 }
666
667 resch <- "res2"
668 <-donech
669 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
670 t.Errorf("after second response, idle conns = %d; want %d", g, w)
671 }
672
673 resch <- "res3"
674 <-donech
675 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
676 t.Errorf("after third response, idle conns = %d; want %d", g, w)
677 }
678 }
679
680 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
681 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
682 }
683 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
684 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
685 _, err := w.Write([]byte("foo"))
686 if err != nil {
687 t.Fatalf("Write: %v", err)
688 }
689 })).ts
690 c := ts.Client()
691 tr := c.Transport.(*Transport)
692 dialStarted := make(chan struct{})
693 stallDial := make(chan struct{})
694 tr.Dial = func(network, addr string) (net.Conn, error) {
695 dialStarted <- struct{}{}
696 <-stallDial
697 return net.Dial(network, addr)
698 }
699
700 tr.DisableKeepAlives = true
701 tr.MaxConnsPerHost = 1
702
703 preDial := make(chan struct{})
704 reqComplete := make(chan struct{})
705 doReq := func(reqId string) {
706 req, _ := NewRequest("GET", ts.URL, nil)
707 trace := &httptrace.ClientTrace{
708 GetConn: func(hostPort string) {
709 preDial <- struct{}{}
710 },
711 }
712 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
713 resp, err := tr.RoundTrip(req)
714 if err != nil {
715 t.Errorf("unexpected error for request %s: %v", reqId, err)
716 }
717 _, err = io.ReadAll(resp.Body)
718 if err != nil {
719 t.Errorf("unexpected error for request %s: %v", reqId, err)
720 }
721 reqComplete <- struct{}{}
722 }
723
724 go doReq("req1")
725 <-preDial
726 <-dialStarted
727
728
729 go doReq("req2")
730 <-preDial
731 select {
732 case <-dialStarted:
733 t.Error("req2 dial started while req1 dial in progress")
734 return
735 default:
736 }
737
738
739 stallDial <- struct{}{}
740 <-reqComplete
741
742
743 <-dialStarted
744 stallDial <- struct{}{}
745 <-reqComplete
746 }
747
748 func TestTransportMaxConnsPerHost(t *testing.T) {
749 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
750 }
751 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
752 CondSkipHTTP2(t)
753
754 h := HandlerFunc(func(w ResponseWriter, r *Request) {
755 _, err := w.Write([]byte("foo"))
756 if err != nil {
757 t.Fatalf("Write: %v", err)
758 }
759 })
760
761 ts := newClientServerTest(t, mode, h).ts
762 c := ts.Client()
763 tr := c.Transport.(*Transport)
764 tr.MaxConnsPerHost = 1
765
766 mu := sync.Mutex{}
767 var conns []net.Conn
768 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
769 tr.Dial = func(network, addr string) (net.Conn, error) {
770 atomic.AddInt32(&dialCnt, 1)
771 c, err := net.Dial(network, addr)
772 mu.Lock()
773 defer mu.Unlock()
774 conns = append(conns, c)
775 return c, err
776 }
777
778 doReq := func() {
779 trace := &httptrace.ClientTrace{
780 GotConn: func(connInfo httptrace.GotConnInfo) {
781 if !connInfo.Reused {
782 atomic.AddInt32(&gotConnCnt, 1)
783 }
784 },
785 TLSHandshakeStart: func() {
786 atomic.AddInt32(&tlsHandshakeCnt, 1)
787 },
788 }
789 req, _ := NewRequest("GET", ts.URL, nil)
790 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
791
792 resp, err := c.Do(req)
793 if err != nil {
794 t.Fatalf("request failed: %v", err)
795 }
796 defer resp.Body.Close()
797 _, err = io.ReadAll(resp.Body)
798 if err != nil {
799 t.Fatalf("read body failed: %v", err)
800 }
801 }
802
803 wg := sync.WaitGroup{}
804 for i := 0; i < 10; i++ {
805 wg.Add(1)
806 go func() {
807 defer wg.Done()
808 doReq()
809 }()
810 }
811 wg.Wait()
812
813 expected := int32(tr.MaxConnsPerHost)
814 if dialCnt != expected {
815 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
816 }
817 if gotConnCnt != expected {
818 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
819 }
820 if ts.TLS != nil && tlsHandshakeCnt != expected {
821 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
822 }
823
824 if t.Failed() {
825 t.FailNow()
826 }
827
828 mu.Lock()
829 for _, c := range conns {
830 c.Close()
831 }
832 conns = nil
833 mu.Unlock()
834 tr.CloseIdleConnections()
835
836 doReq()
837 expected++
838 if dialCnt != expected {
839 t.Errorf("round 2: too many dials: %d", dialCnt)
840 }
841 if gotConnCnt != expected {
842 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
843 }
844 if ts.TLS != nil && tlsHandshakeCnt != expected {
845 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
846 }
847 }
848
849 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
850 run(t, testTransportMaxConnsPerHostDialCancellation,
851 testNotParallel,
852 []testMode{http1Mode, https1Mode, http2Mode},
853 )
854 }
855
856 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
857 CondSkipHTTP2(t)
858
859 h := HandlerFunc(func(w ResponseWriter, r *Request) {
860 _, err := w.Write([]byte("foo"))
861 if err != nil {
862 t.Fatalf("Write: %v", err)
863 }
864 })
865
866 cst := newClientServerTest(t, mode, h)
867 defer cst.close()
868 ts := cst.ts
869 c := ts.Client()
870 tr := c.Transport.(*Transport)
871 tr.MaxConnsPerHost = 1
872
873
874 ctx, cancel := context.WithCancel(context.Background())
875 defer cancel()
876 SetPendingDialHooks(cancel, nil)
877 defer SetPendingDialHooks(nil, nil)
878
879 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
880 _, err := c.Do(req)
881 if !errors.Is(err, context.Canceled) {
882 t.Errorf("expected error %v, got %v", context.Canceled, err)
883 }
884
885
886 SetPendingDialHooks(nil, nil)
887 req, _ = NewRequest("GET", ts.URL, nil)
888 resp, err := c.Do(req)
889 if err != nil {
890 t.Fatalf("request failed: %v", err)
891 }
892 defer resp.Body.Close()
893 _, err = io.ReadAll(resp.Body)
894 if err != nil {
895 t.Fatalf("read body failed: %v", err)
896 }
897 }
898
899 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
900 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
901 }
902 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
903 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
904 io.WriteString(w, r.RemoteAddr)
905 })).ts
906
907 c := ts.Client()
908 tr := c.Transport.(*Transport)
909
910 doReq := func(name string) {
911
912
913 res, err := c.Post(ts.URL, "", nil)
914 if err != nil {
915 t.Fatalf("%s: %v", name, err)
916 }
917 if res.StatusCode != 200 {
918 t.Fatalf("%s: %v", name, res.Status)
919 }
920 defer res.Body.Close()
921 slurp, err := io.ReadAll(res.Body)
922 if err != nil {
923 t.Fatalf("%s: %v", name, err)
924 }
925 t.Logf("%s: ok (%q)", name, slurp)
926 }
927
928 doReq("first")
929 keys1 := tr.IdleConnKeysForTesting()
930
931 ts.CloseClientConnections()
932
933 var keys2 []string
934 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
935 keys2 = tr.IdleConnKeysForTesting()
936 if len(keys2) != 0 {
937 if d > 0 {
938 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
939 }
940 return false
941 }
942 return true
943 })
944
945 doReq("second")
946 }
947
948
949
950 func TestTransportServerClosingUnexpectedly(t *testing.T) {
951 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
952 }
953 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
954 ts := newClientServerTest(t, mode, hostPortHandler).ts
955 c := ts.Client()
956
957 fetch := func(n, retries int) string {
958 condFatalf := func(format string, arg ...any) {
959 if retries <= 0 {
960 t.Fatalf(format, arg...)
961 }
962 t.Logf("retrying shortly after expected error: "+format, arg...)
963 time.Sleep(time.Second / time.Duration(retries))
964 }
965 for retries >= 0 {
966 retries--
967 res, err := c.Get(ts.URL)
968 if err != nil {
969 condFatalf("error in req #%d, GET: %v", n, err)
970 continue
971 }
972 body, err := io.ReadAll(res.Body)
973 if err != nil {
974 condFatalf("error in req #%d, ReadAll: %v", n, err)
975 continue
976 }
977 res.Body.Close()
978 return string(body)
979 }
980 panic("unreachable")
981 }
982
983 body1 := fetch(1, 0)
984 body2 := fetch(2, 0)
985
986
987
988
989
990
991
992
993 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
994
995 body3 := fetch(3, 5)
996
997 if body1 != body2 {
998 t.Errorf("expected body1 and body2 to be equal")
999 }
1000 if body2 == body3 {
1001 t.Errorf("expected body2 and body3 to be different")
1002 }
1003 }
1004
1005
1006
1007 func TestStressSurpriseServerCloses(t *testing.T) {
1008 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
1009 }
1010 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
1011 if testing.Short() {
1012 t.Skip("skipping test in short mode")
1013 }
1014 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1015 w.Header().Set("Content-Length", "5")
1016 w.Header().Set("Content-Type", "text/plain")
1017 w.Write([]byte("Hello"))
1018 w.(Flusher).Flush()
1019 conn, buf, _ := w.(Hijacker).Hijack()
1020 buf.Flush()
1021 conn.Close()
1022 })).ts
1023 c := ts.Client()
1024
1025
1026
1027
1028
1029
1030
1031 const (
1032 numClients = 20
1033 reqsPerClient = 25
1034 )
1035 var wg sync.WaitGroup
1036 wg.Add(numClients * reqsPerClient)
1037 for i := 0; i < numClients; i++ {
1038 go func() {
1039 for i := 0; i < reqsPerClient; i++ {
1040 res, err := c.Get(ts.URL)
1041 if err == nil {
1042
1043
1044
1045
1046
1047
1048 res.Body.Close()
1049 }
1050 wg.Done()
1051 }
1052 }()
1053 }
1054
1055
1056 wg.Wait()
1057 }
1058
1059
1060
1061 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
1062 func testTransportHeadResponses(t *testing.T, mode testMode) {
1063 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1064 if r.Method != "HEAD" {
1065 panic("expected HEAD; got " + r.Method)
1066 }
1067 w.Header().Set("Content-Length", "123")
1068 w.WriteHeader(200)
1069 })).ts
1070 c := ts.Client()
1071
1072 for i := 0; i < 2; i++ {
1073 res, err := c.Head(ts.URL)
1074 if err != nil {
1075 t.Errorf("error on loop %d: %v", i, err)
1076 continue
1077 }
1078 if e, g := "123", res.Header.Get("Content-Length"); e != g {
1079 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
1080 }
1081 if e, g := int64(123), res.ContentLength; e != g {
1082 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
1083 }
1084 if all, err := io.ReadAll(res.Body); err != nil {
1085 t.Errorf("loop %d: Body ReadAll: %v", i, err)
1086 } else if len(all) != 0 {
1087 t.Errorf("Bogus body %q", all)
1088 }
1089 }
1090 }
1091
1092
1093
1094 func TestTransportHeadChunkedResponse(t *testing.T) {
1095 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
1096 }
1097 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
1098 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1099 if r.Method != "HEAD" {
1100 panic("expected HEAD; got " + r.Method)
1101 }
1102 w.Header().Set("Transfer-Encoding", "chunked")
1103 w.Header().Set("x-client-ipport", r.RemoteAddr)
1104 w.WriteHeader(200)
1105 })).ts
1106 c := ts.Client()
1107
1108
1109
1110 didRead := make(chan bool)
1111 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
1112 defer SetReadLoopBeforeNextReadHook(nil)
1113
1114 res1, err := c.Head(ts.URL)
1115 <-didRead
1116
1117 if err != nil {
1118 t.Fatalf("request 1 error: %v", err)
1119 }
1120
1121 res2, err := c.Head(ts.URL)
1122 <-didRead
1123
1124 if err != nil {
1125 t.Fatalf("request 2 error: %v", err)
1126 }
1127 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1128 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1129 }
1130 }
1131
1132 var roundTripTests = []struct {
1133 accept string
1134 expectAccept string
1135 compressed bool
1136 }{
1137
1138 {"", "gzip", false},
1139
1140 {"foo", "foo", false},
1141
1142 {"gzip", "gzip", true},
1143 }
1144
1145
1146 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1147 func testRoundTripGzip(t *testing.T, mode testMode) {
1148 const responseBody = "test response body"
1149 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1150 accept := req.Header.Get("Accept-Encoding")
1151 if expect := req.FormValue("expect_accept"); accept != expect {
1152 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1153 req.FormValue("testnum"), accept, expect)
1154 }
1155 if accept == "gzip" {
1156 rw.Header().Set("Content-Encoding", "gzip")
1157 gz := gzip.NewWriter(rw)
1158 gz.Write([]byte(responseBody))
1159 gz.Close()
1160 } else {
1161 rw.Header().Set("Content-Encoding", accept)
1162 rw.Write([]byte(responseBody))
1163 }
1164 })).ts
1165 tr := ts.Client().Transport.(*Transport)
1166
1167 for i, test := range roundTripTests {
1168
1169 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1170 if test.accept != "" {
1171 req.Header.Set("Accept-Encoding", test.accept)
1172 }
1173 res, err := tr.RoundTrip(req)
1174 if err != nil {
1175 t.Errorf("%d. RoundTrip: %v", i, err)
1176 continue
1177 }
1178 var body []byte
1179 if test.compressed {
1180 var r *gzip.Reader
1181 r, err = gzip.NewReader(res.Body)
1182 if err != nil {
1183 t.Errorf("%d. gzip NewReader: %v", i, err)
1184 continue
1185 }
1186 body, err = io.ReadAll(r)
1187 res.Body.Close()
1188 } else {
1189 body, err = io.ReadAll(res.Body)
1190 }
1191 if err != nil {
1192 t.Errorf("%d. Error: %q", i, err)
1193 continue
1194 }
1195 if g, e := string(body), responseBody; g != e {
1196 t.Errorf("%d. body = %q; want %q", i, g, e)
1197 }
1198 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1199 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1200 }
1201 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1202 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1203 }
1204 }
1205
1206 }
1207
1208 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1209 func testTransportGzip(t *testing.T, mode testMode) {
1210 if mode == http2Mode {
1211 t.Skip("https://go.dev/issue/56020")
1212 }
1213 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1214 const nRandBytes = 1024 * 1024
1215 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1216 if req.Method == "HEAD" {
1217 if g := req.Header.Get("Accept-Encoding"); g != "" {
1218 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1219 }
1220 return
1221 }
1222 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1223 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1224 }
1225 rw.Header().Set("Content-Encoding", "gzip")
1226
1227 var w io.Writer = rw
1228 var buf bytes.Buffer
1229 if req.FormValue("chunked") == "0" {
1230 w = &buf
1231 defer io.Copy(rw, &buf)
1232 defer func() {
1233 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1234 }()
1235 }
1236 gz := gzip.NewWriter(w)
1237 gz.Write([]byte(testString))
1238 if req.FormValue("body") == "large" {
1239 io.CopyN(gz, rand.Reader, nRandBytes)
1240 }
1241 gz.Close()
1242 })).ts
1243 c := ts.Client()
1244
1245 for _, chunked := range []string{"1", "0"} {
1246
1247 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1248 if err != nil {
1249 t.Fatalf("large get: %v", err)
1250 }
1251 buf := make([]byte, len(testString))
1252 n, err := io.ReadFull(res.Body, buf)
1253 if err != nil {
1254 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1255 }
1256 if e, g := testString, string(buf); e != g {
1257 t.Errorf("partial read got %q, expected %q", g, e)
1258 }
1259 res.Body.Close()
1260
1261 n, err = res.Body.Read(buf)
1262 if n != 0 || err == nil {
1263 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1264 }
1265
1266
1267 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1268 if err != nil {
1269 t.Fatal(err)
1270 }
1271 body, err := io.ReadAll(res.Body)
1272 if err != nil {
1273 t.Fatal(err)
1274 }
1275 if g, e := string(body), testString; g != e {
1276 t.Fatalf("body = %q; want %q", g, e)
1277 }
1278 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1279 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1280 }
1281
1282
1283 n, err = res.Body.Read(buf)
1284 if n != 0 || err == nil {
1285 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1286 }
1287 res.Body.Close()
1288 n, err = res.Body.Read(buf)
1289 if n != 0 || err == nil {
1290 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1291 }
1292 }
1293
1294
1295 res, err := c.Head(ts.URL)
1296 if err != nil {
1297 t.Fatalf("Head: %v", err)
1298 }
1299 if res.StatusCode != 200 {
1300 t.Errorf("Head status=%d; want=200", res.StatusCode)
1301 }
1302 }
1303
1304
1305
1306 type transport100ContinueTest struct {
1307 t *testing.T
1308
1309 reqdone chan struct{}
1310 resp *Response
1311 respErr error
1312
1313 conn net.Conn
1314 reader *bufio.Reader
1315 }
1316
1317 const transport100ContinueTestBody = "request body"
1318
1319
1320
1321 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1322 ln := newLocalListener(t)
1323 defer ln.Close()
1324
1325 test := &transport100ContinueTest{
1326 t: t,
1327 reqdone: make(chan struct{}),
1328 }
1329
1330 tr := &Transport{
1331 ExpectContinueTimeout: timeout,
1332 }
1333 go func() {
1334 defer close(test.reqdone)
1335 body := strings.NewReader(transport100ContinueTestBody)
1336 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1337 req.Header.Set("Expect", "100-continue")
1338 req.ContentLength = int64(len(transport100ContinueTestBody))
1339 test.resp, test.respErr = tr.RoundTrip(req)
1340 test.resp.Body.Close()
1341 }()
1342
1343 c, err := ln.Accept()
1344 if err != nil {
1345 t.Fatalf("Accept: %v", err)
1346 }
1347 t.Cleanup(func() {
1348 c.Close()
1349 })
1350 br := bufio.NewReader(c)
1351 _, err = ReadRequest(br)
1352 if err != nil {
1353 t.Fatalf("ReadRequest: %v", err)
1354 }
1355 test.conn = c
1356 test.reader = br
1357 t.Cleanup(func() {
1358 <-test.reqdone
1359 tr.CloseIdleConnections()
1360 got, _ := io.ReadAll(test.reader)
1361 if len(got) > 0 {
1362 t.Fatalf("Transport sent unexpected bytes: %q", got)
1363 }
1364 })
1365
1366 return test
1367 }
1368
1369
1370 func (test *transport100ContinueTest) respond(lines ...string) {
1371 for _, line := range lines {
1372 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1373 test.t.Fatalf("Write: %v", err)
1374 }
1375 }
1376 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1377 test.t.Fatalf("Write: %v", err)
1378 }
1379 }
1380
1381
1382 func (test *transport100ContinueTest) wantBodySent() {
1383 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1384 if err != nil {
1385 test.t.Fatalf("unexpected error reading body: %v", err)
1386 }
1387 if got, want := string(got), transport100ContinueTestBody; got != want {
1388 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1389 }
1390 }
1391
1392
1393 func (test *transport100ContinueTest) wantRequestDone(want int) {
1394 <-test.reqdone
1395 if test.respErr != nil {
1396 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1397 }
1398 if got := test.resp.StatusCode; got != want {
1399 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1400 }
1401 }
1402
1403 func TestTransportExpect100ContinueSent(t *testing.T) {
1404 test := newTransport100ContinueTest(t, 1*time.Hour)
1405
1406 test.respond("HTTP/1.1 100 Continue")
1407 test.wantBodySent()
1408 test.respond("HTTP/1.1 200", "Content-Length: 0")
1409 test.wantRequestDone(200)
1410 }
1411
1412 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1413 test := newTransport100ContinueTest(t, 1*time.Hour)
1414
1415 test.respond("HTTP/1.1 200", "Content-Length: 0")
1416 test.wantBodySent()
1417 test.wantRequestDone(200)
1418 }
1419
1420 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1421 test := newTransport100ContinueTest(t, 1*time.Hour)
1422
1423 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1424 test.wantRequestDone(200)
1425 }
1426
1427 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1428 test := newTransport100ContinueTest(t, 1*time.Hour)
1429
1430 test.respond("HTTP/1.1 500", "Content-Length: 0")
1431 test.wantBodySent()
1432 test.wantRequestDone(500)
1433 }
1434
1435 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1436 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1437 test.wantBodySent()
1438 test.respond("HTTP/1.1 200", "Content-Length: 0")
1439 test.wantRequestDone(200)
1440 }
1441
1442 func TestSOCKS5Proxy(t *testing.T) {
1443 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1444 }
1445 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1446 ch := make(chan string, 1)
1447 l := newLocalListener(t)
1448 defer l.Close()
1449 defer close(ch)
1450 proxy := func(t *testing.T) {
1451 s, err := l.Accept()
1452 if err != nil {
1453 t.Errorf("socks5 proxy Accept(): %v", err)
1454 return
1455 }
1456 defer s.Close()
1457 var buf [22]byte
1458 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1459 t.Errorf("socks5 proxy initial read: %v", err)
1460 return
1461 }
1462 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1463 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1464 return
1465 }
1466 if _, err := s.Write([]byte{5, 0}); err != nil {
1467 t.Errorf("socks5 proxy initial write: %v", err)
1468 return
1469 }
1470 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1471 t.Errorf("socks5 proxy second read: %v", err)
1472 return
1473 }
1474 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1475 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1476 return
1477 }
1478 var ipLen int
1479 switch buf[3] {
1480 case 1:
1481 ipLen = net.IPv4len
1482 case 4:
1483 ipLen = net.IPv6len
1484 default:
1485 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1486 return
1487 }
1488 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1489 t.Errorf("socks5 proxy address read: %v", err)
1490 return
1491 }
1492 ip := net.IP(buf[4 : ipLen+4])
1493 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1494 copy(buf[:3], []byte{5, 0, 0})
1495 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1496 t.Errorf("socks5 proxy connect write: %v", err)
1497 return
1498 }
1499 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1500
1501
1502 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1503 targetConn, err := net.Dial("tcp", targetHost)
1504 if err != nil {
1505 t.Errorf("net.Dial failed")
1506 return
1507 }
1508 go io.Copy(targetConn, s)
1509 io.Copy(s, targetConn)
1510 targetConn.Close()
1511 }
1512
1513 pu, err := url.Parse("socks5://" + l.Addr().String())
1514 if err != nil {
1515 t.Fatal(err)
1516 }
1517
1518 sentinelHeader := "X-Sentinel"
1519 sentinelValue := "12345"
1520 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1521 w.Header().Set(sentinelHeader, sentinelValue)
1522 })
1523 for _, useTLS := range []bool{false, true} {
1524 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1525 ts := newClientServerTest(t, mode, h).ts
1526 go proxy(t)
1527 c := ts.Client()
1528 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1529 r, err := c.Head(ts.URL)
1530 if err != nil {
1531 t.Fatal(err)
1532 }
1533 if r.Header.Get(sentinelHeader) != sentinelValue {
1534 t.Errorf("Failed to retrieve sentinel value")
1535 }
1536 got := <-ch
1537 ts.Close()
1538 tsu, err := url.Parse(ts.URL)
1539 if err != nil {
1540 t.Fatal(err)
1541 }
1542 want := "proxy for " + tsu.Host
1543 if got != want {
1544 t.Errorf("got %q, want %q", got, want)
1545 }
1546 })
1547 }
1548 }
1549
1550 func TestTransportProxy(t *testing.T) {
1551 defer afterTest(t)
1552 testCases := []struct{ siteMode, proxyMode testMode }{
1553 {http1Mode, http1Mode},
1554 {http1Mode, https1Mode},
1555 {https1Mode, http1Mode},
1556 {https1Mode, https1Mode},
1557 }
1558 for _, testCase := range testCases {
1559 siteMode := testCase.siteMode
1560 proxyMode := testCase.proxyMode
1561 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1562 siteCh := make(chan *Request, 1)
1563 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1564 siteCh <- r
1565 })
1566 proxyCh := make(chan *Request, 1)
1567 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1568 proxyCh <- r
1569
1570 if r.Method == "CONNECT" {
1571 hijacker, ok := w.(Hijacker)
1572 if !ok {
1573 t.Errorf("hijack not allowed")
1574 return
1575 }
1576 clientConn, _, err := hijacker.Hijack()
1577 if err != nil {
1578 t.Errorf("hijacking failed")
1579 return
1580 }
1581 res := &Response{
1582 StatusCode: StatusOK,
1583 Proto: "HTTP/1.1",
1584 ProtoMajor: 1,
1585 ProtoMinor: 1,
1586 Header: make(Header),
1587 }
1588
1589 targetConn, err := net.Dial("tcp", r.URL.Host)
1590 if err != nil {
1591 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1592 return
1593 }
1594
1595 if err := res.Write(clientConn); err != nil {
1596 t.Errorf("Writing 200 OK failed: %v", err)
1597 return
1598 }
1599
1600 go io.Copy(targetConn, clientConn)
1601 go func() {
1602 io.Copy(clientConn, targetConn)
1603 targetConn.Close()
1604 }()
1605 }
1606 })
1607 ts := newClientServerTest(t, siteMode, h1).ts
1608 proxy := newClientServerTest(t, proxyMode, h2).ts
1609
1610 pu, err := url.Parse(proxy.URL)
1611 if err != nil {
1612 t.Fatal(err)
1613 }
1614
1615
1616
1617
1618 c := proxy.Client()
1619 if siteMode == https1Mode {
1620 c = ts.Client()
1621 }
1622
1623 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1624 if _, err := c.Head(ts.URL); err != nil {
1625 t.Error(err)
1626 }
1627 got := <-proxyCh
1628 c.Transport.(*Transport).CloseIdleConnections()
1629 ts.Close()
1630 proxy.Close()
1631 if siteMode == https1Mode {
1632
1633 if got.Method != "CONNECT" {
1634 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1635 }
1636 gotHost := got.URL.Host
1637 pu, err := url.Parse(ts.URL)
1638 if err != nil {
1639 t.Fatal("Invalid site URL")
1640 }
1641 if wantHost := pu.Host; gotHost != wantHost {
1642 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1643 }
1644
1645
1646 next := <-siteCh
1647 if next.Method != "HEAD" {
1648 t.Errorf("Wrong method at destination: %s", next.Method)
1649 }
1650 if nextURL := next.URL.String(); nextURL != "/" {
1651 t.Errorf("Wrong URL at destination: %s", nextURL)
1652 }
1653 } else {
1654 if got.Method != "HEAD" {
1655 t.Errorf("Wrong method for destination: %q", got.Method)
1656 }
1657 gotURL := got.URL.String()
1658 wantURL := ts.URL + "/"
1659 if gotURL != wantURL {
1660 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1661 }
1662 }
1663 })
1664 }
1665 }
1666
1667
1668
1669
1670 func TestProxyWithInfiniteHeader(t *testing.T) {
1671 defer afterTest(t)
1672
1673 ln := newLocalListener(t)
1674 defer ln.Close()
1675 cancelc := make(chan struct{})
1676 defer close(cancelc)
1677
1678
1679
1680 go func() {
1681 c, err := ln.Accept()
1682 if err != nil {
1683 t.Errorf("Accept: %v", err)
1684 return
1685 }
1686 defer c.Close()
1687
1688 br := bufio.NewReader(c)
1689 cr, err := ReadRequest(br)
1690 if err != nil {
1691 t.Errorf("proxy server failed to read CONNECT request")
1692 return
1693 }
1694 if cr.Method != "CONNECT" {
1695 t.Errorf("unexpected method %q", cr.Method)
1696 return
1697 }
1698
1699
1700 for {
1701
1702
1703
1704 runtime.Gosched()
1705 select {
1706 case <-cancelc:
1707 return
1708 default:
1709 c.Write([]byte("infinite stream of bytes"))
1710 }
1711 }
1712 }()
1713
1714 c := &Client{
1715 Transport: &Transport{
1716 Proxy: func(*Request) (*url.URL, error) {
1717 return url.Parse("http://" + ln.Addr().String())
1718 },
1719
1720 MaxResponseHeaderBytes: 1024,
1721 },
1722 }
1723 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1724 if err != nil {
1725 t.Fatal(err)
1726 }
1727 _, err = c.Do(req)
1728 if err == nil {
1729 t.Errorf("unexpected Get success")
1730 }
1731 }
1732
1733 func TestOnProxyConnectResponse(t *testing.T) {
1734
1735 var tcases = []struct {
1736 proxyStatusCode int
1737 err error
1738 }{
1739 {
1740 StatusOK,
1741 nil,
1742 },
1743 {
1744 StatusForbidden,
1745 errors.New("403"),
1746 },
1747 }
1748 for _, tcase := range tcases {
1749 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1750
1751 })
1752
1753 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1754
1755 if r.Method == "CONNECT" {
1756 if tcase.proxyStatusCode != StatusOK {
1757 w.WriteHeader(tcase.proxyStatusCode)
1758 return
1759 }
1760 hijacker, ok := w.(Hijacker)
1761 if !ok {
1762 t.Errorf("hijack not allowed")
1763 return
1764 }
1765 clientConn, _, err := hijacker.Hijack()
1766 if err != nil {
1767 t.Errorf("hijacking failed")
1768 return
1769 }
1770 res := &Response{
1771 StatusCode: StatusOK,
1772 Proto: "HTTP/1.1",
1773 ProtoMajor: 1,
1774 ProtoMinor: 1,
1775 Header: make(Header),
1776 }
1777
1778 targetConn, err := net.Dial("tcp", r.URL.Host)
1779 if err != nil {
1780 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1781 return
1782 }
1783
1784 if err := res.Write(clientConn); err != nil {
1785 t.Errorf("Writing 200 OK failed: %v", err)
1786 return
1787 }
1788
1789 go io.Copy(targetConn, clientConn)
1790 go func() {
1791 io.Copy(clientConn, targetConn)
1792 targetConn.Close()
1793 }()
1794 }
1795 })
1796 ts := newClientServerTest(t, https1Mode, h1).ts
1797 proxy := newClientServerTest(t, https1Mode, h2).ts
1798
1799 pu, err := url.Parse(proxy.URL)
1800 if err != nil {
1801 t.Fatal(err)
1802 }
1803
1804 c := proxy.Client()
1805
1806 var (
1807 dials atomic.Int32
1808 closes atomic.Int32
1809 )
1810 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1811 conn, err := net.Dial(network, addr)
1812 if err != nil {
1813 return nil, err
1814 }
1815 dials.Add(1)
1816 return noteCloseConn{
1817 Conn: conn,
1818 closeFunc: func() {
1819 closes.Add(1)
1820 },
1821 }, nil
1822 }
1823
1824 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1825 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1826 if proxyURL.String() != pu.String() {
1827 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1828 }
1829
1830 if "https://"+connectReq.URL.String() != ts.URL {
1831 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1832 }
1833 return tcase.err
1834 }
1835 wantCloses := int32(0)
1836 if _, err := c.Head(ts.URL); err != nil {
1837 wantCloses = 1
1838 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1839 t.Errorf("got %v, want %v", err, tcase.err)
1840 }
1841 } else {
1842 if tcase.err != nil {
1843 t.Errorf("got %v, want nil", err)
1844 }
1845 }
1846 if got, want := dials.Load(), int32(1); got != want {
1847 t.Errorf("got %v dials, want %v", got, want)
1848 }
1849
1850 if got, want := closes.Load(), wantCloses; got != want {
1851 t.Errorf("got %v closes, want %v", got, want)
1852 }
1853 }
1854 }
1855
1856
1857
1858 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1859 cancelc := make(chan struct{})
1860 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1861 ctx, cancel := context.WithCancel(ctx)
1862 go func() {
1863 select {
1864 case <-cancelc:
1865 case <-ctx.Done():
1866 }
1867 cancel()
1868 }()
1869 return ctx, cancel
1870 })
1871
1872 defer afterTest(t)
1873
1874 ln := newLocalListener(t)
1875 defer ln.Close()
1876 listenerDone := make(chan struct{})
1877 go func() {
1878 defer close(listenerDone)
1879 c, err := ln.Accept()
1880 if err != nil {
1881 t.Errorf("Accept: %v", err)
1882 return
1883 }
1884 defer c.Close()
1885
1886 br := bufio.NewReader(c)
1887 cr, err := ReadRequest(br)
1888 if err != nil {
1889 t.Errorf("proxy server failed to read CONNECT request")
1890 return
1891 }
1892 if cr.Method != "CONNECT" {
1893 t.Errorf("unexpected method %q", cr.Method)
1894 return
1895 }
1896
1897
1898
1899
1900 close(cancelc)
1901 var buf [1]byte
1902 _, err = br.Read(buf[:])
1903 if err != io.EOF {
1904 t.Errorf("proxy server Read err = %v; want EOF", err)
1905 }
1906 return
1907 }()
1908
1909 c := &Client{
1910 Transport: &Transport{
1911 Proxy: func(*Request) (*url.URL, error) {
1912 return url.Parse("http://" + ln.Addr().String())
1913 },
1914 },
1915 }
1916 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1917 if err != nil {
1918 t.Fatal(err)
1919 }
1920 _, err = c.Do(req)
1921 if err == nil {
1922 t.Errorf("unexpected Get success")
1923 }
1924
1925
1926
1927
1928 <-listenerDone
1929 }
1930
1931
1932 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1933 defer afterTest(t)
1934
1935 var errDial = errors.New("some dial error")
1936
1937 tr := &Transport{
1938 Proxy: func(*Request) (*url.URL, error) {
1939 return url.Parse("http://proxy.fake.tld/")
1940 },
1941 Dial: func(string, string) (net.Conn, error) {
1942 return nil, errDial
1943 },
1944 }
1945 defer tr.CloseIdleConnections()
1946
1947 c := &Client{Transport: tr}
1948 req, _ := NewRequest("GET", "http://fake.tld", nil)
1949 res, err := c.Do(req)
1950 if err == nil {
1951 res.Body.Close()
1952 t.Fatal("wanted a non-nil error")
1953 }
1954
1955 uerr, ok := err.(*url.Error)
1956 if !ok {
1957 t.Fatalf("got %T, want *url.Error", err)
1958 }
1959 oe, ok := uerr.Err.(*net.OpError)
1960 if !ok {
1961 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1962 }
1963 want := &net.OpError{
1964 Op: "proxyconnect",
1965 Net: "tcp",
1966 Err: errDial,
1967 }
1968 if !reflect.DeepEqual(oe, want) {
1969 t.Errorf("Got error %#v; want %#v", oe, want)
1970 }
1971 }
1972
1973
1974
1975
1976
1977 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1978 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1979 }
1980 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1981 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1982 defer proxy.Close()
1983 c := proxy.Client()
1984
1985 tr := c.Transport.(*Transport)
1986 tr.Proxy = func(*Request) (*url.URL, error) {
1987 u, _ := url.Parse(proxy.URL)
1988 u.User = url.UserPassword("aladdin", "opensesame")
1989 return u, nil
1990 }
1991 h := tr.ProxyConnectHeader
1992 if h == nil {
1993 h = make(Header)
1994 }
1995 tr.ProxyConnectHeader = h.Clone()
1996
1997 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1998 if err != nil {
1999 t.Fatal(err)
2000 }
2001 _, err = c.Do(req)
2002 if err == nil {
2003 t.Errorf("unexpected Get success")
2004 }
2005
2006 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
2007 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
2008 }
2009 }
2010
2011
2012
2013
2014
2015 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
2016 func testTransportGzipRecursive(t *testing.T, mode testMode) {
2017 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2018 w.Header().Set("Content-Encoding", "gzip")
2019 w.Write(rgz)
2020 })).ts
2021
2022 c := ts.Client()
2023 res, err := c.Get(ts.URL)
2024 if err != nil {
2025 t.Fatal(err)
2026 }
2027 body, err := io.ReadAll(res.Body)
2028 if err != nil {
2029 t.Fatal(err)
2030 }
2031 if !bytes.Equal(body, rgz) {
2032 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
2033 body, rgz)
2034 }
2035 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
2036 t.Fatalf("Content-Encoding = %q; want %q", g, e)
2037 }
2038 }
2039
2040
2041
2042 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
2043 func testTransportGzipShort(t *testing.T, mode testMode) {
2044 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2045 w.Header().Set("Content-Encoding", "gzip")
2046 w.Write([]byte{0x1f, 0x8b})
2047 })).ts
2048
2049 c := ts.Client()
2050 res, err := c.Get(ts.URL)
2051 if err != nil {
2052 t.Fatal(err)
2053 }
2054 defer res.Body.Close()
2055 _, err = io.ReadAll(res.Body)
2056 if err == nil {
2057 t.Fatal("Expect an error from reading a body.")
2058 }
2059 if err != io.ErrUnexpectedEOF {
2060 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
2061 }
2062 }
2063
2064
2065 func waitNumGoroutine(nmax int) int {
2066 nfinal := runtime.NumGoroutine()
2067 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
2068 time.Sleep(50 * time.Millisecond)
2069 runtime.GC()
2070 nfinal = runtime.NumGoroutine()
2071 }
2072 return nfinal
2073 }
2074
2075
2076 func TestTransportPersistConnLeak(t *testing.T) {
2077 run(t, testTransportPersistConnLeak, testNotParallel)
2078 }
2079 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
2080 if mode == http2Mode {
2081 t.Skip("flaky in HTTP/2")
2082 }
2083
2084
2085 const numReq = 25
2086 gotReqCh := make(chan bool, numReq)
2087 unblockCh := make(chan bool, numReq)
2088 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2089 gotReqCh <- true
2090 <-unblockCh
2091 w.Header().Set("Content-Length", "0")
2092 w.WriteHeader(204)
2093 })).ts
2094 c := ts.Client()
2095 tr := c.Transport.(*Transport)
2096
2097 n0 := runtime.NumGoroutine()
2098
2099 didReqCh := make(chan bool, numReq)
2100 failed := make(chan bool, numReq)
2101 for i := 0; i < numReq; i++ {
2102 go func() {
2103 res, err := c.Get(ts.URL)
2104 didReqCh <- true
2105 if err != nil {
2106 t.Logf("client fetch error: %v", err)
2107 failed <- true
2108 return
2109 }
2110 res.Body.Close()
2111 }()
2112 }
2113
2114
2115 for i := 0; i < numReq; i++ {
2116 select {
2117 case <-gotReqCh:
2118
2119 case <-failed:
2120
2121
2122 }
2123 }
2124
2125 nhigh := runtime.NumGoroutine()
2126
2127
2128 close(unblockCh)
2129
2130
2131 for i := 0; i < numReq; i++ {
2132 <-didReqCh
2133 }
2134
2135 tr.CloseIdleConnections()
2136 nfinal := waitNumGoroutine(n0 + 5)
2137
2138 growth := nfinal - n0
2139
2140
2141
2142 if int(growth) > 5 {
2143 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2144 t.Error("too many new goroutines")
2145 }
2146 }
2147
2148
2149
2150 func TestTransportPersistConnLeakShortBody(t *testing.T) {
2151 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
2152 }
2153 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
2154 if mode == http2Mode {
2155 t.Skip("flaky in HTTP/2")
2156 }
2157
2158
2159 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2160 })).ts
2161 c := ts.Client()
2162 tr := c.Transport.(*Transport)
2163
2164 n0 := runtime.NumGoroutine()
2165 body := []byte("Hello")
2166 for i := 0; i < 20; i++ {
2167 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2168 if err != nil {
2169 t.Fatal(err)
2170 }
2171 req.ContentLength = int64(len(body) - 2)
2172 _, err = c.Do(req)
2173 if err == nil {
2174 t.Fatal("Expect an error from writing too long of a body.")
2175 }
2176 }
2177 nhigh := runtime.NumGoroutine()
2178 tr.CloseIdleConnections()
2179 nfinal := waitNumGoroutine(n0 + 5)
2180
2181 growth := nfinal - n0
2182
2183
2184
2185 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2186 if int(growth) > 5 {
2187 t.Error("too many new goroutines")
2188 }
2189 }
2190
2191
2192 type countedConn struct {
2193 net.Conn
2194 }
2195
2196
2197 type countingDialer struct {
2198 dialer net.Dialer
2199 mu sync.Mutex
2200 total, live int64
2201 }
2202
2203 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2204 conn, err := d.dialer.DialContext(ctx, network, address)
2205 if err != nil {
2206 return nil, err
2207 }
2208
2209 counted := new(countedConn)
2210 counted.Conn = conn
2211
2212 d.mu.Lock()
2213 defer d.mu.Unlock()
2214 d.total++
2215 d.live++
2216
2217 runtime.AddCleanup(counted, func(dd *countingDialer) { dd.decrement(nil) }, d)
2218 return counted, nil
2219 }
2220
2221 func (d *countingDialer) decrement(*countedConn) {
2222 d.mu.Lock()
2223 defer d.mu.Unlock()
2224 d.live--
2225 }
2226
2227 func (d *countingDialer) Read() (total, live int64) {
2228 d.mu.Lock()
2229 defer d.mu.Unlock()
2230 return d.total, d.live
2231 }
2232
2233 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2234 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2235 }
2236 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2237 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2238
2239 conn, _, err := w.(Hijacker).Hijack()
2240 if err != nil {
2241 t.Errorf("Hijack failed unexpectedly: %v", err)
2242 return
2243 }
2244 conn.Close()
2245 })).ts
2246
2247 var d countingDialer
2248 c := ts.Client()
2249 c.Transport.(*Transport).DialContext = d.DialContext
2250
2251 body := []byte("Hello")
2252 for i := 0; ; i++ {
2253 total, live := d.Read()
2254 if live < total {
2255 break
2256 }
2257 if i >= 1<<12 {
2258 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2259 }
2260
2261 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2262 if err != nil {
2263 t.Fatal(err)
2264 }
2265 _, err = c.Do(req)
2266 if err == nil {
2267 t.Fatal("expected broken connection")
2268 }
2269
2270 runtime.GC()
2271 }
2272 }
2273
2274 type countedContext struct {
2275 context.Context
2276 }
2277
2278 type contextCounter struct {
2279 mu sync.Mutex
2280 live int64
2281 }
2282
2283 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2284 counted := new(countedContext)
2285 counted.Context = ctx
2286 cc.mu.Lock()
2287 defer cc.mu.Unlock()
2288 cc.live++
2289 runtime.AddCleanup(counted, func(c *contextCounter) { cc.decrement(nil) }, cc)
2290 return counted
2291 }
2292
2293 func (cc *contextCounter) decrement(*countedContext) {
2294 cc.mu.Lock()
2295 defer cc.mu.Unlock()
2296 cc.live--
2297 }
2298
2299 func (cc *contextCounter) Read() (live int64) {
2300 cc.mu.Lock()
2301 defer cc.mu.Unlock()
2302 return cc.live
2303 }
2304
2305 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2306 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2307 }
2308 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2309 if mode == http2Mode {
2310 t.Skip("https://go.dev/issue/56021")
2311 }
2312
2313 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2314 runtime.Gosched()
2315 w.WriteHeader(StatusOK)
2316 })).ts
2317
2318 c := ts.Client()
2319 c.Transport.(*Transport).MaxConnsPerHost = 1
2320
2321 ctx := context.Background()
2322 body := []byte("Hello")
2323 doPosts := func(cc *contextCounter) {
2324 var wg sync.WaitGroup
2325 for n := 64; n > 0; n-- {
2326 wg.Add(1)
2327 go func() {
2328 defer wg.Done()
2329
2330 ctx := cc.Track(ctx)
2331 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2332 if err != nil {
2333 t.Error(err)
2334 }
2335
2336 _, err = c.Do(req.WithContext(ctx))
2337 if err != nil {
2338 t.Errorf("Do failed with error: %v", err)
2339 }
2340 }()
2341 }
2342 wg.Wait()
2343 }
2344
2345 var initialCC contextCounter
2346 doPosts(&initialCC)
2347
2348
2349
2350
2351 var flushCC contextCounter
2352 for i := 0; ; i++ {
2353 live := initialCC.Read()
2354 if live == 0 {
2355 break
2356 }
2357 if i >= 100 {
2358 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2359 }
2360 doPosts(&flushCC)
2361 runtime.GC()
2362 }
2363 }
2364
2365
2366 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2367 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2368 var tr *Transport
2369
2370 unblockCh := make(chan bool, 1)
2371 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2372 <-unblockCh
2373 tr.CloseIdleConnections()
2374 })).ts
2375 c := ts.Client()
2376 tr = c.Transport.(*Transport)
2377
2378 didreq := make(chan bool)
2379 go func() {
2380 res, err := c.Get(ts.URL)
2381 if err != nil {
2382 t.Error(err)
2383 } else {
2384 res.Body.Close()
2385 }
2386 didreq <- true
2387 }()
2388 unblockCh <- true
2389 <-didreq
2390 }
2391
2392
2393
2394
2395
2396 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2397 func testIssue3644(t *testing.T, mode testMode) {
2398 const numFoos = 5000
2399 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2400 w.Header().Set("Connection", "close")
2401 for i := 0; i < numFoos; i++ {
2402 w.Write([]byte("foo "))
2403 }
2404 })).ts
2405 c := ts.Client()
2406 res, err := c.Get(ts.URL)
2407 if err != nil {
2408 t.Fatal(err)
2409 }
2410 defer res.Body.Close()
2411 bs, err := io.ReadAll(res.Body)
2412 if err != nil {
2413 t.Fatal(err)
2414 }
2415 if len(bs) != numFoos*len("foo ") {
2416 t.Errorf("unexpected response length")
2417 }
2418 }
2419
2420
2421
2422 func TestIssue3595(t *testing.T) {
2423
2424 run(t, testIssue3595, testNotParallel)
2425 }
2426 func testIssue3595(t *testing.T, mode testMode) {
2427 runTimeSensitiveTest(t, []time.Duration{
2428 1 * time.Millisecond,
2429 5 * time.Millisecond,
2430 10 * time.Millisecond,
2431 50 * time.Millisecond,
2432 100 * time.Millisecond,
2433 500 * time.Millisecond,
2434 time.Second,
2435 5 * time.Second,
2436 }, func(t *testing.T, timeout time.Duration) error {
2437 SetRSTAvoidanceDelay(t, timeout)
2438 t.Logf("set RST avoidance delay to %v", timeout)
2439
2440 const deniedMsg = "sorry, denied."
2441 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2442 Error(w, deniedMsg, StatusUnauthorized)
2443 }))
2444
2445
2446 defer cst.close()
2447 ts := cst.ts
2448 c := ts.Client()
2449
2450 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2451 if err != nil {
2452 return fmt.Errorf("Post: %v", err)
2453 }
2454 got, err := io.ReadAll(res.Body)
2455 if err != nil {
2456 return fmt.Errorf("Body ReadAll: %v", err)
2457 }
2458 t.Logf("server response:\n%s", got)
2459 if !strings.Contains(string(got), deniedMsg) {
2460
2461
2462 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2463 }
2464 return nil
2465 })
2466 }
2467
2468
2469
2470 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2471 func testChunkedNoContent(t *testing.T, mode testMode) {
2472 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2473 w.WriteHeader(StatusNoContent)
2474 })).ts
2475
2476 c := ts.Client()
2477 for _, closeBody := range []bool{true, false} {
2478 const n = 4
2479 for i := 1; i <= n; i++ {
2480 res, err := c.Get(ts.URL)
2481 if err != nil {
2482 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2483 } else {
2484 if closeBody {
2485 res.Body.Close()
2486 }
2487 }
2488 }
2489 }
2490 }
2491
2492 func TestTransportConcurrency(t *testing.T) {
2493 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2494 }
2495 func testTransportConcurrency(t *testing.T, mode testMode) {
2496
2497 maxProcs, numReqs := 16, 500
2498 if testing.Short() {
2499 maxProcs, numReqs = 4, 50
2500 }
2501 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2502 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2503 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2504 })).ts
2505
2506 var wg sync.WaitGroup
2507 wg.Add(numReqs)
2508
2509
2510
2511
2512
2513
2514
2515 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2516 defer SetPendingDialHooks(nil, nil)
2517
2518 c := ts.Client()
2519 reqs := make(chan string)
2520 defer close(reqs)
2521
2522 for i := 0; i < maxProcs*2; i++ {
2523 go func() {
2524 for req := range reqs {
2525 res, err := c.Get(ts.URL + "/?echo=" + req)
2526 if err != nil {
2527 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2528
2529
2530 t.Logf("error on req %s: %v", req, err)
2531 t.Logf("(see https://go.dev/issue/52168)")
2532 } else {
2533 t.Errorf("error on req %s: %v", req, err)
2534 }
2535 wg.Done()
2536 continue
2537 }
2538 all, err := io.ReadAll(res.Body)
2539 if err != nil {
2540 t.Errorf("read error on req %s: %v", req, err)
2541 } else if string(all) != req {
2542 t.Errorf("body of req %s = %q; want %q", req, all, req)
2543 }
2544 res.Body.Close()
2545 wg.Done()
2546 }
2547 }()
2548 }
2549 for i := 0; i < numReqs; i++ {
2550 reqs <- fmt.Sprintf("request-%d", i)
2551 }
2552 wg.Wait()
2553 }
2554
2555 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2556 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2557 mux := NewServeMux()
2558 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2559 io.Copy(w, neverEnding('a'))
2560 })
2561 ts := newClientServerTest(t, mode, mux).ts
2562
2563 connc := make(chan net.Conn, 1)
2564 c := ts.Client()
2565 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2566 conn, err := net.Dial(n, addr)
2567 if err != nil {
2568 return nil, err
2569 }
2570 select {
2571 case connc <- conn:
2572 default:
2573 }
2574 return conn, nil
2575 }
2576
2577 res, err := c.Get(ts.URL + "/get")
2578 if err != nil {
2579 t.Fatalf("Error issuing GET: %v", err)
2580 }
2581 defer res.Body.Close()
2582
2583 conn := <-connc
2584 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2585 _, err = io.Copy(io.Discard, res.Body)
2586 if err == nil {
2587 t.Errorf("Unexpected successful copy")
2588 }
2589 }
2590
2591 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2592 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2593 }
2594 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2595 const debug = false
2596 mux := NewServeMux()
2597 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2598 io.Copy(w, neverEnding('a'))
2599 })
2600 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2601 defer r.Body.Close()
2602 io.Copy(io.Discard, r.Body)
2603 })
2604 ts := newClientServerTest(t, mode, mux).ts
2605 timeout := 100 * time.Millisecond
2606
2607 c := ts.Client()
2608 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2609 conn, err := net.Dial(n, addr)
2610 if err != nil {
2611 return nil, err
2612 }
2613 conn.SetDeadline(time.Now().Add(timeout))
2614 if debug {
2615 conn = NewLoggingConn("client", conn)
2616 }
2617 return conn, nil
2618 }
2619
2620 getFailed := false
2621 nRuns := 5
2622 if testing.Short() {
2623 nRuns = 1
2624 }
2625 for i := 0; i < nRuns; i++ {
2626 if debug {
2627 println("run", i+1, "of", nRuns)
2628 }
2629 sres, err := c.Get(ts.URL + "/get")
2630 if err != nil {
2631 if !getFailed {
2632
2633 getFailed = true
2634 t.Logf("increasing timeout")
2635 i--
2636 timeout *= 10
2637 continue
2638 }
2639 t.Errorf("Error issuing GET: %v", err)
2640 break
2641 }
2642 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2643 _, err = c.Do(req)
2644 if err == nil {
2645 sres.Body.Close()
2646 t.Errorf("Unexpected successful PUT")
2647 break
2648 }
2649 sres.Body.Close()
2650 }
2651 if debug {
2652 println("tests complete; waiting for handlers to finish")
2653 }
2654 ts.Close()
2655 }
2656
2657 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2658 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2659 if testing.Short() {
2660 t.Skip("skipping timeout test in -short mode")
2661 }
2662
2663 timeout := 2 * time.Millisecond
2664 retry := true
2665 for retry && !t.Failed() {
2666 var srvWG sync.WaitGroup
2667 inHandler := make(chan bool, 1)
2668 mux := NewServeMux()
2669 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2670 inHandler <- true
2671 srvWG.Done()
2672 })
2673 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2674 inHandler <- true
2675 <-r.Context().Done()
2676 srvWG.Done()
2677 })
2678 ts := newClientServerTest(t, mode, mux).ts
2679
2680 c := ts.Client()
2681 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2682
2683 retry = false
2684 srvWG.Add(3)
2685 tests := []struct {
2686 path string
2687 wantTimeout bool
2688 }{
2689 {path: "/fast"},
2690 {path: "/slow", wantTimeout: true},
2691 {path: "/fast"},
2692 }
2693 for i, tt := range tests {
2694 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2695 req = req.WithT(t)
2696 res, err := c.Do(req)
2697 <-inHandler
2698 if err != nil {
2699 uerr, ok := err.(*url.Error)
2700 if !ok {
2701 t.Errorf("error is not a url.Error; got: %#v", err)
2702 continue
2703 }
2704 nerr, ok := uerr.Err.(net.Error)
2705 if !ok {
2706 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2707 continue
2708 }
2709 if !nerr.Timeout() {
2710 t.Errorf("want timeout error; got: %q", nerr)
2711 continue
2712 }
2713 if !tt.wantTimeout {
2714 if !retry {
2715
2716 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2717 timeout *= 2
2718 retry = true
2719 }
2720 }
2721 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2722 t.Errorf("%d. unexpected error: %v", i, err)
2723 }
2724 continue
2725 }
2726 if tt.wantTimeout {
2727 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2728 continue
2729 }
2730 if res.StatusCode != 200 {
2731 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2732 }
2733 }
2734
2735 srvWG.Wait()
2736 ts.Close()
2737 }
2738 }
2739
2740
2741 type cancelTest struct {
2742 mode testMode
2743 newReq func(req *Request) *Request
2744 cancel func(tr *Transport, req *Request)
2745 checkErr func(when string, err error)
2746 }
2747
2748
2749 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2750 t.Run("TransportCancel", func(t *testing.T) {
2751 f(t, cancelTest{
2752 mode: mode,
2753 newReq: func(req *Request) *Request {
2754 return req
2755 },
2756 cancel: func(tr *Transport, req *Request) {
2757 tr.CancelRequest(req)
2758 },
2759 checkErr: func(when string, err error) {
2760 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2761 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2762 }
2763 },
2764 })
2765 })
2766 }
2767
2768
2769 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2770 cancelc := make(chan struct{})
2771 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2772 f(t, cancelTest{
2773 mode: mode,
2774 newReq: func(req *Request) *Request {
2775 req.Cancel = cancelc
2776 return req
2777 },
2778 cancel: func(tr *Transport, req *Request) {
2779 cancelOnce()
2780 },
2781 checkErr: func(when string, err error) {
2782 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2783 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2784 }
2785 },
2786 })
2787 }
2788
2789
2790 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2791 ctx, cancel := context.WithCancel(context.Background())
2792 f(t, cancelTest{
2793 mode: mode,
2794 newReq: func(req *Request) *Request {
2795 return req.WithContext(ctx)
2796 },
2797 cancel: func(tr *Transport, req *Request) {
2798 cancel()
2799 },
2800 checkErr: func(when string, err error) {
2801 if !errors.Is(err, context.Canceled) {
2802 t.Errorf("%v error = %v, want context.Canceled", when, err)
2803 }
2804 },
2805 })
2806 }
2807
2808 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2809 run(t, func(t *testing.T, mode testMode) {
2810 if mode == http1Mode {
2811 t.Run("TransportCancel", func(t *testing.T) {
2812 runCancelTestTransport(t, mode, f)
2813 })
2814 }
2815 t.Run("RequestCancel", func(t *testing.T) {
2816 runCancelTestChannel(t, mode, f)
2817 })
2818 t.Run("ContextCancel", func(t *testing.T) {
2819 runCancelTestContext(t, mode, f)
2820 })
2821 }, opts...)
2822 }
2823
2824 func TestTransportCancelRequest(t *testing.T) {
2825 runCancelTest(t, testTransportCancelRequest)
2826 }
2827 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2828 if testing.Short() {
2829 t.Skip("skipping test in -short mode")
2830 }
2831
2832 const msg = "Hello"
2833 unblockc := make(chan bool)
2834 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2835 io.WriteString(w, msg)
2836 w.(Flusher).Flush()
2837 <-unblockc
2838 })).ts
2839 defer close(unblockc)
2840
2841 c := ts.Client()
2842 tr := c.Transport.(*Transport)
2843
2844 req, _ := NewRequest("GET", ts.URL, nil)
2845 req = test.newReq(req)
2846 res, err := c.Do(req)
2847 if err != nil {
2848 t.Fatal(err)
2849 }
2850 body := make([]byte, len(msg))
2851 n, _ := io.ReadFull(res.Body, body)
2852 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2853 t.Errorf("Body = %q; want %q", body[:n], msg)
2854 }
2855 test.cancel(tr, req)
2856
2857 tail, err := io.ReadAll(res.Body)
2858 res.Body.Close()
2859 test.checkErr("Body.Read", err)
2860 if len(tail) > 0 {
2861 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2862 }
2863
2864
2865
2866 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2867 n := tr.NumPendingRequestsForTesting()
2868 if n > 0 {
2869 if d > 0 {
2870 t.Logf("pending requests = %d after %v (want 0)", n, d)
2871 }
2872 return false
2873 }
2874 return true
2875 })
2876 }
2877
2878 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2879 if testing.Short() {
2880 t.Skip("skipping test in -short mode")
2881 }
2882 unblockc := make(chan bool)
2883 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2884 <-unblockc
2885 })).ts
2886 defer close(unblockc)
2887
2888 c := ts.Client()
2889 tr := c.Transport.(*Transport)
2890
2891 donec := make(chan bool)
2892 req, _ := NewRequest("GET", ts.URL, body)
2893 req = test.newReq(req)
2894 go func() {
2895 defer close(donec)
2896 c.Do(req)
2897 }()
2898
2899 unblockc <- true
2900 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2901 test.cancel(tr, req)
2902 select {
2903 case <-donec:
2904 return true
2905 default:
2906 if d > 0 {
2907 t.Logf("Do of canceled request has not returned after %v", d)
2908 }
2909 return false
2910 }
2911 })
2912 }
2913
2914 func TestTransportCancelRequestInDo(t *testing.T) {
2915 runCancelTest(t, func(t *testing.T, test cancelTest) {
2916 testTransportCancelRequestInDo(t, test, nil)
2917 })
2918 }
2919
2920 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2921 runCancelTest(t, func(t *testing.T, test cancelTest) {
2922 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2923 })
2924 }
2925
2926 func TestTransportCancelRequestInDial(t *testing.T) {
2927 runCancelTest(t, testTransportCancelRequestInDial)
2928 }
2929 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2930 defer afterTest(t)
2931 if testing.Short() {
2932 t.Skip("skipping test in -short mode")
2933 }
2934 var logbuf strings.Builder
2935 eventLog := log.New(&logbuf, "", 0)
2936
2937 unblockDial := make(chan bool)
2938 defer close(unblockDial)
2939
2940 inDial := make(chan bool)
2941 tr := &Transport{
2942 Dial: func(network, addr string) (net.Conn, error) {
2943 eventLog.Println("dial: blocking")
2944 if !<-inDial {
2945 return nil, errors.New("main Test goroutine exited")
2946 }
2947 <-unblockDial
2948 return nil, errors.New("nope")
2949 },
2950 }
2951 cl := &Client{Transport: tr}
2952 gotres := make(chan bool)
2953 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2954 req = test.newReq(req)
2955 go func() {
2956 _, err := cl.Do(req)
2957 eventLog.Printf("Get error = %v", err != nil)
2958 test.checkErr("Get", err)
2959 gotres <- true
2960 }()
2961
2962 inDial <- true
2963
2964 eventLog.Printf("canceling")
2965 test.cancel(tr, req)
2966 test.cancel(tr, req)
2967
2968 if d, ok := t.Deadline(); ok {
2969
2970
2971 timeout := time.Until(d) * 19 / 20
2972 timer := time.AfterFunc(timeout, func() {
2973 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2974 })
2975 defer timer.Stop()
2976 }
2977 <-gotres
2978
2979 got := logbuf.String()
2980 want := `dial: blocking
2981 canceling
2982 Get error = true
2983 `
2984 if got != want {
2985 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2986 }
2987 }
2988
2989
2990 func TestTransportCancelRequestWithBody(t *testing.T) {
2991 runCancelTest(t, testTransportCancelRequestWithBody)
2992 }
2993 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2994 if testing.Short() {
2995 t.Skip("skipping test in -short mode")
2996 }
2997
2998 const msg = "Hello"
2999 unblockc := make(chan struct{})
3000 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3001 io.WriteString(w, msg)
3002 w.(Flusher).Flush()
3003 <-unblockc
3004 })).ts
3005 defer close(unblockc)
3006
3007 c := ts.Client()
3008 tr := c.Transport.(*Transport)
3009
3010 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
3011 req = test.newReq(req)
3012
3013 res, err := c.Do(req)
3014 if err != nil {
3015 t.Fatal(err)
3016 }
3017 body := make([]byte, len(msg))
3018 n, _ := io.ReadFull(res.Body, body)
3019 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
3020 t.Errorf("Body = %q; want %q", body[:n], msg)
3021 }
3022 test.cancel(tr, req)
3023
3024 tail, err := io.ReadAll(res.Body)
3025 res.Body.Close()
3026 test.checkErr("Body.Read", err)
3027 if len(tail) > 0 {
3028 t.Errorf("Spurious bytes from Body.Read: %q", tail)
3029 }
3030
3031
3032
3033 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
3034 n := tr.NumPendingRequestsForTesting()
3035 if n > 0 {
3036 if d > 0 {
3037 t.Logf("pending requests = %d after %v (want 0)", n, d)
3038 }
3039 return false
3040 }
3041 return true
3042 })
3043 }
3044
3045 func TestTransportCancelRequestBeforeDo(t *testing.T) {
3046
3047 run(t, func(t *testing.T, mode testMode) {
3048 t.Run("RequestCancel", func(t *testing.T) {
3049 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
3050 })
3051 t.Run("ContextCancel", func(t *testing.T) {
3052 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
3053 })
3054 })
3055 }
3056 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
3057 unblockc := make(chan bool)
3058 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3059 <-unblockc
3060 }))
3061 defer close(unblockc)
3062
3063 c := cst.ts.Client()
3064
3065 req, _ := NewRequest("GET", cst.ts.URL, nil)
3066 req = test.newReq(req)
3067 test.cancel(cst.tr, req)
3068
3069 _, err := c.Do(req)
3070 test.checkErr("Do", err)
3071 }
3072
3073
3074 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
3075 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
3076 }
3077 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
3078 defer afterTest(t)
3079
3080 serverConnCh := make(chan net.Conn, 1)
3081 tr := &Transport{
3082 Dial: func(network, addr string) (net.Conn, error) {
3083 cc, sc := net.Pipe()
3084 serverConnCh <- sc
3085 return cc, nil
3086 },
3087 }
3088 defer tr.CloseIdleConnections()
3089 errc := make(chan error, 1)
3090 req, _ := NewRequest("GET", "http://example.com/", nil)
3091 req = test.newReq(req)
3092 go func() {
3093 _, err := tr.RoundTrip(req)
3094 errc <- err
3095 }()
3096
3097 sc := <-serverConnCh
3098 verb := make([]byte, 3)
3099 if _, err := io.ReadFull(sc, verb); err != nil {
3100 t.Errorf("Error reading HTTP verb from server: %v", err)
3101 }
3102 if string(verb) != "GET" {
3103 t.Errorf("server received %q; want GET", verb)
3104 }
3105 defer sc.Close()
3106
3107 test.cancel(tr, req)
3108
3109 err := <-errc
3110 if err == nil {
3111 t.Fatalf("unexpected success from RoundTrip")
3112 }
3113 test.checkErr("RoundTrip", err)
3114 }
3115
3116
3117
3118
3119 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
3120 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
3121 writeErr := make(chan error, 1)
3122 msg := []byte("young\n")
3123 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3124 for {
3125 _, err := w.Write(msg)
3126 if err != nil {
3127 writeErr <- err
3128 return
3129 }
3130 w.(Flusher).Flush()
3131 }
3132 })).ts
3133
3134 c := ts.Client()
3135 tr := c.Transport.(*Transport)
3136
3137 req, _ := NewRequest("GET", ts.URL, nil)
3138 defer tr.CancelRequest(req)
3139
3140 res, err := c.Do(req)
3141 if err != nil {
3142 t.Fatal(err)
3143 }
3144
3145 const repeats = 3
3146 buf := make([]byte, len(msg)*repeats)
3147 want := bytes.Repeat(msg, repeats)
3148
3149 _, err = io.ReadFull(res.Body, buf)
3150 if err != nil {
3151 t.Fatal(err)
3152 }
3153 if !bytes.Equal(buf, want) {
3154 t.Fatalf("read %q; want %q", buf, want)
3155 }
3156
3157 if err := res.Body.Close(); err != nil {
3158 t.Errorf("Close = %v", err)
3159 }
3160
3161 if err := <-writeErr; err == nil {
3162 t.Errorf("expected non-nil write error")
3163 }
3164 }
3165
3166 type fooProto struct{}
3167
3168 func (fooProto) RoundTrip(req *Request) (*Response, error) {
3169 res := &Response{
3170 Status: "200 OK",
3171 StatusCode: 200,
3172 Header: make(Header),
3173 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
3174 }
3175 return res, nil
3176 }
3177
3178 func TestTransportAltProto(t *testing.T) {
3179 defer afterTest(t)
3180 tr := &Transport{}
3181 c := &Client{Transport: tr}
3182 tr.RegisterProtocol("foo", fooProto{})
3183 res, err := c.Get("foo://bar.com/path")
3184 if err != nil {
3185 t.Fatal(err)
3186 }
3187 bodyb, err := io.ReadAll(res.Body)
3188 if err != nil {
3189 t.Fatal(err)
3190 }
3191 body := string(bodyb)
3192 if e := "You wanted foo://bar.com/path"; body != e {
3193 t.Errorf("got response %q, want %q", body, e)
3194 }
3195 }
3196
3197 func TestTransportNoHost(t *testing.T) {
3198 defer afterTest(t)
3199 tr := &Transport{}
3200 _, err := tr.RoundTrip(&Request{
3201 Header: make(Header),
3202 URL: &url.URL{
3203 Scheme: "http",
3204 },
3205 })
3206 want := "http: no Host in request URL"
3207 if got := fmt.Sprint(err); got != want {
3208 t.Errorf("error = %v; want %q", err, want)
3209 }
3210 }
3211
3212
3213 func TestTransportEmptyMethod(t *testing.T) {
3214 req, _ := NewRequest("GET", "http://foo.com/", nil)
3215 req.Method = ""
3216 got, err := httputil.DumpRequestOut(req, false)
3217 if err != nil {
3218 t.Fatal(err)
3219 }
3220 if !strings.Contains(string(got), "GET ") {
3221 t.Fatalf("expected substring 'GET '; got: %s", got)
3222 }
3223 }
3224
3225 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3226 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3227 mux := NewServeMux()
3228 fooGate := make(chan bool, 1)
3229 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3230 w.Header().Set("foo-ipport", r.RemoteAddr)
3231 w.(Flusher).Flush()
3232 <-fooGate
3233 })
3234 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3235 w.Header().Set("bar-ipport", r.RemoteAddr)
3236 })
3237 ts := newClientServerTest(t, mode, mux).ts
3238
3239 dialGate := make(chan bool, 1)
3240 dialing := make(chan bool)
3241 c := ts.Client()
3242 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3243 for {
3244 select {
3245 case ok := <-dialGate:
3246 if !ok {
3247 return nil, errors.New("manually closed")
3248 }
3249 return net.Dial(n, addr)
3250 case dialing <- true:
3251 }
3252 }
3253 }
3254 defer close(dialGate)
3255
3256 dialGate <- true
3257 fooRes, err := c.Get(ts.URL + "/foo")
3258 if err != nil {
3259 t.Fatal(err)
3260 }
3261 fooAddr := fooRes.Header.Get("foo-ipport")
3262 if fooAddr == "" {
3263 t.Fatal("No addr on /foo request")
3264 }
3265
3266 fooDone := make(chan struct{})
3267 go func() {
3268
3269
3270
3271
3272 if mode == http2Mode {
3273
3274
3275
3276
3277 select {
3278 case <-dialing:
3279 t.Errorf("unexpected second Dial in HTTP/2 mode")
3280 case <-time.After(10 * time.Millisecond):
3281 }
3282 } else {
3283 <-dialing
3284 }
3285 fooGate <- true
3286 io.Copy(io.Discard, fooRes.Body)
3287 fooRes.Body.Close()
3288 close(fooDone)
3289 }()
3290 defer func() {
3291 <-fooDone
3292 }()
3293
3294 barRes, err := c.Get(ts.URL + "/bar")
3295 if err != nil {
3296 t.Fatal(err)
3297 }
3298 barAddr := barRes.Header.Get("bar-ipport")
3299 if barAddr != fooAddr {
3300 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3301 }
3302 barRes.Body.Close()
3303 }
3304
3305
3306 func TestTransportReading100Continue(t *testing.T) {
3307 defer afterTest(t)
3308
3309 const numReqs = 5
3310 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3311 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3312
3313 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3314 defer w.Close()
3315 defer r.Close()
3316 br := bufio.NewReader(r)
3317 n := 0
3318 for {
3319 n++
3320 req, err := ReadRequest(br)
3321 if err == io.EOF {
3322 return
3323 }
3324 if err != nil {
3325 t.Error(err)
3326 return
3327 }
3328 slurp, err := io.ReadAll(req.Body)
3329 if err != nil {
3330 t.Errorf("Server request body slurp: %v", err)
3331 return
3332 }
3333 id := req.Header.Get("Request-Id")
3334 resCode := req.Header.Get("X-Want-Response-Code")
3335 if resCode == "" {
3336 resCode = "100 Continue"
3337 if string(slurp) != reqBody(n) {
3338 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3339 }
3340 }
3341 body := fmt.Sprintf("Response number %d", n)
3342 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3343 Date: Thu, 28 Feb 2013 17:55:41 GMT
3344
3345 HTTP/1.1 200 OK
3346 Content-Type: text/html
3347 Echo-Request-Id: %s
3348 Content-Length: %d
3349
3350 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3351 w.Write(v)
3352 if id == reqID(numReqs) {
3353 return
3354 }
3355 }
3356
3357 }
3358
3359 tr := &Transport{
3360 Dial: func(n, addr string) (net.Conn, error) {
3361 sr, sw := io.Pipe()
3362 cr, cw := io.Pipe()
3363 conn := &rwTestConn{
3364 Reader: cr,
3365 Writer: sw,
3366 closeFunc: func() error {
3367 sw.Close()
3368 cw.Close()
3369 return nil
3370 },
3371 }
3372 go send100Response(cw, sr)
3373 return conn, nil
3374 },
3375 DisableKeepAlives: false,
3376 }
3377 defer tr.CloseIdleConnections()
3378 c := &Client{Transport: tr}
3379
3380 testResponse := func(req *Request, name string, wantCode int) {
3381 t.Helper()
3382 res, err := c.Do(req)
3383 if err != nil {
3384 t.Fatalf("%s: Do: %v", name, err)
3385 }
3386 if res.StatusCode != wantCode {
3387 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3388 }
3389 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3390 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3391 }
3392 _, err = io.ReadAll(res.Body)
3393 if err != nil {
3394 t.Fatalf("%s: Slurp error: %v", name, err)
3395 }
3396 }
3397
3398
3399 for i := 1; i <= numReqs; i++ {
3400 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3401 req.Header.Set("Request-Id", reqID(i))
3402 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3403 }
3404 }
3405
3406
3407
3408 func TestTransportIgnore1xxResponses(t *testing.T) {
3409 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3410 }
3411 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3412 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3413 conn, buf, _ := w.(Hijacker).Hijack()
3414 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3415 buf.Flush()
3416 conn.Close()
3417 }))
3418 cst.tr.DisableKeepAlives = true
3419
3420 var got strings.Builder
3421
3422 req, _ := NewRequest("GET", cst.ts.URL, nil)
3423 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3424 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3425 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3426 return nil
3427 },
3428 }))
3429 res, err := cst.c.Do(req)
3430 if err != nil {
3431 t.Fatal(err)
3432 }
3433 defer res.Body.Close()
3434
3435 res.Write(&got)
3436 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3437 if got.String() != want {
3438 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3439 }
3440 }
3441
3442 func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
3443 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3444 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3445 w.Header().Add("X-Header", strings.Repeat("a", 100))
3446 for i := 0; i < 10; i++ {
3447 w.WriteHeader(123)
3448 }
3449 w.WriteHeader(204)
3450 }))
3451 cst.tr.DisableKeepAlives = true
3452 cst.tr.MaxResponseHeaderBytes = 1000
3453
3454 res, err := cst.c.Get(cst.ts.URL)
3455 if err == nil {
3456 res.Body.Close()
3457 t.Fatalf("RoundTrip succeeded; want error")
3458 }
3459 for _, want := range []string{
3460 "response headers exceeded",
3461 "too many 1xx",
3462 "header list too large",
3463 } {
3464 if strings.Contains(err.Error(), want) {
3465 return
3466 }
3467 }
3468 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3469 }
3470
3471 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3472 run(t, testTransportDoesNotLimitDelivered1xxResponses)
3473 }
3474 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3475 if mode == http2Mode {
3476 t.Skip("skip until x/net/http2 updated")
3477 }
3478 const num1xx = 10
3479 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3480 w.Header().Add("X-Header", strings.Repeat("a", 100))
3481 for i := 0; i < 10; i++ {
3482 w.WriteHeader(123)
3483 }
3484 w.WriteHeader(204)
3485 }))
3486 cst.tr.DisableKeepAlives = true
3487 cst.tr.MaxResponseHeaderBytes = 1000
3488
3489 got1xx := 0
3490 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3491 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3492 got1xx++
3493 return nil
3494 },
3495 })
3496 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3497 res, err := cst.c.Do(req)
3498 if err != nil {
3499 t.Fatal(err)
3500 }
3501 res.Body.Close()
3502 if got1xx != num1xx {
3503 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3504 }
3505 }
3506
3507
3508
3509 func TestTransportTreat101Terminal(t *testing.T) {
3510 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3511 }
3512 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3513 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3514 conn, buf, _ := w.(Hijacker).Hijack()
3515 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3516 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3517 buf.Flush()
3518 conn.Close()
3519 }))
3520 res, err := cst.c.Get(cst.ts.URL)
3521 if err != nil {
3522 t.Fatal(err)
3523 }
3524 defer res.Body.Close()
3525 if res.StatusCode != StatusSwitchingProtocols {
3526 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3527 }
3528 }
3529
3530 type proxyFromEnvTest struct {
3531 req string
3532
3533 env string
3534 httpsenv string
3535 noenv string
3536 reqmeth string
3537
3538 want string
3539 wanterr error
3540 }
3541
3542 func (t proxyFromEnvTest) String() string {
3543 var buf strings.Builder
3544 space := func() {
3545 if buf.Len() > 0 {
3546 buf.WriteByte(' ')
3547 }
3548 }
3549 if t.env != "" {
3550 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3551 }
3552 if t.httpsenv != "" {
3553 space()
3554 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3555 }
3556 if t.noenv != "" {
3557 space()
3558 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3559 }
3560 if t.reqmeth != "" {
3561 space()
3562 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3563 }
3564 req := "http://example.com"
3565 if t.req != "" {
3566 req = t.req
3567 }
3568 space()
3569 fmt.Fprintf(&buf, "req=%q", req)
3570 return strings.TrimSpace(buf.String())
3571 }
3572
3573 var proxyFromEnvTests = []proxyFromEnvTest{
3574 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3575 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3576 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3577 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3578 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3579 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3580 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3581 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3582
3583
3584 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3585
3586 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3587 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3588
3589
3590
3591 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3592 want: "<nil>",
3593 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3594
3595 {want: "<nil>"},
3596
3597 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3598 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3599 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3600 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3601 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3602 }
3603
3604 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3605 t.Helper()
3606 reqURL := tt.req
3607 if reqURL == "" {
3608 reqURL = "http://example.com"
3609 }
3610 req, _ := NewRequest("GET", reqURL, nil)
3611 url, err := proxyForRequest(req)
3612 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3613 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3614 return
3615 }
3616 if got := fmt.Sprintf("%s", url); got != tt.want {
3617 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3618 }
3619 }
3620
3621 func TestProxyFromEnvironment(t *testing.T) {
3622 ResetProxyEnv()
3623 defer ResetProxyEnv()
3624 for _, tt := range proxyFromEnvTests {
3625 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3626 os.Setenv("HTTP_PROXY", tt.env)
3627 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3628 os.Setenv("NO_PROXY", tt.noenv)
3629 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3630 ResetCachedEnvironment()
3631 return ProxyFromEnvironment(req)
3632 })
3633 }
3634 }
3635
3636 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3637 ResetProxyEnv()
3638 defer ResetProxyEnv()
3639 for _, tt := range proxyFromEnvTests {
3640 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3641 os.Setenv("http_proxy", tt.env)
3642 os.Setenv("https_proxy", tt.httpsenv)
3643 os.Setenv("no_proxy", tt.noenv)
3644 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3645 ResetCachedEnvironment()
3646 return ProxyFromEnvironment(req)
3647 })
3648 }
3649 }
3650
3651 func TestIdleConnChannelLeak(t *testing.T) {
3652 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3653 }
3654 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3655
3656 var mu sync.Mutex
3657 var n int
3658
3659 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3660 mu.Lock()
3661 n++
3662 mu.Unlock()
3663 })).ts
3664
3665 const nReqs = 5
3666 didRead := make(chan bool, nReqs)
3667 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3668 defer SetReadLoopBeforeNextReadHook(nil)
3669
3670 c := ts.Client()
3671 tr := c.Transport.(*Transport)
3672 tr.Dial = func(netw, addr string) (net.Conn, error) {
3673 return net.Dial(netw, ts.Listener.Addr().String())
3674 }
3675
3676
3677 for _, disableKeep := range []bool{true, false} {
3678 tr.DisableKeepAlives = disableKeep
3679 for i := 0; i < nReqs; i++ {
3680 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3681 if err != nil {
3682 t.Fatal(err)
3683 }
3684
3685
3686
3687
3688
3689 }
3690
3691
3692
3693
3694
3695
3696
3697 for i := 0; i < nReqs; i++ {
3698 <-didRead
3699 }
3700
3701 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3702 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3703 }
3704 }
3705 }
3706
3707
3708
3709
3710 func TestTransportClosesRequestBody(t *testing.T) {
3711 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3712 }
3713 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3714 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3715 io.Copy(io.Discard, r.Body)
3716 })).ts
3717
3718 c := ts.Client()
3719
3720 closes := 0
3721
3722 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3723 if err != nil {
3724 t.Fatal(err)
3725 }
3726 res.Body.Close()
3727 if closes != 1 {
3728 t.Errorf("closes = %d; want 1", closes)
3729 }
3730 }
3731
3732 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3733 defer afterTest(t)
3734 if testing.Short() {
3735 t.Skip("skipping in short mode")
3736 }
3737 ln := newLocalListener(t)
3738 defer ln.Close()
3739 testdonec := make(chan struct{})
3740 defer close(testdonec)
3741
3742 go func() {
3743 c, err := ln.Accept()
3744 if err != nil {
3745 t.Error(err)
3746 return
3747 }
3748 <-testdonec
3749 c.Close()
3750 }()
3751
3752 tr := &Transport{
3753 Dial: func(_, _ string) (net.Conn, error) {
3754 return net.Dial("tcp", ln.Addr().String())
3755 },
3756 TLSHandshakeTimeout: 250 * time.Millisecond,
3757 }
3758 cl := &Client{Transport: tr}
3759 _, err := cl.Get("https://dummy.tld/")
3760 if err == nil {
3761 t.Error("expected error")
3762 return
3763 }
3764 ue, ok := err.(*url.Error)
3765 if !ok {
3766 t.Errorf("expected url.Error; got %#v", err)
3767 return
3768 }
3769 ne, ok := ue.Err.(net.Error)
3770 if !ok {
3771 t.Errorf("expected net.Error; got %#v", err)
3772 return
3773 }
3774 if !ne.Timeout() {
3775 t.Errorf("expected timeout error; got %v", err)
3776 }
3777 if !strings.Contains(err.Error(), "handshake timeout") {
3778 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3779 }
3780 }
3781
3782
3783 func TestTLSServerClosesConnection(t *testing.T) {
3784 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3785 }
3786 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3787 closedc := make(chan bool, 1)
3788 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3789 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3790 conn, _, _ := w.(Hijacker).Hijack()
3791 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3792 conn.Close()
3793 closedc <- true
3794 return
3795 }
3796 fmt.Fprintf(w, "hello")
3797 })).ts
3798
3799 c := ts.Client()
3800 tr := c.Transport.(*Transport)
3801
3802 var nSuccess = 0
3803 var errs []error
3804 const trials = 20
3805 for i := 0; i < trials; i++ {
3806 tr.CloseIdleConnections()
3807 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3808 if err != nil {
3809 t.Fatal(err)
3810 }
3811 <-closedc
3812 slurp, err := io.ReadAll(res.Body)
3813 if err != nil {
3814 t.Fatal(err)
3815 }
3816 if string(slurp) != "foo" {
3817 t.Errorf("Got %q, want foo", slurp)
3818 }
3819
3820
3821
3822 res, err = c.Get(ts.URL + "/")
3823 if err != nil {
3824 errs = append(errs, err)
3825 continue
3826 }
3827 slurp, err = io.ReadAll(res.Body)
3828 if err != nil {
3829 errs = append(errs, err)
3830 continue
3831 }
3832 nSuccess++
3833 }
3834 if nSuccess > 0 {
3835 t.Logf("successes = %d of %d", nSuccess, trials)
3836 } else {
3837 t.Errorf("All runs failed:")
3838 }
3839 for _, err := range errs {
3840 t.Logf(" err: %v", err)
3841 }
3842 }
3843
3844
3845
3846
3847 type byteFromChanReader chan byte
3848
3849 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3850 if len(p) == 0 {
3851 return
3852 }
3853 b, ok := <-c
3854 if !ok {
3855 return 0, io.EOF
3856 }
3857 p[0] = b
3858 return 1, nil
3859 }
3860
3861
3862
3863
3864
3865
3866
3867 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3868 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3869 }
3870 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3871 defer func(d time.Duration) {
3872 *MaxWriteWaitBeforeConnReuse = d
3873 }(*MaxWriteWaitBeforeConnReuse)
3874 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3875 var sconn struct {
3876 sync.Mutex
3877 c net.Conn
3878 }
3879 var getOkay bool
3880 var copying sync.WaitGroup
3881 closeConn := func() {
3882 sconn.Lock()
3883 defer sconn.Unlock()
3884 if sconn.c != nil {
3885 sconn.c.Close()
3886 sconn.c = nil
3887 if !getOkay {
3888 t.Logf("Closed server connection")
3889 }
3890 }
3891 }
3892 defer func() {
3893 closeConn()
3894 copying.Wait()
3895 }()
3896
3897 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3898 if r.Method == "GET" {
3899 io.WriteString(w, "bar")
3900 return
3901 }
3902 conn, _, _ := w.(Hijacker).Hijack()
3903 sconn.Lock()
3904 sconn.c = conn
3905 sconn.Unlock()
3906 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3907
3908 copying.Add(1)
3909 go func() {
3910 io.Copy(io.Discard, conn)
3911 copying.Done()
3912 }()
3913 })).ts
3914 c := ts.Client()
3915
3916 const bodySize = 256 << 10
3917 finalBit := make(byteFromChanReader, 1)
3918 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3919 req.ContentLength = bodySize
3920 res, err := c.Do(req)
3921 if err := wantBody(res, err, "foo"); err != nil {
3922 t.Errorf("POST response: %v", err)
3923 }
3924
3925 res, err = c.Get(ts.URL)
3926 if err := wantBody(res, err, "bar"); err != nil {
3927 t.Errorf("GET response: %v", err)
3928 return
3929 }
3930 getOkay = true
3931 finalBit <- 'x'
3932 close(finalBit)
3933 }
3934
3935
3936
3937 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3938 func testTransportIssue10457(t *testing.T, mode testMode) {
3939 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3940
3941
3942
3943
3944
3945 conn, _, _ := w.(Hijacker).Hijack()
3946 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3947 conn.Close()
3948 })).ts
3949 c := ts.Client()
3950
3951 res, err := c.Get(ts.URL)
3952 if err != nil {
3953 t.Fatalf("Get: %v", err)
3954 }
3955 defer res.Body.Close()
3956
3957
3958
3959
3960 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3961 t.Errorf("Foo header = %q; want %q", got, want)
3962 }
3963 }
3964
3965 type closerFunc func() error
3966
3967 func (f closerFunc) Close() error { return f() }
3968
3969 type writerFuncConn struct {
3970 net.Conn
3971 write func(p []byte) (n int, err error)
3972 }
3973
3974 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
3988 func TestRetryRequestsOnError(t *testing.T) {
3989 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3990 }
3991 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3992 newRequest := func(method, urlStr string, body io.Reader) *Request {
3993 req, err := NewRequest(method, urlStr, body)
3994 if err != nil {
3995 t.Fatal(err)
3996 }
3997 return req
3998 }
3999
4000 testCases := []struct {
4001 name string
4002 failureN int
4003 failureErr error
4004
4005
4006
4007 req func() *Request
4008 reqString string
4009 }{
4010 {
4011 name: "IdempotentNoBodySomeWritten",
4012
4013
4014 failureN: 1,
4015
4016 failureErr: ExportErrServerClosedIdle,
4017 req: func() *Request {
4018 return newRequest("GET", "http://fake.golang", nil)
4019 },
4020 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
4021 },
4022 {
4023 name: "IdempotentGetBodySomeWritten",
4024
4025
4026 failureN: 1,
4027
4028 failureErr: ExportErrServerClosedIdle,
4029 req: func() *Request {
4030 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
4031 },
4032 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
4033 },
4034 {
4035 name: "NothingWrittenNoBody",
4036
4037
4038 failureN: 0,
4039 failureErr: errors.New("second write fails"),
4040 req: func() *Request {
4041 return newRequest("DELETE", "http://fake.golang", nil)
4042 },
4043 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
4044 },
4045 {
4046 name: "NothingWrittenGetBody",
4047
4048
4049 failureN: 0,
4050 failureErr: errors.New("second write fails"),
4051
4052
4053 req: func() *Request {
4054 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
4055 },
4056 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
4057 },
4058 }
4059
4060 for _, tc := range testCases {
4061 t.Run(tc.name, func(t *testing.T) {
4062 var (
4063 mu sync.Mutex
4064 logbuf strings.Builder
4065 )
4066 logf := func(format string, args ...any) {
4067 mu.Lock()
4068 defer mu.Unlock()
4069 fmt.Fprintf(&logbuf, format, args...)
4070 logbuf.WriteByte('\n')
4071 }
4072
4073 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4074 logf("Handler")
4075 w.Header().Set("X-Status", "ok")
4076 })).ts
4077
4078 var writeNumAtomic int32
4079 c := ts.Client()
4080 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
4081 logf("Dial")
4082 c, err := net.Dial(network, ts.Listener.Addr().String())
4083 if err != nil {
4084 logf("Dial error: %v", err)
4085 return nil, err
4086 }
4087 return &writerFuncConn{
4088 Conn: c,
4089 write: func(p []byte) (n int, err error) {
4090 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
4091 logf("intentional write failure")
4092 return tc.failureN, tc.failureErr
4093 }
4094 logf("Write(%q)", p)
4095 return c.Write(p)
4096 },
4097 }, nil
4098 }
4099
4100 SetRoundTripRetried(func() {
4101 logf("Retried.")
4102 })
4103 defer SetRoundTripRetried(nil)
4104
4105 for i := 0; i < 3; i++ {
4106 t0 := time.Now()
4107 req := tc.req()
4108 res, err := c.Do(req)
4109 if err != nil {
4110 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
4111 mu.Lock()
4112 got := logbuf.String()
4113 mu.Unlock()
4114 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
4115 }
4116 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
4117 }
4118 res.Body.Close()
4119 if res.Request != req {
4120 t.Errorf("Response.Request != original request; want identical Request")
4121 }
4122 }
4123
4124 mu.Lock()
4125 got := logbuf.String()
4126 mu.Unlock()
4127 want := fmt.Sprintf(`Dial
4128 Write("%s")
4129 Handler
4130 intentional write failure
4131 Retried.
4132 Dial
4133 Write("%s")
4134 Handler
4135 Write("%s")
4136 Handler
4137 `, tc.reqString, tc.reqString, tc.reqString)
4138 if got != want {
4139 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
4140 }
4141 })
4142 }
4143 }
4144
4145
4146 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
4147 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
4148 readBody := make(chan error, 1)
4149 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4150 _, err := io.ReadAll(r.Body)
4151 readBody <- err
4152 })).ts
4153 c := ts.Client()
4154 fakeErr := errors.New("fake error")
4155 didClose := make(chan bool, 1)
4156 req, _ := NewRequest("POST", ts.URL, struct {
4157 io.Reader
4158 io.Closer
4159 }{
4160 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
4161 closerFunc(func() error {
4162 select {
4163 case didClose <- true:
4164 default:
4165 }
4166 return nil
4167 }),
4168 })
4169 res, err := c.Do(req)
4170 if res != nil {
4171 defer res.Body.Close()
4172 }
4173 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
4174 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
4175 }
4176 if err := <-readBody; err == nil {
4177 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
4178 }
4179 select {
4180 case <-didClose:
4181 default:
4182 t.Errorf("didn't see Body.Close")
4183 }
4184 }
4185
4186 func TestTransportDialTLS(t *testing.T) {
4187 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4188 }
4189 func testTransportDialTLS(t *testing.T, mode testMode) {
4190 var mu sync.Mutex
4191 var gotReq, didDial bool
4192
4193 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4194 mu.Lock()
4195 gotReq = true
4196 mu.Unlock()
4197 })).ts
4198 c := ts.Client()
4199 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4200 mu.Lock()
4201 didDial = true
4202 mu.Unlock()
4203 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4204 if err != nil {
4205 return nil, err
4206 }
4207 return c, c.Handshake()
4208 }
4209
4210 res, err := c.Get(ts.URL)
4211 if err != nil {
4212 t.Fatal(err)
4213 }
4214 res.Body.Close()
4215 mu.Lock()
4216 if !gotReq {
4217 t.Error("didn't get request")
4218 }
4219 if !didDial {
4220 t.Error("didn't use dial hook")
4221 }
4222 }
4223
4224 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4225 func testTransportDialContext(t *testing.T, mode testMode) {
4226 ctxKey := "some-key"
4227 ctxValue := "some-value"
4228 var (
4229 mu sync.Mutex
4230 gotReq bool
4231 gotCtxValue any
4232 )
4233
4234 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4235 mu.Lock()
4236 gotReq = true
4237 mu.Unlock()
4238 })).ts
4239 c := ts.Client()
4240 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4241 mu.Lock()
4242 gotCtxValue = ctx.Value(ctxKey)
4243 mu.Unlock()
4244 return net.Dial(netw, addr)
4245 }
4246
4247 req, err := NewRequest("GET", ts.URL, nil)
4248 if err != nil {
4249 t.Fatal(err)
4250 }
4251 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4252 res, err := c.Do(req.WithContext(ctx))
4253 if err != nil {
4254 t.Fatal(err)
4255 }
4256 res.Body.Close()
4257 mu.Lock()
4258 if !gotReq {
4259 t.Error("didn't get request")
4260 }
4261 if got, want := gotCtxValue, ctxValue; got != want {
4262 t.Errorf("got context with value %v, want %v", got, want)
4263 }
4264 }
4265
4266 func TestTransportDialTLSContext(t *testing.T) {
4267 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4268 }
4269 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4270 ctxKey := "some-key"
4271 ctxValue := "some-value"
4272 var (
4273 mu sync.Mutex
4274 gotReq bool
4275 gotCtxValue any
4276 )
4277
4278 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4279 mu.Lock()
4280 gotReq = true
4281 mu.Unlock()
4282 })).ts
4283 c := ts.Client()
4284 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4285 mu.Lock()
4286 gotCtxValue = ctx.Value(ctxKey)
4287 mu.Unlock()
4288 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4289 if err != nil {
4290 return nil, err
4291 }
4292 return c, c.HandshakeContext(ctx)
4293 }
4294
4295 req, err := NewRequest("GET", ts.URL, nil)
4296 if err != nil {
4297 t.Fatal(err)
4298 }
4299 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4300 res, err := c.Do(req.WithContext(ctx))
4301 if err != nil {
4302 t.Fatal(err)
4303 }
4304 res.Body.Close()
4305 mu.Lock()
4306 if !gotReq {
4307 t.Error("didn't get request")
4308 }
4309 if got, want := gotCtxValue, ctxValue; got != want {
4310 t.Errorf("got context with value %v, want %v", got, want)
4311 }
4312 }
4313
4314
4315
4316 func TestRoundTripReturnsProxyError(t *testing.T) {
4317 badProxy := func(*Request) (*url.URL, error) {
4318 return nil, errors.New("errorMessage")
4319 }
4320
4321 tr := &Transport{Proxy: badProxy}
4322
4323 req, _ := NewRequest("GET", "http://example.com", nil)
4324
4325 _, err := tr.RoundTrip(req)
4326
4327 if err == nil {
4328 t.Error("Expected proxy error to be returned by RoundTrip")
4329 }
4330 }
4331
4332
4333 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4334 tr := &Transport{}
4335 wantIdle := func(when string, n int) bool {
4336 got := tr.IdleConnCountForTesting("http", "example.com")
4337 if got == n {
4338 return true
4339 }
4340 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4341 return false
4342 }
4343 wantIdle("start", 0)
4344 if !tr.PutIdleTestConn("http", "example.com") {
4345 t.Fatal("put failed")
4346 }
4347 if !tr.PutIdleTestConn("http", "example.com") {
4348 t.Fatal("second put failed")
4349 }
4350 wantIdle("after put", 2)
4351 tr.CloseIdleConnections()
4352 if !tr.IsIdleForTesting() {
4353 t.Error("should be idle after CloseIdleConnections")
4354 }
4355 wantIdle("after close idle", 0)
4356 if tr.PutIdleTestConn("http", "example.com") {
4357 t.Fatal("put didn't fail")
4358 }
4359 wantIdle("after second put", 0)
4360
4361 tr.QueueForIdleConnForTesting()
4362 if tr.IsIdleForTesting() {
4363 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4364 }
4365 if !tr.PutIdleTestConn("http", "example.com") {
4366 t.Fatal("after re-activation")
4367 }
4368 wantIdle("after final put", 1)
4369 }
4370
4371
4372
4373 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4374 tr := &Transport{}
4375 wantIdle := func(when string, n int) bool {
4376 got := tr.IdleConnCountForTesting("https", "example.com:443")
4377 if got == n {
4378 return true
4379 }
4380 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4381 return false
4382 }
4383 wantIdle("start", 0)
4384 alt := funcRoundTripper(func() {})
4385 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4386 t.Fatal("put failed")
4387 }
4388 wantIdle("after put", 1)
4389 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4390 GotConn: func(httptrace.GotConnInfo) {
4391
4392 t.Error("GotConn called")
4393 },
4394 })
4395 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4396 _, err := tr.RoundTrip(req)
4397 if err != errFakeRoundTrip {
4398 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4399 }
4400 wantIdle("after round trip", 1)
4401 }
4402
4403
4404
4405
4406
4407
4408 func TestTransportIdleConnRacesRequest(t *testing.T) {
4409
4410
4411 runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
4412 }
4413 func testTransportIdleConnRacesRequest(t *testing.T, mode testMode) {
4414 if mode == http2UnencryptedMode {
4415 t.Skip("remove skip when #70515 is fixed")
4416 }
4417 timeout := 1 * time.Millisecond
4418 trFunc := func(tr *Transport) {
4419 tr.IdleConnTimeout = timeout
4420 }
4421 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4422 }), trFunc, optFakeNet)
4423 cst.li.trackConns = true
4424
4425
4426
4427
4428
4429 dialc := make(chan struct{})
4430 cst.li.onDial = func() {
4431 <-dialc
4432 }
4433 closec := make(chan struct{})
4434 cst.li.onClose = func(*fakeNetConn) {
4435 <-closec
4436 }
4437 ctx, cancel := context.WithCancel(context.Background())
4438 req1c := make(chan error)
4439 go func() {
4440 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
4441 resp, err := cst.c.Do(req)
4442 if err == nil {
4443 resp.Body.Close()
4444 }
4445 req1c <- err
4446 }()
4447
4448 synctest.Wait()
4449
4450 cancel()
4451 synctest.Wait()
4452 if err := <-req1c; err == nil {
4453 t.Fatal("expected request to fail, but it succeeded")
4454 }
4455
4456 close(dialc)
4457
4458
4459
4460
4461
4462
4463
4464
4465 synctest.Wait()
4466 time.Sleep(timeout)
4467 synctest.Wait()
4468
4469 req2c := make(chan error)
4470 go func() {
4471 resp, err := cst.c.Get(cst.ts.URL)
4472 if err == nil {
4473 resp.Body.Close()
4474 }
4475 req2c <- err
4476 }()
4477
4478
4479 close(closec)
4480 if err := <-req2c; err != nil {
4481 t.Fatalf("Get: %v", err)
4482 }
4483 }
4484
4485 func TestTransportRemovesConnsAfterIdle(t *testing.T) {
4486 runSynctest(t, testTransportRemovesConnsAfterIdle)
4487 }
4488 func testTransportRemovesConnsAfterIdle(t *testing.T, mode testMode) {
4489 if testing.Short() {
4490 t.Skip("skipping in short mode")
4491 }
4492
4493 timeout := 1 * time.Second
4494 trFunc := func(tr *Transport) {
4495 tr.MaxConnsPerHost = 1
4496 tr.MaxIdleConnsPerHost = 1
4497 tr.IdleConnTimeout = timeout
4498 }
4499 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4500 w.Header().Set("X-Addr", r.RemoteAddr)
4501 }), trFunc, optFakeNet)
4502
4503
4504
4505 makeRequest := func() string {
4506 resp, err := cst.c.Get(cst.ts.URL)
4507 if err != nil {
4508 t.Fatalf("got error: %s", err)
4509 }
4510 resp.Body.Close()
4511 return resp.Header.Get("X-Addr")
4512 }
4513
4514 addr1 := makeRequest()
4515
4516 time.Sleep(timeout / 2)
4517 synctest.Wait()
4518 addr2 := makeRequest()
4519 if addr1 != addr2 {
4520 t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
4521 }
4522
4523 time.Sleep(timeout)
4524 synctest.Wait()
4525 addr3 := makeRequest()
4526 if addr1 == addr3 {
4527 t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
4528 }
4529 }
4530
4531 func TestTransportRemovesConnsAfterBroken(t *testing.T) {
4532 runSynctest(t, testTransportRemovesConnsAfterBroken)
4533 }
4534 func testTransportRemovesConnsAfterBroken(t *testing.T, mode testMode) {
4535 if testing.Short() {
4536 t.Skip("skipping in short mode")
4537 }
4538
4539 trFunc := func(tr *Transport) {
4540 tr.MaxConnsPerHost = 1
4541 tr.MaxIdleConnsPerHost = 1
4542 }
4543 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4544 w.Header().Set("X-Addr", r.RemoteAddr)
4545 }), trFunc, optFakeNet)
4546 cst.li.trackConns = true
4547
4548
4549
4550 makeRequest := func() string {
4551 resp, err := cst.c.Get(cst.ts.URL)
4552 if err != nil {
4553 t.Fatalf("got error: %s", err)
4554 }
4555 resp.Body.Close()
4556 return resp.Header.Get("X-Addr")
4557 }
4558
4559 addr1 := makeRequest()
4560 addr2 := makeRequest()
4561 if addr1 != addr2 {
4562 t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
4563 }
4564
4565
4566 synctest.Wait()
4567 cst.li.conns[0].peer.Close()
4568 synctest.Wait()
4569 addr3 := makeRequest()
4570 if addr1 == addr3 {
4571 t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
4572 }
4573 }
4574
4575
4576
4577
4578
4579 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4580 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4581 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4582 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4583 t.Error("Transport advertised gzip support in the Accept header")
4584 }
4585 if r.Header.Get("Range") == "" {
4586 t.Error("no Range in request")
4587 }
4588 })).ts
4589 c := ts.Client()
4590
4591 req, _ := NewRequest("GET", ts.URL, nil)
4592 req.Header.Set("Range", "bytes=7-11")
4593 res, err := c.Do(req)
4594 if err != nil {
4595 t.Fatal(err)
4596 }
4597 res.Body.Close()
4598 }
4599
4600
4601 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4602 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4603 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4604
4605 var b [1024]byte
4606 w.Write(b[:])
4607 })).ts
4608 tr := ts.Client().Transport.(*Transport)
4609
4610 req, err := NewRequest("GET", ts.URL, nil)
4611 if err != nil {
4612 t.Fatal(err)
4613 }
4614 res, err := tr.RoundTrip(req)
4615 if err != nil {
4616 t.Fatal(err)
4617 }
4618
4619
4620
4621 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4622 t.Fatal(err)
4623 }
4624
4625 req2, err := NewRequest("GET", ts.URL, nil)
4626 if err != nil {
4627 t.Fatal(err)
4628 }
4629 tr.CancelRequest(req)
4630 res, err = tr.RoundTrip(req2)
4631 if err != nil {
4632 t.Fatal(err)
4633 }
4634 res.Body.Close()
4635 }
4636
4637
4638 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4639 run(t, testTransportContentEncodingCaseInsensitive)
4640 }
4641 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4642 for _, ce := range []string{"gzip", "GZIP"} {
4643 t.Run(ce, func(t *testing.T) {
4644 const encodedString = "Hello Gopher"
4645 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4646 w.Header().Set("Content-Encoding", ce)
4647 gz := gzip.NewWriter(w)
4648 gz.Write([]byte(encodedString))
4649 gz.Close()
4650 })).ts
4651
4652 res, err := ts.Client().Get(ts.URL)
4653 if err != nil {
4654 t.Fatal(err)
4655 }
4656
4657 body, err := io.ReadAll(res.Body)
4658 res.Body.Close()
4659 if err != nil {
4660 t.Fatal(err)
4661 }
4662
4663 if string(body) != encodedString {
4664 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4665 }
4666 })
4667 }
4668 }
4669
4670
4671 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4672 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4673 }
4674 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4675 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4676 func(tr *Transport) {
4677 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4678
4679 return &funcConn{
4680 read: func([]byte) (int, error) {
4681 return 0, errors.New("error")
4682 },
4683 write: func([]byte) (int, error) {
4684 return 0, errors.New("error")
4685 },
4686 }, nil
4687 }
4688 },
4689 ).ts
4690
4691
4692
4693
4694
4695 SetEnterRoundTripHook(func() {
4696 time.Sleep(1 * time.Millisecond)
4697 })
4698 defer SetEnterRoundTripHook(nil)
4699 var closes int
4700 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4701 if err == nil {
4702 t.Fatalf("expected request to fail, but it did not")
4703 }
4704 if closes != 1 {
4705 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4706 }
4707 }
4708
4709
4710
4711
4712 type logWritesConn struct {
4713 net.Conn
4714
4715 w io.Writer
4716
4717 rch <-chan io.Reader
4718 r io.Reader
4719
4720 mu sync.Mutex
4721 writes []string
4722 }
4723
4724 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4725 c.mu.Lock()
4726 defer c.mu.Unlock()
4727 c.writes = append(c.writes, string(p))
4728 return c.w.Write(p)
4729 }
4730
4731 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4732 if c.r == nil {
4733 c.r = <-c.rch
4734 }
4735 return c.r.Read(p)
4736 }
4737
4738 func (c *logWritesConn) Close() error { return nil }
4739
4740
4741 func TestTransportFlushesBodyChunks(t *testing.T) {
4742 defer afterTest(t)
4743 resBody := make(chan io.Reader, 1)
4744 connr, connw := io.Pipe()
4745 lw := &logWritesConn{
4746 rch: resBody,
4747 w: connw,
4748 }
4749 tr := &Transport{
4750 Dial: func(network, addr string) (net.Conn, error) {
4751 return lw, nil
4752 },
4753 }
4754 bodyr, bodyw := io.Pipe()
4755 go func() {
4756 defer bodyw.Close()
4757 for i := 0; i < 3; i++ {
4758 fmt.Fprintf(bodyw, "num%d\n", i)
4759 }
4760 }()
4761 resc := make(chan *Response)
4762 go func() {
4763 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4764 req.Header.Set("User-Agent", "x")
4765 res, err := tr.RoundTrip(req)
4766 if err != nil {
4767 t.Errorf("RoundTrip: %v", err)
4768 close(resc)
4769 return
4770 }
4771 resc <- res
4772
4773 }()
4774
4775 req, err := ReadRequest(bufio.NewReader(connr))
4776 if err != nil {
4777 t.Fatal(err)
4778 }
4779 io.Copy(io.Discard, req.Body)
4780
4781
4782 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4783 res, ok := <-resc
4784 if !ok {
4785 return
4786 }
4787 defer res.Body.Close()
4788
4789 want := []string{
4790 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4791 "5\r\nnum0\n\r\n",
4792 "5\r\nnum1\n\r\n",
4793 "5\r\nnum2\n\r\n",
4794 "0\r\n\r\n",
4795 }
4796 if !slices.Equal(lw.writes, want) {
4797 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4798 }
4799 }
4800
4801
4802 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4803 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4804 gotReq := make(chan struct{})
4805 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4806 close(gotReq)
4807 }))
4808
4809 pr, pw := io.Pipe()
4810 req, err := NewRequest("POST", cst.ts.URL, pr)
4811 if err != nil {
4812 t.Fatal(err)
4813 }
4814 gotRes := make(chan struct{})
4815 go func() {
4816 defer close(gotRes)
4817 res, err := cst.tr.RoundTrip(req)
4818 if err != nil {
4819 t.Error(err)
4820 return
4821 }
4822 res.Body.Close()
4823 }()
4824
4825 <-gotReq
4826 pw.Close()
4827 <-gotRes
4828 }
4829
4830 type wgReadCloser struct {
4831 io.Reader
4832 wg *sync.WaitGroup
4833 closed bool
4834 }
4835
4836 func (c *wgReadCloser) Close() error {
4837 if c.closed {
4838 return net.ErrClosed
4839 }
4840 c.closed = true
4841 c.wg.Done()
4842 return nil
4843 }
4844
4845
4846 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4847
4848 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4849 }
4850 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4851 if testing.Short() {
4852 t.Skip("skipping in short mode")
4853 }
4854
4855 runTimeSensitiveTest(t, []time.Duration{
4856 1 * time.Millisecond,
4857 5 * time.Millisecond,
4858 10 * time.Millisecond,
4859 50 * time.Millisecond,
4860 100 * time.Millisecond,
4861 500 * time.Millisecond,
4862 time.Second,
4863 5 * time.Second,
4864 }, func(t *testing.T, timeout time.Duration) error {
4865 SetRSTAvoidanceDelay(t, timeout)
4866 t.Logf("set RST avoidance delay to %v", timeout)
4867
4868 const contentLengthLimit = 1024 * 1024
4869 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4870 if r.ContentLength >= contentLengthLimit {
4871 w.WriteHeader(StatusBadRequest)
4872 r.Body.Close()
4873 return
4874 }
4875 w.WriteHeader(StatusOK)
4876 }))
4877
4878
4879 defer cst.close()
4880 ts := cst.ts
4881 c := ts.Client()
4882
4883 count := 100
4884
4885 bigBody := strings.Repeat("a", contentLengthLimit*2)
4886 var wg sync.WaitGroup
4887 defer wg.Wait()
4888 getBody := func() (io.ReadCloser, error) {
4889 wg.Add(1)
4890 body := &wgReadCloser{
4891 Reader: strings.NewReader(bigBody),
4892 wg: &wg,
4893 }
4894 return body, nil
4895 }
4896
4897 for i := 0; i < count; i++ {
4898 reqBody, _ := getBody()
4899 req, err := NewRequest("PUT", ts.URL, reqBody)
4900 if err != nil {
4901 reqBody.Close()
4902 t.Fatal(err)
4903 }
4904 req.ContentLength = int64(len(bigBody))
4905 req.GetBody = getBody
4906
4907 resp, err := c.Do(req)
4908 if err != nil {
4909 return fmt.Errorf("Do %d: %v", i, err)
4910 } else {
4911 resp.Body.Close()
4912 if resp.StatusCode != 400 {
4913 t.Errorf("Expected status code 400, got %v", resp.Status)
4914 }
4915 }
4916 }
4917 return nil
4918 })
4919 }
4920
4921 func TestTransportAutomaticHTTP2(t *testing.T) {
4922 testTransportAutoHTTP(t, &Transport{}, true)
4923 }
4924
4925 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4926 testTransportAutoHTTP(t, &Transport{
4927 ForceAttemptHTTP2: true,
4928 TLSClientConfig: new(tls.Config),
4929 }, true)
4930 }
4931
4932
4933 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4934 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4935 }
4936
4937 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4938 testTransportAutoHTTP(t, &Transport{
4939 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4940 }, false)
4941 }
4942
4943 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4944 testTransportAutoHTTP(t, &Transport{
4945 TLSClientConfig: new(tls.Config),
4946 }, false)
4947 }
4948
4949 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4950 testTransportAutoHTTP(t, &Transport{
4951 ExpectContinueTimeout: 1 * time.Second,
4952 }, true)
4953 }
4954
4955 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4956 var d net.Dialer
4957 testTransportAutoHTTP(t, &Transport{
4958 Dial: d.Dial,
4959 }, false)
4960 }
4961
4962 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4963 var d net.Dialer
4964 testTransportAutoHTTP(t, &Transport{
4965 DialContext: d.DialContext,
4966 }, false)
4967 }
4968
4969 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4970 testTransportAutoHTTP(t, &Transport{
4971 DialTLS: func(network, addr string) (net.Conn, error) {
4972 panic("unused")
4973 },
4974 }, false)
4975 }
4976
4977 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4978 CondSkipHTTP2(t)
4979 _, err := tr.RoundTrip(new(Request))
4980 if err == nil {
4981 t.Error("expected error from RoundTrip")
4982 }
4983 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4984 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4985 }
4986 }
4987
4988
4989
4990
4991
4992
4993
4994
4995 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4996 run(t, testTransportReuseConnEmptyResponseBody)
4997 }
4998 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4999 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5000 w.Header().Set("X-Addr", r.RemoteAddr)
5001
5002 }))
5003 n := 100
5004 if testing.Short() {
5005 n = 10
5006 }
5007 var firstAddr string
5008 for i := 0; i < n; i++ {
5009 res, err := cst.c.Get(cst.ts.URL)
5010 if err != nil {
5011 log.Fatal(err)
5012 }
5013 addr := res.Header.Get("X-Addr")
5014 if i == 0 {
5015 firstAddr = addr
5016 } else if addr != firstAddr {
5017 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
5018 }
5019 res.Body.Close()
5020 }
5021 }
5022
5023
5024 func TestNoCrashReturningTransportAltConn(t *testing.T) {
5025 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
5026 if err != nil {
5027 t.Fatal(err)
5028 }
5029 ln := newLocalListener(t)
5030 defer ln.Close()
5031
5032 var wg sync.WaitGroup
5033 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
5034 defer SetPendingDialHooks(nil, nil)
5035
5036 testDone := make(chan struct{})
5037 defer close(testDone)
5038 go func() {
5039 tln := tls.NewListener(ln, &tls.Config{
5040 NextProtos: []string{"foo"},
5041 Certificates: []tls.Certificate{cert},
5042 })
5043 sc, err := tln.Accept()
5044 if err != nil {
5045 t.Error(err)
5046 return
5047 }
5048 if err := sc.(*tls.Conn).Handshake(); err != nil {
5049 t.Error(err)
5050 return
5051 }
5052 <-testDone
5053 sc.Close()
5054 }()
5055
5056 addr := ln.Addr().String()
5057
5058 req, _ := NewRequest("GET", "https://fake.tld/", nil)
5059 cancel := make(chan struct{})
5060 req.Cancel = cancel
5061
5062 doReturned := make(chan bool, 1)
5063 madeRoundTripper := make(chan bool, 1)
5064
5065 tr := &Transport{
5066 DisableKeepAlives: true,
5067 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
5068 "foo": func(authority string, c *tls.Conn) RoundTripper {
5069 madeRoundTripper <- true
5070 return funcRoundTripper(func() {
5071 t.Error("foo RoundTripper should not be called")
5072 })
5073 },
5074 },
5075 Dial: func(_, _ string) (net.Conn, error) {
5076 panic("shouldn't be called")
5077 },
5078 DialTLS: func(_, _ string) (net.Conn, error) {
5079 tc, err := tls.Dial("tcp", addr, &tls.Config{
5080 InsecureSkipVerify: true,
5081 NextProtos: []string{"foo"},
5082 })
5083 if err != nil {
5084 return nil, err
5085 }
5086 if err := tc.Handshake(); err != nil {
5087 return nil, err
5088 }
5089 close(cancel)
5090 <-doReturned
5091 return tc, nil
5092 },
5093 }
5094 c := &Client{Transport: tr}
5095
5096 _, err = c.Do(req)
5097 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
5098 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
5099 }
5100
5101 doReturned <- true
5102 <-madeRoundTripper
5103 wg.Wait()
5104 }
5105
5106 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
5107 run(t, func(t *testing.T, mode testMode) {
5108 testTransportReuseConnection_Gzip(t, mode, true)
5109 })
5110 }
5111
5112 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
5113 run(t, func(t *testing.T, mode testMode) {
5114 testTransportReuseConnection_Gzip(t, mode, false)
5115 })
5116 }
5117
5118
5119 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
5120 addr := make(chan string, 2)
5121 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5122 addr <- r.RemoteAddr
5123 w.Header().Set("Content-Encoding", "gzip")
5124 if chunked {
5125 w.(Flusher).Flush()
5126 }
5127 w.Write(rgz)
5128 })).ts
5129 c := ts.Client()
5130
5131 trace := &httptrace.ClientTrace{
5132 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
5133 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
5134 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
5135 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
5136 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
5137 }
5138 ctx := httptrace.WithClientTrace(context.Background(), trace)
5139
5140 for i := 0; i < 2; i++ {
5141 req, _ := NewRequest("GET", ts.URL, nil)
5142 req = req.WithContext(ctx)
5143 res, err := c.Do(req)
5144 if err != nil {
5145 t.Fatal(err)
5146 }
5147 buf := make([]byte, len(rgz))
5148 if n, err := io.ReadFull(res.Body, buf); err != nil {
5149 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
5150 }
5151
5152
5153
5154 }
5155 a1, a2 := <-addr, <-addr
5156 if a1 != a2 {
5157 t.Fatalf("didn't reuse connection")
5158 }
5159 }
5160
5161 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
5162 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
5163 if mode == http2Mode {
5164 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
5165 }
5166 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5167 if r.URL.Path == "/long" {
5168 w.Header().Set("Long", strings.Repeat("a", 1<<20))
5169 }
5170 })).ts
5171 c := ts.Client()
5172 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
5173
5174 if res, err := c.Get(ts.URL); err != nil {
5175 t.Fatal(err)
5176 } else {
5177 res.Body.Close()
5178 }
5179
5180 res, err := c.Get(ts.URL + "/long")
5181 if err == nil {
5182 defer res.Body.Close()
5183 var n int64
5184 for k, vv := range res.Header {
5185 for _, v := range vv {
5186 n += int64(len(k)) + int64(len(v))
5187 }
5188 }
5189 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
5190 }
5191 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
5192 t.Errorf("got error: %v; want %q", err, want)
5193 }
5194 }
5195
5196 func TestTransportEventTrace(t *testing.T) {
5197 run(t, func(t *testing.T, mode testMode) {
5198 testTransportEventTrace(t, mode, false)
5199 }, testNotParallel)
5200 }
5201
5202
5203 func TestTransportEventTrace_NoHooks(t *testing.T) {
5204 run(t, func(t *testing.T, mode testMode) {
5205 testTransportEventTrace(t, mode, true)
5206 }, testNotParallel)
5207 }
5208
5209 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
5210 const resBody = "some body"
5211 gotWroteReqEvent := make(chan struct{}, 500)
5212 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5213 if r.Method == "GET" {
5214
5215 return
5216 }
5217 if _, err := io.ReadAll(r.Body); err != nil {
5218 t.Error(err)
5219 }
5220 if !noHooks {
5221 <-gotWroteReqEvent
5222 }
5223 io.WriteString(w, resBody)
5224 }), func(tr *Transport) {
5225 if tr.TLSClientConfig != nil {
5226 tr.TLSClientConfig.InsecureSkipVerify = true
5227 }
5228 })
5229 defer cst.close()
5230
5231 cst.tr.ExpectContinueTimeout = 1 * time.Second
5232
5233 var mu sync.Mutex
5234 var buf strings.Builder
5235 logf := func(format string, args ...any) {
5236 mu.Lock()
5237 defer mu.Unlock()
5238 fmt.Fprintf(&buf, format, args...)
5239 buf.WriteByte('\n')
5240 }
5241
5242 addrStr := cst.ts.Listener.Addr().String()
5243 ip, port, err := net.SplitHostPort(addrStr)
5244 if err != nil {
5245 t.Fatal(err)
5246 }
5247
5248
5249 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5250 if host != "dns-is-faked.golang" {
5251 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
5252 return nil, nil
5253 }
5254 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5255 })
5256
5257 body := "some body"
5258 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
5259 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
5260 trace := &httptrace.ClientTrace{
5261 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
5262 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
5263 GotFirstResponseByte: func() { logf("first response byte") },
5264 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
5265 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
5266 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
5267 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
5268 ConnectDone: func(network, addr string, err error) {
5269 if err != nil {
5270 t.Errorf("ConnectDone: %v", err)
5271 }
5272 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
5273 },
5274 WroteHeaderField: func(key string, value []string) {
5275 logf("WroteHeaderField: %s: %v", key, value)
5276 },
5277 WroteHeaders: func() {
5278 logf("WroteHeaders")
5279 },
5280 Wait100Continue: func() { logf("Wait100Continue") },
5281 Got100Continue: func() { logf("Got100Continue") },
5282 WroteRequest: func(e httptrace.WroteRequestInfo) {
5283 logf("WroteRequest: %+v", e)
5284 gotWroteReqEvent <- struct{}{}
5285 },
5286 }
5287 if mode == http2Mode {
5288 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
5289 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
5290 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
5291 }
5292 }
5293 if noHooks {
5294
5295 *trace = httptrace.ClientTrace{}
5296 }
5297 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5298
5299 req.Header.Set("Expect", "100-continue")
5300 res, err := cst.c.Do(req)
5301 if err != nil {
5302 t.Fatal(err)
5303 }
5304 logf("got roundtrip.response")
5305 slurp, err := io.ReadAll(res.Body)
5306 if err != nil {
5307 t.Fatal(err)
5308 }
5309 logf("consumed body")
5310 if string(slurp) != resBody || res.StatusCode != 200 {
5311 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5312 }
5313 res.Body.Close()
5314
5315 if noHooks {
5316
5317
5318
5319 return
5320 }
5321
5322 mu.Lock()
5323 got := buf.String()
5324 mu.Unlock()
5325
5326 wantOnce := func(sub string) {
5327 if strings.Count(got, sub) != 1 {
5328 t.Errorf("expected substring %q exactly once in output.", sub)
5329 }
5330 }
5331 wantOnceOrMore := func(sub string) {
5332 if strings.Count(got, sub) == 0 {
5333 t.Errorf("expected substring %q at least once in output.", sub)
5334 }
5335 }
5336 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5337 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5338 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5339 wantOnce("got conn: {")
5340 wantOnceOrMore("Connecting to tcp " + addrStr)
5341 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5342 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5343 wantOnce("first response byte")
5344 if mode == http2Mode {
5345 wantOnce("tls handshake start")
5346 wantOnce("tls handshake done")
5347 } else {
5348 wantOnce("PutIdleConn = <nil>")
5349 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5350
5351
5352 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5353 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5354 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5355 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5356 }
5357 wantOnce("WroteHeaders")
5358 wantOnce("Wait100Continue")
5359 wantOnce("Got100Continue")
5360 wantOnce("WroteRequest: {Err:<nil>}")
5361 if strings.Contains(got, " to udp ") {
5362 t.Errorf("should not see UDP (DNS) connections")
5363 }
5364 if t.Failed() {
5365 t.Errorf("Output:\n%s", got)
5366 }
5367
5368
5369 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5370 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5371 res, err = cst.c.Do(req)
5372 if err != nil {
5373 t.Fatal(err)
5374 }
5375 if res.StatusCode != 200 {
5376 t.Fatal(res.Status)
5377 }
5378 res.Body.Close()
5379
5380 mu.Lock()
5381 got = buf.String()
5382 mu.Unlock()
5383
5384 sub := "Getting conn for dns-is-faked.golang:"
5385 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5386 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5387 }
5388
5389 }
5390
5391 func TestTransportEventTraceTLSVerify(t *testing.T) {
5392 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5393 }
5394 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5395 var mu sync.Mutex
5396 var buf strings.Builder
5397 logf := func(format string, args ...any) {
5398 mu.Lock()
5399 defer mu.Unlock()
5400 fmt.Fprintf(&buf, format, args...)
5401 buf.WriteByte('\n')
5402 }
5403
5404 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5405 t.Error("Unexpected request")
5406 }), func(ts *httptest.Server) {
5407 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5408 logf("%s", p)
5409 return len(p), nil
5410 }), "", 0)
5411 }).ts
5412
5413 certpool := x509.NewCertPool()
5414 certpool.AddCert(ts.Certificate())
5415
5416 c := &Client{Transport: &Transport{
5417 TLSClientConfig: &tls.Config{
5418 ServerName: "dns-is-faked.golang",
5419 RootCAs: certpool,
5420 },
5421 }}
5422
5423 trace := &httptrace.ClientTrace{
5424 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5425 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5426 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5427 },
5428 }
5429
5430 req, _ := NewRequest("GET", ts.URL, nil)
5431 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5432 _, err := c.Do(req)
5433 if err == nil {
5434 t.Error("Expected request to fail TLS verification")
5435 }
5436
5437 mu.Lock()
5438 got := buf.String()
5439 mu.Unlock()
5440
5441 wantOnce := func(sub string) {
5442 if strings.Count(got, sub) != 1 {
5443 t.Errorf("expected substring %q exactly once in output.", sub)
5444 }
5445 }
5446
5447 wantOnce("TLSHandshakeStart")
5448 wantOnce("TLSHandshakeDone")
5449 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5450
5451 if t.Failed() {
5452 t.Errorf("Output:\n%s", got)
5453 }
5454 }
5455
5456 var isDNSHijacked = sync.OnceValue(func() bool {
5457 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5458 return len(addrs) != 0
5459 })
5460
5461 func skipIfDNSHijacked(t *testing.T) {
5462
5463
5464
5465 if isDNSHijacked() {
5466 t.Skip("skipping; test requires non-hijacking DNS server")
5467 }
5468 }
5469
5470 func TestTransportEventTraceRealDNS(t *testing.T) {
5471 skipIfDNSHijacked(t)
5472 defer afterTest(t)
5473 tr := &Transport{}
5474 defer tr.CloseIdleConnections()
5475 c := &Client{Transport: tr}
5476
5477 var mu sync.Mutex
5478 var buf strings.Builder
5479 logf := func(format string, args ...any) {
5480 mu.Lock()
5481 defer mu.Unlock()
5482 fmt.Fprintf(&buf, format, args...)
5483 buf.WriteByte('\n')
5484 }
5485
5486 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5487 trace := &httptrace.ClientTrace{
5488 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5489 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5490 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5491 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5492 }
5493 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5494
5495 resp, err := c.Do(req)
5496 if err == nil {
5497 resp.Body.Close()
5498 t.Fatal("expected error during DNS lookup")
5499 }
5500
5501 mu.Lock()
5502 got := buf.String()
5503 mu.Unlock()
5504
5505 wantSub := func(sub string) {
5506 if !strings.Contains(got, sub) {
5507 t.Errorf("expected substring %q in output.", sub)
5508 }
5509 }
5510 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5511 wantSub("DNSDone: {Addrs:[] Err:")
5512 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5513 t.Errorf("should not see Connect events")
5514 }
5515 if t.Failed() {
5516 t.Errorf("Output:\n%s", got)
5517 }
5518 }
5519
5520
5521 func TestTransportRejectsAlphaPort(t *testing.T) {
5522 res, err := Get("http://dummy.tld:123foo/bar")
5523 if err == nil {
5524 res.Body.Close()
5525 t.Fatal("unexpected success")
5526 }
5527 ue, ok := err.(*url.Error)
5528 if !ok {
5529 t.Fatalf("got %#v; want *url.Error", err)
5530 }
5531 got := ue.Err.Error()
5532 want := `invalid port ":123foo" after host`
5533 if got != want {
5534 t.Errorf("got error %q; want %q", got, want)
5535 }
5536 }
5537
5538
5539
5540 func TestTLSHandshakeTrace(t *testing.T) {
5541 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5542 }
5543 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5544 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5545
5546 var mu sync.Mutex
5547 var start, done bool
5548 trace := &httptrace.ClientTrace{
5549 TLSHandshakeStart: func() {
5550 mu.Lock()
5551 defer mu.Unlock()
5552 start = true
5553 },
5554 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5555 mu.Lock()
5556 defer mu.Unlock()
5557 done = true
5558 if err != nil {
5559 t.Fatal("Expected error to be nil but was:", err)
5560 }
5561 },
5562 }
5563
5564 c := ts.Client()
5565 req, err := NewRequest("GET", ts.URL, nil)
5566 if err != nil {
5567 t.Fatal("Unable to construct test request:", err)
5568 }
5569 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5570
5571 r, err := c.Do(req)
5572 if err != nil {
5573 t.Fatal("Unexpected error making request:", err)
5574 }
5575 r.Body.Close()
5576 mu.Lock()
5577 defer mu.Unlock()
5578 if !start {
5579 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5580 }
5581 if !done {
5582 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5583 }
5584 }
5585
5586 func TestTransportMaxIdleConns(t *testing.T) {
5587 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5588 }
5589 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5590 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5591
5592 })).ts
5593 c := ts.Client()
5594 tr := c.Transport.(*Transport)
5595 tr.MaxIdleConns = 4
5596
5597 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5598 if err != nil {
5599 t.Fatal(err)
5600 }
5601 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5602 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5603 })
5604
5605 hitHost := func(n int) {
5606 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5607 req = req.WithContext(ctx)
5608 res, err := c.Do(req)
5609 if err != nil {
5610 t.Fatal(err)
5611 }
5612 res.Body.Close()
5613 }
5614 for i := 0; i < 4; i++ {
5615 hitHost(i)
5616 }
5617 want := []string{
5618 "|http|host-0.dns-is-faked.golang:" + port,
5619 "|http|host-1.dns-is-faked.golang:" + port,
5620 "|http|host-2.dns-is-faked.golang:" + port,
5621 "|http|host-3.dns-is-faked.golang:" + port,
5622 }
5623 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5624 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5625 }
5626
5627
5628 hitHost(4)
5629 want = []string{
5630 "|http|host-1.dns-is-faked.golang:" + port,
5631 "|http|host-2.dns-is-faked.golang:" + port,
5632 "|http|host-3.dns-is-faked.golang:" + port,
5633 "|http|host-4.dns-is-faked.golang:" + port,
5634 }
5635 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5636 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5637 }
5638 }
5639
5640 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5641 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5642 if testing.Short() {
5643 t.Skip("skipping in short mode")
5644 }
5645
5646 timeout := 1 * time.Millisecond
5647 timeoutLoop:
5648 for {
5649 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5650
5651 }))
5652 tr := cst.tr
5653 tr.IdleConnTimeout = timeout
5654 defer tr.CloseIdleConnections()
5655 c := &Client{Transport: tr}
5656
5657 idleConns := func() []string {
5658 if mode == http2Mode {
5659 return tr.IdleConnStrsForTesting_h2()
5660 } else {
5661 return tr.IdleConnStrsForTesting()
5662 }
5663 }
5664
5665 var conn string
5666 doReq := func(n int) (timeoutOk bool) {
5667 req, _ := NewRequest("GET", cst.ts.URL, nil)
5668 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5669 PutIdleConn: func(err error) {
5670 if err != nil {
5671 t.Errorf("failed to keep idle conn: %v", err)
5672 }
5673 },
5674 }))
5675 res, err := c.Do(req)
5676 if err != nil {
5677 if strings.Contains(err.Error(), "use of closed network connection") {
5678 t.Logf("req %v: connection closed prematurely", n)
5679 return false
5680 }
5681 }
5682 if err == nil {
5683 res.Body.Close()
5684 }
5685 conns := idleConns()
5686 if len(conns) != 1 {
5687 if len(conns) == 0 {
5688 t.Logf("req %v: no idle conns", n)
5689 return false
5690 }
5691 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5692 }
5693 if conn == "" {
5694 conn = conns[0]
5695 }
5696 if conn != conns[0] {
5697 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5698 return false
5699 }
5700 return true
5701 }
5702 for i := 0; i < 3; i++ {
5703 if !doReq(i) {
5704 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5705 timeout *= 2
5706 cst.close()
5707 continue timeoutLoop
5708 }
5709 time.Sleep(timeout / 2)
5710 }
5711
5712 waitCondition(t, timeout/2, func(d time.Duration) bool {
5713 if got := idleConns(); len(got) != 0 {
5714 if d >= timeout*3/2 {
5715 t.Logf("after %v, idle conns = %q", d, got)
5716 }
5717 return false
5718 }
5719 return true
5720 })
5721 break
5722 }
5723 }
5724
5725
5726
5727
5728
5729
5730
5731
5732
5733
5734
5735
5736 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5737 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5738 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5739
5740 }))
5741
5742 ctx, cancel := context.WithCancel(context.Background())
5743 defer cancel()
5744
5745 sawDoErr := make(chan bool, 1)
5746 testDone := make(chan struct{})
5747 defer close(testDone)
5748
5749 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5750 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5751 c, err := tls.Dial(network, addr, &tls.Config{
5752 InsecureSkipVerify: true,
5753 NextProtos: []string{"h2"},
5754 })
5755 if err != nil {
5756 t.Error(err)
5757 return nil, err
5758 }
5759 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5760 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5761 c.Close()
5762 return nil, errors.New("bogus")
5763 }
5764
5765 cancel()
5766
5767 select {
5768 case <-sawDoErr:
5769 case <-testDone:
5770 }
5771 return c, nil
5772 }
5773
5774 req, _ := NewRequest("GET", cst.ts.URL, nil)
5775 req = req.WithContext(ctx)
5776 res, err := cst.c.Do(req)
5777 if err == nil {
5778 res.Body.Close()
5779 t.Fatal("unexpected success")
5780 }
5781 sawDoErr <- true
5782
5783
5784 time.Sleep(cst.tr.IdleConnTimeout * 10)
5785 }
5786
5787 type funcConn struct {
5788 net.Conn
5789 read func([]byte) (int, error)
5790 write func([]byte) (int, error)
5791 }
5792
5793 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5794 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5795 func (c funcConn) Close() error { return nil }
5796
5797
5798
5799 func TestTransportReturnsPeekError(t *testing.T) {
5800 errValue := errors.New("specific error value")
5801
5802 wrote := make(chan struct{})
5803 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5804
5805 tr := &Transport{
5806 Dial: func(network, addr string) (net.Conn, error) {
5807 c := funcConn{
5808 read: func([]byte) (int, error) {
5809 <-wrote
5810 return 0, errValue
5811 },
5812 write: func(p []byte) (int, error) {
5813 wroteOnce()
5814 return len(p), nil
5815 },
5816 }
5817 return c, nil
5818 },
5819 }
5820 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5821 if err != errValue {
5822 t.Errorf("error = %#v; want %v", err, errValue)
5823 }
5824 }
5825
5826
5827 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5828 func testTransportIDNA(t *testing.T, mode testMode) {
5829 const uniDomain = "гофер.го"
5830 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5831
5832 var port string
5833 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5834 want := punyDomain + ":" + port
5835 if r.Host != want {
5836 t.Errorf("Host header = %q; want %q", r.Host, want)
5837 }
5838 if mode == http2Mode {
5839 if r.TLS == nil {
5840 t.Errorf("r.TLS == nil")
5841 } else if r.TLS.ServerName != punyDomain {
5842 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5843 }
5844 }
5845 w.Header().Set("Hit-Handler", "1")
5846 }), func(tr *Transport) {
5847 if tr.TLSClientConfig != nil {
5848 tr.TLSClientConfig.InsecureSkipVerify = true
5849 }
5850 })
5851
5852 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5853 if err != nil {
5854 t.Fatal(err)
5855 }
5856
5857
5858 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5859 if host != punyDomain {
5860 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5861 return nil, nil
5862 }
5863 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5864 })
5865
5866 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5867 trace := &httptrace.ClientTrace{
5868 GetConn: func(hostPort string) {
5869 want := net.JoinHostPort(punyDomain, port)
5870 if hostPort != want {
5871 t.Errorf("getting conn for %q; want %q", hostPort, want)
5872 }
5873 },
5874 DNSStart: func(e httptrace.DNSStartInfo) {
5875 if e.Host != punyDomain {
5876 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5877 }
5878 },
5879 }
5880 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5881
5882 res, err := cst.tr.RoundTrip(req)
5883 if err != nil {
5884 t.Fatal(err)
5885 }
5886 defer res.Body.Close()
5887 if res.Header.Get("Hit-Handler") != "1" {
5888 out, err := httputil.DumpResponse(res, true)
5889 if err != nil {
5890 t.Fatal(err)
5891 }
5892 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5893 }
5894 }
5895
5896
5897 func TestTransportProxyConnectHeader(t *testing.T) {
5898 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5899 }
5900 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5901 reqc := make(chan *Request, 1)
5902 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5903 if r.Method != "CONNECT" {
5904 t.Errorf("method = %q; want CONNECT", r.Method)
5905 }
5906 reqc <- r
5907 c, _, err := w.(Hijacker).Hijack()
5908 if err != nil {
5909 t.Errorf("Hijack: %v", err)
5910 return
5911 }
5912 c.Close()
5913 })).ts
5914
5915 c := ts.Client()
5916 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5917 return url.Parse(ts.URL)
5918 }
5919 c.Transport.(*Transport).ProxyConnectHeader = Header{
5920 "User-Agent": {"foo"},
5921 "Other": {"bar"},
5922 }
5923
5924 res, err := c.Get("https://dummy.tld/")
5925 if err == nil {
5926 res.Body.Close()
5927 t.Errorf("unexpected success")
5928 }
5929
5930 r := <-reqc
5931 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5932 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5933 }
5934 if got, want := r.Header.Get("Other"), "bar"; got != want {
5935 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5936 }
5937 }
5938
5939 func TestTransportProxyGetConnectHeader(t *testing.T) {
5940 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5941 }
5942 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5943 reqc := make(chan *Request, 1)
5944 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5945 if r.Method != "CONNECT" {
5946 t.Errorf("method = %q; want CONNECT", r.Method)
5947 }
5948 reqc <- r
5949 c, _, err := w.(Hijacker).Hijack()
5950 if err != nil {
5951 t.Errorf("Hijack: %v", err)
5952 return
5953 }
5954 c.Close()
5955 })).ts
5956
5957 c := ts.Client()
5958 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5959 return url.Parse(ts.URL)
5960 }
5961
5962 c.Transport.(*Transport).ProxyConnectHeader = Header{
5963 "User-Agent": {"foo"},
5964 "Other": {"bar"},
5965 }
5966 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5967 return Header{
5968 "User-Agent": {"foo2"},
5969 "Other": {"bar2"},
5970 }, nil
5971 }
5972
5973 res, err := c.Get("https://dummy.tld/")
5974 if err == nil {
5975 res.Body.Close()
5976 t.Errorf("unexpected success")
5977 }
5978
5979 r := <-reqc
5980 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5981 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5982 }
5983 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5984 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5985 }
5986 }
5987
5988 var errFakeRoundTrip = errors.New("fake roundtrip")
5989
5990 type funcRoundTripper func()
5991
5992 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5993 fn()
5994 return nil, errFakeRoundTrip
5995 }
5996
5997 func wantBody(res *Response, err error, want string) error {
5998 if err != nil {
5999 return err
6000 }
6001 slurp, err := io.ReadAll(res.Body)
6002 if err != nil {
6003 return fmt.Errorf("error reading body: %v", err)
6004 }
6005 if string(slurp) != want {
6006 return fmt.Errorf("body = %q; want %q", slurp, want)
6007 }
6008 if err := res.Body.Close(); err != nil {
6009 return fmt.Errorf("body Close = %v", err)
6010 }
6011 return nil
6012 }
6013
6014 func newLocalListener(t *testing.T) net.Listener {
6015 ln, err := net.Listen("tcp", "127.0.0.1:0")
6016 if err != nil {
6017 ln, err = net.Listen("tcp6", "[::1]:0")
6018 }
6019 if err != nil {
6020 t.Fatal(err)
6021 }
6022 return ln
6023 }
6024
6025 type countCloseReader struct {
6026 n *int
6027 io.Reader
6028 }
6029
6030 func (cr countCloseReader) Close() error {
6031 (*cr.n)++
6032 return nil
6033 }
6034
6035
6036 var rgz = []byte{
6037 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
6038 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
6039 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
6040 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
6041 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
6042 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
6043 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
6044 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
6045 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
6046 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
6047 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
6048 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
6049 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
6050 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
6051 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
6052 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
6053 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
6054 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
6055 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
6056 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
6057 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
6058 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
6059 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
6060 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
6061 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
6062 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
6063 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
6064 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
6065 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
6066 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
6067 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
6068 0x00, 0x00,
6069 }
6070
6071
6072
6073 func TestMissingStatusNoPanic(t *testing.T) {
6074 t.Parallel()
6075
6076 const want = "unknown status code"
6077
6078 ln := newLocalListener(t)
6079 addr := ln.Addr().String()
6080 done := make(chan bool)
6081 fullAddrURL := fmt.Sprintf("http://%s", addr)
6082 raw := "HTTP/1.1 400\r\n" +
6083 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6084 "Content-Type: text/html; charset=utf-8\r\n" +
6085 "Content-Length: 10\r\n" +
6086 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
6087 "Vary: Accept-Encoding\r\n\r\n" +
6088 "Aloha Olaa"
6089
6090 go func() {
6091 defer close(done)
6092
6093 conn, _ := ln.Accept()
6094 if conn != nil {
6095 io.WriteString(conn, raw)
6096 io.ReadAll(conn)
6097 conn.Close()
6098 }
6099 }()
6100
6101 proxyURL, err := url.Parse(fullAddrURL)
6102 if err != nil {
6103 t.Fatalf("proxyURL: %v", err)
6104 }
6105
6106 tr := &Transport{Proxy: ProxyURL(proxyURL)}
6107
6108 req, _ := NewRequest("GET", "https://golang.org/", nil)
6109 res, err, panicked := doFetchCheckPanic(tr, req)
6110 if panicked {
6111 t.Error("panicked, expecting an error")
6112 }
6113 if res != nil && res.Body != nil {
6114 io.Copy(io.Discard, res.Body)
6115 res.Body.Close()
6116 }
6117
6118 if err == nil || !strings.Contains(err.Error(), want) {
6119 t.Errorf("got=%v want=%q", err, want)
6120 }
6121
6122 ln.Close()
6123 <-done
6124 }
6125
6126 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
6127 defer func() {
6128 if r := recover(); r != nil {
6129 panicked = true
6130 }
6131 }()
6132 res, err = tr.RoundTrip(req)
6133 return
6134 }
6135
6136
6137
6138 func TestNoBodyOnChunked304Response(t *testing.T) {
6139 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
6140 }
6141 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
6142 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6143 conn, buf, _ := w.(Hijacker).Hijack()
6144 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
6145 buf.Flush()
6146 conn.Close()
6147 }))
6148
6149
6150
6151
6152
6153 cst.tr.DisableKeepAlives = true
6154
6155 res, err := cst.c.Get(cst.ts.URL)
6156 if err != nil {
6157 t.Fatal(err)
6158 }
6159
6160 if res.Body != NoBody {
6161 t.Errorf("Unexpected body on 304 response")
6162 }
6163 }
6164
6165 type funcWriter func([]byte) (int, error)
6166
6167 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
6168
6169 type doneContext struct {
6170 context.Context
6171 err error
6172 }
6173
6174 func (doneContext) Done() <-chan struct{} {
6175 c := make(chan struct{})
6176 close(c)
6177 return c
6178 }
6179
6180 func (d doneContext) Err() error { return d.err }
6181
6182
6183 func TestTransportCheckContextDoneEarly(t *testing.T) {
6184 tr := &Transport{}
6185 req, _ := NewRequest("GET", "http://fake.example/", nil)
6186 wantErr := errors.New("some error")
6187 req = req.WithContext(doneContext{context.Background(), wantErr})
6188 _, err := tr.RoundTrip(req)
6189 if err != wantErr {
6190 t.Errorf("error = %v; want %v", err, wantErr)
6191 }
6192 }
6193
6194
6195
6196
6197
6198
6199 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
6200 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
6201 }
6202 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
6203 timeout := 1 * time.Millisecond
6204 for {
6205 inHandler := make(chan bool)
6206 cancelHandler := make(chan struct{})
6207 handlerDone := make(chan bool)
6208 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6209 <-r.Context().Done()
6210
6211 select {
6212 case <-cancelHandler:
6213 return
6214 case inHandler <- true:
6215 }
6216 defer func() { handlerDone <- true }()
6217
6218
6219 conn, _, err := w.(Hijacker).Hijack()
6220 if err != nil {
6221 t.Error(err)
6222 return
6223 }
6224 n, err := conn.Read([]byte{0})
6225 if n != 0 || err != io.EOF {
6226 t.Errorf("unexpected Read result: %v, %v", n, err)
6227 }
6228 conn.Close()
6229 }))
6230
6231 cst.c.Timeout = timeout
6232
6233 _, err := cst.c.Get(cst.ts.URL)
6234 if err == nil {
6235 close(cancelHandler)
6236 t.Fatal("unexpected Get success")
6237 }
6238
6239 tooSlow := time.NewTimer(timeout * 10)
6240 select {
6241 case <-tooSlow.C:
6242
6243
6244
6245 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
6246 close(cancelHandler)
6247 cst.close()
6248 timeout *= 2
6249 continue
6250 case <-inHandler:
6251 tooSlow.Stop()
6252 <-handlerDone
6253 }
6254 break
6255 }
6256 }
6257
6258
6259
6260
6261
6262
6263 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
6264 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
6265 }
6266 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
6267 inHandler := make(chan bool)
6268 cancelHandler := make(chan struct{})
6269 handlerDone := make(chan bool)
6270 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6271 w.Header().Set("Content-Length", "100")
6272 w.(Flusher).Flush()
6273
6274 select {
6275 case <-cancelHandler:
6276 return
6277 case inHandler <- true:
6278 }
6279 defer func() { handlerDone <- true }()
6280
6281 conn, _, err := w.(Hijacker).Hijack()
6282 if err != nil {
6283 t.Error(err)
6284 return
6285 }
6286 conn.Write([]byte("foo"))
6287
6288 n, err := conn.Read([]byte{0})
6289
6290
6291
6292
6293
6294 if n != 0 || err == nil {
6295 t.Errorf("unexpected Read result: %v, %v", n, err)
6296 }
6297 conn.Close()
6298 }))
6299
6300
6301
6302
6303
6304 cst.c.Timeout = 24 * time.Hour
6305 req, _ := NewRequest("GET", cst.ts.URL, nil)
6306 cancelReq := make(chan struct{})
6307 req.Cancel = cancelReq
6308
6309 res, err := cst.c.Do(req)
6310 if err != nil {
6311 close(cancelHandler)
6312 t.Fatalf("Get error: %v", err)
6313 }
6314
6315
6316
6317
6318 close(cancelReq)
6319 got, err := io.ReadAll(res.Body)
6320 if err == nil {
6321 t.Errorf("unexpected success; read %q, nil", got)
6322 }
6323
6324
6325 <-inHandler
6326 <-handlerDone
6327 }
6328
6329 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6330 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6331 }
6332 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6333 done := make(chan struct{})
6334 defer close(done)
6335 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6336 conn, _, err := w.(Hijacker).Hijack()
6337 if err != nil {
6338 t.Error(err)
6339 return
6340 }
6341 defer conn.Close()
6342 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6343 bs := bufio.NewScanner(conn)
6344 bs.Scan()
6345 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6346 <-done
6347 }))
6348
6349 req, _ := NewRequest("GET", cst.ts.URL, nil)
6350 req.Header.Set("Upgrade", "foo")
6351 req.Header.Set("Connection", "upgrade")
6352 res, err := cst.c.Do(req)
6353 if err != nil {
6354 t.Fatal(err)
6355 }
6356 if res.StatusCode != 101 {
6357 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6358 }
6359 rwc, ok := res.Body.(io.ReadWriteCloser)
6360 if !ok {
6361 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6362 }
6363 defer rwc.Close()
6364 bs := bufio.NewScanner(rwc)
6365 if !bs.Scan() {
6366 t.Fatalf("expected readable input")
6367 }
6368 if got, want := bs.Text(), "Some buffered data"; got != want {
6369 t.Errorf("read %q; want %q", got, want)
6370 }
6371 io.WriteString(rwc, "echo\n")
6372 if !bs.Scan() {
6373 t.Fatalf("expected another line")
6374 }
6375 if got, want := bs.Text(), "ECHO"; got != want {
6376 t.Errorf("read %q; want %q", got, want)
6377 }
6378 }
6379
6380 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6381 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6382 const target = "backend:443"
6383 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6384 if r.Method != "CONNECT" {
6385 t.Errorf("unexpected method %q", r.Method)
6386 w.WriteHeader(500)
6387 return
6388 }
6389 if r.RequestURI != target {
6390 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6391 w.WriteHeader(500)
6392 return
6393 }
6394 nc, brw, err := w.(Hijacker).Hijack()
6395 if err != nil {
6396 t.Error(err)
6397 return
6398 }
6399 defer nc.Close()
6400 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6401
6402 for {
6403 line, err := brw.ReadString('\n')
6404 if err != nil {
6405 if err != io.EOF {
6406 t.Error(err)
6407 }
6408 return
6409 }
6410 io.WriteString(brw, strings.ToUpper(line))
6411 brw.Flush()
6412 }
6413 }))
6414 pr, pw := io.Pipe()
6415 defer pw.Close()
6416 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6417 if err != nil {
6418 t.Fatal(err)
6419 }
6420 req.URL.Opaque = target
6421 res, err := cst.c.Do(req)
6422 if err != nil {
6423 t.Fatal(err)
6424 }
6425 defer res.Body.Close()
6426 if res.StatusCode != 200 {
6427 t.Fatalf("status code = %d; want 200", res.StatusCode)
6428 }
6429 br := bufio.NewReader(res.Body)
6430 for _, str := range []string{"foo", "bar", "baz"} {
6431 fmt.Fprintf(pw, "%s\n", str)
6432 got, err := br.ReadString('\n')
6433 if err != nil {
6434 t.Fatal(err)
6435 }
6436 got = strings.TrimSpace(got)
6437 want := strings.ToUpper(str)
6438 if got != want {
6439 t.Fatalf("got %q; want %q", got, want)
6440 }
6441 }
6442 }
6443
6444 func TestTransportRequestReplayable(t *testing.T) {
6445 someBody := io.NopCloser(strings.NewReader(""))
6446 tests := []struct {
6447 name string
6448 req *Request
6449 want bool
6450 }{
6451 {
6452 name: "GET",
6453 req: &Request{Method: "GET"},
6454 want: true,
6455 },
6456 {
6457 name: "GET_http.NoBody",
6458 req: &Request{Method: "GET", Body: NoBody},
6459 want: true,
6460 },
6461 {
6462 name: "GET_body",
6463 req: &Request{Method: "GET", Body: someBody},
6464 want: false,
6465 },
6466 {
6467 name: "POST",
6468 req: &Request{Method: "POST"},
6469 want: false,
6470 },
6471 {
6472 name: "POST_idempotency-key",
6473 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6474 want: true,
6475 },
6476 {
6477 name: "POST_x-idempotency-key",
6478 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6479 want: true,
6480 },
6481 {
6482 name: "POST_body",
6483 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6484 want: false,
6485 },
6486 }
6487 for _, tt := range tests {
6488 t.Run(tt.name, func(t *testing.T) {
6489 got := tt.req.ExportIsReplayable()
6490 if got != tt.want {
6491 t.Errorf("replyable = %v; want %v", got, tt.want)
6492 }
6493 })
6494 }
6495 }
6496
6497
6498
6499 type testMockTCPConn struct {
6500 *net.TCPConn
6501
6502 ReadFromCalled bool
6503 }
6504
6505 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6506 c.ReadFromCalled = true
6507 return c.TCPConn.ReadFrom(r)
6508 }
6509
6510 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6511 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6512 nBytes := int64(1 << 10)
6513 newFileFunc := func() (r io.Reader, done func(), err error) {
6514 f, err := os.CreateTemp("", "net-http-newfilefunc")
6515 if err != nil {
6516 return nil, nil, err
6517 }
6518
6519
6520 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6521 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6522 }
6523 if _, err := f.Seek(0, 0); err != nil {
6524 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6525 }
6526
6527 done = func() {
6528 f.Close()
6529 os.Remove(f.Name())
6530 }
6531
6532 return f, done, nil
6533 }
6534
6535 newBufferFunc := func() (io.Reader, func(), error) {
6536 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6537 }
6538
6539 cases := []struct {
6540 name string
6541 readerFunc func() (io.Reader, func(), error)
6542 contentLength int64
6543 expectedReadFrom bool
6544 }{
6545 {
6546 name: "file, length",
6547 readerFunc: newFileFunc,
6548 contentLength: nBytes,
6549 expectedReadFrom: true,
6550 },
6551 {
6552 name: "file, no length",
6553 readerFunc: newFileFunc,
6554 },
6555 {
6556 name: "file, negative length",
6557 readerFunc: newFileFunc,
6558 contentLength: -1,
6559 },
6560 {
6561 name: "buffer",
6562 contentLength: nBytes,
6563 readerFunc: newBufferFunc,
6564 },
6565 {
6566 name: "buffer, no length",
6567 readerFunc: newBufferFunc,
6568 },
6569 {
6570 name: "buffer, length -1",
6571 contentLength: -1,
6572 readerFunc: newBufferFunc,
6573 },
6574 }
6575
6576 for _, tc := range cases {
6577 t.Run(tc.name, func(t *testing.T) {
6578 r, cleanup, err := tc.readerFunc()
6579 if err != nil {
6580 t.Fatal(err)
6581 }
6582 defer cleanup()
6583
6584 tConn := &testMockTCPConn{}
6585 trFunc := func(tr *Transport) {
6586 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6587 var d net.Dialer
6588 conn, err := d.DialContext(ctx, network, addr)
6589 if err != nil {
6590 return nil, err
6591 }
6592
6593 tcpConn, ok := conn.(*net.TCPConn)
6594 if !ok {
6595 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6596 }
6597
6598 tConn.TCPConn = tcpConn
6599 return tConn, nil
6600 }
6601 }
6602
6603 cst := newClientServerTest(
6604 t,
6605 mode,
6606 HandlerFunc(func(w ResponseWriter, r *Request) {
6607 io.Copy(io.Discard, r.Body)
6608 r.Body.Close()
6609 w.WriteHeader(200)
6610 }),
6611 trFunc,
6612 )
6613
6614 req, err := NewRequest("PUT", cst.ts.URL, r)
6615 if err != nil {
6616 t.Fatal(err)
6617 }
6618 req.ContentLength = tc.contentLength
6619 req.Header.Set("Content-Type", "application/octet-stream")
6620 resp, err := cst.c.Do(req)
6621 if err != nil {
6622 t.Fatal(err)
6623 }
6624 defer resp.Body.Close()
6625 if resp.StatusCode != 200 {
6626 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6627 }
6628
6629 expectedReadFrom := tc.expectedReadFrom
6630 if mode != http1Mode {
6631 expectedReadFrom = false
6632 }
6633 if !tConn.ReadFromCalled && expectedReadFrom {
6634 t.Fatalf("did not call ReadFrom")
6635 }
6636
6637 if tConn.ReadFromCalled && !expectedReadFrom {
6638 t.Fatalf("ReadFrom was unexpectedly invoked")
6639 }
6640 })
6641 }
6642 }
6643
6644 func TestTransportClone(t *testing.T) {
6645 tr := &Transport{
6646 Proxy: func(*Request) (*url.URL, error) { panic("") },
6647 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6648 return nil
6649 },
6650 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6651 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6652 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6653 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6654 TLSClientConfig: new(tls.Config),
6655 TLSHandshakeTimeout: time.Second,
6656 DisableKeepAlives: true,
6657 DisableCompression: true,
6658 MaxIdleConns: 1,
6659 MaxIdleConnsPerHost: 1,
6660 MaxConnsPerHost: 1,
6661 IdleConnTimeout: time.Second,
6662 ResponseHeaderTimeout: time.Second,
6663 ExpectContinueTimeout: time.Second,
6664 ProxyConnectHeader: Header{},
6665 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6666 MaxResponseHeaderBytes: 1,
6667 ForceAttemptHTTP2: true,
6668 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6669 Protocols: &Protocols{},
6670 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6671 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6672 },
6673 ReadBufferSize: 1,
6674 WriteBufferSize: 1,
6675 }
6676 tr.Protocols.SetHTTP1(true)
6677 tr.Protocols.SetHTTP2(true)
6678 tr2 := tr.Clone()
6679 rv := reflect.ValueOf(tr2).Elem()
6680 rt := rv.Type()
6681 for i := 0; i < rt.NumField(); i++ {
6682 sf := rt.Field(i)
6683 if !token.IsExported(sf.Name) {
6684 continue
6685 }
6686 if rv.Field(i).IsZero() {
6687 t.Errorf("cloned field t2.%s is zero", sf.Name)
6688 }
6689 }
6690
6691 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6692 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6693 }
6694
6695
6696 tr = new(Transport)
6697 tr2 = tr.Clone()
6698 if tr2.TLSNextProto != nil {
6699 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6700 }
6701 }
6702
6703 func TestIs408(t *testing.T) {
6704 tests := []struct {
6705 in string
6706 want bool
6707 }{
6708 {"HTTP/1.0 408", true},
6709 {"HTTP/1.1 408", true},
6710 {"HTTP/1.8 408", true},
6711 {"HTTP/2.0 408", false},
6712 {"HTTP/1.1 408 ", true},
6713 {"HTTP/1.1 40", false},
6714 {"http/1.0 408", false},
6715 {"HTTP/1-1 408", false},
6716 }
6717 for _, tt := range tests {
6718 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6719 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6720 }
6721 }
6722 }
6723
6724 func TestTransportIgnores408(t *testing.T) {
6725 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6726 }
6727 func testTransportIgnores408(t *testing.T, mode testMode) {
6728
6729 defer log.SetOutput(log.Writer())
6730
6731 var logout strings.Builder
6732 log.SetOutput(&logout)
6733
6734 const target = "backend:443"
6735
6736 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6737 nc, _, err := w.(Hijacker).Hijack()
6738 if err != nil {
6739 t.Error(err)
6740 return
6741 }
6742 defer nc.Close()
6743 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6744 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6745 }))
6746 req, err := NewRequest("GET", cst.ts.URL, nil)
6747 if err != nil {
6748 t.Fatal(err)
6749 }
6750 res, err := cst.c.Do(req)
6751 if err != nil {
6752 t.Fatal(err)
6753 }
6754 slurp, err := io.ReadAll(res.Body)
6755 if err != nil {
6756 t.Fatal(err)
6757 }
6758 if err != nil {
6759 t.Fatal(err)
6760 }
6761 if string(slurp) != "ok" {
6762 t.Fatalf("got %q; want ok", slurp)
6763 }
6764
6765 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6766 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6767 if d > 0 {
6768 t.Logf("%v idle conns still present after %v", n, d)
6769 }
6770 return false
6771 }
6772 return true
6773 })
6774 if got := logout.String(); got != "" {
6775 t.Fatalf("expected no log output; got: %s", got)
6776 }
6777 }
6778
6779 func TestInvalidHeaderResponse(t *testing.T) {
6780 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6781 }
6782 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6783 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6784 conn, buf, _ := w.(Hijacker).Hijack()
6785 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6786 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6787 "Content-Type: text/html; charset=utf-8\r\n" +
6788 "Content-Length: 0\r\n" +
6789 "Foo : bar\r\n\r\n"))
6790 buf.Flush()
6791 conn.Close()
6792 }))
6793 res, err := cst.c.Get(cst.ts.URL)
6794 if err != nil {
6795 t.Fatal(err)
6796 }
6797 defer res.Body.Close()
6798 if v := res.Header.Get("Foo"); v != "" {
6799 t.Errorf(`unexpected "Foo" header: %q`, v)
6800 }
6801 if v := res.Header.Get("Foo "); v != "bar" {
6802 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6803 }
6804 }
6805
6806 type bodyCloser bool
6807
6808 func (bc *bodyCloser) Close() error {
6809 *bc = true
6810 return nil
6811 }
6812 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6813 return 0, io.EOF
6814 }
6815
6816
6817
6818 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6819 run(t, testTransportClosesBodyOnInvalidRequests)
6820 }
6821 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6822 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6823 t.Errorf("Should not have been invoked")
6824 })).ts
6825
6826 u, _ := url.Parse(cst.URL)
6827
6828 tests := []struct {
6829 name string
6830 req *Request
6831 wantErr string
6832 }{
6833 {
6834 name: "invalid method",
6835 req: &Request{
6836 Method: " ",
6837 URL: u,
6838 },
6839 wantErr: `invalid method " "`,
6840 },
6841 {
6842 name: "nil URL",
6843 req: &Request{
6844 Method: "GET",
6845 },
6846 wantErr: `nil Request.URL`,
6847 },
6848 {
6849 name: "invalid header key",
6850 req: &Request{
6851 Method: "GET",
6852 Header: Header{"💡": {"emoji"}},
6853 URL: u,
6854 },
6855 wantErr: `invalid header field name "💡"`,
6856 },
6857 {
6858 name: "invalid header value",
6859 req: &Request{
6860 Method: "POST",
6861 Header: Header{"key": {"\x19"}},
6862 URL: u,
6863 },
6864 wantErr: `invalid header field value for "key"`,
6865 },
6866 {
6867 name: "non HTTP(s) scheme",
6868 req: &Request{
6869 Method: "POST",
6870 URL: &url.URL{Scheme: "faux"},
6871 },
6872 wantErr: `unsupported protocol scheme "faux"`,
6873 },
6874 {
6875 name: "no Host in URL",
6876 req: &Request{
6877 Method: "POST",
6878 URL: &url.URL{Scheme: "http"},
6879 },
6880 wantErr: `no Host in request URL`,
6881 },
6882 }
6883
6884 for _, tt := range tests {
6885 t.Run(tt.name, func(t *testing.T) {
6886 var bc bodyCloser
6887 req := tt.req
6888 req.Body = &bc
6889 _, err := cst.Client().Do(tt.req)
6890 if err == nil {
6891 t.Fatal("Expected an error")
6892 }
6893 if !bc {
6894 t.Fatal("Expected body to have been closed")
6895 }
6896 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6897 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6898 }
6899 })
6900 }
6901 }
6902
6903
6904
6905 type breakableConn struct {
6906 net.Conn
6907 *brokenState
6908 }
6909
6910 type brokenState struct {
6911 sync.Mutex
6912 broken bool
6913 }
6914
6915 func (w *breakableConn) Write(b []byte) (n int, err error) {
6916 w.Lock()
6917 defer w.Unlock()
6918 if w.broken {
6919 return 0, errors.New("some write error")
6920 }
6921 return w.Conn.Write(b)
6922 }
6923
6924
6925 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6926 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6927 }
6928 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6929 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6930
6931 var brokenState brokenState
6932
6933 const numReqs = 5
6934 var numDials, gotConns uint32
6935
6936 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6937 atomic.AddUint32(&numDials, 1)
6938 c, err := net.Dial(netw, addr)
6939 if err != nil {
6940 t.Errorf("unexpected Dial error: %v", err)
6941 return nil, err
6942 }
6943 return &breakableConn{c, &brokenState}, err
6944 }
6945
6946 for i := 1; i <= numReqs; i++ {
6947 brokenState.Lock()
6948 brokenState.broken = false
6949 brokenState.Unlock()
6950
6951
6952
6953
6954 doBreak := i != numReqs
6955
6956 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6957 GotConn: func(info httptrace.GotConnInfo) {
6958 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6959 atomic.AddUint32(&gotConns, 1)
6960 },
6961 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6962 brokenState.Lock()
6963 defer brokenState.Unlock()
6964 if doBreak {
6965 brokenState.broken = true
6966 }
6967 },
6968 })
6969 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6970 if err != nil {
6971 t.Fatal(err)
6972 }
6973 _, err = cst.c.Do(req)
6974 if doBreak != (err != nil) {
6975 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6976 }
6977 }
6978 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6979 t.Errorf("GotConn calls = %v; want %v", got, want)
6980 }
6981 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6982 t.Errorf("Dials = %v; want %v", got, want)
6983 }
6984 }
6985
6986
6987
6988
6989
6990 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6991 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6992 }
6993 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6994 CondSkipHTTP2(t)
6995
6996 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6997 _, err := w.Write([]byte("foo"))
6998 if err != nil {
6999 t.Fatalf("Write: %v", err)
7000 }
7001 })
7002
7003 ts := newClientServerTest(t, mode, h).ts
7004
7005 c := ts.Client()
7006 tr := c.Transport.(*Transport)
7007 tr.MaxConnsPerHost = 1
7008
7009 errCh := make(chan error, 300)
7010 doReq := func() {
7011 resp, err := c.Get(ts.URL)
7012 if err != nil {
7013 errCh <- fmt.Errorf("request failed: %v", err)
7014 return
7015 }
7016 defer resp.Body.Close()
7017 _, err = io.ReadAll(resp.Body)
7018 if err != nil {
7019 errCh <- fmt.Errorf("read body failed: %v", err)
7020 }
7021 }
7022
7023 var wg sync.WaitGroup
7024 for i := 0; i < 300; i++ {
7025 wg.Add(1)
7026 go func() {
7027 defer wg.Done()
7028 doReq()
7029 }()
7030 }
7031 wg.Wait()
7032 close(errCh)
7033
7034 for err := range errCh {
7035 t.Errorf("error occurred: %v", err)
7036 }
7037 }
7038
7039
7040
7041
7042 func TestAltProtoCancellation(t *testing.T) {
7043 defer afterTest(t)
7044 tr := &Transport{}
7045 c := &Client{
7046 Transport: tr,
7047 Timeout: time.Millisecond,
7048 }
7049 tr.RegisterProtocol("cancel", cancelProto{})
7050 _, err := c.Get("cancel://bar.com/path")
7051 if err == nil {
7052 t.Error("request unexpectedly succeeded")
7053 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
7054 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
7055 }
7056 }
7057
7058 var errCancelProto = errors.New("canceled as expected")
7059
7060 type cancelProto struct{}
7061
7062 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
7063 <-req.Cancel
7064 return nil, errCancelProto
7065 }
7066
7067 type roundTripFunc func(r *Request) (*Response, error)
7068
7069 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
7070
7071
7072 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
7073 func testIssue32441(t *testing.T, mode testMode) {
7074 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7075 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
7076 t.Error("body length is zero")
7077 }
7078 })).ts
7079 c := ts.Client()
7080 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
7081
7082 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
7083 t.Error("body length is zero during round trip")
7084 }
7085 return nil, ErrSkipAltProtocol
7086 }))
7087 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
7088 t.Error(err)
7089 }
7090 }
7091
7092
7093
7094 func TestTransportRejectsSignInContentLength(t *testing.T) {
7095 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
7096 }
7097 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
7098 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7099 w.Header().Set("Content-Length", "+3")
7100 w.Write([]byte("abc"))
7101 })).ts
7102
7103 c := cst.Client()
7104 res, err := c.Get(cst.URL)
7105 if err == nil || res != nil {
7106 t.Fatal("Expected a non-nil error and a nil http.Response")
7107 }
7108 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
7109 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
7110 }
7111 }
7112
7113
7114 type dumpConn struct {
7115 io.Writer
7116 io.Reader
7117 }
7118
7119 func (c *dumpConn) Close() error { return nil }
7120 func (c *dumpConn) LocalAddr() net.Addr { return nil }
7121 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
7122 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
7123 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
7124 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
7125
7126
7127
7128 type delegateReader struct {
7129 c chan io.Reader
7130 r io.Reader
7131 }
7132
7133 func (r *delegateReader) Read(p []byte) (int, error) {
7134 if r.r == nil {
7135 var ok bool
7136 if r.r, ok = <-r.c; !ok {
7137 return 0, errors.New("delegate closed")
7138 }
7139 }
7140 return r.r.Read(p)
7141 }
7142
7143 func testTransportRace(req *Request) {
7144 save := req.Body
7145 pr, pw := io.Pipe()
7146 defer pr.Close()
7147 defer pw.Close()
7148 dr := &delegateReader{c: make(chan io.Reader)}
7149
7150 t := &Transport{
7151 Dial: func(net, addr string) (net.Conn, error) {
7152 return &dumpConn{pw, dr}, nil
7153 },
7154 }
7155 defer t.CloseIdleConnections()
7156
7157 quitReadCh := make(chan struct{})
7158
7159 go func() {
7160 defer close(quitReadCh)
7161
7162 req, err := ReadRequest(bufio.NewReader(pr))
7163 if err == nil {
7164
7165
7166 io.Copy(io.Discard, req.Body)
7167 req.Body.Close()
7168 }
7169 select {
7170 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
7171 case quitReadCh <- struct{}{}:
7172
7173 close(dr.c)
7174 }
7175 }()
7176
7177 t.RoundTrip(req)
7178
7179
7180
7181 pw.Close()
7182 <-quitReadCh
7183
7184 req.Body = save
7185 }
7186
7187
7188
7189
7190
7191 func TestErrorWriteLoopRace(t *testing.T) {
7192 if testing.Short() {
7193 return
7194 }
7195 t.Parallel()
7196 for i := 0; i < 1000; i++ {
7197 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
7198 ctx, cancel := context.WithTimeout(context.Background(), delay)
7199 defer cancel()
7200
7201 r := bytes.NewBuffer(make([]byte, 10000))
7202 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
7203 if err != nil {
7204 t.Fatal(err)
7205 }
7206
7207 testTransportRace(req)
7208 }
7209 }
7210
7211
7212
7213
7214 func TestCancelRequestWhenSharingConnection(t *testing.T) {
7215 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
7216 }
7217 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
7218 reqc := make(chan chan struct{}, 2)
7219 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
7220 ch := make(chan struct{}, 1)
7221 reqc <- ch
7222 <-ch
7223 w.Header().Add("Content-Length", "0")
7224 })).ts
7225
7226 client := ts.Client()
7227 transport := client.Transport.(*Transport)
7228 transport.MaxIdleConns = 1
7229 transport.MaxConnsPerHost = 1
7230
7231 var wg sync.WaitGroup
7232
7233 wg.Add(1)
7234 putidlec := make(chan chan struct{}, 1)
7235 reqerrc := make(chan error, 1)
7236 go func() {
7237 defer wg.Done()
7238 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7239 PutIdleConn: func(error) {
7240
7241
7242 ch := make(chan struct{})
7243 putidlec <- ch
7244 close(putidlec)
7245 <-ch
7246 },
7247 })
7248 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
7249 res, err := client.Do(req)
7250 if err != nil {
7251 reqerrc <- err
7252 } else {
7253 res.Body.Close()
7254 }
7255 }()
7256
7257
7258
7259 select {
7260 case err := <-reqerrc:
7261 t.Fatalf("request 1: got err %v, want nil", err)
7262 case r1c := <-reqc:
7263 close(r1c)
7264 }
7265 var idlec chan struct{}
7266 select {
7267 case err := <-reqerrc:
7268 t.Fatalf("request 1: got err %v, want nil", err)
7269 case idlec = <-putidlec:
7270 }
7271
7272 wg.Add(1)
7273 cancelctx, cancel := context.WithCancel(context.Background())
7274 go func() {
7275 defer wg.Done()
7276 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
7277 res, err := client.Do(req)
7278 if err == nil {
7279 res.Body.Close()
7280 }
7281 if !errors.Is(err, context.Canceled) {
7282 t.Errorf("request 2: got err %v, want Canceled", err)
7283 }
7284
7285
7286 close(idlec)
7287 }()
7288
7289
7290
7291 r2c := <-reqc
7292 cancel()
7293
7294 <-idlec
7295
7296 close(r2c)
7297 wg.Wait()
7298 }
7299
7300 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
7301 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
7302 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7303 go io.Copy(io.Discard, req.Body)
7304 panic(ErrAbortHandler)
7305 })).ts
7306
7307 var wg sync.WaitGroup
7308 for i := 0; i < 2; i++ {
7309 wg.Add(1)
7310 go func() {
7311 defer wg.Done()
7312 for j := 0; j < 10; j++ {
7313 const reqLen = 6 * 1024 * 1024
7314 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7315 req.ContentLength = reqLen
7316 resp, _ := ts.Client().Transport.RoundTrip(req)
7317 if resp != nil {
7318 resp.Body.Close()
7319 }
7320 }
7321 }()
7322 }
7323 wg.Wait()
7324 }
7325
7326 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7327 func testRequestSanitization(t *testing.T, mode testMode) {
7328 if mode == http2Mode {
7329
7330 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7331 }
7332 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7333 if h, ok := req.Header["X-Evil"]; ok {
7334 t.Errorf("request has X-Evil header: %q", h)
7335 }
7336 })).ts
7337 req, _ := NewRequest("GET", ts.URL, nil)
7338 req.Host = "go.dev\r\nX-Evil:evil"
7339 resp, _ := ts.Client().Do(req)
7340 if resp != nil {
7341 resp.Body.Close()
7342 }
7343 }
7344
7345 func TestProxyAuthHeader(t *testing.T) {
7346
7347 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7348 }
7349 func testProxyAuthHeader(t *testing.T, mode testMode) {
7350 const username = "u"
7351 const password = "@/?!"
7352 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7353
7354
7355 var r2 Request
7356 r2.Header = Header{
7357 "Authorization": req.Header["Proxy-Authorization"],
7358 }
7359 gotuser, gotpass, ok := r2.BasicAuth()
7360 if !ok || gotuser != username || gotpass != password {
7361 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7362 }
7363 }))
7364 u, err := url.Parse(cst.ts.URL)
7365 if err != nil {
7366 t.Fatal(err)
7367 }
7368 u.User = url.UserPassword(username, password)
7369 t.Setenv("HTTP_PROXY", u.String())
7370 cst.tr.Proxy = ProxyURL(u)
7371 resp, err := cst.c.Get("http://_/")
7372 if err != nil {
7373 t.Fatal(err)
7374 }
7375 resp.Body.Close()
7376 }
7377
7378
7379 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7380 ln := newLocalListener(t)
7381 addr := ln.Addr().String()
7382
7383 done := make(chan struct{})
7384 go func() {
7385 conn, err := ln.Accept()
7386 if err != nil {
7387 t.Errorf("ln.Accept: %v", err)
7388 return
7389 }
7390
7391
7392 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7393 t.Errorf("conn.Read: %v", err)
7394 return
7395 }
7396 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7397 <-done
7398 conn.Close()
7399 }()
7400
7401 didRead := make(chan bool)
7402 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7403 defer SetReadLoopBeforeNextReadHook(nil)
7404
7405 tr := &Transport{}
7406
7407
7408 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7409 if err != nil {
7410 t.Fatalf("NewRequest: %v", err)
7411 }
7412
7413 resp, err := tr.RoundTrip(req)
7414 if err != nil {
7415 t.Fatalf("tr.RoundTrip: %v", err)
7416 }
7417
7418 close(done)
7419
7420
7421
7422 <-didRead
7423
7424 resp.Body.Close()
7425
7426
7427
7428 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7429 n := tr.NumPendingRequestsForTesting()
7430 if n > 0 {
7431 if d > 0 {
7432 t.Logf("pending requests = %d after %v (want 0)", n, d)
7433 }
7434 return false
7435 }
7436 return true
7437 })
7438 }
7439
7440 func TestValidateClientRequestTrailers(t *testing.T) {
7441 run(t, testValidateClientRequestTrailers)
7442 }
7443
7444 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7445 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7446 rw.Write([]byte("Hello"))
7447 })).ts
7448
7449 cases := []struct {
7450 trailer Header
7451 wantErr string
7452 }{
7453 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7454 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7455 }
7456
7457 for i, tt := range cases {
7458 testName := fmt.Sprintf("%s%d", mode, i)
7459 t.Run(testName, func(t *testing.T) {
7460 req, err := NewRequest("GET", cst.URL, nil)
7461 if err != nil {
7462 t.Fatal(err)
7463 }
7464 req.Trailer = tt.trailer
7465 res, err := cst.Client().Do(req)
7466 if err == nil {
7467 t.Fatal("Expected an error")
7468 }
7469 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7470 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7471 }
7472 if res != nil {
7473 t.Fatal("Unexpected non-nil response")
7474 }
7475 })
7476 }
7477 }
7478
7479 func TestTransportServerProtocols(t *testing.T) {
7480 CondSkipHTTP2(t)
7481 DefaultTransport.(*Transport).CloseIdleConnections()
7482
7483 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7484 if err != nil {
7485 t.Fatal(err)
7486 }
7487 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7488 if err != nil {
7489 t.Fatal(err)
7490 }
7491 certpool := x509.NewCertPool()
7492 certpool.AddCert(leafCert)
7493
7494 for _, test := range []struct {
7495 name string
7496 scheme string
7497 setup func(t *testing.T)
7498 transport func(*Transport)
7499 server func(*Server)
7500 want string
7501 }{{
7502 name: "http default",
7503 scheme: "http",
7504 want: "HTTP/1.1",
7505 }, {
7506 name: "https default",
7507 scheme: "https",
7508 transport: func(tr *Transport) {
7509
7510 },
7511 want: "HTTP/1.1",
7512 }, {
7513 name: "https transport protocols include HTTP2",
7514 scheme: "https",
7515 transport: func(tr *Transport) {
7516
7517
7518 tr.Protocols = &Protocols{}
7519 tr.Protocols.SetHTTP1(true)
7520 tr.Protocols.SetHTTP2(true)
7521 },
7522 want: "HTTP/2.0",
7523 }, {
7524 name: "https transport protocols only include HTTP1",
7525 scheme: "https",
7526 transport: func(tr *Transport) {
7527
7528 tr.Protocols = &Protocols{}
7529 tr.Protocols.SetHTTP1(true)
7530 },
7531 want: "HTTP/1.1",
7532 }, {
7533 name: "https transport ForceAttemptHTTP2",
7534 scheme: "https",
7535 transport: func(tr *Transport) {
7536
7537 tr.ForceAttemptHTTP2 = true
7538 },
7539 want: "HTTP/2.0",
7540 }, {
7541 name: "https transport protocols override TLSNextProto",
7542 scheme: "https",
7543 transport: func(tr *Transport) {
7544
7545
7546
7547 tr.Protocols = &Protocols{}
7548 tr.Protocols.SetHTTP1(true)
7549 tr.Protocols.SetHTTP2(true)
7550 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7551 },
7552 want: "HTTP/2.0",
7553 }, {
7554 name: "https server disables HTTP2 with TLSNextProto",
7555 scheme: "https",
7556 server: func(srv *Server) {
7557
7558
7559 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7560 },
7561 want: "HTTP/1.1",
7562 }, {
7563 name: "https server Protocols overrides empty TLSNextProto",
7564 scheme: "https",
7565 server: func(srv *Server) {
7566
7567
7568 srv.Protocols = &Protocols{}
7569 srv.Protocols.SetHTTP1(true)
7570 srv.Protocols.SetHTTP2(true)
7571 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7572 },
7573 want: "HTTP/2.0",
7574 }, {
7575 name: "https server protocols only include HTTP1",
7576 scheme: "https",
7577 server: func(srv *Server) {
7578 srv.Protocols = &Protocols{}
7579 srv.Protocols.SetHTTP1(true)
7580 },
7581 want: "HTTP/1.1",
7582 }, {
7583 name: "https server protocols include HTTP2",
7584 scheme: "https",
7585 server: func(srv *Server) {
7586 srv.Protocols = &Protocols{}
7587 srv.Protocols.SetHTTP1(true)
7588 srv.Protocols.SetHTTP2(true)
7589 },
7590 want: "HTTP/2.0",
7591 }, {
7592 name: "GODEBUG disables HTTP2 client",
7593 scheme: "https",
7594 setup: func(t *testing.T) {
7595 t.Setenv("GODEBUG", "http2client=0")
7596 },
7597 transport: func(tr *Transport) {
7598
7599
7600 tr.Protocols = &Protocols{}
7601 tr.Protocols.SetHTTP1(true)
7602 tr.Protocols.SetHTTP2(true)
7603 },
7604 want: "HTTP/1.1",
7605 }, {
7606 name: "GODEBUG disables HTTP2 server",
7607 scheme: "https",
7608 setup: func(t *testing.T) {
7609 t.Setenv("GODEBUG", "http2server=0")
7610 },
7611 transport: func(tr *Transport) {
7612
7613
7614 tr.Protocols = &Protocols{}
7615 tr.Protocols.SetHTTP1(true)
7616 tr.Protocols.SetHTTP2(true)
7617 },
7618 want: "HTTP/1.1",
7619 }, {
7620 name: "unencrypted HTTP2 with prior knowledge",
7621 scheme: "http",
7622 transport: func(tr *Transport) {
7623 tr.Protocols = &Protocols{}
7624 tr.Protocols.SetUnencryptedHTTP2(true)
7625 },
7626 server: func(srv *Server) {
7627 srv.Protocols = &Protocols{}
7628 srv.Protocols.SetHTTP1(true)
7629 srv.Protocols.SetUnencryptedHTTP2(true)
7630 },
7631 want: "HTTP/2.0",
7632 }, {
7633 name: "unencrypted HTTP2 only on server",
7634 scheme: "http",
7635 transport: func(tr *Transport) {
7636 tr.Protocols = &Protocols{}
7637 tr.Protocols.SetUnencryptedHTTP2(true)
7638 },
7639 server: func(srv *Server) {
7640 srv.Protocols = &Protocols{}
7641 srv.Protocols.SetUnencryptedHTTP2(true)
7642 },
7643 want: "HTTP/2.0",
7644 }, {
7645 name: "unencrypted HTTP2 with no server support",
7646 scheme: "http",
7647 transport: func(tr *Transport) {
7648 tr.Protocols = &Protocols{}
7649 tr.Protocols.SetUnencryptedHTTP2(true)
7650 },
7651 server: func(srv *Server) {
7652 srv.Protocols = &Protocols{}
7653 srv.Protocols.SetHTTP1(true)
7654 },
7655 want: "error",
7656 }, {
7657 name: "HTTP1 with no server support",
7658 scheme: "http",
7659 transport: func(tr *Transport) {
7660 tr.Protocols = &Protocols{}
7661 tr.Protocols.SetHTTP1(true)
7662 },
7663 server: func(srv *Server) {
7664 srv.Protocols = &Protocols{}
7665 srv.Protocols.SetUnencryptedHTTP2(true)
7666 },
7667 want: "error",
7668 }, {
7669 name: "HTTPS1 with no server support",
7670 scheme: "https",
7671 transport: func(tr *Transport) {
7672 tr.Protocols = &Protocols{}
7673 tr.Protocols.SetHTTP1(true)
7674 },
7675 server: func(srv *Server) {
7676 srv.Protocols = &Protocols{}
7677 srv.Protocols.SetHTTP2(true)
7678 },
7679 want: "error",
7680 }} {
7681 t.Run(test.name, func(t *testing.T) {
7682
7683
7684 srv := &Server{
7685 TLSConfig: &tls.Config{
7686 Certificates: []tls.Certificate{cert},
7687 },
7688 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7689 w.Header().Set("X-Proto", req.Proto)
7690 }),
7691 }
7692 tr := &Transport{
7693 TLSClientConfig: &tls.Config{
7694 RootCAs: certpool,
7695 },
7696 }
7697
7698 if test.setup != nil {
7699 test.setup(t)
7700 }
7701 if test.server != nil {
7702 test.server(srv)
7703 }
7704 if test.transport != nil {
7705 test.transport(tr)
7706 } else {
7707 tr.Protocols = &Protocols{}
7708 tr.Protocols.SetHTTP1(true)
7709 tr.Protocols.SetHTTP2(true)
7710 }
7711
7712 listener := newLocalListener(t)
7713 srvc := make(chan error, 1)
7714 go func() {
7715 switch test.scheme {
7716 case "http":
7717 srvc <- srv.Serve(listener)
7718 case "https":
7719 srvc <- srv.ServeTLS(listener, "", "")
7720 }
7721 }()
7722 t.Cleanup(func() {
7723 srv.Close()
7724 <-srvc
7725 })
7726
7727 client := &Client{Transport: tr}
7728 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7729 if err != nil {
7730 if test.want == "error" {
7731 return
7732 }
7733 t.Fatal(err)
7734 }
7735 if got := resp.Header.Get("X-Proto"); got != test.want {
7736 t.Fatalf("request proto %q, want %q", got, test.want)
7737 }
7738 })
7739 }
7740 }
7741
7742 func TestIssue61474(t *testing.T) {
7743 run(t, testIssue61474, []testMode{http2Mode})
7744 }
7745 func testIssue61474(t *testing.T, mode testMode) {
7746 if testing.Short() {
7747 return
7748 }
7749
7750
7751
7752
7753 t.Skip("test is too large")
7754
7755 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7756 }), func(tr *Transport) {
7757 tr.MaxConnsPerHost = 1
7758 })
7759 var wg sync.WaitGroup
7760 defer wg.Wait()
7761 for range 100000 {
7762 wg.Go(func() {
7763 ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond)
7764 defer cancel()
7765 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
7766 resp, err := cst.c.Do(req)
7767 if err == nil {
7768 resp.Body.Close()
7769 }
7770 })
7771 }
7772 }
7773
View as plain text