Source file src/internal/concurrent/hashtriemap_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package concurrent
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"runtime"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"unsafe"
    16  )
    17  
    18  func TestHashTrieMap(t *testing.T) {
    19  	testHashTrieMap(t, func() *HashTrieMap[string, int] {
    20  		return NewHashTrieMap[string, int]()
    21  	})
    22  }
    23  
    24  func TestHashTrieMapBadHash(t *testing.T) {
    25  	testHashTrieMap(t, func() *HashTrieMap[string, int] {
    26  		// Stub out the good hash function with a terrible one.
    27  		// Everything should still work as expected.
    28  		m := NewHashTrieMap[string, int]()
    29  		m.keyHash = func(_ unsafe.Pointer, _ uintptr) uintptr {
    30  			return 0
    31  		}
    32  		return m
    33  	})
    34  }
    35  
    36  func testHashTrieMap(t *testing.T, newMap func() *HashTrieMap[string, int]) {
    37  	t.Run("LoadEmpty", func(t *testing.T) {
    38  		m := newMap()
    39  
    40  		for _, s := range testData {
    41  			expectMissing(t, s, 0)(m.Load(s))
    42  		}
    43  	})
    44  	t.Run("LoadOrStore", func(t *testing.T) {
    45  		m := newMap()
    46  
    47  		for i, s := range testData {
    48  			expectMissing(t, s, 0)(m.Load(s))
    49  			expectStored(t, s, i)(m.LoadOrStore(s, i))
    50  			expectPresent(t, s, i)(m.Load(s))
    51  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    52  		}
    53  		for i, s := range testData {
    54  			expectPresent(t, s, i)(m.Load(s))
    55  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    56  		}
    57  	})
    58  	t.Run("CompareAndDeleteAll", func(t *testing.T) {
    59  		m := newMap()
    60  
    61  		for range 3 {
    62  			for i, s := range testData {
    63  				expectMissing(t, s, 0)(m.Load(s))
    64  				expectStored(t, s, i)(m.LoadOrStore(s, i))
    65  				expectPresent(t, s, i)(m.Load(s))
    66  				expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    67  			}
    68  			for i, s := range testData {
    69  				expectPresent(t, s, i)(m.Load(s))
    70  				expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt))
    71  				expectDeleted(t, s, i)(m.CompareAndDelete(s, i))
    72  				expectNotDeleted(t, s, i)(m.CompareAndDelete(s, i))
    73  				expectMissing(t, s, 0)(m.Load(s))
    74  			}
    75  			for _, s := range testData {
    76  				expectMissing(t, s, 0)(m.Load(s))
    77  			}
    78  		}
    79  	})
    80  	t.Run("CompareAndDeleteOne", func(t *testing.T) {
    81  		m := newMap()
    82  
    83  		for i, s := range testData {
    84  			expectMissing(t, s, 0)(m.Load(s))
    85  			expectStored(t, s, i)(m.LoadOrStore(s, i))
    86  			expectPresent(t, s, i)(m.Load(s))
    87  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
    88  		}
    89  		expectNotDeleted(t, testData[15], math.MaxInt)(m.CompareAndDelete(testData[15], math.MaxInt))
    90  		expectDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15))
    91  		expectNotDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15))
    92  		for i, s := range testData {
    93  			if i == 15 {
    94  				expectMissing(t, s, 0)(m.Load(s))
    95  			} else {
    96  				expectPresent(t, s, i)(m.Load(s))
    97  			}
    98  		}
    99  	})
   100  	t.Run("DeleteMultiple", func(t *testing.T) {
   101  		m := newMap()
   102  
   103  		for i, s := range testData {
   104  			expectMissing(t, s, 0)(m.Load(s))
   105  			expectStored(t, s, i)(m.LoadOrStore(s, i))
   106  			expectPresent(t, s, i)(m.Load(s))
   107  			expectLoaded(t, s, i)(m.LoadOrStore(s, 0))
   108  		}
   109  		for _, i := range []int{1, 105, 6, 85} {
   110  			expectNotDeleted(t, testData[i], math.MaxInt)(m.CompareAndDelete(testData[i], math.MaxInt))
   111  			expectDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i))
   112  			expectNotDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i))
   113  		}
   114  		for i, s := range testData {
   115  			if i == 1 || i == 105 || i == 6 || i == 85 {
   116  				expectMissing(t, s, 0)(m.Load(s))
   117  			} else {
   118  				expectPresent(t, s, i)(m.Load(s))
   119  			}
   120  		}
   121  	})
   122  	t.Run("All", func(t *testing.T) {
   123  		m := newMap()
   124  
   125  		testAll(t, m, testDataMap(testData[:]), func(_ string, _ int) bool {
   126  			return true
   127  		})
   128  	})
   129  	t.Run("AllDelete", func(t *testing.T) {
   130  		m := newMap()
   131  
   132  		testAll(t, m, testDataMap(testData[:]), func(s string, i int) bool {
   133  			expectDeleted(t, s, i)(m.CompareAndDelete(s, i))
   134  			return true
   135  		})
   136  		for _, s := range testData {
   137  			expectMissing(t, s, 0)(m.Load(s))
   138  		}
   139  	})
   140  	t.Run("ConcurrentLifecycleUnsharedKeys", func(t *testing.T) {
   141  		m := newMap()
   142  
   143  		gmp := runtime.GOMAXPROCS(-1)
   144  		var wg sync.WaitGroup
   145  		for i := range gmp {
   146  			wg.Add(1)
   147  			go func(id int) {
   148  				defer wg.Done()
   149  
   150  				makeKey := func(s string) string {
   151  					return s + "-" + strconv.Itoa(id)
   152  				}
   153  				for _, s := range testData {
   154  					key := makeKey(s)
   155  					expectMissing(t, key, 0)(m.Load(key))
   156  					expectStored(t, key, id)(m.LoadOrStore(key, id))
   157  					expectPresent(t, key, id)(m.Load(key))
   158  					expectLoaded(t, key, id)(m.LoadOrStore(key, 0))
   159  				}
   160  				for _, s := range testData {
   161  					key := makeKey(s)
   162  					expectPresent(t, key, id)(m.Load(key))
   163  					expectDeleted(t, key, id)(m.CompareAndDelete(key, id))
   164  					expectMissing(t, key, 0)(m.Load(key))
   165  				}
   166  				for _, s := range testData {
   167  					key := makeKey(s)
   168  					expectMissing(t, key, 0)(m.Load(key))
   169  				}
   170  			}(i)
   171  		}
   172  		wg.Wait()
   173  	})
   174  	t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) {
   175  		m := newMap()
   176  
   177  		// Load up the map.
   178  		for i, s := range testData {
   179  			expectMissing(t, s, 0)(m.Load(s))
   180  			expectStored(t, s, i)(m.LoadOrStore(s, i))
   181  		}
   182  		gmp := runtime.GOMAXPROCS(-1)
   183  		var wg sync.WaitGroup
   184  		for i := range gmp {
   185  			wg.Add(1)
   186  			go func(id int) {
   187  				defer wg.Done()
   188  
   189  				for i, s := range testData {
   190  					expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt))
   191  					m.CompareAndDelete(s, i)
   192  					expectMissing(t, s, 0)(m.Load(s))
   193  				}
   194  				for _, s := range testData {
   195  					expectMissing(t, s, 0)(m.Load(s))
   196  				}
   197  			}(i)
   198  		}
   199  		wg.Wait()
   200  	})
   201  }
   202  
   203  func testAll[K, V comparable](t *testing.T, m *HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) {
   204  	for k, v := range testData {
   205  		expectStored(t, k, v)(m.LoadOrStore(k, v))
   206  	}
   207  	visited := make(map[K]int)
   208  	m.All()(func(key K, got V) bool {
   209  		want, ok := testData[key]
   210  		if !ok {
   211  			t.Errorf("unexpected key %v in map", key)
   212  			return false
   213  		}
   214  		if got != want {
   215  			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
   216  			return false
   217  		}
   218  		visited[key]++
   219  		return yield(key, got)
   220  	})
   221  	for key, n := range visited {
   222  		if n > 1 {
   223  			t.Errorf("visited key %v more than once", key)
   224  		}
   225  	}
   226  }
   227  
   228  func expectPresent[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) {
   229  	t.Helper()
   230  	return func(got V, ok bool) {
   231  		t.Helper()
   232  
   233  		if !ok {
   234  			t.Errorf("expected key %v to be present in map", key)
   235  		}
   236  		if ok && got != want {
   237  			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
   238  		}
   239  	}
   240  }
   241  
   242  func expectMissing[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) {
   243  	t.Helper()
   244  	if want != *new(V) {
   245  		// This is awkward, but the want argument is necessary to smooth over type inference.
   246  		// Just make sure the want argument always looks the same.
   247  		panic("expectMissing must always have a zero value variable")
   248  	}
   249  	return func(got V, ok bool) {
   250  		t.Helper()
   251  
   252  		if ok {
   253  			t.Errorf("expected key %v to be missing from map, got value %v", key, got)
   254  		}
   255  		if !ok && got != want {
   256  			t.Errorf("expected missing key %v to be paired with the zero value; got %v", key, got)
   257  		}
   258  	}
   259  }
   260  
   261  func expectLoaded[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) {
   262  	t.Helper()
   263  	return func(got V, loaded bool) {
   264  		t.Helper()
   265  
   266  		if !loaded {
   267  			t.Errorf("expected key %v to have been loaded, not stored", key)
   268  		}
   269  		if got != want {
   270  			t.Errorf("expected key %v to have value %v, got %v", key, want, got)
   271  		}
   272  	}
   273  }
   274  
   275  func expectStored[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) {
   276  	t.Helper()
   277  	return func(got V, loaded bool) {
   278  		t.Helper()
   279  
   280  		if loaded {
   281  			t.Errorf("expected inserted key %v to have been stored, not loaded", key)
   282  		}
   283  		if got != want {
   284  			t.Errorf("expected inserted key %v to have value %v, got %v", key, want, got)
   285  		}
   286  	}
   287  }
   288  
   289  func expectDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) {
   290  	t.Helper()
   291  	return func(deleted bool) {
   292  		t.Helper()
   293  
   294  		if !deleted {
   295  			t.Errorf("expected key %v with value %v to be in map and deleted", key, old)
   296  		}
   297  	}
   298  }
   299  
   300  func expectNotDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) {
   301  	t.Helper()
   302  	return func(deleted bool) {
   303  		t.Helper()
   304  
   305  		if deleted {
   306  			t.Errorf("expected key %v with value %v to not be in map and thus not deleted", key, old)
   307  		}
   308  	}
   309  }
   310  
   311  func testDataMap(data []string) map[string]int {
   312  	m := make(map[string]int)
   313  	for i, s := range data {
   314  		m[s] = i
   315  	}
   316  	return m
   317  }
   318  
   319  var (
   320  	testDataSmall [8]string
   321  	testData      [128]string
   322  	testDataLarge [128 << 10]string
   323  )
   324  
   325  func init() {
   326  	for i := range testDataSmall {
   327  		testDataSmall[i] = fmt.Sprintf("%b", i)
   328  	}
   329  	for i := range testData {
   330  		testData[i] = fmt.Sprintf("%b", i)
   331  	}
   332  	for i := range testDataLarge {
   333  		testDataLarge[i] = fmt.Sprintf("%b", i)
   334  	}
   335  }
   336  
   337  func dumpMap[K, V comparable](ht *HashTrieMap[K, V]) {
   338  	dumpNode(ht, &ht.root.node, 0)
   339  }
   340  
   341  func dumpNode[K, V comparable](ht *HashTrieMap[K, V], n *node[K, V], depth int) {
   342  	var sb strings.Builder
   343  	for range depth {
   344  		fmt.Fprintf(&sb, "\t")
   345  	}
   346  	prefix := sb.String()
   347  	if n.isEntry {
   348  		e := n.entry()
   349  		for e != nil {
   350  			fmt.Printf("%s%p [Entry Key=%v Value=%v Overflow=%p, Hash=%016x]\n", prefix, e, e.key, e.value, e.overflow.Load(), ht.keyHash(unsafe.Pointer(&e.key), ht.seed))
   351  			e = e.overflow.Load()
   352  		}
   353  		return
   354  	}
   355  	i := n.indirect()
   356  	fmt.Printf("%s%p [Indirect Parent=%p Dead=%t Children=[", prefix, i, i.parent, i.dead.Load())
   357  	for j := range i.children {
   358  		c := i.children[j].Load()
   359  		fmt.Printf("%p", c)
   360  		if j != len(i.children)-1 {
   361  			fmt.Printf(", ")
   362  		}
   363  	}
   364  	fmt.Printf("]]\n")
   365  	for j := range i.children {
   366  		c := i.children[j].Load()
   367  		if c != nil {
   368  			dumpNode(ht, c, depth+1)
   369  		}
   370  	}
   371  }
   372  

View as plain text