Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "internal/synctest"
26 "io"
27 "log"
28 mrand "math/rand"
29 "net"
30 . "net/http"
31 "net/http/httptest"
32 "net/http/httptrace"
33 "net/http/httputil"
34 "net/http/internal/testcert"
35 "net/textproto"
36 "net/url"
37 "os"
38 "reflect"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "testing"
46 "testing/iotest"
47 "time"
48
49 "golang.org/x/net/http/httpguts"
50 )
51
52
53
54
55
56 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
57 if r.FormValue("close") == "true" {
58 w.Header().Set("Connection", "close")
59 }
60 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
61 w.Write([]byte(r.RemoteAddr))
62
63
64
65 if c, ok := ResponseWriterConnForTesting(w); ok {
66 fmt.Fprintf(w, ", %T %p", c, c)
67 }
68 })
69
70
71 type testCloseConn struct {
72 net.Conn
73 set *testConnSet
74 }
75
76 func (c *testCloseConn) Close() error {
77 c.set.remove(c)
78 return c.Conn.Close()
79 }
80
81
82
83 type testConnSet struct {
84 t *testing.T
85 mu sync.Mutex
86 closed map[net.Conn]bool
87 list []net.Conn
88 }
89
90 func (tcs *testConnSet) insert(c net.Conn) {
91 tcs.mu.Lock()
92 defer tcs.mu.Unlock()
93 tcs.closed[c] = false
94 tcs.list = append(tcs.list, c)
95 }
96
97 func (tcs *testConnSet) remove(c net.Conn) {
98 tcs.mu.Lock()
99 defer tcs.mu.Unlock()
100 tcs.closed[c] = true
101 }
102
103
104 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
105 connSet := &testConnSet{
106 t: t,
107 closed: make(map[net.Conn]bool),
108 }
109 dial := func(n, addr string) (net.Conn, error) {
110 c, err := net.Dial(n, addr)
111 if err != nil {
112 return nil, err
113 }
114 tc := &testCloseConn{c, connSet}
115 connSet.insert(tc)
116 return tc, nil
117 }
118 return connSet, dial
119 }
120
121 func (tcs *testConnSet) check(t *testing.T) {
122 tcs.mu.Lock()
123 defer tcs.mu.Unlock()
124 for i := 4; i >= 0; i-- {
125 for i, c := range tcs.list {
126 if tcs.closed[c] {
127 continue
128 }
129 if i != 0 {
130
131
132 tcs.mu.Unlock()
133 time.Sleep(50 * time.Millisecond)
134 tcs.mu.Lock()
135 continue
136 }
137 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
138 }
139 }
140 }
141
142 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
143 func testReuseRequest(t *testing.T, mode testMode) {
144 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
145 w.Write([]byte("{}"))
146 })).ts
147
148 c := ts.Client()
149 req, _ := NewRequest("GET", ts.URL, nil)
150 res, err := c.Do(req)
151 if err != nil {
152 t.Fatal(err)
153 }
154 err = res.Body.Close()
155 if err != nil {
156 t.Fatal(err)
157 }
158
159 res, err = c.Do(req)
160 if err != nil {
161 t.Fatal(err)
162 }
163 err = res.Body.Close()
164 if err != nil {
165 t.Fatal(err)
166 }
167 }
168
169
170
171 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
172 func testTransportKeepAlives(t *testing.T, mode testMode) {
173 ts := newClientServerTest(t, mode, hostPortHandler).ts
174
175 c := ts.Client()
176 for _, disableKeepAlive := range []bool{false, true} {
177 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
178 fetch := func(n int) string {
179 res, err := c.Get(ts.URL)
180 if err != nil {
181 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
182 }
183 body, err := io.ReadAll(res.Body)
184 if err != nil {
185 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
186 }
187 return string(body)
188 }
189
190 body1 := fetch(1)
191 body2 := fetch(2)
192
193 bodiesDiffer := body1 != body2
194 if bodiesDiffer != disableKeepAlive {
195 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
196 disableKeepAlive, bodiesDiffer, body1, body2)
197 }
198 }
199 }
200
201 func TestTransportConnectionCloseOnResponse(t *testing.T) {
202 run(t, testTransportConnectionCloseOnResponse)
203 }
204 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
205 ts := newClientServerTest(t, mode, hostPortHandler).ts
206
207 connSet, testDial := makeTestDial(t)
208
209 c := ts.Client()
210 tr := c.Transport.(*Transport)
211 tr.Dial = testDial
212
213 for _, connectionClose := range []bool{false, true} {
214 fetch := func(n int) string {
215 req := new(Request)
216 var err error
217 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
218 if err != nil {
219 t.Fatalf("URL parse error: %v", err)
220 }
221 req.Method = "GET"
222 req.Proto = "HTTP/1.1"
223 req.ProtoMajor = 1
224 req.ProtoMinor = 1
225
226 res, err := c.Do(req)
227 if err != nil {
228 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
229 }
230 defer res.Body.Close()
231 body, err := io.ReadAll(res.Body)
232 if err != nil {
233 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
234 }
235 return string(body)
236 }
237
238 body1 := fetch(1)
239 body2 := fetch(2)
240 bodiesDiffer := body1 != body2
241 if bodiesDiffer != connectionClose {
242 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
243 connectionClose, bodiesDiffer, body1, body2)
244 }
245
246 tr.CloseIdleConnections()
247 }
248
249 connSet.check(t)
250 }
251
252
253
254
255
256
257
258 func TestTransportConnectionCloseOnRequest(t *testing.T) {
259 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
260 }
261 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
262 ts := newClientServerTest(t, mode, hostPortHandler).ts
263
264 connSet, testDial := makeTestDial(t)
265
266 c := ts.Client()
267 tr := c.Transport.(*Transport)
268 tr.Dial = testDial
269 for _, reqClose := range []bool{false, true} {
270 fetch := func(n int) string {
271 req := new(Request)
272 var err error
273 req.URL, err = url.Parse(ts.URL)
274 if err != nil {
275 t.Fatalf("URL parse error: %v", err)
276 }
277 req.Method = "GET"
278 req.Proto = "HTTP/1.1"
279 req.ProtoMajor = 1
280 req.ProtoMinor = 1
281 req.Close = reqClose
282
283 res, err := c.Do(req)
284 if err != nil {
285 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
286 }
287 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
288 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
289 reqClose, got, !reqClose)
290 }
291 body, err := io.ReadAll(res.Body)
292 if err != nil {
293 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
294 }
295 return string(body)
296 }
297
298 body1 := fetch(1)
299 body2 := fetch(2)
300
301 got := 1
302 if body1 != body2 {
303 got++
304 }
305 want := 1
306 if reqClose {
307 want = 2
308 }
309 if got != want {
310 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
311 reqClose, got, want, body1, body2)
312 }
313
314 tr.CloseIdleConnections()
315 }
316
317 connSet.check(t)
318 }
319
320
321
322
323 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
324 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
325 }
326 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
327 ts := newClientServerTest(t, mode, hostPortHandler).ts
328
329 c := ts.Client()
330 c.Transport.(*Transport).DisableKeepAlives = true
331
332 res, err := c.Get(ts.URL)
333 if err != nil {
334 t.Fatal(err)
335 }
336 res.Body.Close()
337 if res.Header.Get("X-Saw-Close") != "true" {
338 t.Errorf("handler didn't see Connection: close ")
339 }
340 }
341
342
343
344 func TestTransportRespectRequestWantsClose(t *testing.T) {
345 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
346 }
347 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
348 tests := []struct {
349 disableKeepAlives bool
350 close bool
351 }{
352 {disableKeepAlives: false, close: false},
353 {disableKeepAlives: false, close: true},
354 {disableKeepAlives: true, close: false},
355 {disableKeepAlives: true, close: true},
356 }
357
358 for _, tc := range tests {
359 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
360 func(t *testing.T) {
361 ts := newClientServerTest(t, mode, hostPortHandler).ts
362
363 c := ts.Client()
364 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
365 req, err := NewRequest("GET", ts.URL, nil)
366 if err != nil {
367 t.Fatal(err)
368 }
369 count := 0
370 trace := &httptrace.ClientTrace{
371 WroteHeaderField: func(key string, field []string) {
372 if key != "Connection" {
373 return
374 }
375 if httpguts.HeaderValuesContainsToken(field, "close") {
376 count += 1
377 }
378 },
379 }
380 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
381 req.Close = tc.close
382 res, err := c.Do(req)
383 if err != nil {
384 t.Fatal(err)
385 }
386 defer res.Body.Close()
387 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
388 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
389 }
390 })
391 }
392
393 }
394
395 func TestTransportIdleCacheKeys(t *testing.T) {
396 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
397 }
398 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
399 ts := newClientServerTest(t, mode, hostPortHandler).ts
400 c := ts.Client()
401 tr := c.Transport.(*Transport)
402
403 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
404 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
405 }
406
407 resp, err := c.Get(ts.URL)
408 if err != nil {
409 t.Error(err)
410 }
411 io.ReadAll(resp.Body)
412
413 keys := tr.IdleConnKeysForTesting()
414 if e, g := 1, len(keys); e != g {
415 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
416 }
417
418 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
419 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
420 }
421
422 tr.CloseIdleConnections()
423 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
424 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
425 }
426 }
427
428
429
430 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
431 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
432 const msg = "foobar"
433
434 var addrSeen map[string]int
435 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
436 addrSeen[r.RemoteAddr]++
437 if r.URL.Path == "/chunked/" {
438 w.WriteHeader(200)
439 w.(Flusher).Flush()
440 } else {
441 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
442 w.WriteHeader(200)
443 }
444 w.Write([]byte(msg))
445 })).ts
446
447 for pi, path := range []string{"/content-length/", "/chunked/"} {
448 wantLen := []int{len(msg), -1}[pi]
449 addrSeen = make(map[string]int)
450 for i := 0; i < 3; i++ {
451 res, err := ts.Client().Get(ts.URL + path)
452 if err != nil {
453 t.Errorf("Get %s: %v", path, err)
454 continue
455 }
456
457
458
459
460
461 defer res.Body.Close()
462
463 if res.ContentLength != int64(wantLen) {
464 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
465 }
466 got, err := io.ReadAll(res.Body)
467 if string(got) != msg || err != nil {
468 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
469 }
470 }
471 if len(addrSeen) != 1 {
472 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
473 }
474 }
475 }
476
477 func TestTransportMaxPerHostIdleConns(t *testing.T) {
478 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
479 }
480 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
481 stop := make(chan struct{})
482 defer close(stop)
483
484 resch := make(chan string)
485 gotReq := make(chan bool)
486 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
487 gotReq <- true
488 var msg string
489 select {
490 case <-stop:
491 return
492 case msg = <-resch:
493 }
494 _, err := w.Write([]byte(msg))
495 if err != nil {
496 t.Errorf("Write: %v", err)
497 return
498 }
499 })).ts
500
501 c := ts.Client()
502 tr := c.Transport.(*Transport)
503 maxIdleConnsPerHost := 2
504 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
505
506
507
508 donech := make(chan bool)
509 doReq := func() {
510 defer func() {
511 select {
512 case <-stop:
513 return
514 case donech <- t.Failed():
515 }
516 }()
517 resp, err := c.Get(ts.URL)
518 if err != nil {
519 t.Error(err)
520 return
521 }
522 if _, err := io.ReadAll(resp.Body); err != nil {
523 t.Errorf("ReadAll: %v", err)
524 return
525 }
526 }
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531 go doReq()
532 <-gotReq
533
534 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
535 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
536 }
537
538 resch <- "res1"
539 <-donech
540 keys := tr.IdleConnKeysForTesting()
541 if e, g := 1, len(keys); e != g {
542 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
543 }
544 addr := ts.Listener.Addr().String()
545 cacheKey := "|http|" + addr
546 if keys[0] != cacheKey {
547 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
548 }
549 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
550 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
551 }
552
553 resch <- "res2"
554 <-donech
555 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
556 t.Errorf("after second response, idle conns = %d; want %d", g, w)
557 }
558
559 resch <- "res3"
560 <-donech
561 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
562 t.Errorf("after third response, idle conns = %d; want %d", g, w)
563 }
564 }
565
566 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
567 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
568 }
569 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
570 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
571 _, err := w.Write([]byte("foo"))
572 if err != nil {
573 t.Fatalf("Write: %v", err)
574 }
575 })).ts
576 c := ts.Client()
577 tr := c.Transport.(*Transport)
578 dialStarted := make(chan struct{})
579 stallDial := make(chan struct{})
580 tr.Dial = func(network, addr string) (net.Conn, error) {
581 dialStarted <- struct{}{}
582 <-stallDial
583 return net.Dial(network, addr)
584 }
585
586 tr.DisableKeepAlives = true
587 tr.MaxConnsPerHost = 1
588
589 preDial := make(chan struct{})
590 reqComplete := make(chan struct{})
591 doReq := func(reqId string) {
592 req, _ := NewRequest("GET", ts.URL, nil)
593 trace := &httptrace.ClientTrace{
594 GetConn: func(hostPort string) {
595 preDial <- struct{}{}
596 },
597 }
598 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
599 resp, err := tr.RoundTrip(req)
600 if err != nil {
601 t.Errorf("unexpected error for request %s: %v", reqId, err)
602 }
603 _, err = io.ReadAll(resp.Body)
604 if err != nil {
605 t.Errorf("unexpected error for request %s: %v", reqId, err)
606 }
607 reqComplete <- struct{}{}
608 }
609
610 go doReq("req1")
611 <-preDial
612 <-dialStarted
613
614
615 go doReq("req2")
616 <-preDial
617 select {
618 case <-dialStarted:
619 t.Error("req2 dial started while req1 dial in progress")
620 return
621 default:
622 }
623
624
625 stallDial <- struct{}{}
626 <-reqComplete
627
628
629 <-dialStarted
630 stallDial <- struct{}{}
631 <-reqComplete
632 }
633
634 func TestTransportMaxConnsPerHost(t *testing.T) {
635 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
636 }
637 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
638 CondSkipHTTP2(t)
639
640 h := HandlerFunc(func(w ResponseWriter, r *Request) {
641 _, err := w.Write([]byte("foo"))
642 if err != nil {
643 t.Fatalf("Write: %v", err)
644 }
645 })
646
647 ts := newClientServerTest(t, mode, h).ts
648 c := ts.Client()
649 tr := c.Transport.(*Transport)
650 tr.MaxConnsPerHost = 1
651
652 mu := sync.Mutex{}
653 var conns []net.Conn
654 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
655 tr.Dial = func(network, addr string) (net.Conn, error) {
656 atomic.AddInt32(&dialCnt, 1)
657 c, err := net.Dial(network, addr)
658 mu.Lock()
659 defer mu.Unlock()
660 conns = append(conns, c)
661 return c, err
662 }
663
664 doReq := func() {
665 trace := &httptrace.ClientTrace{
666 GotConn: func(connInfo httptrace.GotConnInfo) {
667 if !connInfo.Reused {
668 atomic.AddInt32(&gotConnCnt, 1)
669 }
670 },
671 TLSHandshakeStart: func() {
672 atomic.AddInt32(&tlsHandshakeCnt, 1)
673 },
674 }
675 req, _ := NewRequest("GET", ts.URL, nil)
676 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
677
678 resp, err := c.Do(req)
679 if err != nil {
680 t.Fatalf("request failed: %v", err)
681 }
682 defer resp.Body.Close()
683 _, err = io.ReadAll(resp.Body)
684 if err != nil {
685 t.Fatalf("read body failed: %v", err)
686 }
687 }
688
689 wg := sync.WaitGroup{}
690 for i := 0; i < 10; i++ {
691 wg.Add(1)
692 go func() {
693 defer wg.Done()
694 doReq()
695 }()
696 }
697 wg.Wait()
698
699 expected := int32(tr.MaxConnsPerHost)
700 if dialCnt != expected {
701 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
702 }
703 if gotConnCnt != expected {
704 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
705 }
706 if ts.TLS != nil && tlsHandshakeCnt != expected {
707 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
708 }
709
710 if t.Failed() {
711 t.FailNow()
712 }
713
714 mu.Lock()
715 for _, c := range conns {
716 c.Close()
717 }
718 conns = nil
719 mu.Unlock()
720 tr.CloseIdleConnections()
721
722 doReq()
723 expected++
724 if dialCnt != expected {
725 t.Errorf("round 2: too many dials: %d", dialCnt)
726 }
727 if gotConnCnt != expected {
728 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
729 }
730 if ts.TLS != nil && tlsHandshakeCnt != expected {
731 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
732 }
733 }
734
735 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
736 run(t, testTransportMaxConnsPerHostDialCancellation,
737 testNotParallel,
738 []testMode{http1Mode, https1Mode, http2Mode},
739 )
740 }
741
742 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
743 CondSkipHTTP2(t)
744
745 h := HandlerFunc(func(w ResponseWriter, r *Request) {
746 _, err := w.Write([]byte("foo"))
747 if err != nil {
748 t.Fatalf("Write: %v", err)
749 }
750 })
751
752 cst := newClientServerTest(t, mode, h)
753 defer cst.close()
754 ts := cst.ts
755 c := ts.Client()
756 tr := c.Transport.(*Transport)
757 tr.MaxConnsPerHost = 1
758
759
760 ctx, cancel := context.WithCancel(context.Background())
761 defer cancel()
762 SetPendingDialHooks(cancel, nil)
763 defer SetPendingDialHooks(nil, nil)
764
765 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
766 _, err := c.Do(req)
767 if !errors.Is(err, context.Canceled) {
768 t.Errorf("expected error %v, got %v", context.Canceled, err)
769 }
770
771
772 SetPendingDialHooks(nil, nil)
773 req, _ = NewRequest("GET", ts.URL, nil)
774 resp, err := c.Do(req)
775 if err != nil {
776 t.Fatalf("request failed: %v", err)
777 }
778 defer resp.Body.Close()
779 _, err = io.ReadAll(resp.Body)
780 if err != nil {
781 t.Fatalf("read body failed: %v", err)
782 }
783 }
784
785 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
786 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
787 }
788 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
789 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
790 io.WriteString(w, r.RemoteAddr)
791 })).ts
792
793 c := ts.Client()
794 tr := c.Transport.(*Transport)
795
796 doReq := func(name string) {
797
798
799 res, err := c.Post(ts.URL, "", nil)
800 if err != nil {
801 t.Fatalf("%s: %v", name, err)
802 }
803 if res.StatusCode != 200 {
804 t.Fatalf("%s: %v", name, res.Status)
805 }
806 defer res.Body.Close()
807 slurp, err := io.ReadAll(res.Body)
808 if err != nil {
809 t.Fatalf("%s: %v", name, err)
810 }
811 t.Logf("%s: ok (%q)", name, slurp)
812 }
813
814 doReq("first")
815 keys1 := tr.IdleConnKeysForTesting()
816
817 ts.CloseClientConnections()
818
819 var keys2 []string
820 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
821 keys2 = tr.IdleConnKeysForTesting()
822 if len(keys2) != 0 {
823 if d > 0 {
824 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
825 }
826 return false
827 }
828 return true
829 })
830
831 doReq("second")
832 }
833
834
835
836 func TestTransportServerClosingUnexpectedly(t *testing.T) {
837 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
838 }
839 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
840 ts := newClientServerTest(t, mode, hostPortHandler).ts
841 c := ts.Client()
842
843 fetch := func(n, retries int) string {
844 condFatalf := func(format string, arg ...any) {
845 if retries <= 0 {
846 t.Fatalf(format, arg...)
847 }
848 t.Logf("retrying shortly after expected error: "+format, arg...)
849 time.Sleep(time.Second / time.Duration(retries))
850 }
851 for retries >= 0 {
852 retries--
853 res, err := c.Get(ts.URL)
854 if err != nil {
855 condFatalf("error in req #%d, GET: %v", n, err)
856 continue
857 }
858 body, err := io.ReadAll(res.Body)
859 if err != nil {
860 condFatalf("error in req #%d, ReadAll: %v", n, err)
861 continue
862 }
863 res.Body.Close()
864 return string(body)
865 }
866 panic("unreachable")
867 }
868
869 body1 := fetch(1, 0)
870 body2 := fetch(2, 0)
871
872
873
874
875
876
877
878
879 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
880
881 body3 := fetch(3, 5)
882
883 if body1 != body2 {
884 t.Errorf("expected body1 and body2 to be equal")
885 }
886 if body2 == body3 {
887 t.Errorf("expected body2 and body3 to be different")
888 }
889 }
890
891
892
893 func TestStressSurpriseServerCloses(t *testing.T) {
894 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
895 }
896 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
897 if testing.Short() {
898 t.Skip("skipping test in short mode")
899 }
900 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
901 w.Header().Set("Content-Length", "5")
902 w.Header().Set("Content-Type", "text/plain")
903 w.Write([]byte("Hello"))
904 w.(Flusher).Flush()
905 conn, buf, _ := w.(Hijacker).Hijack()
906 buf.Flush()
907 conn.Close()
908 })).ts
909 c := ts.Client()
910
911
912
913
914
915
916
917 const (
918 numClients = 20
919 reqsPerClient = 25
920 )
921 var wg sync.WaitGroup
922 wg.Add(numClients * reqsPerClient)
923 for i := 0; i < numClients; i++ {
924 go func() {
925 for i := 0; i < reqsPerClient; i++ {
926 res, err := c.Get(ts.URL)
927 if err == nil {
928
929
930
931
932
933
934 res.Body.Close()
935 }
936 wg.Done()
937 }
938 }()
939 }
940
941
942 wg.Wait()
943 }
944
945
946
947 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
948 func testTransportHeadResponses(t *testing.T, mode testMode) {
949 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
950 if r.Method != "HEAD" {
951 panic("expected HEAD; got " + r.Method)
952 }
953 w.Header().Set("Content-Length", "123")
954 w.WriteHeader(200)
955 })).ts
956 c := ts.Client()
957
958 for i := 0; i < 2; i++ {
959 res, err := c.Head(ts.URL)
960 if err != nil {
961 t.Errorf("error on loop %d: %v", i, err)
962 continue
963 }
964 if e, g := "123", res.Header.Get("Content-Length"); e != g {
965 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
966 }
967 if e, g := int64(123), res.ContentLength; e != g {
968 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
969 }
970 if all, err := io.ReadAll(res.Body); err != nil {
971 t.Errorf("loop %d: Body ReadAll: %v", i, err)
972 } else if len(all) != 0 {
973 t.Errorf("Bogus body %q", all)
974 }
975 }
976 }
977
978
979
980 func TestTransportHeadChunkedResponse(t *testing.T) {
981 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
982 }
983 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
984 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
985 if r.Method != "HEAD" {
986 panic("expected HEAD; got " + r.Method)
987 }
988 w.Header().Set("Transfer-Encoding", "chunked")
989 w.Header().Set("x-client-ipport", r.RemoteAddr)
990 w.WriteHeader(200)
991 })).ts
992 c := ts.Client()
993
994
995
996 didRead := make(chan bool)
997 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
998 defer SetReadLoopBeforeNextReadHook(nil)
999
1000 res1, err := c.Head(ts.URL)
1001 <-didRead
1002
1003 if err != nil {
1004 t.Fatalf("request 1 error: %v", err)
1005 }
1006
1007 res2, err := c.Head(ts.URL)
1008 <-didRead
1009
1010 if err != nil {
1011 t.Fatalf("request 2 error: %v", err)
1012 }
1013 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1014 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1015 }
1016 }
1017
1018 var roundTripTests = []struct {
1019 accept string
1020 expectAccept string
1021 compressed bool
1022 }{
1023
1024 {"", "gzip", false},
1025
1026 {"foo", "foo", false},
1027
1028 {"gzip", "gzip", true},
1029 }
1030
1031
1032 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1033 func testRoundTripGzip(t *testing.T, mode testMode) {
1034 const responseBody = "test response body"
1035 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1036 accept := req.Header.Get("Accept-Encoding")
1037 if expect := req.FormValue("expect_accept"); accept != expect {
1038 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1039 req.FormValue("testnum"), accept, expect)
1040 }
1041 if accept == "gzip" {
1042 rw.Header().Set("Content-Encoding", "gzip")
1043 gz := gzip.NewWriter(rw)
1044 gz.Write([]byte(responseBody))
1045 gz.Close()
1046 } else {
1047 rw.Header().Set("Content-Encoding", accept)
1048 rw.Write([]byte(responseBody))
1049 }
1050 })).ts
1051 tr := ts.Client().Transport.(*Transport)
1052
1053 for i, test := range roundTripTests {
1054
1055 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1056 if test.accept != "" {
1057 req.Header.Set("Accept-Encoding", test.accept)
1058 }
1059 res, err := tr.RoundTrip(req)
1060 if err != nil {
1061 t.Errorf("%d. RoundTrip: %v", i, err)
1062 continue
1063 }
1064 var body []byte
1065 if test.compressed {
1066 var r *gzip.Reader
1067 r, err = gzip.NewReader(res.Body)
1068 if err != nil {
1069 t.Errorf("%d. gzip NewReader: %v", i, err)
1070 continue
1071 }
1072 body, err = io.ReadAll(r)
1073 res.Body.Close()
1074 } else {
1075 body, err = io.ReadAll(res.Body)
1076 }
1077 if err != nil {
1078 t.Errorf("%d. Error: %q", i, err)
1079 continue
1080 }
1081 if g, e := string(body), responseBody; g != e {
1082 t.Errorf("%d. body = %q; want %q", i, g, e)
1083 }
1084 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1085 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1086 }
1087 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1088 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1089 }
1090 }
1091
1092 }
1093
1094 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1095 func testTransportGzip(t *testing.T, mode testMode) {
1096 if mode == http2Mode {
1097 t.Skip("https://go.dev/issue/56020")
1098 }
1099 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1100 const nRandBytes = 1024 * 1024
1101 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1102 if req.Method == "HEAD" {
1103 if g := req.Header.Get("Accept-Encoding"); g != "" {
1104 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1105 }
1106 return
1107 }
1108 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1109 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1110 }
1111 rw.Header().Set("Content-Encoding", "gzip")
1112
1113 var w io.Writer = rw
1114 var buf bytes.Buffer
1115 if req.FormValue("chunked") == "0" {
1116 w = &buf
1117 defer io.Copy(rw, &buf)
1118 defer func() {
1119 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1120 }()
1121 }
1122 gz := gzip.NewWriter(w)
1123 gz.Write([]byte(testString))
1124 if req.FormValue("body") == "large" {
1125 io.CopyN(gz, rand.Reader, nRandBytes)
1126 }
1127 gz.Close()
1128 })).ts
1129 c := ts.Client()
1130
1131 for _, chunked := range []string{"1", "0"} {
1132
1133 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1134 if err != nil {
1135 t.Fatalf("large get: %v", err)
1136 }
1137 buf := make([]byte, len(testString))
1138 n, err := io.ReadFull(res.Body, buf)
1139 if err != nil {
1140 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1141 }
1142 if e, g := testString, string(buf); e != g {
1143 t.Errorf("partial read got %q, expected %q", g, e)
1144 }
1145 res.Body.Close()
1146
1147 n, err = res.Body.Read(buf)
1148 if n != 0 || err == nil {
1149 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1150 }
1151
1152
1153 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1154 if err != nil {
1155 t.Fatal(err)
1156 }
1157 body, err := io.ReadAll(res.Body)
1158 if err != nil {
1159 t.Fatal(err)
1160 }
1161 if g, e := string(body), testString; g != e {
1162 t.Fatalf("body = %q; want %q", g, e)
1163 }
1164 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1165 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1166 }
1167
1168
1169 n, err = res.Body.Read(buf)
1170 if n != 0 || err == nil {
1171 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1172 }
1173 res.Body.Close()
1174 n, err = res.Body.Read(buf)
1175 if n != 0 || err == nil {
1176 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1177 }
1178 }
1179
1180
1181 res, err := c.Head(ts.URL)
1182 if err != nil {
1183 t.Fatalf("Head: %v", err)
1184 }
1185 if res.StatusCode != 200 {
1186 t.Errorf("Head status=%d; want=200", res.StatusCode)
1187 }
1188 }
1189
1190
1191
1192 type transport100ContinueTest struct {
1193 t *testing.T
1194
1195 reqdone chan struct{}
1196 resp *Response
1197 respErr error
1198
1199 conn net.Conn
1200 reader *bufio.Reader
1201 }
1202
1203 const transport100ContinueTestBody = "request body"
1204
1205
1206
1207 func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1208 ln := newLocalListener(t)
1209 defer ln.Close()
1210
1211 test := &transport100ContinueTest{
1212 t: t,
1213 reqdone: make(chan struct{}),
1214 }
1215
1216 tr := &Transport{
1217 ExpectContinueTimeout: timeout,
1218 }
1219 go func() {
1220 defer close(test.reqdone)
1221 body := strings.NewReader(transport100ContinueTestBody)
1222 req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1223 req.Header.Set("Expect", "100-continue")
1224 req.ContentLength = int64(len(transport100ContinueTestBody))
1225 test.resp, test.respErr = tr.RoundTrip(req)
1226 test.resp.Body.Close()
1227 }()
1228
1229 c, err := ln.Accept()
1230 if err != nil {
1231 t.Fatalf("Accept: %v", err)
1232 }
1233 t.Cleanup(func() {
1234 c.Close()
1235 })
1236 br := bufio.NewReader(c)
1237 _, err = ReadRequest(br)
1238 if err != nil {
1239 t.Fatalf("ReadRequest: %v", err)
1240 }
1241 test.conn = c
1242 test.reader = br
1243 t.Cleanup(func() {
1244 <-test.reqdone
1245 tr.CloseIdleConnections()
1246 got, _ := io.ReadAll(test.reader)
1247 if len(got) > 0 {
1248 t.Fatalf("Transport sent unexpected bytes: %q", got)
1249 }
1250 })
1251
1252 return test
1253 }
1254
1255
1256 func (test *transport100ContinueTest) respond(lines ...string) {
1257 for _, line := range lines {
1258 if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1259 test.t.Fatalf("Write: %v", err)
1260 }
1261 }
1262 if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1263 test.t.Fatalf("Write: %v", err)
1264 }
1265 }
1266
1267
1268 func (test *transport100ContinueTest) wantBodySent() {
1269 got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1270 if err != nil {
1271 test.t.Fatalf("unexpected error reading body: %v", err)
1272 }
1273 if got, want := string(got), transport100ContinueTestBody; got != want {
1274 test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1275 }
1276 }
1277
1278
1279 func (test *transport100ContinueTest) wantRequestDone(want int) {
1280 <-test.reqdone
1281 if test.respErr != nil {
1282 test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1283 }
1284 if got := test.resp.StatusCode; got != want {
1285 test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1286 }
1287 }
1288
1289 func TestTransportExpect100ContinueSent(t *testing.T) {
1290 test := newTransport100ContinueTest(t, 1*time.Hour)
1291
1292 test.respond("HTTP/1.1 100 Continue")
1293 test.wantBodySent()
1294 test.respond("HTTP/1.1 200", "Content-Length: 0")
1295 test.wantRequestDone(200)
1296 }
1297
1298 func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1299 test := newTransport100ContinueTest(t, 1*time.Hour)
1300
1301 test.respond("HTTP/1.1 200", "Content-Length: 0")
1302 test.wantBodySent()
1303 test.wantRequestDone(200)
1304 }
1305
1306 func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1307 test := newTransport100ContinueTest(t, 1*time.Hour)
1308
1309 test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1310 test.wantRequestDone(200)
1311 }
1312
1313 func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1314 test := newTransport100ContinueTest(t, 1*time.Hour)
1315
1316 test.respond("HTTP/1.1 500", "Content-Length: 0")
1317 test.wantBodySent()
1318 test.wantRequestDone(500)
1319 }
1320
1321 func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1322 test := newTransport100ContinueTest(t, 5*time.Millisecond)
1323 test.wantBodySent()
1324 test.respond("HTTP/1.1 200", "Content-Length: 0")
1325 test.wantRequestDone(200)
1326 }
1327
1328 func TestSOCKS5Proxy(t *testing.T) {
1329 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1330 }
1331 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1332 ch := make(chan string, 1)
1333 l := newLocalListener(t)
1334 defer l.Close()
1335 defer close(ch)
1336 proxy := func(t *testing.T) {
1337 s, err := l.Accept()
1338 if err != nil {
1339 t.Errorf("socks5 proxy Accept(): %v", err)
1340 return
1341 }
1342 defer s.Close()
1343 var buf [22]byte
1344 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1345 t.Errorf("socks5 proxy initial read: %v", err)
1346 return
1347 }
1348 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1349 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1350 return
1351 }
1352 if _, err := s.Write([]byte{5, 0}); err != nil {
1353 t.Errorf("socks5 proxy initial write: %v", err)
1354 return
1355 }
1356 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1357 t.Errorf("socks5 proxy second read: %v", err)
1358 return
1359 }
1360 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1361 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1362 return
1363 }
1364 var ipLen int
1365 switch buf[3] {
1366 case 1:
1367 ipLen = net.IPv4len
1368 case 4:
1369 ipLen = net.IPv6len
1370 default:
1371 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1372 return
1373 }
1374 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1375 t.Errorf("socks5 proxy address read: %v", err)
1376 return
1377 }
1378 ip := net.IP(buf[4 : ipLen+4])
1379 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1380 copy(buf[:3], []byte{5, 0, 0})
1381 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1382 t.Errorf("socks5 proxy connect write: %v", err)
1383 return
1384 }
1385 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1386
1387
1388 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1389 targetConn, err := net.Dial("tcp", targetHost)
1390 if err != nil {
1391 t.Errorf("net.Dial failed")
1392 return
1393 }
1394 go io.Copy(targetConn, s)
1395 io.Copy(s, targetConn)
1396 targetConn.Close()
1397 }
1398
1399 pu, err := url.Parse("socks5://" + l.Addr().String())
1400 if err != nil {
1401 t.Fatal(err)
1402 }
1403
1404 sentinelHeader := "X-Sentinel"
1405 sentinelValue := "12345"
1406 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1407 w.Header().Set(sentinelHeader, sentinelValue)
1408 })
1409 for _, useTLS := range []bool{false, true} {
1410 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1411 ts := newClientServerTest(t, mode, h).ts
1412 go proxy(t)
1413 c := ts.Client()
1414 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1415 r, err := c.Head(ts.URL)
1416 if err != nil {
1417 t.Fatal(err)
1418 }
1419 if r.Header.Get(sentinelHeader) != sentinelValue {
1420 t.Errorf("Failed to retrieve sentinel value")
1421 }
1422 got := <-ch
1423 ts.Close()
1424 tsu, err := url.Parse(ts.URL)
1425 if err != nil {
1426 t.Fatal(err)
1427 }
1428 want := "proxy for " + tsu.Host
1429 if got != want {
1430 t.Errorf("got %q, want %q", got, want)
1431 }
1432 })
1433 }
1434 }
1435
1436 func TestTransportProxy(t *testing.T) {
1437 defer afterTest(t)
1438 testCases := []struct{ siteMode, proxyMode testMode }{
1439 {http1Mode, http1Mode},
1440 {http1Mode, https1Mode},
1441 {https1Mode, http1Mode},
1442 {https1Mode, https1Mode},
1443 }
1444 for _, testCase := range testCases {
1445 siteMode := testCase.siteMode
1446 proxyMode := testCase.proxyMode
1447 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1448 siteCh := make(chan *Request, 1)
1449 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1450 siteCh <- r
1451 })
1452 proxyCh := make(chan *Request, 1)
1453 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1454 proxyCh <- r
1455
1456 if r.Method == "CONNECT" {
1457 hijacker, ok := w.(Hijacker)
1458 if !ok {
1459 t.Errorf("hijack not allowed")
1460 return
1461 }
1462 clientConn, _, err := hijacker.Hijack()
1463 if err != nil {
1464 t.Errorf("hijacking failed")
1465 return
1466 }
1467 res := &Response{
1468 StatusCode: StatusOK,
1469 Proto: "HTTP/1.1",
1470 ProtoMajor: 1,
1471 ProtoMinor: 1,
1472 Header: make(Header),
1473 }
1474
1475 targetConn, err := net.Dial("tcp", r.URL.Host)
1476 if err != nil {
1477 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1478 return
1479 }
1480
1481 if err := res.Write(clientConn); err != nil {
1482 t.Errorf("Writing 200 OK failed: %v", err)
1483 return
1484 }
1485
1486 go io.Copy(targetConn, clientConn)
1487 go func() {
1488 io.Copy(clientConn, targetConn)
1489 targetConn.Close()
1490 }()
1491 }
1492 })
1493 ts := newClientServerTest(t, siteMode, h1).ts
1494 proxy := newClientServerTest(t, proxyMode, h2).ts
1495
1496 pu, err := url.Parse(proxy.URL)
1497 if err != nil {
1498 t.Fatal(err)
1499 }
1500
1501
1502
1503
1504 c := proxy.Client()
1505 if siteMode == https1Mode {
1506 c = ts.Client()
1507 }
1508
1509 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1510 if _, err := c.Head(ts.URL); err != nil {
1511 t.Error(err)
1512 }
1513 got := <-proxyCh
1514 c.Transport.(*Transport).CloseIdleConnections()
1515 ts.Close()
1516 proxy.Close()
1517 if siteMode == https1Mode {
1518
1519 if got.Method != "CONNECT" {
1520 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1521 }
1522 gotHost := got.URL.Host
1523 pu, err := url.Parse(ts.URL)
1524 if err != nil {
1525 t.Fatal("Invalid site URL")
1526 }
1527 if wantHost := pu.Host; gotHost != wantHost {
1528 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1529 }
1530
1531
1532 next := <-siteCh
1533 if next.Method != "HEAD" {
1534 t.Errorf("Wrong method at destination: %s", next.Method)
1535 }
1536 if nextURL := next.URL.String(); nextURL != "/" {
1537 t.Errorf("Wrong URL at destination: %s", nextURL)
1538 }
1539 } else {
1540 if got.Method != "HEAD" {
1541 t.Errorf("Wrong method for destination: %q", got.Method)
1542 }
1543 gotURL := got.URL.String()
1544 wantURL := ts.URL + "/"
1545 if gotURL != wantURL {
1546 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1547 }
1548 }
1549 })
1550 }
1551 }
1552
1553
1554
1555
1556 func TestProxyWithInfiniteHeader(t *testing.T) {
1557 defer afterTest(t)
1558
1559 ln := newLocalListener(t)
1560 defer ln.Close()
1561 cancelc := make(chan struct{})
1562 defer close(cancelc)
1563
1564
1565
1566 go func() {
1567 c, err := ln.Accept()
1568 if err != nil {
1569 t.Errorf("Accept: %v", err)
1570 return
1571 }
1572 defer c.Close()
1573
1574 br := bufio.NewReader(c)
1575 cr, err := ReadRequest(br)
1576 if err != nil {
1577 t.Errorf("proxy server failed to read CONNECT request")
1578 return
1579 }
1580 if cr.Method != "CONNECT" {
1581 t.Errorf("unexpected method %q", cr.Method)
1582 return
1583 }
1584
1585
1586 for {
1587
1588
1589
1590 runtime.Gosched()
1591 select {
1592 case <-cancelc:
1593 return
1594 default:
1595 c.Write([]byte("infinite stream of bytes"))
1596 }
1597 }
1598 }()
1599
1600 c := &Client{
1601 Transport: &Transport{
1602 Proxy: func(*Request) (*url.URL, error) {
1603 return url.Parse("http://" + ln.Addr().String())
1604 },
1605
1606 MaxResponseHeaderBytes: 1024,
1607 },
1608 }
1609 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1610 if err != nil {
1611 t.Fatal(err)
1612 }
1613 _, err = c.Do(req)
1614 if err == nil {
1615 t.Errorf("unexpected Get success")
1616 }
1617 }
1618
1619 func TestOnProxyConnectResponse(t *testing.T) {
1620
1621 var tcases = []struct {
1622 proxyStatusCode int
1623 err error
1624 }{
1625 {
1626 StatusOK,
1627 nil,
1628 },
1629 {
1630 StatusForbidden,
1631 errors.New("403"),
1632 },
1633 }
1634 for _, tcase := range tcases {
1635 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1636
1637 })
1638
1639 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1640
1641 if r.Method == "CONNECT" {
1642 if tcase.proxyStatusCode != StatusOK {
1643 w.WriteHeader(tcase.proxyStatusCode)
1644 return
1645 }
1646 hijacker, ok := w.(Hijacker)
1647 if !ok {
1648 t.Errorf("hijack not allowed")
1649 return
1650 }
1651 clientConn, _, err := hijacker.Hijack()
1652 if err != nil {
1653 t.Errorf("hijacking failed")
1654 return
1655 }
1656 res := &Response{
1657 StatusCode: StatusOK,
1658 Proto: "HTTP/1.1",
1659 ProtoMajor: 1,
1660 ProtoMinor: 1,
1661 Header: make(Header),
1662 }
1663
1664 targetConn, err := net.Dial("tcp", r.URL.Host)
1665 if err != nil {
1666 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1667 return
1668 }
1669
1670 if err := res.Write(clientConn); err != nil {
1671 t.Errorf("Writing 200 OK failed: %v", err)
1672 return
1673 }
1674
1675 go io.Copy(targetConn, clientConn)
1676 go func() {
1677 io.Copy(clientConn, targetConn)
1678 targetConn.Close()
1679 }()
1680 }
1681 })
1682 ts := newClientServerTest(t, https1Mode, h1).ts
1683 proxy := newClientServerTest(t, https1Mode, h2).ts
1684
1685 pu, err := url.Parse(proxy.URL)
1686 if err != nil {
1687 t.Fatal(err)
1688 }
1689
1690 c := proxy.Client()
1691
1692 var (
1693 dials atomic.Int32
1694 closes atomic.Int32
1695 )
1696 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1697 conn, err := net.Dial(network, addr)
1698 if err != nil {
1699 return nil, err
1700 }
1701 dials.Add(1)
1702 return noteCloseConn{
1703 Conn: conn,
1704 closeFunc: func() {
1705 closes.Add(1)
1706 },
1707 }, nil
1708 }
1709
1710 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1711 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1712 if proxyURL.String() != pu.String() {
1713 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1714 }
1715
1716 if "https://"+connectReq.URL.String() != ts.URL {
1717 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1718 }
1719 return tcase.err
1720 }
1721 wantCloses := int32(0)
1722 if _, err := c.Head(ts.URL); err != nil {
1723 wantCloses = 1
1724 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1725 t.Errorf("got %v, want %v", err, tcase.err)
1726 }
1727 } else {
1728 if tcase.err != nil {
1729 t.Errorf("got %v, want nil", err)
1730 }
1731 }
1732 if got, want := dials.Load(), int32(1); got != want {
1733 t.Errorf("got %v dials, want %v", got, want)
1734 }
1735
1736 if got, want := closes.Load(), wantCloses; got != want {
1737 t.Errorf("got %v closes, want %v", got, want)
1738 }
1739 }
1740 }
1741
1742
1743
1744 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1745 cancelc := make(chan struct{})
1746 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1747 ctx, cancel := context.WithCancel(ctx)
1748 go func() {
1749 select {
1750 case <-cancelc:
1751 case <-ctx.Done():
1752 }
1753 cancel()
1754 }()
1755 return ctx, cancel
1756 })
1757
1758 defer afterTest(t)
1759
1760 ln := newLocalListener(t)
1761 defer ln.Close()
1762 listenerDone := make(chan struct{})
1763 go func() {
1764 defer close(listenerDone)
1765 c, err := ln.Accept()
1766 if err != nil {
1767 t.Errorf("Accept: %v", err)
1768 return
1769 }
1770 defer c.Close()
1771
1772 br := bufio.NewReader(c)
1773 cr, err := ReadRequest(br)
1774 if err != nil {
1775 t.Errorf("proxy server failed to read CONNECT request")
1776 return
1777 }
1778 if cr.Method != "CONNECT" {
1779 t.Errorf("unexpected method %q", cr.Method)
1780 return
1781 }
1782
1783
1784
1785
1786 close(cancelc)
1787 var buf [1]byte
1788 _, err = br.Read(buf[:])
1789 if err != io.EOF {
1790 t.Errorf("proxy server Read err = %v; want EOF", err)
1791 }
1792 return
1793 }()
1794
1795 c := &Client{
1796 Transport: &Transport{
1797 Proxy: func(*Request) (*url.URL, error) {
1798 return url.Parse("http://" + ln.Addr().String())
1799 },
1800 },
1801 }
1802 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1803 if err != nil {
1804 t.Fatal(err)
1805 }
1806 _, err = c.Do(req)
1807 if err == nil {
1808 t.Errorf("unexpected Get success")
1809 }
1810
1811
1812
1813
1814 <-listenerDone
1815 }
1816
1817
1818 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1819 defer afterTest(t)
1820
1821 var errDial = errors.New("some dial error")
1822
1823 tr := &Transport{
1824 Proxy: func(*Request) (*url.URL, error) {
1825 return url.Parse("http://proxy.fake.tld/")
1826 },
1827 Dial: func(string, string) (net.Conn, error) {
1828 return nil, errDial
1829 },
1830 }
1831 defer tr.CloseIdleConnections()
1832
1833 c := &Client{Transport: tr}
1834 req, _ := NewRequest("GET", "http://fake.tld", nil)
1835 res, err := c.Do(req)
1836 if err == nil {
1837 res.Body.Close()
1838 t.Fatal("wanted a non-nil error")
1839 }
1840
1841 uerr, ok := err.(*url.Error)
1842 if !ok {
1843 t.Fatalf("got %T, want *url.Error", err)
1844 }
1845 oe, ok := uerr.Err.(*net.OpError)
1846 if !ok {
1847 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1848 }
1849 want := &net.OpError{
1850 Op: "proxyconnect",
1851 Net: "tcp",
1852 Err: errDial,
1853 }
1854 if !reflect.DeepEqual(oe, want) {
1855 t.Errorf("Got error %#v; want %#v", oe, want)
1856 }
1857 }
1858
1859
1860
1861
1862
1863 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1864 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1865 }
1866 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1867 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1868 defer proxy.Close()
1869 c := proxy.Client()
1870
1871 tr := c.Transport.(*Transport)
1872 tr.Proxy = func(*Request) (*url.URL, error) {
1873 u, _ := url.Parse(proxy.URL)
1874 u.User = url.UserPassword("aladdin", "opensesame")
1875 return u, nil
1876 }
1877 h := tr.ProxyConnectHeader
1878 if h == nil {
1879 h = make(Header)
1880 }
1881 tr.ProxyConnectHeader = h.Clone()
1882
1883 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1884 if err != nil {
1885 t.Fatal(err)
1886 }
1887 _, err = c.Do(req)
1888 if err == nil {
1889 t.Errorf("unexpected Get success")
1890 }
1891
1892 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1893 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1894 }
1895 }
1896
1897
1898
1899
1900
1901 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1902 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1903 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1904 w.Header().Set("Content-Encoding", "gzip")
1905 w.Write(rgz)
1906 })).ts
1907
1908 c := ts.Client()
1909 res, err := c.Get(ts.URL)
1910 if err != nil {
1911 t.Fatal(err)
1912 }
1913 body, err := io.ReadAll(res.Body)
1914 if err != nil {
1915 t.Fatal(err)
1916 }
1917 if !bytes.Equal(body, rgz) {
1918 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1919 body, rgz)
1920 }
1921 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1922 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1923 }
1924 }
1925
1926
1927
1928 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1929 func testTransportGzipShort(t *testing.T, mode testMode) {
1930 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1931 w.Header().Set("Content-Encoding", "gzip")
1932 w.Write([]byte{0x1f, 0x8b})
1933 })).ts
1934
1935 c := ts.Client()
1936 res, err := c.Get(ts.URL)
1937 if err != nil {
1938 t.Fatal(err)
1939 }
1940 defer res.Body.Close()
1941 _, err = io.ReadAll(res.Body)
1942 if err == nil {
1943 t.Fatal("Expect an error from reading a body.")
1944 }
1945 if err != io.ErrUnexpectedEOF {
1946 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1947 }
1948 }
1949
1950
1951 func waitNumGoroutine(nmax int) int {
1952 nfinal := runtime.NumGoroutine()
1953 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1954 time.Sleep(50 * time.Millisecond)
1955 runtime.GC()
1956 nfinal = runtime.NumGoroutine()
1957 }
1958 return nfinal
1959 }
1960
1961
1962 func TestTransportPersistConnLeak(t *testing.T) {
1963 run(t, testTransportPersistConnLeak, testNotParallel)
1964 }
1965 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1966 if mode == http2Mode {
1967 t.Skip("flaky in HTTP/2")
1968 }
1969
1970
1971 const numReq = 25
1972 gotReqCh := make(chan bool, numReq)
1973 unblockCh := make(chan bool, numReq)
1974 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1975 gotReqCh <- true
1976 <-unblockCh
1977 w.Header().Set("Content-Length", "0")
1978 w.WriteHeader(204)
1979 })).ts
1980 c := ts.Client()
1981 tr := c.Transport.(*Transport)
1982
1983 n0 := runtime.NumGoroutine()
1984
1985 didReqCh := make(chan bool, numReq)
1986 failed := make(chan bool, numReq)
1987 for i := 0; i < numReq; i++ {
1988 go func() {
1989 res, err := c.Get(ts.URL)
1990 didReqCh <- true
1991 if err != nil {
1992 t.Logf("client fetch error: %v", err)
1993 failed <- true
1994 return
1995 }
1996 res.Body.Close()
1997 }()
1998 }
1999
2000
2001 for i := 0; i < numReq; i++ {
2002 select {
2003 case <-gotReqCh:
2004
2005 case <-failed:
2006
2007
2008 }
2009 }
2010
2011 nhigh := runtime.NumGoroutine()
2012
2013
2014 close(unblockCh)
2015
2016
2017 for i := 0; i < numReq; i++ {
2018 <-didReqCh
2019 }
2020
2021 tr.CloseIdleConnections()
2022 nfinal := waitNumGoroutine(n0 + 5)
2023
2024 growth := nfinal - n0
2025
2026
2027
2028 if int(growth) > 5 {
2029 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2030 t.Error("too many new goroutines")
2031 }
2032 }
2033
2034
2035
2036 func TestTransportPersistConnLeakShortBody(t *testing.T) {
2037 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
2038 }
2039 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
2040 if mode == http2Mode {
2041 t.Skip("flaky in HTTP/2")
2042 }
2043
2044
2045 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2046 })).ts
2047 c := ts.Client()
2048 tr := c.Transport.(*Transport)
2049
2050 n0 := runtime.NumGoroutine()
2051 body := []byte("Hello")
2052 for i := 0; i < 20; i++ {
2053 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2054 if err != nil {
2055 t.Fatal(err)
2056 }
2057 req.ContentLength = int64(len(body) - 2)
2058 _, err = c.Do(req)
2059 if err == nil {
2060 t.Fatal("Expect an error from writing too long of a body.")
2061 }
2062 }
2063 nhigh := runtime.NumGoroutine()
2064 tr.CloseIdleConnections()
2065 nfinal := waitNumGoroutine(n0 + 5)
2066
2067 growth := nfinal - n0
2068
2069
2070
2071 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2072 if int(growth) > 5 {
2073 t.Error("too many new goroutines")
2074 }
2075 }
2076
2077
2078 type countedConn struct {
2079 net.Conn
2080 }
2081
2082
2083 type countingDialer struct {
2084 dialer net.Dialer
2085 mu sync.Mutex
2086 total, live int64
2087 }
2088
2089 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2090 conn, err := d.dialer.DialContext(ctx, network, address)
2091 if err != nil {
2092 return nil, err
2093 }
2094
2095 counted := new(countedConn)
2096 counted.Conn = conn
2097
2098 d.mu.Lock()
2099 defer d.mu.Unlock()
2100 d.total++
2101 d.live++
2102
2103 runtime.AddCleanup(counted, func(dd *countingDialer) { dd.decrement(nil) }, d)
2104 return counted, nil
2105 }
2106
2107 func (d *countingDialer) decrement(*countedConn) {
2108 d.mu.Lock()
2109 defer d.mu.Unlock()
2110 d.live--
2111 }
2112
2113 func (d *countingDialer) Read() (total, live int64) {
2114 d.mu.Lock()
2115 defer d.mu.Unlock()
2116 return d.total, d.live
2117 }
2118
2119 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2120 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2121 }
2122 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2123 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2124
2125 conn, _, err := w.(Hijacker).Hijack()
2126 if err != nil {
2127 t.Errorf("Hijack failed unexpectedly: %v", err)
2128 return
2129 }
2130 conn.Close()
2131 })).ts
2132
2133 var d countingDialer
2134 c := ts.Client()
2135 c.Transport.(*Transport).DialContext = d.DialContext
2136
2137 body := []byte("Hello")
2138 for i := 0; ; i++ {
2139 total, live := d.Read()
2140 if live < total {
2141 break
2142 }
2143 if i >= 1<<12 {
2144 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2145 }
2146
2147 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2148 if err != nil {
2149 t.Fatal(err)
2150 }
2151 _, err = c.Do(req)
2152 if err == nil {
2153 t.Fatal("expected broken connection")
2154 }
2155
2156 runtime.GC()
2157 }
2158 }
2159
2160 type countedContext struct {
2161 context.Context
2162 }
2163
2164 type contextCounter struct {
2165 mu sync.Mutex
2166 live int64
2167 }
2168
2169 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2170 counted := new(countedContext)
2171 counted.Context = ctx
2172 cc.mu.Lock()
2173 defer cc.mu.Unlock()
2174 cc.live++
2175 runtime.AddCleanup(counted, func(c *contextCounter) { cc.decrement(nil) }, cc)
2176 return counted
2177 }
2178
2179 func (cc *contextCounter) decrement(*countedContext) {
2180 cc.mu.Lock()
2181 defer cc.mu.Unlock()
2182 cc.live--
2183 }
2184
2185 func (cc *contextCounter) Read() (live int64) {
2186 cc.mu.Lock()
2187 defer cc.mu.Unlock()
2188 return cc.live
2189 }
2190
2191 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2192 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2193 }
2194 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2195 if mode == http2Mode {
2196 t.Skip("https://go.dev/issue/56021")
2197 }
2198
2199 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2200 runtime.Gosched()
2201 w.WriteHeader(StatusOK)
2202 })).ts
2203
2204 c := ts.Client()
2205 c.Transport.(*Transport).MaxConnsPerHost = 1
2206
2207 ctx := context.Background()
2208 body := []byte("Hello")
2209 doPosts := func(cc *contextCounter) {
2210 var wg sync.WaitGroup
2211 for n := 64; n > 0; n-- {
2212 wg.Add(1)
2213 go func() {
2214 defer wg.Done()
2215
2216 ctx := cc.Track(ctx)
2217 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2218 if err != nil {
2219 t.Error(err)
2220 }
2221
2222 _, err = c.Do(req.WithContext(ctx))
2223 if err != nil {
2224 t.Errorf("Do failed with error: %v", err)
2225 }
2226 }()
2227 }
2228 wg.Wait()
2229 }
2230
2231 var initialCC contextCounter
2232 doPosts(&initialCC)
2233
2234
2235
2236
2237 var flushCC contextCounter
2238 for i := 0; ; i++ {
2239 live := initialCC.Read()
2240 if live == 0 {
2241 break
2242 }
2243 if i >= 100 {
2244 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2245 }
2246 doPosts(&flushCC)
2247 runtime.GC()
2248 }
2249 }
2250
2251
2252 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2253 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2254 var tr *Transport
2255
2256 unblockCh := make(chan bool, 1)
2257 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2258 <-unblockCh
2259 tr.CloseIdleConnections()
2260 })).ts
2261 c := ts.Client()
2262 tr = c.Transport.(*Transport)
2263
2264 didreq := make(chan bool)
2265 go func() {
2266 res, err := c.Get(ts.URL)
2267 if err != nil {
2268 t.Error(err)
2269 } else {
2270 res.Body.Close()
2271 }
2272 didreq <- true
2273 }()
2274 unblockCh <- true
2275 <-didreq
2276 }
2277
2278
2279
2280
2281
2282 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2283 func testIssue3644(t *testing.T, mode testMode) {
2284 const numFoos = 5000
2285 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2286 w.Header().Set("Connection", "close")
2287 for i := 0; i < numFoos; i++ {
2288 w.Write([]byte("foo "))
2289 }
2290 })).ts
2291 c := ts.Client()
2292 res, err := c.Get(ts.URL)
2293 if err != nil {
2294 t.Fatal(err)
2295 }
2296 defer res.Body.Close()
2297 bs, err := io.ReadAll(res.Body)
2298 if err != nil {
2299 t.Fatal(err)
2300 }
2301 if len(bs) != numFoos*len("foo ") {
2302 t.Errorf("unexpected response length")
2303 }
2304 }
2305
2306
2307
2308 func TestIssue3595(t *testing.T) {
2309
2310 run(t, testIssue3595, testNotParallel)
2311 }
2312 func testIssue3595(t *testing.T, mode testMode) {
2313 runTimeSensitiveTest(t, []time.Duration{
2314 1 * time.Millisecond,
2315 5 * time.Millisecond,
2316 10 * time.Millisecond,
2317 50 * time.Millisecond,
2318 100 * time.Millisecond,
2319 500 * time.Millisecond,
2320 time.Second,
2321 5 * time.Second,
2322 }, func(t *testing.T, timeout time.Duration) error {
2323 SetRSTAvoidanceDelay(t, timeout)
2324 t.Logf("set RST avoidance delay to %v", timeout)
2325
2326 const deniedMsg = "sorry, denied."
2327 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2328 Error(w, deniedMsg, StatusUnauthorized)
2329 }))
2330
2331
2332 defer cst.close()
2333 ts := cst.ts
2334 c := ts.Client()
2335
2336 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2337 if err != nil {
2338 return fmt.Errorf("Post: %v", err)
2339 }
2340 got, err := io.ReadAll(res.Body)
2341 if err != nil {
2342 return fmt.Errorf("Body ReadAll: %v", err)
2343 }
2344 t.Logf("server response:\n%s", got)
2345 if !strings.Contains(string(got), deniedMsg) {
2346
2347
2348 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2349 }
2350 return nil
2351 })
2352 }
2353
2354
2355
2356 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2357 func testChunkedNoContent(t *testing.T, mode testMode) {
2358 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2359 w.WriteHeader(StatusNoContent)
2360 })).ts
2361
2362 c := ts.Client()
2363 for _, closeBody := range []bool{true, false} {
2364 const n = 4
2365 for i := 1; i <= n; i++ {
2366 res, err := c.Get(ts.URL)
2367 if err != nil {
2368 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2369 } else {
2370 if closeBody {
2371 res.Body.Close()
2372 }
2373 }
2374 }
2375 }
2376 }
2377
2378 func TestTransportConcurrency(t *testing.T) {
2379 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2380 }
2381 func testTransportConcurrency(t *testing.T, mode testMode) {
2382
2383 maxProcs, numReqs := 16, 500
2384 if testing.Short() {
2385 maxProcs, numReqs = 4, 50
2386 }
2387 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2388 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2389 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2390 })).ts
2391
2392 var wg sync.WaitGroup
2393 wg.Add(numReqs)
2394
2395
2396
2397
2398
2399
2400
2401 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2402 defer SetPendingDialHooks(nil, nil)
2403
2404 c := ts.Client()
2405 reqs := make(chan string)
2406 defer close(reqs)
2407
2408 for i := 0; i < maxProcs*2; i++ {
2409 go func() {
2410 for req := range reqs {
2411 res, err := c.Get(ts.URL + "/?echo=" + req)
2412 if err != nil {
2413 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2414
2415
2416 t.Logf("error on req %s: %v", req, err)
2417 t.Logf("(see https://go.dev/issue/52168)")
2418 } else {
2419 t.Errorf("error on req %s: %v", req, err)
2420 }
2421 wg.Done()
2422 continue
2423 }
2424 all, err := io.ReadAll(res.Body)
2425 if err != nil {
2426 t.Errorf("read error on req %s: %v", req, err)
2427 } else if string(all) != req {
2428 t.Errorf("body of req %s = %q; want %q", req, all, req)
2429 }
2430 res.Body.Close()
2431 wg.Done()
2432 }
2433 }()
2434 }
2435 for i := 0; i < numReqs; i++ {
2436 reqs <- fmt.Sprintf("request-%d", i)
2437 }
2438 wg.Wait()
2439 }
2440
2441 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2442 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2443 mux := NewServeMux()
2444 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2445 io.Copy(w, neverEnding('a'))
2446 })
2447 ts := newClientServerTest(t, mode, mux).ts
2448
2449 connc := make(chan net.Conn, 1)
2450 c := ts.Client()
2451 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2452 conn, err := net.Dial(n, addr)
2453 if err != nil {
2454 return nil, err
2455 }
2456 select {
2457 case connc <- conn:
2458 default:
2459 }
2460 return conn, nil
2461 }
2462
2463 res, err := c.Get(ts.URL + "/get")
2464 if err != nil {
2465 t.Fatalf("Error issuing GET: %v", err)
2466 }
2467 defer res.Body.Close()
2468
2469 conn := <-connc
2470 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2471 _, err = io.Copy(io.Discard, res.Body)
2472 if err == nil {
2473 t.Errorf("Unexpected successful copy")
2474 }
2475 }
2476
2477 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2478 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2479 }
2480 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2481 const debug = false
2482 mux := NewServeMux()
2483 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2484 io.Copy(w, neverEnding('a'))
2485 })
2486 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2487 defer r.Body.Close()
2488 io.Copy(io.Discard, r.Body)
2489 })
2490 ts := newClientServerTest(t, mode, mux).ts
2491 timeout := 100 * time.Millisecond
2492
2493 c := ts.Client()
2494 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2495 conn, err := net.Dial(n, addr)
2496 if err != nil {
2497 return nil, err
2498 }
2499 conn.SetDeadline(time.Now().Add(timeout))
2500 if debug {
2501 conn = NewLoggingConn("client", conn)
2502 }
2503 return conn, nil
2504 }
2505
2506 getFailed := false
2507 nRuns := 5
2508 if testing.Short() {
2509 nRuns = 1
2510 }
2511 for i := 0; i < nRuns; i++ {
2512 if debug {
2513 println("run", i+1, "of", nRuns)
2514 }
2515 sres, err := c.Get(ts.URL + "/get")
2516 if err != nil {
2517 if !getFailed {
2518
2519 getFailed = true
2520 t.Logf("increasing timeout")
2521 i--
2522 timeout *= 10
2523 continue
2524 }
2525 t.Errorf("Error issuing GET: %v", err)
2526 break
2527 }
2528 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2529 _, err = c.Do(req)
2530 if err == nil {
2531 sres.Body.Close()
2532 t.Errorf("Unexpected successful PUT")
2533 break
2534 }
2535 sres.Body.Close()
2536 }
2537 if debug {
2538 println("tests complete; waiting for handlers to finish")
2539 }
2540 ts.Close()
2541 }
2542
2543 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2544 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2545 if testing.Short() {
2546 t.Skip("skipping timeout test in -short mode")
2547 }
2548
2549 timeout := 2 * time.Millisecond
2550 retry := true
2551 for retry && !t.Failed() {
2552 var srvWG sync.WaitGroup
2553 inHandler := make(chan bool, 1)
2554 mux := NewServeMux()
2555 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2556 inHandler <- true
2557 srvWG.Done()
2558 })
2559 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2560 inHandler <- true
2561 <-r.Context().Done()
2562 srvWG.Done()
2563 })
2564 ts := newClientServerTest(t, mode, mux).ts
2565
2566 c := ts.Client()
2567 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2568
2569 retry = false
2570 srvWG.Add(3)
2571 tests := []struct {
2572 path string
2573 wantTimeout bool
2574 }{
2575 {path: "/fast"},
2576 {path: "/slow", wantTimeout: true},
2577 {path: "/fast"},
2578 }
2579 for i, tt := range tests {
2580 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2581 req = req.WithT(t)
2582 res, err := c.Do(req)
2583 <-inHandler
2584 if err != nil {
2585 uerr, ok := err.(*url.Error)
2586 if !ok {
2587 t.Errorf("error is not a url.Error; got: %#v", err)
2588 continue
2589 }
2590 nerr, ok := uerr.Err.(net.Error)
2591 if !ok {
2592 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2593 continue
2594 }
2595 if !nerr.Timeout() {
2596 t.Errorf("want timeout error; got: %q", nerr)
2597 continue
2598 }
2599 if !tt.wantTimeout {
2600 if !retry {
2601
2602 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2603 timeout *= 2
2604 retry = true
2605 }
2606 }
2607 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2608 t.Errorf("%d. unexpected error: %v", i, err)
2609 }
2610 continue
2611 }
2612 if tt.wantTimeout {
2613 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2614 continue
2615 }
2616 if res.StatusCode != 200 {
2617 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2618 }
2619 }
2620
2621 srvWG.Wait()
2622 ts.Close()
2623 }
2624 }
2625
2626
2627 type cancelTest struct {
2628 mode testMode
2629 newReq func(req *Request) *Request
2630 cancel func(tr *Transport, req *Request)
2631 checkErr func(when string, err error)
2632 }
2633
2634
2635 func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2636 t.Run("TransportCancel", func(t *testing.T) {
2637 f(t, cancelTest{
2638 mode: mode,
2639 newReq: func(req *Request) *Request {
2640 return req
2641 },
2642 cancel: func(tr *Transport, req *Request) {
2643 tr.CancelRequest(req)
2644 },
2645 checkErr: func(when string, err error) {
2646 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2647 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2648 }
2649 },
2650 })
2651 })
2652 }
2653
2654
2655 func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2656 cancelc := make(chan struct{})
2657 cancelOnce := sync.OnceFunc(func() { close(cancelc) })
2658 f(t, cancelTest{
2659 mode: mode,
2660 newReq: func(req *Request) *Request {
2661 req.Cancel = cancelc
2662 return req
2663 },
2664 cancel: func(tr *Transport, req *Request) {
2665 cancelOnce()
2666 },
2667 checkErr: func(when string, err error) {
2668 if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2669 t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2670 }
2671 },
2672 })
2673 }
2674
2675
2676 func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2677 ctx, cancel := context.WithCancel(context.Background())
2678 f(t, cancelTest{
2679 mode: mode,
2680 newReq: func(req *Request) *Request {
2681 return req.WithContext(ctx)
2682 },
2683 cancel: func(tr *Transport, req *Request) {
2684 cancel()
2685 },
2686 checkErr: func(when string, err error) {
2687 if !errors.Is(err, context.Canceled) {
2688 t.Errorf("%v error = %v, want context.Canceled", when, err)
2689 }
2690 },
2691 })
2692 }
2693
2694 func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2695 run(t, func(t *testing.T, mode testMode) {
2696 if mode == http1Mode {
2697 t.Run("TransportCancel", func(t *testing.T) {
2698 runCancelTestTransport(t, mode, f)
2699 })
2700 }
2701 t.Run("RequestCancel", func(t *testing.T) {
2702 runCancelTestChannel(t, mode, f)
2703 })
2704 t.Run("ContextCancel", func(t *testing.T) {
2705 runCancelTestContext(t, mode, f)
2706 })
2707 }, opts...)
2708 }
2709
2710 func TestTransportCancelRequest(t *testing.T) {
2711 runCancelTest(t, testTransportCancelRequest)
2712 }
2713 func testTransportCancelRequest(t *testing.T, test cancelTest) {
2714 if testing.Short() {
2715 t.Skip("skipping test in -short mode")
2716 }
2717
2718 const msg = "Hello"
2719 unblockc := make(chan bool)
2720 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2721 io.WriteString(w, msg)
2722 w.(Flusher).Flush()
2723 <-unblockc
2724 })).ts
2725 defer close(unblockc)
2726
2727 c := ts.Client()
2728 tr := c.Transport.(*Transport)
2729
2730 req, _ := NewRequest("GET", ts.URL, nil)
2731 req = test.newReq(req)
2732 res, err := c.Do(req)
2733 if err != nil {
2734 t.Fatal(err)
2735 }
2736 body := make([]byte, len(msg))
2737 n, _ := io.ReadFull(res.Body, body)
2738 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2739 t.Errorf("Body = %q; want %q", body[:n], msg)
2740 }
2741 test.cancel(tr, req)
2742
2743 tail, err := io.ReadAll(res.Body)
2744 res.Body.Close()
2745 test.checkErr("Body.Read", err)
2746 if len(tail) > 0 {
2747 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2748 }
2749
2750
2751
2752 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2753 n := tr.NumPendingRequestsForTesting()
2754 if n > 0 {
2755 if d > 0 {
2756 t.Logf("pending requests = %d after %v (want 0)", n, d)
2757 }
2758 return false
2759 }
2760 return true
2761 })
2762 }
2763
2764 func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2765 if testing.Short() {
2766 t.Skip("skipping test in -short mode")
2767 }
2768 unblockc := make(chan bool)
2769 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2770 <-unblockc
2771 })).ts
2772 defer close(unblockc)
2773
2774 c := ts.Client()
2775 tr := c.Transport.(*Transport)
2776
2777 donec := make(chan bool)
2778 req, _ := NewRequest("GET", ts.URL, body)
2779 req = test.newReq(req)
2780 go func() {
2781 defer close(donec)
2782 c.Do(req)
2783 }()
2784
2785 unblockc <- true
2786 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2787 test.cancel(tr, req)
2788 select {
2789 case <-donec:
2790 return true
2791 default:
2792 if d > 0 {
2793 t.Logf("Do of canceled request has not returned after %v", d)
2794 }
2795 return false
2796 }
2797 })
2798 }
2799
2800 func TestTransportCancelRequestInDo(t *testing.T) {
2801 runCancelTest(t, func(t *testing.T, test cancelTest) {
2802 testTransportCancelRequestInDo(t, test, nil)
2803 })
2804 }
2805
2806 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2807 runCancelTest(t, func(t *testing.T, test cancelTest) {
2808 testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2809 })
2810 }
2811
2812 func TestTransportCancelRequestInDial(t *testing.T) {
2813 runCancelTest(t, testTransportCancelRequestInDial)
2814 }
2815 func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2816 defer afterTest(t)
2817 if testing.Short() {
2818 t.Skip("skipping test in -short mode")
2819 }
2820 var logbuf strings.Builder
2821 eventLog := log.New(&logbuf, "", 0)
2822
2823 unblockDial := make(chan bool)
2824 defer close(unblockDial)
2825
2826 inDial := make(chan bool)
2827 tr := &Transport{
2828 Dial: func(network, addr string) (net.Conn, error) {
2829 eventLog.Println("dial: blocking")
2830 if !<-inDial {
2831 return nil, errors.New("main Test goroutine exited")
2832 }
2833 <-unblockDial
2834 return nil, errors.New("nope")
2835 },
2836 }
2837 cl := &Client{Transport: tr}
2838 gotres := make(chan bool)
2839 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2840 req = test.newReq(req)
2841 go func() {
2842 _, err := cl.Do(req)
2843 eventLog.Printf("Get error = %v", err != nil)
2844 test.checkErr("Get", err)
2845 gotres <- true
2846 }()
2847
2848 inDial <- true
2849
2850 eventLog.Printf("canceling")
2851 test.cancel(tr, req)
2852 test.cancel(tr, req)
2853
2854 if d, ok := t.Deadline(); ok {
2855
2856
2857 timeout := time.Until(d) * 19 / 20
2858 timer := time.AfterFunc(timeout, func() {
2859 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2860 })
2861 defer timer.Stop()
2862 }
2863 <-gotres
2864
2865 got := logbuf.String()
2866 want := `dial: blocking
2867 canceling
2868 Get error = true
2869 `
2870 if got != want {
2871 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2872 }
2873 }
2874
2875
2876 func TestTransportCancelRequestWithBody(t *testing.T) {
2877 runCancelTest(t, testTransportCancelRequestWithBody)
2878 }
2879 func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2880 if testing.Short() {
2881 t.Skip("skipping test in -short mode")
2882 }
2883
2884 const msg = "Hello"
2885 unblockc := make(chan struct{})
2886 ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2887 io.WriteString(w, msg)
2888 w.(Flusher).Flush()
2889 <-unblockc
2890 })).ts
2891 defer close(unblockc)
2892
2893 c := ts.Client()
2894 tr := c.Transport.(*Transport)
2895
2896 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2897 req = test.newReq(req)
2898
2899 res, err := c.Do(req)
2900 if err != nil {
2901 t.Fatal(err)
2902 }
2903 body := make([]byte, len(msg))
2904 n, _ := io.ReadFull(res.Body, body)
2905 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2906 t.Errorf("Body = %q; want %q", body[:n], msg)
2907 }
2908 test.cancel(tr, req)
2909
2910 tail, err := io.ReadAll(res.Body)
2911 res.Body.Close()
2912 test.checkErr("Body.Read", err)
2913 if len(tail) > 0 {
2914 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2915 }
2916
2917
2918
2919 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2920 n := tr.NumPendingRequestsForTesting()
2921 if n > 0 {
2922 if d > 0 {
2923 t.Logf("pending requests = %d after %v (want 0)", n, d)
2924 }
2925 return false
2926 }
2927 return true
2928 })
2929 }
2930
2931 func TestTransportCancelRequestBeforeDo(t *testing.T) {
2932
2933 run(t, func(t *testing.T, mode testMode) {
2934 t.Run("RequestCancel", func(t *testing.T) {
2935 runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
2936 })
2937 t.Run("ContextCancel", func(t *testing.T) {
2938 runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
2939 })
2940 })
2941 }
2942 func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
2943 unblockc := make(chan bool)
2944 cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2945 <-unblockc
2946 }))
2947 defer close(unblockc)
2948
2949 c := cst.ts.Client()
2950
2951 req, _ := NewRequest("GET", cst.ts.URL, nil)
2952 req = test.newReq(req)
2953 test.cancel(cst.tr, req)
2954
2955 _, err := c.Do(req)
2956 test.checkErr("Do", err)
2957 }
2958
2959
2960 func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
2961 runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
2962 }
2963 func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
2964 defer afterTest(t)
2965
2966 serverConnCh := make(chan net.Conn, 1)
2967 tr := &Transport{
2968 Dial: func(network, addr string) (net.Conn, error) {
2969 cc, sc := net.Pipe()
2970 serverConnCh <- sc
2971 return cc, nil
2972 },
2973 }
2974 defer tr.CloseIdleConnections()
2975 errc := make(chan error, 1)
2976 req, _ := NewRequest("GET", "http://example.com/", nil)
2977 req = test.newReq(req)
2978 go func() {
2979 _, err := tr.RoundTrip(req)
2980 errc <- err
2981 }()
2982
2983 sc := <-serverConnCh
2984 verb := make([]byte, 3)
2985 if _, err := io.ReadFull(sc, verb); err != nil {
2986 t.Errorf("Error reading HTTP verb from server: %v", err)
2987 }
2988 if string(verb) != "GET" {
2989 t.Errorf("server received %q; want GET", verb)
2990 }
2991 defer sc.Close()
2992
2993 test.cancel(tr, req)
2994
2995 err := <-errc
2996 if err == nil {
2997 t.Fatalf("unexpected success from RoundTrip")
2998 }
2999 test.checkErr("RoundTrip", err)
3000 }
3001
3002
3003
3004
3005 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
3006 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
3007 writeErr := make(chan error, 1)
3008 msg := []byte("young\n")
3009 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3010 for {
3011 _, err := w.Write(msg)
3012 if err != nil {
3013 writeErr <- err
3014 return
3015 }
3016 w.(Flusher).Flush()
3017 }
3018 })).ts
3019
3020 c := ts.Client()
3021 tr := c.Transport.(*Transport)
3022
3023 req, _ := NewRequest("GET", ts.URL, nil)
3024 defer tr.CancelRequest(req)
3025
3026 res, err := c.Do(req)
3027 if err != nil {
3028 t.Fatal(err)
3029 }
3030
3031 const repeats = 3
3032 buf := make([]byte, len(msg)*repeats)
3033 want := bytes.Repeat(msg, repeats)
3034
3035 _, err = io.ReadFull(res.Body, buf)
3036 if err != nil {
3037 t.Fatal(err)
3038 }
3039 if !bytes.Equal(buf, want) {
3040 t.Fatalf("read %q; want %q", buf, want)
3041 }
3042
3043 if err := res.Body.Close(); err != nil {
3044 t.Errorf("Close = %v", err)
3045 }
3046
3047 if err := <-writeErr; err == nil {
3048 t.Errorf("expected non-nil write error")
3049 }
3050 }
3051
3052 type fooProto struct{}
3053
3054 func (fooProto) RoundTrip(req *Request) (*Response, error) {
3055 res := &Response{
3056 Status: "200 OK",
3057 StatusCode: 200,
3058 Header: make(Header),
3059 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
3060 }
3061 return res, nil
3062 }
3063
3064 func TestTransportAltProto(t *testing.T) {
3065 defer afterTest(t)
3066 tr := &Transport{}
3067 c := &Client{Transport: tr}
3068 tr.RegisterProtocol("foo", fooProto{})
3069 res, err := c.Get("foo://bar.com/path")
3070 if err != nil {
3071 t.Fatal(err)
3072 }
3073 bodyb, err := io.ReadAll(res.Body)
3074 if err != nil {
3075 t.Fatal(err)
3076 }
3077 body := string(bodyb)
3078 if e := "You wanted foo://bar.com/path"; body != e {
3079 t.Errorf("got response %q, want %q", body, e)
3080 }
3081 }
3082
3083 func TestTransportNoHost(t *testing.T) {
3084 defer afterTest(t)
3085 tr := &Transport{}
3086 _, err := tr.RoundTrip(&Request{
3087 Header: make(Header),
3088 URL: &url.URL{
3089 Scheme: "http",
3090 },
3091 })
3092 want := "http: no Host in request URL"
3093 if got := fmt.Sprint(err); got != want {
3094 t.Errorf("error = %v; want %q", err, want)
3095 }
3096 }
3097
3098
3099 func TestTransportEmptyMethod(t *testing.T) {
3100 req, _ := NewRequest("GET", "http://foo.com/", nil)
3101 req.Method = ""
3102 got, err := httputil.DumpRequestOut(req, false)
3103 if err != nil {
3104 t.Fatal(err)
3105 }
3106 if !strings.Contains(string(got), "GET ") {
3107 t.Fatalf("expected substring 'GET '; got: %s", got)
3108 }
3109 }
3110
3111 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3112 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3113 mux := NewServeMux()
3114 fooGate := make(chan bool, 1)
3115 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3116 w.Header().Set("foo-ipport", r.RemoteAddr)
3117 w.(Flusher).Flush()
3118 <-fooGate
3119 })
3120 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3121 w.Header().Set("bar-ipport", r.RemoteAddr)
3122 })
3123 ts := newClientServerTest(t, mode, mux).ts
3124
3125 dialGate := make(chan bool, 1)
3126 dialing := make(chan bool)
3127 c := ts.Client()
3128 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3129 for {
3130 select {
3131 case ok := <-dialGate:
3132 if !ok {
3133 return nil, errors.New("manually closed")
3134 }
3135 return net.Dial(n, addr)
3136 case dialing <- true:
3137 }
3138 }
3139 }
3140 defer close(dialGate)
3141
3142 dialGate <- true
3143 fooRes, err := c.Get(ts.URL + "/foo")
3144 if err != nil {
3145 t.Fatal(err)
3146 }
3147 fooAddr := fooRes.Header.Get("foo-ipport")
3148 if fooAddr == "" {
3149 t.Fatal("No addr on /foo request")
3150 }
3151
3152 fooDone := make(chan struct{})
3153 go func() {
3154
3155
3156
3157
3158 if mode == http2Mode {
3159
3160
3161
3162
3163 select {
3164 case <-dialing:
3165 t.Errorf("unexpected second Dial in HTTP/2 mode")
3166 case <-time.After(10 * time.Millisecond):
3167 }
3168 } else {
3169 <-dialing
3170 }
3171 fooGate <- true
3172 io.Copy(io.Discard, fooRes.Body)
3173 fooRes.Body.Close()
3174 close(fooDone)
3175 }()
3176 defer func() {
3177 <-fooDone
3178 }()
3179
3180 barRes, err := c.Get(ts.URL + "/bar")
3181 if err != nil {
3182 t.Fatal(err)
3183 }
3184 barAddr := barRes.Header.Get("bar-ipport")
3185 if barAddr != fooAddr {
3186 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3187 }
3188 barRes.Body.Close()
3189 }
3190
3191
3192 func TestTransportReading100Continue(t *testing.T) {
3193 defer afterTest(t)
3194
3195 const numReqs = 5
3196 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3197 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3198
3199 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3200 defer w.Close()
3201 defer r.Close()
3202 br := bufio.NewReader(r)
3203 n := 0
3204 for {
3205 n++
3206 req, err := ReadRequest(br)
3207 if err == io.EOF {
3208 return
3209 }
3210 if err != nil {
3211 t.Error(err)
3212 return
3213 }
3214 slurp, err := io.ReadAll(req.Body)
3215 if err != nil {
3216 t.Errorf("Server request body slurp: %v", err)
3217 return
3218 }
3219 id := req.Header.Get("Request-Id")
3220 resCode := req.Header.Get("X-Want-Response-Code")
3221 if resCode == "" {
3222 resCode = "100 Continue"
3223 if string(slurp) != reqBody(n) {
3224 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3225 }
3226 }
3227 body := fmt.Sprintf("Response number %d", n)
3228 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3229 Date: Thu, 28 Feb 2013 17:55:41 GMT
3230
3231 HTTP/1.1 200 OK
3232 Content-Type: text/html
3233 Echo-Request-Id: %s
3234 Content-Length: %d
3235
3236 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3237 w.Write(v)
3238 if id == reqID(numReqs) {
3239 return
3240 }
3241 }
3242
3243 }
3244
3245 tr := &Transport{
3246 Dial: func(n, addr string) (net.Conn, error) {
3247 sr, sw := io.Pipe()
3248 cr, cw := io.Pipe()
3249 conn := &rwTestConn{
3250 Reader: cr,
3251 Writer: sw,
3252 closeFunc: func() error {
3253 sw.Close()
3254 cw.Close()
3255 return nil
3256 },
3257 }
3258 go send100Response(cw, sr)
3259 return conn, nil
3260 },
3261 DisableKeepAlives: false,
3262 }
3263 defer tr.CloseIdleConnections()
3264 c := &Client{Transport: tr}
3265
3266 testResponse := func(req *Request, name string, wantCode int) {
3267 t.Helper()
3268 res, err := c.Do(req)
3269 if err != nil {
3270 t.Fatalf("%s: Do: %v", name, err)
3271 }
3272 if res.StatusCode != wantCode {
3273 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3274 }
3275 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3276 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3277 }
3278 _, err = io.ReadAll(res.Body)
3279 if err != nil {
3280 t.Fatalf("%s: Slurp error: %v", name, err)
3281 }
3282 }
3283
3284
3285 for i := 1; i <= numReqs; i++ {
3286 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3287 req.Header.Set("Request-Id", reqID(i))
3288 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3289 }
3290 }
3291
3292
3293
3294 func TestTransportIgnore1xxResponses(t *testing.T) {
3295 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3296 }
3297 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3298 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3299 conn, buf, _ := w.(Hijacker).Hijack()
3300 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3301 buf.Flush()
3302 conn.Close()
3303 }))
3304 cst.tr.DisableKeepAlives = true
3305
3306 var got strings.Builder
3307
3308 req, _ := NewRequest("GET", cst.ts.URL, nil)
3309 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3310 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3311 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3312 return nil
3313 },
3314 }))
3315 res, err := cst.c.Do(req)
3316 if err != nil {
3317 t.Fatal(err)
3318 }
3319 defer res.Body.Close()
3320
3321 res.Write(&got)
3322 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3323 if got.String() != want {
3324 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3325 }
3326 }
3327
3328 func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
3329 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3330 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3331 w.Header().Add("X-Header", strings.Repeat("a", 100))
3332 for i := 0; i < 10; i++ {
3333 w.WriteHeader(123)
3334 }
3335 w.WriteHeader(204)
3336 }))
3337 cst.tr.DisableKeepAlives = true
3338 cst.tr.MaxResponseHeaderBytes = 1000
3339
3340 res, err := cst.c.Get(cst.ts.URL)
3341 if err == nil {
3342 res.Body.Close()
3343 t.Fatalf("RoundTrip succeeded; want error")
3344 }
3345 for _, want := range []string{
3346 "response headers exceeded",
3347 "too many 1xx",
3348 "header list too large",
3349 } {
3350 if strings.Contains(err.Error(), want) {
3351 return
3352 }
3353 }
3354 t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
3355 }
3356
3357 func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
3358 run(t, testTransportDoesNotLimitDelivered1xxResponses)
3359 }
3360 func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
3361 if mode == http2Mode {
3362 t.Skip("skip until x/net/http2 updated")
3363 }
3364 const num1xx = 10
3365 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3366 w.Header().Add("X-Header", strings.Repeat("a", 100))
3367 for i := 0; i < 10; i++ {
3368 w.WriteHeader(123)
3369 }
3370 w.WriteHeader(204)
3371 }))
3372 cst.tr.DisableKeepAlives = true
3373 cst.tr.MaxResponseHeaderBytes = 1000
3374
3375 got1xx := 0
3376 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3377 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3378 got1xx++
3379 return nil
3380 },
3381 })
3382 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
3383 res, err := cst.c.Do(req)
3384 if err != nil {
3385 t.Fatal(err)
3386 }
3387 res.Body.Close()
3388 if got1xx != num1xx {
3389 t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
3390 }
3391 }
3392
3393
3394
3395 func TestTransportTreat101Terminal(t *testing.T) {
3396 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3397 }
3398 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3399 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3400 conn, buf, _ := w.(Hijacker).Hijack()
3401 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3402 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3403 buf.Flush()
3404 conn.Close()
3405 }))
3406 res, err := cst.c.Get(cst.ts.URL)
3407 if err != nil {
3408 t.Fatal(err)
3409 }
3410 defer res.Body.Close()
3411 if res.StatusCode != StatusSwitchingProtocols {
3412 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3413 }
3414 }
3415
3416 type proxyFromEnvTest struct {
3417 req string
3418
3419 env string
3420 httpsenv string
3421 noenv string
3422 reqmeth string
3423
3424 want string
3425 wanterr error
3426 }
3427
3428 func (t proxyFromEnvTest) String() string {
3429 var buf strings.Builder
3430 space := func() {
3431 if buf.Len() > 0 {
3432 buf.WriteByte(' ')
3433 }
3434 }
3435 if t.env != "" {
3436 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3437 }
3438 if t.httpsenv != "" {
3439 space()
3440 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3441 }
3442 if t.noenv != "" {
3443 space()
3444 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3445 }
3446 if t.reqmeth != "" {
3447 space()
3448 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3449 }
3450 req := "http://example.com"
3451 if t.req != "" {
3452 req = t.req
3453 }
3454 space()
3455 fmt.Fprintf(&buf, "req=%q", req)
3456 return strings.TrimSpace(buf.String())
3457 }
3458
3459 var proxyFromEnvTests = []proxyFromEnvTest{
3460 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3461 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3462 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3463 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3464 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3465 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3466 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3467 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3468
3469
3470 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3471
3472 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3473 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3474
3475
3476
3477 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3478 want: "<nil>",
3479 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3480
3481 {want: "<nil>"},
3482
3483 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3484 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3485 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3486 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3487 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3488 }
3489
3490 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3491 t.Helper()
3492 reqURL := tt.req
3493 if reqURL == "" {
3494 reqURL = "http://example.com"
3495 }
3496 req, _ := NewRequest("GET", reqURL, nil)
3497 url, err := proxyForRequest(req)
3498 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3499 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3500 return
3501 }
3502 if got := fmt.Sprintf("%s", url); got != tt.want {
3503 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3504 }
3505 }
3506
3507 func TestProxyFromEnvironment(t *testing.T) {
3508 ResetProxyEnv()
3509 defer ResetProxyEnv()
3510 for _, tt := range proxyFromEnvTests {
3511 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3512 os.Setenv("HTTP_PROXY", tt.env)
3513 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3514 os.Setenv("NO_PROXY", tt.noenv)
3515 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3516 ResetCachedEnvironment()
3517 return ProxyFromEnvironment(req)
3518 })
3519 }
3520 }
3521
3522 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3523 ResetProxyEnv()
3524 defer ResetProxyEnv()
3525 for _, tt := range proxyFromEnvTests {
3526 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3527 os.Setenv("http_proxy", tt.env)
3528 os.Setenv("https_proxy", tt.httpsenv)
3529 os.Setenv("no_proxy", tt.noenv)
3530 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3531 ResetCachedEnvironment()
3532 return ProxyFromEnvironment(req)
3533 })
3534 }
3535 }
3536
3537 func TestIdleConnChannelLeak(t *testing.T) {
3538 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3539 }
3540 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3541
3542 var mu sync.Mutex
3543 var n int
3544
3545 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3546 mu.Lock()
3547 n++
3548 mu.Unlock()
3549 })).ts
3550
3551 const nReqs = 5
3552 didRead := make(chan bool, nReqs)
3553 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3554 defer SetReadLoopBeforeNextReadHook(nil)
3555
3556 c := ts.Client()
3557 tr := c.Transport.(*Transport)
3558 tr.Dial = func(netw, addr string) (net.Conn, error) {
3559 return net.Dial(netw, ts.Listener.Addr().String())
3560 }
3561
3562
3563 for _, disableKeep := range []bool{true, false} {
3564 tr.DisableKeepAlives = disableKeep
3565 for i := 0; i < nReqs; i++ {
3566 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3567 if err != nil {
3568 t.Fatal(err)
3569 }
3570
3571
3572
3573
3574
3575 }
3576
3577
3578
3579
3580
3581
3582
3583 for i := 0; i < nReqs; i++ {
3584 <-didRead
3585 }
3586
3587 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3588 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3589 }
3590 }
3591 }
3592
3593
3594
3595
3596 func TestTransportClosesRequestBody(t *testing.T) {
3597 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3598 }
3599 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3601 io.Copy(io.Discard, r.Body)
3602 })).ts
3603
3604 c := ts.Client()
3605
3606 closes := 0
3607
3608 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3609 if err != nil {
3610 t.Fatal(err)
3611 }
3612 res.Body.Close()
3613 if closes != 1 {
3614 t.Errorf("closes = %d; want 1", closes)
3615 }
3616 }
3617
3618 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3619 defer afterTest(t)
3620 if testing.Short() {
3621 t.Skip("skipping in short mode")
3622 }
3623 ln := newLocalListener(t)
3624 defer ln.Close()
3625 testdonec := make(chan struct{})
3626 defer close(testdonec)
3627
3628 go func() {
3629 c, err := ln.Accept()
3630 if err != nil {
3631 t.Error(err)
3632 return
3633 }
3634 <-testdonec
3635 c.Close()
3636 }()
3637
3638 tr := &Transport{
3639 Dial: func(_, _ string) (net.Conn, error) {
3640 return net.Dial("tcp", ln.Addr().String())
3641 },
3642 TLSHandshakeTimeout: 250 * time.Millisecond,
3643 }
3644 cl := &Client{Transport: tr}
3645 _, err := cl.Get("https://dummy.tld/")
3646 if err == nil {
3647 t.Error("expected error")
3648 return
3649 }
3650 ue, ok := err.(*url.Error)
3651 if !ok {
3652 t.Errorf("expected url.Error; got %#v", err)
3653 return
3654 }
3655 ne, ok := ue.Err.(net.Error)
3656 if !ok {
3657 t.Errorf("expected net.Error; got %#v", err)
3658 return
3659 }
3660 if !ne.Timeout() {
3661 t.Errorf("expected timeout error; got %v", err)
3662 }
3663 if !strings.Contains(err.Error(), "handshake timeout") {
3664 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3665 }
3666 }
3667
3668
3669 func TestTLSServerClosesConnection(t *testing.T) {
3670 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3671 }
3672 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3673 closedc := make(chan bool, 1)
3674 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3675 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3676 conn, _, _ := w.(Hijacker).Hijack()
3677 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3678 conn.Close()
3679 closedc <- true
3680 return
3681 }
3682 fmt.Fprintf(w, "hello")
3683 })).ts
3684
3685 c := ts.Client()
3686 tr := c.Transport.(*Transport)
3687
3688 var nSuccess = 0
3689 var errs []error
3690 const trials = 20
3691 for i := 0; i < trials; i++ {
3692 tr.CloseIdleConnections()
3693 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3694 if err != nil {
3695 t.Fatal(err)
3696 }
3697 <-closedc
3698 slurp, err := io.ReadAll(res.Body)
3699 if err != nil {
3700 t.Fatal(err)
3701 }
3702 if string(slurp) != "foo" {
3703 t.Errorf("Got %q, want foo", slurp)
3704 }
3705
3706
3707
3708 res, err = c.Get(ts.URL + "/")
3709 if err != nil {
3710 errs = append(errs, err)
3711 continue
3712 }
3713 slurp, err = io.ReadAll(res.Body)
3714 if err != nil {
3715 errs = append(errs, err)
3716 continue
3717 }
3718 nSuccess++
3719 }
3720 if nSuccess > 0 {
3721 t.Logf("successes = %d of %d", nSuccess, trials)
3722 } else {
3723 t.Errorf("All runs failed:")
3724 }
3725 for _, err := range errs {
3726 t.Logf(" err: %v", err)
3727 }
3728 }
3729
3730
3731
3732
3733 type byteFromChanReader chan byte
3734
3735 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3736 if len(p) == 0 {
3737 return
3738 }
3739 b, ok := <-c
3740 if !ok {
3741 return 0, io.EOF
3742 }
3743 p[0] = b
3744 return 1, nil
3745 }
3746
3747
3748
3749
3750
3751
3752
3753 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3754 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3755 }
3756 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3757 defer func(d time.Duration) {
3758 *MaxWriteWaitBeforeConnReuse = d
3759 }(*MaxWriteWaitBeforeConnReuse)
3760 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3761 var sconn struct {
3762 sync.Mutex
3763 c net.Conn
3764 }
3765 var getOkay bool
3766 var copying sync.WaitGroup
3767 closeConn := func() {
3768 sconn.Lock()
3769 defer sconn.Unlock()
3770 if sconn.c != nil {
3771 sconn.c.Close()
3772 sconn.c = nil
3773 if !getOkay {
3774 t.Logf("Closed server connection")
3775 }
3776 }
3777 }
3778 defer func() {
3779 closeConn()
3780 copying.Wait()
3781 }()
3782
3783 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3784 if r.Method == "GET" {
3785 io.WriteString(w, "bar")
3786 return
3787 }
3788 conn, _, _ := w.(Hijacker).Hijack()
3789 sconn.Lock()
3790 sconn.c = conn
3791 sconn.Unlock()
3792 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3793
3794 copying.Add(1)
3795 go func() {
3796 io.Copy(io.Discard, conn)
3797 copying.Done()
3798 }()
3799 })).ts
3800 c := ts.Client()
3801
3802 const bodySize = 256 << 10
3803 finalBit := make(byteFromChanReader, 1)
3804 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3805 req.ContentLength = bodySize
3806 res, err := c.Do(req)
3807 if err := wantBody(res, err, "foo"); err != nil {
3808 t.Errorf("POST response: %v", err)
3809 }
3810
3811 res, err = c.Get(ts.URL)
3812 if err := wantBody(res, err, "bar"); err != nil {
3813 t.Errorf("GET response: %v", err)
3814 return
3815 }
3816 getOkay = true
3817 finalBit <- 'x'
3818 close(finalBit)
3819 }
3820
3821
3822
3823 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3824 func testTransportIssue10457(t *testing.T, mode testMode) {
3825 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3826
3827
3828
3829
3830
3831 conn, _, _ := w.(Hijacker).Hijack()
3832 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3833 conn.Close()
3834 })).ts
3835 c := ts.Client()
3836
3837 res, err := c.Get(ts.URL)
3838 if err != nil {
3839 t.Fatalf("Get: %v", err)
3840 }
3841 defer res.Body.Close()
3842
3843
3844
3845
3846 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3847 t.Errorf("Foo header = %q; want %q", got, want)
3848 }
3849 }
3850
3851 type closerFunc func() error
3852
3853 func (f closerFunc) Close() error { return f() }
3854
3855 type writerFuncConn struct {
3856 net.Conn
3857 write func(p []byte) (n int, err error)
3858 }
3859
3860 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3861
3862
3863
3864
3865
3866
3867
3868
3869
3870
3871
3872
3873
3874 func TestRetryRequestsOnError(t *testing.T) {
3875 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3876 }
3877 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3878 newRequest := func(method, urlStr string, body io.Reader) *Request {
3879 req, err := NewRequest(method, urlStr, body)
3880 if err != nil {
3881 t.Fatal(err)
3882 }
3883 return req
3884 }
3885
3886 testCases := []struct {
3887 name string
3888 failureN int
3889 failureErr error
3890
3891
3892
3893 req func() *Request
3894 reqString string
3895 }{
3896 {
3897 name: "IdempotentNoBodySomeWritten",
3898
3899
3900 failureN: 1,
3901
3902 failureErr: ExportErrServerClosedIdle,
3903 req: func() *Request {
3904 return newRequest("GET", "http://fake.golang", nil)
3905 },
3906 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3907 },
3908 {
3909 name: "IdempotentGetBodySomeWritten",
3910
3911
3912 failureN: 1,
3913
3914 failureErr: ExportErrServerClosedIdle,
3915 req: func() *Request {
3916 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3917 },
3918 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3919 },
3920 {
3921 name: "NothingWrittenNoBody",
3922
3923
3924 failureN: 0,
3925 failureErr: errors.New("second write fails"),
3926 req: func() *Request {
3927 return newRequest("DELETE", "http://fake.golang", nil)
3928 },
3929 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3930 },
3931 {
3932 name: "NothingWrittenGetBody",
3933
3934
3935 failureN: 0,
3936 failureErr: errors.New("second write fails"),
3937
3938
3939 req: func() *Request {
3940 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3941 },
3942 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3943 },
3944 }
3945
3946 for _, tc := range testCases {
3947 t.Run(tc.name, func(t *testing.T) {
3948 var (
3949 mu sync.Mutex
3950 logbuf strings.Builder
3951 )
3952 logf := func(format string, args ...any) {
3953 mu.Lock()
3954 defer mu.Unlock()
3955 fmt.Fprintf(&logbuf, format, args...)
3956 logbuf.WriteByte('\n')
3957 }
3958
3959 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3960 logf("Handler")
3961 w.Header().Set("X-Status", "ok")
3962 })).ts
3963
3964 var writeNumAtomic int32
3965 c := ts.Client()
3966 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3967 logf("Dial")
3968 c, err := net.Dial(network, ts.Listener.Addr().String())
3969 if err != nil {
3970 logf("Dial error: %v", err)
3971 return nil, err
3972 }
3973 return &writerFuncConn{
3974 Conn: c,
3975 write: func(p []byte) (n int, err error) {
3976 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3977 logf("intentional write failure")
3978 return tc.failureN, tc.failureErr
3979 }
3980 logf("Write(%q)", p)
3981 return c.Write(p)
3982 },
3983 }, nil
3984 }
3985
3986 SetRoundTripRetried(func() {
3987 logf("Retried.")
3988 })
3989 defer SetRoundTripRetried(nil)
3990
3991 for i := 0; i < 3; i++ {
3992 t0 := time.Now()
3993 req := tc.req()
3994 res, err := c.Do(req)
3995 if err != nil {
3996 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3997 mu.Lock()
3998 got := logbuf.String()
3999 mu.Unlock()
4000 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
4001 }
4002 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
4003 }
4004 res.Body.Close()
4005 if res.Request != req {
4006 t.Errorf("Response.Request != original request; want identical Request")
4007 }
4008 }
4009
4010 mu.Lock()
4011 got := logbuf.String()
4012 mu.Unlock()
4013 want := fmt.Sprintf(`Dial
4014 Write("%s")
4015 Handler
4016 intentional write failure
4017 Retried.
4018 Dial
4019 Write("%s")
4020 Handler
4021 Write("%s")
4022 Handler
4023 `, tc.reqString, tc.reqString, tc.reqString)
4024 if got != want {
4025 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
4026 }
4027 })
4028 }
4029 }
4030
4031
4032 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
4033 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
4034 readBody := make(chan error, 1)
4035 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4036 _, err := io.ReadAll(r.Body)
4037 readBody <- err
4038 })).ts
4039 c := ts.Client()
4040 fakeErr := errors.New("fake error")
4041 didClose := make(chan bool, 1)
4042 req, _ := NewRequest("POST", ts.URL, struct {
4043 io.Reader
4044 io.Closer
4045 }{
4046 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
4047 closerFunc(func() error {
4048 select {
4049 case didClose <- true:
4050 default:
4051 }
4052 return nil
4053 }),
4054 })
4055 res, err := c.Do(req)
4056 if res != nil {
4057 defer res.Body.Close()
4058 }
4059 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
4060 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
4061 }
4062 if err := <-readBody; err == nil {
4063 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
4064 }
4065 select {
4066 case <-didClose:
4067 default:
4068 t.Errorf("didn't see Body.Close")
4069 }
4070 }
4071
4072 func TestTransportDialTLS(t *testing.T) {
4073 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
4074 }
4075 func testTransportDialTLS(t *testing.T, mode testMode) {
4076 var mu sync.Mutex
4077 var gotReq, didDial bool
4078
4079 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4080 mu.Lock()
4081 gotReq = true
4082 mu.Unlock()
4083 })).ts
4084 c := ts.Client()
4085 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
4086 mu.Lock()
4087 didDial = true
4088 mu.Unlock()
4089 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4090 if err != nil {
4091 return nil, err
4092 }
4093 return c, c.Handshake()
4094 }
4095
4096 res, err := c.Get(ts.URL)
4097 if err != nil {
4098 t.Fatal(err)
4099 }
4100 res.Body.Close()
4101 mu.Lock()
4102 if !gotReq {
4103 t.Error("didn't get request")
4104 }
4105 if !didDial {
4106 t.Error("didn't use dial hook")
4107 }
4108 }
4109
4110 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4111 func testTransportDialContext(t *testing.T, mode testMode) {
4112 ctxKey := "some-key"
4113 ctxValue := "some-value"
4114 var (
4115 mu sync.Mutex
4116 gotReq bool
4117 gotCtxValue any
4118 )
4119
4120 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4121 mu.Lock()
4122 gotReq = true
4123 mu.Unlock()
4124 })).ts
4125 c := ts.Client()
4126 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4127 mu.Lock()
4128 gotCtxValue = ctx.Value(ctxKey)
4129 mu.Unlock()
4130 return net.Dial(netw, addr)
4131 }
4132
4133 req, err := NewRequest("GET", ts.URL, nil)
4134 if err != nil {
4135 t.Fatal(err)
4136 }
4137 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4138 res, err := c.Do(req.WithContext(ctx))
4139 if err != nil {
4140 t.Fatal(err)
4141 }
4142 res.Body.Close()
4143 mu.Lock()
4144 if !gotReq {
4145 t.Error("didn't get request")
4146 }
4147 if got, want := gotCtxValue, ctxValue; got != want {
4148 t.Errorf("got context with value %v, want %v", got, want)
4149 }
4150 }
4151
4152 func TestTransportDialTLSContext(t *testing.T) {
4153 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4154 }
4155 func testTransportDialTLSContext(t *testing.T, mode testMode) {
4156 ctxKey := "some-key"
4157 ctxValue := "some-value"
4158 var (
4159 mu sync.Mutex
4160 gotReq bool
4161 gotCtxValue any
4162 )
4163
4164 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4165 mu.Lock()
4166 gotReq = true
4167 mu.Unlock()
4168 })).ts
4169 c := ts.Client()
4170 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4171 mu.Lock()
4172 gotCtxValue = ctx.Value(ctxKey)
4173 mu.Unlock()
4174 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4175 if err != nil {
4176 return nil, err
4177 }
4178 return c, c.HandshakeContext(ctx)
4179 }
4180
4181 req, err := NewRequest("GET", ts.URL, nil)
4182 if err != nil {
4183 t.Fatal(err)
4184 }
4185 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4186 res, err := c.Do(req.WithContext(ctx))
4187 if err != nil {
4188 t.Fatal(err)
4189 }
4190 res.Body.Close()
4191 mu.Lock()
4192 if !gotReq {
4193 t.Error("didn't get request")
4194 }
4195 if got, want := gotCtxValue, ctxValue; got != want {
4196 t.Errorf("got context with value %v, want %v", got, want)
4197 }
4198 }
4199
4200
4201
4202 func TestRoundTripReturnsProxyError(t *testing.T) {
4203 badProxy := func(*Request) (*url.URL, error) {
4204 return nil, errors.New("errorMessage")
4205 }
4206
4207 tr := &Transport{Proxy: badProxy}
4208
4209 req, _ := NewRequest("GET", "http://example.com", nil)
4210
4211 _, err := tr.RoundTrip(req)
4212
4213 if err == nil {
4214 t.Error("Expected proxy error to be returned by RoundTrip")
4215 }
4216 }
4217
4218
4219 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4220 tr := &Transport{}
4221 wantIdle := func(when string, n int) bool {
4222 got := tr.IdleConnCountForTesting("http", "example.com")
4223 if got == n {
4224 return true
4225 }
4226 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4227 return false
4228 }
4229 wantIdle("start", 0)
4230 if !tr.PutIdleTestConn("http", "example.com") {
4231 t.Fatal("put failed")
4232 }
4233 if !tr.PutIdleTestConn("http", "example.com") {
4234 t.Fatal("second put failed")
4235 }
4236 wantIdle("after put", 2)
4237 tr.CloseIdleConnections()
4238 if !tr.IsIdleForTesting() {
4239 t.Error("should be idle after CloseIdleConnections")
4240 }
4241 wantIdle("after close idle", 0)
4242 if tr.PutIdleTestConn("http", "example.com") {
4243 t.Fatal("put didn't fail")
4244 }
4245 wantIdle("after second put", 0)
4246
4247 tr.QueueForIdleConnForTesting()
4248 if tr.IsIdleForTesting() {
4249 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4250 }
4251 if !tr.PutIdleTestConn("http", "example.com") {
4252 t.Fatal("after re-activation")
4253 }
4254 wantIdle("after final put", 1)
4255 }
4256
4257
4258
4259 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4260 tr := &Transport{}
4261 wantIdle := func(when string, n int) bool {
4262 got := tr.IdleConnCountForTesting("https", "example.com:443")
4263 if got == n {
4264 return true
4265 }
4266 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4267 return false
4268 }
4269 wantIdle("start", 0)
4270 alt := funcRoundTripper(func() {})
4271 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4272 t.Fatal("put failed")
4273 }
4274 wantIdle("after put", 1)
4275 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4276 GotConn: func(httptrace.GotConnInfo) {
4277
4278 t.Error("GotConn called")
4279 },
4280 })
4281 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4282 _, err := tr.RoundTrip(req)
4283 if err != errFakeRoundTrip {
4284 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4285 }
4286 wantIdle("after round trip", 1)
4287 }
4288
4289
4290
4291
4292
4293
4294 func TestTransportIdleConnRacesRequest(t *testing.T) {
4295
4296
4297 runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
4298 }
4299 func testTransportIdleConnRacesRequest(t *testing.T, mode testMode) {
4300 if mode == http2UnencryptedMode {
4301 t.Skip("remove skip when #70515 is fixed")
4302 }
4303 timeout := 1 * time.Millisecond
4304 trFunc := func(tr *Transport) {
4305 tr.IdleConnTimeout = timeout
4306 }
4307 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4308 }), trFunc, optFakeNet)
4309 cst.li.trackConns = true
4310
4311
4312
4313
4314
4315 dialc := make(chan struct{})
4316 cst.li.onDial = func() {
4317 <-dialc
4318 }
4319 closec := make(chan struct{})
4320 cst.li.onClose = func(*fakeNetConn) {
4321 <-closec
4322 }
4323 ctx, cancel := context.WithCancel(context.Background())
4324 req1c := make(chan error)
4325 go func() {
4326 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
4327 resp, err := cst.c.Do(req)
4328 if err == nil {
4329 resp.Body.Close()
4330 }
4331 req1c <- err
4332 }()
4333
4334 synctest.Wait()
4335
4336 cancel()
4337 synctest.Wait()
4338 if err := <-req1c; err == nil {
4339 t.Fatal("expected request to fail, but it succeeded")
4340 }
4341
4342 close(dialc)
4343
4344
4345
4346
4347
4348
4349
4350
4351 synctest.Wait()
4352 time.Sleep(timeout)
4353 synctest.Wait()
4354
4355 req2c := make(chan error)
4356 go func() {
4357 resp, err := cst.c.Get(cst.ts.URL)
4358 if err == nil {
4359 resp.Body.Close()
4360 }
4361 req2c <- err
4362 }()
4363
4364
4365 close(closec)
4366 if err := <-req2c; err != nil {
4367 t.Fatalf("Get: %v", err)
4368 }
4369 }
4370
4371 func TestTransportRemovesConnsAfterIdle(t *testing.T) {
4372 runSynctest(t, testTransportRemovesConnsAfterIdle)
4373 }
4374 func testTransportRemovesConnsAfterIdle(t *testing.T, mode testMode) {
4375 if testing.Short() {
4376 t.Skip("skipping in short mode")
4377 }
4378
4379 timeout := 1 * time.Second
4380 trFunc := func(tr *Transport) {
4381 tr.MaxConnsPerHost = 1
4382 tr.MaxIdleConnsPerHost = 1
4383 tr.IdleConnTimeout = timeout
4384 }
4385 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4386 w.Header().Set("X-Addr", r.RemoteAddr)
4387 }), trFunc, optFakeNet)
4388
4389
4390
4391 makeRequest := func() string {
4392 resp, err := cst.c.Get(cst.ts.URL)
4393 if err != nil {
4394 t.Fatalf("got error: %s", err)
4395 }
4396 resp.Body.Close()
4397 return resp.Header.Get("X-Addr")
4398 }
4399
4400 addr1 := makeRequest()
4401
4402 time.Sleep(timeout / 2)
4403 synctest.Wait()
4404 addr2 := makeRequest()
4405 if addr1 != addr2 {
4406 t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
4407 }
4408
4409 time.Sleep(timeout)
4410 synctest.Wait()
4411 addr3 := makeRequest()
4412 if addr1 == addr3 {
4413 t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
4414 }
4415 }
4416
4417 func TestTransportRemovesConnsAfterBroken(t *testing.T) {
4418 runSynctest(t, testTransportRemovesConnsAfterBroken)
4419 }
4420 func testTransportRemovesConnsAfterBroken(t *testing.T, mode testMode) {
4421 if testing.Short() {
4422 t.Skip("skipping in short mode")
4423 }
4424
4425 trFunc := func(tr *Transport) {
4426 tr.MaxConnsPerHost = 1
4427 tr.MaxIdleConnsPerHost = 1
4428 }
4429 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4430 w.Header().Set("X-Addr", r.RemoteAddr)
4431 }), trFunc, optFakeNet)
4432 cst.li.trackConns = true
4433
4434
4435
4436 makeRequest := func() string {
4437 resp, err := cst.c.Get(cst.ts.URL)
4438 if err != nil {
4439 t.Fatalf("got error: %s", err)
4440 }
4441 resp.Body.Close()
4442 return resp.Header.Get("X-Addr")
4443 }
4444
4445 addr1 := makeRequest()
4446 addr2 := makeRequest()
4447 if addr1 != addr2 {
4448 t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
4449 }
4450
4451
4452 synctest.Wait()
4453 cst.li.conns[0].peer.Close()
4454 synctest.Wait()
4455 addr3 := makeRequest()
4456 if addr1 == addr3 {
4457 t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
4458 }
4459 }
4460
4461
4462
4463
4464
4465 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4466 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4467 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4468 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4469 t.Error("Transport advertised gzip support in the Accept header")
4470 }
4471 if r.Header.Get("Range") == "" {
4472 t.Error("no Range in request")
4473 }
4474 })).ts
4475 c := ts.Client()
4476
4477 req, _ := NewRequest("GET", ts.URL, nil)
4478 req.Header.Set("Range", "bytes=7-11")
4479 res, err := c.Do(req)
4480 if err != nil {
4481 t.Fatal(err)
4482 }
4483 res.Body.Close()
4484 }
4485
4486
4487 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4488 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4489 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4490
4491 var b [1024]byte
4492 w.Write(b[:])
4493 })).ts
4494 tr := ts.Client().Transport.(*Transport)
4495
4496 req, err := NewRequest("GET", ts.URL, nil)
4497 if err != nil {
4498 t.Fatal(err)
4499 }
4500 res, err := tr.RoundTrip(req)
4501 if err != nil {
4502 t.Fatal(err)
4503 }
4504
4505
4506
4507 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4508 t.Fatal(err)
4509 }
4510
4511 req2, err := NewRequest("GET", ts.URL, nil)
4512 if err != nil {
4513 t.Fatal(err)
4514 }
4515 tr.CancelRequest(req)
4516 res, err = tr.RoundTrip(req2)
4517 if err != nil {
4518 t.Fatal(err)
4519 }
4520 res.Body.Close()
4521 }
4522
4523
4524 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4525 run(t, testTransportContentEncodingCaseInsensitive)
4526 }
4527 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4528 for _, ce := range []string{"gzip", "GZIP"} {
4529 t.Run(ce, func(t *testing.T) {
4530 const encodedString = "Hello Gopher"
4531 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4532 w.Header().Set("Content-Encoding", ce)
4533 gz := gzip.NewWriter(w)
4534 gz.Write([]byte(encodedString))
4535 gz.Close()
4536 })).ts
4537
4538 res, err := ts.Client().Get(ts.URL)
4539 if err != nil {
4540 t.Fatal(err)
4541 }
4542
4543 body, err := io.ReadAll(res.Body)
4544 res.Body.Close()
4545 if err != nil {
4546 t.Fatal(err)
4547 }
4548
4549 if string(body) != encodedString {
4550 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4551 }
4552 })
4553 }
4554 }
4555
4556
4557 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4558 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4559 }
4560 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4561 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4562 func(tr *Transport) {
4563 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4564
4565 return &funcConn{
4566 read: func([]byte) (int, error) {
4567 return 0, errors.New("error")
4568 },
4569 write: func([]byte) (int, error) {
4570 return 0, errors.New("error")
4571 },
4572 }, nil
4573 }
4574 },
4575 ).ts
4576
4577
4578
4579
4580
4581 SetEnterRoundTripHook(func() {
4582 time.Sleep(1 * time.Millisecond)
4583 })
4584 defer SetEnterRoundTripHook(nil)
4585 var closes int
4586 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4587 if err == nil {
4588 t.Fatalf("expected request to fail, but it did not")
4589 }
4590 if closes != 1 {
4591 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4592 }
4593 }
4594
4595
4596
4597
4598 type logWritesConn struct {
4599 net.Conn
4600
4601 w io.Writer
4602
4603 rch <-chan io.Reader
4604 r io.Reader
4605
4606 mu sync.Mutex
4607 writes []string
4608 }
4609
4610 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4611 c.mu.Lock()
4612 defer c.mu.Unlock()
4613 c.writes = append(c.writes, string(p))
4614 return c.w.Write(p)
4615 }
4616
4617 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4618 if c.r == nil {
4619 c.r = <-c.rch
4620 }
4621 return c.r.Read(p)
4622 }
4623
4624 func (c *logWritesConn) Close() error { return nil }
4625
4626
4627 func TestTransportFlushesBodyChunks(t *testing.T) {
4628 defer afterTest(t)
4629 resBody := make(chan io.Reader, 1)
4630 connr, connw := io.Pipe()
4631 lw := &logWritesConn{
4632 rch: resBody,
4633 w: connw,
4634 }
4635 tr := &Transport{
4636 Dial: func(network, addr string) (net.Conn, error) {
4637 return lw, nil
4638 },
4639 }
4640 bodyr, bodyw := io.Pipe()
4641 go func() {
4642 defer bodyw.Close()
4643 for i := 0; i < 3; i++ {
4644 fmt.Fprintf(bodyw, "num%d\n", i)
4645 }
4646 }()
4647 resc := make(chan *Response)
4648 go func() {
4649 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4650 req.Header.Set("User-Agent", "x")
4651 res, err := tr.RoundTrip(req)
4652 if err != nil {
4653 t.Errorf("RoundTrip: %v", err)
4654 close(resc)
4655 return
4656 }
4657 resc <- res
4658
4659 }()
4660
4661 req, err := ReadRequest(bufio.NewReader(connr))
4662 if err != nil {
4663 t.Fatal(err)
4664 }
4665 io.Copy(io.Discard, req.Body)
4666
4667
4668 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4669 res, ok := <-resc
4670 if !ok {
4671 return
4672 }
4673 defer res.Body.Close()
4674
4675 want := []string{
4676 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4677 "5\r\nnum0\n\r\n",
4678 "5\r\nnum1\n\r\n",
4679 "5\r\nnum2\n\r\n",
4680 "0\r\n\r\n",
4681 }
4682 if !slices.Equal(lw.writes, want) {
4683 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4684 }
4685 }
4686
4687
4688 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4689 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4690 gotReq := make(chan struct{})
4691 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4692 close(gotReq)
4693 }))
4694
4695 pr, pw := io.Pipe()
4696 req, err := NewRequest("POST", cst.ts.URL, pr)
4697 if err != nil {
4698 t.Fatal(err)
4699 }
4700 gotRes := make(chan struct{})
4701 go func() {
4702 defer close(gotRes)
4703 res, err := cst.tr.RoundTrip(req)
4704 if err != nil {
4705 t.Error(err)
4706 return
4707 }
4708 res.Body.Close()
4709 }()
4710
4711 <-gotReq
4712 pw.Close()
4713 <-gotRes
4714 }
4715
4716 type wgReadCloser struct {
4717 io.Reader
4718 wg *sync.WaitGroup
4719 closed bool
4720 }
4721
4722 func (c *wgReadCloser) Close() error {
4723 if c.closed {
4724 return net.ErrClosed
4725 }
4726 c.closed = true
4727 c.wg.Done()
4728 return nil
4729 }
4730
4731
4732 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4733
4734 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4735 }
4736 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4737 if testing.Short() {
4738 t.Skip("skipping in short mode")
4739 }
4740
4741 runTimeSensitiveTest(t, []time.Duration{
4742 1 * time.Millisecond,
4743 5 * time.Millisecond,
4744 10 * time.Millisecond,
4745 50 * time.Millisecond,
4746 100 * time.Millisecond,
4747 500 * time.Millisecond,
4748 time.Second,
4749 5 * time.Second,
4750 }, func(t *testing.T, timeout time.Duration) error {
4751 SetRSTAvoidanceDelay(t, timeout)
4752 t.Logf("set RST avoidance delay to %v", timeout)
4753
4754 const contentLengthLimit = 1024 * 1024
4755 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4756 if r.ContentLength >= contentLengthLimit {
4757 w.WriteHeader(StatusBadRequest)
4758 r.Body.Close()
4759 return
4760 }
4761 w.WriteHeader(StatusOK)
4762 }))
4763
4764
4765 defer cst.close()
4766 ts := cst.ts
4767 c := ts.Client()
4768
4769 count := 100
4770
4771 bigBody := strings.Repeat("a", contentLengthLimit*2)
4772 var wg sync.WaitGroup
4773 defer wg.Wait()
4774 getBody := func() (io.ReadCloser, error) {
4775 wg.Add(1)
4776 body := &wgReadCloser{
4777 Reader: strings.NewReader(bigBody),
4778 wg: &wg,
4779 }
4780 return body, nil
4781 }
4782
4783 for i := 0; i < count; i++ {
4784 reqBody, _ := getBody()
4785 req, err := NewRequest("PUT", ts.URL, reqBody)
4786 if err != nil {
4787 reqBody.Close()
4788 t.Fatal(err)
4789 }
4790 req.ContentLength = int64(len(bigBody))
4791 req.GetBody = getBody
4792
4793 resp, err := c.Do(req)
4794 if err != nil {
4795 return fmt.Errorf("Do %d: %v", i, err)
4796 } else {
4797 resp.Body.Close()
4798 if resp.StatusCode != 400 {
4799 t.Errorf("Expected status code 400, got %v", resp.Status)
4800 }
4801 }
4802 }
4803 return nil
4804 })
4805 }
4806
4807 func TestTransportAutomaticHTTP2(t *testing.T) {
4808 testTransportAutoHTTP(t, &Transport{}, true)
4809 }
4810
4811 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4812 testTransportAutoHTTP(t, &Transport{
4813 ForceAttemptHTTP2: true,
4814 TLSClientConfig: new(tls.Config),
4815 }, true)
4816 }
4817
4818
4819 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4820 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4821 }
4822
4823 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4824 testTransportAutoHTTP(t, &Transport{
4825 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4826 }, false)
4827 }
4828
4829 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4830 testTransportAutoHTTP(t, &Transport{
4831 TLSClientConfig: new(tls.Config),
4832 }, false)
4833 }
4834
4835 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4836 testTransportAutoHTTP(t, &Transport{
4837 ExpectContinueTimeout: 1 * time.Second,
4838 }, true)
4839 }
4840
4841 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4842 var d net.Dialer
4843 testTransportAutoHTTP(t, &Transport{
4844 Dial: d.Dial,
4845 }, false)
4846 }
4847
4848 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4849 var d net.Dialer
4850 testTransportAutoHTTP(t, &Transport{
4851 DialContext: d.DialContext,
4852 }, false)
4853 }
4854
4855 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4856 testTransportAutoHTTP(t, &Transport{
4857 DialTLS: func(network, addr string) (net.Conn, error) {
4858 panic("unused")
4859 },
4860 }, false)
4861 }
4862
4863 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4864 CondSkipHTTP2(t)
4865 _, err := tr.RoundTrip(new(Request))
4866 if err == nil {
4867 t.Error("expected error from RoundTrip")
4868 }
4869 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4870 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4871 }
4872 }
4873
4874
4875
4876
4877
4878
4879
4880
4881 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4882 run(t, testTransportReuseConnEmptyResponseBody)
4883 }
4884 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4885 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4886 w.Header().Set("X-Addr", r.RemoteAddr)
4887
4888 }))
4889 n := 100
4890 if testing.Short() {
4891 n = 10
4892 }
4893 var firstAddr string
4894 for i := 0; i < n; i++ {
4895 res, err := cst.c.Get(cst.ts.URL)
4896 if err != nil {
4897 log.Fatal(err)
4898 }
4899 addr := res.Header.Get("X-Addr")
4900 if i == 0 {
4901 firstAddr = addr
4902 } else if addr != firstAddr {
4903 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4904 }
4905 res.Body.Close()
4906 }
4907 }
4908
4909
4910 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4911 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4912 if err != nil {
4913 t.Fatal(err)
4914 }
4915 ln := newLocalListener(t)
4916 defer ln.Close()
4917
4918 var wg sync.WaitGroup
4919 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4920 defer SetPendingDialHooks(nil, nil)
4921
4922 testDone := make(chan struct{})
4923 defer close(testDone)
4924 go func() {
4925 tln := tls.NewListener(ln, &tls.Config{
4926 NextProtos: []string{"foo"},
4927 Certificates: []tls.Certificate{cert},
4928 })
4929 sc, err := tln.Accept()
4930 if err != nil {
4931 t.Error(err)
4932 return
4933 }
4934 if err := sc.(*tls.Conn).Handshake(); err != nil {
4935 t.Error(err)
4936 return
4937 }
4938 <-testDone
4939 sc.Close()
4940 }()
4941
4942 addr := ln.Addr().String()
4943
4944 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4945 cancel := make(chan struct{})
4946 req.Cancel = cancel
4947
4948 doReturned := make(chan bool, 1)
4949 madeRoundTripper := make(chan bool, 1)
4950
4951 tr := &Transport{
4952 DisableKeepAlives: true,
4953 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4954 "foo": func(authority string, c *tls.Conn) RoundTripper {
4955 madeRoundTripper <- true
4956 return funcRoundTripper(func() {
4957 t.Error("foo RoundTripper should not be called")
4958 })
4959 },
4960 },
4961 Dial: func(_, _ string) (net.Conn, error) {
4962 panic("shouldn't be called")
4963 },
4964 DialTLS: func(_, _ string) (net.Conn, error) {
4965 tc, err := tls.Dial("tcp", addr, &tls.Config{
4966 InsecureSkipVerify: true,
4967 NextProtos: []string{"foo"},
4968 })
4969 if err != nil {
4970 return nil, err
4971 }
4972 if err := tc.Handshake(); err != nil {
4973 return nil, err
4974 }
4975 close(cancel)
4976 <-doReturned
4977 return tc, nil
4978 },
4979 }
4980 c := &Client{Transport: tr}
4981
4982 _, err = c.Do(req)
4983 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4984 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4985 }
4986
4987 doReturned <- true
4988 <-madeRoundTripper
4989 wg.Wait()
4990 }
4991
4992 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4993 run(t, func(t *testing.T, mode testMode) {
4994 testTransportReuseConnection_Gzip(t, mode, true)
4995 })
4996 }
4997
4998 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4999 run(t, func(t *testing.T, mode testMode) {
5000 testTransportReuseConnection_Gzip(t, mode, false)
5001 })
5002 }
5003
5004
5005 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
5006 addr := make(chan string, 2)
5007 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5008 addr <- r.RemoteAddr
5009 w.Header().Set("Content-Encoding", "gzip")
5010 if chunked {
5011 w.(Flusher).Flush()
5012 }
5013 w.Write(rgz)
5014 })).ts
5015 c := ts.Client()
5016
5017 trace := &httptrace.ClientTrace{
5018 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
5019 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
5020 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
5021 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
5022 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
5023 }
5024 ctx := httptrace.WithClientTrace(context.Background(), trace)
5025
5026 for i := 0; i < 2; i++ {
5027 req, _ := NewRequest("GET", ts.URL, nil)
5028 req = req.WithContext(ctx)
5029 res, err := c.Do(req)
5030 if err != nil {
5031 t.Fatal(err)
5032 }
5033 buf := make([]byte, len(rgz))
5034 if n, err := io.ReadFull(res.Body, buf); err != nil {
5035 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
5036 }
5037
5038
5039
5040 }
5041 a1, a2 := <-addr, <-addr
5042 if a1 != a2 {
5043 t.Fatalf("didn't reuse connection")
5044 }
5045 }
5046
5047 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
5048 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
5049 if mode == http2Mode {
5050 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
5051 }
5052 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5053 if r.URL.Path == "/long" {
5054 w.Header().Set("Long", strings.Repeat("a", 1<<20))
5055 }
5056 })).ts
5057 c := ts.Client()
5058 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
5059
5060 if res, err := c.Get(ts.URL); err != nil {
5061 t.Fatal(err)
5062 } else {
5063 res.Body.Close()
5064 }
5065
5066 res, err := c.Get(ts.URL + "/long")
5067 if err == nil {
5068 defer res.Body.Close()
5069 var n int64
5070 for k, vv := range res.Header {
5071 for _, v := range vv {
5072 n += int64(len(k)) + int64(len(v))
5073 }
5074 }
5075 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
5076 }
5077 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
5078 t.Errorf("got error: %v; want %q", err, want)
5079 }
5080 }
5081
5082 func TestTransportEventTrace(t *testing.T) {
5083 run(t, func(t *testing.T, mode testMode) {
5084 testTransportEventTrace(t, mode, false)
5085 }, testNotParallel)
5086 }
5087
5088
5089 func TestTransportEventTrace_NoHooks(t *testing.T) {
5090 run(t, func(t *testing.T, mode testMode) {
5091 testTransportEventTrace(t, mode, true)
5092 }, testNotParallel)
5093 }
5094
5095 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
5096 const resBody = "some body"
5097 gotWroteReqEvent := make(chan struct{}, 500)
5098 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5099 if r.Method == "GET" {
5100
5101 return
5102 }
5103 if _, err := io.ReadAll(r.Body); err != nil {
5104 t.Error(err)
5105 }
5106 if !noHooks {
5107 <-gotWroteReqEvent
5108 }
5109 io.WriteString(w, resBody)
5110 }), func(tr *Transport) {
5111 if tr.TLSClientConfig != nil {
5112 tr.TLSClientConfig.InsecureSkipVerify = true
5113 }
5114 })
5115 defer cst.close()
5116
5117 cst.tr.ExpectContinueTimeout = 1 * time.Second
5118
5119 var mu sync.Mutex
5120 var buf strings.Builder
5121 logf := func(format string, args ...any) {
5122 mu.Lock()
5123 defer mu.Unlock()
5124 fmt.Fprintf(&buf, format, args...)
5125 buf.WriteByte('\n')
5126 }
5127
5128 addrStr := cst.ts.Listener.Addr().String()
5129 ip, port, err := net.SplitHostPort(addrStr)
5130 if err != nil {
5131 t.Fatal(err)
5132 }
5133
5134
5135 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5136 if host != "dns-is-faked.golang" {
5137 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
5138 return nil, nil
5139 }
5140 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5141 })
5142
5143 body := "some body"
5144 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
5145 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
5146 trace := &httptrace.ClientTrace{
5147 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
5148 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
5149 GotFirstResponseByte: func() { logf("first response byte") },
5150 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
5151 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
5152 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
5153 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
5154 ConnectDone: func(network, addr string, err error) {
5155 if err != nil {
5156 t.Errorf("ConnectDone: %v", err)
5157 }
5158 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
5159 },
5160 WroteHeaderField: func(key string, value []string) {
5161 logf("WroteHeaderField: %s: %v", key, value)
5162 },
5163 WroteHeaders: func() {
5164 logf("WroteHeaders")
5165 },
5166 Wait100Continue: func() { logf("Wait100Continue") },
5167 Got100Continue: func() { logf("Got100Continue") },
5168 WroteRequest: func(e httptrace.WroteRequestInfo) {
5169 logf("WroteRequest: %+v", e)
5170 gotWroteReqEvent <- struct{}{}
5171 },
5172 }
5173 if mode == http2Mode {
5174 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
5175 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
5176 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
5177 }
5178 }
5179 if noHooks {
5180
5181 *trace = httptrace.ClientTrace{}
5182 }
5183 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5184
5185 req.Header.Set("Expect", "100-continue")
5186 res, err := cst.c.Do(req)
5187 if err != nil {
5188 t.Fatal(err)
5189 }
5190 logf("got roundtrip.response")
5191 slurp, err := io.ReadAll(res.Body)
5192 if err != nil {
5193 t.Fatal(err)
5194 }
5195 logf("consumed body")
5196 if string(slurp) != resBody || res.StatusCode != 200 {
5197 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
5198 }
5199 res.Body.Close()
5200
5201 if noHooks {
5202
5203
5204
5205 return
5206 }
5207
5208 mu.Lock()
5209 got := buf.String()
5210 mu.Unlock()
5211
5212 wantOnce := func(sub string) {
5213 if strings.Count(got, sub) != 1 {
5214 t.Errorf("expected substring %q exactly once in output.", sub)
5215 }
5216 }
5217 wantOnceOrMore := func(sub string) {
5218 if strings.Count(got, sub) == 0 {
5219 t.Errorf("expected substring %q at least once in output.", sub)
5220 }
5221 }
5222 wantOnce("Getting conn for dns-is-faked.golang:" + port)
5223 wantOnce("DNS start: {Host:dns-is-faked.golang}")
5224 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
5225 wantOnce("got conn: {")
5226 wantOnceOrMore("Connecting to tcp " + addrStr)
5227 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5228 wantOnce("Reused:false WasIdle:false IdleTime:0s")
5229 wantOnce("first response byte")
5230 if mode == http2Mode {
5231 wantOnce("tls handshake start")
5232 wantOnce("tls handshake done")
5233 } else {
5234 wantOnce("PutIdleConn = <nil>")
5235 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5236
5237
5238 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5239 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5240 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5241 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5242 }
5243 wantOnce("WroteHeaders")
5244 wantOnce("Wait100Continue")
5245 wantOnce("Got100Continue")
5246 wantOnce("WroteRequest: {Err:<nil>}")
5247 if strings.Contains(got, " to udp ") {
5248 t.Errorf("should not see UDP (DNS) connections")
5249 }
5250 if t.Failed() {
5251 t.Errorf("Output:\n%s", got)
5252 }
5253
5254
5255 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5256 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5257 res, err = cst.c.Do(req)
5258 if err != nil {
5259 t.Fatal(err)
5260 }
5261 if res.StatusCode != 200 {
5262 t.Fatal(res.Status)
5263 }
5264 res.Body.Close()
5265
5266 mu.Lock()
5267 got = buf.String()
5268 mu.Unlock()
5269
5270 sub := "Getting conn for dns-is-faked.golang:"
5271 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5272 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5273 }
5274
5275 }
5276
5277 func TestTransportEventTraceTLSVerify(t *testing.T) {
5278 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5279 }
5280 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5281 var mu sync.Mutex
5282 var buf strings.Builder
5283 logf := func(format string, args ...any) {
5284 mu.Lock()
5285 defer mu.Unlock()
5286 fmt.Fprintf(&buf, format, args...)
5287 buf.WriteByte('\n')
5288 }
5289
5290 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5291 t.Error("Unexpected request")
5292 }), func(ts *httptest.Server) {
5293 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5294 logf("%s", p)
5295 return len(p), nil
5296 }), "", 0)
5297 }).ts
5298
5299 certpool := x509.NewCertPool()
5300 certpool.AddCert(ts.Certificate())
5301
5302 c := &Client{Transport: &Transport{
5303 TLSClientConfig: &tls.Config{
5304 ServerName: "dns-is-faked.golang",
5305 RootCAs: certpool,
5306 },
5307 }}
5308
5309 trace := &httptrace.ClientTrace{
5310 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5311 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5312 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5313 },
5314 }
5315
5316 req, _ := NewRequest("GET", ts.URL, nil)
5317 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5318 _, err := c.Do(req)
5319 if err == nil {
5320 t.Error("Expected request to fail TLS verification")
5321 }
5322
5323 mu.Lock()
5324 got := buf.String()
5325 mu.Unlock()
5326
5327 wantOnce := func(sub string) {
5328 if strings.Count(got, sub) != 1 {
5329 t.Errorf("expected substring %q exactly once in output.", sub)
5330 }
5331 }
5332
5333 wantOnce("TLSHandshakeStart")
5334 wantOnce("TLSHandshakeDone")
5335 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5336
5337 if t.Failed() {
5338 t.Errorf("Output:\n%s", got)
5339 }
5340 }
5341
5342 var isDNSHijacked = sync.OnceValue(func() bool {
5343 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5344 return len(addrs) != 0
5345 })
5346
5347 func skipIfDNSHijacked(t *testing.T) {
5348
5349
5350
5351 if isDNSHijacked() {
5352 t.Skip("skipping; test requires non-hijacking DNS server")
5353 }
5354 }
5355
5356 func TestTransportEventTraceRealDNS(t *testing.T) {
5357 skipIfDNSHijacked(t)
5358 defer afterTest(t)
5359 tr := &Transport{}
5360 defer tr.CloseIdleConnections()
5361 c := &Client{Transport: tr}
5362
5363 var mu sync.Mutex
5364 var buf strings.Builder
5365 logf := func(format string, args ...any) {
5366 mu.Lock()
5367 defer mu.Unlock()
5368 fmt.Fprintf(&buf, format, args...)
5369 buf.WriteByte('\n')
5370 }
5371
5372 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5373 trace := &httptrace.ClientTrace{
5374 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5375 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5376 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5377 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5378 }
5379 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5380
5381 resp, err := c.Do(req)
5382 if err == nil {
5383 resp.Body.Close()
5384 t.Fatal("expected error during DNS lookup")
5385 }
5386
5387 mu.Lock()
5388 got := buf.String()
5389 mu.Unlock()
5390
5391 wantSub := func(sub string) {
5392 if !strings.Contains(got, sub) {
5393 t.Errorf("expected substring %q in output.", sub)
5394 }
5395 }
5396 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5397 wantSub("DNSDone: {Addrs:[] Err:")
5398 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5399 t.Errorf("should not see Connect events")
5400 }
5401 if t.Failed() {
5402 t.Errorf("Output:\n%s", got)
5403 }
5404 }
5405
5406
5407 func TestTransportRejectsAlphaPort(t *testing.T) {
5408 res, err := Get("http://dummy.tld:123foo/bar")
5409 if err == nil {
5410 res.Body.Close()
5411 t.Fatal("unexpected success")
5412 }
5413 ue, ok := err.(*url.Error)
5414 if !ok {
5415 t.Fatalf("got %#v; want *url.Error", err)
5416 }
5417 got := ue.Err.Error()
5418 want := `invalid port ":123foo" after host`
5419 if got != want {
5420 t.Errorf("got error %q; want %q", got, want)
5421 }
5422 }
5423
5424
5425
5426 func TestTLSHandshakeTrace(t *testing.T) {
5427 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5428 }
5429 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5430 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5431
5432 var mu sync.Mutex
5433 var start, done bool
5434 trace := &httptrace.ClientTrace{
5435 TLSHandshakeStart: func() {
5436 mu.Lock()
5437 defer mu.Unlock()
5438 start = true
5439 },
5440 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5441 mu.Lock()
5442 defer mu.Unlock()
5443 done = true
5444 if err != nil {
5445 t.Fatal("Expected error to be nil but was:", err)
5446 }
5447 },
5448 }
5449
5450 c := ts.Client()
5451 req, err := NewRequest("GET", ts.URL, nil)
5452 if err != nil {
5453 t.Fatal("Unable to construct test request:", err)
5454 }
5455 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5456
5457 r, err := c.Do(req)
5458 if err != nil {
5459 t.Fatal("Unexpected error making request:", err)
5460 }
5461 r.Body.Close()
5462 mu.Lock()
5463 defer mu.Unlock()
5464 if !start {
5465 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5466 }
5467 if !done {
5468 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5469 }
5470 }
5471
5472 func TestTransportMaxIdleConns(t *testing.T) {
5473 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5474 }
5475 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5476 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5477
5478 })).ts
5479 c := ts.Client()
5480 tr := c.Transport.(*Transport)
5481 tr.MaxIdleConns = 4
5482
5483 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5484 if err != nil {
5485 t.Fatal(err)
5486 }
5487 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5488 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5489 })
5490
5491 hitHost := func(n int) {
5492 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5493 req = req.WithContext(ctx)
5494 res, err := c.Do(req)
5495 if err != nil {
5496 t.Fatal(err)
5497 }
5498 res.Body.Close()
5499 }
5500 for i := 0; i < 4; i++ {
5501 hitHost(i)
5502 }
5503 want := []string{
5504 "|http|host-0.dns-is-faked.golang:" + port,
5505 "|http|host-1.dns-is-faked.golang:" + port,
5506 "|http|host-2.dns-is-faked.golang:" + port,
5507 "|http|host-3.dns-is-faked.golang:" + port,
5508 }
5509 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5510 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5511 }
5512
5513
5514 hitHost(4)
5515 want = []string{
5516 "|http|host-1.dns-is-faked.golang:" + port,
5517 "|http|host-2.dns-is-faked.golang:" + port,
5518 "|http|host-3.dns-is-faked.golang:" + port,
5519 "|http|host-4.dns-is-faked.golang:" + port,
5520 }
5521 if got := tr.IdleConnKeysForTesting(); !slices.Equal(got, want) {
5522 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5523 }
5524 }
5525
5526 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5527 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5528 if testing.Short() {
5529 t.Skip("skipping in short mode")
5530 }
5531
5532 timeout := 1 * time.Millisecond
5533 timeoutLoop:
5534 for {
5535 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5536
5537 }))
5538 tr := cst.tr
5539 tr.IdleConnTimeout = timeout
5540 defer tr.CloseIdleConnections()
5541 c := &Client{Transport: tr}
5542
5543 idleConns := func() []string {
5544 if mode == http2Mode {
5545 return tr.IdleConnStrsForTesting_h2()
5546 } else {
5547 return tr.IdleConnStrsForTesting()
5548 }
5549 }
5550
5551 var conn string
5552 doReq := func(n int) (timeoutOk bool) {
5553 req, _ := NewRequest("GET", cst.ts.URL, nil)
5554 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5555 PutIdleConn: func(err error) {
5556 if err != nil {
5557 t.Errorf("failed to keep idle conn: %v", err)
5558 }
5559 },
5560 }))
5561 res, err := c.Do(req)
5562 if err != nil {
5563 if strings.Contains(err.Error(), "use of closed network connection") {
5564 t.Logf("req %v: connection closed prematurely", n)
5565 return false
5566 }
5567 }
5568 if err == nil {
5569 res.Body.Close()
5570 }
5571 conns := idleConns()
5572 if len(conns) != 1 {
5573 if len(conns) == 0 {
5574 t.Logf("req %v: no idle conns", n)
5575 return false
5576 }
5577 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5578 }
5579 if conn == "" {
5580 conn = conns[0]
5581 }
5582 if conn != conns[0] {
5583 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5584 return false
5585 }
5586 return true
5587 }
5588 for i := 0; i < 3; i++ {
5589 if !doReq(i) {
5590 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5591 timeout *= 2
5592 cst.close()
5593 continue timeoutLoop
5594 }
5595 time.Sleep(timeout / 2)
5596 }
5597
5598 waitCondition(t, timeout/2, func(d time.Duration) bool {
5599 if got := idleConns(); len(got) != 0 {
5600 if d >= timeout*3/2 {
5601 t.Logf("after %v, idle conns = %q", d, got)
5602 }
5603 return false
5604 }
5605 return true
5606 })
5607 break
5608 }
5609 }
5610
5611
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
5622 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5623 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5624 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5625
5626 }))
5627
5628 ctx, cancel := context.WithCancel(context.Background())
5629 defer cancel()
5630
5631 sawDoErr := make(chan bool, 1)
5632 testDone := make(chan struct{})
5633 defer close(testDone)
5634
5635 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5636 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5637 c, err := tls.Dial(network, addr, &tls.Config{
5638 InsecureSkipVerify: true,
5639 NextProtos: []string{"h2"},
5640 })
5641 if err != nil {
5642 t.Error(err)
5643 return nil, err
5644 }
5645 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5646 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5647 c.Close()
5648 return nil, errors.New("bogus")
5649 }
5650
5651 cancel()
5652
5653 select {
5654 case <-sawDoErr:
5655 case <-testDone:
5656 }
5657 return c, nil
5658 }
5659
5660 req, _ := NewRequest("GET", cst.ts.URL, nil)
5661 req = req.WithContext(ctx)
5662 res, err := cst.c.Do(req)
5663 if err == nil {
5664 res.Body.Close()
5665 t.Fatal("unexpected success")
5666 }
5667 sawDoErr <- true
5668
5669
5670 time.Sleep(cst.tr.IdleConnTimeout * 10)
5671 }
5672
5673 type funcConn struct {
5674 net.Conn
5675 read func([]byte) (int, error)
5676 write func([]byte) (int, error)
5677 }
5678
5679 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5680 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5681 func (c funcConn) Close() error { return nil }
5682
5683
5684
5685 func TestTransportReturnsPeekError(t *testing.T) {
5686 errValue := errors.New("specific error value")
5687
5688 wrote := make(chan struct{})
5689 wroteOnce := sync.OnceFunc(func() { close(wrote) })
5690
5691 tr := &Transport{
5692 Dial: func(network, addr string) (net.Conn, error) {
5693 c := funcConn{
5694 read: func([]byte) (int, error) {
5695 <-wrote
5696 return 0, errValue
5697 },
5698 write: func(p []byte) (int, error) {
5699 wroteOnce()
5700 return len(p), nil
5701 },
5702 }
5703 return c, nil
5704 },
5705 }
5706 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5707 if err != errValue {
5708 t.Errorf("error = %#v; want %v", err, errValue)
5709 }
5710 }
5711
5712
5713 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5714 func testTransportIDNA(t *testing.T, mode testMode) {
5715 const uniDomain = "гофер.го"
5716 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5717
5718 var port string
5719 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5720 want := punyDomain + ":" + port
5721 if r.Host != want {
5722 t.Errorf("Host header = %q; want %q", r.Host, want)
5723 }
5724 if mode == http2Mode {
5725 if r.TLS == nil {
5726 t.Errorf("r.TLS == nil")
5727 } else if r.TLS.ServerName != punyDomain {
5728 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5729 }
5730 }
5731 w.Header().Set("Hit-Handler", "1")
5732 }), func(tr *Transport) {
5733 if tr.TLSClientConfig != nil {
5734 tr.TLSClientConfig.InsecureSkipVerify = true
5735 }
5736 })
5737
5738 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5739 if err != nil {
5740 t.Fatal(err)
5741 }
5742
5743
5744 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5745 if host != punyDomain {
5746 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5747 return nil, nil
5748 }
5749 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5750 })
5751
5752 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5753 trace := &httptrace.ClientTrace{
5754 GetConn: func(hostPort string) {
5755 want := net.JoinHostPort(punyDomain, port)
5756 if hostPort != want {
5757 t.Errorf("getting conn for %q; want %q", hostPort, want)
5758 }
5759 },
5760 DNSStart: func(e httptrace.DNSStartInfo) {
5761 if e.Host != punyDomain {
5762 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5763 }
5764 },
5765 }
5766 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5767
5768 res, err := cst.tr.RoundTrip(req)
5769 if err != nil {
5770 t.Fatal(err)
5771 }
5772 defer res.Body.Close()
5773 if res.Header.Get("Hit-Handler") != "1" {
5774 out, err := httputil.DumpResponse(res, true)
5775 if err != nil {
5776 t.Fatal(err)
5777 }
5778 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5779 }
5780 }
5781
5782
5783 func TestTransportProxyConnectHeader(t *testing.T) {
5784 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5785 }
5786 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5787 reqc := make(chan *Request, 1)
5788 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5789 if r.Method != "CONNECT" {
5790 t.Errorf("method = %q; want CONNECT", r.Method)
5791 }
5792 reqc <- r
5793 c, _, err := w.(Hijacker).Hijack()
5794 if err != nil {
5795 t.Errorf("Hijack: %v", err)
5796 return
5797 }
5798 c.Close()
5799 })).ts
5800
5801 c := ts.Client()
5802 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5803 return url.Parse(ts.URL)
5804 }
5805 c.Transport.(*Transport).ProxyConnectHeader = Header{
5806 "User-Agent": {"foo"},
5807 "Other": {"bar"},
5808 }
5809
5810 res, err := c.Get("https://dummy.tld/")
5811 if err == nil {
5812 res.Body.Close()
5813 t.Errorf("unexpected success")
5814 }
5815
5816 r := <-reqc
5817 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5818 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5819 }
5820 if got, want := r.Header.Get("Other"), "bar"; got != want {
5821 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5822 }
5823 }
5824
5825 func TestTransportProxyGetConnectHeader(t *testing.T) {
5826 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5827 }
5828 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5829 reqc := make(chan *Request, 1)
5830 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5831 if r.Method != "CONNECT" {
5832 t.Errorf("method = %q; want CONNECT", r.Method)
5833 }
5834 reqc <- r
5835 c, _, err := w.(Hijacker).Hijack()
5836 if err != nil {
5837 t.Errorf("Hijack: %v", err)
5838 return
5839 }
5840 c.Close()
5841 })).ts
5842
5843 c := ts.Client()
5844 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5845 return url.Parse(ts.URL)
5846 }
5847
5848 c.Transport.(*Transport).ProxyConnectHeader = Header{
5849 "User-Agent": {"foo"},
5850 "Other": {"bar"},
5851 }
5852 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5853 return Header{
5854 "User-Agent": {"foo2"},
5855 "Other": {"bar2"},
5856 }, nil
5857 }
5858
5859 res, err := c.Get("https://dummy.tld/")
5860 if err == nil {
5861 res.Body.Close()
5862 t.Errorf("unexpected success")
5863 }
5864
5865 r := <-reqc
5866 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5867 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5868 }
5869 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5870 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5871 }
5872 }
5873
5874 var errFakeRoundTrip = errors.New("fake roundtrip")
5875
5876 type funcRoundTripper func()
5877
5878 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5879 fn()
5880 return nil, errFakeRoundTrip
5881 }
5882
5883 func wantBody(res *Response, err error, want string) error {
5884 if err != nil {
5885 return err
5886 }
5887 slurp, err := io.ReadAll(res.Body)
5888 if err != nil {
5889 return fmt.Errorf("error reading body: %v", err)
5890 }
5891 if string(slurp) != want {
5892 return fmt.Errorf("body = %q; want %q", slurp, want)
5893 }
5894 if err := res.Body.Close(); err != nil {
5895 return fmt.Errorf("body Close = %v", err)
5896 }
5897 return nil
5898 }
5899
5900 func newLocalListener(t *testing.T) net.Listener {
5901 ln, err := net.Listen("tcp", "127.0.0.1:0")
5902 if err != nil {
5903 ln, err = net.Listen("tcp6", "[::1]:0")
5904 }
5905 if err != nil {
5906 t.Fatal(err)
5907 }
5908 return ln
5909 }
5910
5911 type countCloseReader struct {
5912 n *int
5913 io.Reader
5914 }
5915
5916 func (cr countCloseReader) Close() error {
5917 (*cr.n)++
5918 return nil
5919 }
5920
5921
5922 var rgz = []byte{
5923 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5924 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5925 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5926 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5927 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5928 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5929 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5930 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5931 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5932 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5933 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5934 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5935 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5936 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5937 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5938 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5939 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5940 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5941 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5942 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5943 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5944 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5945 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5946 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5947 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5948 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5949 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5950 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5951 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5952 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5953 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5954 0x00, 0x00,
5955 }
5956
5957
5958
5959 func TestMissingStatusNoPanic(t *testing.T) {
5960 t.Parallel()
5961
5962 const want = "unknown status code"
5963
5964 ln := newLocalListener(t)
5965 addr := ln.Addr().String()
5966 done := make(chan bool)
5967 fullAddrURL := fmt.Sprintf("http://%s", addr)
5968 raw := "HTTP/1.1 400\r\n" +
5969 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5970 "Content-Type: text/html; charset=utf-8\r\n" +
5971 "Content-Length: 10\r\n" +
5972 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5973 "Vary: Accept-Encoding\r\n\r\n" +
5974 "Aloha Olaa"
5975
5976 go func() {
5977 defer close(done)
5978
5979 conn, _ := ln.Accept()
5980 if conn != nil {
5981 io.WriteString(conn, raw)
5982 io.ReadAll(conn)
5983 conn.Close()
5984 }
5985 }()
5986
5987 proxyURL, err := url.Parse(fullAddrURL)
5988 if err != nil {
5989 t.Fatalf("proxyURL: %v", err)
5990 }
5991
5992 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5993
5994 req, _ := NewRequest("GET", "https://golang.org/", nil)
5995 res, err, panicked := doFetchCheckPanic(tr, req)
5996 if panicked {
5997 t.Error("panicked, expecting an error")
5998 }
5999 if res != nil && res.Body != nil {
6000 io.Copy(io.Discard, res.Body)
6001 res.Body.Close()
6002 }
6003
6004 if err == nil || !strings.Contains(err.Error(), want) {
6005 t.Errorf("got=%v want=%q", err, want)
6006 }
6007
6008 ln.Close()
6009 <-done
6010 }
6011
6012 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
6013 defer func() {
6014 if r := recover(); r != nil {
6015 panicked = true
6016 }
6017 }()
6018 res, err = tr.RoundTrip(req)
6019 return
6020 }
6021
6022
6023
6024 func TestNoBodyOnChunked304Response(t *testing.T) {
6025 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
6026 }
6027 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
6028 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6029 conn, buf, _ := w.(Hijacker).Hijack()
6030 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
6031 buf.Flush()
6032 conn.Close()
6033 }))
6034
6035
6036
6037
6038
6039 cst.tr.DisableKeepAlives = true
6040
6041 res, err := cst.c.Get(cst.ts.URL)
6042 if err != nil {
6043 t.Fatal(err)
6044 }
6045
6046 if res.Body != NoBody {
6047 t.Errorf("Unexpected body on 304 response")
6048 }
6049 }
6050
6051 type funcWriter func([]byte) (int, error)
6052
6053 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
6054
6055 type doneContext struct {
6056 context.Context
6057 err error
6058 }
6059
6060 func (doneContext) Done() <-chan struct{} {
6061 c := make(chan struct{})
6062 close(c)
6063 return c
6064 }
6065
6066 func (d doneContext) Err() error { return d.err }
6067
6068
6069 func TestTransportCheckContextDoneEarly(t *testing.T) {
6070 tr := &Transport{}
6071 req, _ := NewRequest("GET", "http://fake.example/", nil)
6072 wantErr := errors.New("some error")
6073 req = req.WithContext(doneContext{context.Background(), wantErr})
6074 _, err := tr.RoundTrip(req)
6075 if err != wantErr {
6076 t.Errorf("error = %v; want %v", err, wantErr)
6077 }
6078 }
6079
6080
6081
6082
6083
6084
6085 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
6086 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
6087 }
6088 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
6089 timeout := 1 * time.Millisecond
6090 for {
6091 inHandler := make(chan bool)
6092 cancelHandler := make(chan struct{})
6093 handlerDone := make(chan bool)
6094 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6095 <-r.Context().Done()
6096
6097 select {
6098 case <-cancelHandler:
6099 return
6100 case inHandler <- true:
6101 }
6102 defer func() { handlerDone <- true }()
6103
6104
6105 conn, _, err := w.(Hijacker).Hijack()
6106 if err != nil {
6107 t.Error(err)
6108 return
6109 }
6110 n, err := conn.Read([]byte{0})
6111 if n != 0 || err != io.EOF {
6112 t.Errorf("unexpected Read result: %v, %v", n, err)
6113 }
6114 conn.Close()
6115 }))
6116
6117 cst.c.Timeout = timeout
6118
6119 _, err := cst.c.Get(cst.ts.URL)
6120 if err == nil {
6121 close(cancelHandler)
6122 t.Fatal("unexpected Get success")
6123 }
6124
6125 tooSlow := time.NewTimer(timeout * 10)
6126 select {
6127 case <-tooSlow.C:
6128
6129
6130
6131 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
6132 close(cancelHandler)
6133 cst.close()
6134 timeout *= 2
6135 continue
6136 case <-inHandler:
6137 tooSlow.Stop()
6138 <-handlerDone
6139 }
6140 break
6141 }
6142 }
6143
6144
6145
6146
6147
6148
6149 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
6150 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
6151 }
6152 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
6153 inHandler := make(chan bool)
6154 cancelHandler := make(chan struct{})
6155 handlerDone := make(chan bool)
6156 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6157 w.Header().Set("Content-Length", "100")
6158 w.(Flusher).Flush()
6159
6160 select {
6161 case <-cancelHandler:
6162 return
6163 case inHandler <- true:
6164 }
6165 defer func() { handlerDone <- true }()
6166
6167 conn, _, err := w.(Hijacker).Hijack()
6168 if err != nil {
6169 t.Error(err)
6170 return
6171 }
6172 conn.Write([]byte("foo"))
6173
6174 n, err := conn.Read([]byte{0})
6175
6176
6177
6178
6179
6180 if n != 0 || err == nil {
6181 t.Errorf("unexpected Read result: %v, %v", n, err)
6182 }
6183 conn.Close()
6184 }))
6185
6186
6187
6188
6189
6190 cst.c.Timeout = 24 * time.Hour
6191 req, _ := NewRequest("GET", cst.ts.URL, nil)
6192 cancelReq := make(chan struct{})
6193 req.Cancel = cancelReq
6194
6195 res, err := cst.c.Do(req)
6196 if err != nil {
6197 close(cancelHandler)
6198 t.Fatalf("Get error: %v", err)
6199 }
6200
6201
6202
6203
6204 close(cancelReq)
6205 got, err := io.ReadAll(res.Body)
6206 if err == nil {
6207 t.Errorf("unexpected success; read %q, nil", got)
6208 }
6209
6210
6211 <-inHandler
6212 <-handlerDone
6213 }
6214
6215 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
6216 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
6217 }
6218 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
6219 done := make(chan struct{})
6220 defer close(done)
6221 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6222 conn, _, err := w.(Hijacker).Hijack()
6223 if err != nil {
6224 t.Error(err)
6225 return
6226 }
6227 defer conn.Close()
6228 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6229 bs := bufio.NewScanner(conn)
6230 bs.Scan()
6231 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6232 <-done
6233 }))
6234
6235 req, _ := NewRequest("GET", cst.ts.URL, nil)
6236 req.Header.Set("Upgrade", "foo")
6237 req.Header.Set("Connection", "upgrade")
6238 res, err := cst.c.Do(req)
6239 if err != nil {
6240 t.Fatal(err)
6241 }
6242 if res.StatusCode != 101 {
6243 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6244 }
6245 rwc, ok := res.Body.(io.ReadWriteCloser)
6246 if !ok {
6247 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6248 }
6249 defer rwc.Close()
6250 bs := bufio.NewScanner(rwc)
6251 if !bs.Scan() {
6252 t.Fatalf("expected readable input")
6253 }
6254 if got, want := bs.Text(), "Some buffered data"; got != want {
6255 t.Errorf("read %q; want %q", got, want)
6256 }
6257 io.WriteString(rwc, "echo\n")
6258 if !bs.Scan() {
6259 t.Fatalf("expected another line")
6260 }
6261 if got, want := bs.Text(), "ECHO"; got != want {
6262 t.Errorf("read %q; want %q", got, want)
6263 }
6264 }
6265
6266 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6267 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6268 const target = "backend:443"
6269 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6270 if r.Method != "CONNECT" {
6271 t.Errorf("unexpected method %q", r.Method)
6272 w.WriteHeader(500)
6273 return
6274 }
6275 if r.RequestURI != target {
6276 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6277 w.WriteHeader(500)
6278 return
6279 }
6280 nc, brw, err := w.(Hijacker).Hijack()
6281 if err != nil {
6282 t.Error(err)
6283 return
6284 }
6285 defer nc.Close()
6286 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6287
6288 for {
6289 line, err := brw.ReadString('\n')
6290 if err != nil {
6291 if err != io.EOF {
6292 t.Error(err)
6293 }
6294 return
6295 }
6296 io.WriteString(brw, strings.ToUpper(line))
6297 brw.Flush()
6298 }
6299 }))
6300 pr, pw := io.Pipe()
6301 defer pw.Close()
6302 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6303 if err != nil {
6304 t.Fatal(err)
6305 }
6306 req.URL.Opaque = target
6307 res, err := cst.c.Do(req)
6308 if err != nil {
6309 t.Fatal(err)
6310 }
6311 defer res.Body.Close()
6312 if res.StatusCode != 200 {
6313 t.Fatalf("status code = %d; want 200", res.StatusCode)
6314 }
6315 br := bufio.NewReader(res.Body)
6316 for _, str := range []string{"foo", "bar", "baz"} {
6317 fmt.Fprintf(pw, "%s\n", str)
6318 got, err := br.ReadString('\n')
6319 if err != nil {
6320 t.Fatal(err)
6321 }
6322 got = strings.TrimSpace(got)
6323 want := strings.ToUpper(str)
6324 if got != want {
6325 t.Fatalf("got %q; want %q", got, want)
6326 }
6327 }
6328 }
6329
6330 func TestTransportRequestReplayable(t *testing.T) {
6331 someBody := io.NopCloser(strings.NewReader(""))
6332 tests := []struct {
6333 name string
6334 req *Request
6335 want bool
6336 }{
6337 {
6338 name: "GET",
6339 req: &Request{Method: "GET"},
6340 want: true,
6341 },
6342 {
6343 name: "GET_http.NoBody",
6344 req: &Request{Method: "GET", Body: NoBody},
6345 want: true,
6346 },
6347 {
6348 name: "GET_body",
6349 req: &Request{Method: "GET", Body: someBody},
6350 want: false,
6351 },
6352 {
6353 name: "POST",
6354 req: &Request{Method: "POST"},
6355 want: false,
6356 },
6357 {
6358 name: "POST_idempotency-key",
6359 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6360 want: true,
6361 },
6362 {
6363 name: "POST_x-idempotency-key",
6364 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6365 want: true,
6366 },
6367 {
6368 name: "POST_body",
6369 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6370 want: false,
6371 },
6372 }
6373 for _, tt := range tests {
6374 t.Run(tt.name, func(t *testing.T) {
6375 got := tt.req.ExportIsReplayable()
6376 if got != tt.want {
6377 t.Errorf("replyable = %v; want %v", got, tt.want)
6378 }
6379 })
6380 }
6381 }
6382
6383
6384
6385 type testMockTCPConn struct {
6386 *net.TCPConn
6387
6388 ReadFromCalled bool
6389 }
6390
6391 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6392 c.ReadFromCalled = true
6393 return c.TCPConn.ReadFrom(r)
6394 }
6395
6396 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6397 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6398 nBytes := int64(1 << 10)
6399 newFileFunc := func() (r io.Reader, done func(), err error) {
6400 f, err := os.CreateTemp("", "net-http-newfilefunc")
6401 if err != nil {
6402 return nil, nil, err
6403 }
6404
6405
6406 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6407 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6408 }
6409 if _, err := f.Seek(0, 0); err != nil {
6410 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6411 }
6412
6413 done = func() {
6414 f.Close()
6415 os.Remove(f.Name())
6416 }
6417
6418 return f, done, nil
6419 }
6420
6421 newBufferFunc := func() (io.Reader, func(), error) {
6422 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6423 }
6424
6425 cases := []struct {
6426 name string
6427 readerFunc func() (io.Reader, func(), error)
6428 contentLength int64
6429 expectedReadFrom bool
6430 }{
6431 {
6432 name: "file, length",
6433 readerFunc: newFileFunc,
6434 contentLength: nBytes,
6435 expectedReadFrom: true,
6436 },
6437 {
6438 name: "file, no length",
6439 readerFunc: newFileFunc,
6440 },
6441 {
6442 name: "file, negative length",
6443 readerFunc: newFileFunc,
6444 contentLength: -1,
6445 },
6446 {
6447 name: "buffer",
6448 contentLength: nBytes,
6449 readerFunc: newBufferFunc,
6450 },
6451 {
6452 name: "buffer, no length",
6453 readerFunc: newBufferFunc,
6454 },
6455 {
6456 name: "buffer, length -1",
6457 contentLength: -1,
6458 readerFunc: newBufferFunc,
6459 },
6460 }
6461
6462 for _, tc := range cases {
6463 t.Run(tc.name, func(t *testing.T) {
6464 r, cleanup, err := tc.readerFunc()
6465 if err != nil {
6466 t.Fatal(err)
6467 }
6468 defer cleanup()
6469
6470 tConn := &testMockTCPConn{}
6471 trFunc := func(tr *Transport) {
6472 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6473 var d net.Dialer
6474 conn, err := d.DialContext(ctx, network, addr)
6475 if err != nil {
6476 return nil, err
6477 }
6478
6479 tcpConn, ok := conn.(*net.TCPConn)
6480 if !ok {
6481 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6482 }
6483
6484 tConn.TCPConn = tcpConn
6485 return tConn, nil
6486 }
6487 }
6488
6489 cst := newClientServerTest(
6490 t,
6491 mode,
6492 HandlerFunc(func(w ResponseWriter, r *Request) {
6493 io.Copy(io.Discard, r.Body)
6494 r.Body.Close()
6495 w.WriteHeader(200)
6496 }),
6497 trFunc,
6498 )
6499
6500 req, err := NewRequest("PUT", cst.ts.URL, r)
6501 if err != nil {
6502 t.Fatal(err)
6503 }
6504 req.ContentLength = tc.contentLength
6505 req.Header.Set("Content-Type", "application/octet-stream")
6506 resp, err := cst.c.Do(req)
6507 if err != nil {
6508 t.Fatal(err)
6509 }
6510 defer resp.Body.Close()
6511 if resp.StatusCode != 200 {
6512 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6513 }
6514
6515 expectedReadFrom := tc.expectedReadFrom
6516 if mode != http1Mode {
6517 expectedReadFrom = false
6518 }
6519 if !tConn.ReadFromCalled && expectedReadFrom {
6520 t.Fatalf("did not call ReadFrom")
6521 }
6522
6523 if tConn.ReadFromCalled && !expectedReadFrom {
6524 t.Fatalf("ReadFrom was unexpectedly invoked")
6525 }
6526 })
6527 }
6528 }
6529
6530 func TestTransportClone(t *testing.T) {
6531 tr := &Transport{
6532 Proxy: func(*Request) (*url.URL, error) { panic("") },
6533 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6534 return nil
6535 },
6536 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6537 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6538 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6539 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6540 TLSClientConfig: new(tls.Config),
6541 TLSHandshakeTimeout: time.Second,
6542 DisableKeepAlives: true,
6543 DisableCompression: true,
6544 MaxIdleConns: 1,
6545 MaxIdleConnsPerHost: 1,
6546 MaxConnsPerHost: 1,
6547 IdleConnTimeout: time.Second,
6548 ResponseHeaderTimeout: time.Second,
6549 ExpectContinueTimeout: time.Second,
6550 ProxyConnectHeader: Header{},
6551 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6552 MaxResponseHeaderBytes: 1,
6553 ForceAttemptHTTP2: true,
6554 HTTP2: &HTTP2Config{MaxConcurrentStreams: 1},
6555 Protocols: &Protocols{},
6556 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6557 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6558 },
6559 ReadBufferSize: 1,
6560 WriteBufferSize: 1,
6561 }
6562 tr.Protocols.SetHTTP1(true)
6563 tr.Protocols.SetHTTP2(true)
6564 tr2 := tr.Clone()
6565 rv := reflect.ValueOf(tr2).Elem()
6566 rt := rv.Type()
6567 for i := 0; i < rt.NumField(); i++ {
6568 sf := rt.Field(i)
6569 if !token.IsExported(sf.Name) {
6570 continue
6571 }
6572 if rv.Field(i).IsZero() {
6573 t.Errorf("cloned field t2.%s is zero", sf.Name)
6574 }
6575 }
6576
6577 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6578 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6579 }
6580
6581
6582 tr = new(Transport)
6583 tr2 = tr.Clone()
6584 if tr2.TLSNextProto != nil {
6585 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6586 }
6587 }
6588
6589 func TestIs408(t *testing.T) {
6590 tests := []struct {
6591 in string
6592 want bool
6593 }{
6594 {"HTTP/1.0 408", true},
6595 {"HTTP/1.1 408", true},
6596 {"HTTP/1.8 408", true},
6597 {"HTTP/2.0 408", false},
6598 {"HTTP/1.1 408 ", true},
6599 {"HTTP/1.1 40", false},
6600 {"http/1.0 408", false},
6601 {"HTTP/1-1 408", false},
6602 }
6603 for _, tt := range tests {
6604 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6605 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6606 }
6607 }
6608 }
6609
6610 func TestTransportIgnores408(t *testing.T) {
6611 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6612 }
6613 func testTransportIgnores408(t *testing.T, mode testMode) {
6614
6615 defer log.SetOutput(log.Writer())
6616
6617 var logout strings.Builder
6618 log.SetOutput(&logout)
6619
6620 const target = "backend:443"
6621
6622 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6623 nc, _, err := w.(Hijacker).Hijack()
6624 if err != nil {
6625 t.Error(err)
6626 return
6627 }
6628 defer nc.Close()
6629 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6630 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6631 }))
6632 req, err := NewRequest("GET", cst.ts.URL, nil)
6633 if err != nil {
6634 t.Fatal(err)
6635 }
6636 res, err := cst.c.Do(req)
6637 if err != nil {
6638 t.Fatal(err)
6639 }
6640 slurp, err := io.ReadAll(res.Body)
6641 if err != nil {
6642 t.Fatal(err)
6643 }
6644 if err != nil {
6645 t.Fatal(err)
6646 }
6647 if string(slurp) != "ok" {
6648 t.Fatalf("got %q; want ok", slurp)
6649 }
6650
6651 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6652 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6653 if d > 0 {
6654 t.Logf("%v idle conns still present after %v", n, d)
6655 }
6656 return false
6657 }
6658 return true
6659 })
6660 if got := logout.String(); got != "" {
6661 t.Fatalf("expected no log output; got: %s", got)
6662 }
6663 }
6664
6665 func TestInvalidHeaderResponse(t *testing.T) {
6666 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6667 }
6668 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6669 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6670 conn, buf, _ := w.(Hijacker).Hijack()
6671 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6672 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6673 "Content-Type: text/html; charset=utf-8\r\n" +
6674 "Content-Length: 0\r\n" +
6675 "Foo : bar\r\n\r\n"))
6676 buf.Flush()
6677 conn.Close()
6678 }))
6679 res, err := cst.c.Get(cst.ts.URL)
6680 if err != nil {
6681 t.Fatal(err)
6682 }
6683 defer res.Body.Close()
6684 if v := res.Header.Get("Foo"); v != "" {
6685 t.Errorf(`unexpected "Foo" header: %q`, v)
6686 }
6687 if v := res.Header.Get("Foo "); v != "bar" {
6688 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6689 }
6690 }
6691
6692 type bodyCloser bool
6693
6694 func (bc *bodyCloser) Close() error {
6695 *bc = true
6696 return nil
6697 }
6698 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6699 return 0, io.EOF
6700 }
6701
6702
6703
6704 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6705 run(t, testTransportClosesBodyOnInvalidRequests)
6706 }
6707 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6708 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6709 t.Errorf("Should not have been invoked")
6710 })).ts
6711
6712 u, _ := url.Parse(cst.URL)
6713
6714 tests := []struct {
6715 name string
6716 req *Request
6717 wantErr string
6718 }{
6719 {
6720 name: "invalid method",
6721 req: &Request{
6722 Method: " ",
6723 URL: u,
6724 },
6725 wantErr: `invalid method " "`,
6726 },
6727 {
6728 name: "nil URL",
6729 req: &Request{
6730 Method: "GET",
6731 },
6732 wantErr: `nil Request.URL`,
6733 },
6734 {
6735 name: "invalid header key",
6736 req: &Request{
6737 Method: "GET",
6738 Header: Header{"💡": {"emoji"}},
6739 URL: u,
6740 },
6741 wantErr: `invalid header field name "💡"`,
6742 },
6743 {
6744 name: "invalid header value",
6745 req: &Request{
6746 Method: "POST",
6747 Header: Header{"key": {"\x19"}},
6748 URL: u,
6749 },
6750 wantErr: `invalid header field value for "key"`,
6751 },
6752 {
6753 name: "non HTTP(s) scheme",
6754 req: &Request{
6755 Method: "POST",
6756 URL: &url.URL{Scheme: "faux"},
6757 },
6758 wantErr: `unsupported protocol scheme "faux"`,
6759 },
6760 {
6761 name: "no Host in URL",
6762 req: &Request{
6763 Method: "POST",
6764 URL: &url.URL{Scheme: "http"},
6765 },
6766 wantErr: `no Host in request URL`,
6767 },
6768 }
6769
6770 for _, tt := range tests {
6771 t.Run(tt.name, func(t *testing.T) {
6772 var bc bodyCloser
6773 req := tt.req
6774 req.Body = &bc
6775 _, err := cst.Client().Do(tt.req)
6776 if err == nil {
6777 t.Fatal("Expected an error")
6778 }
6779 if !bc {
6780 t.Fatal("Expected body to have been closed")
6781 }
6782 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6783 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6784 }
6785 })
6786 }
6787 }
6788
6789
6790
6791 type breakableConn struct {
6792 net.Conn
6793 *brokenState
6794 }
6795
6796 type brokenState struct {
6797 sync.Mutex
6798 broken bool
6799 }
6800
6801 func (w *breakableConn) Write(b []byte) (n int, err error) {
6802 w.Lock()
6803 defer w.Unlock()
6804 if w.broken {
6805 return 0, errors.New("some write error")
6806 }
6807 return w.Conn.Write(b)
6808 }
6809
6810
6811 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6812 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6813 }
6814 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6815 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6816
6817 var brokenState brokenState
6818
6819 const numReqs = 5
6820 var numDials, gotConns uint32
6821
6822 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6823 atomic.AddUint32(&numDials, 1)
6824 c, err := net.Dial(netw, addr)
6825 if err != nil {
6826 t.Errorf("unexpected Dial error: %v", err)
6827 return nil, err
6828 }
6829 return &breakableConn{c, &brokenState}, err
6830 }
6831
6832 for i := 1; i <= numReqs; i++ {
6833 brokenState.Lock()
6834 brokenState.broken = false
6835 brokenState.Unlock()
6836
6837
6838
6839
6840 doBreak := i != numReqs
6841
6842 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6843 GotConn: func(info httptrace.GotConnInfo) {
6844 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6845 atomic.AddUint32(&gotConns, 1)
6846 },
6847 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6848 brokenState.Lock()
6849 defer brokenState.Unlock()
6850 if doBreak {
6851 brokenState.broken = true
6852 }
6853 },
6854 })
6855 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6856 if err != nil {
6857 t.Fatal(err)
6858 }
6859 _, err = cst.c.Do(req)
6860 if doBreak != (err != nil) {
6861 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6862 }
6863 }
6864 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6865 t.Errorf("GotConn calls = %v; want %v", got, want)
6866 }
6867 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6868 t.Errorf("Dials = %v; want %v", got, want)
6869 }
6870 }
6871
6872
6873
6874
6875
6876 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6877 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6878 }
6879 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6880 CondSkipHTTP2(t)
6881
6882 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6883 _, err := w.Write([]byte("foo"))
6884 if err != nil {
6885 t.Fatalf("Write: %v", err)
6886 }
6887 })
6888
6889 ts := newClientServerTest(t, mode, h).ts
6890
6891 c := ts.Client()
6892 tr := c.Transport.(*Transport)
6893 tr.MaxConnsPerHost = 1
6894
6895 errCh := make(chan error, 300)
6896 doReq := func() {
6897 resp, err := c.Get(ts.URL)
6898 if err != nil {
6899 errCh <- fmt.Errorf("request failed: %v", err)
6900 return
6901 }
6902 defer resp.Body.Close()
6903 _, err = io.ReadAll(resp.Body)
6904 if err != nil {
6905 errCh <- fmt.Errorf("read body failed: %v", err)
6906 }
6907 }
6908
6909 var wg sync.WaitGroup
6910 for i := 0; i < 300; i++ {
6911 wg.Add(1)
6912 go func() {
6913 defer wg.Done()
6914 doReq()
6915 }()
6916 }
6917 wg.Wait()
6918 close(errCh)
6919
6920 for err := range errCh {
6921 t.Errorf("error occurred: %v", err)
6922 }
6923 }
6924
6925
6926
6927
6928 func TestAltProtoCancellation(t *testing.T) {
6929 defer afterTest(t)
6930 tr := &Transport{}
6931 c := &Client{
6932 Transport: tr,
6933 Timeout: time.Millisecond,
6934 }
6935 tr.RegisterProtocol("cancel", cancelProto{})
6936 _, err := c.Get("cancel://bar.com/path")
6937 if err == nil {
6938 t.Error("request unexpectedly succeeded")
6939 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6940 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6941 }
6942 }
6943
6944 var errCancelProto = errors.New("canceled as expected")
6945
6946 type cancelProto struct{}
6947
6948 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6949 <-req.Cancel
6950 return nil, errCancelProto
6951 }
6952
6953 type roundTripFunc func(r *Request) (*Response, error)
6954
6955 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6956
6957
6958 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6959 func testIssue32441(t *testing.T, mode testMode) {
6960 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6961 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6962 t.Error("body length is zero")
6963 }
6964 })).ts
6965 c := ts.Client()
6966 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6967
6968 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6969 t.Error("body length is zero during round trip")
6970 }
6971 return nil, ErrSkipAltProtocol
6972 }))
6973 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6974 t.Error(err)
6975 }
6976 }
6977
6978
6979
6980 func TestTransportRejectsSignInContentLength(t *testing.T) {
6981 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6982 }
6983 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6984 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6985 w.Header().Set("Content-Length", "+3")
6986 w.Write([]byte("abc"))
6987 })).ts
6988
6989 c := cst.Client()
6990 res, err := c.Get(cst.URL)
6991 if err == nil || res != nil {
6992 t.Fatal("Expected a non-nil error and a nil http.Response")
6993 }
6994 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6995 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6996 }
6997 }
6998
6999
7000 type dumpConn struct {
7001 io.Writer
7002 io.Reader
7003 }
7004
7005 func (c *dumpConn) Close() error { return nil }
7006 func (c *dumpConn) LocalAddr() net.Addr { return nil }
7007 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
7008 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
7009 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
7010 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
7011
7012
7013
7014 type delegateReader struct {
7015 c chan io.Reader
7016 r io.Reader
7017 }
7018
7019 func (r *delegateReader) Read(p []byte) (int, error) {
7020 if r.r == nil {
7021 var ok bool
7022 if r.r, ok = <-r.c; !ok {
7023 return 0, errors.New("delegate closed")
7024 }
7025 }
7026 return r.r.Read(p)
7027 }
7028
7029 func testTransportRace(req *Request) {
7030 save := req.Body
7031 pr, pw := io.Pipe()
7032 defer pr.Close()
7033 defer pw.Close()
7034 dr := &delegateReader{c: make(chan io.Reader)}
7035
7036 t := &Transport{
7037 Dial: func(net, addr string) (net.Conn, error) {
7038 return &dumpConn{pw, dr}, nil
7039 },
7040 }
7041 defer t.CloseIdleConnections()
7042
7043 quitReadCh := make(chan struct{})
7044
7045 go func() {
7046 defer close(quitReadCh)
7047
7048 req, err := ReadRequest(bufio.NewReader(pr))
7049 if err == nil {
7050
7051
7052 io.Copy(io.Discard, req.Body)
7053 req.Body.Close()
7054 }
7055 select {
7056 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
7057 case quitReadCh <- struct{}{}:
7058
7059 close(dr.c)
7060 }
7061 }()
7062
7063 t.RoundTrip(req)
7064
7065
7066
7067 pw.Close()
7068 <-quitReadCh
7069
7070 req.Body = save
7071 }
7072
7073
7074
7075
7076
7077 func TestErrorWriteLoopRace(t *testing.T) {
7078 if testing.Short() {
7079 return
7080 }
7081 t.Parallel()
7082 for i := 0; i < 1000; i++ {
7083 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
7084 ctx, cancel := context.WithTimeout(context.Background(), delay)
7085 defer cancel()
7086
7087 r := bytes.NewBuffer(make([]byte, 10000))
7088 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
7089 if err != nil {
7090 t.Fatal(err)
7091 }
7092
7093 testTransportRace(req)
7094 }
7095 }
7096
7097
7098
7099
7100 func TestCancelRequestWhenSharingConnection(t *testing.T) {
7101 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
7102 }
7103 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
7104 reqc := make(chan chan struct{}, 2)
7105 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
7106 ch := make(chan struct{}, 1)
7107 reqc <- ch
7108 <-ch
7109 w.Header().Add("Content-Length", "0")
7110 })).ts
7111
7112 client := ts.Client()
7113 transport := client.Transport.(*Transport)
7114 transport.MaxIdleConns = 1
7115 transport.MaxConnsPerHost = 1
7116
7117 var wg sync.WaitGroup
7118
7119 wg.Add(1)
7120 putidlec := make(chan chan struct{}, 1)
7121 reqerrc := make(chan error, 1)
7122 go func() {
7123 defer wg.Done()
7124 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
7125 PutIdleConn: func(error) {
7126
7127
7128 ch := make(chan struct{})
7129 putidlec <- ch
7130 close(putidlec)
7131 <-ch
7132 },
7133 })
7134 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
7135 res, err := client.Do(req)
7136 if err != nil {
7137 reqerrc <- err
7138 } else {
7139 res.Body.Close()
7140 }
7141 }()
7142
7143
7144
7145 select {
7146 case err := <-reqerrc:
7147 t.Fatalf("request 1: got err %v, want nil", err)
7148 case r1c := <-reqc:
7149 close(r1c)
7150 }
7151 var idlec chan struct{}
7152 select {
7153 case err := <-reqerrc:
7154 t.Fatalf("request 1: got err %v, want nil", err)
7155 case idlec = <-putidlec:
7156 }
7157
7158 wg.Add(1)
7159 cancelctx, cancel := context.WithCancel(context.Background())
7160 go func() {
7161 defer wg.Done()
7162 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
7163 res, err := client.Do(req)
7164 if err == nil {
7165 res.Body.Close()
7166 }
7167 if !errors.Is(err, context.Canceled) {
7168 t.Errorf("request 2: got err %v, want Canceled", err)
7169 }
7170
7171
7172 close(idlec)
7173 }()
7174
7175
7176
7177 r2c := <-reqc
7178 cancel()
7179
7180 <-idlec
7181
7182 close(r2c)
7183 wg.Wait()
7184 }
7185
7186 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
7187 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
7188 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7189 go io.Copy(io.Discard, req.Body)
7190 panic(ErrAbortHandler)
7191 })).ts
7192
7193 var wg sync.WaitGroup
7194 for i := 0; i < 2; i++ {
7195 wg.Add(1)
7196 go func() {
7197 defer wg.Done()
7198 for j := 0; j < 10; j++ {
7199 const reqLen = 6 * 1024 * 1024
7200 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
7201 req.ContentLength = reqLen
7202 resp, _ := ts.Client().Transport.RoundTrip(req)
7203 if resp != nil {
7204 resp.Body.Close()
7205 }
7206 }
7207 }()
7208 }
7209 wg.Wait()
7210 }
7211
7212 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
7213 func testRequestSanitization(t *testing.T, mode testMode) {
7214 if mode == http2Mode {
7215
7216 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
7217 }
7218 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7219 if h, ok := req.Header["X-Evil"]; ok {
7220 t.Errorf("request has X-Evil header: %q", h)
7221 }
7222 })).ts
7223 req, _ := NewRequest("GET", ts.URL, nil)
7224 req.Host = "go.dev\r\nX-Evil:evil"
7225 resp, _ := ts.Client().Do(req)
7226 if resp != nil {
7227 resp.Body.Close()
7228 }
7229 }
7230
7231 func TestProxyAuthHeader(t *testing.T) {
7232
7233 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7234 }
7235 func testProxyAuthHeader(t *testing.T, mode testMode) {
7236 const username = "u"
7237 const password = "@/?!"
7238 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7239
7240
7241 var r2 Request
7242 r2.Header = Header{
7243 "Authorization": req.Header["Proxy-Authorization"],
7244 }
7245 gotuser, gotpass, ok := r2.BasicAuth()
7246 if !ok || gotuser != username || gotpass != password {
7247 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7248 }
7249 }))
7250 u, err := url.Parse(cst.ts.URL)
7251 if err != nil {
7252 t.Fatal(err)
7253 }
7254 u.User = url.UserPassword(username, password)
7255 t.Setenv("HTTP_PROXY", u.String())
7256 cst.tr.Proxy = ProxyURL(u)
7257 resp, err := cst.c.Get("http://_/")
7258 if err != nil {
7259 t.Fatal(err)
7260 }
7261 resp.Body.Close()
7262 }
7263
7264
7265 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7266 ln := newLocalListener(t)
7267 addr := ln.Addr().String()
7268
7269 done := make(chan struct{})
7270 go func() {
7271 conn, err := ln.Accept()
7272 if err != nil {
7273 t.Errorf("ln.Accept: %v", err)
7274 return
7275 }
7276
7277
7278 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7279 t.Errorf("conn.Read: %v", err)
7280 return
7281 }
7282 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7283 <-done
7284 conn.Close()
7285 }()
7286
7287 didRead := make(chan bool)
7288 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7289 defer SetReadLoopBeforeNextReadHook(nil)
7290
7291 tr := &Transport{}
7292
7293
7294 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7295 if err != nil {
7296 t.Fatalf("NewRequest: %v", err)
7297 }
7298
7299 resp, err := tr.RoundTrip(req)
7300 if err != nil {
7301 t.Fatalf("tr.RoundTrip: %v", err)
7302 }
7303
7304 close(done)
7305
7306
7307
7308 <-didRead
7309
7310 resp.Body.Close()
7311
7312
7313
7314 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7315 n := tr.NumPendingRequestsForTesting()
7316 if n > 0 {
7317 if d > 0 {
7318 t.Logf("pending requests = %d after %v (want 0)", n, d)
7319 }
7320 return false
7321 }
7322 return true
7323 })
7324 }
7325
7326 func TestValidateClientRequestTrailers(t *testing.T) {
7327 run(t, testValidateClientRequestTrailers)
7328 }
7329
7330 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7331 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7332 rw.Write([]byte("Hello"))
7333 })).ts
7334
7335 cases := []struct {
7336 trailer Header
7337 wantErr string
7338 }{
7339 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7340 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7341 }
7342
7343 for i, tt := range cases {
7344 testName := fmt.Sprintf("%s%d", mode, i)
7345 t.Run(testName, func(t *testing.T) {
7346 req, err := NewRequest("GET", cst.URL, nil)
7347 if err != nil {
7348 t.Fatal(err)
7349 }
7350 req.Trailer = tt.trailer
7351 res, err := cst.Client().Do(req)
7352 if err == nil {
7353 t.Fatal("Expected an error")
7354 }
7355 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7356 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7357 }
7358 if res != nil {
7359 t.Fatal("Unexpected non-nil response")
7360 }
7361 })
7362 }
7363 }
7364
7365 func TestTransportServerProtocols(t *testing.T) {
7366 CondSkipHTTP2(t)
7367 DefaultTransport.(*Transport).CloseIdleConnections()
7368
7369 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7370 if err != nil {
7371 t.Fatal(err)
7372 }
7373 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7374 if err != nil {
7375 t.Fatal(err)
7376 }
7377 certpool := x509.NewCertPool()
7378 certpool.AddCert(leafCert)
7379
7380 for _, test := range []struct {
7381 name string
7382 scheme string
7383 setup func(t *testing.T)
7384 transport func(*Transport)
7385 server func(*Server)
7386 want string
7387 }{{
7388 name: "http default",
7389 scheme: "http",
7390 want: "HTTP/1.1",
7391 }, {
7392 name: "https default",
7393 scheme: "https",
7394 transport: func(tr *Transport) {
7395
7396 },
7397 want: "HTTP/1.1",
7398 }, {
7399 name: "https transport protocols include HTTP2",
7400 scheme: "https",
7401 transport: func(tr *Transport) {
7402
7403
7404 tr.Protocols = &Protocols{}
7405 tr.Protocols.SetHTTP1(true)
7406 tr.Protocols.SetHTTP2(true)
7407 },
7408 want: "HTTP/2.0",
7409 }, {
7410 name: "https transport protocols only include HTTP1",
7411 scheme: "https",
7412 transport: func(tr *Transport) {
7413
7414 tr.Protocols = &Protocols{}
7415 tr.Protocols.SetHTTP1(true)
7416 },
7417 want: "HTTP/1.1",
7418 }, {
7419 name: "https transport ForceAttemptHTTP2",
7420 scheme: "https",
7421 transport: func(tr *Transport) {
7422
7423 tr.ForceAttemptHTTP2 = true
7424 },
7425 want: "HTTP/2.0",
7426 }, {
7427 name: "https transport protocols override TLSNextProto",
7428 scheme: "https",
7429 transport: func(tr *Transport) {
7430
7431
7432
7433 tr.Protocols = &Protocols{}
7434 tr.Protocols.SetHTTP1(true)
7435 tr.Protocols.SetHTTP2(true)
7436 tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
7437 },
7438 want: "HTTP/2.0",
7439 }, {
7440 name: "https server disables HTTP2 with TLSNextProto",
7441 scheme: "https",
7442 server: func(srv *Server) {
7443
7444
7445 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7446 },
7447 want: "HTTP/1.1",
7448 }, {
7449 name: "https server Protocols overrides empty TLSNextProto",
7450 scheme: "https",
7451 server: func(srv *Server) {
7452
7453
7454 srv.Protocols = &Protocols{}
7455 srv.Protocols.SetHTTP1(true)
7456 srv.Protocols.SetHTTP2(true)
7457 srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
7458 },
7459 want: "HTTP/2.0",
7460 }, {
7461 name: "https server protocols only include HTTP1",
7462 scheme: "https",
7463 server: func(srv *Server) {
7464 srv.Protocols = &Protocols{}
7465 srv.Protocols.SetHTTP1(true)
7466 },
7467 want: "HTTP/1.1",
7468 }, {
7469 name: "https server protocols include HTTP2",
7470 scheme: "https",
7471 server: func(srv *Server) {
7472 srv.Protocols = &Protocols{}
7473 srv.Protocols.SetHTTP1(true)
7474 srv.Protocols.SetHTTP2(true)
7475 },
7476 want: "HTTP/2.0",
7477 }, {
7478 name: "GODEBUG disables HTTP2 client",
7479 scheme: "https",
7480 setup: func(t *testing.T) {
7481 t.Setenv("GODEBUG", "http2client=0")
7482 },
7483 transport: func(tr *Transport) {
7484
7485
7486 tr.Protocols = &Protocols{}
7487 tr.Protocols.SetHTTP1(true)
7488 tr.Protocols.SetHTTP2(true)
7489 },
7490 want: "HTTP/1.1",
7491 }, {
7492 name: "GODEBUG disables HTTP2 server",
7493 scheme: "https",
7494 setup: func(t *testing.T) {
7495 t.Setenv("GODEBUG", "http2server=0")
7496 },
7497 transport: func(tr *Transport) {
7498
7499
7500 tr.Protocols = &Protocols{}
7501 tr.Protocols.SetHTTP1(true)
7502 tr.Protocols.SetHTTP2(true)
7503 },
7504 want: "HTTP/1.1",
7505 }, {
7506 name: "unencrypted HTTP2 with prior knowledge",
7507 scheme: "http",
7508 transport: func(tr *Transport) {
7509 tr.Protocols = &Protocols{}
7510 tr.Protocols.SetUnencryptedHTTP2(true)
7511 },
7512 server: func(srv *Server) {
7513 srv.Protocols = &Protocols{}
7514 srv.Protocols.SetHTTP1(true)
7515 srv.Protocols.SetUnencryptedHTTP2(true)
7516 },
7517 want: "HTTP/2.0",
7518 }, {
7519 name: "unencrypted HTTP2 only on server",
7520 scheme: "http",
7521 transport: func(tr *Transport) {
7522 tr.Protocols = &Protocols{}
7523 tr.Protocols.SetUnencryptedHTTP2(true)
7524 },
7525 server: func(srv *Server) {
7526 srv.Protocols = &Protocols{}
7527 srv.Protocols.SetUnencryptedHTTP2(true)
7528 },
7529 want: "HTTP/2.0",
7530 }, {
7531 name: "unencrypted HTTP2 with no server support",
7532 scheme: "http",
7533 transport: func(tr *Transport) {
7534 tr.Protocols = &Protocols{}
7535 tr.Protocols.SetUnencryptedHTTP2(true)
7536 },
7537 server: func(srv *Server) {
7538 srv.Protocols = &Protocols{}
7539 srv.Protocols.SetHTTP1(true)
7540 },
7541 want: "error",
7542 }, {
7543 name: "HTTP1 with no server support",
7544 scheme: "http",
7545 transport: func(tr *Transport) {
7546 tr.Protocols = &Protocols{}
7547 tr.Protocols.SetHTTP1(true)
7548 },
7549 server: func(srv *Server) {
7550 srv.Protocols = &Protocols{}
7551 srv.Protocols.SetUnencryptedHTTP2(true)
7552 },
7553 want: "error",
7554 }, {
7555 name: "HTTPS1 with no server support",
7556 scheme: "https",
7557 transport: func(tr *Transport) {
7558 tr.Protocols = &Protocols{}
7559 tr.Protocols.SetHTTP1(true)
7560 },
7561 server: func(srv *Server) {
7562 srv.Protocols = &Protocols{}
7563 srv.Protocols.SetHTTP2(true)
7564 },
7565 want: "error",
7566 }} {
7567 t.Run(test.name, func(t *testing.T) {
7568
7569
7570 srv := &Server{
7571 TLSConfig: &tls.Config{
7572 Certificates: []tls.Certificate{cert},
7573 },
7574 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
7575 w.Header().Set("X-Proto", req.Proto)
7576 }),
7577 }
7578 tr := &Transport{
7579 TLSClientConfig: &tls.Config{
7580 RootCAs: certpool,
7581 },
7582 }
7583
7584 if test.setup != nil {
7585 test.setup(t)
7586 }
7587 if test.server != nil {
7588 test.server(srv)
7589 }
7590 if test.transport != nil {
7591 test.transport(tr)
7592 } else {
7593 tr.Protocols = &Protocols{}
7594 tr.Protocols.SetHTTP1(true)
7595 tr.Protocols.SetHTTP2(true)
7596 }
7597
7598 listener := newLocalListener(t)
7599 srvc := make(chan error, 1)
7600 go func() {
7601 switch test.scheme {
7602 case "http":
7603 srvc <- srv.Serve(listener)
7604 case "https":
7605 srvc <- srv.ServeTLS(listener, "", "")
7606 }
7607 }()
7608 t.Cleanup(func() {
7609 srv.Close()
7610 <-srvc
7611 })
7612
7613 client := &Client{Transport: tr}
7614 resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
7615 if err != nil {
7616 if test.want == "error" {
7617 return
7618 }
7619 t.Fatal(err)
7620 }
7621 if got := resp.Header.Get("X-Proto"); got != test.want {
7622 t.Fatalf("request proto %q, want %q", got, test.want)
7623 }
7624 })
7625 }
7626 }
7627
7628 func TestIssue61474(t *testing.T) {
7629 run(t, testIssue61474, []testMode{http2Mode})
7630 }
7631 func testIssue61474(t *testing.T, mode testMode) {
7632 if testing.Short() {
7633 return
7634 }
7635
7636
7637
7638
7639 t.Skip("test is too large")
7640
7641 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7642 }), func(tr *Transport) {
7643 tr.MaxConnsPerHost = 1
7644 })
7645 var wg sync.WaitGroup
7646 defer wg.Wait()
7647 for range 100000 {
7648 wg.Go(func() {
7649 ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond)
7650 defer cancel()
7651 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
7652 resp, err := cst.c.Do(req)
7653 if err == nil {
7654 resp.Body.Close()
7655 }
7656 })
7657 }
7658 }
7659
View as plain text