Source file src/net/http/transport_dial_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http_test
     6  
     7  import (
     8  	"context"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"net/http/httptrace"
    13  	"testing"
    14  )
    15  
    16  func TestTransportPoolConnReusePriorConnection(t *testing.T) {
    17  	dt := newTransportDialTester(t, http1Mode)
    18  
    19  	// First request creates a new connection.
    20  	rt1 := dt.roundTrip()
    21  	c1 := dt.wantDial()
    22  	c1.finish(nil)
    23  	rt1.wantDone(c1)
    24  	rt1.finish()
    25  
    26  	// Second request reuses the first connection.
    27  	rt2 := dt.roundTrip()
    28  	rt2.wantDone(c1)
    29  	rt2.finish()
    30  }
    31  
    32  func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
    33  	dt := newTransportDialTester(t, http1Mode)
    34  
    35  	// First request creates a new connection.
    36  	rt1 := dt.roundTrip()
    37  	c1 := dt.wantDial()
    38  	c1.finish(nil)
    39  	rt1.wantDone(c1)
    40  
    41  	// Second request is made while the first request is still using its connection,
    42  	// so it goes on a new connection.
    43  	rt2 := dt.roundTrip()
    44  	c2 := dt.wantDial()
    45  	c2.finish(nil)
    46  	rt2.wantDone(c2)
    47  }
    48  
    49  func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
    50  	dt := newTransportDialTester(t, http1Mode)
    51  
    52  	// First request creates a new connection.
    53  	rt1 := dt.roundTrip()
    54  	c1 := dt.wantDial()
    55  	c1.finish(nil)
    56  	rt1.wantDone(c1)
    57  
    58  	// Second request is made while the first request is still using its connection.
    59  	// The first connection completes while the second Dial is in progress, so the
    60  	// second request uses the first connection.
    61  	rt2 := dt.roundTrip()
    62  	c2 := dt.wantDial()
    63  	rt1.finish()
    64  	rt2.wantDone(c1)
    65  
    66  	// This section is a bit overfitted to the current Transport implementation:
    67  	// A third request starts. We have an in-progress dial that was started by rt2,
    68  	// but this new request (rt3) is going to ignore it and make a dial of its own.
    69  	// rt3 will use the first of these dials that completes.
    70  	rt3 := dt.roundTrip()
    71  	c3 := dt.wantDial()
    72  	c2.finish(nil)
    73  	rt3.wantDone(c2)
    74  
    75  	c3.finish(nil)
    76  }
    77  
    78  // A transportDialTester manages a test of a connection's Dials.
    79  type transportDialTester struct {
    80  	t   *testing.T
    81  	cst *clientServerTest
    82  
    83  	dials chan *transportDialTesterConn // each new conn is sent to this channel
    84  
    85  	roundTripCount int
    86  	dialCount      int
    87  }
    88  
    89  // A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
    90  type transportDialTesterRoundTrip struct {
    91  	t *testing.T
    92  
    93  	roundTripID int                // distinguishes RoundTrips in logs
    94  	cancel      context.CancelFunc // cancels the Request context
    95  	reqBody     io.WriteCloser     // write half of the Request.Body
    96  	finished    bool
    97  
    98  	done chan struct{} // closed when RoundTrip returns:w
    99  	res  *http.Response
   100  	err  error
   101  	conn *transportDialTesterConn
   102  }
   103  
   104  // A transportDialTesterConn is a client connection created by the Transport as
   105  // part of a dial test.
   106  type transportDialTesterConn struct {
   107  	t *testing.T
   108  
   109  	connID int        // distinguished Dials in logs
   110  	ready  chan error // sent on to complete the Dial
   111  
   112  	net.Conn
   113  }
   114  
   115  func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
   116  	t.Helper()
   117  	dt := &transportDialTester{
   118  		t:     t,
   119  		dials: make(chan *transportDialTesterConn),
   120  	}
   121  	dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   122  		// Write response headers when we receive a request.
   123  		http.NewResponseController(w).EnableFullDuplex()
   124  		w.WriteHeader(200)
   125  		http.NewResponseController(w).Flush()
   126  		// Wait for the client to send the request body,
   127  		// to synchronize with the rest of the test.
   128  		io.ReadAll(r.Body)
   129  	}), func(tr *http.Transport) {
   130  		tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
   131  			c := &transportDialTesterConn{
   132  				t:     t,
   133  				ready: make(chan error),
   134  			}
   135  			// Notify the test that a Dial has started,
   136  			// and wait for the test to notify us that it should complete.
   137  			dt.dials <- c
   138  			if err := <-c.ready; err != nil {
   139  				return nil, err
   140  			}
   141  			nc, err := net.Dial(network, address)
   142  			if err != nil {
   143  				return nil, err
   144  			}
   145  			// Use the *transportDialTesterConn as the net.Conn,
   146  			// to let tests associate requests with connections.
   147  			c.Conn = nc
   148  			return c, err
   149  		}
   150  	})
   151  	return dt
   152  }
   153  
   154  // roundTrip starts a RoundTrip.
   155  // It returns immediately, without waiting for the RoundTrip call to complete.
   156  func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
   157  	dt.t.Helper()
   158  	ctx, cancel := context.WithCancel(context.Background())
   159  	pr, pw := io.Pipe()
   160  	rt := &transportDialTesterRoundTrip{
   161  		t:           dt.t,
   162  		roundTripID: dt.roundTripCount,
   163  		done:        make(chan struct{}),
   164  		reqBody:     pw,
   165  		cancel:      cancel,
   166  	}
   167  	dt.roundTripCount++
   168  	dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
   169  	dt.t.Cleanup(func() {
   170  		rt.cancel()
   171  		rt.finish()
   172  	})
   173  	go func() {
   174  		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
   175  			GotConn: func(info httptrace.GotConnInfo) {
   176  				rt.conn = info.Conn.(*transportDialTesterConn)
   177  			},
   178  		})
   179  		req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
   180  		req.Header.Set("Content-Type", "text/plain")
   181  		rt.res, rt.err = dt.cst.tr.RoundTrip(req)
   182  		dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
   183  		close(rt.done)
   184  	}()
   185  	return rt
   186  }
   187  
   188  // wantDone indicates that a RoundTrip should have returned.
   189  func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
   190  	rt.t.Helper()
   191  	<-rt.done
   192  	if rt.err != nil {
   193  		rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
   194  	}
   195  	if rt.conn != c {
   196  		rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
   197  	}
   198  }
   199  
   200  // finish completes a RoundTrip by sending the request body, consuming the response body,
   201  // and closing the response body.
   202  func (rt *transportDialTesterRoundTrip) finish() {
   203  	rt.t.Helper()
   204  
   205  	if rt.finished {
   206  		return
   207  	}
   208  	rt.finished = true
   209  
   210  	<-rt.done
   211  
   212  	if rt.err != nil {
   213  		return
   214  	}
   215  	rt.reqBody.Close()
   216  	io.ReadAll(rt.res.Body)
   217  	rt.res.Body.Close()
   218  	rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
   219  }
   220  
   221  // wantDial waits for the Transport to start a Dial.
   222  func (dt *transportDialTester) wantDial() *transportDialTesterConn {
   223  	c := <-dt.dials
   224  	c.connID = dt.dialCount
   225  	dt.dialCount++
   226  	dt.t.Logf("Dial %v: started", c.connID)
   227  	return c
   228  }
   229  
   230  // finish completes a Dial.
   231  func (c *transportDialTesterConn) finish(err error) {
   232  	c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
   233  	c.ready <- err
   234  	close(c.ready)
   235  }
   236  

View as plain text