Source file src/context/x_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	. "context"
     9  	"errors"
    10  	"fmt"
    11  	"math/rand"
    12  	"runtime"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  // Each XTestFoo in context_test.go must be called from a TestFoo here to run.
    20  func TestParentFinishesChild(t *testing.T) {
    21  	XTestParentFinishesChild(t) // uses unexported context types
    22  }
    23  func TestChildFinishesFirst(t *testing.T) {
    24  	XTestChildFinishesFirst(t) // uses unexported context types
    25  }
    26  func TestCancelRemoves(t *testing.T) {
    27  	XTestCancelRemoves(t) // uses unexported context types
    28  }
    29  func TestCustomContextGoroutines(t *testing.T) {
    30  	XTestCustomContextGoroutines(t) // reads the context.goroutines counter
    31  }
    32  
    33  // The following are regular tests in package context_test.
    34  
    35  // otherContext is a Context that's not one of the types defined in context.go.
    36  // This lets us test code paths that differ based on the underlying type of the
    37  // Context.
    38  type otherContext struct {
    39  	Context
    40  }
    41  
    42  const (
    43  	shortDuration    = 1 * time.Millisecond // a reasonable duration to block in a test
    44  	veryLongDuration = 1000 * time.Hour     // an arbitrary upper bound on the test's running time
    45  )
    46  
    47  // quiescent returns an arbitrary duration by which the program should have
    48  // completed any remaining work and reached a steady (idle) state.
    49  func quiescent(t *testing.T) time.Duration {
    50  	deadline, ok := t.Deadline()
    51  	if !ok {
    52  		return 5 * time.Second
    53  	}
    54  
    55  	const arbitraryCleanupMargin = 1 * time.Second
    56  	return time.Until(deadline) - arbitraryCleanupMargin
    57  }
    58  func TestBackground(t *testing.T) {
    59  	c := Background()
    60  	if c == nil {
    61  		t.Fatalf("Background returned nil")
    62  	}
    63  	select {
    64  	case x := <-c.Done():
    65  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    66  	default:
    67  	}
    68  	if got, want := fmt.Sprint(c), "context.Background"; got != want {
    69  		t.Errorf("Background().String() = %q want %q", got, want)
    70  	}
    71  }
    72  
    73  func TestTODO(t *testing.T) {
    74  	c := TODO()
    75  	if c == nil {
    76  		t.Fatalf("TODO returned nil")
    77  	}
    78  	select {
    79  	case x := <-c.Done():
    80  		t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
    81  	default:
    82  	}
    83  	if got, want := fmt.Sprint(c), "context.TODO"; got != want {
    84  		t.Errorf("TODO().String() = %q want %q", got, want)
    85  	}
    86  }
    87  
    88  func TestWithCancel(t *testing.T) {
    89  	c1, cancel := WithCancel(Background())
    90  
    91  	if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
    92  		t.Errorf("c1.String() = %q want %q", got, want)
    93  	}
    94  
    95  	o := otherContext{c1}
    96  	c2, _ := WithCancel(o)
    97  	contexts := []Context{c1, o, c2}
    98  
    99  	for i, c := range contexts {
   100  		if d := c.Done(); d == nil {
   101  			t.Errorf("c[%d].Done() == %v want non-nil", i, d)
   102  		}
   103  		if e := c.Err(); e != nil {
   104  			t.Errorf("c[%d].Err() == %v want nil", i, e)
   105  		}
   106  
   107  		select {
   108  		case x := <-c.Done():
   109  			t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
   110  		default:
   111  		}
   112  	}
   113  
   114  	cancel() // Should propagate synchronously.
   115  	for i, c := range contexts {
   116  		select {
   117  		case <-c.Done():
   118  		default:
   119  			t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
   120  		}
   121  		if e := c.Err(); e != Canceled {
   122  			t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
   123  		}
   124  	}
   125  }
   126  
   127  func testDeadline(c Context, name string, t *testing.T) {
   128  	t.Helper()
   129  	d := quiescent(t)
   130  	timer := time.NewTimer(d)
   131  	defer timer.Stop()
   132  	select {
   133  	case <-timer.C:
   134  		t.Fatalf("%s: context not timed out after %v", name, d)
   135  	case <-c.Done():
   136  	}
   137  	if e := c.Err(); e != DeadlineExceeded {
   138  		t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded)
   139  	}
   140  }
   141  
   142  func TestDeadline(t *testing.T) {
   143  	t.Parallel()
   144  
   145  	c, _ := WithDeadline(Background(), time.Now().Add(shortDuration))
   146  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   147  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   148  	}
   149  	testDeadline(c, "WithDeadline", t)
   150  
   151  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   152  	o := otherContext{c}
   153  	testDeadline(o, "WithDeadline+otherContext", t)
   154  
   155  	c, _ = WithDeadline(Background(), time.Now().Add(shortDuration))
   156  	o = otherContext{c}
   157  	c, _ = WithDeadline(o, time.Now().Add(veryLongDuration))
   158  	testDeadline(c, "WithDeadline+otherContext+WithDeadline", t)
   159  
   160  	c, _ = WithDeadline(Background(), time.Now().Add(-shortDuration))
   161  	testDeadline(c, "WithDeadline+inthepast", t)
   162  
   163  	c, _ = WithDeadline(Background(), time.Now())
   164  	testDeadline(c, "WithDeadline+now", t)
   165  }
   166  
   167  func TestTimeout(t *testing.T) {
   168  	t.Parallel()
   169  
   170  	c, _ := WithTimeout(Background(), shortDuration)
   171  	if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
   172  		t.Errorf("c.String() = %q want prefix %q", got, prefix)
   173  	}
   174  	testDeadline(c, "WithTimeout", t)
   175  
   176  	c, _ = WithTimeout(Background(), shortDuration)
   177  	o := otherContext{c}
   178  	testDeadline(o, "WithTimeout+otherContext", t)
   179  
   180  	c, _ = WithTimeout(Background(), shortDuration)
   181  	o = otherContext{c}
   182  	c, _ = WithTimeout(o, veryLongDuration)
   183  	testDeadline(c, "WithTimeout+otherContext+WithTimeout", t)
   184  }
   185  
   186  func TestCanceledTimeout(t *testing.T) {
   187  	c, _ := WithTimeout(Background(), time.Second)
   188  	o := otherContext{c}
   189  	c, cancel := WithTimeout(o, veryLongDuration)
   190  	cancel() // Should propagate synchronously.
   191  	select {
   192  	case <-c.Done():
   193  	default:
   194  		t.Errorf("<-c.Done() blocked, but shouldn't have")
   195  	}
   196  	if e := c.Err(); e != Canceled {
   197  		t.Errorf("c.Err() == %v want %v", e, Canceled)
   198  	}
   199  }
   200  
   201  type key1 int
   202  type key2 int
   203  
   204  func (k key2) String() string { return fmt.Sprintf("%[1]T(%[1]d)", k) }
   205  
   206  var k1 = key1(1)
   207  var k2 = key2(1) // same int as k1, different type
   208  var k3 = key2(3) // same type as k2, different int
   209  
   210  func TestValues(t *testing.T) {
   211  	check := func(c Context, nm, v1, v2, v3 string) {
   212  		if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
   213  			t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
   214  		}
   215  		if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
   216  			t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
   217  		}
   218  		if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
   219  			t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
   220  		}
   221  	}
   222  
   223  	c0 := Background()
   224  	check(c0, "c0", "", "", "")
   225  
   226  	c1 := WithValue(Background(), k1, "c1k1")
   227  	check(c1, "c1", "c1k1", "", "")
   228  
   229  	if got, want := fmt.Sprint(c1), `context.Background.WithValue(context_test.key1, c1k1)`; got != want {
   230  		t.Errorf("c.String() = %q want %q", got, want)
   231  	}
   232  
   233  	c2 := WithValue(c1, k2, "c2k2")
   234  	check(c2, "c2", "c1k1", "c2k2", "")
   235  
   236  	if got, want := fmt.Sprint(c2), `context.Background.WithValue(context_test.key1, c1k1).WithValue(context_test.key2(1), c2k2)`; got != want {
   237  		t.Errorf("c.String() = %q want %q", got, want)
   238  	}
   239  
   240  	c3 := WithValue(c2, k3, "c3k3")
   241  	check(c3, "c2", "c1k1", "c2k2", "c3k3")
   242  
   243  	c4 := WithValue(c3, k1, nil)
   244  	check(c4, "c4", "", "c2k2", "c3k3")
   245  
   246  	o0 := otherContext{Background()}
   247  	check(o0, "o0", "", "", "")
   248  
   249  	o1 := otherContext{WithValue(Background(), k1, "c1k1")}
   250  	check(o1, "o1", "c1k1", "", "")
   251  
   252  	o2 := WithValue(o1, k2, "o2k2")
   253  	check(o2, "o2", "c1k1", "o2k2", "")
   254  
   255  	o3 := otherContext{c4}
   256  	check(o3, "o3", "", "c2k2", "c3k3")
   257  
   258  	o4 := WithValue(o3, k3, nil)
   259  	check(o4, "o4", "", "c2k2", "")
   260  }
   261  
   262  func TestAllocs(t *testing.T) {
   263  	bg := Background()
   264  	for _, test := range []struct {
   265  		desc       string
   266  		f          func()
   267  		limit      float64
   268  		gccgoLimit float64
   269  	}{
   270  		{
   271  			desc:       "Background()",
   272  			f:          func() { Background() },
   273  			limit:      0,
   274  			gccgoLimit: 0,
   275  		},
   276  		{
   277  			desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
   278  			f: func() {
   279  				c := WithValue(bg, k1, nil)
   280  				c.Value(k1)
   281  			},
   282  			limit:      3,
   283  			gccgoLimit: 3,
   284  		},
   285  		{
   286  			desc: "WithTimeout(bg, 1*time.Nanosecond)",
   287  			f: func() {
   288  				c, _ := WithTimeout(bg, 1*time.Nanosecond)
   289  				<-c.Done()
   290  			},
   291  			limit:      12,
   292  			gccgoLimit: 15,
   293  		},
   294  		{
   295  			desc: "WithCancel(bg)",
   296  			f: func() {
   297  				c, cancel := WithCancel(bg)
   298  				cancel()
   299  				<-c.Done()
   300  			},
   301  			limit:      5,
   302  			gccgoLimit: 8,
   303  		},
   304  		{
   305  			desc: "WithTimeout(bg, 5*time.Millisecond)",
   306  			f: func() {
   307  				c, cancel := WithTimeout(bg, 5*time.Millisecond)
   308  				cancel()
   309  				<-c.Done()
   310  			},
   311  			limit:      8,
   312  			gccgoLimit: 25,
   313  		},
   314  	} {
   315  		limit := test.limit
   316  		if runtime.Compiler == "gccgo" {
   317  			// gccgo does not yet do escape analysis.
   318  			// TODO(iant): Remove this when gccgo does do escape analysis.
   319  			limit = test.gccgoLimit
   320  		}
   321  		numRuns := 100
   322  		if testing.Short() {
   323  			numRuns = 10
   324  		}
   325  		if n := testing.AllocsPerRun(numRuns, test.f); n > limit {
   326  			t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
   327  		}
   328  	}
   329  }
   330  
   331  func TestSimultaneousCancels(t *testing.T) {
   332  	root, cancel := WithCancel(Background())
   333  	m := map[Context]CancelFunc{root: cancel}
   334  	q := []Context{root}
   335  	// Create a tree of contexts.
   336  	for len(q) != 0 && len(m) < 100 {
   337  		parent := q[0]
   338  		q = q[1:]
   339  		for i := 0; i < 4; i++ {
   340  			ctx, cancel := WithCancel(parent)
   341  			m[ctx] = cancel
   342  			q = append(q, ctx)
   343  		}
   344  	}
   345  	// Start all the cancels in a random order.
   346  	var wg sync.WaitGroup
   347  	wg.Add(len(m))
   348  	for _, cancel := range m {
   349  		go func(cancel CancelFunc) {
   350  			cancel()
   351  			wg.Done()
   352  		}(cancel)
   353  	}
   354  
   355  	d := quiescent(t)
   356  	stuck := make(chan struct{})
   357  	timer := time.AfterFunc(d, func() { close(stuck) })
   358  	defer timer.Stop()
   359  
   360  	// Wait on all the contexts in a random order.
   361  	for ctx := range m {
   362  		select {
   363  		case <-ctx.Done():
   364  		case <-stuck:
   365  			buf := make([]byte, 10<<10)
   366  			n := runtime.Stack(buf, true)
   367  			t.Fatalf("timed out after %v waiting for <-ctx.Done(); stacks:\n%s", d, buf[:n])
   368  		}
   369  	}
   370  	// Wait for all the cancel functions to return.
   371  	done := make(chan struct{})
   372  	go func() {
   373  		wg.Wait()
   374  		close(done)
   375  	}()
   376  	select {
   377  	case <-done:
   378  	case <-stuck:
   379  		buf := make([]byte, 10<<10)
   380  		n := runtime.Stack(buf, true)
   381  		t.Fatalf("timed out after %v waiting for cancel functions; stacks:\n%s", d, buf[:n])
   382  	}
   383  }
   384  
   385  func TestInterlockedCancels(t *testing.T) {
   386  	parent, cancelParent := WithCancel(Background())
   387  	child, cancelChild := WithCancel(parent)
   388  	go func() {
   389  		<-parent.Done()
   390  		cancelChild()
   391  	}()
   392  	cancelParent()
   393  	d := quiescent(t)
   394  	timer := time.NewTimer(d)
   395  	defer timer.Stop()
   396  	select {
   397  	case <-child.Done():
   398  	case <-timer.C:
   399  		buf := make([]byte, 10<<10)
   400  		n := runtime.Stack(buf, true)
   401  		t.Fatalf("timed out after %v waiting for child.Done(); stacks:\n%s", d, buf[:n])
   402  	}
   403  }
   404  
   405  func TestLayersCancel(t *testing.T) {
   406  	testLayers(t, time.Now().UnixNano(), false)
   407  }
   408  
   409  func TestLayersTimeout(t *testing.T) {
   410  	testLayers(t, time.Now().UnixNano(), true)
   411  }
   412  
   413  func testLayers(t *testing.T, seed int64, testTimeout bool) {
   414  	t.Parallel()
   415  
   416  	r := rand.New(rand.NewSource(seed))
   417  	prefix := fmt.Sprintf("seed=%d", seed)
   418  	errorf := func(format string, a ...any) {
   419  		t.Errorf(prefix+format, a...)
   420  	}
   421  	const (
   422  		minLayers = 30
   423  	)
   424  	type value int
   425  	var (
   426  		vals      []*value
   427  		cancels   []CancelFunc
   428  		numTimers int
   429  		ctx       = Background()
   430  	)
   431  	for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
   432  		switch r.Intn(3) {
   433  		case 0:
   434  			v := new(value)
   435  			ctx = WithValue(ctx, v, v)
   436  			vals = append(vals, v)
   437  		case 1:
   438  			var cancel CancelFunc
   439  			ctx, cancel = WithCancel(ctx)
   440  			cancels = append(cancels, cancel)
   441  		case 2:
   442  			var cancel CancelFunc
   443  			d := veryLongDuration
   444  			if testTimeout {
   445  				d = shortDuration
   446  			}
   447  			ctx, cancel = WithTimeout(ctx, d)
   448  			cancels = append(cancels, cancel)
   449  			numTimers++
   450  		}
   451  	}
   452  	checkValues := func(when string) {
   453  		for _, key := range vals {
   454  			if val := ctx.Value(key).(*value); key != val {
   455  				errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
   456  			}
   457  		}
   458  	}
   459  	if !testTimeout {
   460  		select {
   461  		case <-ctx.Done():
   462  			errorf("ctx should not be canceled yet")
   463  		default:
   464  		}
   465  	}
   466  	if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
   467  		t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
   468  	}
   469  	t.Log(ctx)
   470  	checkValues("before cancel")
   471  	if testTimeout {
   472  		d := quiescent(t)
   473  		timer := time.NewTimer(d)
   474  		defer timer.Stop()
   475  		select {
   476  		case <-ctx.Done():
   477  		case <-timer.C:
   478  			errorf("ctx should have timed out after %v", d)
   479  		}
   480  		checkValues("after timeout")
   481  	} else {
   482  		cancel := cancels[r.Intn(len(cancels))]
   483  		cancel()
   484  		select {
   485  		case <-ctx.Done():
   486  		default:
   487  			errorf("ctx should be canceled")
   488  		}
   489  		checkValues("after cancel")
   490  	}
   491  }
   492  
   493  func TestWithCancelCanceledParent(t *testing.T) {
   494  	parent, pcancel := WithCancelCause(Background())
   495  	cause := fmt.Errorf("Because!")
   496  	pcancel(cause)
   497  
   498  	c, _ := WithCancel(parent)
   499  	select {
   500  	case <-c.Done():
   501  	default:
   502  		t.Errorf("child not done immediately upon construction")
   503  	}
   504  	if got, want := c.Err(), Canceled; got != want {
   505  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   506  	}
   507  	if got, want := Cause(c), cause; got != want {
   508  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   509  	}
   510  }
   511  
   512  func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) {
   513  	// Cancel the parent goroutine concurrently with creating a child.
   514  	for i := 0; i < 100; i++ {
   515  		parent, pcancel := WithCancelCause(Background())
   516  		cause := fmt.Errorf("Because!")
   517  		go pcancel(cause)
   518  
   519  		c, _ := WithCancel(parent)
   520  		<-c.Done()
   521  		if got, want := c.Err(), Canceled; got != want {
   522  			t.Errorf("child not canceled; got = %v, want = %v", got, want)
   523  		}
   524  		if got, want := Cause(c), cause; got != want {
   525  			t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   526  		}
   527  	}
   528  }
   529  
   530  func TestWithValueChecksKey(t *testing.T) {
   531  	panicVal := recoveredValue(func() { _ = WithValue(Background(), []byte("foo"), "bar") })
   532  	if panicVal == nil {
   533  		t.Error("expected panic")
   534  	}
   535  	panicVal = recoveredValue(func() { _ = WithValue(Background(), nil, "bar") })
   536  	if got, want := fmt.Sprint(panicVal), "nil key"; got != want {
   537  		t.Errorf("panic = %q; want %q", got, want)
   538  	}
   539  }
   540  
   541  func TestInvalidDerivedFail(t *testing.T) {
   542  	panicVal := recoveredValue(func() { _, _ = WithCancel(nil) })
   543  	if panicVal == nil {
   544  		t.Error("expected panic")
   545  	}
   546  	panicVal = recoveredValue(func() { _, _ = WithDeadline(nil, time.Now().Add(shortDuration)) })
   547  	if panicVal == nil {
   548  		t.Error("expected panic")
   549  	}
   550  	panicVal = recoveredValue(func() { _ = WithValue(nil, "foo", "bar") })
   551  	if panicVal == nil {
   552  		t.Error("expected panic")
   553  	}
   554  }
   555  
   556  func recoveredValue(fn func()) (v any) {
   557  	defer func() { v = recover() }()
   558  	fn()
   559  	return
   560  }
   561  
   562  func TestDeadlineExceededSupportsTimeout(t *testing.T) {
   563  	i, ok := DeadlineExceeded.(interface {
   564  		Timeout() bool
   565  	})
   566  	if !ok {
   567  		t.Fatal("DeadlineExceeded does not support Timeout interface")
   568  	}
   569  	if !i.Timeout() {
   570  		t.Fatal("wrong value for timeout")
   571  	}
   572  }
   573  func TestCause(t *testing.T) {
   574  	var (
   575  		forever       = 1e6 * time.Second
   576  		parentCause   = fmt.Errorf("parentCause")
   577  		childCause    = fmt.Errorf("childCause")
   578  		tooSlow       = fmt.Errorf("tooSlow")
   579  		finishedEarly = fmt.Errorf("finishedEarly")
   580  	)
   581  	for _, test := range []struct {
   582  		name  string
   583  		ctx   func() Context
   584  		err   error
   585  		cause error
   586  	}{
   587  		{
   588  			name:  "Background",
   589  			ctx:   Background,
   590  			err:   nil,
   591  			cause: nil,
   592  		},
   593  		{
   594  			name:  "TODO",
   595  			ctx:   TODO,
   596  			err:   nil,
   597  			cause: nil,
   598  		},
   599  		{
   600  			name: "WithCancel",
   601  			ctx: func() Context {
   602  				ctx, cancel := WithCancel(Background())
   603  				cancel()
   604  				return ctx
   605  			},
   606  			err:   Canceled,
   607  			cause: Canceled,
   608  		},
   609  		{
   610  			name: "WithCancelCause",
   611  			ctx: func() Context {
   612  				ctx, cancel := WithCancelCause(Background())
   613  				cancel(parentCause)
   614  				return ctx
   615  			},
   616  			err:   Canceled,
   617  			cause: parentCause,
   618  		},
   619  		{
   620  			name: "WithCancelCause nil",
   621  			ctx: func() Context {
   622  				ctx, cancel := WithCancelCause(Background())
   623  				cancel(nil)
   624  				return ctx
   625  			},
   626  			err:   Canceled,
   627  			cause: Canceled,
   628  		},
   629  		{
   630  			name: "WithCancelCause: parent cause before child",
   631  			ctx: func() Context {
   632  				ctx, cancelParent := WithCancelCause(Background())
   633  				ctx, cancelChild := WithCancelCause(ctx)
   634  				cancelParent(parentCause)
   635  				cancelChild(childCause)
   636  				return ctx
   637  			},
   638  			err:   Canceled,
   639  			cause: parentCause,
   640  		},
   641  		{
   642  			name: "WithCancelCause: parent cause after child",
   643  			ctx: func() Context {
   644  				ctx, cancelParent := WithCancelCause(Background())
   645  				ctx, cancelChild := WithCancelCause(ctx)
   646  				cancelChild(childCause)
   647  				cancelParent(parentCause)
   648  				return ctx
   649  			},
   650  			err:   Canceled,
   651  			cause: childCause,
   652  		},
   653  		{
   654  			name: "WithCancelCause: parent cause before nil",
   655  			ctx: func() Context {
   656  				ctx, cancelParent := WithCancelCause(Background())
   657  				ctx, cancelChild := WithCancel(ctx)
   658  				cancelParent(parentCause)
   659  				cancelChild()
   660  				return ctx
   661  			},
   662  			err:   Canceled,
   663  			cause: parentCause,
   664  		},
   665  		{
   666  			name: "WithCancelCause: parent cause after nil",
   667  			ctx: func() Context {
   668  				ctx, cancelParent := WithCancelCause(Background())
   669  				ctx, cancelChild := WithCancel(ctx)
   670  				cancelChild()
   671  				cancelParent(parentCause)
   672  				return ctx
   673  			},
   674  			err:   Canceled,
   675  			cause: Canceled,
   676  		},
   677  		{
   678  			name: "WithCancelCause: child cause after nil",
   679  			ctx: func() Context {
   680  				ctx, cancelParent := WithCancel(Background())
   681  				ctx, cancelChild := WithCancelCause(ctx)
   682  				cancelParent()
   683  				cancelChild(childCause)
   684  				return ctx
   685  			},
   686  			err:   Canceled,
   687  			cause: Canceled,
   688  		},
   689  		{
   690  			name: "WithCancelCause: child cause before nil",
   691  			ctx: func() Context {
   692  				ctx, cancelParent := WithCancel(Background())
   693  				ctx, cancelChild := WithCancelCause(ctx)
   694  				cancelChild(childCause)
   695  				cancelParent()
   696  				return ctx
   697  			},
   698  			err:   Canceled,
   699  			cause: childCause,
   700  		},
   701  		{
   702  			name: "WithTimeout",
   703  			ctx: func() Context {
   704  				ctx, cancel := WithTimeout(Background(), 0)
   705  				cancel()
   706  				return ctx
   707  			},
   708  			err:   DeadlineExceeded,
   709  			cause: DeadlineExceeded,
   710  		},
   711  		{
   712  			name: "WithTimeout canceled",
   713  			ctx: func() Context {
   714  				ctx, cancel := WithTimeout(Background(), forever)
   715  				cancel()
   716  				return ctx
   717  			},
   718  			err:   Canceled,
   719  			cause: Canceled,
   720  		},
   721  		{
   722  			name: "WithTimeoutCause",
   723  			ctx: func() Context {
   724  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   725  				cancel()
   726  				return ctx
   727  			},
   728  			err:   DeadlineExceeded,
   729  			cause: tooSlow,
   730  		},
   731  		{
   732  			name: "WithTimeoutCause canceled",
   733  			ctx: func() Context {
   734  				ctx, cancel := WithTimeoutCause(Background(), forever, tooSlow)
   735  				cancel()
   736  				return ctx
   737  			},
   738  			err:   Canceled,
   739  			cause: Canceled,
   740  		},
   741  		{
   742  			name: "WithTimeoutCause stacked",
   743  			ctx: func() Context {
   744  				ctx, cancel := WithCancelCause(Background())
   745  				ctx, _ = WithTimeoutCause(ctx, 0, tooSlow)
   746  				cancel(finishedEarly)
   747  				return ctx
   748  			},
   749  			err:   DeadlineExceeded,
   750  			cause: tooSlow,
   751  		},
   752  		{
   753  			name: "WithTimeoutCause stacked canceled",
   754  			ctx: func() Context {
   755  				ctx, cancel := WithCancelCause(Background())
   756  				ctx, _ = WithTimeoutCause(ctx, forever, tooSlow)
   757  				cancel(finishedEarly)
   758  				return ctx
   759  			},
   760  			err:   Canceled,
   761  			cause: finishedEarly,
   762  		},
   763  		{
   764  			name: "WithoutCancel",
   765  			ctx: func() Context {
   766  				return WithoutCancel(Background())
   767  			},
   768  			err:   nil,
   769  			cause: nil,
   770  		},
   771  		{
   772  			name: "WithoutCancel canceled",
   773  			ctx: func() Context {
   774  				ctx, cancel := WithCancelCause(Background())
   775  				ctx = WithoutCancel(ctx)
   776  				cancel(finishedEarly)
   777  				return ctx
   778  			},
   779  			err:   nil,
   780  			cause: nil,
   781  		},
   782  		{
   783  			name: "WithoutCancel timeout",
   784  			ctx: func() Context {
   785  				ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow)
   786  				ctx = WithoutCancel(ctx)
   787  				cancel()
   788  				return ctx
   789  			},
   790  			err:   nil,
   791  			cause: nil,
   792  		},
   793  	} {
   794  		test := test
   795  		t.Run(test.name, func(t *testing.T) {
   796  			t.Parallel()
   797  			ctx := test.ctx()
   798  			if got, want := ctx.Err(), test.err; want != got {
   799  				t.Errorf("ctx.Err() = %v want %v", got, want)
   800  			}
   801  			if got, want := Cause(ctx), test.cause; want != got {
   802  				t.Errorf("Cause(ctx) = %v want %v", got, want)
   803  			}
   804  		})
   805  	}
   806  }
   807  
   808  func TestCauseRace(t *testing.T) {
   809  	cause := errors.New("TestCauseRace")
   810  	ctx, cancel := WithCancelCause(Background())
   811  	go func() {
   812  		cancel(cause)
   813  	}()
   814  	for {
   815  		// Poll Cause, rather than waiting for Done, to test that
   816  		// access to the underlying cause is synchronized properly.
   817  		if err := Cause(ctx); err != nil {
   818  			if err != cause {
   819  				t.Errorf("Cause returned %v, want %v", err, cause)
   820  			}
   821  			break
   822  		}
   823  		runtime.Gosched()
   824  	}
   825  }
   826  
   827  func TestWithoutCancel(t *testing.T) {
   828  	key, value := "key", "value"
   829  	ctx := WithValue(Background(), key, value)
   830  	ctx = WithoutCancel(ctx)
   831  	if d, ok := ctx.Deadline(); !d.IsZero() || ok != false {
   832  		t.Errorf("ctx.Deadline() = %v, %v want zero, false", d, ok)
   833  	}
   834  	if done := ctx.Done(); done != nil {
   835  		t.Errorf("ctx.Deadline() = %v want nil", done)
   836  	}
   837  	if err := ctx.Err(); err != nil {
   838  		t.Errorf("ctx.Err() = %v want nil", err)
   839  	}
   840  	if v := ctx.Value(key); v != value {
   841  		t.Errorf("ctx.Value(%q) = %q want %q", key, v, value)
   842  	}
   843  }
   844  
   845  type customDoneContext struct {
   846  	Context
   847  	donec chan struct{}
   848  }
   849  
   850  func (c *customDoneContext) Done() <-chan struct{} {
   851  	return c.donec
   852  }
   853  
   854  func TestCustomContextPropagation(t *testing.T) {
   855  	cause := errors.New("TestCustomContextPropagation")
   856  	donec := make(chan struct{})
   857  	ctx1, cancel1 := WithCancelCause(Background())
   858  	ctx2 := &customDoneContext{
   859  		Context: ctx1,
   860  		donec:   donec,
   861  	}
   862  	ctx3, cancel3 := WithCancel(ctx2)
   863  	defer cancel3()
   864  
   865  	cancel1(cause)
   866  	close(donec)
   867  
   868  	<-ctx3.Done()
   869  	if got, want := ctx3.Err(), Canceled; got != want {
   870  		t.Errorf("child not canceled; got = %v, want = %v", got, want)
   871  	}
   872  	if got, want := Cause(ctx3), cause; got != want {
   873  		t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
   874  	}
   875  }
   876  
   877  // customCauseContext is a custom Context used to test context.Cause.
   878  type customCauseContext struct {
   879  	mu   sync.Mutex
   880  	done chan struct{}
   881  	err  error
   882  
   883  	cancelChild CancelFunc
   884  }
   885  
   886  func (ccc *customCauseContext) Deadline() (deadline time.Time, ok bool) {
   887  	return
   888  }
   889  
   890  func (ccc *customCauseContext) Done() <-chan struct{} {
   891  	ccc.mu.Lock()
   892  	defer ccc.mu.Unlock()
   893  	return ccc.done
   894  }
   895  
   896  func (ccc *customCauseContext) Err() error {
   897  	ccc.mu.Lock()
   898  	defer ccc.mu.Unlock()
   899  	return ccc.err
   900  }
   901  
   902  func (ccc *customCauseContext) Value(key any) any {
   903  	return nil
   904  }
   905  
   906  func (ccc *customCauseContext) cancel() {
   907  	ccc.mu.Lock()
   908  	ccc.err = Canceled
   909  	close(ccc.done)
   910  	cancelChild := ccc.cancelChild
   911  	ccc.mu.Unlock()
   912  
   913  	if cancelChild != nil {
   914  		cancelChild()
   915  	}
   916  }
   917  
   918  func (ccc *customCauseContext) setCancelChild(cancelChild CancelFunc) {
   919  	ccc.cancelChild = cancelChild
   920  }
   921  
   922  func TestCustomContextCause(t *testing.T) {
   923  	// Test if we cancel a custom context, Err and Cause return Canceled.
   924  	ccc := &customCauseContext{
   925  		done: make(chan struct{}),
   926  	}
   927  	ccc.cancel()
   928  	if got := ccc.Err(); got != Canceled {
   929  		t.Errorf("ccc.Err() = %v, want %v", got, Canceled)
   930  	}
   931  	if got := Cause(ccc); got != Canceled {
   932  		t.Errorf("Cause(ccc) = %v, want %v", got, Canceled)
   933  	}
   934  
   935  	// Test that if we pass a custom context to WithCancelCause,
   936  	// and then cancel that child context with a cause,
   937  	// that the cause of the child canceled context is correct
   938  	// but that the parent custom context is not canceled.
   939  	ccc = &customCauseContext{
   940  		done: make(chan struct{}),
   941  	}
   942  	ctx, causeFunc := WithCancelCause(ccc)
   943  	cause := errors.New("TestCustomContextCause")
   944  	causeFunc(cause)
   945  	if got := ctx.Err(); got != Canceled {
   946  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   947  	}
   948  	if got := Cause(ctx); got != cause {
   949  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   950  	}
   951  	if got := ccc.Err(); got != nil {
   952  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, nil)
   953  	}
   954  	if got := Cause(ccc); got != nil {
   955  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, nil)
   956  	}
   957  
   958  	// Test that if we now cancel the parent custom context,
   959  	// the cause of the child canceled context is still correct,
   960  	// and the parent custom context is canceled without a cause.
   961  	ccc.cancel()
   962  	if got := ctx.Err(); got != Canceled {
   963  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   964  	}
   965  	if got := Cause(ctx); got != cause {
   966  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, cause)
   967  	}
   968  	if got := ccc.Err(); got != Canceled {
   969  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   970  	}
   971  	if got := Cause(ccc); got != Canceled {
   972  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   973  	}
   974  
   975  	// Test that if we associate a custom context with a child,
   976  	// then canceling the custom context cancels the child.
   977  	ccc = &customCauseContext{
   978  		done: make(chan struct{}),
   979  	}
   980  	ctx, cancelFunc := WithCancel(ccc)
   981  	ccc.setCancelChild(cancelFunc)
   982  	ccc.cancel()
   983  	if got := ctx.Err(); got != Canceled {
   984  		t.Errorf("after CancelCauseFunc ctx.Err() = %v, want %v", got, Canceled)
   985  	}
   986  	if got := Cause(ctx); got != Canceled {
   987  		t.Errorf("after CancelCauseFunc Cause(ctx) = %v, want %v", got, Canceled)
   988  	}
   989  	if got := ccc.Err(); got != Canceled {
   990  		t.Errorf("after CancelCauseFunc ccc.Err() = %v, want %v", got, Canceled)
   991  	}
   992  	if got := Cause(ccc); got != Canceled {
   993  		t.Errorf("after CancelCauseFunc Cause(ccc) = %v, want %v", got, Canceled)
   994  	}
   995  }
   996  
   997  func TestAfterFuncCalledAfterCancel(t *testing.T) {
   998  	ctx, cancel := WithCancel(Background())
   999  	donec := make(chan struct{})
  1000  	stop := AfterFunc(ctx, func() {
  1001  		close(donec)
  1002  	})
  1003  	select {
  1004  	case <-donec:
  1005  		t.Fatalf("AfterFunc called before context is done")
  1006  	case <-time.After(shortDuration):
  1007  	}
  1008  	cancel()
  1009  	select {
  1010  	case <-donec:
  1011  	case <-time.After(veryLongDuration):
  1012  		t.Fatalf("AfterFunc not called after context is canceled")
  1013  	}
  1014  	if stop() {
  1015  		t.Fatalf("stop() = true, want false")
  1016  	}
  1017  }
  1018  
  1019  func TestAfterFuncCalledAfterTimeout(t *testing.T) {
  1020  	ctx, cancel := WithTimeout(Background(), shortDuration)
  1021  	defer cancel()
  1022  	donec := make(chan struct{})
  1023  	AfterFunc(ctx, func() {
  1024  		close(donec)
  1025  	})
  1026  	select {
  1027  	case <-donec:
  1028  	case <-time.After(veryLongDuration):
  1029  		t.Fatalf("AfterFunc not called after context is canceled")
  1030  	}
  1031  }
  1032  
  1033  func TestAfterFuncCalledImmediately(t *testing.T) {
  1034  	ctx, cancel := WithCancel(Background())
  1035  	cancel()
  1036  	donec := make(chan struct{})
  1037  	AfterFunc(ctx, func() {
  1038  		close(donec)
  1039  	})
  1040  	select {
  1041  	case <-donec:
  1042  	case <-time.After(veryLongDuration):
  1043  		t.Fatalf("AfterFunc not called for already-canceled context")
  1044  	}
  1045  }
  1046  
  1047  func TestAfterFuncNotCalledAfterStop(t *testing.T) {
  1048  	ctx, cancel := WithCancel(Background())
  1049  	donec := make(chan struct{})
  1050  	stop := AfterFunc(ctx, func() {
  1051  		close(donec)
  1052  	})
  1053  	if !stop() {
  1054  		t.Fatalf("stop() = false, want true")
  1055  	}
  1056  	cancel()
  1057  	select {
  1058  	case <-donec:
  1059  		t.Fatalf("AfterFunc called for already-canceled context")
  1060  	case <-time.After(shortDuration):
  1061  	}
  1062  	if stop() {
  1063  		t.Fatalf("stop() = true, want false")
  1064  	}
  1065  }
  1066  
  1067  // This test verifies that canceling a context does not block waiting for AfterFuncs to finish.
  1068  func TestAfterFuncCalledAsynchronously(t *testing.T) {
  1069  	ctx, cancel := WithCancel(Background())
  1070  	donec := make(chan struct{})
  1071  	stop := AfterFunc(ctx, func() {
  1072  		// The channel send blocks until donec is read from.
  1073  		donec <- struct{}{}
  1074  	})
  1075  	defer stop()
  1076  	cancel()
  1077  	// After cancel returns, read from donec and unblock the AfterFunc.
  1078  	select {
  1079  	case <-donec:
  1080  	case <-time.After(veryLongDuration):
  1081  		t.Fatalf("AfterFunc not called after context is canceled")
  1082  	}
  1083  }
  1084  

View as plain text