1
2
3
4
5 package sync
6
7 import (
8 "internal/abi"
9 "internal/goarch"
10 "sync/atomic"
11 "unsafe"
12 )
13
14
15
16
17
18
19
20
21 type HashTrieMap[K comparable, V any] struct {
22 inited atomic.Uint32
23 initMu Mutex
24 root atomic.Pointer[indirect[K, V]]
25 keyHash hashFunc
26 valEqual equalFunc
27 seed uintptr
28 }
29
30 func (ht *HashTrieMap[K, V]) init() {
31 if ht.inited.Load() == 0 {
32 ht.initSlow()
33 }
34 }
35
36
37 func (ht *HashTrieMap[K, V]) initSlow() {
38 ht.initMu.Lock()
39 defer ht.initMu.Unlock()
40
41 if ht.inited.Load() != 0 {
42
43 return
44 }
45
46
47
48 var m map[K]V
49 mapType := abi.TypeOf(m).MapType()
50 ht.root.Store(newIndirectNode[K, V](nil))
51 ht.keyHash = mapType.Hasher
52 ht.valEqual = mapType.Elem.Equal
53 ht.seed = uintptr(runtime_rand())
54
55 ht.inited.Store(1)
56 }
57
58 type hashFunc func(unsafe.Pointer, uintptr) uintptr
59 type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
60
61
62
63
64 func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
65 ht.init()
66 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
67
68 i := ht.root.Load()
69 hashShift := 8 * goarch.PtrSize
70 for hashShift != 0 {
71 hashShift -= nChildrenLog2
72
73 n := i.children[(hash>>hashShift)&nChildrenMask].Load()
74 if n == nil {
75 return *new(V), false
76 }
77 if n.isEntry {
78 return n.entry().lookup(key)
79 }
80 i = n.indirect()
81 }
82 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
83 }
84
85
86
87
88 func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
89 ht.init()
90 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
91 var i *indirect[K, V]
92 var hashShift uint
93 var slot *atomic.Pointer[node[K, V]]
94 var n *node[K, V]
95 for {
96
97 i = ht.root.Load()
98 hashShift = 8 * goarch.PtrSize
99 haveInsertPoint := false
100 for hashShift != 0 {
101 hashShift -= nChildrenLog2
102
103 slot = &i.children[(hash>>hashShift)&nChildrenMask]
104 n = slot.Load()
105 if n == nil {
106
107 haveInsertPoint = true
108 break
109 }
110 if n.isEntry {
111
112
113
114 if v, ok := n.entry().lookup(key); ok {
115 return v, true
116 }
117 haveInsertPoint = true
118 break
119 }
120 i = n.indirect()
121 }
122 if !haveInsertPoint {
123 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
124 }
125
126
127 i.mu.Lock()
128 n = slot.Load()
129 if (n == nil || n.isEntry) && !i.dead.Load() {
130
131 break
132 }
133
134 i.mu.Unlock()
135 }
136
137
138
139
140
141 defer i.mu.Unlock()
142
143 var oldEntry *entry[K, V]
144 if n != nil {
145 oldEntry = n.entry()
146 if v, ok := oldEntry.lookup(key); ok {
147
148 return v, true
149 }
150 }
151 newEntry := newEntryNode(key, value)
152 if oldEntry == nil {
153
154 slot.Store(&newEntry.node)
155 } else {
156
157
158
159
160 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
161 }
162 return value, false
163 }
164
165
166
167 func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
168
169 oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
170 if oldHash == newHash {
171
172
173 newEntry.overflow.Store(oldEntry)
174 return &newEntry.node
175 }
176
177 newIndirect := newIndirectNode(parent)
178 top := newIndirect
179 for {
180 if hashShift == 0 {
181 panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting")
182 }
183 hashShift -= nChildrenLog2
184 oi := (oldHash >> hashShift) & nChildrenMask
185 ni := (newHash >> hashShift) & nChildrenMask
186 if oi != ni {
187 newIndirect.children[oi].Store(&oldEntry.node)
188 newIndirect.children[ni].Store(&newEntry.node)
189 break
190 }
191 nextIndirect := newIndirectNode(newIndirect)
192 newIndirect.children[oi].Store(&nextIndirect.node)
193 newIndirect = nextIndirect
194 }
195 return &top.node
196 }
197
198
199 func (ht *HashTrieMap[K, V]) Store(key K, old V) {
200 _, _ = ht.Swap(key, old)
201 }
202
203
204
205 func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
206 ht.init()
207 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
208 var i *indirect[K, V]
209 var hashShift uint
210 var slot *atomic.Pointer[node[K, V]]
211 var n *node[K, V]
212 for {
213
214 i = ht.root.Load()
215 hashShift = 8 * goarch.PtrSize
216 haveInsertPoint := false
217 for hashShift != 0 {
218 hashShift -= nChildrenLog2
219
220 slot = &i.children[(hash>>hashShift)&nChildrenMask]
221 n = slot.Load()
222 if n == nil {
223
224
225 haveInsertPoint = true
226 break
227 }
228 if n.isEntry {
229
230 old, swapped := n.entry().swap(key, new)
231 if swapped {
232 return old, true
233 }
234
235 haveInsertPoint = true
236 break
237 }
238 i = n.indirect()
239 }
240 if !haveInsertPoint {
241 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
242 }
243
244
245 i.mu.Lock()
246 n = slot.Load()
247 if (n == nil || n.isEntry) && !i.dead.Load() {
248
249 break
250 }
251
252 i.mu.Unlock()
253 }
254
255
256
257
258
259 defer i.mu.Unlock()
260
261 var zero V
262 var oldEntry *entry[K, V]
263 if n != nil {
264
265 oldEntry = n.entry()
266 old, swapped := oldEntry.swap(key, new)
267 if swapped {
268 return old, true
269 }
270 }
271
272 newEntry := newEntryNode(key, new)
273 if oldEntry == nil {
274
275 slot.Store(&newEntry.node)
276 } else {
277
278
279
280
281 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
282 }
283 return zero, false
284 }
285
286
287
288
289 func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
290 ht.init()
291 if ht.valEqual == nil {
292 panic("called CompareAndSwap when value is not of comparable type")
293 }
294 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
295 for {
296
297 i := ht.root.Load()
298 hashShift := 8 * goarch.PtrSize
299 found := false
300 for hashShift != 0 {
301 hashShift -= nChildrenLog2
302
303 slot := &i.children[(hash>>hashShift)&nChildrenMask]
304 n := slot.Load()
305 if n == nil {
306
307 return false
308 }
309 if n.isEntry {
310
311 return n.entry().compareAndSwap(key, old, new, ht.valEqual)
312 }
313 i = n.indirect()
314 }
315 if !found {
316 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
317 }
318 }
319 }
320
321
322
323 func (ht *HashTrieMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
324 ht.init()
325 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
326
327
328 i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
329 if n == nil {
330 if i != nil {
331 i.mu.Unlock()
332 }
333 return *new(V), false
334 }
335
336
337 v, e, loaded := n.entry().loadAndDelete(key)
338 if !loaded {
339
340 i.mu.Unlock()
341 return *new(V), false
342 }
343 if e != nil {
344
345
346 slot.Store(&e.node)
347 i.mu.Unlock()
348 return v, true
349 }
350
351 slot.Store(nil)
352
353
354 for i.parent != nil && i.empty() {
355 if hashShift == 8*goarch.PtrSize {
356 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
357 }
358 hashShift += nChildrenLog2
359
360
361 parent := i.parent
362 parent.mu.Lock()
363 i.dead.Store(true)
364 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
365 i.mu.Unlock()
366 i = parent
367 }
368 i.mu.Unlock()
369 return v, true
370 }
371
372
373 func (ht *HashTrieMap[K, V]) Delete(key K) {
374 _, _ = ht.LoadAndDelete(key)
375 }
376
377
378
379
380
381
382 func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
383 ht.init()
384 if ht.valEqual == nil {
385 panic("called CompareAndDelete when value is not of comparable type")
386 }
387 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
388
389
390 i, hashShift, slot, n := ht.find(key, hash, nil, *new(V))
391 if n == nil {
392 if i != nil {
393 i.mu.Unlock()
394 }
395 return false
396 }
397
398
399 e, deleted := n.entry().compareAndDelete(key, old, ht.valEqual)
400 if !deleted {
401
402 i.mu.Unlock()
403 return false
404 }
405 if e != nil {
406
407
408 slot.Store(&e.node)
409 i.mu.Unlock()
410 return true
411 }
412
413 slot.Store(nil)
414
415
416 for i.parent != nil && i.empty() {
417 if hashShift == 8*goarch.PtrSize {
418 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
419 }
420 hashShift += nChildrenLog2
421
422
423 parent := i.parent
424 parent.mu.Lock()
425 i.dead.Store(true)
426 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
427 i.mu.Unlock()
428 i = parent
429 }
430 i.mu.Unlock()
431 return true
432 }
433
434
435
436
437
438
439
440 func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) {
441 for {
442
443 i = ht.root.Load()
444 hashShift = 8 * goarch.PtrSize
445 found := false
446 for hashShift != 0 {
447 hashShift -= nChildrenLog2
448
449 slot = &i.children[(hash>>hashShift)&nChildrenMask]
450 n = slot.Load()
451 if n == nil {
452
453 i = nil
454 return
455 }
456 if n.isEntry {
457
458 if _, ok := n.entry().lookupWithValue(key, value, valEqual); !ok {
459
460 i = nil
461 n = nil
462 return
463 }
464
465 found = true
466 break
467 }
468 i = n.indirect()
469 }
470 if !found {
471 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
472 }
473
474
475 i.mu.Lock()
476 n = slot.Load()
477 if !i.dead.Load() && (n == nil || n.isEntry) {
478
479
480 return
481 }
482
483 i.mu.Unlock()
484 }
485 }
486
487
488
489
490
491
492
493
494
495 func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) {
496 ht.init()
497 return func(yield func(key K, value V) bool) {
498 ht.iter(ht.root.Load(), yield)
499 }
500 }
501
502
503
504
505
506
507 func (ht *HashTrieMap[K, V]) Range(yield func(K, V) bool) {
508 ht.init()
509 ht.iter(ht.root.Load(), yield)
510 }
511
512 func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
513 for j := range i.children {
514 n := i.children[j].Load()
515 if n == nil {
516 continue
517 }
518 if !n.isEntry {
519 if !ht.iter(n.indirect(), yield) {
520 return false
521 }
522 continue
523 }
524 e := n.entry()
525 for e != nil {
526 if !yield(e.key, *e.value.Load()) {
527 return false
528 }
529 e = e.overflow.Load()
530 }
531 }
532 return true
533 }
534
535
536 func (ht *HashTrieMap[K, V]) Clear() {
537 ht.init()
538
539
540
541 ht.root.Store(newIndirectNode[K, V](nil))
542 }
543
544 const (
545
546
547
548
549 nChildrenLog2 = 4
550 nChildren = 1 << nChildrenLog2
551 nChildrenMask = nChildren - 1
552 )
553
554
555 type indirect[K comparable, V any] struct {
556 node[K, V]
557 dead atomic.Bool
558 mu Mutex
559 parent *indirect[K, V]
560 children [nChildren]atomic.Pointer[node[K, V]]
561 }
562
563 func newIndirectNode[K comparable, V any](parent *indirect[K, V]) *indirect[K, V] {
564 return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
565 }
566
567 func (i *indirect[K, V]) empty() bool {
568 nc := 0
569 for j := range i.children {
570 if i.children[j].Load() != nil {
571 nc++
572 }
573 }
574 return nc == 0
575 }
576
577
578 type entry[K comparable, V any] struct {
579 node[K, V]
580 overflow atomic.Pointer[entry[K, V]]
581 key K
582 value atomic.Pointer[V]
583 }
584
585 func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
586 e := &entry[K, V]{
587 node: node[K, V]{isEntry: true},
588 key: key,
589 }
590 e.value.Store(&value)
591 return e
592 }
593
594 func (e *entry[K, V]) lookup(key K) (V, bool) {
595 for e != nil {
596 if e.key == key {
597 return *e.value.Load(), true
598 }
599 e = e.overflow.Load()
600 }
601 return *new(V), false
602 }
603
604 func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
605 for e != nil {
606 oldp := e.value.Load()
607 if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) {
608 return *oldp, true
609 }
610 e = e.overflow.Load()
611 }
612 return *new(V), false
613 }
614
615
616
617
618
619 func (head *entry[K, V]) swap(key K, newv V) (V, bool) {
620 if head.key == key {
621 vp := new(V)
622 *vp = newv
623 oldp := head.value.Swap(vp)
624 return *oldp, true
625 }
626 i := &head.overflow
627 e := i.Load()
628 for e != nil {
629 if e.key == key {
630 vp := new(V)
631 *vp = newv
632 oldp := e.value.Swap(vp)
633 return *oldp, true
634 }
635 i = &e.overflow
636 e = e.overflow.Load()
637 }
638 var zero V
639 return zero, false
640 }
641
642
643
644
645
646 func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool {
647 var vbox *V
648 outerLoop:
649 for {
650 oldvp := head.value.Load()
651 if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
652
653 if vbox == nil {
654
655
656 vbox = new(V)
657 *vbox = newv
658 }
659 if head.value.CompareAndSwap(oldvp, vbox) {
660 return true
661 }
662
663
664 continue outerLoop
665 }
666 i := &head.overflow
667 e := i.Load()
668 for e != nil {
669 oldvp := e.value.Load()
670 if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
671 if vbox == nil {
672
673
674 vbox = new(V)
675 *vbox = newv
676 }
677 if e.value.CompareAndSwap(oldvp, vbox) {
678 return true
679 }
680 continue outerLoop
681 }
682 i = &e.overflow
683 e = e.overflow.Load()
684 }
685 return false
686 }
687 }
688
689
690
691
692
693 func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
694 if head.key == key {
695
696 return *head.value.Load(), head.overflow.Load(), true
697 }
698 i := &head.overflow
699 e := i.Load()
700 for e != nil {
701 if e.key == key {
702 i.Store(e.overflow.Load())
703 return *e.value.Load(), head, true
704 }
705 i = &e.overflow
706 e = e.overflow.Load()
707 }
708 return *new(V), head, false
709 }
710
711
712
713
714
715 func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
716 if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
717
718 return head.overflow.Load(), true
719 }
720 i := &head.overflow
721 e := i.Load()
722 for e != nil {
723 if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) {
724 i.Store(e.overflow.Load())
725 return head, true
726 }
727 i = &e.overflow
728 e = e.overflow.Load()
729 }
730 return head, false
731 }
732
733
734
735 type node[K comparable, V any] struct {
736 isEntry bool
737 }
738
739 func (n *node[K, V]) entry() *entry[K, V] {
740 if !n.isEntry {
741 panic("called entry on non-entry node")
742 }
743 return (*entry[K, V])(unsafe.Pointer(n))
744 }
745
746 func (n *node[K, V]) indirect() *indirect[K, V] {
747 if n.isEntry {
748 panic("called indirect on entry node")
749 }
750 return (*indirect[K, V])(unsafe.Pointer(n))
751 }
752
753
754
755
756
757 func runtime_rand() uint64
758
View as plain text