Source file src/net/http/export_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Bridge package to expose http internals to tests in the http_test
     6  // package.
     7  
     8  package http
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"net"
    14  	"net/url"
    15  	"slices"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  var (
    22  	DefaultUserAgent                  = defaultUserAgent
    23  	NewLoggingConn                    = newLoggingConn
    24  	ExportAppendTime                  = appendTime
    25  	ExportRefererForURL               = refererForURL
    26  	ExportServerNewConn               = (*Server).newConn
    27  	ExportCloseWriteAndWait           = (*conn).closeWriteAndWait
    28  	ExportErrRequestCanceled          = errRequestCanceled
    29  	ExportErrRequestCanceledConn      = errRequestCanceledConn
    30  	ExportErrServerClosedIdle         = errServerClosedIdle
    31  	ExportServeFile                   = serveFile
    32  	ExportScanETag                    = scanETag
    33  	ExportHttp2ConfigureServer        = http2ConfigureServer
    34  	Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect
    35  	Export_writeStatusLine            = writeStatusLine
    36  	Export_is408Message               = is408Message
    37  )
    38  
    39  var MaxWriteWaitBeforeConnReuse = &maxWriteWaitBeforeConnReuse
    40  
    41  func init() {
    42  	// We only want to pay for this cost during testing.
    43  	// When not under test, these values are always nil
    44  	// and never assigned to.
    45  	testHookMu = new(sync.Mutex)
    46  
    47  	testHookClientDoResult = func(res *Response, err error) {
    48  		if err != nil {
    49  			if _, ok := err.(*url.Error); !ok {
    50  				panic(fmt.Sprintf("unexpected Client.Do error of type %T; want *url.Error", err))
    51  			}
    52  		} else {
    53  			if res == nil {
    54  				panic("Client.Do returned nil, nil")
    55  			}
    56  			if res.Body == nil {
    57  				panic("Client.Do returned nil res.Body and no error")
    58  			}
    59  		}
    60  	}
    61  }
    62  
    63  func CondSkipHTTP2(t testing.TB) {
    64  	if omitBundledHTTP2 {
    65  		t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use")
    66  	}
    67  }
    68  
    69  var (
    70  	SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip)
    71  	SetRoundTripRetried   = hookSetter(&testHookRoundTripRetried)
    72  )
    73  
    74  func SetReadLoopBeforeNextReadHook(f func()) {
    75  	unnilTestHook(&f)
    76  	testHookReadLoopBeforeNextRead = f
    77  }
    78  
    79  // SetPendingDialHooks sets the hooks that run before and after handling
    80  // pending dials.
    81  func SetPendingDialHooks(before, after func()) {
    82  	unnilTestHook(&before)
    83  	unnilTestHook(&after)
    84  	testHookPrePendingDial, testHookPostPendingDial = before, after
    85  }
    86  
    87  func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
    88  
    89  func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
    90  	orig := testHookProxyConnectTimeout
    91  	t.Cleanup(func() {
    92  		testHookProxyConnectTimeout = orig
    93  	})
    94  	testHookProxyConnectTimeout = f
    95  }
    96  
    97  func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
    98  	return &timeoutHandler{
    99  		handler:     handler,
   100  		testContext: ctx,
   101  		// (no body)
   102  	}
   103  }
   104  
   105  func ResetCachedEnvironment() {
   106  	resetProxyConfig()
   107  }
   108  
   109  func (t *Transport) NumPendingRequestsForTesting() int {
   110  	t.reqMu.Lock()
   111  	defer t.reqMu.Unlock()
   112  	return len(t.reqCanceler)
   113  }
   114  
   115  func (t *Transport) IdleConnKeysForTesting() (keys []string) {
   116  	keys = make([]string, 0)
   117  	t.idleMu.Lock()
   118  	defer t.idleMu.Unlock()
   119  	for key := range t.idleConn {
   120  		keys = append(keys, key.String())
   121  	}
   122  	slices.Sort(keys)
   123  	return
   124  }
   125  
   126  func (t *Transport) IdleConnKeyCountForTesting() int {
   127  	t.idleMu.Lock()
   128  	defer t.idleMu.Unlock()
   129  	return len(t.idleConn)
   130  }
   131  
   132  func (t *Transport) IdleConnStrsForTesting() []string {
   133  	var ret []string
   134  	t.idleMu.Lock()
   135  	defer t.idleMu.Unlock()
   136  	for _, conns := range t.idleConn {
   137  		for _, pc := range conns {
   138  			ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String())
   139  		}
   140  	}
   141  	slices.Sort(ret)
   142  	return ret
   143  }
   144  
   145  func (t *Transport) IdleConnStrsForTesting_h2() []string {
   146  	var ret []string
   147  	noDialPool := t.h2transport.(*http2Transport).ConnPool.(http2noDialClientConnPool)
   148  	pool := noDialPool.http2clientConnPool
   149  
   150  	pool.mu.Lock()
   151  	defer pool.mu.Unlock()
   152  
   153  	for k, ccs := range pool.conns {
   154  		for _, cc := range ccs {
   155  			if cc.idleState().canTakeNewRequest {
   156  				ret = append(ret, k)
   157  			}
   158  		}
   159  	}
   160  
   161  	slices.Sort(ret)
   162  	return ret
   163  }
   164  
   165  func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
   166  	t.idleMu.Lock()
   167  	defer t.idleMu.Unlock()
   168  	key := connectMethodKey{"", scheme, addr, false}
   169  	cacheKey := key.String()
   170  	for k, conns := range t.idleConn {
   171  		if k.String() == cacheKey {
   172  			return len(conns)
   173  		}
   174  	}
   175  	return 0
   176  }
   177  
   178  func (t *Transport) IdleConnWaitMapSizeForTesting() int {
   179  	t.idleMu.Lock()
   180  	defer t.idleMu.Unlock()
   181  	return len(t.idleConnWait)
   182  }
   183  
   184  func (t *Transport) IsIdleForTesting() bool {
   185  	t.idleMu.Lock()
   186  	defer t.idleMu.Unlock()
   187  	return t.closeIdle
   188  }
   189  
   190  func (t *Transport) QueueForIdleConnForTesting() {
   191  	t.queueForIdleConn(nil)
   192  }
   193  
   194  // PutIdleTestConn reports whether it was able to insert a fresh
   195  // persistConn for scheme, addr into the idle connection pool.
   196  func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
   197  	c, _ := net.Pipe()
   198  	key := connectMethodKey{"", scheme, addr, false}
   199  
   200  	if t.MaxConnsPerHost > 0 {
   201  		// Transport is tracking conns-per-host.
   202  		// Increment connection count to account
   203  		// for new persistConn created below.
   204  		t.connsPerHostMu.Lock()
   205  		if t.connsPerHost == nil {
   206  			t.connsPerHost = make(map[connectMethodKey]int)
   207  		}
   208  		t.connsPerHost[key]++
   209  		t.connsPerHostMu.Unlock()
   210  	}
   211  
   212  	return t.tryPutIdleConn(&persistConn{
   213  		t:        t,
   214  		conn:     c,                   // dummy
   215  		closech:  make(chan struct{}), // so it can be closed
   216  		cacheKey: key,
   217  	}) == nil
   218  }
   219  
   220  // PutIdleTestConnH2 reports whether it was able to insert a fresh
   221  // HTTP/2 persistConn for scheme, addr into the idle connection pool.
   222  func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt RoundTripper) bool {
   223  	key := connectMethodKey{"", scheme, addr, false}
   224  
   225  	if t.MaxConnsPerHost > 0 {
   226  		// Transport is tracking conns-per-host.
   227  		// Increment connection count to account
   228  		// for new persistConn created below.
   229  		t.connsPerHostMu.Lock()
   230  		if t.connsPerHost == nil {
   231  			t.connsPerHost = make(map[connectMethodKey]int)
   232  		}
   233  		t.connsPerHost[key]++
   234  		t.connsPerHostMu.Unlock()
   235  	}
   236  
   237  	return t.tryPutIdleConn(&persistConn{
   238  		t:        t,
   239  		alt:      alt,
   240  		cacheKey: key,
   241  	}) == nil
   242  }
   243  
   244  // All test hooks must be non-nil so they can be called directly,
   245  // but the tests use nil to mean hook disabled.
   246  func unnilTestHook(f *func()) {
   247  	if *f == nil {
   248  		*f = nop
   249  	}
   250  }
   251  
   252  func hookSetter(dst *func()) func(func()) {
   253  	return func(fn func()) {
   254  		unnilTestHook(&fn)
   255  		*dst = fn
   256  	}
   257  }
   258  
   259  func ExportHttp2ConfigureTransport(t *Transport) error {
   260  	t2, err := http2configureTransports(t)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	t.h2transport = t2
   265  	return nil
   266  }
   267  
   268  func (s *Server) ExportAllConnsIdle() bool {
   269  	s.mu.Lock()
   270  	defer s.mu.Unlock()
   271  	for c := range s.activeConn {
   272  		st, unixSec := c.getState()
   273  		if unixSec == 0 || st != StateIdle {
   274  			return false
   275  		}
   276  	}
   277  	return true
   278  }
   279  
   280  func (s *Server) ExportAllConnsByState() map[ConnState]int {
   281  	states := map[ConnState]int{}
   282  	s.mu.Lock()
   283  	defer s.mu.Unlock()
   284  	for c := range s.activeConn {
   285  		st, _ := c.getState()
   286  		states[st] += 1
   287  	}
   288  	return states
   289  }
   290  
   291  func (r *Request) WithT(t *testing.T) *Request {
   292  	return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
   293  }
   294  
   295  func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) {
   296  	old := http2goAwayTimeout
   297  	http2goAwayTimeout = d
   298  	return func() { http2goAwayTimeout = old }
   299  }
   300  
   301  func (r *Request) ExportIsReplayable() bool { return r.isReplayable() }
   302  
   303  // ExportCloseTransportConnsAbruptly closes all idle connections from
   304  // tr in an abrupt way, just reaching into the underlying Conns and
   305  // closing them, without telling the Transport or its persistConns
   306  // that it's doing so. This is to simulate the server closing connections
   307  // on the Transport.
   308  func ExportCloseTransportConnsAbruptly(tr *Transport) {
   309  	tr.idleMu.Lock()
   310  	for _, pcs := range tr.idleConn {
   311  		for _, pc := range pcs {
   312  			pc.conn.Close()
   313  		}
   314  	}
   315  	tr.idleMu.Unlock()
   316  }
   317  
   318  // ResponseWriterConnForTesting returns w's underlying connection, if w
   319  // is a regular *response ResponseWriter.
   320  func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) {
   321  	if r, ok := w.(*response); ok {
   322  		return r.conn.rwc, true
   323  	}
   324  	return nil, false
   325  }
   326  
   327  func init() {
   328  	// Set the default rstAvoidanceDelay to the minimum possible value to shake
   329  	// out tests that unexpectedly depend on it. Such tests should use
   330  	// runTimeSensitiveTest and SetRSTAvoidanceDelay to explicitly raise the delay
   331  	// if needed.
   332  	rstAvoidanceDelay = 1 * time.Nanosecond
   333  }
   334  
   335  // SetRSTAvoidanceDelay sets how long we are willing to wait between calling
   336  // CloseWrite on a connection and fully closing the connection.
   337  func SetRSTAvoidanceDelay(t *testing.T, d time.Duration) {
   338  	prevDelay := rstAvoidanceDelay
   339  	t.Cleanup(func() {
   340  		rstAvoidanceDelay = prevDelay
   341  	})
   342  	rstAvoidanceDelay = d
   343  }
   344  

View as plain text