Source file src/cmd/compile/internal/midway/rewrite.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package midway
     6  
     7  import (
     8  	"cmd/compile/internal/syntax"
     9  	"cmd/compile/internal/types2"
    10  	"fmt"
    11  	"internal/buildcfg"
    12  	"strings"
    13  )
    14  
    15  // "Midway" rewriting
    16  //
    17  // Go attempts to provide a package similar to the the "Highway" library
    18  // for C++ (https://google.github.io/highway).  The library package is "simd"
    19  // and defines vector types with unspecified widths that are bound to particular
    20  // machine dependent types as late as program execution.  This is accomplished
    21  // by rewriting code that depends on these types into code that references
    22  // architecture-specific types, perhaps more than once, and if necessary
    23  // dynamically choosing which version to execute based on hardware attributes.
    24  //
    25  // The rewriting takes place early in the compiler, after type checking but
    26  // before conversion to "unified" IR.  To ensure that types are correctly set
    27  // on the modified version of the code, type checking information is reset and
    28  // the type checking phase is re-run.  The places some limits on the shape of
    29  // the rewrites, but it also ensures that the rewritten code is well-formed.
    30  //
    31  // Rewritten code does not reference "archsimd" types directly, but instead
    32  // references types in a "bridge" package that filters the available methods
    33  // and adds a few more.  The package used relies on a builder/compiler hack;
    34  // the compiler's type checker enforces export naming conventions, but the
    35  // build system limits visibility to unrelated "internal" packages and can be
    36  // modified to allow access in special cases (like this one).  This allows the
    37  // rewritten code to reference types, functions, and methods that are not
    38  // accessible otherwise.
    39  
    40  type Rewriter struct {
    41  	pkg      *types2.Package
    42  	analyzer *Analyzer
    43  	info     *types2.Info
    44  	sizes    []int
    45  }
    46  
    47  func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
    48  	return &Rewriter{
    49  		pkg:      pkg,
    50  		info:     info,
    51  		analyzer: analyzer,
    52  		sizes:    sizes,
    53  	}
    54  }
    55  
    56  func (r *Rewriter) Rewrite(files []*syntax.File) {
    57  
    58  	// First duplicate and specialize all dependent functions and variables.
    59  	for _, fileAST := range files {
    60  
    61  		var newDecls []syntax.Decl
    62  		for _, k := range r.sizes {
    63  			newDecls = r.generateForSize(fileAST, k, newDecls)
    64  		}
    65  
    66  		// Then replace original functions with dispatchers.
    67  		r.generateDispatchers(fileAST)
    68  
    69  		fileAST.DeclList = append(fileAST.DeclList, newDecls...)
    70  	}
    71  }
    72  
    73  func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
    74  	var newDecls []syntax.Decl
    75  
    76  	for _, decl := range fileAST.DeclList {
    77  		switch d := decl.(type) {
    78  		case *syntax.FuncDecl:
    79  			if d.Name == nil {
    80  				newDecls = append(newDecls, d)
    81  				continue
    82  			}
    83  			obj := r.info.Defs[d.Name]
    84  			if !r.analyzer.dependentObj[obj] || r.analyzer.inSimd {
    85  				newDecls = append(newDecls, d)
    86  				continue
    87  			}
    88  
    89  			sig, ok := obj.Type().(*types2.Signature)
    90  			if !ok {
    91  				newDecls = append(newDecls, d)
    92  				continue
    93  			}
    94  
    95  			if r.analyzer.HasDependentSignature(sig) {
    96  				if o := r.info.Defs[d.Name]; o != nil && !o.Exported() {
    97  					// Drop unexported dependent signatures entirely
    98  					continue
    99  				}
   100  				d.Body = r.blockOf(d.Pos(), r.panicStmt(d.Pos(), "unexpected call of original function rewritten to specialized SIMD"))
   101  				newDecls = append(newDecls, d)
   102  				continue
   103  			}
   104  
   105  			// Clean signature -> Replace body with dispatcher
   106  			d.Body = r.createDispatcherBody(d, sig)
   107  			newDecls = append(newDecls, d)
   108  
   109  		case *syntax.VarDecl:
   110  			// Keep var decls even if rewritten, so that pre-rewrite code parses correctly.
   111  			// TODO figure out how to deal with side-effects in initializers.
   112  			newDecls = append(newDecls, d)
   113  		case *syntax.TypeDecl:
   114  			if !r.analyzer.dependentObj[r.info.Defs[d.Name]] || r.analyzer.inSimd {
   115  				newDecls = append(newDecls, d)
   116  			}
   117  		default:
   118  			newDecls = append(newDecls, decl)
   119  		}
   120  	}
   121  
   122  	fileAST.DeclList = newDecls
   123  
   124  	if !r.analyzer.inSimd {
   125  		// Inject an import to the bridge package (if not exists)
   126  		hasArchSimd := false
   127  		var simdImport *syntax.ImportDecl
   128  		for _, decl := range fileAST.DeclList {
   129  			if imp, ok := decl.(*syntax.ImportDecl); ok {
   130  				if imp.Path.Value == `"`+archFullPkg+`"` {
   131  					hasArchSimd = true
   132  				}
   133  				if imp.Path.Value == `"`+simdPkg+`"` {
   134  					simdImport = imp
   135  				}
   136  
   137  			}
   138  		}
   139  		p := simdImport.Pos()
   140  		if !hasArchSimd {
   141  			r.injectImport(fileAST, archFullPkg, p)
   142  		}
   143  
   144  		// Ensure at least one use of "simd"
   145  		// var _ = simd.VectorBitLen()
   146  		fun := &syntax.SelectorExpr{
   147  			X:   syntax.NewName(p, simdPkg), // Assume this is resolvable
   148  			Sel: syntax.NewName(p, vectorSizeFn),
   149  		}
   150  		fun.SetPos(p)
   151  		call := &syntax.CallExpr{Fun: fun}
   152  		call.SetPos(p)
   153  
   154  		name := syntax.NewName(p, "_")
   155  
   156  		varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
   157  		varDecl.SetPos(p)
   158  		fileAST.DeclList = append(fileAST.DeclList, varDecl)
   159  	}
   160  }
   161  
   162  func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
   163  	importDecl := &syntax.ImportDecl{
   164  		Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
   165  	}
   166  	importDecl.Path.SetPos(simdImportPos)
   167  	importDecl.SetPos(simdImportPos)
   168  	fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
   169  }
   170  
   171  func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
   172  
   173  	// Build call arguments from the function parameters
   174  	args := func() []syntax.Expr {
   175  		var args []syntax.Expr
   176  		if d.Type.ParamList != nil {
   177  			for _, field := range d.Type.ParamList {
   178  				if field.Name != nil {
   179  					paramName := syntax.NewName(field.Pos(), field.Name.Value)
   180  					args = append(args, paramName)
   181  				}
   182  			}
   183  		}
   184  		return args
   185  	}
   186  
   187  	// Slap a pos on an expression
   188  	pe := func(e syntax.Expr) syntax.Expr {
   189  		e.SetPos(d.Pos())
   190  		return e
   191  	}
   192  	// Slap a pos on a statement
   193  	ps := func(e syntax.Stmt) syntax.Stmt {
   194  		e.SetPos(d.Pos())
   195  		return e
   196  	}
   197  
   198  	// switch ast node.
   199  	// the goal is something like (for now, till there are finer-grained choices)
   200  	// switch simd.VectorSize() {
   201  	//   case 128: if simd.Emulated() { call the specialize-for-emulation-code(args) }
   202  	//             else { call the specialize-for-128-code(args) }
   203  	//   case 256: call the specialize-for-256-code(args)
   204  	//   etc
   205  	// }
   206  	//
   207  	// the cases above deal with the usual `return call(...)` vs `call(...); return`
   208  	switchStmt := &syntax.SwitchStmt{
   209  		Tag: pe(&syntax.CallExpr{
   210  			Fun: pe(&syntax.SelectorExpr{
   211  				X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   212  				Sel: syntax.NewName(d.Pos(), vectorSizeFn),
   213  			}),
   214  		}),
   215  		Body: []*syntax.CaseClause{},
   216  	}
   217  
   218  	var emulation syntax.Stmt
   219  
   220  	for _, k := range r.sizes {
   221  		fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
   222  		fnIdent := syntax.NewName(d.Pos(), fnName)
   223  
   224  		callExpr := pe(&syntax.CallExpr{
   225  			Fun:     pe(fnIdent),
   226  			ArgList: args(),
   227  		})
   228  
   229  		// callReturnStmt is either `return call(...)` or `call(...); return`
   230  		var callReturnStmt syntax.Stmt
   231  		if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
   232  			callReturnStmt = &syntax.ReturnStmt{Results: callExpr}
   233  		} else {
   234  			callReturnStmt = &syntax.BlockStmt{
   235  				List: []syntax.Stmt{
   236  					ps(&syntax.ExprStmt{X: callExpr}),
   237  					ps(&syntax.ReturnStmt{}),
   238  				},
   239  				Rbrace: d.Pos(),
   240  			}
   241  		}
   242  		callReturnStmt.SetPos(d.Pos())
   243  
   244  		if k == 0 {
   245  			// emulation == `if simd.Emulated() { callReturnStmt }`
   246  			// save it for the first part of the 128 case.
   247  			cond := pe(&syntax.CallExpr{
   248  				Fun: pe(&syntax.SelectorExpr{
   249  					X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   250  					Sel: syntax.NewName(d.Pos(), emulatedFn),
   251  				})})
   252  
   253  			blockStmt, ok := callReturnStmt.(*syntax.BlockStmt)
   254  			if !ok {
   255  				blockStmt = &syntax.BlockStmt{
   256  					List:   []syntax.Stmt{callReturnStmt},
   257  					Rbrace: d.Pos(),
   258  				}
   259  				blockStmt.SetPos(d.Pos())
   260  			}
   261  
   262  			emulation = ps(&syntax.IfStmt{
   263  				Cond: cond,
   264  				Then: blockStmt,
   265  			})
   266  			continue
   267  		}
   268  
   269  		var caseBody []syntax.Stmt
   270  		// assume that 128 is a case; when we do scalable simd, this may change.
   271  		// For now, if there is emulation, it is 128-bit (only).
   272  		if emulation != nil && k == 128 {
   273  			caseBody = append(caseBody, emulation)
   274  			emulation = nil
   275  		}
   276  
   277  		caseClause := &syntax.CaseClause{
   278  			Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
   279  			Body:  append(caseBody, callReturnStmt),
   280  		}
   281  		caseClause.SetPos(d.Pos())
   282  		switchStmt.Body = append(switchStmt.Body, caseClause)
   283  	}
   284  
   285  	panicStmt := r.panicStmt(d.Pos(), "unsupported vector size in simd-rewritten code")
   286  	return r.blockOf(d.Pos(), switchStmt, panicStmt)
   287  }
   288  
   289  func (r *Rewriter) blockOf(p syntax.Pos, stmts ...syntax.Stmt) *syntax.BlockStmt {
   290  	for _, s := range stmts {
   291  		s.SetPos(p)
   292  	}
   293  	blockStmt := &syntax.BlockStmt{List: stmts}
   294  	blockStmt.SetPos(p)
   295  	return blockStmt
   296  }
   297  
   298  func (r *Rewriter) panicStmt(p syntax.Pos, unquotedMessage string) *syntax.ExprStmt {
   299  	pe := func(e syntax.Expr) syntax.Expr {
   300  		e.SetPos(p)
   301  		return e
   302  	}
   303  	fnName := "panic"
   304  	fnIdent := pe(syntax.NewName(p, fnName))
   305  	callExpr := pe(&syntax.CallExpr{
   306  		Fun:     fnIdent,
   307  		ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: `"` + unquotedMessage + `"`, Kind: syntax.StringLit})},
   308  	})
   309  	panicStmt := &syntax.ExprStmt{X: callExpr}
   310  	panicStmt.SetPos(p)
   311  	return panicStmt
   312  }
   313  
   314  func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
   315  	copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
   316  	for _, decl := range fileAST.DeclList {
   317  		if r.shouldIncludeDecl(decl) {
   318  			newDecl := copier.CopyDecl(decl)
   319  			newDecls = append(newDecls, newDecl)
   320  		}
   321  	}
   322  	return newDecls
   323  }
   324  
   325  func nameToElemBitWidth(name string) int {
   326  	var width int
   327  	switch name {
   328  	case "Int8s", "Uint8s", "Mask8s":
   329  		width = 8
   330  	case "Int16s", "Uint16s", "Mask16s":
   331  		width = 16
   332  	case "Int32s", "Uint32s", "Float32s", "Mask32s":
   333  		width = 32
   334  	case "Int64s", "Uint64s", "Float64s", "Mask64s":
   335  		width = 64
   336  	}
   337  	return width
   338  }
   339  
   340  func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
   341  	// Files (and declarations) in the simd package are excluded
   342  	// from processing, except for those that whose name begins
   343  	// with "tofrom_".
   344  	if r.analyzer.inSimd {
   345  		theFile := decl.Pos().Base().Filename()
   346  
   347  		lastSlash := strings.LastIndex(theFile, simdPkg+"/")
   348  		lastBackslash := strings.LastIndex(theFile, simdPkg+"\\")
   349  
   350  		// Windows paths can be chaos, all we care, is whether the very last part
   351  		// of the path is any-path-separator + "tofrom_" + anything-else, given that
   352  		// we already know that we are in the simd package.
   353  		maxSlash := max(lastSlash, lastBackslash)
   354  		if maxSlash == -1 {
   355  			return false
   356  		}
   357  		if !strings.HasPrefix(theFile[maxSlash:], simdPkg+"/tofrom_") &&
   358  			!strings.HasPrefix(theFile[maxSlash:], simdPkg+"\\tofrom_") {
   359  			return false
   360  		}
   361  	}
   362  
   363  	switch d := decl.(type) {
   364  	case *syntax.FuncDecl:
   365  		if d.Name != nil {
   366  			return r.analyzer.dependentObj[r.info.Defs[d.Name]]
   367  		}
   368  	case *syntax.TypeDecl:
   369  		return r.analyzer.dependentObj[r.info.Defs[d.Name]]
   370  	case *syntax.VarDecl:
   371  		for _, name := range d.NameList {
   372  			if r.analyzer.dependentObj[r.info.Defs[name]] {
   373  				return true
   374  			}
   375  		}
   376  	}
   377  	return false
   378  }
   379  
   380  // Generate an API matching the standalone compilation call
   381  func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
   382  	if !buildcfg.Experiment.SIMD {
   383  		return false
   384  	}
   385  
   386  	switch buildcfg.GOARCH {
   387  	case "wasm", "amd64", "arm64":
   388  	default:
   389  		return false
   390  	}
   391  
   392  	sizes := rewriteSizes()
   393  	if len(sizes) == 0 {
   394  		return false
   395  	}
   396  	analyzer := NewAnalyzer(pkg, info)
   397  	if !analyzer.Analyze(files) {
   398  		return false
   399  	}
   400  
   401  	CheckPositions(files, "before midway")
   402  
   403  	rewriter := NewRewriter(pkg, info, analyzer, sizes)
   404  	rewriter.Rewrite(files)
   405  
   406  	CheckPositions(files, "after midway")
   407  
   408  	return true
   409  }
   410  

View as plain text