Source file src/context/afterfunc_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package context_test
     6  
     7  import (
     8  	"context"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  // afterFuncContext is a context that's not one of the types
    15  // defined in context.go, that supports registering AfterFuncs.
    16  type afterFuncContext struct {
    17  	mu         sync.Mutex
    18  	afterFuncs map[*byte]func()
    19  	done       chan struct{}
    20  	err        error
    21  }
    22  
    23  func newAfterFuncContext() context.Context {
    24  	return &afterFuncContext{}
    25  }
    26  
    27  func (c *afterFuncContext) Deadline() (time.Time, bool) {
    28  	return time.Time{}, false
    29  }
    30  
    31  func (c *afterFuncContext) Done() <-chan struct{} {
    32  	c.mu.Lock()
    33  	defer c.mu.Unlock()
    34  	if c.done == nil {
    35  		c.done = make(chan struct{})
    36  	}
    37  	return c.done
    38  }
    39  
    40  func (c *afterFuncContext) Err() error {
    41  	c.mu.Lock()
    42  	defer c.mu.Unlock()
    43  	return c.err
    44  }
    45  
    46  func (c *afterFuncContext) Value(key any) any {
    47  	return nil
    48  }
    49  
    50  func (c *afterFuncContext) AfterFunc(f func()) func() bool {
    51  	c.mu.Lock()
    52  	defer c.mu.Unlock()
    53  	k := new(byte)
    54  	if c.afterFuncs == nil {
    55  		c.afterFuncs = make(map[*byte]func())
    56  	}
    57  	c.afterFuncs[k] = f
    58  	return func() bool {
    59  		c.mu.Lock()
    60  		defer c.mu.Unlock()
    61  		_, ok := c.afterFuncs[k]
    62  		delete(c.afterFuncs, k)
    63  		return ok
    64  	}
    65  }
    66  
    67  func (c *afterFuncContext) cancel(err error) {
    68  	c.mu.Lock()
    69  	defer c.mu.Unlock()
    70  	if c.err != nil {
    71  		return
    72  	}
    73  	c.err = err
    74  	for _, f := range c.afterFuncs {
    75  		go f()
    76  	}
    77  	c.afterFuncs = nil
    78  }
    79  
    80  func TestCustomContextAfterFuncCancel(t *testing.T) {
    81  	ctx0 := &afterFuncContext{}
    82  	ctx1, cancel := context.WithCancel(ctx0)
    83  	defer cancel()
    84  	ctx0.cancel(context.Canceled)
    85  	<-ctx1.Done()
    86  }
    87  
    88  func TestCustomContextAfterFuncTimeout(t *testing.T) {
    89  	ctx0 := &afterFuncContext{}
    90  	ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
    91  	defer cancel()
    92  	ctx0.cancel(context.Canceled)
    93  	<-ctx1.Done()
    94  }
    95  
    96  func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
    97  	ctx0 := &afterFuncContext{}
    98  	donec := make(chan struct{})
    99  	stop := context.AfterFunc(ctx0, func() {
   100  		close(donec)
   101  	})
   102  	defer stop()
   103  	ctx0.cancel(context.Canceled)
   104  	<-donec
   105  }
   106  
   107  func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
   108  	ctx0 := &afterFuncContext{}
   109  	_, cancel1 := context.WithCancel(ctx0)
   110  	_, cancel2 := context.WithCancel(ctx0)
   111  	if got, want := len(ctx0.afterFuncs), 2; got != want {
   112  		t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
   113  	}
   114  	cancel1()
   115  	cancel2()
   116  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   117  		t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
   118  	}
   119  }
   120  
   121  func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
   122  	ctx0 := &afterFuncContext{}
   123  	_, cancel := context.WithTimeout(ctx0, veryLongDuration)
   124  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   125  		t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
   126  	}
   127  	cancel()
   128  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   129  		t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
   130  	}
   131  }
   132  
   133  func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
   134  	ctx0 := &afterFuncContext{}
   135  	stop := context.AfterFunc(ctx0, func() {})
   136  	if got, want := len(ctx0.afterFuncs), 1; got != want {
   137  		t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
   138  	}
   139  	stop()
   140  	if got, want := len(ctx0.afterFuncs), 0; got != want {
   141  		t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
   142  	}
   143  }
   144  

View as plain text