1
2
3
4
5 package concurrent
6
7 import (
8 "internal/abi"
9 "internal/goarch"
10 "math/rand/v2"
11 "sync"
12 "sync/atomic"
13 "unsafe"
14 )
15
16
17
18
19
20 type HashTrieMap[K, V comparable] struct {
21 root *indirect[K, V]
22 keyHash hashFunc
23 keyEqual equalFunc
24 valEqual equalFunc
25 seed uintptr
26 }
27
28
29 func NewHashTrieMap[K, V comparable]() *HashTrieMap[K, V] {
30 var m map[K]V
31 mapType := abi.TypeOf(m).MapType()
32 ht := &HashTrieMap[K, V]{
33 root: newIndirectNode[K, V](nil),
34 keyHash: mapType.Hasher,
35 keyEqual: mapType.Key.Equal,
36 valEqual: mapType.Elem.Equal,
37 seed: uintptr(rand.Uint64()),
38 }
39 return ht
40 }
41
42 type hashFunc func(unsafe.Pointer, uintptr) uintptr
43 type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
44
45
46
47
48 func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
49 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
50
51 i := ht.root
52 hashShift := 8 * goarch.PtrSize
53 for hashShift != 0 {
54 hashShift -= nChildrenLog2
55
56 n := i.children[(hash>>hashShift)&nChildrenMask].Load()
57 if n == nil {
58 return *new(V), false
59 }
60 if n.isEntry {
61 return n.entry().lookup(key, ht.keyEqual)
62 }
63 i = n.indirect()
64 }
65 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
66 }
67
68
69
70
71 func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
72 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
73 var i *indirect[K, V]
74 var hashShift uint
75 var slot *atomic.Pointer[node[K, V]]
76 var n *node[K, V]
77 for {
78
79 i = ht.root
80 hashShift = 8 * goarch.PtrSize
81 haveInsertPoint := false
82 for hashShift != 0 {
83 hashShift -= nChildrenLog2
84
85 slot = &i.children[(hash>>hashShift)&nChildrenMask]
86 n = slot.Load()
87 if n == nil {
88
89 haveInsertPoint = true
90 break
91 }
92 if n.isEntry {
93
94
95
96 if v, ok := n.entry().lookup(key, ht.keyEqual); ok {
97 return v, true
98 }
99 haveInsertPoint = true
100 break
101 }
102 i = n.indirect()
103 }
104 if !haveInsertPoint {
105 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
106 }
107
108
109 i.mu.Lock()
110 n = slot.Load()
111 if (n == nil || n.isEntry) && !i.dead.Load() {
112
113 break
114 }
115
116 i.mu.Unlock()
117 }
118
119
120
121
122
123 defer i.mu.Unlock()
124
125 var oldEntry *entry[K, V]
126 if n != nil {
127 oldEntry = n.entry()
128 if v, ok := oldEntry.lookup(key, ht.keyEqual); ok {
129
130 return v, true
131 }
132 }
133 newEntry := newEntryNode(key, value)
134 if oldEntry == nil {
135
136 slot.Store(&newEntry.node)
137 } else {
138
139
140
141
142 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
143 }
144 return value, false
145 }
146
147
148
149 func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
150
151 oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
152 if oldHash == newHash {
153
154
155 newEntry.overflow.Store(oldEntry)
156 return &newEntry.node
157 }
158
159 newIndirect := newIndirectNode(parent)
160 top := newIndirect
161 for {
162 if hashShift == 0 {
163 panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting")
164 }
165 hashShift -= nChildrenLog2
166 oi := (oldHash >> hashShift) & nChildrenMask
167 ni := (newHash >> hashShift) & nChildrenMask
168 if oi != ni {
169 newIndirect.children[oi].Store(&oldEntry.node)
170 newIndirect.children[ni].Store(&newEntry.node)
171 break
172 }
173 nextIndirect := newIndirectNode(newIndirect)
174 newIndirect.children[oi].Store(&nextIndirect.node)
175 newIndirect = nextIndirect
176 }
177 return &top.node
178 }
179
180
181
182
183
184 func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
185 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
186 var i *indirect[K, V]
187 var hashShift uint
188 var slot *atomic.Pointer[node[K, V]]
189 var n *node[K, V]
190 for {
191
192 i = ht.root
193 hashShift = 8 * goarch.PtrSize
194 found := false
195 for hashShift != 0 {
196 hashShift -= nChildrenLog2
197
198 slot = &i.children[(hash>>hashShift)&nChildrenMask]
199 n = slot.Load()
200 if n == nil {
201
202 return
203 }
204 if n.isEntry {
205
206 if _, ok := n.entry().lookup(key, ht.keyEqual); !ok {
207
208 return
209 }
210
211 found = true
212 break
213 }
214 i = n.indirect()
215 }
216 if !found {
217 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
218 }
219
220
221 i.mu.Lock()
222 n = slot.Load()
223 if !i.dead.Load() {
224 if n == nil {
225
226 i.mu.Unlock()
227 return
228 }
229 if n.isEntry {
230
231 break
232 }
233 }
234
235 i.mu.Unlock()
236 }
237
238 e, deleted := n.entry().compareAndDelete(key, old, ht.keyEqual, ht.valEqual)
239 if !deleted {
240
241 i.mu.Unlock()
242 return false
243 }
244 if e != nil {
245
246
247 slot.Store(&e.node)
248 i.mu.Unlock()
249 return true
250 }
251
252 slot.Store(nil)
253
254
255 for i.parent != nil && i.empty() {
256 if hashShift == 8*goarch.PtrSize {
257 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
258 }
259 hashShift += nChildrenLog2
260
261
262 parent := i.parent
263 parent.mu.Lock()
264 i.dead.Store(true)
265 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
266 i.mu.Unlock()
267 i = parent
268 }
269 i.mu.Unlock()
270 return true
271 }
272
273
274
275
276
277
278 func (ht *HashTrieMap[K, V]) Enumerate(yield func(key K, value V) bool) {
279 ht.iter(ht.root, yield)
280 }
281
282 func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
283 for j := range i.children {
284 n := i.children[j].Load()
285 if n == nil {
286 continue
287 }
288 if !n.isEntry {
289 if !ht.iter(n.indirect(), yield) {
290 return false
291 }
292 continue
293 }
294 e := n.entry()
295 for e != nil {
296 if !yield(e.key, e.value) {
297 return false
298 }
299 e = e.overflow.Load()
300 }
301 }
302 return true
303 }
304
305 const (
306
307
308
309
310 nChildrenLog2 = 4
311 nChildren = 1 << nChildrenLog2
312 nChildrenMask = nChildren - 1
313 )
314
315
316 type indirect[K, V comparable] struct {
317 node[K, V]
318 dead atomic.Bool
319 mu sync.Mutex
320 parent *indirect[K, V]
321 children [nChildren]atomic.Pointer[node[K, V]]
322 }
323
324 func newIndirectNode[K, V comparable](parent *indirect[K, V]) *indirect[K, V] {
325 return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
326 }
327
328 func (i *indirect[K, V]) empty() bool {
329 nc := 0
330 for j := range i.children {
331 if i.children[j].Load() != nil {
332 nc++
333 }
334 }
335 return nc == 0
336 }
337
338
339 type entry[K, V comparable] struct {
340 node[K, V]
341 overflow atomic.Pointer[entry[K, V]]
342 key K
343 value V
344 }
345
346 func newEntryNode[K, V comparable](key K, value V) *entry[K, V] {
347 return &entry[K, V]{
348 node: node[K, V]{isEntry: true},
349 key: key,
350 value: value,
351 }
352 }
353
354 func (e *entry[K, V]) lookup(key K, equal equalFunc) (V, bool) {
355 for e != nil {
356 if equal(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) {
357 return e.value, true
358 }
359 e = e.overflow.Load()
360 }
361 return *new(V), false
362 }
363
364
365
366
367
368 func (head *entry[K, V]) compareAndDelete(key K, value V, keyEqual, valEqual equalFunc) (*entry[K, V], bool) {
369 if keyEqual(unsafe.Pointer(&head.key), abi.NoEscape(unsafe.Pointer(&key))) &&
370 valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
371
372 return head.overflow.Load(), true
373 }
374 i := &head.overflow
375 e := i.Load()
376 for e != nil {
377 if keyEqual(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) &&
378 valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
379 i.Store(e.overflow.Load())
380 return head, true
381 }
382 i = &e.overflow
383 e = e.overflow.Load()
384 }
385 return head, false
386 }
387
388
389
390 type node[K, V comparable] struct {
391 isEntry bool
392 }
393
394 func (n *node[K, V]) entry() *entry[K, V] {
395 if !n.isEntry {
396 panic("called entry on non-entry node")
397 }
398 return (*entry[K, V])(unsafe.Pointer(n))
399 }
400
401 func (n *node[K, V]) indirect() *indirect[K, V] {
402 if n.isEntry {
403 panic("called indirect on entry node")
404 }
405 return (*indirect[K, V])(unsafe.Pointer(n))
406 }
407
View as plain text