1
2
3
4
5 package midway
6
7 import (
8 "cmd/compile/internal/syntax"
9 "cmd/compile/internal/types2"
10 "fmt"
11 "internal/buildcfg"
12 "strings"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 type Rewriter struct {
41 pkg *types2.Package
42 analyzer *Analyzer
43 info *types2.Info
44 sizes []int
45 }
46
47 func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
48 return &Rewriter{
49 pkg: pkg,
50 info: info,
51 analyzer: analyzer,
52 sizes: sizes,
53 }
54 }
55
56 func (r *Rewriter) Rewrite(files []*syntax.File) {
57
58
59 for _, fileAST := range files {
60
61 var newDecls []syntax.Decl
62 for _, k := range r.sizes {
63 newDecls = r.generateForSize(fileAST, k, newDecls)
64 }
65
66
67 r.generateDispatchers(fileAST)
68
69 fileAST.DeclList = append(fileAST.DeclList, newDecls...)
70 }
71 }
72
73 func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
74 var newDecls []syntax.Decl
75
76 for _, decl := range fileAST.DeclList {
77 switch d := decl.(type) {
78 case *syntax.FuncDecl:
79 if d.Name == nil {
80 newDecls = append(newDecls, d)
81 continue
82 }
83 obj := r.info.Defs[d.Name]
84 if !r.analyzer.dependentObj[obj] || r.analyzer.inSimd {
85 newDecls = append(newDecls, d)
86 continue
87 }
88
89 sig, ok := obj.Type().(*types2.Signature)
90 if !ok {
91 newDecls = append(newDecls, d)
92 continue
93 }
94
95 if r.analyzer.HasDependentSignature(sig) {
96 if o := r.info.Defs[d.Name]; o != nil && !o.Exported() {
97
98 continue
99 }
100 d.Body = r.blockOf(d.Pos(), r.panicStmt(d.Pos(), "unexpected call of original function rewritten to specialized SIMD"))
101 newDecls = append(newDecls, d)
102 continue
103 }
104
105
106 d.Body = r.createDispatcherBody(d, sig)
107 newDecls = append(newDecls, d)
108
109 case *syntax.VarDecl:
110
111
112 newDecls = append(newDecls, d)
113 case *syntax.TypeDecl:
114 if !r.analyzer.dependentObj[r.info.Defs[d.Name]] || r.analyzer.inSimd {
115 newDecls = append(newDecls, d)
116 }
117 default:
118 newDecls = append(newDecls, decl)
119 }
120 }
121
122 fileAST.DeclList = newDecls
123
124 if !r.analyzer.inSimd {
125
126 hasArchSimd := false
127 var simdImport *syntax.ImportDecl
128 for _, decl := range fileAST.DeclList {
129 if imp, ok := decl.(*syntax.ImportDecl); ok {
130 if imp.Path.Value == `"`+archFullPkg+`"` {
131 hasArchSimd = true
132 }
133 if imp.Path.Value == `"`+simdPkg+`"` {
134 simdImport = imp
135 }
136
137 }
138 }
139 p := simdImport.Pos()
140 if !hasArchSimd {
141 r.injectImport(fileAST, archFullPkg, p)
142 }
143
144
145
146 fun := &syntax.SelectorExpr{
147 X: syntax.NewName(p, simdPkg),
148 Sel: syntax.NewName(p, vectorSizeFn),
149 }
150 fun.SetPos(p)
151 call := &syntax.CallExpr{Fun: fun}
152 call.SetPos(p)
153
154 name := syntax.NewName(p, "_")
155
156 varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
157 varDecl.SetPos(p)
158 fileAST.DeclList = append(fileAST.DeclList, varDecl)
159 }
160 }
161
162 func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
163 importDecl := &syntax.ImportDecl{
164 Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
165 }
166 importDecl.Path.SetPos(simdImportPos)
167 importDecl.SetPos(simdImportPos)
168 fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
169 }
170
171 func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
172
173
174 args := func() []syntax.Expr {
175 var args []syntax.Expr
176 if d.Type.ParamList != nil {
177 for _, field := range d.Type.ParamList {
178 if field.Name != nil {
179 paramName := syntax.NewName(field.Pos(), field.Name.Value)
180 args = append(args, paramName)
181 }
182 }
183 }
184 return args
185 }
186
187
188 pe := func(e syntax.Expr) syntax.Expr {
189 e.SetPos(d.Pos())
190 return e
191 }
192
193 ps := func(e syntax.Stmt) syntax.Stmt {
194 e.SetPos(d.Pos())
195 return e
196 }
197
198
199
200
201
202
203
204
205
206
207
208 switchStmt := &syntax.SwitchStmt{
209 Tag: pe(&syntax.CallExpr{
210 Fun: pe(&syntax.SelectorExpr{
211 X: syntax.NewName(d.Pos(), simdPkg),
212 Sel: syntax.NewName(d.Pos(), vectorSizeFn),
213 }),
214 }),
215 Body: []*syntax.CaseClause{},
216 }
217
218 var emulation syntax.Stmt
219
220 for _, k := range r.sizes {
221 fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
222 fnIdent := syntax.NewName(d.Pos(), fnName)
223
224 callExpr := pe(&syntax.CallExpr{
225 Fun: pe(fnIdent),
226 ArgList: args(),
227 })
228
229
230 var callReturnStmt syntax.Stmt
231 if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
232 callReturnStmt = &syntax.ReturnStmt{Results: callExpr}
233 } else {
234 callReturnStmt = &syntax.BlockStmt{
235 List: []syntax.Stmt{
236 ps(&syntax.ExprStmt{X: callExpr}),
237 ps(&syntax.ReturnStmt{}),
238 },
239 Rbrace: d.Pos(),
240 }
241 }
242 callReturnStmt.SetPos(d.Pos())
243
244 if k == 0 {
245
246
247 cond := pe(&syntax.CallExpr{
248 Fun: pe(&syntax.SelectorExpr{
249 X: syntax.NewName(d.Pos(), simdPkg),
250 Sel: syntax.NewName(d.Pos(), emulatedFn),
251 })})
252
253 blockStmt, ok := callReturnStmt.(*syntax.BlockStmt)
254 if !ok {
255 blockStmt = &syntax.BlockStmt{
256 List: []syntax.Stmt{callReturnStmt},
257 Rbrace: d.Pos(),
258 }
259 blockStmt.SetPos(d.Pos())
260 }
261
262 emulation = ps(&syntax.IfStmt{
263 Cond: cond,
264 Then: blockStmt,
265 })
266 continue
267 }
268
269 var caseBody []syntax.Stmt
270
271
272 if emulation != nil && k == 128 {
273 caseBody = append(caseBody, emulation)
274 emulation = nil
275 }
276
277 caseClause := &syntax.CaseClause{
278 Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
279 Body: append(caseBody, callReturnStmt),
280 }
281 caseClause.SetPos(d.Pos())
282 switchStmt.Body = append(switchStmt.Body, caseClause)
283 }
284
285 panicStmt := r.panicStmt(d.Pos(), "unsupported vector size in simd-rewritten code")
286 return r.blockOf(d.Pos(), switchStmt, panicStmt)
287 }
288
289 func (r *Rewriter) blockOf(p syntax.Pos, stmts ...syntax.Stmt) *syntax.BlockStmt {
290 for _, s := range stmts {
291 s.SetPos(p)
292 }
293 blockStmt := &syntax.BlockStmt{List: stmts}
294 blockStmt.SetPos(p)
295 return blockStmt
296 }
297
298 func (r *Rewriter) panicStmt(p syntax.Pos, unquotedMessage string) *syntax.ExprStmt {
299 pe := func(e syntax.Expr) syntax.Expr {
300 e.SetPos(p)
301 return e
302 }
303 fnName := "panic"
304 fnIdent := pe(syntax.NewName(p, fnName))
305 callExpr := pe(&syntax.CallExpr{
306 Fun: fnIdent,
307 ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: `"` + unquotedMessage + `"`, Kind: syntax.StringLit})},
308 })
309 panicStmt := &syntax.ExprStmt{X: callExpr}
310 panicStmt.SetPos(p)
311 return panicStmt
312 }
313
314 func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
315 copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
316 for _, decl := range fileAST.DeclList {
317 if r.shouldIncludeDecl(decl) {
318 newDecl := copier.CopyDecl(decl)
319 newDecls = append(newDecls, newDecl)
320 }
321 }
322 return newDecls
323 }
324
325 func nameToElemBitWidth(name string) int {
326 var width int
327 switch name {
328 case "Int8s", "Uint8s", "Mask8s":
329 width = 8
330 case "Int16s", "Uint16s", "Mask16s":
331 width = 16
332 case "Int32s", "Uint32s", "Float32s", "Mask32s":
333 width = 32
334 case "Int64s", "Uint64s", "Float64s", "Mask64s":
335 width = 64
336 }
337 return width
338 }
339
340 func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
341
342
343
344 if r.analyzer.inSimd {
345 theFile := decl.Pos().Base().Filename()
346
347 lastSlash := strings.LastIndex(theFile, simdPkg+"/")
348 lastBackslash := strings.LastIndex(theFile, simdPkg+"\\")
349
350
351
352
353 maxSlash := max(lastSlash, lastBackslash)
354 if maxSlash == -1 {
355 return false
356 }
357 if !strings.HasPrefix(theFile[maxSlash:], simdPkg+"/tofrom_") &&
358 !strings.HasPrefix(theFile[maxSlash:], simdPkg+"\\tofrom_") {
359 return false
360 }
361 }
362
363 switch d := decl.(type) {
364 case *syntax.FuncDecl:
365 if d.Name != nil {
366 return r.analyzer.dependentObj[r.info.Defs[d.Name]]
367 }
368 case *syntax.TypeDecl:
369 return r.analyzer.dependentObj[r.info.Defs[d.Name]]
370 case *syntax.VarDecl:
371 for _, name := range d.NameList {
372 if r.analyzer.dependentObj[r.info.Defs[name]] {
373 return true
374 }
375 }
376 }
377 return false
378 }
379
380
381 func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
382 if !buildcfg.Experiment.SIMD {
383 return false
384 }
385
386 switch buildcfg.GOARCH {
387 case "wasm", "amd64", "arm64":
388 default:
389 return false
390 }
391
392 sizes := rewriteSizes()
393 if len(sizes) == 0 {
394 return false
395 }
396 analyzer := NewAnalyzer(pkg, info)
397 if !analyzer.Analyze(files) {
398 return false
399 }
400
401 CheckPositions(files, "before midway")
402
403 rewriter := NewRewriter(pkg, info, analyzer, sizes)
404 rewriter.Rewrite(files)
405
406 CheckPositions(files, "after midway")
407
408 return true
409 }
410
View as plain text