Source file
src/net/http/transport_test.go
1
2
3
4
5
6
7
8
9
10 package http_test
11
12 import (
13 "bufio"
14 "bytes"
15 "compress/gzip"
16 "context"
17 "crypto/rand"
18 "crypto/tls"
19 "crypto/x509"
20 "encoding/binary"
21 "errors"
22 "fmt"
23 "go/token"
24 "internal/nettrace"
25 "io"
26 "log"
27 mrand "math/rand"
28 "net"
29 . "net/http"
30 "net/http/httptest"
31 "net/http/httptrace"
32 "net/http/httputil"
33 "net/http/internal/testcert"
34 "net/textproto"
35 "net/url"
36 "os"
37 "reflect"
38 "runtime"
39 "strconv"
40 "strings"
41 "sync"
42 "sync/atomic"
43 "testing"
44 "testing/iotest"
45 "time"
46
47 "golang.org/x/net/http/httpguts"
48 )
49
50
51
52
53
54 var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55 if r.FormValue("close") == "true" {
56 w.Header().Set("Connection", "close")
57 }
58 w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59 w.Write([]byte(r.RemoteAddr))
60
61
62
63 if c, ok := ResponseWriterConnForTesting(w); ok {
64 fmt.Fprintf(w, ", %T %p", c, c)
65 }
66 })
67
68
69 type testCloseConn struct {
70 net.Conn
71 set *testConnSet
72 }
73
74 func (c *testCloseConn) Close() error {
75 c.set.remove(c)
76 return c.Conn.Close()
77 }
78
79
80
81 type testConnSet struct {
82 t *testing.T
83 mu sync.Mutex
84 closed map[net.Conn]bool
85 list []net.Conn
86 }
87
88 func (tcs *testConnSet) insert(c net.Conn) {
89 tcs.mu.Lock()
90 defer tcs.mu.Unlock()
91 tcs.closed[c] = false
92 tcs.list = append(tcs.list, c)
93 }
94
95 func (tcs *testConnSet) remove(c net.Conn) {
96 tcs.mu.Lock()
97 defer tcs.mu.Unlock()
98 tcs.closed[c] = true
99 }
100
101
102 func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
103 connSet := &testConnSet{
104 t: t,
105 closed: make(map[net.Conn]bool),
106 }
107 dial := func(n, addr string) (net.Conn, error) {
108 c, err := net.Dial(n, addr)
109 if err != nil {
110 return nil, err
111 }
112 tc := &testCloseConn{c, connSet}
113 connSet.insert(tc)
114 return tc, nil
115 }
116 return connSet, dial
117 }
118
119 func (tcs *testConnSet) check(t *testing.T) {
120 tcs.mu.Lock()
121 defer tcs.mu.Unlock()
122 for i := 4; i >= 0; i-- {
123 for i, c := range tcs.list {
124 if tcs.closed[c] {
125 continue
126 }
127 if i != 0 {
128
129
130 tcs.mu.Unlock()
131 time.Sleep(50 * time.Millisecond)
132 tcs.mu.Lock()
133 continue
134 }
135 t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
136 }
137 }
138 }
139
140 func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
141 func testReuseRequest(t *testing.T, mode testMode) {
142 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
143 w.Write([]byte("{}"))
144 })).ts
145
146 c := ts.Client()
147 req, _ := NewRequest("GET", ts.URL, nil)
148 res, err := c.Do(req)
149 if err != nil {
150 t.Fatal(err)
151 }
152 err = res.Body.Close()
153 if err != nil {
154 t.Fatal(err)
155 }
156
157 res, err = c.Do(req)
158 if err != nil {
159 t.Fatal(err)
160 }
161 err = res.Body.Close()
162 if err != nil {
163 t.Fatal(err)
164 }
165 }
166
167
168
169 func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
170 func testTransportKeepAlives(t *testing.T, mode testMode) {
171 ts := newClientServerTest(t, mode, hostPortHandler).ts
172
173 c := ts.Client()
174 for _, disableKeepAlive := range []bool{false, true} {
175 c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
176 fetch := func(n int) string {
177 res, err := c.Get(ts.URL)
178 if err != nil {
179 t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
180 }
181 body, err := io.ReadAll(res.Body)
182 if err != nil {
183 t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
184 }
185 return string(body)
186 }
187
188 body1 := fetch(1)
189 body2 := fetch(2)
190
191 bodiesDiffer := body1 != body2
192 if bodiesDiffer != disableKeepAlive {
193 t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
194 disableKeepAlive, bodiesDiffer, body1, body2)
195 }
196 }
197 }
198
199 func TestTransportConnectionCloseOnResponse(t *testing.T) {
200 run(t, testTransportConnectionCloseOnResponse)
201 }
202 func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
203 ts := newClientServerTest(t, mode, hostPortHandler).ts
204
205 connSet, testDial := makeTestDial(t)
206
207 c := ts.Client()
208 tr := c.Transport.(*Transport)
209 tr.Dial = testDial
210
211 for _, connectionClose := range []bool{false, true} {
212 fetch := func(n int) string {
213 req := new(Request)
214 var err error
215 req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
216 if err != nil {
217 t.Fatalf("URL parse error: %v", err)
218 }
219 req.Method = "GET"
220 req.Proto = "HTTP/1.1"
221 req.ProtoMajor = 1
222 req.ProtoMinor = 1
223
224 res, err := c.Do(req)
225 if err != nil {
226 t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
227 }
228 defer res.Body.Close()
229 body, err := io.ReadAll(res.Body)
230 if err != nil {
231 t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
232 }
233 return string(body)
234 }
235
236 body1 := fetch(1)
237 body2 := fetch(2)
238 bodiesDiffer := body1 != body2
239 if bodiesDiffer != connectionClose {
240 t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
241 connectionClose, bodiesDiffer, body1, body2)
242 }
243
244 tr.CloseIdleConnections()
245 }
246
247 connSet.check(t)
248 }
249
250
251
252
253
254
255
256 func TestTransportConnectionCloseOnRequest(t *testing.T) {
257 run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
258 }
259 func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
260 ts := newClientServerTest(t, mode, hostPortHandler).ts
261
262 connSet, testDial := makeTestDial(t)
263
264 c := ts.Client()
265 tr := c.Transport.(*Transport)
266 tr.Dial = testDial
267 for _, reqClose := range []bool{false, true} {
268 fetch := func(n int) string {
269 req := new(Request)
270 var err error
271 req.URL, err = url.Parse(ts.URL)
272 if err != nil {
273 t.Fatalf("URL parse error: %v", err)
274 }
275 req.Method = "GET"
276 req.Proto = "HTTP/1.1"
277 req.ProtoMajor = 1
278 req.ProtoMinor = 1
279 req.Close = reqClose
280
281 res, err := c.Do(req)
282 if err != nil {
283 t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
284 }
285 if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
286 t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
287 reqClose, got, !reqClose)
288 }
289 body, err := io.ReadAll(res.Body)
290 if err != nil {
291 t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
292 }
293 return string(body)
294 }
295
296 body1 := fetch(1)
297 body2 := fetch(2)
298
299 got := 1
300 if body1 != body2 {
301 got++
302 }
303 want := 1
304 if reqClose {
305 want = 2
306 }
307 if got != want {
308 t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
309 reqClose, got, want, body1, body2)
310 }
311
312 tr.CloseIdleConnections()
313 }
314
315 connSet.check(t)
316 }
317
318
319
320
321 func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
322 run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
323 }
324 func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
325 ts := newClientServerTest(t, mode, hostPortHandler).ts
326
327 c := ts.Client()
328 c.Transport.(*Transport).DisableKeepAlives = true
329
330 res, err := c.Get(ts.URL)
331 if err != nil {
332 t.Fatal(err)
333 }
334 res.Body.Close()
335 if res.Header.Get("X-Saw-Close") != "true" {
336 t.Errorf("handler didn't see Connection: close ")
337 }
338 }
339
340
341
342 func TestTransportRespectRequestWantsClose(t *testing.T) {
343 run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
344 }
345 func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
346 tests := []struct {
347 disableKeepAlives bool
348 close bool
349 }{
350 {disableKeepAlives: false, close: false},
351 {disableKeepAlives: false, close: true},
352 {disableKeepAlives: true, close: false},
353 {disableKeepAlives: true, close: true},
354 }
355
356 for _, tc := range tests {
357 t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
358 func(t *testing.T) {
359 ts := newClientServerTest(t, mode, hostPortHandler).ts
360
361 c := ts.Client()
362 c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
363 req, err := NewRequest("GET", ts.URL, nil)
364 if err != nil {
365 t.Fatal(err)
366 }
367 count := 0
368 trace := &httptrace.ClientTrace{
369 WroteHeaderField: func(key string, field []string) {
370 if key != "Connection" {
371 return
372 }
373 if httpguts.HeaderValuesContainsToken(field, "close") {
374 count += 1
375 }
376 },
377 }
378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
379 req.Close = tc.close
380 res, err := c.Do(req)
381 if err != nil {
382 t.Fatal(err)
383 }
384 defer res.Body.Close()
385 if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
386 t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
387 }
388 })
389 }
390
391 }
392
393 func TestTransportIdleCacheKeys(t *testing.T) {
394 run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
395 }
396 func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
397 ts := newClientServerTest(t, mode, hostPortHandler).ts
398 c := ts.Client()
399 tr := c.Transport.(*Transport)
400
401 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
402 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
403 }
404
405 resp, err := c.Get(ts.URL)
406 if err != nil {
407 t.Error(err)
408 }
409 io.ReadAll(resp.Body)
410
411 keys := tr.IdleConnKeysForTesting()
412 if e, g := 1, len(keys); e != g {
413 t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
414 }
415
416 if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
417 t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
418 }
419
420 tr.CloseIdleConnections()
421 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
422 t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
423 }
424 }
425
426
427
428 func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
429 func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
430 const msg = "foobar"
431
432 var addrSeen map[string]int
433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
434 addrSeen[r.RemoteAddr]++
435 if r.URL.Path == "/chunked/" {
436 w.WriteHeader(200)
437 w.(Flusher).Flush()
438 } else {
439 w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
440 w.WriteHeader(200)
441 }
442 w.Write([]byte(msg))
443 })).ts
444
445 for pi, path := range []string{"/content-length/", "/chunked/"} {
446 wantLen := []int{len(msg), -1}[pi]
447 addrSeen = make(map[string]int)
448 for i := 0; i < 3; i++ {
449 res, err := ts.Client().Get(ts.URL + path)
450 if err != nil {
451 t.Errorf("Get %s: %v", path, err)
452 continue
453 }
454
455
456
457
458
459 defer res.Body.Close()
460
461 if res.ContentLength != int64(wantLen) {
462 t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
463 }
464 got, err := io.ReadAll(res.Body)
465 if string(got) != msg || err != nil {
466 t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
467 }
468 }
469 if len(addrSeen) != 1 {
470 t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
471 }
472 }
473 }
474
475 func TestTransportMaxPerHostIdleConns(t *testing.T) {
476 run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
477 }
478 func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
479 stop := make(chan struct{})
480 defer close(stop)
481
482 resch := make(chan string)
483 gotReq := make(chan bool)
484 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
485 gotReq <- true
486 var msg string
487 select {
488 case <-stop:
489 return
490 case msg = <-resch:
491 }
492 _, err := w.Write([]byte(msg))
493 if err != nil {
494 t.Errorf("Write: %v", err)
495 return
496 }
497 })).ts
498
499 c := ts.Client()
500 tr := c.Transport.(*Transport)
501 maxIdleConnsPerHost := 2
502 tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
503
504
505
506 donech := make(chan bool)
507 doReq := func() {
508 defer func() {
509 select {
510 case <-stop:
511 return
512 case donech <- t.Failed():
513 }
514 }()
515 resp, err := c.Get(ts.URL)
516 if err != nil {
517 t.Error(err)
518 return
519 }
520 if _, err := io.ReadAll(resp.Body); err != nil {
521 t.Errorf("ReadAll: %v", err)
522 return
523 }
524 }
525 go doReq()
526 <-gotReq
527 go doReq()
528 <-gotReq
529 go doReq()
530 <-gotReq
531
532 if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
533 t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
534 }
535
536 resch <- "res1"
537 <-donech
538 keys := tr.IdleConnKeysForTesting()
539 if e, g := 1, len(keys); e != g {
540 t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
541 }
542 addr := ts.Listener.Addr().String()
543 cacheKey := "|http|" + addr
544 if keys[0] != cacheKey {
545 t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
546 }
547 if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
548 t.Errorf("after first response, expected %d idle conns; got %d", e, g)
549 }
550
551 resch <- "res2"
552 <-donech
553 if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
554 t.Errorf("after second response, idle conns = %d; want %d", g, w)
555 }
556
557 resch <- "res3"
558 <-donech
559 if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
560 t.Errorf("after third response, idle conns = %d; want %d", g, w)
561 }
562 }
563
564 func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
565 run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
566 }
567 func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
568 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
569 _, err := w.Write([]byte("foo"))
570 if err != nil {
571 t.Fatalf("Write: %v", err)
572 }
573 })).ts
574 c := ts.Client()
575 tr := c.Transport.(*Transport)
576 dialStarted := make(chan struct{})
577 stallDial := make(chan struct{})
578 tr.Dial = func(network, addr string) (net.Conn, error) {
579 dialStarted <- struct{}{}
580 <-stallDial
581 return net.Dial(network, addr)
582 }
583
584 tr.DisableKeepAlives = true
585 tr.MaxConnsPerHost = 1
586
587 preDial := make(chan struct{})
588 reqComplete := make(chan struct{})
589 doReq := func(reqId string) {
590 req, _ := NewRequest("GET", ts.URL, nil)
591 trace := &httptrace.ClientTrace{
592 GetConn: func(hostPort string) {
593 preDial <- struct{}{}
594 },
595 }
596 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
597 resp, err := tr.RoundTrip(req)
598 if err != nil {
599 t.Errorf("unexpected error for request %s: %v", reqId, err)
600 }
601 _, err = io.ReadAll(resp.Body)
602 if err != nil {
603 t.Errorf("unexpected error for request %s: %v", reqId, err)
604 }
605 reqComplete <- struct{}{}
606 }
607
608 go doReq("req1")
609 <-preDial
610 <-dialStarted
611
612
613 go doReq("req2")
614 <-preDial
615 select {
616 case <-dialStarted:
617 t.Error("req2 dial started while req1 dial in progress")
618 return
619 default:
620 }
621
622
623 stallDial <- struct{}{}
624 <-reqComplete
625
626
627 <-dialStarted
628 stallDial <- struct{}{}
629 <-reqComplete
630 }
631
632 func TestTransportMaxConnsPerHost(t *testing.T) {
633 run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
634 }
635 func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
636 CondSkipHTTP2(t)
637
638 h := HandlerFunc(func(w ResponseWriter, r *Request) {
639 _, err := w.Write([]byte("foo"))
640 if err != nil {
641 t.Fatalf("Write: %v", err)
642 }
643 })
644
645 ts := newClientServerTest(t, mode, h).ts
646 c := ts.Client()
647 tr := c.Transport.(*Transport)
648 tr.MaxConnsPerHost = 1
649
650 mu := sync.Mutex{}
651 var conns []net.Conn
652 var dialCnt, gotConnCnt, tlsHandshakeCnt int32
653 tr.Dial = func(network, addr string) (net.Conn, error) {
654 atomic.AddInt32(&dialCnt, 1)
655 c, err := net.Dial(network, addr)
656 mu.Lock()
657 defer mu.Unlock()
658 conns = append(conns, c)
659 return c, err
660 }
661
662 doReq := func() {
663 trace := &httptrace.ClientTrace{
664 GotConn: func(connInfo httptrace.GotConnInfo) {
665 if !connInfo.Reused {
666 atomic.AddInt32(&gotConnCnt, 1)
667 }
668 },
669 TLSHandshakeStart: func() {
670 atomic.AddInt32(&tlsHandshakeCnt, 1)
671 },
672 }
673 req, _ := NewRequest("GET", ts.URL, nil)
674 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
675
676 resp, err := c.Do(req)
677 if err != nil {
678 t.Fatalf("request failed: %v", err)
679 }
680 defer resp.Body.Close()
681 _, err = io.ReadAll(resp.Body)
682 if err != nil {
683 t.Fatalf("read body failed: %v", err)
684 }
685 }
686
687 wg := sync.WaitGroup{}
688 for i := 0; i < 10; i++ {
689 wg.Add(1)
690 go func() {
691 defer wg.Done()
692 doReq()
693 }()
694 }
695 wg.Wait()
696
697 expected := int32(tr.MaxConnsPerHost)
698 if dialCnt != expected {
699 t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
700 }
701 if gotConnCnt != expected {
702 t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
703 }
704 if ts.TLS != nil && tlsHandshakeCnt != expected {
705 t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
706 }
707
708 if t.Failed() {
709 t.FailNow()
710 }
711
712 mu.Lock()
713 for _, c := range conns {
714 c.Close()
715 }
716 conns = nil
717 mu.Unlock()
718 tr.CloseIdleConnections()
719
720 doReq()
721 expected++
722 if dialCnt != expected {
723 t.Errorf("round 2: too many dials: %d", dialCnt)
724 }
725 if gotConnCnt != expected {
726 t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
727 }
728 if ts.TLS != nil && tlsHandshakeCnt != expected {
729 t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
730 }
731 }
732
733 func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
734 run(t, testTransportMaxConnsPerHostDialCancellation,
735 testNotParallel,
736 []testMode{http1Mode, https1Mode, http2Mode},
737 )
738 }
739
740 func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
741 CondSkipHTTP2(t)
742
743 h := HandlerFunc(func(w ResponseWriter, r *Request) {
744 _, err := w.Write([]byte("foo"))
745 if err != nil {
746 t.Fatalf("Write: %v", err)
747 }
748 })
749
750 cst := newClientServerTest(t, mode, h)
751 defer cst.close()
752 ts := cst.ts
753 c := ts.Client()
754 tr := c.Transport.(*Transport)
755 tr.MaxConnsPerHost = 1
756
757
758 ctx, cancel := context.WithCancel(context.Background())
759 defer cancel()
760 SetPendingDialHooks(cancel, nil)
761 defer SetPendingDialHooks(nil, nil)
762
763 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
764 _, err := c.Do(req)
765 if !errors.Is(err, context.Canceled) {
766 t.Errorf("expected error %v, got %v", context.Canceled, err)
767 }
768
769
770 SetPendingDialHooks(nil, nil)
771 req, _ = NewRequest("GET", ts.URL, nil)
772 resp, err := c.Do(req)
773 if err != nil {
774 t.Fatalf("request failed: %v", err)
775 }
776 defer resp.Body.Close()
777 _, err = io.ReadAll(resp.Body)
778 if err != nil {
779 t.Fatalf("read body failed: %v", err)
780 }
781 }
782
783 func TestTransportRemovesDeadIdleConnections(t *testing.T) {
784 run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
785 }
786 func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
788 io.WriteString(w, r.RemoteAddr)
789 })).ts
790
791 c := ts.Client()
792 tr := c.Transport.(*Transport)
793
794 doReq := func(name string) {
795
796
797 res, err := c.Post(ts.URL, "", nil)
798 if err != nil {
799 t.Fatalf("%s: %v", name, err)
800 }
801 if res.StatusCode != 200 {
802 t.Fatalf("%s: %v", name, res.Status)
803 }
804 defer res.Body.Close()
805 slurp, err := io.ReadAll(res.Body)
806 if err != nil {
807 t.Fatalf("%s: %v", name, err)
808 }
809 t.Logf("%s: ok (%q)", name, slurp)
810 }
811
812 doReq("first")
813 keys1 := tr.IdleConnKeysForTesting()
814
815 ts.CloseClientConnections()
816
817 var keys2 []string
818 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
819 keys2 = tr.IdleConnKeysForTesting()
820 if len(keys2) != 0 {
821 if d > 0 {
822 t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
823 }
824 return false
825 }
826 return true
827 })
828
829 doReq("second")
830 }
831
832
833
834 func TestTransportServerClosingUnexpectedly(t *testing.T) {
835 run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
836 }
837 func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
838 ts := newClientServerTest(t, mode, hostPortHandler).ts
839 c := ts.Client()
840
841 fetch := func(n, retries int) string {
842 condFatalf := func(format string, arg ...any) {
843 if retries <= 0 {
844 t.Fatalf(format, arg...)
845 }
846 t.Logf("retrying shortly after expected error: "+format, arg...)
847 time.Sleep(time.Second / time.Duration(retries))
848 }
849 for retries >= 0 {
850 retries--
851 res, err := c.Get(ts.URL)
852 if err != nil {
853 condFatalf("error in req #%d, GET: %v", n, err)
854 continue
855 }
856 body, err := io.ReadAll(res.Body)
857 if err != nil {
858 condFatalf("error in req #%d, ReadAll: %v", n, err)
859 continue
860 }
861 res.Body.Close()
862 return string(body)
863 }
864 panic("unreachable")
865 }
866
867 body1 := fetch(1, 0)
868 body2 := fetch(2, 0)
869
870
871
872
873
874
875
876
877 ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
878
879 body3 := fetch(3, 5)
880
881 if body1 != body2 {
882 t.Errorf("expected body1 and body2 to be equal")
883 }
884 if body2 == body3 {
885 t.Errorf("expected body2 and body3 to be different")
886 }
887 }
888
889
890
891 func TestStressSurpriseServerCloses(t *testing.T) {
892 run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
893 }
894 func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
895 if testing.Short() {
896 t.Skip("skipping test in short mode")
897 }
898 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
899 w.Header().Set("Content-Length", "5")
900 w.Header().Set("Content-Type", "text/plain")
901 w.Write([]byte("Hello"))
902 w.(Flusher).Flush()
903 conn, buf, _ := w.(Hijacker).Hijack()
904 buf.Flush()
905 conn.Close()
906 })).ts
907 c := ts.Client()
908
909
910
911
912
913
914
915 const (
916 numClients = 20
917 reqsPerClient = 25
918 )
919 var wg sync.WaitGroup
920 wg.Add(numClients * reqsPerClient)
921 for i := 0; i < numClients; i++ {
922 go func() {
923 for i := 0; i < reqsPerClient; i++ {
924 res, err := c.Get(ts.URL)
925 if err == nil {
926
927
928
929
930
931
932 res.Body.Close()
933 }
934 wg.Done()
935 }
936 }()
937 }
938
939
940 wg.Wait()
941 }
942
943
944
945 func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
946 func testTransportHeadResponses(t *testing.T, mode testMode) {
947 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
948 if r.Method != "HEAD" {
949 panic("expected HEAD; got " + r.Method)
950 }
951 w.Header().Set("Content-Length", "123")
952 w.WriteHeader(200)
953 })).ts
954 c := ts.Client()
955
956 for i := 0; i < 2; i++ {
957 res, err := c.Head(ts.URL)
958 if err != nil {
959 t.Errorf("error on loop %d: %v", i, err)
960 continue
961 }
962 if e, g := "123", res.Header.Get("Content-Length"); e != g {
963 t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
964 }
965 if e, g := int64(123), res.ContentLength; e != g {
966 t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
967 }
968 if all, err := io.ReadAll(res.Body); err != nil {
969 t.Errorf("loop %d: Body ReadAll: %v", i, err)
970 } else if len(all) != 0 {
971 t.Errorf("Bogus body %q", all)
972 }
973 }
974 }
975
976
977
978 func TestTransportHeadChunkedResponse(t *testing.T) {
979 run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
980 }
981 func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
982 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
983 if r.Method != "HEAD" {
984 panic("expected HEAD; got " + r.Method)
985 }
986 w.Header().Set("Transfer-Encoding", "chunked")
987 w.Header().Set("x-client-ipport", r.RemoteAddr)
988 w.WriteHeader(200)
989 })).ts
990 c := ts.Client()
991
992
993
994 didRead := make(chan bool)
995 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
996 defer SetReadLoopBeforeNextReadHook(nil)
997
998 res1, err := c.Head(ts.URL)
999 <-didRead
1000
1001 if err != nil {
1002 t.Fatalf("request 1 error: %v", err)
1003 }
1004
1005 res2, err := c.Head(ts.URL)
1006 <-didRead
1007
1008 if err != nil {
1009 t.Fatalf("request 2 error: %v", err)
1010 }
1011 if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1012 t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1013 }
1014 }
1015
1016 var roundTripTests = []struct {
1017 accept string
1018 expectAccept string
1019 compressed bool
1020 }{
1021
1022 {"", "gzip", false},
1023
1024 {"foo", "foo", false},
1025
1026 {"gzip", "gzip", true},
1027 }
1028
1029
1030 func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1031 func testRoundTripGzip(t *testing.T, mode testMode) {
1032 const responseBody = "test response body"
1033 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1034 accept := req.Header.Get("Accept-Encoding")
1035 if expect := req.FormValue("expect_accept"); accept != expect {
1036 t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1037 req.FormValue("testnum"), accept, expect)
1038 }
1039 if accept == "gzip" {
1040 rw.Header().Set("Content-Encoding", "gzip")
1041 gz := gzip.NewWriter(rw)
1042 gz.Write([]byte(responseBody))
1043 gz.Close()
1044 } else {
1045 rw.Header().Set("Content-Encoding", accept)
1046 rw.Write([]byte(responseBody))
1047 }
1048 })).ts
1049 tr := ts.Client().Transport.(*Transport)
1050
1051 for i, test := range roundTripTests {
1052
1053 req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1054 if test.accept != "" {
1055 req.Header.Set("Accept-Encoding", test.accept)
1056 }
1057 res, err := tr.RoundTrip(req)
1058 if err != nil {
1059 t.Errorf("%d. RoundTrip: %v", i, err)
1060 continue
1061 }
1062 var body []byte
1063 if test.compressed {
1064 var r *gzip.Reader
1065 r, err = gzip.NewReader(res.Body)
1066 if err != nil {
1067 t.Errorf("%d. gzip NewReader: %v", i, err)
1068 continue
1069 }
1070 body, err = io.ReadAll(r)
1071 res.Body.Close()
1072 } else {
1073 body, err = io.ReadAll(res.Body)
1074 }
1075 if err != nil {
1076 t.Errorf("%d. Error: %q", i, err)
1077 continue
1078 }
1079 if g, e := string(body), responseBody; g != e {
1080 t.Errorf("%d. body = %q; want %q", i, g, e)
1081 }
1082 if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1083 t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1084 }
1085 if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1086 t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1087 }
1088 }
1089
1090 }
1091
1092 func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1093 func testTransportGzip(t *testing.T, mode testMode) {
1094 if mode == http2Mode {
1095 t.Skip("https://go.dev/issue/56020")
1096 }
1097 const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1098 const nRandBytes = 1024 * 1024
1099 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1100 if req.Method == "HEAD" {
1101 if g := req.Header.Get("Accept-Encoding"); g != "" {
1102 t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1103 }
1104 return
1105 }
1106 if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1107 t.Errorf("Accept-Encoding = %q, want %q", g, e)
1108 }
1109 rw.Header().Set("Content-Encoding", "gzip")
1110
1111 var w io.Writer = rw
1112 var buf bytes.Buffer
1113 if req.FormValue("chunked") == "0" {
1114 w = &buf
1115 defer io.Copy(rw, &buf)
1116 defer func() {
1117 rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1118 }()
1119 }
1120 gz := gzip.NewWriter(w)
1121 gz.Write([]byte(testString))
1122 if req.FormValue("body") == "large" {
1123 io.CopyN(gz, rand.Reader, nRandBytes)
1124 }
1125 gz.Close()
1126 })).ts
1127 c := ts.Client()
1128
1129 for _, chunked := range []string{"1", "0"} {
1130
1131 res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1132 if err != nil {
1133 t.Fatalf("large get: %v", err)
1134 }
1135 buf := make([]byte, len(testString))
1136 n, err := io.ReadFull(res.Body, buf)
1137 if err != nil {
1138 t.Fatalf("partial read of large response: size=%d, %v", n, err)
1139 }
1140 if e, g := testString, string(buf); e != g {
1141 t.Errorf("partial read got %q, expected %q", g, e)
1142 }
1143 res.Body.Close()
1144
1145 n, err = res.Body.Read(buf)
1146 if n != 0 || err == nil {
1147 t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1148 }
1149
1150
1151 res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1152 if err != nil {
1153 t.Fatal(err)
1154 }
1155 body, err := io.ReadAll(res.Body)
1156 if err != nil {
1157 t.Fatal(err)
1158 }
1159 if g, e := string(body), testString; g != e {
1160 t.Fatalf("body = %q; want %q", g, e)
1161 }
1162 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1163 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1164 }
1165
1166
1167 n, err = res.Body.Read(buf)
1168 if n != 0 || err == nil {
1169 t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1170 }
1171 res.Body.Close()
1172 n, err = res.Body.Read(buf)
1173 if n != 0 || err == nil {
1174 t.Errorf("expected Read error after Close; got %d, %v", n, err)
1175 }
1176 }
1177
1178
1179 res, err := c.Head(ts.URL)
1180 if err != nil {
1181 t.Fatalf("Head: %v", err)
1182 }
1183 if res.StatusCode != 200 {
1184 t.Errorf("Head status=%d; want=200", res.StatusCode)
1185 }
1186 }
1187
1188
1189
1190 func TestTransportExpect100Continue(t *testing.T) {
1191 run(t, testTransportExpect100Continue, []testMode{http1Mode})
1192 }
1193 func testTransportExpect100Continue(t *testing.T, mode testMode) {
1194 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1195 switch req.URL.Path {
1196 case "/100":
1197
1198 if _, err := io.Copy(io.Discard, req.Body); err != nil {
1199 t.Error("Failed to read Body", err)
1200 }
1201 rw.WriteHeader(StatusOK)
1202 case "/200":
1203
1204
1205 rw.WriteHeader(StatusOK)
1206 case "/500":
1207 rw.WriteHeader(StatusInternalServerError)
1208 case "/keepalive":
1209
1210 _, bufrw, err := rw.(Hijacker).Hijack()
1211 if err != nil {
1212 log.Fatal(err)
1213 }
1214 bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
1215 bufrw.WriteString("Content-Length: 0\r\n\r\n")
1216 bufrw.Flush()
1217 case "/timeout":
1218
1219
1220 conn, bufrw, err := rw.(Hijacker).Hijack()
1221 if err != nil {
1222 log.Fatal(err)
1223 }
1224 if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
1225 t.Error("Failed to read Body", err)
1226 }
1227 bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
1228 bufrw.Flush()
1229 conn.Close()
1230 }
1231
1232 })).ts
1233
1234 tests := []struct {
1235 path string
1236 body []byte
1237 sent int
1238 status int
1239 }{
1240 {path: "/100", body: []byte("hello"), sent: 5, status: 200},
1241 {path: "/200", body: []byte("hello"), sent: 0, status: 200},
1242 {path: "/500", body: []byte("hello"), sent: 0, status: 500},
1243 {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500},
1244 {path: "/timeout", body: []byte("hello"), sent: 5, status: 200},
1245 }
1246
1247 c := ts.Client()
1248 for i, v := range tests {
1249 tr := &Transport{
1250 ExpectContinueTimeout: 2 * time.Second,
1251 }
1252 defer tr.CloseIdleConnections()
1253 c.Transport = tr
1254 body := bytes.NewReader(v.body)
1255 req, err := NewRequest("PUT", ts.URL+v.path, body)
1256 if err != nil {
1257 t.Fatal(err)
1258 }
1259 req.Header.Set("Expect", "100-continue")
1260 req.ContentLength = int64(len(v.body))
1261
1262 resp, err := c.Do(req)
1263 if err != nil {
1264 t.Fatal(err)
1265 }
1266 resp.Body.Close()
1267
1268 sent := len(v.body) - body.Len()
1269 if v.status != resp.StatusCode {
1270 t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
1271 }
1272 if v.sent != sent {
1273 t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
1274 }
1275 }
1276 }
1277
1278 func TestSOCKS5Proxy(t *testing.T) {
1279 run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1280 }
1281 func testSOCKS5Proxy(t *testing.T, mode testMode) {
1282 ch := make(chan string, 1)
1283 l := newLocalListener(t)
1284 defer l.Close()
1285 defer close(ch)
1286 proxy := func(t *testing.T) {
1287 s, err := l.Accept()
1288 if err != nil {
1289 t.Errorf("socks5 proxy Accept(): %v", err)
1290 return
1291 }
1292 defer s.Close()
1293 var buf [22]byte
1294 if _, err := io.ReadFull(s, buf[:3]); err != nil {
1295 t.Errorf("socks5 proxy initial read: %v", err)
1296 return
1297 }
1298 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1299 t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1300 return
1301 }
1302 if _, err := s.Write([]byte{5, 0}); err != nil {
1303 t.Errorf("socks5 proxy initial write: %v", err)
1304 return
1305 }
1306 if _, err := io.ReadFull(s, buf[:4]); err != nil {
1307 t.Errorf("socks5 proxy second read: %v", err)
1308 return
1309 }
1310 if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1311 t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1312 return
1313 }
1314 var ipLen int
1315 switch buf[3] {
1316 case 1:
1317 ipLen = net.IPv4len
1318 case 4:
1319 ipLen = net.IPv6len
1320 default:
1321 t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1322 return
1323 }
1324 if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1325 t.Errorf("socks5 proxy address read: %v", err)
1326 return
1327 }
1328 ip := net.IP(buf[4 : ipLen+4])
1329 port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1330 copy(buf[:3], []byte{5, 0, 0})
1331 if _, err := s.Write(buf[:ipLen+6]); err != nil {
1332 t.Errorf("socks5 proxy connect write: %v", err)
1333 return
1334 }
1335 ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1336
1337
1338 targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1339 targetConn, err := net.Dial("tcp", targetHost)
1340 if err != nil {
1341 t.Errorf("net.Dial failed")
1342 return
1343 }
1344 go io.Copy(targetConn, s)
1345 io.Copy(s, targetConn)
1346 targetConn.Close()
1347 }
1348
1349 pu, err := url.Parse("socks5://" + l.Addr().String())
1350 if err != nil {
1351 t.Fatal(err)
1352 }
1353
1354 sentinelHeader := "X-Sentinel"
1355 sentinelValue := "12345"
1356 h := HandlerFunc(func(w ResponseWriter, r *Request) {
1357 w.Header().Set(sentinelHeader, sentinelValue)
1358 })
1359 for _, useTLS := range []bool{false, true} {
1360 t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1361 ts := newClientServerTest(t, mode, h).ts
1362 go proxy(t)
1363 c := ts.Client()
1364 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1365 r, err := c.Head(ts.URL)
1366 if err != nil {
1367 t.Fatal(err)
1368 }
1369 if r.Header.Get(sentinelHeader) != sentinelValue {
1370 t.Errorf("Failed to retrieve sentinel value")
1371 }
1372 got := <-ch
1373 ts.Close()
1374 tsu, err := url.Parse(ts.URL)
1375 if err != nil {
1376 t.Fatal(err)
1377 }
1378 want := "proxy for " + tsu.Host
1379 if got != want {
1380 t.Errorf("got %q, want %q", got, want)
1381 }
1382 })
1383 }
1384 }
1385
1386 func TestTransportProxy(t *testing.T) {
1387 defer afterTest(t)
1388 testCases := []struct{ siteMode, proxyMode testMode }{
1389 {http1Mode, http1Mode},
1390 {http1Mode, https1Mode},
1391 {https1Mode, http1Mode},
1392 {https1Mode, https1Mode},
1393 }
1394 for _, testCase := range testCases {
1395 siteMode := testCase.siteMode
1396 proxyMode := testCase.proxyMode
1397 t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1398 siteCh := make(chan *Request, 1)
1399 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1400 siteCh <- r
1401 })
1402 proxyCh := make(chan *Request, 1)
1403 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1404 proxyCh <- r
1405
1406 if r.Method == "CONNECT" {
1407 hijacker, ok := w.(Hijacker)
1408 if !ok {
1409 t.Errorf("hijack not allowed")
1410 return
1411 }
1412 clientConn, _, err := hijacker.Hijack()
1413 if err != nil {
1414 t.Errorf("hijacking failed")
1415 return
1416 }
1417 res := &Response{
1418 StatusCode: StatusOK,
1419 Proto: "HTTP/1.1",
1420 ProtoMajor: 1,
1421 ProtoMinor: 1,
1422 Header: make(Header),
1423 }
1424
1425 targetConn, err := net.Dial("tcp", r.URL.Host)
1426 if err != nil {
1427 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1428 return
1429 }
1430
1431 if err := res.Write(clientConn); err != nil {
1432 t.Errorf("Writing 200 OK failed: %v", err)
1433 return
1434 }
1435
1436 go io.Copy(targetConn, clientConn)
1437 go func() {
1438 io.Copy(clientConn, targetConn)
1439 targetConn.Close()
1440 }()
1441 }
1442 })
1443 ts := newClientServerTest(t, siteMode, h1).ts
1444 proxy := newClientServerTest(t, proxyMode, h2).ts
1445
1446 pu, err := url.Parse(proxy.URL)
1447 if err != nil {
1448 t.Fatal(err)
1449 }
1450
1451
1452
1453
1454 c := proxy.Client()
1455 if siteMode == https1Mode {
1456 c = ts.Client()
1457 }
1458
1459 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1460 if _, err := c.Head(ts.URL); err != nil {
1461 t.Error(err)
1462 }
1463 got := <-proxyCh
1464 c.Transport.(*Transport).CloseIdleConnections()
1465 ts.Close()
1466 proxy.Close()
1467 if siteMode == https1Mode {
1468
1469 if got.Method != "CONNECT" {
1470 t.Errorf("Wrong method for secure proxying: %q", got.Method)
1471 }
1472 gotHost := got.URL.Host
1473 pu, err := url.Parse(ts.URL)
1474 if err != nil {
1475 t.Fatal("Invalid site URL")
1476 }
1477 if wantHost := pu.Host; gotHost != wantHost {
1478 t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1479 }
1480
1481
1482 next := <-siteCh
1483 if next.Method != "HEAD" {
1484 t.Errorf("Wrong method at destination: %s", next.Method)
1485 }
1486 if nextURL := next.URL.String(); nextURL != "/" {
1487 t.Errorf("Wrong URL at destination: %s", nextURL)
1488 }
1489 } else {
1490 if got.Method != "HEAD" {
1491 t.Errorf("Wrong method for destination: %q", got.Method)
1492 }
1493 gotURL := got.URL.String()
1494 wantURL := ts.URL + "/"
1495 if gotURL != wantURL {
1496 t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1497 }
1498 }
1499 })
1500 }
1501 }
1502
1503 func TestOnProxyConnectResponse(t *testing.T) {
1504
1505 var tcases = []struct {
1506 proxyStatusCode int
1507 err error
1508 }{
1509 {
1510 StatusOK,
1511 nil,
1512 },
1513 {
1514 StatusForbidden,
1515 errors.New("403"),
1516 },
1517 }
1518 for _, tcase := range tcases {
1519 h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1520
1521 })
1522
1523 h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1524
1525 if r.Method == "CONNECT" {
1526 if tcase.proxyStatusCode != StatusOK {
1527 w.WriteHeader(tcase.proxyStatusCode)
1528 return
1529 }
1530 hijacker, ok := w.(Hijacker)
1531 if !ok {
1532 t.Errorf("hijack not allowed")
1533 return
1534 }
1535 clientConn, _, err := hijacker.Hijack()
1536 if err != nil {
1537 t.Errorf("hijacking failed")
1538 return
1539 }
1540 res := &Response{
1541 StatusCode: StatusOK,
1542 Proto: "HTTP/1.1",
1543 ProtoMajor: 1,
1544 ProtoMinor: 1,
1545 Header: make(Header),
1546 }
1547
1548 targetConn, err := net.Dial("tcp", r.URL.Host)
1549 if err != nil {
1550 t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1551 return
1552 }
1553
1554 if err := res.Write(clientConn); err != nil {
1555 t.Errorf("Writing 200 OK failed: %v", err)
1556 return
1557 }
1558
1559 go io.Copy(targetConn, clientConn)
1560 go func() {
1561 io.Copy(clientConn, targetConn)
1562 targetConn.Close()
1563 }()
1564 }
1565 })
1566 ts := newClientServerTest(t, https1Mode, h1).ts
1567 proxy := newClientServerTest(t, https1Mode, h2).ts
1568
1569 pu, err := url.Parse(proxy.URL)
1570 if err != nil {
1571 t.Fatal(err)
1572 }
1573
1574 c := proxy.Client()
1575
1576 var (
1577 dials atomic.Int32
1578 closes atomic.Int32
1579 )
1580 c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1581 conn, err := net.Dial(network, addr)
1582 if err != nil {
1583 return nil, err
1584 }
1585 dials.Add(1)
1586 return noteCloseConn{
1587 Conn: conn,
1588 closeFunc: func() {
1589 closes.Add(1)
1590 },
1591 }, nil
1592 }
1593
1594 c.Transport.(*Transport).Proxy = ProxyURL(pu)
1595 c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1596 if proxyURL.String() != pu.String() {
1597 t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1598 }
1599
1600 if "https://"+connectReq.URL.String() != ts.URL {
1601 t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1602 }
1603 return tcase.err
1604 }
1605 wantCloses := int32(0)
1606 if _, err := c.Head(ts.URL); err != nil {
1607 wantCloses = 1
1608 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1609 t.Errorf("got %v, want %v", err, tcase.err)
1610 }
1611 } else {
1612 if tcase.err != nil {
1613 t.Errorf("got %v, want nil", err)
1614 }
1615 }
1616 if got, want := dials.Load(), int32(1); got != want {
1617 t.Errorf("got %v dials, want %v", got, want)
1618 }
1619
1620 if got, want := closes.Load(), wantCloses; got != want {
1621 t.Errorf("got %v closes, want %v", got, want)
1622 }
1623 }
1624 }
1625
1626
1627
1628 func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1629 cancelc := make(chan struct{})
1630 SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1631 ctx, cancel := context.WithCancel(ctx)
1632 go func() {
1633 select {
1634 case <-cancelc:
1635 case <-ctx.Done():
1636 }
1637 cancel()
1638 }()
1639 return ctx, cancel
1640 })
1641
1642 defer afterTest(t)
1643
1644 ln := newLocalListener(t)
1645 defer ln.Close()
1646 listenerDone := make(chan struct{})
1647 go func() {
1648 defer close(listenerDone)
1649 c, err := ln.Accept()
1650 if err != nil {
1651 t.Errorf("Accept: %v", err)
1652 return
1653 }
1654 defer c.Close()
1655
1656 br := bufio.NewReader(c)
1657 cr, err := ReadRequest(br)
1658 if err != nil {
1659 t.Errorf("proxy server failed to read CONNECT request")
1660 return
1661 }
1662 if cr.Method != "CONNECT" {
1663 t.Errorf("unexpected method %q", cr.Method)
1664 return
1665 }
1666
1667
1668
1669
1670 close(cancelc)
1671 var buf [1]byte
1672 _, err = br.Read(buf[:])
1673 if err != io.EOF {
1674 t.Errorf("proxy server Read err = %v; want EOF", err)
1675 }
1676 return
1677 }()
1678
1679 c := &Client{
1680 Transport: &Transport{
1681 Proxy: func(*Request) (*url.URL, error) {
1682 return url.Parse("http://" + ln.Addr().String())
1683 },
1684 },
1685 }
1686 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1687 if err != nil {
1688 t.Fatal(err)
1689 }
1690 _, err = c.Do(req)
1691 if err == nil {
1692 t.Errorf("unexpected Get success")
1693 }
1694
1695
1696
1697
1698 <-listenerDone
1699 }
1700
1701
1702 func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1703 defer afterTest(t)
1704
1705 var errDial = errors.New("some dial error")
1706
1707 tr := &Transport{
1708 Proxy: func(*Request) (*url.URL, error) {
1709 return url.Parse("http://proxy.fake.tld/")
1710 },
1711 Dial: func(string, string) (net.Conn, error) {
1712 return nil, errDial
1713 },
1714 }
1715 defer tr.CloseIdleConnections()
1716
1717 c := &Client{Transport: tr}
1718 req, _ := NewRequest("GET", "http://fake.tld", nil)
1719 res, err := c.Do(req)
1720 if err == nil {
1721 res.Body.Close()
1722 t.Fatal("wanted a non-nil error")
1723 }
1724
1725 uerr, ok := err.(*url.Error)
1726 if !ok {
1727 t.Fatalf("got %T, want *url.Error", err)
1728 }
1729 oe, ok := uerr.Err.(*net.OpError)
1730 if !ok {
1731 t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
1732 }
1733 want := &net.OpError{
1734 Op: "proxyconnect",
1735 Net: "tcp",
1736 Err: errDial,
1737 }
1738 if !reflect.DeepEqual(oe, want) {
1739 t.Errorf("Got error %#v; want %#v", oe, want)
1740 }
1741 }
1742
1743
1744
1745
1746
1747 func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1748 run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1749 }
1750 func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1751 proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1752 defer proxy.Close()
1753 c := proxy.Client()
1754
1755 tr := c.Transport.(*Transport)
1756 tr.Proxy = func(*Request) (*url.URL, error) {
1757 u, _ := url.Parse(proxy.URL)
1758 u.User = url.UserPassword("aladdin", "opensesame")
1759 return u, nil
1760 }
1761 h := tr.ProxyConnectHeader
1762 if h == nil {
1763 h = make(Header)
1764 }
1765 tr.ProxyConnectHeader = h.Clone()
1766
1767 req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1768 if err != nil {
1769 t.Fatal(err)
1770 }
1771 _, err = c.Do(req)
1772 if err == nil {
1773 t.Errorf("unexpected Get success")
1774 }
1775
1776 if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1777 t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1778 }
1779 }
1780
1781
1782
1783
1784
1785 func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1786 func testTransportGzipRecursive(t *testing.T, mode testMode) {
1787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1788 w.Header().Set("Content-Encoding", "gzip")
1789 w.Write(rgz)
1790 })).ts
1791
1792 c := ts.Client()
1793 res, err := c.Get(ts.URL)
1794 if err != nil {
1795 t.Fatal(err)
1796 }
1797 body, err := io.ReadAll(res.Body)
1798 if err != nil {
1799 t.Fatal(err)
1800 }
1801 if !bytes.Equal(body, rgz) {
1802 t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1803 body, rgz)
1804 }
1805 if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1806 t.Fatalf("Content-Encoding = %q; want %q", g, e)
1807 }
1808 }
1809
1810
1811
1812 func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1813 func testTransportGzipShort(t *testing.T, mode testMode) {
1814 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1815 w.Header().Set("Content-Encoding", "gzip")
1816 w.Write([]byte{0x1f, 0x8b})
1817 })).ts
1818
1819 c := ts.Client()
1820 res, err := c.Get(ts.URL)
1821 if err != nil {
1822 t.Fatal(err)
1823 }
1824 defer res.Body.Close()
1825 _, err = io.ReadAll(res.Body)
1826 if err == nil {
1827 t.Fatal("Expect an error from reading a body.")
1828 }
1829 if err != io.ErrUnexpectedEOF {
1830 t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1831 }
1832 }
1833
1834
1835 func waitNumGoroutine(nmax int) int {
1836 nfinal := runtime.NumGoroutine()
1837 for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1838 time.Sleep(50 * time.Millisecond)
1839 runtime.GC()
1840 nfinal = runtime.NumGoroutine()
1841 }
1842 return nfinal
1843 }
1844
1845
1846 func TestTransportPersistConnLeak(t *testing.T) {
1847 run(t, testTransportPersistConnLeak, testNotParallel)
1848 }
1849 func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1850 if mode == http2Mode {
1851 t.Skip("flaky in HTTP/2")
1852 }
1853
1854
1855 const numReq = 25
1856 gotReqCh := make(chan bool, numReq)
1857 unblockCh := make(chan bool, numReq)
1858 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1859 gotReqCh <- true
1860 <-unblockCh
1861 w.Header().Set("Content-Length", "0")
1862 w.WriteHeader(204)
1863 })).ts
1864 c := ts.Client()
1865 tr := c.Transport.(*Transport)
1866
1867 n0 := runtime.NumGoroutine()
1868
1869 didReqCh := make(chan bool, numReq)
1870 failed := make(chan bool, numReq)
1871 for i := 0; i < numReq; i++ {
1872 go func() {
1873 res, err := c.Get(ts.URL)
1874 didReqCh <- true
1875 if err != nil {
1876 t.Logf("client fetch error: %v", err)
1877 failed <- true
1878 return
1879 }
1880 res.Body.Close()
1881 }()
1882 }
1883
1884
1885 for i := 0; i < numReq; i++ {
1886 select {
1887 case <-gotReqCh:
1888
1889 case <-failed:
1890
1891
1892 }
1893 }
1894
1895 nhigh := runtime.NumGoroutine()
1896
1897
1898 close(unblockCh)
1899
1900
1901 for i := 0; i < numReq; i++ {
1902 <-didReqCh
1903 }
1904
1905 tr.CloseIdleConnections()
1906 nfinal := waitNumGoroutine(n0 + 5)
1907
1908 growth := nfinal - n0
1909
1910
1911
1912 if int(growth) > 5 {
1913 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1914 t.Error("too many new goroutines")
1915 }
1916 }
1917
1918
1919
1920 func TestTransportPersistConnLeakShortBody(t *testing.T) {
1921 run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1922 }
1923 func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1924 if mode == http2Mode {
1925 t.Skip("flaky in HTTP/2")
1926 }
1927
1928
1929 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1930 })).ts
1931 c := ts.Client()
1932 tr := c.Transport.(*Transport)
1933
1934 n0 := runtime.NumGoroutine()
1935 body := []byte("Hello")
1936 for i := 0; i < 20; i++ {
1937 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1938 if err != nil {
1939 t.Fatal(err)
1940 }
1941 req.ContentLength = int64(len(body) - 2)
1942 _, err = c.Do(req)
1943 if err == nil {
1944 t.Fatal("Expect an error from writing too long of a body.")
1945 }
1946 }
1947 nhigh := runtime.NumGoroutine()
1948 tr.CloseIdleConnections()
1949 nfinal := waitNumGoroutine(n0 + 5)
1950
1951 growth := nfinal - n0
1952
1953
1954
1955 t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1956 if int(growth) > 5 {
1957 t.Error("too many new goroutines")
1958 }
1959 }
1960
1961
1962 type countedConn struct {
1963 net.Conn
1964 }
1965
1966
1967 type countingDialer struct {
1968 dialer net.Dialer
1969 mu sync.Mutex
1970 total, live int64
1971 }
1972
1973 func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
1974 conn, err := d.dialer.DialContext(ctx, network, address)
1975 if err != nil {
1976 return nil, err
1977 }
1978
1979 counted := new(countedConn)
1980 counted.Conn = conn
1981
1982 d.mu.Lock()
1983 defer d.mu.Unlock()
1984 d.total++
1985 d.live++
1986
1987 runtime.SetFinalizer(counted, d.decrement)
1988 return counted, nil
1989 }
1990
1991 func (d *countingDialer) decrement(*countedConn) {
1992 d.mu.Lock()
1993 defer d.mu.Unlock()
1994 d.live--
1995 }
1996
1997 func (d *countingDialer) Read() (total, live int64) {
1998 d.mu.Lock()
1999 defer d.mu.Unlock()
2000 return d.total, d.live
2001 }
2002
2003 func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2004 run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2005 }
2006 func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2007 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2008
2009 conn, _, err := w.(Hijacker).Hijack()
2010 if err != nil {
2011 t.Errorf("Hijack failed unexpectedly: %v", err)
2012 return
2013 }
2014 conn.Close()
2015 })).ts
2016
2017 var d countingDialer
2018 c := ts.Client()
2019 c.Transport.(*Transport).DialContext = d.DialContext
2020
2021 body := []byte("Hello")
2022 for i := 0; ; i++ {
2023 total, live := d.Read()
2024 if live < total {
2025 break
2026 }
2027 if i >= 1<<12 {
2028 t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2029 }
2030
2031 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2032 if err != nil {
2033 t.Fatal(err)
2034 }
2035 _, err = c.Do(req)
2036 if err == nil {
2037 t.Fatal("expected broken connection")
2038 }
2039
2040 runtime.GC()
2041 }
2042 }
2043
2044 type countedContext struct {
2045 context.Context
2046 }
2047
2048 type contextCounter struct {
2049 mu sync.Mutex
2050 live int64
2051 }
2052
2053 func (cc *contextCounter) Track(ctx context.Context) context.Context {
2054 counted := new(countedContext)
2055 counted.Context = ctx
2056 cc.mu.Lock()
2057 defer cc.mu.Unlock()
2058 cc.live++
2059 runtime.SetFinalizer(counted, cc.decrement)
2060 return counted
2061 }
2062
2063 func (cc *contextCounter) decrement(*countedContext) {
2064 cc.mu.Lock()
2065 defer cc.mu.Unlock()
2066 cc.live--
2067 }
2068
2069 func (cc *contextCounter) Read() (live int64) {
2070 cc.mu.Lock()
2071 defer cc.mu.Unlock()
2072 return cc.live
2073 }
2074
2075 func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2076 run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2077 }
2078 func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2079 if mode == http2Mode {
2080 t.Skip("https://go.dev/issue/56021")
2081 }
2082
2083 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2084 runtime.Gosched()
2085 w.WriteHeader(StatusOK)
2086 })).ts
2087
2088 c := ts.Client()
2089 c.Transport.(*Transport).MaxConnsPerHost = 1
2090
2091 ctx := context.Background()
2092 body := []byte("Hello")
2093 doPosts := func(cc *contextCounter) {
2094 var wg sync.WaitGroup
2095 for n := 64; n > 0; n-- {
2096 wg.Add(1)
2097 go func() {
2098 defer wg.Done()
2099
2100 ctx := cc.Track(ctx)
2101 req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2102 if err != nil {
2103 t.Error(err)
2104 }
2105
2106 _, err = c.Do(req.WithContext(ctx))
2107 if err != nil {
2108 t.Errorf("Do failed with error: %v", err)
2109 }
2110 }()
2111 }
2112 wg.Wait()
2113 }
2114
2115 var initialCC contextCounter
2116 doPosts(&initialCC)
2117
2118
2119
2120
2121 var flushCC contextCounter
2122 for i := 0; ; i++ {
2123 live := initialCC.Read()
2124 if live == 0 {
2125 break
2126 }
2127 if i >= 100 {
2128 t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2129 }
2130 doPosts(&flushCC)
2131 runtime.GC()
2132 }
2133 }
2134
2135
2136 func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2137 func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2138 var tr *Transport
2139
2140 unblockCh := make(chan bool, 1)
2141 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2142 <-unblockCh
2143 tr.CloseIdleConnections()
2144 })).ts
2145 c := ts.Client()
2146 tr = c.Transport.(*Transport)
2147
2148 didreq := make(chan bool)
2149 go func() {
2150 res, err := c.Get(ts.URL)
2151 if err != nil {
2152 t.Error(err)
2153 } else {
2154 res.Body.Close()
2155 }
2156 didreq <- true
2157 }()
2158 unblockCh <- true
2159 <-didreq
2160 }
2161
2162
2163
2164
2165
2166 func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2167 func testIssue3644(t *testing.T, mode testMode) {
2168 const numFoos = 5000
2169 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2170 w.Header().Set("Connection", "close")
2171 for i := 0; i < numFoos; i++ {
2172 w.Write([]byte("foo "))
2173 }
2174 })).ts
2175 c := ts.Client()
2176 res, err := c.Get(ts.URL)
2177 if err != nil {
2178 t.Fatal(err)
2179 }
2180 defer res.Body.Close()
2181 bs, err := io.ReadAll(res.Body)
2182 if err != nil {
2183 t.Fatal(err)
2184 }
2185 if len(bs) != numFoos*len("foo ") {
2186 t.Errorf("unexpected response length")
2187 }
2188 }
2189
2190
2191
2192 func TestIssue3595(t *testing.T) {
2193
2194 run(t, testIssue3595, testNotParallel)
2195 }
2196 func testIssue3595(t *testing.T, mode testMode) {
2197 runTimeSensitiveTest(t, []time.Duration{
2198 1 * time.Millisecond,
2199 5 * time.Millisecond,
2200 10 * time.Millisecond,
2201 50 * time.Millisecond,
2202 100 * time.Millisecond,
2203 500 * time.Millisecond,
2204 time.Second,
2205 5 * time.Second,
2206 }, func(t *testing.T, timeout time.Duration) error {
2207 SetRSTAvoidanceDelay(t, timeout)
2208 t.Logf("set RST avoidance delay to %v", timeout)
2209
2210 const deniedMsg = "sorry, denied."
2211 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2212 Error(w, deniedMsg, StatusUnauthorized)
2213 }))
2214
2215
2216 defer cst.close()
2217 ts := cst.ts
2218 c := ts.Client()
2219
2220 res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2221 if err != nil {
2222 return fmt.Errorf("Post: %v", err)
2223 }
2224 got, err := io.ReadAll(res.Body)
2225 if err != nil {
2226 return fmt.Errorf("Body ReadAll: %v", err)
2227 }
2228 t.Logf("server response:\n%s", got)
2229 if !strings.Contains(string(got), deniedMsg) {
2230
2231
2232 t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2233 }
2234 return nil
2235 })
2236 }
2237
2238
2239
2240 func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2241 func testChunkedNoContent(t *testing.T, mode testMode) {
2242 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2243 w.WriteHeader(StatusNoContent)
2244 })).ts
2245
2246 c := ts.Client()
2247 for _, closeBody := range []bool{true, false} {
2248 const n = 4
2249 for i := 1; i <= n; i++ {
2250 res, err := c.Get(ts.URL)
2251 if err != nil {
2252 t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2253 } else {
2254 if closeBody {
2255 res.Body.Close()
2256 }
2257 }
2258 }
2259 }
2260 }
2261
2262 func TestTransportConcurrency(t *testing.T) {
2263 run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2264 }
2265 func testTransportConcurrency(t *testing.T, mode testMode) {
2266
2267 maxProcs, numReqs := 16, 500
2268 if testing.Short() {
2269 maxProcs, numReqs = 4, 50
2270 }
2271 defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2272 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2273 fmt.Fprintf(w, "%v", r.FormValue("echo"))
2274 })).ts
2275
2276 var wg sync.WaitGroup
2277 wg.Add(numReqs)
2278
2279
2280
2281
2282
2283
2284
2285 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2286 defer SetPendingDialHooks(nil, nil)
2287
2288 c := ts.Client()
2289 reqs := make(chan string)
2290 defer close(reqs)
2291
2292 for i := 0; i < maxProcs*2; i++ {
2293 go func() {
2294 for req := range reqs {
2295 res, err := c.Get(ts.URL + "/?echo=" + req)
2296 if err != nil {
2297 if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2298
2299
2300 t.Logf("error on req %s: %v", req, err)
2301 t.Logf("(see https://go.dev/issue/52168)")
2302 } else {
2303 t.Errorf("error on req %s: %v", req, err)
2304 }
2305 wg.Done()
2306 continue
2307 }
2308 all, err := io.ReadAll(res.Body)
2309 if err != nil {
2310 t.Errorf("read error on req %s: %v", req, err)
2311 } else if string(all) != req {
2312 t.Errorf("body of req %s = %q; want %q", req, all, req)
2313 }
2314 res.Body.Close()
2315 wg.Done()
2316 }
2317 }()
2318 }
2319 for i := 0; i < numReqs; i++ {
2320 reqs <- fmt.Sprintf("request-%d", i)
2321 }
2322 wg.Wait()
2323 }
2324
2325 func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2326 func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2327 mux := NewServeMux()
2328 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2329 io.Copy(w, neverEnding('a'))
2330 })
2331 ts := newClientServerTest(t, mode, mux).ts
2332
2333 connc := make(chan net.Conn, 1)
2334 c := ts.Client()
2335 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2336 conn, err := net.Dial(n, addr)
2337 if err != nil {
2338 return nil, err
2339 }
2340 select {
2341 case connc <- conn:
2342 default:
2343 }
2344 return conn, nil
2345 }
2346
2347 res, err := c.Get(ts.URL + "/get")
2348 if err != nil {
2349 t.Fatalf("Error issuing GET: %v", err)
2350 }
2351 defer res.Body.Close()
2352
2353 conn := <-connc
2354 conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2355 _, err = io.Copy(io.Discard, res.Body)
2356 if err == nil {
2357 t.Errorf("Unexpected successful copy")
2358 }
2359 }
2360
2361 func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2362 run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2363 }
2364 func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2365 const debug = false
2366 mux := NewServeMux()
2367 mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2368 io.Copy(w, neverEnding('a'))
2369 })
2370 mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2371 defer r.Body.Close()
2372 io.Copy(io.Discard, r.Body)
2373 })
2374 ts := newClientServerTest(t, mode, mux).ts
2375 timeout := 100 * time.Millisecond
2376
2377 c := ts.Client()
2378 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2379 conn, err := net.Dial(n, addr)
2380 if err != nil {
2381 return nil, err
2382 }
2383 conn.SetDeadline(time.Now().Add(timeout))
2384 if debug {
2385 conn = NewLoggingConn("client", conn)
2386 }
2387 return conn, nil
2388 }
2389
2390 getFailed := false
2391 nRuns := 5
2392 if testing.Short() {
2393 nRuns = 1
2394 }
2395 for i := 0; i < nRuns; i++ {
2396 if debug {
2397 println("run", i+1, "of", nRuns)
2398 }
2399 sres, err := c.Get(ts.URL + "/get")
2400 if err != nil {
2401 if !getFailed {
2402
2403 getFailed = true
2404 t.Logf("increasing timeout")
2405 i--
2406 timeout *= 10
2407 continue
2408 }
2409 t.Errorf("Error issuing GET: %v", err)
2410 break
2411 }
2412 req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2413 _, err = c.Do(req)
2414 if err == nil {
2415 sres.Body.Close()
2416 t.Errorf("Unexpected successful PUT")
2417 break
2418 }
2419 sres.Body.Close()
2420 }
2421 if debug {
2422 println("tests complete; waiting for handlers to finish")
2423 }
2424 ts.Close()
2425 }
2426
2427 func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2428 func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2429 if testing.Short() {
2430 t.Skip("skipping timeout test in -short mode")
2431 }
2432
2433 timeout := 2 * time.Millisecond
2434 retry := true
2435 for retry && !t.Failed() {
2436 var srvWG sync.WaitGroup
2437 inHandler := make(chan bool, 1)
2438 mux := NewServeMux()
2439 mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2440 inHandler <- true
2441 srvWG.Done()
2442 })
2443 mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2444 inHandler <- true
2445 <-r.Context().Done()
2446 srvWG.Done()
2447 })
2448 ts := newClientServerTest(t, mode, mux).ts
2449
2450 c := ts.Client()
2451 c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2452
2453 retry = false
2454 srvWG.Add(3)
2455 tests := []struct {
2456 path string
2457 wantTimeout bool
2458 }{
2459 {path: "/fast"},
2460 {path: "/slow", wantTimeout: true},
2461 {path: "/fast"},
2462 }
2463 for i, tt := range tests {
2464 req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2465 req = req.WithT(t)
2466 res, err := c.Do(req)
2467 <-inHandler
2468 if err != nil {
2469 uerr, ok := err.(*url.Error)
2470 if !ok {
2471 t.Errorf("error is not a url.Error; got: %#v", err)
2472 continue
2473 }
2474 nerr, ok := uerr.Err.(net.Error)
2475 if !ok {
2476 t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2477 continue
2478 }
2479 if !nerr.Timeout() {
2480 t.Errorf("want timeout error; got: %q", nerr)
2481 continue
2482 }
2483 if !tt.wantTimeout {
2484 if !retry {
2485
2486 t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2487 timeout *= 2
2488 retry = true
2489 }
2490 }
2491 if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2492 t.Errorf("%d. unexpected error: %v", i, err)
2493 }
2494 continue
2495 }
2496 if tt.wantTimeout {
2497 t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2498 continue
2499 }
2500 if res.StatusCode != 200 {
2501 t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2502 }
2503 }
2504
2505 srvWG.Wait()
2506 ts.Close()
2507 }
2508 }
2509
2510 func TestTransportCancelRequest(t *testing.T) {
2511 run(t, testTransportCancelRequest, []testMode{http1Mode})
2512 }
2513 func testTransportCancelRequest(t *testing.T, mode testMode) {
2514 if testing.Short() {
2515 t.Skip("skipping test in -short mode")
2516 }
2517
2518 const msg = "Hello"
2519 unblockc := make(chan bool)
2520 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2521 io.WriteString(w, msg)
2522 w.(Flusher).Flush()
2523 <-unblockc
2524 })).ts
2525 defer close(unblockc)
2526
2527 c := ts.Client()
2528 tr := c.Transport.(*Transport)
2529
2530 req, _ := NewRequest("GET", ts.URL, nil)
2531 res, err := c.Do(req)
2532 if err != nil {
2533 t.Fatal(err)
2534 }
2535 body := make([]byte, len(msg))
2536 n, _ := io.ReadFull(res.Body, body)
2537 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2538 t.Errorf("Body = %q; want %q", body[:n], msg)
2539 }
2540 tr.CancelRequest(req)
2541
2542 tail, err := io.ReadAll(res.Body)
2543 res.Body.Close()
2544 if err != ExportErrRequestCanceled {
2545 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2546 } else if len(tail) > 0 {
2547 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2548 }
2549
2550
2551
2552 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2553 n := tr.NumPendingRequestsForTesting()
2554 if n > 0 {
2555 if d > 0 {
2556 t.Logf("pending requests = %d after %v (want 0)", n, d)
2557 }
2558 return false
2559 }
2560 return true
2561 })
2562 }
2563
2564 func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
2565 if testing.Short() {
2566 t.Skip("skipping test in -short mode")
2567 }
2568 unblockc := make(chan bool)
2569 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2570 <-unblockc
2571 })).ts
2572 defer close(unblockc)
2573
2574 c := ts.Client()
2575 tr := c.Transport.(*Transport)
2576
2577 donec := make(chan bool)
2578 req, _ := NewRequest("GET", ts.URL, body)
2579 go func() {
2580 defer close(donec)
2581 c.Do(req)
2582 }()
2583
2584 unblockc <- true
2585 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2586 tr.CancelRequest(req)
2587 select {
2588 case <-donec:
2589 return true
2590 default:
2591 if d > 0 {
2592 t.Logf("Do of canceled request has not returned after %v", d)
2593 }
2594 return false
2595 }
2596 })
2597 }
2598
2599 func TestTransportCancelRequestInDo(t *testing.T) {
2600 run(t, func(t *testing.T, mode testMode) {
2601 testTransportCancelRequestInDo(t, mode, nil)
2602 }, []testMode{http1Mode})
2603 }
2604
2605 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2606 run(t, func(t *testing.T, mode testMode) {
2607 testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
2608 }, []testMode{http1Mode})
2609 }
2610
2611 func TestTransportCancelRequestInDial(t *testing.T) {
2612 defer afterTest(t)
2613 if testing.Short() {
2614 t.Skip("skipping test in -short mode")
2615 }
2616 var logbuf strings.Builder
2617 eventLog := log.New(&logbuf, "", 0)
2618
2619 unblockDial := make(chan bool)
2620 defer close(unblockDial)
2621
2622 inDial := make(chan bool)
2623 tr := &Transport{
2624 Dial: func(network, addr string) (net.Conn, error) {
2625 eventLog.Println("dial: blocking")
2626 if !<-inDial {
2627 return nil, errors.New("main Test goroutine exited")
2628 }
2629 <-unblockDial
2630 return nil, errors.New("nope")
2631 },
2632 }
2633 cl := &Client{Transport: tr}
2634 gotres := make(chan bool)
2635 req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2636 go func() {
2637 _, err := cl.Do(req)
2638 eventLog.Printf("Get = %v", err)
2639 gotres <- true
2640 }()
2641
2642 inDial <- true
2643
2644 eventLog.Printf("canceling")
2645 tr.CancelRequest(req)
2646 tr.CancelRequest(req)
2647
2648 if d, ok := t.Deadline(); ok {
2649
2650
2651 timeout := time.Until(d) * 19 / 20
2652 timer := time.AfterFunc(timeout, func() {
2653 panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2654 })
2655 defer timer.Stop()
2656 }
2657 <-gotres
2658
2659 got := logbuf.String()
2660 want := `dial: blocking
2661 canceling
2662 Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
2663 `
2664 if got != want {
2665 t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2666 }
2667 }
2668
2669 func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
2670 func testCancelRequestWithChannel(t *testing.T, mode testMode) {
2671 if testing.Short() {
2672 t.Skip("skipping test in -short mode")
2673 }
2674
2675 const msg = "Hello"
2676 unblockc := make(chan struct{})
2677 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2678 io.WriteString(w, msg)
2679 w.(Flusher).Flush()
2680 <-unblockc
2681 })).ts
2682 defer close(unblockc)
2683
2684 c := ts.Client()
2685 tr := c.Transport.(*Transport)
2686
2687 req, _ := NewRequest("GET", ts.URL, nil)
2688 cancel := make(chan struct{})
2689 req.Cancel = cancel
2690
2691 res, err := c.Do(req)
2692 if err != nil {
2693 t.Fatal(err)
2694 }
2695 body := make([]byte, len(msg))
2696 n, _ := io.ReadFull(res.Body, body)
2697 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2698 t.Errorf("Body = %q; want %q", body[:n], msg)
2699 }
2700 close(cancel)
2701
2702 tail, err := io.ReadAll(res.Body)
2703 res.Body.Close()
2704 if err != ExportErrRequestCanceled {
2705 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2706 } else if len(tail) > 0 {
2707 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2708 }
2709
2710
2711
2712 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2713 n := tr.NumPendingRequestsForTesting()
2714 if n > 0 {
2715 if d > 0 {
2716 t.Logf("pending requests = %d after %v (want 0)", n, d)
2717 }
2718 return false
2719 }
2720 return true
2721 })
2722 }
2723
2724
2725 func TestCancelRequestWithBodyWithChannel(t *testing.T) {
2726 run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
2727 }
2728 func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
2729 if testing.Short() {
2730 t.Skip("skipping test in -short mode")
2731 }
2732
2733 const msg = "Hello"
2734 unblockc := make(chan struct{})
2735 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2736 io.WriteString(w, msg)
2737 w.(Flusher).Flush()
2738 <-unblockc
2739 })).ts
2740 defer close(unblockc)
2741
2742 c := ts.Client()
2743 tr := c.Transport.(*Transport)
2744
2745 req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2746 cancel := make(chan struct{})
2747 req.Cancel = cancel
2748
2749 res, err := c.Do(req)
2750 if err != nil {
2751 t.Fatal(err)
2752 }
2753 body := make([]byte, len(msg))
2754 n, _ := io.ReadFull(res.Body, body)
2755 if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2756 t.Errorf("Body = %q; want %q", body[:n], msg)
2757 }
2758 close(cancel)
2759
2760 tail, err := io.ReadAll(res.Body)
2761 res.Body.Close()
2762 if err != ExportErrRequestCanceled {
2763 t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
2764 } else if len(tail) > 0 {
2765 t.Errorf("Spurious bytes from Body.Read: %q", tail)
2766 }
2767
2768
2769
2770 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2771 n := tr.NumPendingRequestsForTesting()
2772 if n > 0 {
2773 if d > 0 {
2774 t.Logf("pending requests = %d after %v (want 0)", n, d)
2775 }
2776 return false
2777 }
2778 return true
2779 })
2780 }
2781
2782 func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
2783 run(t, func(t *testing.T, mode testMode) {
2784 testCancelRequestWithChannelBeforeDo(t, mode, false)
2785 })
2786 }
2787 func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
2788 run(t, func(t *testing.T, mode testMode) {
2789 testCancelRequestWithChannelBeforeDo(t, mode, true)
2790 })
2791 }
2792 func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
2793 unblockc := make(chan bool)
2794 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2795 <-unblockc
2796 })).ts
2797 defer close(unblockc)
2798
2799 c := ts.Client()
2800
2801 req, _ := NewRequest("GET", ts.URL, nil)
2802 if withCtx {
2803 ctx, cancel := context.WithCancel(context.Background())
2804 cancel()
2805 req = req.WithContext(ctx)
2806 } else {
2807 ch := make(chan struct{})
2808 req.Cancel = ch
2809 close(ch)
2810 }
2811
2812 _, err := c.Do(req)
2813 if ue, ok := err.(*url.Error); ok {
2814 err = ue.Err
2815 }
2816 if withCtx {
2817 if err != context.Canceled {
2818 t.Errorf("Do error = %v; want %v", err, context.Canceled)
2819 }
2820 } else {
2821 if err == nil || !strings.Contains(err.Error(), "canceled") {
2822 t.Errorf("Do error = %v; want cancellation", err)
2823 }
2824 }
2825 }
2826
2827
2828 func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
2829 defer afterTest(t)
2830
2831 serverConnCh := make(chan net.Conn, 1)
2832 tr := &Transport{
2833 Dial: func(network, addr string) (net.Conn, error) {
2834 cc, sc := net.Pipe()
2835 serverConnCh <- sc
2836 return cc, nil
2837 },
2838 }
2839 defer tr.CloseIdleConnections()
2840 errc := make(chan error, 1)
2841 req, _ := NewRequest("GET", "http://example.com/", nil)
2842 go func() {
2843 _, err := tr.RoundTrip(req)
2844 errc <- err
2845 }()
2846
2847 sc := <-serverConnCh
2848 verb := make([]byte, 3)
2849 if _, err := io.ReadFull(sc, verb); err != nil {
2850 t.Errorf("Error reading HTTP verb from server: %v", err)
2851 }
2852 if string(verb) != "GET" {
2853 t.Errorf("server received %q; want GET", verb)
2854 }
2855 defer sc.Close()
2856
2857 tr.CancelRequest(req)
2858
2859 err := <-errc
2860 if err == nil {
2861 t.Fatalf("unexpected success from RoundTrip")
2862 }
2863 if err != ExportErrRequestCanceled {
2864 t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
2865 }
2866 }
2867
2868
2869
2870
2871 func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2872 func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2873 writeErr := make(chan error, 1)
2874 msg := []byte("young\n")
2875 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2876 for {
2877 _, err := w.Write(msg)
2878 if err != nil {
2879 writeErr <- err
2880 return
2881 }
2882 w.(Flusher).Flush()
2883 }
2884 })).ts
2885
2886 c := ts.Client()
2887 tr := c.Transport.(*Transport)
2888
2889 req, _ := NewRequest("GET", ts.URL, nil)
2890 defer tr.CancelRequest(req)
2891
2892 res, err := c.Do(req)
2893 if err != nil {
2894 t.Fatal(err)
2895 }
2896
2897 const repeats = 3
2898 buf := make([]byte, len(msg)*repeats)
2899 want := bytes.Repeat(msg, repeats)
2900
2901 _, err = io.ReadFull(res.Body, buf)
2902 if err != nil {
2903 t.Fatal(err)
2904 }
2905 if !bytes.Equal(buf, want) {
2906 t.Fatalf("read %q; want %q", buf, want)
2907 }
2908
2909 if err := res.Body.Close(); err != nil {
2910 t.Errorf("Close = %v", err)
2911 }
2912
2913 if err := <-writeErr; err == nil {
2914 t.Errorf("expected non-nil write error")
2915 }
2916 }
2917
2918 type fooProto struct{}
2919
2920 func (fooProto) RoundTrip(req *Request) (*Response, error) {
2921 res := &Response{
2922 Status: "200 OK",
2923 StatusCode: 200,
2924 Header: make(Header),
2925 Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2926 }
2927 return res, nil
2928 }
2929
2930 func TestTransportAltProto(t *testing.T) {
2931 defer afterTest(t)
2932 tr := &Transport{}
2933 c := &Client{Transport: tr}
2934 tr.RegisterProtocol("foo", fooProto{})
2935 res, err := c.Get("foo://bar.com/path")
2936 if err != nil {
2937 t.Fatal(err)
2938 }
2939 bodyb, err := io.ReadAll(res.Body)
2940 if err != nil {
2941 t.Fatal(err)
2942 }
2943 body := string(bodyb)
2944 if e := "You wanted foo://bar.com/path"; body != e {
2945 t.Errorf("got response %q, want %q", body, e)
2946 }
2947 }
2948
2949 func TestTransportNoHost(t *testing.T) {
2950 defer afterTest(t)
2951 tr := &Transport{}
2952 _, err := tr.RoundTrip(&Request{
2953 Header: make(Header),
2954 URL: &url.URL{
2955 Scheme: "http",
2956 },
2957 })
2958 want := "http: no Host in request URL"
2959 if got := fmt.Sprint(err); got != want {
2960 t.Errorf("error = %v; want %q", err, want)
2961 }
2962 }
2963
2964
2965 func TestTransportEmptyMethod(t *testing.T) {
2966 req, _ := NewRequest("GET", "http://foo.com/", nil)
2967 req.Method = ""
2968 got, err := httputil.DumpRequestOut(req, false)
2969 if err != nil {
2970 t.Fatal(err)
2971 }
2972 if !strings.Contains(string(got), "GET ") {
2973 t.Fatalf("expected substring 'GET '; got: %s", got)
2974 }
2975 }
2976
2977 func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
2978 func testTransportSocketLateBinding(t *testing.T, mode testMode) {
2979 mux := NewServeMux()
2980 fooGate := make(chan bool, 1)
2981 mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
2982 w.Header().Set("foo-ipport", r.RemoteAddr)
2983 w.(Flusher).Flush()
2984 <-fooGate
2985 })
2986 mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
2987 w.Header().Set("bar-ipport", r.RemoteAddr)
2988 })
2989 ts := newClientServerTest(t, mode, mux).ts
2990
2991 dialGate := make(chan bool, 1)
2992 dialing := make(chan bool)
2993 c := ts.Client()
2994 c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2995 for {
2996 select {
2997 case ok := <-dialGate:
2998 if !ok {
2999 return nil, errors.New("manually closed")
3000 }
3001 return net.Dial(n, addr)
3002 case dialing <- true:
3003 }
3004 }
3005 }
3006 defer close(dialGate)
3007
3008 dialGate <- true
3009 fooRes, err := c.Get(ts.URL + "/foo")
3010 if err != nil {
3011 t.Fatal(err)
3012 }
3013 fooAddr := fooRes.Header.Get("foo-ipport")
3014 if fooAddr == "" {
3015 t.Fatal("No addr on /foo request")
3016 }
3017
3018 fooDone := make(chan struct{})
3019 go func() {
3020
3021
3022
3023
3024 if mode == http2Mode {
3025
3026
3027
3028
3029 select {
3030 case <-dialing:
3031 t.Errorf("unexpected second Dial in HTTP/2 mode")
3032 case <-time.After(10 * time.Millisecond):
3033 }
3034 } else {
3035 <-dialing
3036 }
3037 fooGate <- true
3038 io.Copy(io.Discard, fooRes.Body)
3039 fooRes.Body.Close()
3040 close(fooDone)
3041 }()
3042 defer func() {
3043 <-fooDone
3044 }()
3045
3046 barRes, err := c.Get(ts.URL + "/bar")
3047 if err != nil {
3048 t.Fatal(err)
3049 }
3050 barAddr := barRes.Header.Get("bar-ipport")
3051 if barAddr != fooAddr {
3052 t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3053 }
3054 barRes.Body.Close()
3055 }
3056
3057
3058 func TestTransportReading100Continue(t *testing.T) {
3059 defer afterTest(t)
3060
3061 const numReqs = 5
3062 reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3063 reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3064
3065 send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3066 defer w.Close()
3067 defer r.Close()
3068 br := bufio.NewReader(r)
3069 n := 0
3070 for {
3071 n++
3072 req, err := ReadRequest(br)
3073 if err == io.EOF {
3074 return
3075 }
3076 if err != nil {
3077 t.Error(err)
3078 return
3079 }
3080 slurp, err := io.ReadAll(req.Body)
3081 if err != nil {
3082 t.Errorf("Server request body slurp: %v", err)
3083 return
3084 }
3085 id := req.Header.Get("Request-Id")
3086 resCode := req.Header.Get("X-Want-Response-Code")
3087 if resCode == "" {
3088 resCode = "100 Continue"
3089 if string(slurp) != reqBody(n) {
3090 t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3091 }
3092 }
3093 body := fmt.Sprintf("Response number %d", n)
3094 v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3095 Date: Thu, 28 Feb 2013 17:55:41 GMT
3096
3097 HTTP/1.1 200 OK
3098 Content-Type: text/html
3099 Echo-Request-Id: %s
3100 Content-Length: %d
3101
3102 %s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3103 w.Write(v)
3104 if id == reqID(numReqs) {
3105 return
3106 }
3107 }
3108
3109 }
3110
3111 tr := &Transport{
3112 Dial: func(n, addr string) (net.Conn, error) {
3113 sr, sw := io.Pipe()
3114 cr, cw := io.Pipe()
3115 conn := &rwTestConn{
3116 Reader: cr,
3117 Writer: sw,
3118 closeFunc: func() error {
3119 sw.Close()
3120 cw.Close()
3121 return nil
3122 },
3123 }
3124 go send100Response(cw, sr)
3125 return conn, nil
3126 },
3127 DisableKeepAlives: false,
3128 }
3129 defer tr.CloseIdleConnections()
3130 c := &Client{Transport: tr}
3131
3132 testResponse := func(req *Request, name string, wantCode int) {
3133 t.Helper()
3134 res, err := c.Do(req)
3135 if err != nil {
3136 t.Fatalf("%s: Do: %v", name, err)
3137 }
3138 if res.StatusCode != wantCode {
3139 t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3140 }
3141 if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3142 t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3143 }
3144 _, err = io.ReadAll(res.Body)
3145 if err != nil {
3146 t.Fatalf("%s: Slurp error: %v", name, err)
3147 }
3148 }
3149
3150
3151 for i := 1; i <= numReqs; i++ {
3152 req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3153 req.Header.Set("Request-Id", reqID(i))
3154 testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3155 }
3156 }
3157
3158
3159
3160 func TestTransportIgnore1xxResponses(t *testing.T) {
3161 run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3162 }
3163 func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3164 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3165 conn, buf, _ := w.(Hijacker).Hijack()
3166 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3167 buf.Flush()
3168 conn.Close()
3169 }))
3170 cst.tr.DisableKeepAlives = true
3171
3172 var got strings.Builder
3173
3174 req, _ := NewRequest("GET", cst.ts.URL, nil)
3175 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3176 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3177 fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3178 return nil
3179 },
3180 }))
3181 res, err := cst.c.Do(req)
3182 if err != nil {
3183 t.Fatal(err)
3184 }
3185 defer res.Body.Close()
3186
3187 res.Write(&got)
3188 want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3189 if got.String() != want {
3190 t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3191 }
3192 }
3193
3194 func TestTransportLimits1xxResponses(t *testing.T) {
3195 run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
3196 }
3197 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3198 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3199 conn, buf, _ := w.(Hijacker).Hijack()
3200 for i := 0; i < 10; i++ {
3201 buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
3202 }
3203 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3204 buf.Flush()
3205 conn.Close()
3206 }))
3207 cst.tr.DisableKeepAlives = true
3208
3209 res, err := cst.c.Get(cst.ts.URL)
3210 if res != nil {
3211 defer res.Body.Close()
3212 }
3213 got := fmt.Sprint(err)
3214 wantSub := "too many 1xx informational responses"
3215 if !strings.Contains(got, wantSub) {
3216 t.Errorf("Get error = %v; want substring %q", err, wantSub)
3217 }
3218 }
3219
3220
3221
3222 func TestTransportTreat101Terminal(t *testing.T) {
3223 run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3224 }
3225 func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3226 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3227 conn, buf, _ := w.(Hijacker).Hijack()
3228 buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3229 buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3230 buf.Flush()
3231 conn.Close()
3232 }))
3233 res, err := cst.c.Get(cst.ts.URL)
3234 if err != nil {
3235 t.Fatal(err)
3236 }
3237 defer res.Body.Close()
3238 if res.StatusCode != StatusSwitchingProtocols {
3239 t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3240 }
3241 }
3242
3243 type proxyFromEnvTest struct {
3244 req string
3245
3246 env string
3247 httpsenv string
3248 noenv string
3249 reqmeth string
3250
3251 want string
3252 wanterr error
3253 }
3254
3255 func (t proxyFromEnvTest) String() string {
3256 var buf strings.Builder
3257 space := func() {
3258 if buf.Len() > 0 {
3259 buf.WriteByte(' ')
3260 }
3261 }
3262 if t.env != "" {
3263 fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3264 }
3265 if t.httpsenv != "" {
3266 space()
3267 fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3268 }
3269 if t.noenv != "" {
3270 space()
3271 fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3272 }
3273 if t.reqmeth != "" {
3274 space()
3275 fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3276 }
3277 req := "http://example.com"
3278 if t.req != "" {
3279 req = t.req
3280 }
3281 space()
3282 fmt.Fprintf(&buf, "req=%q", req)
3283 return strings.TrimSpace(buf.String())
3284 }
3285
3286 var proxyFromEnvTests = []proxyFromEnvTest{
3287 {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3288 {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3289 {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3290 {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3291 {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3292 {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3293 {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3294 {env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3295
3296
3297 {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3298
3299 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3300 {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3301
3302
3303
3304 {env: "http://10.1.2.3:8080", reqmeth: "POST",
3305 want: "<nil>",
3306 wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3307
3308 {want: "<nil>"},
3309
3310 {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3311 {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3312 {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3313 {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3314 {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3315 }
3316
3317 func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3318 t.Helper()
3319 reqURL := tt.req
3320 if reqURL == "" {
3321 reqURL = "http://example.com"
3322 }
3323 req, _ := NewRequest("GET", reqURL, nil)
3324 url, err := proxyForRequest(req)
3325 if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3326 t.Errorf("%v: got error = %q, want %q", tt, g, e)
3327 return
3328 }
3329 if got := fmt.Sprintf("%s", url); got != tt.want {
3330 t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3331 }
3332 }
3333
3334 func TestProxyFromEnvironment(t *testing.T) {
3335 ResetProxyEnv()
3336 defer ResetProxyEnv()
3337 for _, tt := range proxyFromEnvTests {
3338 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3339 os.Setenv("HTTP_PROXY", tt.env)
3340 os.Setenv("HTTPS_PROXY", tt.httpsenv)
3341 os.Setenv("NO_PROXY", tt.noenv)
3342 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3343 ResetCachedEnvironment()
3344 return ProxyFromEnvironment(req)
3345 })
3346 }
3347 }
3348
3349 func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3350 ResetProxyEnv()
3351 defer ResetProxyEnv()
3352 for _, tt := range proxyFromEnvTests {
3353 testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3354 os.Setenv("http_proxy", tt.env)
3355 os.Setenv("https_proxy", tt.httpsenv)
3356 os.Setenv("no_proxy", tt.noenv)
3357 os.Setenv("REQUEST_METHOD", tt.reqmeth)
3358 ResetCachedEnvironment()
3359 return ProxyFromEnvironment(req)
3360 })
3361 }
3362 }
3363
3364 func TestIdleConnChannelLeak(t *testing.T) {
3365 run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3366 }
3367 func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3368
3369 var mu sync.Mutex
3370 var n int
3371
3372 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3373 mu.Lock()
3374 n++
3375 mu.Unlock()
3376 })).ts
3377
3378 const nReqs = 5
3379 didRead := make(chan bool, nReqs)
3380 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3381 defer SetReadLoopBeforeNextReadHook(nil)
3382
3383 c := ts.Client()
3384 tr := c.Transport.(*Transport)
3385 tr.Dial = func(netw, addr string) (net.Conn, error) {
3386 return net.Dial(netw, ts.Listener.Addr().String())
3387 }
3388
3389
3390 for _, disableKeep := range []bool{true, false} {
3391 tr.DisableKeepAlives = disableKeep
3392 for i := 0; i < nReqs; i++ {
3393 _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3394 if err != nil {
3395 t.Fatal(err)
3396 }
3397
3398
3399
3400
3401
3402 }
3403
3404
3405
3406
3407
3408
3409
3410 for i := 0; i < nReqs; i++ {
3411 <-didRead
3412 }
3413
3414 if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3415 t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3416 }
3417 }
3418 }
3419
3420
3421
3422
3423 func TestTransportClosesRequestBody(t *testing.T) {
3424 run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3425 }
3426 func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3427 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3428 io.Copy(io.Discard, r.Body)
3429 })).ts
3430
3431 c := ts.Client()
3432
3433 closes := 0
3434
3435 res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3436 if err != nil {
3437 t.Fatal(err)
3438 }
3439 res.Body.Close()
3440 if closes != 1 {
3441 t.Errorf("closes = %d; want 1", closes)
3442 }
3443 }
3444
3445 func TestTransportTLSHandshakeTimeout(t *testing.T) {
3446 defer afterTest(t)
3447 if testing.Short() {
3448 t.Skip("skipping in short mode")
3449 }
3450 ln := newLocalListener(t)
3451 defer ln.Close()
3452 testdonec := make(chan struct{})
3453 defer close(testdonec)
3454
3455 go func() {
3456 c, err := ln.Accept()
3457 if err != nil {
3458 t.Error(err)
3459 return
3460 }
3461 <-testdonec
3462 c.Close()
3463 }()
3464
3465 tr := &Transport{
3466 Dial: func(_, _ string) (net.Conn, error) {
3467 return net.Dial("tcp", ln.Addr().String())
3468 },
3469 TLSHandshakeTimeout: 250 * time.Millisecond,
3470 }
3471 cl := &Client{Transport: tr}
3472 _, err := cl.Get("https://dummy.tld/")
3473 if err == nil {
3474 t.Error("expected error")
3475 return
3476 }
3477 ue, ok := err.(*url.Error)
3478 if !ok {
3479 t.Errorf("expected url.Error; got %#v", err)
3480 return
3481 }
3482 ne, ok := ue.Err.(net.Error)
3483 if !ok {
3484 t.Errorf("expected net.Error; got %#v", err)
3485 return
3486 }
3487 if !ne.Timeout() {
3488 t.Errorf("expected timeout error; got %v", err)
3489 }
3490 if !strings.Contains(err.Error(), "handshake timeout") {
3491 t.Errorf("expected 'handshake timeout' in error; got %v", err)
3492 }
3493 }
3494
3495
3496 func TestTLSServerClosesConnection(t *testing.T) {
3497 run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3498 }
3499 func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3500 closedc := make(chan bool, 1)
3501 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3502 if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3503 conn, _, _ := w.(Hijacker).Hijack()
3504 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3505 conn.Close()
3506 closedc <- true
3507 return
3508 }
3509 fmt.Fprintf(w, "hello")
3510 })).ts
3511
3512 c := ts.Client()
3513 tr := c.Transport.(*Transport)
3514
3515 var nSuccess = 0
3516 var errs []error
3517 const trials = 20
3518 for i := 0; i < trials; i++ {
3519 tr.CloseIdleConnections()
3520 res, err := c.Get(ts.URL + "/keep-alive-then-die")
3521 if err != nil {
3522 t.Fatal(err)
3523 }
3524 <-closedc
3525 slurp, err := io.ReadAll(res.Body)
3526 if err != nil {
3527 t.Fatal(err)
3528 }
3529 if string(slurp) != "foo" {
3530 t.Errorf("Got %q, want foo", slurp)
3531 }
3532
3533
3534
3535 res, err = c.Get(ts.URL + "/")
3536 if err != nil {
3537 errs = append(errs, err)
3538 continue
3539 }
3540 slurp, err = io.ReadAll(res.Body)
3541 if err != nil {
3542 errs = append(errs, err)
3543 continue
3544 }
3545 nSuccess++
3546 }
3547 if nSuccess > 0 {
3548 t.Logf("successes = %d of %d", nSuccess, trials)
3549 } else {
3550 t.Errorf("All runs failed:")
3551 }
3552 for _, err := range errs {
3553 t.Logf(" err: %v", err)
3554 }
3555 }
3556
3557
3558
3559
3560 type byteFromChanReader chan byte
3561
3562 func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3563 if len(p) == 0 {
3564 return
3565 }
3566 b, ok := <-c
3567 if !ok {
3568 return 0, io.EOF
3569 }
3570 p[0] = b
3571 return 1, nil
3572 }
3573
3574
3575
3576
3577
3578
3579
3580 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3581 run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3582 }
3583 func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3584 defer func(d time.Duration) {
3585 *MaxWriteWaitBeforeConnReuse = d
3586 }(*MaxWriteWaitBeforeConnReuse)
3587 *MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3588 var sconn struct {
3589 sync.Mutex
3590 c net.Conn
3591 }
3592 var getOkay bool
3593 var copying sync.WaitGroup
3594 closeConn := func() {
3595 sconn.Lock()
3596 defer sconn.Unlock()
3597 if sconn.c != nil {
3598 sconn.c.Close()
3599 sconn.c = nil
3600 if !getOkay {
3601 t.Logf("Closed server connection")
3602 }
3603 }
3604 }
3605 defer func() {
3606 closeConn()
3607 copying.Wait()
3608 }()
3609
3610 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3611 if r.Method == "GET" {
3612 io.WriteString(w, "bar")
3613 return
3614 }
3615 conn, _, _ := w.(Hijacker).Hijack()
3616 sconn.Lock()
3617 sconn.c = conn
3618 sconn.Unlock()
3619 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3620
3621 copying.Add(1)
3622 go func() {
3623 io.Copy(io.Discard, conn)
3624 copying.Done()
3625 }()
3626 })).ts
3627 c := ts.Client()
3628
3629 const bodySize = 256 << 10
3630 finalBit := make(byteFromChanReader, 1)
3631 req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3632 req.ContentLength = bodySize
3633 res, err := c.Do(req)
3634 if err := wantBody(res, err, "foo"); err != nil {
3635 t.Errorf("POST response: %v", err)
3636 }
3637
3638 res, err = c.Get(ts.URL)
3639 if err := wantBody(res, err, "bar"); err != nil {
3640 t.Errorf("GET response: %v", err)
3641 return
3642 }
3643 getOkay = true
3644 finalBit <- 'x'
3645 close(finalBit)
3646 }
3647
3648
3649
3650 func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3651 func testTransportIssue10457(t *testing.T, mode testMode) {
3652 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3653
3654
3655
3656
3657
3658 conn, _, _ := w.(Hijacker).Hijack()
3659 conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n"))
3660 conn.Close()
3661 })).ts
3662 c := ts.Client()
3663
3664 res, err := c.Get(ts.URL)
3665 if err != nil {
3666 t.Fatalf("Get: %v", err)
3667 }
3668 defer res.Body.Close()
3669
3670
3671
3672
3673 if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3674 t.Errorf("Foo header = %q; want %q", got, want)
3675 }
3676 }
3677
3678 type closerFunc func() error
3679
3680 func (f closerFunc) Close() error { return f() }
3681
3682 type writerFuncConn struct {
3683 net.Conn
3684 write func(p []byte) (n int, err error)
3685 }
3686
3687 func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701 func TestRetryRequestsOnError(t *testing.T) {
3702 run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3703 }
3704 func testRetryRequestsOnError(t *testing.T, mode testMode) {
3705 newRequest := func(method, urlStr string, body io.Reader) *Request {
3706 req, err := NewRequest(method, urlStr, body)
3707 if err != nil {
3708 t.Fatal(err)
3709 }
3710 return req
3711 }
3712
3713 testCases := []struct {
3714 name string
3715 failureN int
3716 failureErr error
3717
3718
3719
3720 req func() *Request
3721 reqString string
3722 }{
3723 {
3724 name: "IdempotentNoBodySomeWritten",
3725
3726
3727 failureN: 1,
3728
3729 failureErr: ExportErrServerClosedIdle,
3730 req: func() *Request {
3731 return newRequest("GET", "http://fake.golang", nil)
3732 },
3733 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3734 },
3735 {
3736 name: "IdempotentGetBodySomeWritten",
3737
3738
3739 failureN: 1,
3740
3741 failureErr: ExportErrServerClosedIdle,
3742 req: func() *Request {
3743 return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3744 },
3745 reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3746 },
3747 {
3748 name: "NothingWrittenNoBody",
3749
3750
3751 failureN: 0,
3752 failureErr: errors.New("second write fails"),
3753 req: func() *Request {
3754 return newRequest("DELETE", "http://fake.golang", nil)
3755 },
3756 reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3757 },
3758 {
3759 name: "NothingWrittenGetBody",
3760
3761
3762 failureN: 0,
3763 failureErr: errors.New("second write fails"),
3764
3765
3766 req: func() *Request {
3767 return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3768 },
3769 reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3770 },
3771 }
3772
3773 for _, tc := range testCases {
3774 t.Run(tc.name, func(t *testing.T) {
3775 var (
3776 mu sync.Mutex
3777 logbuf strings.Builder
3778 )
3779 logf := func(format string, args ...any) {
3780 mu.Lock()
3781 defer mu.Unlock()
3782 fmt.Fprintf(&logbuf, format, args...)
3783 logbuf.WriteByte('\n')
3784 }
3785
3786 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3787 logf("Handler")
3788 w.Header().Set("X-Status", "ok")
3789 })).ts
3790
3791 var writeNumAtomic int32
3792 c := ts.Client()
3793 c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3794 logf("Dial")
3795 c, err := net.Dial(network, ts.Listener.Addr().String())
3796 if err != nil {
3797 logf("Dial error: %v", err)
3798 return nil, err
3799 }
3800 return &writerFuncConn{
3801 Conn: c,
3802 write: func(p []byte) (n int, err error) {
3803 if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3804 logf("intentional write failure")
3805 return tc.failureN, tc.failureErr
3806 }
3807 logf("Write(%q)", p)
3808 return c.Write(p)
3809 },
3810 }, nil
3811 }
3812
3813 SetRoundTripRetried(func() {
3814 logf("Retried.")
3815 })
3816 defer SetRoundTripRetried(nil)
3817
3818 for i := 0; i < 3; i++ {
3819 t0 := time.Now()
3820 req := tc.req()
3821 res, err := c.Do(req)
3822 if err != nil {
3823 if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3824 mu.Lock()
3825 got := logbuf.String()
3826 mu.Unlock()
3827 t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3828 }
3829 t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3830 }
3831 res.Body.Close()
3832 if res.Request != req {
3833 t.Errorf("Response.Request != original request; want identical Request")
3834 }
3835 }
3836
3837 mu.Lock()
3838 got := logbuf.String()
3839 mu.Unlock()
3840 want := fmt.Sprintf(`Dial
3841 Write("%s")
3842 Handler
3843 intentional write failure
3844 Retried.
3845 Dial
3846 Write("%s")
3847 Handler
3848 Write("%s")
3849 Handler
3850 `, tc.reqString, tc.reqString, tc.reqString)
3851 if got != want {
3852 t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3853 }
3854 })
3855 }
3856 }
3857
3858
3859 func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3860 func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3861 readBody := make(chan error, 1)
3862 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3863 _, err := io.ReadAll(r.Body)
3864 readBody <- err
3865 })).ts
3866 c := ts.Client()
3867 fakeErr := errors.New("fake error")
3868 didClose := make(chan bool, 1)
3869 req, _ := NewRequest("POST", ts.URL, struct {
3870 io.Reader
3871 io.Closer
3872 }{
3873 io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3874 closerFunc(func() error {
3875 select {
3876 case didClose <- true:
3877 default:
3878 }
3879 return nil
3880 }),
3881 })
3882 res, err := c.Do(req)
3883 if res != nil {
3884 defer res.Body.Close()
3885 }
3886 if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3887 t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3888 }
3889 if err := <-readBody; err == nil {
3890 t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3891 }
3892 select {
3893 case <-didClose:
3894 default:
3895 t.Errorf("didn't see Body.Close")
3896 }
3897 }
3898
3899 func TestTransportDialTLS(t *testing.T) {
3900 run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
3901 }
3902 func testTransportDialTLS(t *testing.T, mode testMode) {
3903 var mu sync.Mutex
3904 var gotReq, didDial bool
3905
3906 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3907 mu.Lock()
3908 gotReq = true
3909 mu.Unlock()
3910 })).ts
3911 c := ts.Client()
3912 c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3913 mu.Lock()
3914 didDial = true
3915 mu.Unlock()
3916 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3917 if err != nil {
3918 return nil, err
3919 }
3920 return c, c.Handshake()
3921 }
3922
3923 res, err := c.Get(ts.URL)
3924 if err != nil {
3925 t.Fatal(err)
3926 }
3927 res.Body.Close()
3928 mu.Lock()
3929 if !gotReq {
3930 t.Error("didn't get request")
3931 }
3932 if !didDial {
3933 t.Error("didn't use dial hook")
3934 }
3935 }
3936
3937 func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
3938 func testTransportDialContext(t *testing.T, mode testMode) {
3939 ctxKey := "some-key"
3940 ctxValue := "some-value"
3941 var (
3942 mu sync.Mutex
3943 gotReq bool
3944 gotCtxValue any
3945 )
3946
3947 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3948 mu.Lock()
3949 gotReq = true
3950 mu.Unlock()
3951 })).ts
3952 c := ts.Client()
3953 c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3954 mu.Lock()
3955 gotCtxValue = ctx.Value(ctxKey)
3956 mu.Unlock()
3957 return net.Dial(netw, addr)
3958 }
3959
3960 req, err := NewRequest("GET", ts.URL, nil)
3961 if err != nil {
3962 t.Fatal(err)
3963 }
3964 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
3965 res, err := c.Do(req.WithContext(ctx))
3966 if err != nil {
3967 t.Fatal(err)
3968 }
3969 res.Body.Close()
3970 mu.Lock()
3971 if !gotReq {
3972 t.Error("didn't get request")
3973 }
3974 if got, want := gotCtxValue, ctxValue; got != want {
3975 t.Errorf("got context with value %v, want %v", got, want)
3976 }
3977 }
3978
3979 func TestTransportDialTLSContext(t *testing.T) {
3980 run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
3981 }
3982 func testTransportDialTLSContext(t *testing.T, mode testMode) {
3983 ctxKey := "some-key"
3984 ctxValue := "some-value"
3985 var (
3986 mu sync.Mutex
3987 gotReq bool
3988 gotCtxValue any
3989 )
3990
3991 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3992 mu.Lock()
3993 gotReq = true
3994 mu.Unlock()
3995 })).ts
3996 c := ts.Client()
3997 c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
3998 mu.Lock()
3999 gotCtxValue = ctx.Value(ctxKey)
4000 mu.Unlock()
4001 c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4002 if err != nil {
4003 return nil, err
4004 }
4005 return c, c.HandshakeContext(ctx)
4006 }
4007
4008 req, err := NewRequest("GET", ts.URL, nil)
4009 if err != nil {
4010 t.Fatal(err)
4011 }
4012 ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4013 res, err := c.Do(req.WithContext(ctx))
4014 if err != nil {
4015 t.Fatal(err)
4016 }
4017 res.Body.Close()
4018 mu.Lock()
4019 if !gotReq {
4020 t.Error("didn't get request")
4021 }
4022 if got, want := gotCtxValue, ctxValue; got != want {
4023 t.Errorf("got context with value %v, want %v", got, want)
4024 }
4025 }
4026
4027
4028
4029 func TestRoundTripReturnsProxyError(t *testing.T) {
4030 badProxy := func(*Request) (*url.URL, error) {
4031 return nil, errors.New("errorMessage")
4032 }
4033
4034 tr := &Transport{Proxy: badProxy}
4035
4036 req, _ := NewRequest("GET", "http://example.com", nil)
4037
4038 _, err := tr.RoundTrip(req)
4039
4040 if err == nil {
4041 t.Error("Expected proxy error to be returned by RoundTrip")
4042 }
4043 }
4044
4045
4046 func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4047 tr := &Transport{}
4048 wantIdle := func(when string, n int) bool {
4049 got := tr.IdleConnCountForTesting("http", "example.com")
4050 if got == n {
4051 return true
4052 }
4053 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4054 return false
4055 }
4056 wantIdle("start", 0)
4057 if !tr.PutIdleTestConn("http", "example.com") {
4058 t.Fatal("put failed")
4059 }
4060 if !tr.PutIdleTestConn("http", "example.com") {
4061 t.Fatal("second put failed")
4062 }
4063 wantIdle("after put", 2)
4064 tr.CloseIdleConnections()
4065 if !tr.IsIdleForTesting() {
4066 t.Error("should be idle after CloseIdleConnections")
4067 }
4068 wantIdle("after close idle", 0)
4069 if tr.PutIdleTestConn("http", "example.com") {
4070 t.Fatal("put didn't fail")
4071 }
4072 wantIdle("after second put", 0)
4073
4074 tr.QueueForIdleConnForTesting()
4075 if tr.IsIdleForTesting() {
4076 t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4077 }
4078 if !tr.PutIdleTestConn("http", "example.com") {
4079 t.Fatal("after re-activation")
4080 }
4081 wantIdle("after final put", 1)
4082 }
4083
4084
4085
4086 func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4087 tr := &Transport{}
4088 wantIdle := func(when string, n int) bool {
4089 got := tr.IdleConnCountForTesting("https", "example.com:443")
4090 if got == n {
4091 return true
4092 }
4093 t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4094 return false
4095 }
4096 wantIdle("start", 0)
4097 alt := funcRoundTripper(func() {})
4098 if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4099 t.Fatal("put failed")
4100 }
4101 wantIdle("after put", 1)
4102 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4103 GotConn: func(httptrace.GotConnInfo) {
4104
4105 t.Error("GotConn called")
4106 },
4107 })
4108 req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4109 _, err := tr.RoundTrip(req)
4110 if err != errFakeRoundTrip {
4111 t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4112 }
4113 wantIdle("after round trip", 1)
4114 }
4115
4116 func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
4117 run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
4118 }
4119 func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
4120 if testing.Short() {
4121 t.Skip("skipping in short mode")
4122 }
4123
4124 timeout := 1 * time.Millisecond
4125 retry := true
4126 for retry {
4127 trFunc := func(tr *Transport) {
4128 tr.MaxConnsPerHost = 1
4129 tr.MaxIdleConnsPerHost = 1
4130 tr.IdleConnTimeout = timeout
4131 }
4132 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
4133
4134 retry = false
4135 tooShort := func(err error) bool {
4136 if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
4137 return false
4138 }
4139 if !retry {
4140 t.Helper()
4141 t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
4142 timeout *= 2
4143 retry = true
4144 cst.close()
4145 }
4146 return true
4147 }
4148
4149 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4150 if tooShort(err) {
4151 continue
4152 }
4153 t.Fatalf("got error: %s", err)
4154 }
4155
4156 time.Sleep(10 * timeout)
4157 if _, err := cst.c.Get(cst.ts.URL); err != nil {
4158 if tooShort(err) {
4159 continue
4160 }
4161 t.Fatalf("got error: %s", err)
4162 }
4163 }
4164 }
4165
4166
4167
4168
4169
4170 func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4171 func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4172 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4173 if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4174 t.Error("Transport advertised gzip support in the Accept header")
4175 }
4176 if r.Header.Get("Range") == "" {
4177 t.Error("no Range in request")
4178 }
4179 })).ts
4180 c := ts.Client()
4181
4182 req, _ := NewRequest("GET", ts.URL, nil)
4183 req.Header.Set("Range", "bytes=7-11")
4184 res, err := c.Do(req)
4185 if err != nil {
4186 t.Fatal(err)
4187 }
4188 res.Body.Close()
4189 }
4190
4191
4192 func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4193 func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4194 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4195
4196 var b [1024]byte
4197 w.Write(b[:])
4198 })).ts
4199 tr := ts.Client().Transport.(*Transport)
4200
4201 req, err := NewRequest("GET", ts.URL, nil)
4202 if err != nil {
4203 t.Fatal(err)
4204 }
4205 res, err := tr.RoundTrip(req)
4206 if err != nil {
4207 t.Fatal(err)
4208 }
4209
4210
4211
4212 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4213 t.Fatal(err)
4214 }
4215
4216 req2, err := NewRequest("GET", ts.URL, nil)
4217 if err != nil {
4218 t.Fatal(err)
4219 }
4220 tr.CancelRequest(req)
4221 res, err = tr.RoundTrip(req2)
4222 if err != nil {
4223 t.Fatal(err)
4224 }
4225 res.Body.Close()
4226 }
4227
4228
4229 func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4230 run(t, testTransportContentEncodingCaseInsensitive)
4231 }
4232 func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4233 for _, ce := range []string{"gzip", "GZIP"} {
4234 ce := ce
4235 t.Run(ce, func(t *testing.T) {
4236 const encodedString = "Hello Gopher"
4237 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4238 w.Header().Set("Content-Encoding", ce)
4239 gz := gzip.NewWriter(w)
4240 gz.Write([]byte(encodedString))
4241 gz.Close()
4242 })).ts
4243
4244 res, err := ts.Client().Get(ts.URL)
4245 if err != nil {
4246 t.Fatal(err)
4247 }
4248
4249 body, err := io.ReadAll(res.Body)
4250 res.Body.Close()
4251 if err != nil {
4252 t.Fatal(err)
4253 }
4254
4255 if string(body) != encodedString {
4256 t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4257 }
4258 })
4259 }
4260 }
4261
4262 func TestTransportDialCancelRace(t *testing.T) {
4263 run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
4264 }
4265 func testTransportDialCancelRace(t *testing.T, mode testMode) {
4266 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
4267 tr := ts.Client().Transport.(*Transport)
4268
4269 req, err := NewRequest("GET", ts.URL, nil)
4270 if err != nil {
4271 t.Fatal(err)
4272 }
4273 SetEnterRoundTripHook(func() {
4274 tr.CancelRequest(req)
4275 })
4276 defer SetEnterRoundTripHook(nil)
4277 res, err := tr.RoundTrip(req)
4278 if err != ExportErrRequestCanceled {
4279 t.Errorf("expected canceled request error; got %v", err)
4280 if err == nil {
4281 res.Body.Close()
4282 }
4283 }
4284 }
4285
4286
4287 func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4288 run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4289 }
4290 func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4291 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4292 func(tr *Transport) {
4293 tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4294
4295 return &funcConn{
4296 read: func([]byte) (int, error) {
4297 return 0, errors.New("error")
4298 },
4299 write: func([]byte) (int, error) {
4300 return 0, errors.New("error")
4301 },
4302 }, nil
4303 }
4304 },
4305 ).ts
4306
4307
4308
4309
4310
4311 SetEnterRoundTripHook(func() {
4312 time.Sleep(1 * time.Millisecond)
4313 })
4314 defer SetEnterRoundTripHook(nil)
4315 var closes int
4316 _, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4317 if err == nil {
4318 t.Fatalf("expected request to fail, but it did not")
4319 }
4320 if closes != 1 {
4321 t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4322 }
4323 }
4324
4325
4326
4327
4328 type logWritesConn struct {
4329 net.Conn
4330
4331 w io.Writer
4332
4333 rch <-chan io.Reader
4334 r io.Reader
4335
4336 mu sync.Mutex
4337 writes []string
4338 }
4339
4340 func (c *logWritesConn) Write(p []byte) (n int, err error) {
4341 c.mu.Lock()
4342 defer c.mu.Unlock()
4343 c.writes = append(c.writes, string(p))
4344 return c.w.Write(p)
4345 }
4346
4347 func (c *logWritesConn) Read(p []byte) (n int, err error) {
4348 if c.r == nil {
4349 c.r = <-c.rch
4350 }
4351 return c.r.Read(p)
4352 }
4353
4354 func (c *logWritesConn) Close() error { return nil }
4355
4356
4357 func TestTransportFlushesBodyChunks(t *testing.T) {
4358 defer afterTest(t)
4359 resBody := make(chan io.Reader, 1)
4360 connr, connw := io.Pipe()
4361 lw := &logWritesConn{
4362 rch: resBody,
4363 w: connw,
4364 }
4365 tr := &Transport{
4366 Dial: func(network, addr string) (net.Conn, error) {
4367 return lw, nil
4368 },
4369 }
4370 bodyr, bodyw := io.Pipe()
4371 go func() {
4372 defer bodyw.Close()
4373 for i := 0; i < 3; i++ {
4374 fmt.Fprintf(bodyw, "num%d\n", i)
4375 }
4376 }()
4377 resc := make(chan *Response)
4378 go func() {
4379 req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4380 req.Header.Set("User-Agent", "x")
4381 res, err := tr.RoundTrip(req)
4382 if err != nil {
4383 t.Errorf("RoundTrip: %v", err)
4384 close(resc)
4385 return
4386 }
4387 resc <- res
4388
4389 }()
4390
4391 req, err := ReadRequest(bufio.NewReader(connr))
4392 if err != nil {
4393 t.Fatal(err)
4394 }
4395 io.Copy(io.Discard, req.Body)
4396
4397
4398 resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4399 res, ok := <-resc
4400 if !ok {
4401 return
4402 }
4403 defer res.Body.Close()
4404
4405 want := []string{
4406 "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4407 "5\r\nnum0\n\r\n",
4408 "5\r\nnum1\n\r\n",
4409 "5\r\nnum2\n\r\n",
4410 "0\r\n\r\n",
4411 }
4412 if !reflect.DeepEqual(lw.writes, want) {
4413 t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4414 }
4415 }
4416
4417
4418 func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4419 func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4420 gotReq := make(chan struct{})
4421 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4422 close(gotReq)
4423 }))
4424
4425 pr, pw := io.Pipe()
4426 req, err := NewRequest("POST", cst.ts.URL, pr)
4427 if err != nil {
4428 t.Fatal(err)
4429 }
4430 gotRes := make(chan struct{})
4431 go func() {
4432 defer close(gotRes)
4433 res, err := cst.tr.RoundTrip(req)
4434 if err != nil {
4435 t.Error(err)
4436 return
4437 }
4438 res.Body.Close()
4439 }()
4440
4441 <-gotReq
4442 pw.Close()
4443 <-gotRes
4444 }
4445
4446 type wgReadCloser struct {
4447 io.Reader
4448 wg *sync.WaitGroup
4449 closed bool
4450 }
4451
4452 func (c *wgReadCloser) Close() error {
4453 if c.closed {
4454 return net.ErrClosed
4455 }
4456 c.closed = true
4457 c.wg.Done()
4458 return nil
4459 }
4460
4461
4462 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4463
4464 run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4465 }
4466 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4467 if testing.Short() {
4468 t.Skip("skipping in short mode")
4469 }
4470
4471 runTimeSensitiveTest(t, []time.Duration{
4472 1 * time.Millisecond,
4473 5 * time.Millisecond,
4474 10 * time.Millisecond,
4475 50 * time.Millisecond,
4476 100 * time.Millisecond,
4477 500 * time.Millisecond,
4478 time.Second,
4479 5 * time.Second,
4480 }, func(t *testing.T, timeout time.Duration) error {
4481 SetRSTAvoidanceDelay(t, timeout)
4482 t.Logf("set RST avoidance delay to %v", timeout)
4483
4484 const contentLengthLimit = 1024 * 1024
4485 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4486 if r.ContentLength >= contentLengthLimit {
4487 w.WriteHeader(StatusBadRequest)
4488 r.Body.Close()
4489 return
4490 }
4491 w.WriteHeader(StatusOK)
4492 }))
4493
4494
4495 defer cst.close()
4496 ts := cst.ts
4497 c := ts.Client()
4498
4499 count := 100
4500
4501 bigBody := strings.Repeat("a", contentLengthLimit*2)
4502 var wg sync.WaitGroup
4503 defer wg.Wait()
4504 getBody := func() (io.ReadCloser, error) {
4505 wg.Add(1)
4506 body := &wgReadCloser{
4507 Reader: strings.NewReader(bigBody),
4508 wg: &wg,
4509 }
4510 return body, nil
4511 }
4512
4513 for i := 0; i < count; i++ {
4514 reqBody, _ := getBody()
4515 req, err := NewRequest("PUT", ts.URL, reqBody)
4516 if err != nil {
4517 reqBody.Close()
4518 t.Fatal(err)
4519 }
4520 req.ContentLength = int64(len(bigBody))
4521 req.GetBody = getBody
4522
4523 resp, err := c.Do(req)
4524 if err != nil {
4525 return fmt.Errorf("Do %d: %v", i, err)
4526 } else {
4527 resp.Body.Close()
4528 if resp.StatusCode != 400 {
4529 t.Errorf("Expected status code 400, got %v", resp.Status)
4530 }
4531 }
4532 }
4533 return nil
4534 })
4535 }
4536
4537 func TestTransportAutomaticHTTP2(t *testing.T) {
4538 testTransportAutoHTTP(t, &Transport{}, true)
4539 }
4540
4541 func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4542 testTransportAutoHTTP(t, &Transport{
4543 ForceAttemptHTTP2: true,
4544 TLSClientConfig: new(tls.Config),
4545 }, true)
4546 }
4547
4548
4549 func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4550 testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4551 }
4552
4553 func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4554 testTransportAutoHTTP(t, &Transport{
4555 TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4556 }, false)
4557 }
4558
4559 func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4560 testTransportAutoHTTP(t, &Transport{
4561 TLSClientConfig: new(tls.Config),
4562 }, false)
4563 }
4564
4565 func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4566 testTransportAutoHTTP(t, &Transport{
4567 ExpectContinueTimeout: 1 * time.Second,
4568 }, true)
4569 }
4570
4571 func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4572 var d net.Dialer
4573 testTransportAutoHTTP(t, &Transport{
4574 Dial: d.Dial,
4575 }, false)
4576 }
4577
4578 func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4579 var d net.Dialer
4580 testTransportAutoHTTP(t, &Transport{
4581 DialContext: d.DialContext,
4582 }, false)
4583 }
4584
4585 func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4586 testTransportAutoHTTP(t, &Transport{
4587 DialTLS: func(network, addr string) (net.Conn, error) {
4588 panic("unused")
4589 },
4590 }, false)
4591 }
4592
4593 func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4594 CondSkipHTTP2(t)
4595 _, err := tr.RoundTrip(new(Request))
4596 if err == nil {
4597 t.Error("expected error from RoundTrip")
4598 }
4599 if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4600 t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4601 }
4602 }
4603
4604
4605
4606
4607
4608
4609
4610
4611 func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4612 run(t, testTransportReuseConnEmptyResponseBody)
4613 }
4614 func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4615 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4616 w.Header().Set("X-Addr", r.RemoteAddr)
4617
4618 }))
4619 n := 100
4620 if testing.Short() {
4621 n = 10
4622 }
4623 var firstAddr string
4624 for i := 0; i < n; i++ {
4625 res, err := cst.c.Get(cst.ts.URL)
4626 if err != nil {
4627 log.Fatal(err)
4628 }
4629 addr := res.Header.Get("X-Addr")
4630 if i == 0 {
4631 firstAddr = addr
4632 } else if addr != firstAddr {
4633 t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4634 }
4635 res.Body.Close()
4636 }
4637 }
4638
4639
4640 func TestNoCrashReturningTransportAltConn(t *testing.T) {
4641 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4642 if err != nil {
4643 t.Fatal(err)
4644 }
4645 ln := newLocalListener(t)
4646 defer ln.Close()
4647
4648 var wg sync.WaitGroup
4649 SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4650 defer SetPendingDialHooks(nil, nil)
4651
4652 testDone := make(chan struct{})
4653 defer close(testDone)
4654 go func() {
4655 tln := tls.NewListener(ln, &tls.Config{
4656 NextProtos: []string{"foo"},
4657 Certificates: []tls.Certificate{cert},
4658 })
4659 sc, err := tln.Accept()
4660 if err != nil {
4661 t.Error(err)
4662 return
4663 }
4664 if err := sc.(*tls.Conn).Handshake(); err != nil {
4665 t.Error(err)
4666 return
4667 }
4668 <-testDone
4669 sc.Close()
4670 }()
4671
4672 addr := ln.Addr().String()
4673
4674 req, _ := NewRequest("GET", "https://fake.tld/", nil)
4675 cancel := make(chan struct{})
4676 req.Cancel = cancel
4677
4678 doReturned := make(chan bool, 1)
4679 madeRoundTripper := make(chan bool, 1)
4680
4681 tr := &Transport{
4682 DisableKeepAlives: true,
4683 TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4684 "foo": func(authority string, c *tls.Conn) RoundTripper {
4685 madeRoundTripper <- true
4686 return funcRoundTripper(func() {
4687 t.Error("foo RoundTripper should not be called")
4688 })
4689 },
4690 },
4691 Dial: func(_, _ string) (net.Conn, error) {
4692 panic("shouldn't be called")
4693 },
4694 DialTLS: func(_, _ string) (net.Conn, error) {
4695 tc, err := tls.Dial("tcp", addr, &tls.Config{
4696 InsecureSkipVerify: true,
4697 NextProtos: []string{"foo"},
4698 })
4699 if err != nil {
4700 return nil, err
4701 }
4702 if err := tc.Handshake(); err != nil {
4703 return nil, err
4704 }
4705 close(cancel)
4706 <-doReturned
4707 return tc, nil
4708 },
4709 }
4710 c := &Client{Transport: tr}
4711
4712 _, err = c.Do(req)
4713 if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4714 t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4715 }
4716
4717 doReturned <- true
4718 <-madeRoundTripper
4719 wg.Wait()
4720 }
4721
4722 func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4723 run(t, func(t *testing.T, mode testMode) {
4724 testTransportReuseConnection_Gzip(t, mode, true)
4725 })
4726 }
4727
4728 func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4729 run(t, func(t *testing.T, mode testMode) {
4730 testTransportReuseConnection_Gzip(t, mode, false)
4731 })
4732 }
4733
4734
4735 func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4736 addr := make(chan string, 2)
4737 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4738 addr <- r.RemoteAddr
4739 w.Header().Set("Content-Encoding", "gzip")
4740 if chunked {
4741 w.(Flusher).Flush()
4742 }
4743 w.Write(rgz)
4744 })).ts
4745 c := ts.Client()
4746
4747 trace := &httptrace.ClientTrace{
4748 GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4749 GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4750 PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
4751 ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4752 ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4753 }
4754 ctx := httptrace.WithClientTrace(context.Background(), trace)
4755
4756 for i := 0; i < 2; i++ {
4757 req, _ := NewRequest("GET", ts.URL, nil)
4758 req = req.WithContext(ctx)
4759 res, err := c.Do(req)
4760 if err != nil {
4761 t.Fatal(err)
4762 }
4763 buf := make([]byte, len(rgz))
4764 if n, err := io.ReadFull(res.Body, buf); err != nil {
4765 t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4766 }
4767
4768
4769
4770 }
4771 a1, a2 := <-addr, <-addr
4772 if a1 != a2 {
4773 t.Fatalf("didn't reuse connection")
4774 }
4775 }
4776
4777 func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4778 func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4779 if mode == http2Mode {
4780 t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4781 }
4782 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4783 if r.URL.Path == "/long" {
4784 w.Header().Set("Long", strings.Repeat("a", 1<<20))
4785 }
4786 })).ts
4787 c := ts.Client()
4788 c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4789
4790 if res, err := c.Get(ts.URL); err != nil {
4791 t.Fatal(err)
4792 } else {
4793 res.Body.Close()
4794 }
4795
4796 res, err := c.Get(ts.URL + "/long")
4797 if err == nil {
4798 defer res.Body.Close()
4799 var n int64
4800 for k, vv := range res.Header {
4801 for _, v := range vv {
4802 n += int64(len(k)) + int64(len(v))
4803 }
4804 }
4805 t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4806 }
4807 if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4808 t.Errorf("got error: %v; want %q", err, want)
4809 }
4810 }
4811
4812 func TestTransportEventTrace(t *testing.T) {
4813 run(t, func(t *testing.T, mode testMode) {
4814 testTransportEventTrace(t, mode, false)
4815 }, testNotParallel)
4816 }
4817
4818
4819 func TestTransportEventTrace_NoHooks(t *testing.T) {
4820 run(t, func(t *testing.T, mode testMode) {
4821 testTransportEventTrace(t, mode, true)
4822 }, testNotParallel)
4823 }
4824
4825 func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4826 const resBody = "some body"
4827 gotWroteReqEvent := make(chan struct{}, 500)
4828 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4829 if r.Method == "GET" {
4830
4831 return
4832 }
4833 if _, err := io.ReadAll(r.Body); err != nil {
4834 t.Error(err)
4835 }
4836 if !noHooks {
4837 <-gotWroteReqEvent
4838 }
4839 io.WriteString(w, resBody)
4840 }), func(tr *Transport) {
4841 if tr.TLSClientConfig != nil {
4842 tr.TLSClientConfig.InsecureSkipVerify = true
4843 }
4844 })
4845 defer cst.close()
4846
4847 cst.tr.ExpectContinueTimeout = 1 * time.Second
4848
4849 var mu sync.Mutex
4850 var buf strings.Builder
4851 logf := func(format string, args ...any) {
4852 mu.Lock()
4853 defer mu.Unlock()
4854 fmt.Fprintf(&buf, format, args...)
4855 buf.WriteByte('\n')
4856 }
4857
4858 addrStr := cst.ts.Listener.Addr().String()
4859 ip, port, err := net.SplitHostPort(addrStr)
4860 if err != nil {
4861 t.Fatal(err)
4862 }
4863
4864
4865 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4866 if host != "dns-is-faked.golang" {
4867 t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4868 return nil, nil
4869 }
4870 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4871 })
4872
4873 body := "some body"
4874 req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4875 req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4876 trace := &httptrace.ClientTrace{
4877 GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4878 GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4879 GotFirstResponseByte: func() { logf("first response byte") },
4880 PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
4881 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4882 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4883 ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4884 ConnectDone: func(network, addr string, err error) {
4885 if err != nil {
4886 t.Errorf("ConnectDone: %v", err)
4887 }
4888 logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4889 },
4890 WroteHeaderField: func(key string, value []string) {
4891 logf("WroteHeaderField: %s: %v", key, value)
4892 },
4893 WroteHeaders: func() {
4894 logf("WroteHeaders")
4895 },
4896 Wait100Continue: func() { logf("Wait100Continue") },
4897 Got100Continue: func() { logf("Got100Continue") },
4898 WroteRequest: func(e httptrace.WroteRequestInfo) {
4899 logf("WroteRequest: %+v", e)
4900 gotWroteReqEvent <- struct{}{}
4901 },
4902 }
4903 if mode == http2Mode {
4904 trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4905 trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4906 logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4907 }
4908 }
4909 if noHooks {
4910
4911 *trace = httptrace.ClientTrace{}
4912 }
4913 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4914
4915 req.Header.Set("Expect", "100-continue")
4916 res, err := cst.c.Do(req)
4917 if err != nil {
4918 t.Fatal(err)
4919 }
4920 logf("got roundtrip.response")
4921 slurp, err := io.ReadAll(res.Body)
4922 if err != nil {
4923 t.Fatal(err)
4924 }
4925 logf("consumed body")
4926 if string(slurp) != resBody || res.StatusCode != 200 {
4927 t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4928 }
4929 res.Body.Close()
4930
4931 if noHooks {
4932
4933
4934
4935 return
4936 }
4937
4938 mu.Lock()
4939 got := buf.String()
4940 mu.Unlock()
4941
4942 wantOnce := func(sub string) {
4943 if strings.Count(got, sub) != 1 {
4944 t.Errorf("expected substring %q exactly once in output.", sub)
4945 }
4946 }
4947 wantOnceOrMore := func(sub string) {
4948 if strings.Count(got, sub) == 0 {
4949 t.Errorf("expected substring %q at least once in output.", sub)
4950 }
4951 }
4952 wantOnce("Getting conn for dns-is-faked.golang:" + port)
4953 wantOnce("DNS start: {Host:dns-is-faked.golang}")
4954 wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4955 wantOnce("got conn: {")
4956 wantOnceOrMore("Connecting to tcp " + addrStr)
4957 wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
4958 wantOnce("Reused:false WasIdle:false IdleTime:0s")
4959 wantOnce("first response byte")
4960 if mode == http2Mode {
4961 wantOnce("tls handshake start")
4962 wantOnce("tls handshake done")
4963 } else {
4964 wantOnce("PutIdleConn = <nil>")
4965 wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
4966
4967
4968 wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
4969 wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
4970 wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
4971 wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
4972 }
4973 wantOnce("WroteHeaders")
4974 wantOnce("Wait100Continue")
4975 wantOnce("Got100Continue")
4976 wantOnce("WroteRequest: {Err:<nil>}")
4977 if strings.Contains(got, " to udp ") {
4978 t.Errorf("should not see UDP (DNS) connections")
4979 }
4980 if t.Failed() {
4981 t.Errorf("Output:\n%s", got)
4982 }
4983
4984
4985 req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
4986 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4987 res, err = cst.c.Do(req)
4988 if err != nil {
4989 t.Fatal(err)
4990 }
4991 if res.StatusCode != 200 {
4992 t.Fatal(res.Status)
4993 }
4994 res.Body.Close()
4995
4996 mu.Lock()
4997 got = buf.String()
4998 mu.Unlock()
4999
5000 sub := "Getting conn for dns-is-faked.golang:"
5001 if gotn, want := strings.Count(got, sub), 2; gotn != want {
5002 t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5003 }
5004
5005 }
5006
5007 func TestTransportEventTraceTLSVerify(t *testing.T) {
5008 run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5009 }
5010 func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5011 var mu sync.Mutex
5012 var buf strings.Builder
5013 logf := func(format string, args ...any) {
5014 mu.Lock()
5015 defer mu.Unlock()
5016 fmt.Fprintf(&buf, format, args...)
5017 buf.WriteByte('\n')
5018 }
5019
5020 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5021 t.Error("Unexpected request")
5022 }), func(ts *httptest.Server) {
5023 ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5024 logf("%s", p)
5025 return len(p), nil
5026 }), "", 0)
5027 }).ts
5028
5029 certpool := x509.NewCertPool()
5030 certpool.AddCert(ts.Certificate())
5031
5032 c := &Client{Transport: &Transport{
5033 TLSClientConfig: &tls.Config{
5034 ServerName: "dns-is-faked.golang",
5035 RootCAs: certpool,
5036 },
5037 }}
5038
5039 trace := &httptrace.ClientTrace{
5040 TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5041 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5042 logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5043 },
5044 }
5045
5046 req, _ := NewRequest("GET", ts.URL, nil)
5047 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5048 _, err := c.Do(req)
5049 if err == nil {
5050 t.Error("Expected request to fail TLS verification")
5051 }
5052
5053 mu.Lock()
5054 got := buf.String()
5055 mu.Unlock()
5056
5057 wantOnce := func(sub string) {
5058 if strings.Count(got, sub) != 1 {
5059 t.Errorf("expected substring %q exactly once in output.", sub)
5060 }
5061 }
5062
5063 wantOnce("TLSHandshakeStart")
5064 wantOnce("TLSHandshakeDone")
5065 wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5066
5067 if t.Failed() {
5068 t.Errorf("Output:\n%s", got)
5069 }
5070 }
5071
5072 var (
5073 isDNSHijackedOnce sync.Once
5074 isDNSHijacked bool
5075 )
5076
5077 func skipIfDNSHijacked(t *testing.T) {
5078
5079
5080
5081 isDNSHijackedOnce.Do(func() {
5082 addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5083 isDNSHijacked = len(addrs) != 0
5084 })
5085 if isDNSHijacked {
5086 t.Skip("skipping; test requires non-hijacking DNS server")
5087 }
5088 }
5089
5090 func TestTransportEventTraceRealDNS(t *testing.T) {
5091 skipIfDNSHijacked(t)
5092 defer afterTest(t)
5093 tr := &Transport{}
5094 defer tr.CloseIdleConnections()
5095 c := &Client{Transport: tr}
5096
5097 var mu sync.Mutex
5098 var buf strings.Builder
5099 logf := func(format string, args ...any) {
5100 mu.Lock()
5101 defer mu.Unlock()
5102 fmt.Fprintf(&buf, format, args...)
5103 buf.WriteByte('\n')
5104 }
5105
5106 req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5107 trace := &httptrace.ClientTrace{
5108 DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5109 DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5110 ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5111 ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5112 }
5113 req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5114
5115 resp, err := c.Do(req)
5116 if err == nil {
5117 resp.Body.Close()
5118 t.Fatal("expected error during DNS lookup")
5119 }
5120
5121 mu.Lock()
5122 got := buf.String()
5123 mu.Unlock()
5124
5125 wantSub := func(sub string) {
5126 if !strings.Contains(got, sub) {
5127 t.Errorf("expected substring %q in output.", sub)
5128 }
5129 }
5130 wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5131 wantSub("DNSDone: {Addrs:[] Err:")
5132 if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5133 t.Errorf("should not see Connect events")
5134 }
5135 if t.Failed() {
5136 t.Errorf("Output:\n%s", got)
5137 }
5138 }
5139
5140
5141 func TestTransportRejectsAlphaPort(t *testing.T) {
5142 res, err := Get("http://dummy.tld:123foo/bar")
5143 if err == nil {
5144 res.Body.Close()
5145 t.Fatal("unexpected success")
5146 }
5147 ue, ok := err.(*url.Error)
5148 if !ok {
5149 t.Fatalf("got %#v; want *url.Error", err)
5150 }
5151 got := ue.Err.Error()
5152 want := `invalid port ":123foo" after host`
5153 if got != want {
5154 t.Errorf("got error %q; want %q", got, want)
5155 }
5156 }
5157
5158
5159
5160 func TestTLSHandshakeTrace(t *testing.T) {
5161 run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5162 }
5163 func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5164 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5165
5166 var mu sync.Mutex
5167 var start, done bool
5168 trace := &httptrace.ClientTrace{
5169 TLSHandshakeStart: func() {
5170 mu.Lock()
5171 defer mu.Unlock()
5172 start = true
5173 },
5174 TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5175 mu.Lock()
5176 defer mu.Unlock()
5177 done = true
5178 if err != nil {
5179 t.Fatal("Expected error to be nil but was:", err)
5180 }
5181 },
5182 }
5183
5184 c := ts.Client()
5185 req, err := NewRequest("GET", ts.URL, nil)
5186 if err != nil {
5187 t.Fatal("Unable to construct test request:", err)
5188 }
5189 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5190
5191 r, err := c.Do(req)
5192 if err != nil {
5193 t.Fatal("Unexpected error making request:", err)
5194 }
5195 r.Body.Close()
5196 mu.Lock()
5197 defer mu.Unlock()
5198 if !start {
5199 t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5200 }
5201 if !done {
5202 t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5203 }
5204 }
5205
5206 func TestTransportMaxIdleConns(t *testing.T) {
5207 run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5208 }
5209 func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5210 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5211
5212 })).ts
5213 c := ts.Client()
5214 tr := c.Transport.(*Transport)
5215 tr.MaxIdleConns = 4
5216
5217 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5218 if err != nil {
5219 t.Fatal(err)
5220 }
5221 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5222 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5223 })
5224
5225 hitHost := func(n int) {
5226 req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5227 req = req.WithContext(ctx)
5228 res, err := c.Do(req)
5229 if err != nil {
5230 t.Fatal(err)
5231 }
5232 res.Body.Close()
5233 }
5234 for i := 0; i < 4; i++ {
5235 hitHost(i)
5236 }
5237 want := []string{
5238 "|http|host-0.dns-is-faked.golang:" + port,
5239 "|http|host-1.dns-is-faked.golang:" + port,
5240 "|http|host-2.dns-is-faked.golang:" + port,
5241 "|http|host-3.dns-is-faked.golang:" + port,
5242 }
5243 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5244 t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5245 }
5246
5247
5248 hitHost(4)
5249 want = []string{
5250 "|http|host-1.dns-is-faked.golang:" + port,
5251 "|http|host-2.dns-is-faked.golang:" + port,
5252 "|http|host-3.dns-is-faked.golang:" + port,
5253 "|http|host-4.dns-is-faked.golang:" + port,
5254 }
5255 if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5256 t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5257 }
5258 }
5259
5260 func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5261 func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5262 if testing.Short() {
5263 t.Skip("skipping in short mode")
5264 }
5265
5266 timeout := 1 * time.Millisecond
5267 timeoutLoop:
5268 for {
5269 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5270
5271 }))
5272 tr := cst.tr
5273 tr.IdleConnTimeout = timeout
5274 defer tr.CloseIdleConnections()
5275 c := &Client{Transport: tr}
5276
5277 idleConns := func() []string {
5278 if mode == http2Mode {
5279 return tr.IdleConnStrsForTesting_h2()
5280 } else {
5281 return tr.IdleConnStrsForTesting()
5282 }
5283 }
5284
5285 var conn string
5286 doReq := func(n int) (timeoutOk bool) {
5287 req, _ := NewRequest("GET", cst.ts.URL, nil)
5288 req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5289 PutIdleConn: func(err error) {
5290 if err != nil {
5291 t.Errorf("failed to keep idle conn: %v", err)
5292 }
5293 },
5294 }))
5295 res, err := c.Do(req)
5296 if err != nil {
5297 if strings.Contains(err.Error(), "use of closed network connection") {
5298 t.Logf("req %v: connection closed prematurely", n)
5299 return false
5300 }
5301 }
5302 res.Body.Close()
5303 conns := idleConns()
5304 if len(conns) != 1 {
5305 if len(conns) == 0 {
5306 t.Logf("req %v: no idle conns", n)
5307 return false
5308 }
5309 t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5310 }
5311 if conn == "" {
5312 conn = conns[0]
5313 }
5314 if conn != conns[0] {
5315 t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5316 return false
5317 }
5318 return true
5319 }
5320 for i := 0; i < 3; i++ {
5321 if !doReq(i) {
5322 t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5323 timeout *= 2
5324 cst.close()
5325 continue timeoutLoop
5326 }
5327 time.Sleep(timeout / 2)
5328 }
5329
5330 waitCondition(t, timeout/2, func(d time.Duration) bool {
5331 if got := idleConns(); len(got) != 0 {
5332 if d >= timeout*3/2 {
5333 t.Logf("after %v, idle conns = %q", d, got)
5334 }
5335 return false
5336 }
5337 return true
5338 })
5339 break
5340 }
5341 }
5342
5343
5344
5345
5346
5347
5348
5349
5350
5351
5352
5353
5354 func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5355 func testIdleConnH2Crash(t *testing.T, mode testMode) {
5356 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5357
5358 }))
5359
5360 ctx, cancel := context.WithCancel(context.Background())
5361 defer cancel()
5362
5363 sawDoErr := make(chan bool, 1)
5364 testDone := make(chan struct{})
5365 defer close(testDone)
5366
5367 cst.tr.IdleConnTimeout = 5 * time.Millisecond
5368 cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5369 c, err := tls.Dial(network, addr, &tls.Config{
5370 InsecureSkipVerify: true,
5371 NextProtos: []string{"h2"},
5372 })
5373 if err != nil {
5374 t.Error(err)
5375 return nil, err
5376 }
5377 if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5378 t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5379 c.Close()
5380 return nil, errors.New("bogus")
5381 }
5382
5383 cancel()
5384
5385 select {
5386 case <-sawDoErr:
5387 case <-testDone:
5388 }
5389 return c, nil
5390 }
5391
5392 req, _ := NewRequest("GET", cst.ts.URL, nil)
5393 req = req.WithContext(ctx)
5394 res, err := cst.c.Do(req)
5395 if err == nil {
5396 res.Body.Close()
5397 t.Fatal("unexpected success")
5398 }
5399 sawDoErr <- true
5400
5401
5402 time.Sleep(cst.tr.IdleConnTimeout * 10)
5403 }
5404
5405 type funcConn struct {
5406 net.Conn
5407 read func([]byte) (int, error)
5408 write func([]byte) (int, error)
5409 }
5410
5411 func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
5412 func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5413 func (c funcConn) Close() error { return nil }
5414
5415
5416
5417 func TestTransportReturnsPeekError(t *testing.T) {
5418 errValue := errors.New("specific error value")
5419
5420 wrote := make(chan struct{})
5421 var wroteOnce sync.Once
5422
5423 tr := &Transport{
5424 Dial: func(network, addr string) (net.Conn, error) {
5425 c := funcConn{
5426 read: func([]byte) (int, error) {
5427 <-wrote
5428 return 0, errValue
5429 },
5430 write: func(p []byte) (int, error) {
5431 wroteOnce.Do(func() { close(wrote) })
5432 return len(p), nil
5433 },
5434 }
5435 return c, nil
5436 },
5437 }
5438 _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5439 if err != errValue {
5440 t.Errorf("error = %#v; want %v", err, errValue)
5441 }
5442 }
5443
5444
5445 func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5446 func testTransportIDNA(t *testing.T, mode testMode) {
5447 const uniDomain = "гофер.го"
5448 const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5449
5450 var port string
5451 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5452 want := punyDomain + ":" + port
5453 if r.Host != want {
5454 t.Errorf("Host header = %q; want %q", r.Host, want)
5455 }
5456 if mode == http2Mode {
5457 if r.TLS == nil {
5458 t.Errorf("r.TLS == nil")
5459 } else if r.TLS.ServerName != punyDomain {
5460 t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5461 }
5462 }
5463 w.Header().Set("Hit-Handler", "1")
5464 }), func(tr *Transport) {
5465 if tr.TLSClientConfig != nil {
5466 tr.TLSClientConfig.InsecureSkipVerify = true
5467 }
5468 })
5469
5470 ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5471 if err != nil {
5472 t.Fatal(err)
5473 }
5474
5475
5476 ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5477 if host != punyDomain {
5478 t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5479 return nil, nil
5480 }
5481 return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5482 })
5483
5484 req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5485 trace := &httptrace.ClientTrace{
5486 GetConn: func(hostPort string) {
5487 want := net.JoinHostPort(punyDomain, port)
5488 if hostPort != want {
5489 t.Errorf("getting conn for %q; want %q", hostPort, want)
5490 }
5491 },
5492 DNSStart: func(e httptrace.DNSStartInfo) {
5493 if e.Host != punyDomain {
5494 t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5495 }
5496 },
5497 }
5498 req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5499
5500 res, err := cst.tr.RoundTrip(req)
5501 if err != nil {
5502 t.Fatal(err)
5503 }
5504 defer res.Body.Close()
5505 if res.Header.Get("Hit-Handler") != "1" {
5506 out, err := httputil.DumpResponse(res, true)
5507 if err != nil {
5508 t.Fatal(err)
5509 }
5510 t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5511 }
5512 }
5513
5514
5515 func TestTransportProxyConnectHeader(t *testing.T) {
5516 run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5517 }
5518 func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5519 reqc := make(chan *Request, 1)
5520 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5521 if r.Method != "CONNECT" {
5522 t.Errorf("method = %q; want CONNECT", r.Method)
5523 }
5524 reqc <- r
5525 c, _, err := w.(Hijacker).Hijack()
5526 if err != nil {
5527 t.Errorf("Hijack: %v", err)
5528 return
5529 }
5530 c.Close()
5531 })).ts
5532
5533 c := ts.Client()
5534 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5535 return url.Parse(ts.URL)
5536 }
5537 c.Transport.(*Transport).ProxyConnectHeader = Header{
5538 "User-Agent": {"foo"},
5539 "Other": {"bar"},
5540 }
5541
5542 res, err := c.Get("https://dummy.tld/")
5543 if err == nil {
5544 res.Body.Close()
5545 t.Errorf("unexpected success")
5546 }
5547
5548 r := <-reqc
5549 if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5550 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5551 }
5552 if got, want := r.Header.Get("Other"), "bar"; got != want {
5553 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5554 }
5555 }
5556
5557 func TestTransportProxyGetConnectHeader(t *testing.T) {
5558 run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5559 }
5560 func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5561 reqc := make(chan *Request, 1)
5562 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5563 if r.Method != "CONNECT" {
5564 t.Errorf("method = %q; want CONNECT", r.Method)
5565 }
5566 reqc <- r
5567 c, _, err := w.(Hijacker).Hijack()
5568 if err != nil {
5569 t.Errorf("Hijack: %v", err)
5570 return
5571 }
5572 c.Close()
5573 })).ts
5574
5575 c := ts.Client()
5576 c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5577 return url.Parse(ts.URL)
5578 }
5579
5580 c.Transport.(*Transport).ProxyConnectHeader = Header{
5581 "User-Agent": {"foo"},
5582 "Other": {"bar"},
5583 }
5584 c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5585 return Header{
5586 "User-Agent": {"foo2"},
5587 "Other": {"bar2"},
5588 }, nil
5589 }
5590
5591 res, err := c.Get("https://dummy.tld/")
5592 if err == nil {
5593 res.Body.Close()
5594 t.Errorf("unexpected success")
5595 }
5596
5597 r := <-reqc
5598 if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5599 t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5600 }
5601 if got, want := r.Header.Get("Other"), "bar2"; got != want {
5602 t.Errorf("CONNECT request Other = %q; want %q", got, want)
5603 }
5604 }
5605
5606 var errFakeRoundTrip = errors.New("fake roundtrip")
5607
5608 type funcRoundTripper func()
5609
5610 func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5611 fn()
5612 return nil, errFakeRoundTrip
5613 }
5614
5615 func wantBody(res *Response, err error, want string) error {
5616 if err != nil {
5617 return err
5618 }
5619 slurp, err := io.ReadAll(res.Body)
5620 if err != nil {
5621 return fmt.Errorf("error reading body: %v", err)
5622 }
5623 if string(slurp) != want {
5624 return fmt.Errorf("body = %q; want %q", slurp, want)
5625 }
5626 if err := res.Body.Close(); err != nil {
5627 return fmt.Errorf("body Close = %v", err)
5628 }
5629 return nil
5630 }
5631
5632 func newLocalListener(t *testing.T) net.Listener {
5633 ln, err := net.Listen("tcp", "127.0.0.1:0")
5634 if err != nil {
5635 ln, err = net.Listen("tcp6", "[::1]:0")
5636 }
5637 if err != nil {
5638 t.Fatal(err)
5639 }
5640 return ln
5641 }
5642
5643 type countCloseReader struct {
5644 n *int
5645 io.Reader
5646 }
5647
5648 func (cr countCloseReader) Close() error {
5649 (*cr.n)++
5650 return nil
5651 }
5652
5653
5654 var rgz = []byte{
5655 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5656 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5657 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5658 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5659 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5660 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5661 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5662 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5663 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5664 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5665 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5666 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5667 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5668 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5669 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5670 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5671 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5672 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5673 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5674 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5675 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5676 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5677 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5678 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5679 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5680 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5681 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5682 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5683 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5684 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5685 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5686 0x00, 0x00,
5687 }
5688
5689
5690
5691 func TestMissingStatusNoPanic(t *testing.T) {
5692 t.Parallel()
5693
5694 const want = "unknown status code"
5695
5696 ln := newLocalListener(t)
5697 addr := ln.Addr().String()
5698 done := make(chan bool)
5699 fullAddrURL := fmt.Sprintf("http://%s", addr)
5700 raw := "HTTP/1.1 400\r\n" +
5701 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5702 "Content-Type: text/html; charset=utf-8\r\n" +
5703 "Content-Length: 10\r\n" +
5704 "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5705 "Vary: Accept-Encoding\r\n\r\n" +
5706 "Aloha Olaa"
5707
5708 go func() {
5709 defer close(done)
5710
5711 conn, _ := ln.Accept()
5712 if conn != nil {
5713 io.WriteString(conn, raw)
5714 io.ReadAll(conn)
5715 conn.Close()
5716 }
5717 }()
5718
5719 proxyURL, err := url.Parse(fullAddrURL)
5720 if err != nil {
5721 t.Fatalf("proxyURL: %v", err)
5722 }
5723
5724 tr := &Transport{Proxy: ProxyURL(proxyURL)}
5725
5726 req, _ := NewRequest("GET", "https://golang.org/", nil)
5727 res, err, panicked := doFetchCheckPanic(tr, req)
5728 if panicked {
5729 t.Error("panicked, expecting an error")
5730 }
5731 if res != nil && res.Body != nil {
5732 io.Copy(io.Discard, res.Body)
5733 res.Body.Close()
5734 }
5735
5736 if err == nil || !strings.Contains(err.Error(), want) {
5737 t.Errorf("got=%v want=%q", err, want)
5738 }
5739
5740 ln.Close()
5741 <-done
5742 }
5743
5744 func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5745 defer func() {
5746 if r := recover(); r != nil {
5747 panicked = true
5748 }
5749 }()
5750 res, err = tr.RoundTrip(req)
5751 return
5752 }
5753
5754
5755
5756 func TestNoBodyOnChunked304Response(t *testing.T) {
5757 run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5758 }
5759 func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5760 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5761 conn, buf, _ := w.(Hijacker).Hijack()
5762 buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5763 buf.Flush()
5764 conn.Close()
5765 }))
5766
5767
5768
5769
5770
5771 cst.tr.DisableKeepAlives = true
5772
5773 res, err := cst.c.Get(cst.ts.URL)
5774 if err != nil {
5775 t.Fatal(err)
5776 }
5777
5778 if res.Body != NoBody {
5779 t.Errorf("Unexpected body on 304 response")
5780 }
5781 }
5782
5783 type funcWriter func([]byte) (int, error)
5784
5785 func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5786
5787 type doneContext struct {
5788 context.Context
5789 err error
5790 }
5791
5792 func (doneContext) Done() <-chan struct{} {
5793 c := make(chan struct{})
5794 close(c)
5795 return c
5796 }
5797
5798 func (d doneContext) Err() error { return d.err }
5799
5800
5801 func TestTransportCheckContextDoneEarly(t *testing.T) {
5802 tr := &Transport{}
5803 req, _ := NewRequest("GET", "http://fake.example/", nil)
5804 wantErr := errors.New("some error")
5805 req = req.WithContext(doneContext{context.Background(), wantErr})
5806 _, err := tr.RoundTrip(req)
5807 if err != wantErr {
5808 t.Errorf("error = %v; want %v", err, wantErr)
5809 }
5810 }
5811
5812
5813
5814
5815
5816
5817 func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5818 run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5819 }
5820 func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5821 timeout := 1 * time.Millisecond
5822 for {
5823 inHandler := make(chan bool)
5824 cancelHandler := make(chan struct{})
5825 handlerDone := make(chan bool)
5826 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5827 <-r.Context().Done()
5828
5829 select {
5830 case <-cancelHandler:
5831 return
5832 case inHandler <- true:
5833 }
5834 defer func() { handlerDone <- true }()
5835
5836
5837 conn, _, err := w.(Hijacker).Hijack()
5838 if err != nil {
5839 t.Error(err)
5840 return
5841 }
5842 n, err := conn.Read([]byte{0})
5843 if n != 0 || err != io.EOF {
5844 t.Errorf("unexpected Read result: %v, %v", n, err)
5845 }
5846 conn.Close()
5847 }))
5848
5849 cst.c.Timeout = timeout
5850
5851 _, err := cst.c.Get(cst.ts.URL)
5852 if err == nil {
5853 close(cancelHandler)
5854 t.Fatal("unexpected Get success")
5855 }
5856
5857 tooSlow := time.NewTimer(timeout * 10)
5858 select {
5859 case <-tooSlow.C:
5860
5861
5862
5863 t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5864 close(cancelHandler)
5865 cst.close()
5866 timeout *= 2
5867 continue
5868 case <-inHandler:
5869 tooSlow.Stop()
5870 <-handlerDone
5871 }
5872 break
5873 }
5874 }
5875
5876
5877
5878
5879
5880
5881 func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5882 run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5883 }
5884 func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5885 inHandler := make(chan bool)
5886 cancelHandler := make(chan struct{})
5887 handlerDone := make(chan bool)
5888 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5889 w.Header().Set("Content-Length", "100")
5890 w.(Flusher).Flush()
5891
5892 select {
5893 case <-cancelHandler:
5894 return
5895 case inHandler <- true:
5896 }
5897 defer func() { handlerDone <- true }()
5898
5899 conn, _, err := w.(Hijacker).Hijack()
5900 if err != nil {
5901 t.Error(err)
5902 return
5903 }
5904 conn.Write([]byte("foo"))
5905
5906 n, err := conn.Read([]byte{0})
5907
5908
5909
5910
5911
5912 if n != 0 || err == nil {
5913 t.Errorf("unexpected Read result: %v, %v", n, err)
5914 }
5915 conn.Close()
5916 }))
5917
5918
5919
5920
5921
5922 cst.c.Timeout = 24 * time.Hour
5923 req, _ := NewRequest("GET", cst.ts.URL, nil)
5924 cancelReq := make(chan struct{})
5925 req.Cancel = cancelReq
5926
5927 res, err := cst.c.Do(req)
5928 if err != nil {
5929 close(cancelHandler)
5930 t.Fatalf("Get error: %v", err)
5931 }
5932
5933
5934
5935
5936 close(cancelReq)
5937 got, err := io.ReadAll(res.Body)
5938 if err == nil {
5939 t.Errorf("unexpected success; read %q, nil", got)
5940 }
5941
5942
5943 <-inHandler
5944 <-handlerDone
5945 }
5946
5947 func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5948 run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
5949 }
5950 func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
5951 done := make(chan struct{})
5952 defer close(done)
5953 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5954 conn, _, err := w.(Hijacker).Hijack()
5955 if err != nil {
5956 t.Error(err)
5957 return
5958 }
5959 defer conn.Close()
5960 io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
5961 bs := bufio.NewScanner(conn)
5962 bs.Scan()
5963 fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
5964 <-done
5965 }))
5966
5967 req, _ := NewRequest("GET", cst.ts.URL, nil)
5968 req.Header.Set("Upgrade", "foo")
5969 req.Header.Set("Connection", "upgrade")
5970 res, err := cst.c.Do(req)
5971 if err != nil {
5972 t.Fatal(err)
5973 }
5974 if res.StatusCode != 101 {
5975 t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
5976 }
5977 rwc, ok := res.Body.(io.ReadWriteCloser)
5978 if !ok {
5979 t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
5980 }
5981 defer rwc.Close()
5982 bs := bufio.NewScanner(rwc)
5983 if !bs.Scan() {
5984 t.Fatalf("expected readable input")
5985 }
5986 if got, want := bs.Text(), "Some buffered data"; got != want {
5987 t.Errorf("read %q; want %q", got, want)
5988 }
5989 io.WriteString(rwc, "echo\n")
5990 if !bs.Scan() {
5991 t.Fatalf("expected another line")
5992 }
5993 if got, want := bs.Text(), "ECHO"; got != want {
5994 t.Errorf("read %q; want %q", got, want)
5995 }
5996 }
5997
5998 func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
5999 func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6000 const target = "backend:443"
6001 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6002 if r.Method != "CONNECT" {
6003 t.Errorf("unexpected method %q", r.Method)
6004 w.WriteHeader(500)
6005 return
6006 }
6007 if r.RequestURI != target {
6008 t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6009 w.WriteHeader(500)
6010 return
6011 }
6012 nc, brw, err := w.(Hijacker).Hijack()
6013 if err != nil {
6014 t.Error(err)
6015 return
6016 }
6017 defer nc.Close()
6018 nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6019
6020 for {
6021 line, err := brw.ReadString('\n')
6022 if err != nil {
6023 if err != io.EOF {
6024 t.Error(err)
6025 }
6026 return
6027 }
6028 io.WriteString(brw, strings.ToUpper(line))
6029 brw.Flush()
6030 }
6031 }))
6032 pr, pw := io.Pipe()
6033 defer pw.Close()
6034 req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6035 if err != nil {
6036 t.Fatal(err)
6037 }
6038 req.URL.Opaque = target
6039 res, err := cst.c.Do(req)
6040 if err != nil {
6041 t.Fatal(err)
6042 }
6043 defer res.Body.Close()
6044 if res.StatusCode != 200 {
6045 t.Fatalf("status code = %d; want 200", res.StatusCode)
6046 }
6047 br := bufio.NewReader(res.Body)
6048 for _, str := range []string{"foo", "bar", "baz"} {
6049 fmt.Fprintf(pw, "%s\n", str)
6050 got, err := br.ReadString('\n')
6051 if err != nil {
6052 t.Fatal(err)
6053 }
6054 got = strings.TrimSpace(got)
6055 want := strings.ToUpper(str)
6056 if got != want {
6057 t.Fatalf("got %q; want %q", got, want)
6058 }
6059 }
6060 }
6061
6062 func TestTransportRequestReplayable(t *testing.T) {
6063 someBody := io.NopCloser(strings.NewReader(""))
6064 tests := []struct {
6065 name string
6066 req *Request
6067 want bool
6068 }{
6069 {
6070 name: "GET",
6071 req: &Request{Method: "GET"},
6072 want: true,
6073 },
6074 {
6075 name: "GET_http.NoBody",
6076 req: &Request{Method: "GET", Body: NoBody},
6077 want: true,
6078 },
6079 {
6080 name: "GET_body",
6081 req: &Request{Method: "GET", Body: someBody},
6082 want: false,
6083 },
6084 {
6085 name: "POST",
6086 req: &Request{Method: "POST"},
6087 want: false,
6088 },
6089 {
6090 name: "POST_idempotency-key",
6091 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6092 want: true,
6093 },
6094 {
6095 name: "POST_x-idempotency-key",
6096 req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6097 want: true,
6098 },
6099 {
6100 name: "POST_body",
6101 req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6102 want: false,
6103 },
6104 }
6105 for _, tt := range tests {
6106 t.Run(tt.name, func(t *testing.T) {
6107 got := tt.req.ExportIsReplayable()
6108 if got != tt.want {
6109 t.Errorf("replyable = %v; want %v", got, tt.want)
6110 }
6111 })
6112 }
6113 }
6114
6115
6116
6117 type testMockTCPConn struct {
6118 *net.TCPConn
6119
6120 ReadFromCalled bool
6121 }
6122
6123 func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6124 c.ReadFromCalled = true
6125 return c.TCPConn.ReadFrom(r)
6126 }
6127
6128 func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6129 func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6130 nBytes := int64(1 << 10)
6131 newFileFunc := func() (r io.Reader, done func(), err error) {
6132 f, err := os.CreateTemp("", "net-http-newfilefunc")
6133 if err != nil {
6134 return nil, nil, err
6135 }
6136
6137
6138 if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6139 return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6140 }
6141 if _, err := f.Seek(0, 0); err != nil {
6142 return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6143 }
6144
6145 done = func() {
6146 f.Close()
6147 os.Remove(f.Name())
6148 }
6149
6150 return f, done, nil
6151 }
6152
6153 newBufferFunc := func() (io.Reader, func(), error) {
6154 return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6155 }
6156
6157 cases := []struct {
6158 name string
6159 readerFunc func() (io.Reader, func(), error)
6160 contentLength int64
6161 expectedReadFrom bool
6162 }{
6163 {
6164 name: "file, length",
6165 readerFunc: newFileFunc,
6166 contentLength: nBytes,
6167 expectedReadFrom: true,
6168 },
6169 {
6170 name: "file, no length",
6171 readerFunc: newFileFunc,
6172 },
6173 {
6174 name: "file, negative length",
6175 readerFunc: newFileFunc,
6176 contentLength: -1,
6177 },
6178 {
6179 name: "buffer",
6180 contentLength: nBytes,
6181 readerFunc: newBufferFunc,
6182 },
6183 {
6184 name: "buffer, no length",
6185 readerFunc: newBufferFunc,
6186 },
6187 {
6188 name: "buffer, length -1",
6189 contentLength: -1,
6190 readerFunc: newBufferFunc,
6191 },
6192 }
6193
6194 for _, tc := range cases {
6195 t.Run(tc.name, func(t *testing.T) {
6196 r, cleanup, err := tc.readerFunc()
6197 if err != nil {
6198 t.Fatal(err)
6199 }
6200 defer cleanup()
6201
6202 tConn := &testMockTCPConn{}
6203 trFunc := func(tr *Transport) {
6204 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6205 var d net.Dialer
6206 conn, err := d.DialContext(ctx, network, addr)
6207 if err != nil {
6208 return nil, err
6209 }
6210
6211 tcpConn, ok := conn.(*net.TCPConn)
6212 if !ok {
6213 return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6214 }
6215
6216 tConn.TCPConn = tcpConn
6217 return tConn, nil
6218 }
6219 }
6220
6221 cst := newClientServerTest(
6222 t,
6223 mode,
6224 HandlerFunc(func(w ResponseWriter, r *Request) {
6225 io.Copy(io.Discard, r.Body)
6226 r.Body.Close()
6227 w.WriteHeader(200)
6228 }),
6229 trFunc,
6230 )
6231
6232 req, err := NewRequest("PUT", cst.ts.URL, r)
6233 if err != nil {
6234 t.Fatal(err)
6235 }
6236 req.ContentLength = tc.contentLength
6237 req.Header.Set("Content-Type", "application/octet-stream")
6238 resp, err := cst.c.Do(req)
6239 if err != nil {
6240 t.Fatal(err)
6241 }
6242 defer resp.Body.Close()
6243 if resp.StatusCode != 200 {
6244 t.Fatalf("status code = %d; want 200", resp.StatusCode)
6245 }
6246
6247 expectedReadFrom := tc.expectedReadFrom
6248 if mode != http1Mode {
6249 expectedReadFrom = false
6250 }
6251 if !tConn.ReadFromCalled && expectedReadFrom {
6252 t.Fatalf("did not call ReadFrom")
6253 }
6254
6255 if tConn.ReadFromCalled && !expectedReadFrom {
6256 t.Fatalf("ReadFrom was unexpectedly invoked")
6257 }
6258 })
6259 }
6260 }
6261
6262 func TestTransportClone(t *testing.T) {
6263 tr := &Transport{
6264 Proxy: func(*Request) (*url.URL, error) { panic("") },
6265 OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6266 return nil
6267 },
6268 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6269 Dial: func(network, addr string) (net.Conn, error) { panic("") },
6270 DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
6271 DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6272 TLSClientConfig: new(tls.Config),
6273 TLSHandshakeTimeout: time.Second,
6274 DisableKeepAlives: true,
6275 DisableCompression: true,
6276 MaxIdleConns: 1,
6277 MaxIdleConnsPerHost: 1,
6278 MaxConnsPerHost: 1,
6279 IdleConnTimeout: time.Second,
6280 ResponseHeaderTimeout: time.Second,
6281 ExpectContinueTimeout: time.Second,
6282 ProxyConnectHeader: Header{},
6283 GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6284 MaxResponseHeaderBytes: 1,
6285 ForceAttemptHTTP2: true,
6286 TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6287 "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6288 },
6289 ReadBufferSize: 1,
6290 WriteBufferSize: 1,
6291 }
6292 tr2 := tr.Clone()
6293 rv := reflect.ValueOf(tr2).Elem()
6294 rt := rv.Type()
6295 for i := 0; i < rt.NumField(); i++ {
6296 sf := rt.Field(i)
6297 if !token.IsExported(sf.Name) {
6298 continue
6299 }
6300 if rv.Field(i).IsZero() {
6301 t.Errorf("cloned field t2.%s is zero", sf.Name)
6302 }
6303 }
6304
6305 if _, ok := tr2.TLSNextProto["foo"]; !ok {
6306 t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6307 }
6308
6309
6310 tr = new(Transport)
6311 tr2 = tr.Clone()
6312 if tr2.TLSNextProto != nil {
6313 t.Errorf("Transport.TLSNextProto unexpected non-nil")
6314 }
6315 }
6316
6317 func TestIs408(t *testing.T) {
6318 tests := []struct {
6319 in string
6320 want bool
6321 }{
6322 {"HTTP/1.0 408", true},
6323 {"HTTP/1.1 408", true},
6324 {"HTTP/1.8 408", true},
6325 {"HTTP/2.0 408", false},
6326 {"HTTP/1.1 408 ", true},
6327 {"HTTP/1.1 40", false},
6328 {"http/1.0 408", false},
6329 {"HTTP/1-1 408", false},
6330 }
6331 for _, tt := range tests {
6332 if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6333 t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6334 }
6335 }
6336 }
6337
6338 func TestTransportIgnores408(t *testing.T) {
6339 run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6340 }
6341 func testTransportIgnores408(t *testing.T, mode testMode) {
6342
6343 defer log.SetOutput(log.Writer())
6344
6345 var logout strings.Builder
6346 log.SetOutput(&logout)
6347
6348 const target = "backend:443"
6349
6350 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6351 nc, _, err := w.(Hijacker).Hijack()
6352 if err != nil {
6353 t.Error(err)
6354 return
6355 }
6356 defer nc.Close()
6357 nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6358 nc.Write([]byte("HTTP/1.1 408 bye\r\n"))
6359 }))
6360 req, err := NewRequest("GET", cst.ts.URL, nil)
6361 if err != nil {
6362 t.Fatal(err)
6363 }
6364 res, err := cst.c.Do(req)
6365 if err != nil {
6366 t.Fatal(err)
6367 }
6368 slurp, err := io.ReadAll(res.Body)
6369 if err != nil {
6370 t.Fatal(err)
6371 }
6372 if err != nil {
6373 t.Fatal(err)
6374 }
6375 if string(slurp) != "ok" {
6376 t.Fatalf("got %q; want ok", slurp)
6377 }
6378
6379 waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6380 if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6381 if d > 0 {
6382 t.Logf("%v idle conns still present after %v", n, d)
6383 }
6384 return false
6385 }
6386 return true
6387 })
6388 if got := logout.String(); got != "" {
6389 t.Fatalf("expected no log output; got: %s", got)
6390 }
6391 }
6392
6393 func TestInvalidHeaderResponse(t *testing.T) {
6394 run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6395 }
6396 func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6397 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6398 conn, buf, _ := w.(Hijacker).Hijack()
6399 buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6400 "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6401 "Content-Type: text/html; charset=utf-8\r\n" +
6402 "Content-Length: 0\r\n" +
6403 "Foo : bar\r\n\r\n"))
6404 buf.Flush()
6405 conn.Close()
6406 }))
6407 res, err := cst.c.Get(cst.ts.URL)
6408 if err != nil {
6409 t.Fatal(err)
6410 }
6411 defer res.Body.Close()
6412 if v := res.Header.Get("Foo"); v != "" {
6413 t.Errorf(`unexpected "Foo" header: %q`, v)
6414 }
6415 if v := res.Header.Get("Foo "); v != "bar" {
6416 t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6417 }
6418 }
6419
6420 type bodyCloser bool
6421
6422 func (bc *bodyCloser) Close() error {
6423 *bc = true
6424 return nil
6425 }
6426 func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6427 return 0, io.EOF
6428 }
6429
6430
6431
6432 func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6433 run(t, testTransportClosesBodyOnInvalidRequests)
6434 }
6435 func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6436 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6437 t.Errorf("Should not have been invoked")
6438 })).ts
6439
6440 u, _ := url.Parse(cst.URL)
6441
6442 tests := []struct {
6443 name string
6444 req *Request
6445 wantErr string
6446 }{
6447 {
6448 name: "invalid method",
6449 req: &Request{
6450 Method: " ",
6451 URL: u,
6452 },
6453 wantErr: `invalid method " "`,
6454 },
6455 {
6456 name: "nil URL",
6457 req: &Request{
6458 Method: "GET",
6459 },
6460 wantErr: `nil Request.URL`,
6461 },
6462 {
6463 name: "invalid header key",
6464 req: &Request{
6465 Method: "GET",
6466 Header: Header{"💡": {"emoji"}},
6467 URL: u,
6468 },
6469 wantErr: `invalid header field name "💡"`,
6470 },
6471 {
6472 name: "invalid header value",
6473 req: &Request{
6474 Method: "POST",
6475 Header: Header{"key": {"\x19"}},
6476 URL: u,
6477 },
6478 wantErr: `invalid header field value for "key"`,
6479 },
6480 {
6481 name: "non HTTP(s) scheme",
6482 req: &Request{
6483 Method: "POST",
6484 URL: &url.URL{Scheme: "faux"},
6485 },
6486 wantErr: `unsupported protocol scheme "faux"`,
6487 },
6488 {
6489 name: "no Host in URL",
6490 req: &Request{
6491 Method: "POST",
6492 URL: &url.URL{Scheme: "http"},
6493 },
6494 wantErr: `no Host in request URL`,
6495 },
6496 }
6497
6498 for _, tt := range tests {
6499 t.Run(tt.name, func(t *testing.T) {
6500 var bc bodyCloser
6501 req := tt.req
6502 req.Body = &bc
6503 _, err := cst.Client().Do(tt.req)
6504 if err == nil {
6505 t.Fatal("Expected an error")
6506 }
6507 if !bc {
6508 t.Fatal("Expected body to have been closed")
6509 }
6510 if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6511 t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6512 }
6513 })
6514 }
6515 }
6516
6517
6518
6519 type breakableConn struct {
6520 net.Conn
6521 *brokenState
6522 }
6523
6524 type brokenState struct {
6525 sync.Mutex
6526 broken bool
6527 }
6528
6529 func (w *breakableConn) Write(b []byte) (n int, err error) {
6530 w.Lock()
6531 defer w.Unlock()
6532 if w.broken {
6533 return 0, errors.New("some write error")
6534 }
6535 return w.Conn.Write(b)
6536 }
6537
6538
6539 func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6540 run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6541 }
6542 func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6543 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6544
6545 var brokenState brokenState
6546
6547 const numReqs = 5
6548 var numDials, gotConns uint32
6549
6550 cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6551 atomic.AddUint32(&numDials, 1)
6552 c, err := net.Dial(netw, addr)
6553 if err != nil {
6554 t.Errorf("unexpected Dial error: %v", err)
6555 return nil, err
6556 }
6557 return &breakableConn{c, &brokenState}, err
6558 }
6559
6560 for i := 1; i <= numReqs; i++ {
6561 brokenState.Lock()
6562 brokenState.broken = false
6563 brokenState.Unlock()
6564
6565
6566
6567
6568 doBreak := i != numReqs
6569
6570 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6571 GotConn: func(info httptrace.GotConnInfo) {
6572 t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6573 atomic.AddUint32(&gotConns, 1)
6574 },
6575 TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6576 brokenState.Lock()
6577 defer brokenState.Unlock()
6578 if doBreak {
6579 brokenState.broken = true
6580 }
6581 },
6582 })
6583 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6584 if err != nil {
6585 t.Fatal(err)
6586 }
6587 _, err = cst.c.Do(req)
6588 if doBreak != (err != nil) {
6589 t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6590 }
6591 }
6592 if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6593 t.Errorf("GotConn calls = %v; want %v", got, want)
6594 }
6595 if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6596 t.Errorf("Dials = %v; want %v", got, want)
6597 }
6598 }
6599
6600
6601
6602
6603
6604 func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6605 run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6606 }
6607 func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6608 CondSkipHTTP2(t)
6609
6610 h := HandlerFunc(func(w ResponseWriter, r *Request) {
6611 _, err := w.Write([]byte("foo"))
6612 if err != nil {
6613 t.Fatalf("Write: %v", err)
6614 }
6615 })
6616
6617 ts := newClientServerTest(t, mode, h).ts
6618
6619 c := ts.Client()
6620 tr := c.Transport.(*Transport)
6621 tr.MaxConnsPerHost = 1
6622
6623 errCh := make(chan error, 300)
6624 doReq := func() {
6625 resp, err := c.Get(ts.URL)
6626 if err != nil {
6627 errCh <- fmt.Errorf("request failed: %v", err)
6628 return
6629 }
6630 defer resp.Body.Close()
6631 _, err = io.ReadAll(resp.Body)
6632 if err != nil {
6633 errCh <- fmt.Errorf("read body failed: %v", err)
6634 }
6635 }
6636
6637 var wg sync.WaitGroup
6638 for i := 0; i < 300; i++ {
6639 wg.Add(1)
6640 go func() {
6641 defer wg.Done()
6642 doReq()
6643 }()
6644 }
6645 wg.Wait()
6646 close(errCh)
6647
6648 for err := range errCh {
6649 t.Errorf("error occurred: %v", err)
6650 }
6651 }
6652
6653
6654
6655
6656 func TestAltProtoCancellation(t *testing.T) {
6657 defer afterTest(t)
6658 tr := &Transport{}
6659 c := &Client{
6660 Transport: tr,
6661 Timeout: time.Millisecond,
6662 }
6663 tr.RegisterProtocol("cancel", cancelProto{})
6664 _, err := c.Get("cancel://bar.com/path")
6665 if err == nil {
6666 t.Error("request unexpectedly succeeded")
6667 } else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6668 t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6669 }
6670 }
6671
6672 var errCancelProto = errors.New("canceled as expected")
6673
6674 type cancelProto struct{}
6675
6676 func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6677 <-req.Cancel
6678 return nil, errCancelProto
6679 }
6680
6681 type roundTripFunc func(r *Request) (*Response, error)
6682
6683 func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6684
6685
6686 func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6687 func testIssue32441(t *testing.T, mode testMode) {
6688 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6689 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6690 t.Error("body length is zero")
6691 }
6692 })).ts
6693 c := ts.Client()
6694 c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6695
6696 if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6697 t.Error("body length is zero during round trip")
6698 }
6699 return nil, ErrSkipAltProtocol
6700 }))
6701 if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6702 t.Error(err)
6703 }
6704 }
6705
6706
6707
6708 func TestTransportRejectsSignInContentLength(t *testing.T) {
6709 run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6710 }
6711 func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6712 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6713 w.Header().Set("Content-Length", "+3")
6714 w.Write([]byte("abc"))
6715 })).ts
6716
6717 c := cst.Client()
6718 res, err := c.Get(cst.URL)
6719 if err == nil || res != nil {
6720 t.Fatal("Expected a non-nil error and a nil http.Response")
6721 }
6722 if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6723 t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6724 }
6725 }
6726
6727
6728 type dumpConn struct {
6729 io.Writer
6730 io.Reader
6731 }
6732
6733 func (c *dumpConn) Close() error { return nil }
6734 func (c *dumpConn) LocalAddr() net.Addr { return nil }
6735 func (c *dumpConn) RemoteAddr() net.Addr { return nil }
6736 func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
6737 func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
6738 func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6739
6740
6741
6742 type delegateReader struct {
6743 c chan io.Reader
6744 r io.Reader
6745 }
6746
6747 func (r *delegateReader) Read(p []byte) (int, error) {
6748 if r.r == nil {
6749 var ok bool
6750 if r.r, ok = <-r.c; !ok {
6751 return 0, errors.New("delegate closed")
6752 }
6753 }
6754 return r.r.Read(p)
6755 }
6756
6757 func testTransportRace(req *Request) {
6758 save := req.Body
6759 pr, pw := io.Pipe()
6760 defer pr.Close()
6761 defer pw.Close()
6762 dr := &delegateReader{c: make(chan io.Reader)}
6763
6764 t := &Transport{
6765 Dial: func(net, addr string) (net.Conn, error) {
6766 return &dumpConn{pw, dr}, nil
6767 },
6768 }
6769 defer t.CloseIdleConnections()
6770
6771 quitReadCh := make(chan struct{})
6772
6773 go func() {
6774 defer close(quitReadCh)
6775
6776 req, err := ReadRequest(bufio.NewReader(pr))
6777 if err == nil {
6778
6779
6780 io.Copy(io.Discard, req.Body)
6781 req.Body.Close()
6782 }
6783 select {
6784 case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6785 case quitReadCh <- struct{}{}:
6786
6787 close(dr.c)
6788 }
6789 }()
6790
6791 t.RoundTrip(req)
6792
6793
6794
6795 pw.Close()
6796 <-quitReadCh
6797
6798 req.Body = save
6799 }
6800
6801
6802
6803
6804
6805 func TestErrorWriteLoopRace(t *testing.T) {
6806 if testing.Short() {
6807 return
6808 }
6809 t.Parallel()
6810 for i := 0; i < 1000; i++ {
6811 delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6812 ctx, cancel := context.WithTimeout(context.Background(), delay)
6813 defer cancel()
6814
6815 r := bytes.NewBuffer(make([]byte, 10000))
6816 req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6817 if err != nil {
6818 t.Fatal(err)
6819 }
6820
6821 testTransportRace(req)
6822 }
6823 }
6824
6825
6826
6827
6828 func TestCancelRequestWhenSharingConnection(t *testing.T) {
6829 run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6830 }
6831 func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6832 reqc := make(chan chan struct{}, 2)
6833 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6834 ch := make(chan struct{}, 1)
6835 reqc <- ch
6836 <-ch
6837 w.Header().Add("Content-Length", "0")
6838 })).ts
6839
6840 client := ts.Client()
6841 transport := client.Transport.(*Transport)
6842 transport.MaxIdleConns = 1
6843 transport.MaxConnsPerHost = 1
6844
6845 var wg sync.WaitGroup
6846
6847 wg.Add(1)
6848 putidlec := make(chan chan struct{}, 1)
6849 reqerrc := make(chan error, 1)
6850 go func() {
6851 defer wg.Done()
6852 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6853 PutIdleConn: func(error) {
6854
6855
6856 ch := make(chan struct{})
6857 putidlec <- ch
6858 close(putidlec)
6859 <-ch
6860 },
6861 })
6862 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6863 res, err := client.Do(req)
6864 if err != nil {
6865 reqerrc <- err
6866 } else {
6867 res.Body.Close()
6868 }
6869 }()
6870
6871
6872
6873 select {
6874 case err := <-reqerrc:
6875 t.Fatalf("request 1: got err %v, want nil", err)
6876 case r1c := <-reqc:
6877 close(r1c)
6878 }
6879 var idlec chan struct{}
6880 select {
6881 case err := <-reqerrc:
6882 t.Fatalf("request 1: got err %v, want nil", err)
6883 case idlec = <-putidlec:
6884 }
6885
6886 wg.Add(1)
6887 cancelctx, cancel := context.WithCancel(context.Background())
6888 go func() {
6889 defer wg.Done()
6890 req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6891 res, err := client.Do(req)
6892 if err == nil {
6893 res.Body.Close()
6894 }
6895 if !errors.Is(err, context.Canceled) {
6896 t.Errorf("request 2: got err %v, want Canceled", err)
6897 }
6898
6899
6900 close(idlec)
6901 }()
6902
6903
6904
6905 r2c := <-reqc
6906 cancel()
6907
6908 <-idlec
6909
6910 close(r2c)
6911 wg.Wait()
6912 }
6913
6914 func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6915 func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6916 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6917 go io.Copy(io.Discard, req.Body)
6918 panic(ErrAbortHandler)
6919 })).ts
6920
6921 var wg sync.WaitGroup
6922 for i := 0; i < 2; i++ {
6923 wg.Add(1)
6924 go func() {
6925 defer wg.Done()
6926 for j := 0; j < 10; j++ {
6927 const reqLen = 6 * 1024 * 1024
6928 req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6929 req.ContentLength = reqLen
6930 resp, _ := ts.Client().Transport.RoundTrip(req)
6931 if resp != nil {
6932 resp.Body.Close()
6933 }
6934 }
6935 }()
6936 }
6937 wg.Wait()
6938 }
6939
6940 func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
6941 func testRequestSanitization(t *testing.T, mode testMode) {
6942 if mode == http2Mode {
6943
6944 t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
6945 }
6946 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6947 if h, ok := req.Header["X-Evil"]; ok {
6948 t.Errorf("request has X-Evil header: %q", h)
6949 }
6950 })).ts
6951 req, _ := NewRequest("GET", ts.URL, nil)
6952 req.Host = "go.dev\r\nX-Evil:evil"
6953 resp, _ := ts.Client().Do(req)
6954 if resp != nil {
6955 resp.Body.Close()
6956 }
6957 }
6958
6959 func TestProxyAuthHeader(t *testing.T) {
6960
6961 run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
6962 }
6963 func testProxyAuthHeader(t *testing.T, mode testMode) {
6964 const username = "u"
6965 const password = "@/?!"
6966 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6967
6968
6969 var r2 Request
6970 r2.Header = Header{
6971 "Authorization": req.Header["Proxy-Authorization"],
6972 }
6973 gotuser, gotpass, ok := r2.BasicAuth()
6974 if !ok || gotuser != username || gotpass != password {
6975 t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
6976 }
6977 }))
6978 u, err := url.Parse(cst.ts.URL)
6979 if err != nil {
6980 t.Fatal(err)
6981 }
6982 u.User = url.UserPassword(username, password)
6983 t.Setenv("HTTP_PROXY", u.String())
6984 cst.tr.Proxy = ProxyURL(u)
6985 resp, err := cst.c.Get("http://_/")
6986 if err != nil {
6987 t.Fatal(err)
6988 }
6989 resp.Body.Close()
6990 }
6991
6992
6993 func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
6994 ln := newLocalListener(t)
6995 addr := ln.Addr().String()
6996
6997 done := make(chan struct{})
6998 go func() {
6999 conn, err := ln.Accept()
7000 if err != nil {
7001 t.Errorf("ln.Accept: %v", err)
7002 return
7003 }
7004
7005
7006 if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7007 t.Errorf("conn.Read: %v", err)
7008 return
7009 }
7010 io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7011 <-done
7012 conn.Close()
7013 }()
7014
7015 didRead := make(chan bool)
7016 SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7017 defer SetReadLoopBeforeNextReadHook(nil)
7018
7019 tr := &Transport{}
7020
7021
7022 req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7023 if err != nil {
7024 t.Fatalf("NewRequest: %v", err)
7025 }
7026
7027 resp, err := tr.RoundTrip(req)
7028 if err != nil {
7029 t.Fatalf("tr.RoundTrip: %v", err)
7030 }
7031
7032 close(done)
7033
7034
7035
7036 <-didRead
7037
7038 resp.Body.Close()
7039
7040
7041
7042 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7043 n := tr.NumPendingRequestsForTesting()
7044 if n > 0 {
7045 if d > 0 {
7046 t.Logf("pending requests = %d after %v (want 0)", n, d)
7047 }
7048 return false
7049 }
7050 return true
7051 })
7052 }
7053
7054 func TestValidateClientRequestTrailers(t *testing.T) {
7055 run(t, testValidateClientRequestTrailers)
7056 }
7057
7058 func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7059 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7060 rw.Write([]byte("Hello"))
7061 })).ts
7062
7063 cases := []struct {
7064 trailer Header
7065 wantErr string
7066 }{
7067 {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7068 {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7069 }
7070
7071 for i, tt := range cases {
7072 testName := fmt.Sprintf("%s%d", mode, i)
7073 t.Run(testName, func(t *testing.T) {
7074 req, err := NewRequest("GET", cst.URL, nil)
7075 if err != nil {
7076 t.Fatal(err)
7077 }
7078 req.Trailer = tt.trailer
7079 res, err := cst.Client().Do(req)
7080 if err == nil {
7081 t.Fatal("Expected an error")
7082 }
7083 if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7084 t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7085 }
7086 if res != nil {
7087 t.Fatal("Unexpected non-nil response")
7088 }
7089 })
7090 }
7091 }
7092
View as plain text