1
2
3
4
5 package main
6
7
8
9
10 import (
11 "bytes"
12 "fmt"
13 "go/format"
14 "io"
15 "log"
16 "os"
17 )
18
19 type allocator struct {
20 name string
21 typ string
22 mak string
23 capacity string
24 resize string
25 clear string
26 minLog int
27 maxLog int
28 }
29
30 type derived struct {
31 name string
32 typ string
33 base string
34 }
35
36 func genAllocators() {
37 allocators := []allocator{
38 {
39 name: "ValueSlice",
40 typ: "[]*Value",
41 capacity: "cap(%s)",
42 mak: "make([]*Value, %s)",
43 resize: "%s[:%s]",
44 clear: "clear(%s)",
45 minLog: 5,
46 maxLog: 32,
47 },
48 {
49 name: "LimitSlice",
50 typ: "[]limit",
51 capacity: "cap(%s)",
52 mak: "make([]limit, %s)",
53 resize: "%s[:%s]",
54 clear: "clear(%s)",
55 minLog: 3,
56 maxLog: 30,
57 },
58 {
59 name: "SparseSet",
60 typ: "*sparseSet",
61 capacity: "%s.cap()",
62 mak: "newSparseSet(%s)",
63 resize: "",
64 clear: "%s.clear()",
65 minLog: 5,
66 maxLog: 32,
67 },
68 {
69 name: "SparseMap",
70 typ: "*sparseMap",
71 capacity: "%s.cap()",
72 mak: "newSparseMap(%s)",
73 resize: "",
74 clear: "%s.clear()",
75 minLog: 5,
76 maxLog: 32,
77 },
78 {
79 name: "SparseMapPos",
80 typ: "*sparseMapPos",
81 capacity: "%s.cap()",
82 mak: "newSparseMapPos(%s)",
83 resize: "",
84 clear: "%s.clear()",
85 minLog: 5,
86 maxLog: 32,
87 },
88 }
89 deriveds := []derived{
90 {
91 name: "BlockSlice",
92 typ: "[]*Block",
93 base: "ValueSlice",
94 },
95 {
96 name: "Int64",
97 typ: "[]int64",
98 base: "LimitSlice",
99 },
100 {
101 name: "IntSlice",
102 typ: "[]int",
103 base: "LimitSlice",
104 },
105 {
106 name: "Int32Slice",
107 typ: "[]int32",
108 base: "LimitSlice",
109 },
110 {
111 name: "Int8Slice",
112 typ: "[]int8",
113 base: "LimitSlice",
114 },
115 {
116 name: "BoolSlice",
117 typ: "[]bool",
118 base: "LimitSlice",
119 },
120 {
121 name: "IDSlice",
122 typ: "[]ID",
123 base: "LimitSlice",
124 },
125 {
126 name: "UintSlice",
127 typ: "[]uint",
128 base: "LimitSlice",
129 },
130 {
131 name: "KnownBitsEntriesSlice",
132 typ: "[]knownBitsEntry",
133 base: "LimitSlice",
134 },
135 }
136
137 w := new(bytes.Buffer)
138 fmt.Fprintf(w, "// Code generated from _gen/allocators.go using 'go generate'; DO NOT EDIT.\n")
139 fmt.Fprintln(w)
140 fmt.Fprintln(w, "package ssa")
141
142 fmt.Fprintln(w, "import (")
143 fmt.Fprintln(w, "\"internal/unsafeheader\"")
144 fmt.Fprintln(w, "\"math/bits\"")
145 fmt.Fprintln(w, "\"sync\"")
146 fmt.Fprintln(w, "\"unsafe\"")
147 fmt.Fprintln(w, ")")
148 for _, a := range allocators {
149 genAllocator(w, a)
150 }
151 for _, d := range deriveds {
152 for _, base := range allocators {
153 if base.name == d.base {
154 genDerived(w, d, base)
155 break
156 }
157 }
158 }
159
160 b := w.Bytes()
161 var err error
162 b, err = format.Source(b)
163 if err != nil {
164 fmt.Printf("%s\n", w.Bytes())
165 panic(err)
166 }
167
168 if err := os.WriteFile(outFile("allocators.go"), b, 0666); err != nil {
169 log.Fatalf("can't write output: %v\n", err)
170 }
171 }
172 func genAllocator(w io.Writer, a allocator) {
173 fmt.Fprintf(w, "var poolFree%s [%d]sync.Pool\n", a.name, a.maxLog-a.minLog)
174 fmt.Fprintf(w, "func (c *Cache) alloc%s(n int) %s {\n", a.name, a.typ)
175 fmt.Fprintf(w, "var s %s\n", a.typ)
176 fmt.Fprintf(w, "n2 := n\n")
177 fmt.Fprintf(w, "if n2 < %d { n2 = %d }\n", 1<<a.minLog, 1<<a.minLog)
178 fmt.Fprintf(w, "b := bits.Len(uint(n2-1))\n")
179 fmt.Fprintf(w, "v := poolFree%s[b-%d].Get()\n", a.name, a.minLog)
180 fmt.Fprintf(w, "if v == nil {\n")
181 fmt.Fprintf(w, " s = %s\n", fmt.Sprintf(a.mak, "1<<b"))
182 fmt.Fprintf(w, "} else {\n")
183 if a.typ[0] == '*' {
184 fmt.Fprintf(w, "s = v.(%s)\n", a.typ)
185 } else {
186 fmt.Fprintf(w, "sp := v.(*%s)\n", a.typ)
187 fmt.Fprintf(w, "s = *sp\n")
188 fmt.Fprintf(w, "*sp = nil\n")
189 fmt.Fprintf(w, "c.hdr%s = append(c.hdr%s, sp)\n", a.name, a.name)
190 }
191 fmt.Fprintf(w, "}\n")
192 if a.resize != "" {
193 fmt.Fprintf(w, "s = %s\n", fmt.Sprintf(a.resize, "s", "n"))
194 }
195 fmt.Fprintf(w, "return s\n")
196 fmt.Fprintf(w, "}\n")
197 fmt.Fprintf(w, "func (c *Cache) free%s(s %s) {\n", a.name, a.typ)
198 fmt.Fprintf(w, "%s\n", fmt.Sprintf(a.clear, "s"))
199 fmt.Fprintf(w, "b := bits.Len(uint(%s) - 1)\n", fmt.Sprintf(a.capacity, "s"))
200 if a.typ[0] == '*' {
201 fmt.Fprintf(w, "poolFree%s[b-%d].Put(s)\n", a.name, a.minLog)
202 } else {
203 fmt.Fprintf(w, "var sp *%s\n", a.typ)
204 fmt.Fprintf(w, "if len(c.hdr%s) == 0 {\n", a.name)
205 fmt.Fprintf(w, " sp = new(%s)\n", a.typ)
206 fmt.Fprintf(w, "} else {\n")
207 fmt.Fprintf(w, " sp = c.hdr%s[len(c.hdr%s)-1]\n", a.name, a.name)
208 fmt.Fprintf(w, " c.hdr%s[len(c.hdr%s)-1] = nil\n", a.name, a.name)
209 fmt.Fprintf(w, " c.hdr%s = c.hdr%s[:len(c.hdr%s)-1]\n", a.name, a.name, a.name)
210 fmt.Fprintf(w, "}\n")
211 fmt.Fprintf(w, "*sp = s\n")
212 fmt.Fprintf(w, "poolFree%s[b-%d].Put(sp)\n", a.name, a.minLog)
213 }
214 fmt.Fprintf(w, "}\n")
215 }
216 func genDerived(w io.Writer, d derived, base allocator) {
217 fmt.Fprintf(w, "func (c *Cache) alloc%s(n int) %s {\n", d.name, d.typ)
218 if d.typ[:2] != "[]" || base.typ[:2] != "[]" {
219 panic(fmt.Sprintf("bad derived types: %s %s", d.typ, base.typ))
220 }
221 fmt.Fprintf(w, "var base %s\n", base.typ[2:])
222 fmt.Fprintf(w, "var derived %s\n", d.typ[2:])
223 fmt.Fprintf(w, "if unsafe.Sizeof(base)%%unsafe.Sizeof(derived) != 0 { panic(\"bad\") }\n")
224 fmt.Fprintf(w, "scale := unsafe.Sizeof(base)/unsafe.Sizeof(derived)\n")
225 fmt.Fprintf(w, "b := c.alloc%s(int((uintptr(n)+scale-1)/scale))\n", base.name)
226 fmt.Fprintf(w, "s := unsafeheader.Slice {\n")
227 fmt.Fprintf(w, " Data: unsafe.Pointer(&b[0]),\n")
228 fmt.Fprintf(w, " Len: n,\n")
229 fmt.Fprintf(w, " Cap: cap(b)*int(scale),\n")
230 fmt.Fprintf(w, " }\n")
231 fmt.Fprintf(w, "return *(*%s)(unsafe.Pointer(&s))\n", d.typ)
232 fmt.Fprintf(w, "}\n")
233 fmt.Fprintf(w, "func (c *Cache) free%s(s %s) {\n", d.name, d.typ)
234 fmt.Fprintf(w, "var base %s\n", base.typ[2:])
235 fmt.Fprintf(w, "var derived %s\n", d.typ[2:])
236 fmt.Fprintf(w, "scale := unsafe.Sizeof(base)/unsafe.Sizeof(derived)\n")
237 fmt.Fprintf(w, "b := unsafeheader.Slice {\n")
238 fmt.Fprintf(w, " Data: unsafe.Pointer(&s[0]),\n")
239 fmt.Fprintf(w, " Len: int((uintptr(len(s))+scale-1)/scale),\n")
240 fmt.Fprintf(w, " Cap: int((uintptr(cap(s))+scale-1)/scale),\n")
241 fmt.Fprintf(w, " }\n")
242 fmt.Fprintf(w, "c.free%s(*(*%s)(unsafe.Pointer(&b)))\n", base.name, base.typ)
243 fmt.Fprintf(w, "}\n")
244 }
245
View as plain text