Source file src/internal/nettest/listener_test.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package nettest_test
     6  
     7  import (
     8  	"errors"
     9  	"internal/nettest"
    10  	"io"
    11  	"net"
    12  	"net/netip"
    13  	"slices"
    14  	"testing"
    15  	"testing/synctest"
    16  )
    17  
    18  func TestListenerNewConn(t *testing.T) {
    19  	synctest.Test(t, func(t *testing.T) {
    20  		li := nettest.NewListener()
    21  		defer li.Close()
    22  
    23  		// Create several connections in parallel.
    24  		want := []string{"a", "b", "c"}
    25  		for i := range len(want) {
    26  			go func() {
    27  				conn := li.NewConn()
    28  				defer conn.Close()
    29  				n, err := conn.Write([]byte(want[i]))
    30  				if n != len(want[i]) || err != nil {
    31  					t.Errorf("conn%v.Write() = %v, %v; want %v, nil", i, n, err, len(want[i]))
    32  				}
    33  			}()
    34  		}
    35  
    36  		// Accept the connections in parallel as well.
    37  		got := make([]string, len(want))
    38  		for i := range len(want) {
    39  			go func() {
    40  				conn, err := li.Accept()
    41  				if err != nil {
    42  					t.Errorf("li.Accept() = %v", err)
    43  				}
    44  				b, err := io.ReadAll(conn)
    45  				if err != nil {
    46  					t.Errorf("io.ReadAll(conn%v) = %v", i, err)
    47  				}
    48  				got[i] = string(b)
    49  			}()
    50  		}
    51  
    52  		synctest.Wait()
    53  		slices.Sort(got)
    54  		slices.Sort(want)
    55  		if !slices.Equal(got, want) {
    56  			t.Errorf("connections read %v; want %q", got, want)
    57  		}
    58  	})
    59  }
    60  
    61  func TestListenerInterruptAccept(t *testing.T) {
    62  	synctest.Test(t, func(t *testing.T) {
    63  		li := nettest.NewListener()
    64  
    65  		var acceptErr error
    66  		go func() {
    67  			_, acceptErr = li.Accept()
    68  		}()
    69  
    70  		synctest.Wait()
    71  		if acceptErr != nil {
    72  			t.Fatalf("li.Accept() = %v, want still running before close", acceptErr)
    73  		}
    74  
    75  		li.Close()
    76  		synctest.Wait()
    77  		if !errors.Is(acceptErr, net.ErrClosed) {
    78  			t.Fatalf("li.Accept() = %v, want ErrClosed", acceptErr)
    79  		}
    80  	})
    81  }
    82  
    83  func TestListenerAddresses(t *testing.T) {
    84  	synctest.Test(t, func(t *testing.T) {
    85  		srvaddr := netip.MustParseAddrPort("10.0.0.1:80")
    86  		cliaddr := netip.MustParseAddrPort("10.0.0.2:1234")
    87  
    88  		li := nettest.NewListener()
    89  		defer li.Close()
    90  
    91  		li.SetAddr(net.TCPAddrFromAddrPort(srvaddr))
    92  		if got, want := li.Addr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
    93  			t.Errorf("li.Addr() = %v, want %v", got, want)
    94  		}
    95  
    96  		cli := li.NewConnConfig(func(conn *nettest.Conn) {
    97  			conn.SetLocalAddr(net.TCPAddrFromAddrPort(cliaddr))
    98  		})
    99  		srvc, err := li.Accept()
   100  		if err != nil {
   101  			t.Fatalf("li.Accept() = %v", err)
   102  		}
   103  		srv := srvc.(*nettest.Conn)
   104  
   105  		if cli.Peer() != srv {
   106  			t.Errorf("cli.Peer() != srv; should be the same")
   107  		}
   108  		if srv.Peer() != cli {
   109  			t.Errorf("cli.Peer() != srv; should be the same")
   110  		}
   111  
   112  		if got, want := cli.LocalAddr().(*net.TCPAddr).AddrPort(), cliaddr; got != want {
   113  			t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
   114  		}
   115  		if got, want := cli.RemoteAddr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
   116  			t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
   117  		}
   118  		if got, want := srv.LocalAddr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
   119  			t.Errorf("srv.LocalAddr() = %v, want %v", got, want)
   120  		}
   121  		if got, want := srv.RemoteAddr().(*net.TCPAddr).AddrPort(), cliaddr; got != want {
   122  			t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
   123  		}
   124  	})
   125  }
   126  
   127  func wantListenerAccept(t *testing.T, li *nettest.Listener, want *nettest.Conn) {
   128  	t.Helper()
   129  	got, err := li.Accept()
   130  	if err != nil {
   131  		t.Fatalf("li.Accept() = %v, want conn", err)
   132  	}
   133  	if got != want {
   134  		t.Fatalf("li.Accept() returned unexpected conn")
   135  	}
   136  }
   137  
   138  func wantListenerAcceptErr(t *testing.T, li *nettest.Listener, want error) {
   139  	t.Helper()
   140  	got, err := li.Accept()
   141  	if got != nil || !isOpError(err, want) {
   142  		t.Fatalf("li.Accept() = %p, %v; want nil, OpError{Err: %q}", got, err, want)
   143  	}
   144  }
   145  
   146  func wantListenerAcceptBlocked(t *testing.T, li *nettest.Listener) {
   147  	cancelErr := errors.New("cancel")
   148  	done := false
   149  	go func() {
   150  		got, err := li.Accept()
   151  		if got != nil || !errors.Is(err, cancelErr) {
   152  			t.Errorf("li.Accept = %p, %v; want nil, cancelErr", got, err)
   153  		}
   154  		done = true
   155  	}()
   156  	synctest.Wait()
   157  	if done {
   158  		t.Fatalf("Accept unexpectedly returned before canceling")
   159  	}
   160  	li.SetAcceptError(cancelErr)
   161  	synctest.Wait()
   162  	li.SetAcceptError(nil)
   163  	if !done {
   164  		t.Fatalf("Accept unexpectedly did not return after canceling")
   165  	}
   166  }
   167  
   168  func TestListenerSetAcceptError(t *testing.T) {
   169  	synctest.Test(t, func(t *testing.T) {
   170  		acceptErr := errors.New("accept error")
   171  		li := nettest.NewListener()
   172  		defer li.Close()
   173  		li.SetAcceptError(acceptErr)
   174  
   175  		// Accept conns from queue before returning error.
   176  		c1 := li.NewConn()
   177  		wantListenerAccept(t, li, c1.Peer())
   178  		wantListenerAcceptErr(t, li, acceptErr)
   179  
   180  		// Add a new conn, suppressing error until the queue is empty.
   181  		c2 := li.NewConn()
   182  		wantListenerAccept(t, li, c2.Peer())
   183  		wantListenerAcceptErr(t, li, acceptErr)
   184  
   185  		// Error may be cleared.
   186  		li.SetAcceptError(nil)
   187  		wantListenerAcceptBlocked(t, li)
   188  
   189  		// ErrClosed takes precedence over accept error.
   190  		li.SetAcceptError(acceptErr)
   191  		li.Close()
   192  		wantListenerAcceptErr(t, li, net.ErrClosed)
   193  	})
   194  }
   195  
   196  func TestListenerSetCloseError(t *testing.T) {
   197  	synctest.Test(t, func(t *testing.T) {
   198  		li := nettest.NewListener()
   199  		closeErr := errors.New("close error")
   200  		li.SetCloseError(closeErr)
   201  
   202  		// First close uses the user-provided error.
   203  		if err := li.Close(); !isOpError(err, closeErr) {
   204  			t.Fatalf("li.Close() = %v; want OpError wrapping accept error", err)
   205  		}
   206  
   207  		// Repeated closes return ErrClosed.
   208  		if err := li.Close(); !isOpError(err, net.ErrClosed) {
   209  			t.Fatalf("li.Close() = %v; want OpError wrapping net.ErrClosed", err)
   210  		}
   211  	})
   212  }
   213  

View as plain text