1  
     2  
     3  
     4  
     5  package sync
     6  
     7  import (
     8  	"internal/abi"
     9  	"internal/goarch"
    10  	"sync/atomic"
    11  	"unsafe"
    12  )
    13  
    14  
    15  
    16  
    17  
    18  
    19  
    20  
    21  type HashTrieMap[K comparable, V any] struct {
    22  	inited   atomic.Uint32
    23  	initMu   Mutex
    24  	root     atomic.Pointer[indirect[K, V]]
    25  	keyHash  hashFunc
    26  	valEqual equalFunc
    27  	seed     uintptr
    28  }
    29  
    30  func (ht *HashTrieMap[K, V]) init() {
    31  	if ht.inited.Load() == 0 {
    32  		ht.initSlow()
    33  	}
    34  }
    35  
    36  
    37  func (ht *HashTrieMap[K, V]) initSlow() {
    38  	ht.initMu.Lock()
    39  	defer ht.initMu.Unlock()
    40  
    41  	if ht.inited.Load() != 0 {
    42  		
    43  		return
    44  	}
    45  
    46  	
    47  	
    48  	var m map[K]V
    49  	mapType := abi.TypeOf(m).MapType()
    50  	ht.root.Store(newIndirectNode[K, V](nil))
    51  	ht.keyHash = mapType.Hasher
    52  	ht.valEqual = mapType.Elem.Equal
    53  	ht.seed = uintptr(runtime_rand())
    54  
    55  	ht.inited.Store(1)
    56  }
    57  
    58  type hashFunc func(unsafe.Pointer, uintptr) uintptr
    59  type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
    60  
    61  
    62  
    63  
    64  func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
    65  	ht.init()
    66  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
    67  
    68  	i := ht.root.Load()
    69  	hashShift := 8 * goarch.PtrSize
    70  	for hashShift != 0 {
    71  		hashShift -= nChildrenLog2
    72  
    73  		n := i.children[(hash>>hashShift)&nChildrenMask].Load()
    74  		if n == nil {
    75  			return *new(V), false
    76  		}
    77  		if n.isEntry {
    78  			return n.entry().lookup(key)
    79  		}
    80  		i = n.indirect()
    81  	}
    82  	panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
    83  }
    84  
    85  
    86  
    87  
    88  func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
    89  	ht.init()
    90  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
    91  	var i *indirect[K, V]
    92  	var hashShift uint
    93  	var slot *atomic.Pointer[node[K, V]]
    94  	var n *node[K, V]
    95  	for {
    96  		
    97  		i = ht.root.Load()
    98  		hashShift = 8 * goarch.PtrSize
    99  		haveInsertPoint := false
   100  		for hashShift != 0 {
   101  			hashShift -= nChildrenLog2
   102  
   103  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
   104  			n = slot.Load()
   105  			if n == nil {
   106  				
   107  				haveInsertPoint = true
   108  				break
   109  			}
   110  			if n.isEntry {
   111  				
   112  				
   113  				
   114  				if v, ok := n.entry().lookup(key); ok {
   115  					return v, true
   116  				}
   117  				haveInsertPoint = true
   118  				break
   119  			}
   120  			i = n.indirect()
   121  		}
   122  		if !haveInsertPoint {
   123  			panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
   124  		}
   125  
   126  		
   127  		i.mu.Lock()
   128  		n = slot.Load()
   129  		if (n == nil || n.isEntry) && !i.dead.Load() {
   130  			
   131  			break
   132  		}
   133  		
   134  		i.mu.Unlock()
   135  	}
   136  	
   137  	
   138  	
   139  	
   140  	
   141  	defer i.mu.Unlock()
   142  
   143  	var oldEntry *entry[K, V]
   144  	if n != nil {
   145  		oldEntry = n.entry()
   146  		if v, ok := oldEntry.lookup(key); ok {
   147  			
   148  			return v, true
   149  		}
   150  	}
   151  	newEntry := newEntryNode(key, value)
   152  	if oldEntry == nil {
   153  		
   154  		slot.Store(&newEntry.node)
   155  	} else {
   156  		
   157  		
   158  		
   159  		
   160  		slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
   161  	}
   162  	return value, false
   163  }
   164  
   165  
   166  
   167  func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
   168  	
   169  	oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
   170  	if oldHash == newHash {
   171  		
   172  		
   173  		newEntry.overflow.Store(oldEntry)
   174  		return &newEntry.node
   175  	}
   176  	
   177  	newIndirect := newIndirectNode(parent)
   178  	top := newIndirect
   179  	for {
   180  		if hashShift == 0 {
   181  			panic("internal/sync.HashTrieMap: ran out of hash bits while inserting (incorrect use of unsafe or cgo, or data race?)")
   182  		}
   183  		hashShift -= nChildrenLog2 
   184  		oi := (oldHash >> hashShift) & nChildrenMask
   185  		ni := (newHash >> hashShift) & nChildrenMask
   186  		if oi != ni {
   187  			newIndirect.children[oi].Store(&oldEntry.node)
   188  			newIndirect.children[ni].Store(&newEntry.node)
   189  			break
   190  		}
   191  		nextIndirect := newIndirectNode(newIndirect)
   192  		newIndirect.children[oi].Store(&nextIndirect.node)
   193  		newIndirect = nextIndirect
   194  	}
   195  	return &top.node
   196  }
   197  
   198  
   199  func (ht *HashTrieMap[K, V]) Store(key K, new V) {
   200  	_, _ = ht.Swap(key, new)
   201  }
   202  
   203  
   204  
   205  func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
   206  	ht.init()
   207  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
   208  	var i *indirect[K, V]
   209  	var hashShift uint
   210  	var slot *atomic.Pointer[node[K, V]]
   211  	var n *node[K, V]
   212  	for {
   213  		
   214  		i = ht.root.Load()
   215  		hashShift = 8 * goarch.PtrSize
   216  		haveInsertPoint := false
   217  		for hashShift != 0 {
   218  			hashShift -= nChildrenLog2
   219  
   220  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
   221  			n = slot.Load()
   222  			if n == nil || n.isEntry {
   223  				
   224  				
   225  				haveInsertPoint = true
   226  				break
   227  			}
   228  			i = n.indirect()
   229  		}
   230  		if !haveInsertPoint {
   231  			panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
   232  		}
   233  
   234  		
   235  		i.mu.Lock()
   236  		n = slot.Load()
   237  		if (n == nil || n.isEntry) && !i.dead.Load() {
   238  			
   239  			break
   240  		}
   241  		
   242  		i.mu.Unlock()
   243  	}
   244  	
   245  	
   246  	
   247  	
   248  	
   249  	defer i.mu.Unlock()
   250  
   251  	var zero V
   252  	var oldEntry *entry[K, V]
   253  	if n != nil {
   254  		
   255  		oldEntry = n.entry()
   256  		newEntry, old, swapped := oldEntry.swap(key, new)
   257  		if swapped {
   258  			slot.Store(&newEntry.node)
   259  			return old, true
   260  		}
   261  	}
   262  	
   263  	newEntry := newEntryNode(key, new)
   264  	if oldEntry == nil {
   265  		
   266  		slot.Store(&newEntry.node)
   267  	} else {
   268  		
   269  		
   270  		
   271  		
   272  		slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
   273  	}
   274  	return zero, false
   275  }
   276  
   277  
   278  
   279  
   280  func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
   281  	ht.init()
   282  	if ht.valEqual == nil {
   283  		panic("called CompareAndSwap when value is not of comparable type")
   284  	}
   285  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
   286  
   287  	
   288  	i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
   289  	if i != nil {
   290  		defer i.mu.Unlock()
   291  	}
   292  	if n == nil {
   293  		return false
   294  	}
   295  
   296  	
   297  	e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
   298  	if !swapped {
   299  		
   300  		return false
   301  	}
   302  	
   303  	slot.Store(&e.node)
   304  	return true
   305  }
   306  
   307  
   308  
   309  func (ht *HashTrieMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
   310  	ht.init()
   311  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
   312  
   313  	
   314  	i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
   315  	if n == nil {
   316  		if i != nil {
   317  			i.mu.Unlock()
   318  		}
   319  		return *new(V), false
   320  	}
   321  
   322  	
   323  	v, e, loaded := n.entry().loadAndDelete(key)
   324  	if !loaded {
   325  		
   326  		i.mu.Unlock()
   327  		return *new(V), false
   328  	}
   329  	if e != nil {
   330  		
   331  		
   332  		slot.Store(&e.node)
   333  		i.mu.Unlock()
   334  		return v, true
   335  	}
   336  	
   337  	slot.Store(nil)
   338  
   339  	
   340  	for i.parent != nil && i.empty() {
   341  		if hashShift == 8*goarch.PtrSize {
   342  			panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
   343  		}
   344  		hashShift += nChildrenLog2
   345  
   346  		
   347  		parent := i.parent
   348  		parent.mu.Lock()
   349  		i.dead.Store(true)
   350  		parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
   351  		i.mu.Unlock()
   352  		i = parent
   353  	}
   354  	i.mu.Unlock()
   355  	return v, true
   356  }
   357  
   358  
   359  func (ht *HashTrieMap[K, V]) Delete(key K) {
   360  	_, _ = ht.LoadAndDelete(key)
   361  }
   362  
   363  
   364  
   365  
   366  
   367  
   368  func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
   369  	ht.init()
   370  	if ht.valEqual == nil {
   371  		panic("called CompareAndDelete when value is not of comparable type")
   372  	}
   373  	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
   374  
   375  	
   376  	i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
   377  	if n == nil {
   378  		if i != nil {
   379  			i.mu.Unlock()
   380  		}
   381  		return false
   382  	}
   383  
   384  	
   385  	e, deleted := n.entry().compareAndDelete(key, old, ht.valEqual)
   386  	if !deleted {
   387  		
   388  		i.mu.Unlock()
   389  		return false
   390  	}
   391  	if e != nil {
   392  		
   393  		
   394  		slot.Store(&e.node)
   395  		i.mu.Unlock()
   396  		return true
   397  	}
   398  	
   399  	slot.Store(nil)
   400  
   401  	
   402  	for i.parent != nil && i.empty() {
   403  		if hashShift == 8*goarch.PtrSize {
   404  			panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
   405  		}
   406  		hashShift += nChildrenLog2
   407  
   408  		
   409  		parent := i.parent
   410  		parent.mu.Lock()
   411  		i.dead.Store(true)
   412  		parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
   413  		i.mu.Unlock()
   414  		i = parent
   415  	}
   416  	i.mu.Unlock()
   417  	return true
   418  }
   419  
   420  
   421  
   422  
   423  
   424  
   425  
   426  func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
   427  	for {
   428  		
   429  		i = ht.root.Load()
   430  		hashShift = 8 * goarch.PtrSize
   431  		found := false
   432  		for hashShift != 0 {
   433  			hashShift -= nChildrenLog2
   434  
   435  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
   436  			n = slot.Load()
   437  			if n == nil {
   438  				
   439  				i = nil
   440  				return
   441  			}
   442  			if n.isEntry {
   443  				
   444  				if _, ok := n.entry().lookupWithValue(key, value, valEqual); !ok {
   445  					
   446  					i = nil
   447  					n = nil
   448  					return
   449  				}
   450  				
   451  				found = true
   452  				break
   453  			}
   454  			i = n.indirect()
   455  		}
   456  		if !found {
   457  			panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
   458  		}
   459  
   460  		
   461  		i.mu.Lock()
   462  		n = slot.Load()
   463  		if !i.dead.Load() && (n == nil || n.isEntry) {
   464  			
   465  			
   466  			return
   467  		}
   468  		
   469  		i.mu.Unlock()
   470  	}
   471  }
   472  
   473  
   474  
   475  
   476  
   477  
   478  
   479  
   480  
   481  func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) {
   482  	ht.init()
   483  	return func(yield func(key K, value V) bool) {
   484  		ht.iter(ht.root.Load(), yield)
   485  	}
   486  }
   487  
   488  
   489  
   490  
   491  
   492  
   493  func (ht *HashTrieMap[K, V]) Range(yield func(K, V) bool) {
   494  	ht.init()
   495  	ht.iter(ht.root.Load(), yield)
   496  }
   497  
   498  func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
   499  	for j := range i.children {
   500  		n := i.children[j].Load()
   501  		if n == nil {
   502  			continue
   503  		}
   504  		if !n.isEntry {
   505  			if !ht.iter(n.indirect(), yield) {
   506  				return false
   507  			}
   508  			continue
   509  		}
   510  		e := n.entry()
   511  		for e != nil {
   512  			if !yield(e.key, e.value) {
   513  				return false
   514  			}
   515  			e = e.overflow.Load()
   516  		}
   517  	}
   518  	return true
   519  }
   520  
   521  
   522  func (ht *HashTrieMap[K, V]) Clear() {
   523  	ht.init()
   524  
   525  	
   526  	
   527  	ht.root.Store(newIndirectNode[K, V](nil))
   528  }
   529  
   530  const (
   531  	
   532  	
   533  	
   534  	
   535  	nChildrenLog2 = 4
   536  	nChildren     = 1 << nChildrenLog2
   537  	nChildrenMask = nChildren - 1
   538  )
   539  
   540  
   541  type indirect[K comparable, V any] struct {
   542  	node[K, V]
   543  	dead     atomic.Bool
   544  	mu       Mutex 
   545  	parent   *indirect[K, V]
   546  	children [nChildren]atomic.Pointer[node[K, V]]
   547  }
   548  
   549  func newIndirectNode[K comparable, V any](parent *indirect[K, V]) *indirect[K, V] {
   550  	return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
   551  }
   552  
   553  func (i *indirect[K, V]) empty() bool {
   554  	nc := 0
   555  	for j := range i.children {
   556  		if i.children[j].Load() != nil {
   557  			nc++
   558  		}
   559  	}
   560  	return nc == 0
   561  }
   562  
   563  
   564  type entry[K comparable, V any] struct {
   565  	node[K, V]
   566  	overflow atomic.Pointer[entry[K, V]] 
   567  	key      K
   568  	value    V
   569  }
   570  
   571  func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
   572  	return &entry[K, V]{
   573  		node:  node[K, V]{isEntry: true},
   574  		key:   key,
   575  		value: value,
   576  	}
   577  }
   578  
   579  func (e *entry[K, V]) lookup(key K) (V, bool) {
   580  	for e != nil {
   581  		if e.key == key {
   582  			return e.value, true
   583  		}
   584  		e = e.overflow.Load()
   585  	}
   586  	return *new(V), false
   587  }
   588  
   589  func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
   590  	for e != nil {
   591  		if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
   592  			return e.value, true
   593  		}
   594  		e = e.overflow.Load()
   595  	}
   596  	return *new(V), false
   597  }
   598  
   599  
   600  
   601  
   602  
   603  func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
   604  	if head.key == key {
   605  		
   606  		e := newEntryNode(key, new)
   607  		if chain := head.overflow.Load(); chain != nil {
   608  			e.overflow.Store(chain)
   609  		}
   610  		return e, head.value, true
   611  	}
   612  	i := &head.overflow
   613  	e := i.Load()
   614  	for e != nil {
   615  		if e.key == key {
   616  			eNew := newEntryNode(key, new)
   617  			eNew.overflow.Store(e.overflow.Load())
   618  			i.Store(eNew)
   619  			return head, e.value, true
   620  		}
   621  		i = &e.overflow
   622  		e = e.overflow.Load()
   623  	}
   624  	var zero V
   625  	return head, zero, false
   626  }
   627  
   628  
   629  
   630  
   631  
   632  func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
   633  	if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
   634  		
   635  		e := newEntryNode(key, new)
   636  		if chain := head.overflow.Load(); chain != nil {
   637  			e.overflow.Store(chain)
   638  		}
   639  		return e, true
   640  	}
   641  	i := &head.overflow
   642  	e := i.Load()
   643  	for e != nil {
   644  		if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
   645  			eNew := newEntryNode(key, new)
   646  			eNew.overflow.Store(e.overflow.Load())
   647  			i.Store(eNew)
   648  			return head, true
   649  		}
   650  		i = &e.overflow
   651  		e = e.overflow.Load()
   652  	}
   653  	return head, false
   654  }
   655  
   656  
   657  
   658  
   659  
   660  func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
   661  	if head.key == key {
   662  		
   663  		return head.value, head.overflow.Load(), true
   664  	}
   665  	i := &head.overflow
   666  	e := i.Load()
   667  	for e != nil {
   668  		if e.key == key {
   669  			i.Store(e.overflow.Load())
   670  			return e.value, head, true
   671  		}
   672  		i = &e.overflow
   673  		e = e.overflow.Load()
   674  	}
   675  	return *new(V), head, false
   676  }
   677  
   678  
   679  
   680  
   681  
   682  func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
   683  	if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
   684  		
   685  		return head.overflow.Load(), true
   686  	}
   687  	i := &head.overflow
   688  	e := i.Load()
   689  	for e != nil {
   690  		if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
   691  			i.Store(e.overflow.Load())
   692  			return head, true
   693  		}
   694  		i = &e.overflow
   695  		e = e.overflow.Load()
   696  	}
   697  	return head, false
   698  }
   699  
   700  
   701  
   702  type node[K comparable, V any] struct {
   703  	isEntry bool
   704  }
   705  
   706  func (n *node[K, V]) entry() *entry[K, V] {
   707  	if !n.isEntry {
   708  		panic("called entry on non-entry node")
   709  	}
   710  	return (*entry[K, V])(unsafe.Pointer(n))
   711  }
   712  
   713  func (n *node[K, V]) indirect() *indirect[K, V] {
   714  	if n.isEntry {
   715  		panic("called indirect on entry node")
   716  	}
   717  	return (*indirect[K, V])(unsafe.Pointer(n))
   718  }
   719  
   720  
   721  
   722  
   723  
   724  func runtime_rand() uint64
   725  
View as plain text