Source file src/net/http/httptest/server_test.go

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package httptest
     6  
     7  import (
     8  	"bufio"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"sync"
    13  	"testing"
    14  )
    15  
    16  type newServerFunc func(http.Handler) *Server
    17  
    18  var newServers = map[string]newServerFunc{
    19  	"NewServer":    NewServer,
    20  	"NewTLSServer": NewTLSServer,
    21  
    22  	// The manual variants of newServer create a Server manually by only filling
    23  	// in the exported fields of Server.
    24  	"NewServerManual": func(h http.Handler) *Server {
    25  		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
    26  		ts.Start()
    27  		return ts
    28  	},
    29  	"NewTLSServerManual": func(h http.Handler) *Server {
    30  		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
    31  		ts.StartTLS()
    32  		return ts
    33  	},
    34  }
    35  
    36  func TestServer(t *testing.T) {
    37  	for _, name := range []string{"NewServer", "NewServerManual"} {
    38  		t.Run(name, func(t *testing.T) {
    39  			newServer := newServers[name]
    40  			t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
    41  			t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
    42  			t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
    43  			t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
    44  			t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
    45  		})
    46  	}
    47  	for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
    48  		t.Run(name, func(t *testing.T) {
    49  			newServer := newServers[name]
    50  			t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
    51  			t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
    52  		})
    53  	}
    54  }
    55  
    56  func testServer(t *testing.T, newServer newServerFunc) {
    57  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    58  		w.Write([]byte("hello"))
    59  	}))
    60  	defer ts.Close()
    61  	res, err := http.Get(ts.URL)
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  	got, err := io.ReadAll(res.Body)
    66  	res.Body.Close()
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	if string(got) != "hello" {
    71  		t.Errorf("got %q, want hello", string(got))
    72  	}
    73  }
    74  
    75  // Issue 12781
    76  func testGetAfterClose(t *testing.T, newServer newServerFunc) {
    77  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  		w.Write([]byte("hello"))
    79  	}))
    80  
    81  	res, err := http.Get(ts.URL)
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	got, err := io.ReadAll(res.Body)
    86  	res.Body.Close()
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	if string(got) != "hello" {
    91  		t.Fatalf("got %q, want hello", string(got))
    92  	}
    93  
    94  	ts.Close()
    95  
    96  	res, err = http.Get(ts.URL)
    97  	if err == nil {
    98  		body, _ := io.ReadAll(res.Body)
    99  		t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
   100  	}
   101  }
   102  
   103  func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
   104  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   105  		w.Write([]byte("hello"))
   106  	}))
   107  	dial := func() net.Conn {
   108  		c, err := net.Dial("tcp", ts.Listener.Addr().String())
   109  		if err != nil {
   110  			t.Fatal(err)
   111  		}
   112  		return c
   113  	}
   114  
   115  	// Keep one connection in StateNew (connected, but not sending anything)
   116  	cnew := dial()
   117  	defer cnew.Close()
   118  
   119  	// Keep one connection in StateIdle (idle after a request)
   120  	cidle := dial()
   121  	defer cidle.Close()
   122  	cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
   123  	_, err := http.ReadResponse(bufio.NewReader(cidle), nil)
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  
   128  	ts.Close() // test we don't hang here forever.
   129  }
   130  
   131  // Issue 14290
   132  func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
   133  	var s *Server
   134  	s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   135  		s.CloseClientConnections()
   136  	}))
   137  	defer s.Close()
   138  	res, err := http.Get(s.URL)
   139  	if err == nil {
   140  		res.Body.Close()
   141  		t.Fatalf("Unexpected response: %#v", res)
   142  	}
   143  }
   144  
   145  // Tests that the Server.Client method works and returns an http.Client that can hit
   146  // NewTLSServer without cert warnings.
   147  func testServerClient(t *testing.T, newTLSServer newServerFunc) {
   148  	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   149  		w.Write([]byte("hello"))
   150  	}))
   151  	defer ts.Close()
   152  	client := ts.Client()
   153  	res, err := client.Get(ts.URL)
   154  	if err != nil {
   155  		t.Fatal(err)
   156  	}
   157  	got, err := io.ReadAll(res.Body)
   158  	res.Body.Close()
   159  	if err != nil {
   160  		t.Fatal(err)
   161  	}
   162  	if string(got) != "hello" {
   163  		t.Errorf("got %q, want hello", string(got))
   164  	}
   165  }
   166  
   167  // Tests that the Server.Client.Transport interface is implemented
   168  // by a *http.Transport.
   169  func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
   170  	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   171  	}))
   172  	defer ts.Close()
   173  	client := ts.Client()
   174  	if _, ok := client.Transport.(*http.Transport); !ok {
   175  		t.Errorf("got %T, want *http.Transport", client.Transport)
   176  	}
   177  }
   178  
   179  // Tests that the TLS Server.Client.Transport interface is implemented
   180  // by a *http.Transport.
   181  func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
   182  	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   183  	}))
   184  	defer ts.Close()
   185  	client := ts.Client()
   186  	if _, ok := client.Transport.(*http.Transport); !ok {
   187  		t.Errorf("got %T, want *http.Transport", client.Transport)
   188  	}
   189  }
   190  
   191  type onlyCloseListener struct {
   192  	net.Listener
   193  }
   194  
   195  func (onlyCloseListener) Close() error { return nil }
   196  
   197  // Issue 19729: panic in Server.Close for values created directly
   198  // without a constructor (so the unexported client field is nil).
   199  func TestServerZeroValueClose(t *testing.T) {
   200  	ts := &Server{
   201  		Listener: onlyCloseListener{},
   202  		Config:   &http.Server{},
   203  	}
   204  
   205  	ts.Close() // tests that it doesn't panic
   206  }
   207  
   208  // Issue 51799: test hijacking a connection and then closing it
   209  // concurrently with closing the server.
   210  func TestCloseHijackedConnection(t *testing.T) {
   211  	hijacked := make(chan net.Conn)
   212  	ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   213  		defer close(hijacked)
   214  		hj, ok := w.(http.Hijacker)
   215  		if !ok {
   216  			t.Fatal("failed to hijack")
   217  		}
   218  		c, _, err := hj.Hijack()
   219  		if err != nil {
   220  			t.Fatal(err)
   221  		}
   222  		hijacked <- c
   223  	}))
   224  
   225  	var wg sync.WaitGroup
   226  	wg.Add(1)
   227  	go func() {
   228  		defer wg.Done()
   229  		req, err := http.NewRequest("GET", ts.URL, nil)
   230  		if err != nil {
   231  			t.Log(err)
   232  		}
   233  		// Use a client not associated with the Server.
   234  		var c http.Client
   235  		resp, err := c.Do(req)
   236  		if err != nil {
   237  			t.Log(err)
   238  			return
   239  		}
   240  		resp.Body.Close()
   241  	}()
   242  
   243  	wg.Add(1)
   244  	conn := <-hijacked
   245  	go func(conn net.Conn) {
   246  		defer wg.Done()
   247  		// Close the connection and then inform the Server that
   248  		// we closed it.
   249  		conn.Close()
   250  		ts.Config.ConnState(conn, http.StateClosed)
   251  	}(conn)
   252  
   253  	wg.Add(1)
   254  	go func() {
   255  		defer wg.Done()
   256  		ts.Close()
   257  	}()
   258  	wg.Wait()
   259  }
   260  
   261  func TestTLSServerWithHTTP2(t *testing.T) {
   262  	modes := []struct {
   263  		name      string
   264  		wantProto string
   265  	}{
   266  		{"http1", "HTTP/1.1"},
   267  		{"http2", "HTTP/2.0"},
   268  	}
   269  
   270  	for _, tt := range modes {
   271  		t.Run(tt.name, func(t *testing.T) {
   272  			cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   273  				w.Header().Set("X-Proto", r.Proto)
   274  			}))
   275  
   276  			switch tt.name {
   277  			case "http2":
   278  				cst.EnableHTTP2 = true
   279  				cst.StartTLS()
   280  			default:
   281  				cst.Start()
   282  			}
   283  
   284  			defer cst.Close()
   285  
   286  			res, err := cst.Client().Get(cst.URL)
   287  			if err != nil {
   288  				t.Fatalf("Failed to make request: %v", err)
   289  			}
   290  			if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
   291  				t.Fatalf("X-Proto header mismatch:\n\tgot:  %q\n\twant: %q", g, w)
   292  			}
   293  		})
   294  	}
   295  }
   296  

View as plain text