Source file src/net/splice_linux_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build linux
     6  
     7  package net
     8  
     9  import (
    10  	"internal/poll"
    11  	"io"
    12  	"os"
    13  	"strconv"
    14  	"sync"
    15  	"syscall"
    16  	"testing"
    17  )
    18  
    19  func TestSplice(t *testing.T) {
    20  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    21  	if !testableNetwork("unixgram") {
    22  		t.Skip("skipping unix-to-tcp tests")
    23  	}
    24  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    25  	t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
    26  	t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
    27  	t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
    28  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    29  	t.Run("no-unixgram", testSpliceNoUnixgram)
    30  }
    31  
    32  func testSpliceToFile(t *testing.T, upNet, downNet string) {
    33  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
    34  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
    35  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
    36  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
    37  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
    38  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
    39  }
    40  
    41  func testSplice(t *testing.T, upNet, downNet string) {
    42  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    43  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    44  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    45  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    46  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    47  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    48  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    49  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    50  }
    51  
    52  type spliceTestCase struct {
    53  	upNet, downNet string
    54  
    55  	chunkSize, totalSize int
    56  	limitReadSize        int
    57  }
    58  
    59  func (tc spliceTestCase) test(t *testing.T) {
    60  	hook := hookSplice(t)
    61  
    62  	// We need to use the actual size for startTestSocketPeer when testing with LimitedReader,
    63  	// otherwise the child process created in startTestSocketPeer will hang infinitely because of
    64  	// the mismatch of data size to transfer.
    65  	size := tc.totalSize
    66  	if tc.limitReadSize > 0 {
    67  		if tc.limitReadSize < size {
    68  			size = tc.limitReadSize
    69  		}
    70  	}
    71  
    72  	clientUp, serverUp := spawnTestSocketPair(t, tc.upNet)
    73  	defer serverUp.Close()
    74  	cleanup, err := startTestSocketPeer(t, clientUp, "w", tc.chunkSize, size)
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	defer cleanup(t)
    79  	clientDown, serverDown := spawnTestSocketPair(t, tc.downNet)
    80  	defer serverDown.Close()
    81  	cleanup, err = startTestSocketPeer(t, clientDown, "r", tc.chunkSize, size)
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	defer cleanup(t)
    86  
    87  	var r io.Reader = serverUp
    88  	if tc.limitReadSize > 0 {
    89  		r = &io.LimitedReader{
    90  			N: int64(tc.limitReadSize),
    91  			R: serverUp,
    92  		}
    93  		defer serverUp.Close()
    94  	}
    95  	n, err := io.Copy(serverDown, r)
    96  	if err != nil {
    97  		t.Fatal(err)
    98  	}
    99  
   100  	if want := int64(size); want != n {
   101  		t.Errorf("want %d bytes spliced, got %d", want, n)
   102  	}
   103  
   104  	if tc.limitReadSize > 0 {
   105  		wantN := 0
   106  		if tc.limitReadSize > size {
   107  			wantN = tc.limitReadSize - size
   108  		}
   109  
   110  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
   111  			t.Errorf("r.N = %d, want %d", n, wantN)
   112  		}
   113  	}
   114  
   115  	// poll.Splice is expected to be called when the source is not
   116  	// a wrapper or the destination is TCPConn.
   117  	if tc.limitReadSize == 0 || tc.downNet == "tcp" {
   118  		// We should have called poll.Splice with the right file descriptor arguments.
   119  		if n > 0 && !hook.called {
   120  			t.Fatal("expected poll.Splice to be called")
   121  		}
   122  
   123  		verifySpliceFds(t, serverDown, hook, "dst")
   124  		verifySpliceFds(t, serverUp, hook, "src")
   125  
   126  		// poll.Splice is expected to handle the data transmission successfully.
   127  		if !hook.handled || hook.written != int64(size) || hook.err != nil {
   128  			t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v",
   129  				size, hook.handled, hook.written, hook.err)
   130  		}
   131  	} else if hook.called {
   132  		// poll.Splice will certainly not be called when the source
   133  		// is a wrapper and the destination is not TCPConn.
   134  		t.Errorf("expected poll.Splice not be called")
   135  	}
   136  }
   137  
   138  func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) {
   139  	t.Helper()
   140  
   141  	sc, ok := c.(syscall.Conn)
   142  	if !ok {
   143  		t.Fatalf("expected syscall.Conn")
   144  	}
   145  	rc, err := sc.SyscallConn()
   146  	if err != nil {
   147  		t.Fatalf("syscall.Conn.SyscallConn error: %v", err)
   148  	}
   149  	var hookFd int
   150  	switch fdType {
   151  	case "src":
   152  		hookFd = hook.srcfd
   153  	case "dst":
   154  		hookFd = hook.dstfd
   155  	default:
   156  		t.Fatalf("unknown fdType %q", fdType)
   157  	}
   158  	if err := rc.Control(func(fd uintptr) {
   159  		if hook.called && hookFd != int(fd) {
   160  			t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd))
   161  		}
   162  	}); err != nil {
   163  		t.Fatalf("syscall.RawConn.Control error: %v", err)
   164  	}
   165  }
   166  
   167  func (tc spliceTestCase) testFile(t *testing.T) {
   168  	hook := hookSplice(t)
   169  
   170  	// We need to use the actual size for startTestSocketPeer when testing with LimitedReader,
   171  	// otherwise the child process created in startTestSocketPeer will hang infinitely because of
   172  	// the mismatch of data size to transfer.
   173  	actualSize := tc.totalSize
   174  	if tc.limitReadSize > 0 {
   175  		if tc.limitReadSize < actualSize {
   176  			actualSize = tc.limitReadSize
   177  		}
   178  	}
   179  
   180  	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  	defer f.Close()
   185  
   186  	client, server := spawnTestSocketPair(t, tc.upNet)
   187  	defer server.Close()
   188  
   189  	cleanup, err := startTestSocketPeer(t, client, "w", tc.chunkSize, actualSize)
   190  	if err != nil {
   191  		client.Close()
   192  		t.Fatal("failed to start splice client:", err)
   193  	}
   194  	defer cleanup(t)
   195  
   196  	var r io.Reader = server
   197  	if tc.limitReadSize > 0 {
   198  		r = &io.LimitedReader{
   199  			N: int64(tc.limitReadSize),
   200  			R: r,
   201  		}
   202  	}
   203  
   204  	got, err := io.Copy(f, r)
   205  	if err != nil {
   206  		t.Fatalf("failed to ReadFrom with error: %v", err)
   207  	}
   208  
   209  	// We shouldn't have called poll.Splice in TCPConn.WriteTo,
   210  	// it's supposed to be called from File.ReadFrom.
   211  	if got > 0 && hook.called {
   212  		t.Error("expected not poll.Splice to be called")
   213  	}
   214  
   215  	if want := int64(actualSize); got != want {
   216  		t.Errorf("got %d bytes, want %d", got, want)
   217  	}
   218  	if tc.limitReadSize > 0 {
   219  		wantN := 0
   220  		if tc.limitReadSize > actualSize {
   221  			wantN = tc.limitReadSize - actualSize
   222  		}
   223  
   224  		if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
   225  			t.Errorf("r.N = %d, want %d", gotN, wantN)
   226  		}
   227  	}
   228  }
   229  
   230  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   231  	// UnixConn doesn't implement io.ReaderFrom, which will fail
   232  	// the following test in asserting a UnixConn to be an io.ReaderFrom,
   233  	// so skip this test.
   234  	if downNet == "unix" {
   235  		t.Skip("skipping test on unix socket")
   236  	}
   237  
   238  	hook := hookSplice(t)
   239  
   240  	clientUp, serverUp := spawnTestSocketPair(t, upNet)
   241  	defer clientUp.Close()
   242  	clientDown, serverDown := spawnTestSocketPair(t, downNet)
   243  	defer clientDown.Close()
   244  	defer serverDown.Close()
   245  
   246  	serverUp.Close()
   247  
   248  	// We'd like to call net.spliceFrom here and check the handled return
   249  	// value, but we disable splice on old Linux kernels.
   250  	//
   251  	// In that case, poll.Splice and net.spliceFrom return a non-nil error
   252  	// and handled == false. We'd ideally like to see handled == true
   253  	// because the source reader is at EOF, but if we're running on an old
   254  	// kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
   255  	// because we won't touch the reader at all.
   256  	//
   257  	// Trying to untangle the errors from net.spliceFrom and match them
   258  	// against the errors created by the poll package would be brittle,
   259  	// so this is a higher level test.
   260  	//
   261  	// The following ReadFrom should return immediately, regardless of
   262  	// whether splice is disabled or not. The other side should then
   263  	// get a goodbye signal. Test for the goodbye signal.
   264  	msg := "bye"
   265  	go func() {
   266  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   267  		io.WriteString(serverDown, msg)
   268  	}()
   269  
   270  	buf := make([]byte, 3)
   271  	n, err := io.ReadFull(clientDown, buf)
   272  	if err != nil {
   273  		t.Errorf("clientDown: %v", err)
   274  	}
   275  	if string(buf) != msg {
   276  		t.Errorf("clientDown got %q, want %q", buf, msg)
   277  	}
   278  
   279  	// We should have called poll.Splice with the right file descriptor arguments.
   280  	if n > 0 && !hook.called {
   281  		t.Fatal("expected poll.Splice to be called")
   282  	}
   283  
   284  	verifySpliceFds(t, serverDown, hook, "dst")
   285  
   286  	// poll.Splice is expected to handle the data transmission but fail
   287  	// when working with a closed endpoint, return an error.
   288  	if !hook.handled || hook.written > 0 || hook.err == nil {
   289  		t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v",
   290  			hook.handled, hook.written, hook.err)
   291  	}
   292  }
   293  
   294  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   295  	front := newLocalListener(t, upNet)
   296  	defer front.Close()
   297  	back := newLocalListener(t, downNet)
   298  	defer back.Close()
   299  
   300  	var wg sync.WaitGroup
   301  	wg.Add(2)
   302  
   303  	proxy := func() {
   304  		src, err := front.Accept()
   305  		if err != nil {
   306  			return
   307  		}
   308  		dst, err := Dial(downNet, back.Addr().String())
   309  		if err != nil {
   310  			return
   311  		}
   312  		defer dst.Close()
   313  		defer src.Close()
   314  		go func() {
   315  			io.Copy(src, dst)
   316  			wg.Done()
   317  		}()
   318  		go func() {
   319  			io.Copy(dst, src)
   320  			wg.Done()
   321  		}()
   322  	}
   323  
   324  	go proxy()
   325  
   326  	toFront, err := Dial(upNet, front.Addr().String())
   327  	if err != nil {
   328  		t.Fatal(err)
   329  	}
   330  
   331  	io.WriteString(toFront, "foo")
   332  	toFront.Close()
   333  
   334  	fromProxy, err := back.Accept()
   335  	if err != nil {
   336  		t.Fatal(err)
   337  	}
   338  	defer fromProxy.Close()
   339  
   340  	_, err = io.ReadAll(fromProxy)
   341  	if err != nil {
   342  		t.Fatal(err)
   343  	}
   344  
   345  	wg.Wait()
   346  }
   347  
   348  func testSpliceNoUnixpacket(t *testing.T) {
   349  	clientUp, serverUp := spawnTestSocketPair(t, "unixpacket")
   350  	defer clientUp.Close()
   351  	defer serverUp.Close()
   352  	clientDown, serverDown := spawnTestSocketPair(t, "tcp")
   353  	defer clientDown.Close()
   354  	defer serverDown.Close()
   355  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   356  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   357  	// try, it assumes the kernel it's running on doesn't support splice
   358  	// for unix sockets and returns handled == false. This works for our
   359  	// purposes by somewhat of an accident, but is not entirely correct.
   360  	//
   361  	// What we want is err == nil and handled == false, i.e. we never
   362  	// called poll.Splice, because we know the unix socket's network.
   363  	_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
   364  	if err != nil || handled != false {
   365  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   366  	}
   367  }
   368  
   369  func testSpliceNoUnixgram(t *testing.T) {
   370  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
   371  	if err != nil {
   372  		t.Fatal(err)
   373  	}
   374  	defer os.Remove(addr.Name)
   375  	up, err := ListenUnixgram("unixgram", addr)
   376  	if err != nil {
   377  		t.Fatal(err)
   378  	}
   379  	defer up.Close()
   380  	clientDown, serverDown := spawnTestSocketPair(t, "tcp")
   381  	defer clientDown.Close()
   382  	defer serverDown.Close()
   383  	// Analogous to testSpliceNoUnixpacket.
   384  	_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
   385  	if err != nil || handled != false {
   386  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   387  	}
   388  }
   389  
   390  func BenchmarkSplice(b *testing.B) {
   391  	testHookUninstaller.Do(uninstallTestHooks)
   392  
   393  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   394  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   395  	b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
   396  }
   397  
   398  func benchSplice(b *testing.B, upNet, downNet string) {
   399  	for i := 0; i <= 10; i++ {
   400  		chunkSize := 1 << uint(i+10)
   401  		tc := spliceTestCase{
   402  			upNet:     upNet,
   403  			downNet:   downNet,
   404  			chunkSize: chunkSize,
   405  		}
   406  
   407  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   408  	}
   409  }
   410  
   411  func (tc spliceTestCase) bench(b *testing.B) {
   412  	// To benchmark the genericReadFrom code path, set this to false.
   413  	useSplice := true
   414  
   415  	clientUp, serverUp := spawnTestSocketPair(b, tc.upNet)
   416  	defer serverUp.Close()
   417  
   418  	cleanup, err := startTestSocketPeer(b, clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   419  	if err != nil {
   420  		b.Fatal(err)
   421  	}
   422  	defer cleanup(b)
   423  
   424  	clientDown, serverDown := spawnTestSocketPair(b, tc.downNet)
   425  	defer serverDown.Close()
   426  
   427  	cleanup, err = startTestSocketPeer(b, clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   428  	if err != nil {
   429  		b.Fatal(err)
   430  	}
   431  	defer cleanup(b)
   432  
   433  	b.SetBytes(int64(tc.chunkSize))
   434  	b.ResetTimer()
   435  
   436  	if useSplice {
   437  		_, err := io.Copy(serverDown, serverUp)
   438  		if err != nil {
   439  			b.Fatal(err)
   440  		}
   441  	} else {
   442  		type onlyReader struct {
   443  			io.Reader
   444  		}
   445  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   446  		if err != nil {
   447  			b.Fatal(err)
   448  		}
   449  	}
   450  }
   451  
   452  func BenchmarkSpliceFile(b *testing.B) {
   453  	b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
   454  	b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
   455  }
   456  
   457  func benchmarkSpliceFile(b *testing.B, proto string) {
   458  	for i := 0; i <= 10; i++ {
   459  		size := 1 << (i + 10)
   460  		bench := spliceFileBench{
   461  			proto:     proto,
   462  			chunkSize: size,
   463  		}
   464  		b.Run(strconv.Itoa(size), bench.benchSpliceFile)
   465  	}
   466  }
   467  
   468  type spliceFileBench struct {
   469  	proto     string
   470  	chunkSize int
   471  }
   472  
   473  func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
   474  	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
   475  	if err != nil {
   476  		b.Fatal(err)
   477  	}
   478  	defer f.Close()
   479  
   480  	totalSize := b.N * bench.chunkSize
   481  
   482  	client, server := spawnTestSocketPair(b, bench.proto)
   483  	defer server.Close()
   484  
   485  	cleanup, err := startTestSocketPeer(b, client, "w", bench.chunkSize, totalSize)
   486  	if err != nil {
   487  		client.Close()
   488  		b.Fatalf("failed to start splice client: %v", err)
   489  	}
   490  	defer cleanup(b)
   491  
   492  	b.ReportAllocs()
   493  	b.SetBytes(int64(bench.chunkSize))
   494  	b.ResetTimer()
   495  
   496  	got, err := io.Copy(f, server)
   497  	if err != nil {
   498  		b.Fatalf("failed to ReadFrom with error: %v", err)
   499  	}
   500  	if want := int64(totalSize); got != want {
   501  		b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
   502  	}
   503  }
   504  
   505  func hookSplice(t *testing.T) *spliceHook {
   506  	t.Helper()
   507  
   508  	h := new(spliceHook)
   509  	h.install()
   510  	t.Cleanup(h.uninstall)
   511  	return h
   512  }
   513  
   514  type spliceHook struct {
   515  	called bool
   516  	dstfd  int
   517  	srcfd  int
   518  	remain int64
   519  
   520  	written int64
   521  	handled bool
   522  	err     error
   523  
   524  	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
   525  }
   526  
   527  func (h *spliceHook) install() {
   528  	h.original = pollSplice
   529  	pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
   530  		h.called = true
   531  		h.dstfd = dst.Sysfd
   532  		h.srcfd = src.Sysfd
   533  		h.remain = remain
   534  		h.written, h.handled, h.err = h.original(dst, src, remain)
   535  		return h.written, h.handled, h.err
   536  	}
   537  }
   538  
   539  func (h *spliceHook) uninstall() {
   540  	pollSplice = h.original
   541  }
   542  

View as plain text