1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/token"
15 "log"
16 "os"
17 "strings"
18
19 "golang.org/x/tools/go/ast/astutil"
20
21 internalastutil "runtime/_mkmalloc/astutil"
22 )
23
24 var stdout = flag.Bool("stdout", false, "write sizeclasses source to stdout instead of sizeclasses.go")
25
26 func makeSizeToSizeClass(classes []class) []uint8 {
27 sc := uint8(0)
28 ret := make([]uint8, smallScanNoHeaderMax+1)
29 for i := range ret {
30 if i > classes[sc].size {
31 sc++
32 }
33 ret[i] = sc
34 }
35 return ret
36 }
37
38 func main() {
39 log.SetFlags(0)
40 log.SetPrefix("mkmalloc: ")
41
42 classes := makeClasses()
43 sizeToSizeClass := makeSizeToSizeClass(classes)
44
45 if *stdout {
46 if _, err := os.Stdout.Write(mustFormat(generateSizeClasses(classes))); err != nil {
47 log.Fatal(err)
48 }
49 return
50 }
51
52 sizeclasesesfile := "../../internal/runtime/gc/sizeclasses.go"
53 if err := os.WriteFile(sizeclasesesfile, mustFormat(generateSizeClasses(classes)), 0666); err != nil {
54 log.Fatal(err)
55 }
56
57 outfile := "../malloc_generated.go"
58 if err := os.WriteFile(outfile, mustFormat(inline(specializedMallocConfig(classes, sizeToSizeClass))), 0666); err != nil {
59 log.Fatal(err)
60 }
61
62 tablefile := "../malloc_tables_generated.go"
63 if err := os.WriteFile(tablefile, mustFormat(generateTable(sizeToSizeClass)), 0666); err != nil {
64 log.Fatal(err)
65 }
66 }
67
68
69 func withLineNumbers(b []byte) []byte {
70 var buf bytes.Buffer
71 i := 1
72 for line := range bytes.Lines(b) {
73 fmt.Fprintf(&buf, "%d: %s", i, line)
74 i++
75 }
76 return buf.Bytes()
77 }
78
79
80 func mustFormat(b []byte) []byte {
81 formatted, err := format.Source(b)
82 if err != nil {
83 log.Fatalf("error formatting source: %v\nsource:\n%s\n", err, withLineNumbers(b))
84 }
85 return formatted
86 }
87
88
89
90 type generatorConfig struct {
91 file string
92 specs []spec
93 }
94
95
96
97
98 type spec struct {
99 name string
100 templateFunc string
101 ops []op
102 }
103
104
105 type replacementKind int
106
107 const (
108 inlineFunc = replacementKind(iota)
109 subBasicLit
110 )
111
112
113
114
115 type op struct {
116 kind replacementKind
117 from string
118 to string
119 }
120
121 func smallScanNoHeaderSCFuncName(sc, scMax uint8) string {
122 if sc == 0 || sc > scMax {
123 return "mallocPanic"
124 }
125 return fmt.Sprintf("mallocgcSmallScanNoHeaderSC%d", sc)
126 }
127
128 func tinyFuncName(size uintptr) string {
129 if size == 0 || size > smallScanNoHeaderMax {
130 return "mallocPanic"
131 }
132 return fmt.Sprintf("mallocTiny%d", size)
133 }
134
135 func smallNoScanSCFuncName(sc, scMax uint8) string {
136 if sc < 2 || sc > scMax {
137 return "mallocPanic"
138 }
139 return fmt.Sprintf("mallocgcSmallNoScanSC%d", sc)
140 }
141
142
143
144 func specializedMallocConfig(classes []class, sizeToSizeClass []uint8) generatorConfig {
145 config := generatorConfig{file: "../malloc_stubs.go"}
146
147
148
149
150 scMax := sizeToSizeClass[smallScanNoHeaderMax]
151
152 str := fmt.Sprint
153
154
155 {
156 const noscan = 0
157 for sc := uint8(0); sc <= scMax; sc++ {
158 if sc == 0 {
159 continue
160 }
161 name := smallScanNoHeaderSCFuncName(sc, scMax)
162 elemsize := classes[sc].size
163 config.specs = append(config.specs, spec{
164 templateFunc: "mallocStub",
165 name: name,
166 ops: []op{
167 {inlineFunc, "inlinedMalloc", "smallScanNoHeaderStub"},
168 {inlineFunc, "heapSetTypeNoHeaderStub", "heapSetTypeNoHeaderStub"},
169 {inlineFunc, "nextFreeFastStub", "nextFreeFastStub"},
170 {inlineFunc, "writeHeapBitsSmallStub", "writeHeapBitsSmallStub"},
171 {subBasicLit, "elemsize_", str(elemsize)},
172 {subBasicLit, "sizeclass_", str(sc)},
173 {subBasicLit, "noscanint_", str(noscan)},
174 },
175 })
176 }
177 }
178
179
180 {
181 const noscan = 1
182
183
184 tinySizeClass := sizeToSizeClass[tinySize]
185 for s := range uintptr(16) {
186 if s == 0 {
187 continue
188 }
189 name := tinyFuncName(s)
190 elemsize := classes[tinySizeClass].size
191 config.specs = append(config.specs, spec{
192 templateFunc: "mallocStub",
193 name: name,
194 ops: []op{
195 {inlineFunc, "inlinedMalloc", "tinyStub"},
196 {inlineFunc, "nextFreeFastTiny", "nextFreeFastTiny"},
197 {subBasicLit, "elemsize_", str(elemsize)},
198 {subBasicLit, "sizeclass_", str(tinySizeClass)},
199 {subBasicLit, "size_", str(s)},
200 {subBasicLit, "noscanint_", str(noscan)},
201 },
202 })
203 }
204
205
206 for sc := uint8(tinySizeClass); sc <= scMax; sc++ {
207 name := smallNoScanSCFuncName(sc, scMax)
208 elemsize := classes[sc].size
209 config.specs = append(config.specs, spec{
210 templateFunc: "mallocStub",
211 name: name,
212 ops: []op{
213 {inlineFunc, "inlinedMalloc", "smallNoScanStub"},
214 {inlineFunc, "nextFreeFastStub", "nextFreeFastStub"},
215 {subBasicLit, "elemsize_", str(elemsize)},
216 {subBasicLit, "sizeclass_", str(sc)},
217 {subBasicLit, "noscanint_", str(noscan)},
218 },
219 })
220 }
221 }
222
223 return config
224 }
225
226
227 func inline(config generatorConfig) []byte {
228 var out bytes.Buffer
229
230
231 fset := token.NewFileSet()
232 f, err := parser.ParseFile(fset, config.file, nil, 0)
233 if err != nil {
234 log.Fatalf("parsing %s: %v", config.file, err)
235 }
236
237
238
239
240
241
242 funcDecls := map[string]*ast.FuncDecl{}
243 importDecls := []*ast.GenDecl{}
244 for _, decl := range f.Decls {
245 switch decl := decl.(type) {
246 case *ast.FuncDecl:
247 funcDecls[decl.Name.Name] = decl
248 case *ast.GenDecl:
249 if decl.Tok.String() == "import" {
250 importDecls = append(importDecls, decl)
251 continue
252 }
253 }
254 }
255
256
257 out.WriteString("// Code generated by mkmalloc.go; DO NOT EDIT.\n\n")
258 out.WriteString("package " + f.Name.Name + "\n\n")
259 for _, importDecl := range importDecls {
260 out.Write(mustFormatNode(fset, importDecl))
261 out.WriteString("\n\n")
262 }
263
264
265 for _, spec := range config.specs {
266
267 containingFuncCopy := internalastutil.CloneNode(funcDecls[spec.templateFunc])
268 if containingFuncCopy == nil {
269 log.Fatal("did not find", spec.templateFunc)
270 }
271 containingFuncCopy.Name.Name = spec.name
272
273
274 stamped := ast.Node(containingFuncCopy)
275 for _, repl := range spec.ops {
276 if toDecl, ok := funcDecls[repl.to]; ok {
277 stamped = inlineFunction(stamped, repl.from, toDecl)
278 } else {
279 stamped = substituteWithBasicLit(stamped, repl.from, repl.to)
280 }
281 }
282
283 out.Write(mustFormatNode(fset, stamped))
284 out.WriteString("\n\n")
285 }
286
287 return out.Bytes()
288 }
289
290
291
292 func substituteWithBasicLit(node ast.Node, from, to string) ast.Node {
293
294 toExpr, err := parser.ParseExpr(to)
295 if err != nil {
296 log.Fatalf("parsing expr %q: %v", to, err)
297 }
298 if _, ok := toExpr.(*ast.BasicLit); !ok {
299 log.Fatalf("op 'to' expr %q is not a basic literal", to)
300 }
301 return astutil.Apply(node, func(cursor *astutil.Cursor) bool {
302 if isIdentWithName(cursor.Node(), from) {
303 cursor.Replace(toExpr)
304 }
305 return true
306 }, nil)
307 }
308
309
310
311
312
313
314 func inlineFunction(node ast.Node, from string, toDecl *ast.FuncDecl) ast.Node {
315 return astutil.Apply(node, func(cursor *astutil.Cursor) bool {
316 switch node := cursor.Node().(type) {
317 case *ast.AssignStmt:
318
319
320 if len(node.Rhs) == 1 && isCallTo(node.Rhs[0], from) {
321 args := node.Rhs[0].(*ast.CallExpr).Args
322 if !argsMatchParameters(args, toDecl.Type.Params) {
323 log.Fatalf("applying op: arguments to %v don't match parameter names of %v: %v", from, toDecl.Name, debugPrint(args...))
324 }
325 replaceAssignment(cursor, node, toDecl)
326 }
327 return false
328 case *ast.CallExpr:
329
330 if isCallTo(node, from) {
331 if _, ok := cursor.Parent().(*ast.AssignStmt); !ok {
332 log.Fatalf("applying op: all calls to function %q being replaced must appear in an assignment statement, appears in %T", from, cursor.Parent())
333 }
334 }
335 }
336 return true
337 }, nil)
338 }
339
340
341
342 func argsMatchParameters(args []ast.Expr, params *ast.FieldList) bool {
343 var paramIdents []*ast.Ident
344 for _, f := range params.List {
345 paramIdents = append(paramIdents, f.Names...)
346 }
347
348 if len(args) != len(paramIdents) {
349 return false
350 }
351
352 for i := range args {
353 if !isIdentWithName(args[i], paramIdents[i].Name) {
354 return false
355 }
356 }
357
358 return true
359 }
360
361
362 func isIdentWithName(expr ast.Node, name string) bool {
363 ident, ok := expr.(*ast.Ident)
364 if !ok {
365 return false
366 }
367 return ident.Name == name
368 }
369
370
371 func isCallTo(expr ast.Expr, name string) bool {
372 callexpr, ok := expr.(*ast.CallExpr)
373 if !ok {
374 return false
375 }
376 return isIdentWithName(callexpr.Fun, name)
377 }
378
379
380
381
382 func replaceAssignment(cursor *astutil.Cursor, assign *ast.AssignStmt, funcdecl *ast.FuncDecl) {
383 if !hasTerminatingReturn(funcdecl.Body) {
384 log.Fatal("function being inlined must have a return at the end")
385 }
386
387 body := internalastutil.CloneNode(funcdecl.Body)
388 if hasTerminatingAndNonterminatingReturn(funcdecl.Body) {
389
390
391
392
393 body = addContinues(cursor, assign, body, everythingFollowingInParent(cursor)).(*ast.BlockStmt)
394 }
395
396 if len(body.List) < 1 {
397 log.Fatal("replacing with empty bodied function")
398 }
399
400
401
402
403
404
405 beforeReturn, ret := body.List[:len(body.List)-1], body.List[len(body.List)-1]
406 returnStmt, ok := ret.(*ast.ReturnStmt)
407 if !ok {
408 log.Fatal("last stmt in function we're replacing with should be a return")
409 }
410 results := returnStmt.Results
411
412
413 for _, stmt := range beforeReturn {
414 cursor.InsertBefore(stmt)
415 }
416
417
418 replaceWithAssignment(cursor, assign.Lhs, results, assign.Tok)
419 }
420
421
422 func hasTerminatingReturn(block *ast.BlockStmt) bool {
423 _, ok := block.List[len(block.List)-1].(*ast.ReturnStmt)
424 return ok
425 }
426
427
428
429 func hasTerminatingAndNonterminatingReturn(block *ast.BlockStmt) bool {
430 if !hasTerminatingReturn(block) {
431 return false
432 }
433 var ret bool
434 for i := range block.List[:len(block.List)-1] {
435 ast.Inspect(block.List[i], func(node ast.Node) bool {
436 _, ok := node.(*ast.ReturnStmt)
437 if ok {
438 ret = true
439 return false
440 }
441 return true
442 })
443 }
444 return ret
445 }
446
447
448
449 func everythingFollowingInParent(cursor *astutil.Cursor) *ast.BlockStmt {
450 parent := cursor.Parent()
451 block, ok := parent.(*ast.BlockStmt)
452 if !ok {
453 log.Fatal("internal error: in everythingFollowingInParent, cursor doesn't point to element in block list")
454 }
455
456 blockcopy := internalastutil.CloneNode(block)
457 blockcopy.List = blockcopy.List[cursor.Index()+1:]
458
459 if _, ok := blockcopy.List[len(blockcopy.List)-1].(*ast.ReturnStmt); !ok {
460 log.Printf("%s", mustFormatNode(token.NewFileSet(), blockcopy))
461 log.Fatal("internal error: parent doesn't end in a return")
462 }
463 return blockcopy
464 }
465
466
467
468
469
470 func addContinues(cursor *astutil.Cursor, assignNode *ast.AssignStmt, toBlock *ast.BlockStmt, continueBlock *ast.BlockStmt) ast.Node {
471 if !hasTerminatingReturn(continueBlock) {
472 log.Fatal("the block being continued to in addContinues must end in a return")
473 }
474 applyFunc := func(cursor *astutil.Cursor) bool {
475 ret, ok := cursor.Node().(*ast.ReturnStmt)
476 if !ok {
477 return true
478 }
479
480 if cursor.Parent() == toBlock && cursor.Index() == len(toBlock.List)-1 {
481 return false
482 }
483
484
485
486
487 replaceWithAssignment(cursor, assignNode.Lhs, ret.Results, assignNode.Tok)
488 cursor.InsertAfter(internalastutil.CloneNode(continueBlock))
489
490 return false
491 }
492 return astutil.Apply(toBlock, applyFunc, nil)
493 }
494
495
496 func debugPrint(nodes ...ast.Expr) string {
497 var b strings.Builder
498 for i, node := range nodes {
499 b.Write(mustFormatNode(token.NewFileSet(), node))
500 if i != len(nodes)-1 {
501 b.WriteString(", ")
502 }
503 }
504 return b.String()
505 }
506
507
508 func mustFormatNode(fset *token.FileSet, node any) []byte {
509 var buf bytes.Buffer
510 format.Node(&buf, fset, node)
511 return buf.Bytes()
512 }
513
514
515
516
517
518 func mustMatchExprs(lhs []ast.Expr, rhs []ast.Expr) ([]ast.Expr, []ast.Expr) {
519 if len(lhs) != len(rhs) {
520 log.Fatal("exprs don't match", debugPrint(lhs...), debugPrint(rhs...))
521 }
522
523 var newLhs, newRhs []ast.Expr
524 for i := range lhs {
525 lhsIdent, ok1 := lhs[i].(*ast.Ident)
526 rhsIdent, ok2 := rhs[i].(*ast.Ident)
527 if ok1 && ok2 && lhsIdent.Name == rhsIdent.Name {
528 continue
529 }
530 newLhs = append(newLhs, lhs[i])
531 newRhs = append(newRhs, rhs[i])
532 }
533
534 return newLhs, newRhs
535 }
536
537
538
539
540 func replaceWithAssignment(cursor *astutil.Cursor, lhs, rhs []ast.Expr, tok token.Token) {
541 newLhs, newRhs := mustMatchExprs(lhs, rhs)
542 if len(newLhs) == 0 {
543 cursor.Delete()
544 return
545 }
546 if len(newRhs) == 1 {
547 if lit, ok := newRhs[0].(*ast.BasicLit); ok {
548 constDecl := &ast.DeclStmt{
549 Decl: &ast.GenDecl{
550 Tok: token.CONST,
551 Specs: []ast.Spec{
552 &ast.ValueSpec{
553 Names: []*ast.Ident{newLhs[0].(*ast.Ident)},
554 Values: []ast.Expr{lit},
555 },
556 },
557 },
558 }
559 cursor.Replace(constDecl)
560 return
561 }
562 }
563 newAssignment := &ast.AssignStmt{
564 Lhs: newLhs,
565 Rhs: newRhs,
566 Tok: tok,
567 }
568 cursor.Replace(newAssignment)
569 }
570
571
572 func generateTable(sizeToSizeClass []uint8) []byte {
573 scMax := sizeToSizeClass[smallScanNoHeaderMax]
574
575 var b bytes.Buffer
576 fmt.Fprintln(&b, `// Code generated by mkmalloc.go; DO NOT EDIT.
577 //go:build !plan9
578
579 package runtime
580
581 import "unsafe"
582
583 var mallocScanTable = [513]func(size uintptr, typ *_type, needzero bool) unsafe.Pointer{`)
584
585 for i := range uintptr(smallScanNoHeaderMax + 1) {
586 fmt.Fprintf(&b, "%s,\n", smallScanNoHeaderSCFuncName(sizeToSizeClass[i], scMax))
587 }
588
589 fmt.Fprintln(&b, `
590 }
591
592 var mallocNoScanTable = [513]func(size uintptr, typ *_type, needzero bool) unsafe.Pointer{`)
593 for i := range uintptr(smallScanNoHeaderMax + 1) {
594 if i < 16 {
595 fmt.Fprintf(&b, "%s,\n", tinyFuncName(i))
596 } else {
597 fmt.Fprintf(&b, "%s,\n", smallNoScanSCFuncName(sizeToSizeClass[i], scMax))
598 }
599 }
600
601 fmt.Fprintln(&b, `
602 }`)
603
604 return b.Bytes()
605 }
606
View as plain text