Source file src/iter/pull_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package iter_test
     6  
     7  import (
     8  	"fmt"
     9  	. "iter"
    10  	"runtime"
    11  	"testing"
    12  )
    13  
    14  func count(n int) Seq[int] {
    15  	return func(yield func(int) bool) {
    16  		for i := range n {
    17  			if !yield(i) {
    18  				break
    19  			}
    20  		}
    21  	}
    22  }
    23  
    24  func squares(n int) Seq2[int, int64] {
    25  	return func(yield func(int, int64) bool) {
    26  		for i := range n {
    27  			if !yield(i, int64(i)*int64(i)) {
    28  				break
    29  			}
    30  		}
    31  	}
    32  }
    33  
    34  func TestPull(t *testing.T) {
    35  	for end := 0; end <= 3; end++ {
    36  		t.Run(fmt.Sprint(end), func(t *testing.T) {
    37  			ng := runtime.NumGoroutine()
    38  			wantNG := func(want int) {
    39  				if xg := runtime.NumGoroutine() - ng; xg != want {
    40  					t.Helper()
    41  					t.Errorf("have %d extra goroutines, want %d", xg, want)
    42  				}
    43  			}
    44  			wantNG(0)
    45  			next, stop := Pull(count(3))
    46  			wantNG(1)
    47  			for i := range end {
    48  				v, ok := next()
    49  				if v != i || ok != true {
    50  					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, i, true)
    51  				}
    52  				wantNG(1)
    53  			}
    54  			wantNG(1)
    55  			if end < 3 {
    56  				stop()
    57  				wantNG(0)
    58  			}
    59  			for range 2 {
    60  				v, ok := next()
    61  				if v != 0 || ok != false {
    62  					t.Fatalf("next() = %d, %v, want %d, %v", v, ok, 0, false)
    63  				}
    64  				wantNG(0)
    65  			}
    66  			wantNG(0)
    67  
    68  			stop()
    69  			stop()
    70  			stop()
    71  			wantNG(0)
    72  		})
    73  	}
    74  }
    75  
    76  func TestPull2(t *testing.T) {
    77  	for end := 0; end <= 3; end++ {
    78  		t.Run(fmt.Sprint(end), func(t *testing.T) {
    79  			ng := runtime.NumGoroutine()
    80  			wantNG := func(want int) {
    81  				if xg := runtime.NumGoroutine() - ng; xg != want {
    82  					t.Helper()
    83  					t.Errorf("have %d extra goroutines, want %d", xg, want)
    84  				}
    85  			}
    86  			wantNG(0)
    87  			next, stop := Pull2(squares(3))
    88  			wantNG(1)
    89  			for i := range end {
    90  				k, v, ok := next()
    91  				if k != i || v != int64(i*i) || ok != true {
    92  					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, i, i*i, true)
    93  				}
    94  				wantNG(1)
    95  			}
    96  			wantNG(1)
    97  			if end < 3 {
    98  				stop()
    99  				wantNG(0)
   100  			}
   101  			for range 2 {
   102  				k, v, ok := next()
   103  				if v != 0 || ok != false {
   104  					t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, 0, 0, false)
   105  				}
   106  				wantNG(0)
   107  			}
   108  			wantNG(0)
   109  
   110  			stop()
   111  			stop()
   112  			stop()
   113  			wantNG(0)
   114  		})
   115  	}
   116  }
   117  
   118  func TestPullDoubleNext(t *testing.T) {
   119  	next, _ := Pull(doDoubleNext())
   120  	nextSlot = next
   121  	next()
   122  	if nextSlot != nil {
   123  		t.Fatal("double next did not fail")
   124  	}
   125  }
   126  
   127  var nextSlot func() (int, bool)
   128  
   129  func doDoubleNext() Seq[int] {
   130  	return func(_ func(int) bool) {
   131  		defer func() {
   132  			if recover() != nil {
   133  				nextSlot = nil
   134  			}
   135  		}()
   136  		nextSlot()
   137  	}
   138  }
   139  
   140  func TestPullDoubleNext2(t *testing.T) {
   141  	next, _ := Pull2(doDoubleNext2())
   142  	nextSlot2 = next
   143  	next()
   144  	if nextSlot2 != nil {
   145  		t.Fatal("double next did not fail")
   146  	}
   147  }
   148  
   149  var nextSlot2 func() (int, int, bool)
   150  
   151  func doDoubleNext2() Seq2[int, int] {
   152  	return func(_ func(int, int) bool) {
   153  		defer func() {
   154  			if recover() != nil {
   155  				nextSlot2 = nil
   156  			}
   157  		}()
   158  		nextSlot2()
   159  	}
   160  }
   161  
   162  func TestPullDoubleYield(t *testing.T) {
   163  	_, stop := Pull(storeYield())
   164  	defer func() {
   165  		if recover() != nil {
   166  			yieldSlot = nil
   167  		}
   168  		stop()
   169  	}()
   170  	yieldSlot(5)
   171  	if yieldSlot != nil {
   172  		t.Fatal("double yield did not fail")
   173  	}
   174  }
   175  
   176  func storeYield() Seq[int] {
   177  	return func(yield func(int) bool) {
   178  		yieldSlot = yield
   179  		if !yield(5) {
   180  			return
   181  		}
   182  	}
   183  }
   184  
   185  var yieldSlot func(int) bool
   186  
   187  func TestPullDoubleYield2(t *testing.T) {
   188  	_, stop := Pull2(storeYield2())
   189  	defer func() {
   190  		if recover() != nil {
   191  			yieldSlot2 = nil
   192  		}
   193  		stop()
   194  	}()
   195  	yieldSlot2(23, 77)
   196  	if yieldSlot2 != nil {
   197  		t.Fatal("double yield did not fail")
   198  	}
   199  }
   200  
   201  func storeYield2() Seq2[int, int] {
   202  	return func(yield func(int, int) bool) {
   203  		yieldSlot2 = yield
   204  		if !yield(23, 77) {
   205  			return
   206  		}
   207  	}
   208  }
   209  
   210  var yieldSlot2 func(int, int) bool
   211  

View as plain text