Source file src/simd/archsimd/_gen/midway/intersect_simd_ops.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"flag"
    11  	"fmt"
    12  	"go/ast"
    13  	"go/format"
    14  	"go/parser"
    15  	"go/token"
    16  	"io"
    17  	"log"
    18  	"os"
    19  	"path/filepath"
    20  	"slices"
    21  	"sort"
    22  	"strings"
    23  	"unicode"
    24  	"unicode/utf8"
    25  
    26  	"gopkg.in/yaml.v3"
    27  )
    28  
    29  type MethodSet map[string]*ast.FuncDecl
    30  type TypeMethods map[string]MethodSet
    31  
    32  type Comments struct {
    33  	Types     map[string]string            `yaml:"types"`
    34  	Functions map[string]string            `yaml:"functions"`
    35  	Methods   map[string]map[string]string `yaml:"methods"`
    36  }
    37  
    38  var goRoot = flag.String("goroot", "../../../../..", "Go root")
    39  var verbose = flag.Bool("v", false, "Be much chattier about processing")
    40  
    41  type ArchAndFiles struct {
    42  	arch  string
    43  	files []string
    44  }
    45  
    46  type TypeMethod struct {
    47  	t, m string
    48  }
    49  
    50  type whyMissing struct {
    51  	wasm128, arm128, amd128, amd256, amd512 bool
    52  }
    53  
    54  func (w whyMissing) String() string {
    55  	why := ""
    56  	if w.wasm128 {
    57  		why += " wasm"
    58  	}
    59  	if w.arm128 {
    60  		why += " neon"
    61  	}
    62  	if w.amd128 {
    63  		why += " avx"
    64  	}
    65  	if w.amd256 {
    66  		why += " avx2"
    67  	}
    68  	if w.amd512 {
    69  		why += " avx512"
    70  	}
    71  	return why[1:]
    72  }
    73  
    74  func combine(arch, typ string) string {
    75  	return arch + "-" + typ
    76  }
    77  
    78  func main() {
    79  	minorProblem := false
    80  
    81  	flag.Parse()
    82  
    83  	var comments Comments
    84  	commentsData, err := os.ReadFile("comments.yaml")
    85  	if err != nil {
    86  		log.Fatalf("Failed to read comments.yaml: %v", err)
    87  	}
    88  	if err := yaml.Unmarshal(commentsData, &comments); err != nil {
    89  		log.Fatalf("Failed to parse comments.yaml: %v", err)
    90  	}
    91  
    92  	pv := func(f string, s ...any) {
    93  		if *verbose {
    94  			fmt.Fprintf(os.Stderr, f, s...)
    95  		}
    96  	}
    97  	pw := func(f string, s ...any) {
    98  		minorProblem = true
    99  		fmt.Fprintf(os.Stderr, f, s...)
   100  	}
   101  
   102  	// Hardcoded path to archsimd
   103  	archSimdPath := *goRoot + "/src/simd/archsimd"
   104  
   105  	// Hardcoded list of files
   106  	amd64Files := []string{"ops_amd64.go", "compare_gen_amd64.go", "types_amd64.go",
   107  		"other_gen_amd64.go", "extra_amd64.go", "maskmerge_gen_amd64.go",
   108  		"shuffles_amd64.go", "slice_gen_amd64.go", "slicepart_amd64.go",
   109  		"slicepart_128.go", "string.go", "ops_emulated_amd64.go"}
   110  	wasmFiles := []string{"ops_wasm.go", "types_wasm.go", "slicepart_wasm.go",
   111  		"string.go", "slicepart_128.go", "ops_emulated_wasm.go"}
   112  	neonFiles := []string{"clmul_arm64.go", "compare_gen_arm64.go",
   113  		"maskmerge_gen_arm64.go", "ops_arm64.go", "slicepart_128.go",
   114  		"ops_internal_arm64.go", "other_gen_arm64.go", "slice_gen_arm64.go",
   115  		"slicepart_arm64.go", "types_arm64.go"}
   116  
   117  	emulatedFile := *goRoot + "/src/simd/simd_emulated.go"
   118  
   119  	archAndFiles := []ArchAndFiles{
   120  		ArchAndFiles{"wasm", wasmFiles},
   121  		ArchAndFiles{"amd64", amd64Files},
   122  		ArchAndFiles{"arm64", neonFiles},
   123  	}
   124  
   125  	// Categories based on bit size
   126  	// 128-bit map: ElementType -> TypeName
   127  	map128 := map[string]string{
   128  		"Int8":    "Int8x16",
   129  		"Int16":   "Int16x8",
   130  		"Int32":   "Int32x4",
   131  		"Int64":   "Int64x2",
   132  		"Uint8":   "Uint8x16",
   133  		"Uint16":  "Uint16x8",
   134  		"Uint32":  "Uint32x4",
   135  		"Uint64":  "Uint64x2",
   136  		"Float32": "Float32x4",
   137  		"Float64": "Float64x2",
   138  		"Mask8":   "Mask8x16",
   139  		"Mask16":  "Mask16x8",
   140  		"Mask32":  "Mask32x4",
   141  		"Mask64":  "Mask64x2",
   142  	}
   143  
   144  	// 256-bit map: ElementType -> TypeName
   145  	map256 := map[string]string{
   146  		"Int8":    "Int8x32",
   147  		"Int16":   "Int16x16",
   148  		"Int32":   "Int32x8",
   149  		"Int64":   "Int64x4",
   150  		"Uint8":   "Uint8x32",
   151  		"Uint16":  "Uint16x16",
   152  		"Uint32":  "Uint32x8",
   153  		"Uint64":  "Uint64x4",
   154  		"Float32": "Float32x8",
   155  		"Float64": "Float64x4",
   156  		"Mask8":   "Mask8x32",
   157  		"Mask16":  "Mask16x16",
   158  		"Mask32":  "Mask32x8",
   159  		"Mask64":  "Mask64x4",
   160  	}
   161  
   162  	map512 := map[string]string{
   163  		"Int8":    "Int8x64",
   164  		"Int16":   "Int16x32",
   165  		"Int32":   "Int32x16",
   166  		"Int64":   "Int64x8",
   167  		"Uint8":   "Uint8x64",
   168  		"Uint16":  "Uint16x32",
   169  		"Uint32":  "Uint32x16",
   170  		"Uint64":  "Uint64x8",
   171  		"Float32": "Float32x16",
   172  		"Float64": "Float64x8",
   173  		"Mask8":   "Mask8x64",
   174  		"Mask16":  "Mask16x32",
   175  		"Mask32":  "Mask32x16",
   176  		"Mask64":  "Mask64x8",
   177  	}
   178  
   179  	sizeForType := make(map[string]int)
   180  
   181  	methodsByType := make(TypeMethods)
   182  
   183  	allMethodNames := make(map[string]bool)
   184  
   185  	missing := make(map[string]whyMissing)
   186  
   187  	fset := token.NewFileSet()
   188  
   189  	knownReceivers := make(map[string]string)
   190  	for k, v := range map128 {
   191  		knownReceivers[v] = k + "s"
   192  		sizeForType[v] = 128
   193  	}
   194  	for k, v := range map256 {
   195  		knownReceivers[v] = k + "s"
   196  		sizeForType[v] = 256
   197  	}
   198  	for k, v := range map512 {
   199  		knownReceivers[v] = k + "s"
   200  		sizeForType[v] = 512
   201  	}
   202  
   203  	receiver := func(funcDecl *ast.FuncDecl) string {
   204  		if funcDecl.Recv == nil {
   205  			return ""
   206  		}
   207  		recvType := ""
   208  		for _, field := range funcDecl.Recv.List {
   209  			// We assume single receiver
   210  			if ident, ok := field.Type.(*ast.Ident); ok {
   211  				recvType = ident.Name
   212  			} else if star, ok := field.Type.(*ast.StarExpr); ok {
   213  				if ident, ok := star.X.(*ast.Ident); ok {
   214  					recvType = ident.Name
   215  				}
   216  			}
   217  		}
   218  		return recvType
   219  	}
   220  
   221  	emulated := make(map[TypeMethod]bool)
   222  	f, err := parser.ParseFile(fset, emulatedFile, nil, parser.ParseComments)
   223  	if err != nil {
   224  		log.Fatalf("Failed to parse %s: %v", emulatedFile, err)
   225  	}
   226  
   227  	for _, decl := range f.Decls {
   228  		if funcDecl, ok := decl.(*ast.FuncDecl); ok {
   229  			if receiver := receiver(funcDecl); receiver != "" {
   230  				method := funcDecl.Name.Name
   231  				// Exported methods only (must begin with uppercase)
   232  				if m, _ := utf8.DecodeRuneInString(method); unicode.IsUpper(m) {
   233  					emulated[TypeMethod{receiver, method}] = true
   234  				}
   235  			}
   236  		}
   237  	}
   238  
   239  	for _, aaf := range archAndFiles {
   240  		for _, fname := range aaf.files {
   241  			path := filepath.Join(archSimdPath, fname)
   242  			f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
   243  			if err != nil {
   244  				log.Fatalf("Failed to parse %s: %v", path, err)
   245  			}
   246  
   247  			lci := 0
   248  			fComments := f.Comments
   249  
   250  			for _, decl := range f.Decls {
   251  				if funcDecl, ok := decl.(*ast.FuncDecl); ok {
   252  
   253  					lastComment := ""
   254  					for ; lci < len(fComments) && fComments[lci].Pos() > funcDecl.Pos(); lci++ {
   255  						lastComment = fComments[lci].Text()
   256  					}
   257  
   258  					recvType := receiver(funcDecl)
   259  
   260  					if recvType == "" || knownReceivers[recvType] == "" {
   261  						continue
   262  					}
   263  
   264  					methodName := funcDecl.Name.Name
   265  
   266  					if strings.Contains(funcDecl.Doc.Text(), "Deprecated:") {
   267  						pv("Skipping deprecated %s.%s\n", recvType, methodName)
   268  						continue
   269  					}
   270  
   271  					if strings.Contains(lastComment, "Deprecated:") {
   272  						pv("Skipping MAYBE deprecated %s.%s (check comment)\n", recvType, methodName)
   273  						continue
   274  					}
   275  
   276  					if sizeForType[recvType] == 128 {
   277  						if s := funcDecl.Doc.Text(); strings.Contains(s, "AVX512") || strings.Contains(s, "AVX2") {
   278  							pv("Skipping 128-bit %s.%s because AVX2/AVX512\n", recvType, methodName)
   279  							continue
   280  						}
   281  					}
   282  					if sizeForType[recvType] == 256 {
   283  						if s := funcDecl.Doc.Text(); strings.Contains(s, "AVX512") {
   284  							pv("Skipping 256-bit %s.%s because AVX512\n", recvType, methodName)
   285  							continue
   286  						}
   287  					}
   288  
   289  					eltType := recvType[:strings.Index(recvType, "x")]
   290  
   291  					// Allow reinterpret vectors.
   292  					if xAt := strings.Index(methodName, "x"); xAt != -1 && (strings.HasPrefix(methodName, "As") || strings.HasPrefix(methodName, "ToInt") && strings.HasPrefix(eltType, "Mask")) {
   293  						// We think this is fine, even if it changes the number of elements in the vector.
   294  						// Tweak the method name so that they will line up properly.
   295  						methodName = methodName[:xAt] + "s"
   296  					} else if strings.HasPrefix(methodName, "Broadcast") {
   297  						// Broadcast is okay
   298  					} else {
   299  						// Exclude "grouped", "Store" (not slice), and vector-size-changing methods.
   300  						if strings.Contains(methodName, "Group") {
   301  							pv("Skipping grouped method %s.%s\n", recvType, methodName)
   302  							continue
   303  						}
   304  						if methodName == "StoreArray" || methodName == "StoreMasked" {
   305  							pv("Skipping fixed-size Store method method %s.%s\n", recvType, methodName)
   306  							continue
   307  						}
   308  						if methodName == "ToBits" && recvType[0] == 'M' {
   309  							pv("Skipping Mask ToBits method (has varying return type) %s.%s\n", recvType, methodName)
   310  							continue
   311  						}
   312  						if lastChar := methodName[len(methodName)-1]; unicode.IsDigit(rune(lastChar)) && lastChar != eltType[len(eltType)-1] {
   313  							pv("Skipping size-changing method %s.%s\n", recvType, methodName)
   314  							continue
   315  						}
   316  					}
   317  
   318  					archReceiver := combine(aaf.arch, recvType)
   319  
   320  					if methodsByType[archReceiver] == nil {
   321  						methodsByType[archReceiver] = make(MethodSet)
   322  					}
   323  					methodsByType[archReceiver][methodName] = funcDecl
   324  					allMethodNames[methodName] = true
   325  				}
   326  			}
   327  		}
   328  	}
   329  
   330  	intersectionByElem := make(map[string][]string)
   331  
   332  	doWrites := func(w io.Writer) {
   333  
   334  		p := func(s ...any) { fmt.Fprint(w, s...) }
   335  		pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   336  		nl := func() { fmt.Fprintln(w) }
   337  
   338  		// elems is a slice of stems of vector types.
   339  		elems := []string{"Int8", "Int16", "Int32", "Int64", "Uint8", "Uint16", "Uint32", "Uint64", "Float32", "Float64", "Mask8", "Mask16", "Mask32", "Mask64"}
   340  
   341  		fmt.Fprintln(w,
   342  			`// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.
   343  
   344  //go:build goexperiment.simd && (amd64 || wasm || arm64)
   345  
   346  // Computed intersection of methods for supported SIMD architectures and vector widths
   347  
   348  package simd`)
   349  
   350  		if c := comments.Types["_simd"]; c != "" {
   351  			pf("// %s\n", c)
   352  		}
   353  		p("type _simd struct {\n\t_ [0]func(*_simd) *_simd\n}\n")
   354  
   355  		sigForMethod := make(map[string]*ast.FuncDecl)
   356  
   357  		// xlateType translates a type by replacing instances of types with keys in knownReceivers with their values,
   358  		// and generates the string representation of the resulting type.  E.g., []Int8x32 -> []Int8s
   359  		// (because Int8x32 -> Int8s in knownReceivers
   360  		var xlateType func(ast.Expr) string
   361  		xlateType = func(e ast.Expr) string {
   362  			switch t := e.(type) {
   363  			case *ast.Ident:
   364  				if mapped, ok := knownReceivers[t.Name]; ok {
   365  					return mapped
   366  				}
   367  				return t.Name
   368  			case *ast.StarExpr:
   369  				return "*" + xlateType(t.X)
   370  			case *ast.ArrayType:
   371  				lenStr := ""
   372  				if t.Len != nil {
   373  					var buf strings.Builder
   374  					format.Node(&buf, token.NewFileSet(), t.Len)
   375  					lenStr = buf.String()
   376  				}
   377  				return "[" + lenStr + "]" + xlateType(t.Elt)
   378  			case *ast.SelectorExpr:
   379  				return xlateType(t.X) + "." + t.Sel.Name
   380  			case *ast.Ellipsis:
   381  				return "..." + xlateType(t.Elt)
   382  			default:
   383  				var buf strings.Builder
   384  				format.Node(&buf, token.NewFileSet(), t)
   385  				return buf.String()
   386  			}
   387  		}
   388  
   389  		toScalar := func(s string) string {
   390  			if strings.HasPrefix(s, "Mask") {
   391  				return "int" + s[4:]
   392  			}
   393  			return strings.ToLower(s)
   394  		}
   395  
   396  		for _, elem := range elems {
   397  			type128 := map128[elem]
   398  			type256 := map256[elem]
   399  			type512 := map512[elem]
   400  
   401  			methods128w := methodsByType[combine("wasm", type128)]
   402  			methods128n := methodsByType[combine("arm64", type128)]
   403  			methods128 := methodsByType[combine("amd64", type128)]
   404  			methods256 := methodsByType[combine("amd64", type256)]
   405  			methods512 := methodsByType[combine("amd64", type512)]
   406  
   407  			var intersection []string
   408  			var missingNames []string
   409  			for m := range allMethodNames {
   410  				if wasm128, arm128, amd128, amd256, amd512 :=
   411  					methods128w[m] == nil, methods128n[m] == nil, methods128[m] == nil, methods256[m] == nil, methods512[m] == nil; !wasm128 && !arm128 && !amd128 && !amd256 && !amd512 {
   412  					intersection = append(intersection, m)
   413  					sigForMethod[m] = methods512[m] // Use 512-bit signature (arbitrary choice, they should match)
   414  				} else if !(wasm128 && arm128 && amd128 && amd256 && amd512) {
   415  					missing[m] = whyMissing{wasm128, arm128, amd128, amd256, amd512}
   416  					missingNames = append(missingNames, m)
   417  				}
   418  			}
   419  			sort.Strings(missingNames)
   420  
   421  			for _, m := range missingNames {
   422  				pv("Missing implementation for %ss.%s on %s\n", elem, m, missing[m].String())
   423  			}
   424  
   425  			sort.Strings(intersection)
   426  
   427  			intersectionByElem[elem] = intersection
   428  
   429  			if c := comments.Types[elem+"s"]; c != "" {
   430  				pf("// %s\n", c)
   431  			}
   432  			pf("type %ss struct {\n\t_       _simd\n\tatLeast [2]uint64 // the actual vector size may be larger.\n}\n", elem)
   433  
   434  			if elem[0] != 'M' {
   435  				// cannot load masks
   436  
   437  				loadComment := comments.Functions["Load"+elem]
   438  				if loadComment == "" && comments.Functions["default_LoadSlice"] != "" {
   439  					loadComment = fmt.Sprintf(comments.Functions["default_LoadSlice"], elem, toScalar(elem), elem)
   440  				}
   441  				if loadComment != "" {
   442  					pf("// %s\n", loadComment)
   443  				}
   444  				pf("func Load%ss([]%s) %ss\n", elem, toScalar(elem), elem)
   445  
   446  				loadPartComment := comments.Functions["Load"+elem+"Part"]
   447  				if loadPartComment == "" && comments.Functions["default_LoadPart"] != "" {
   448  					loadPartComment = fmt.Sprintf(comments.Functions["default_LoadPart"], elem, toScalar(elem), elem)
   449  				}
   450  				if loadPartComment != "" {
   451  					pf("// %s\n", loadPartComment)
   452  				}
   453  				pf("func Load%ssPart([]%s) (%ss, int)\n", elem, toScalar(elem), elem)
   454  
   455  				broadcastComment := comments.Functions["Broadcast"+elem]
   456  				if broadcastComment == "" && comments.Functions["default_Broadcast"] != "" {
   457  					broadcastComment = fmt.Sprintf(comments.Functions["default_Broadcast"], elem)
   458  				}
   459  				if broadcastComment != "" {
   460  					pf("// %s\n", broadcastComment)
   461  				}
   462  				pf("func Broadcast%ss(%s) %ss\n", elem, toScalar(elem), elem)
   463  
   464  			}
   465  
   466  			for _, m := range intersection {
   467  				fd := sigForMethod[m]
   468  				elems := elem + "s"
   469  				methodComment := ""
   470  				if typeMethods, ok := comments.Methods[elem+"s"]; ok {
   471  					methodComment = typeMethods[m]
   472  				}
   473  				if methodComment != "" {
   474  					pf("// %s\n", methodComment)
   475  				} else {
   476  					pw("Missing doc comment (in midway/comments.yaml) for %s.%s\n", elems, m)
   477  				}
   478  				pf("func (x %s) %s(", elems, m)
   479  
   480  				if !emulated[TypeMethod{elems, m}] {
   481  					pw("Missing emulated method for %s.%s\n", elems, m)
   482  				} else {
   483  					delete(emulated, TypeMethod{elems, m})
   484  				}
   485  
   486  				if fd.Type.Params != nil {
   487  					for i, field := range fd.Type.Params.List {
   488  						if i > 0 {
   489  							p(", ")
   490  						}
   491  						if len(field.Names) > 0 {
   492  							for j, name := range field.Names {
   493  								if j > 0 {
   494  									p(", ")
   495  								}
   496  								p(name.Name)
   497  							}
   498  							p(" ")
   499  						}
   500  						p(xlateType(field.Type))
   501  					}
   502  				}
   503  				p(")")
   504  
   505  				if fd.Type.Results != nil && len(fd.Type.Results.List) > 0 {
   506  					p(" ")
   507  					needsParens := len(fd.Type.Results.List) > 1 || (len(fd.Type.Results.List) == 1 && len(fd.Type.Results.List[0].Names) > 0)
   508  					if needsParens {
   509  						p("(")
   510  					}
   511  					for i, field := range fd.Type.Results.List {
   512  						if i > 0 {
   513  							p(", ")
   514  						}
   515  						if len(field.Names) > 0 {
   516  							for j, name := range field.Names {
   517  								if j > 0 {
   518  									p(", ")
   519  								}
   520  								p(name.Name)
   521  							}
   522  							p(" ")
   523  						}
   524  						p(xlateType(field.Type))
   525  					}
   526  					if needsParens {
   527  						p(")")
   528  					}
   529  				}
   530  				nl()
   531  			}
   532  		}
   533  	}
   534  	formatAndWrite(*goRoot+"/src/simd/simd.go", doWrites)
   535  	var extraMocks []TypeMethod
   536  	for x := range emulated {
   537  		extraMocks = append(extraMocks, x)
   538  	}
   539  	slices.SortFunc(extraMocks, func(a, b TypeMethod) int {
   540  		if c := strings.Compare(a.t, b.t); c != 0 {
   541  			return c
   542  		}
   543  		return strings.Compare(a.m, b.m)
   544  	})
   545  
   546  	for _, x := range extraMocks {
   547  		pw("%s contains %s.%s missing from intersected methods\n", emulatedFile, x.t, x.m)
   548  	}
   549  
   550  	elems := []string{"Int8", "Int16", "Int32", "Int64", "Uint8", "Uint16", "Uint32", "Uint64", "Float32", "Float64", "Mask8", "Mask16", "Mask32", "Mask64"}
   551  
   552  	for _, aaf := range archAndFiles {
   553  		arch := aaf.arch
   554  		doArchWrites := func(w io.Writer) {
   555  			p := func(s ...any) { fmt.Fprint(w, s...) }
   556  			pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   557  			nl := func() { fmt.Fprintln(w) }
   558  
   559  			pf("// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.\n\n")
   560  			pf("//go:build goexperiment.simd && %s\n\n", arch)
   561  			pf("package bridge\n\n")
   562  			pf("import \"simd/archsimd\"\n\n")
   563  			pf("\n")
   564  			pf("// These types/methods/functions forward calls to their counterparts in simd/archsimd.\n")
   565  			pf("// Interposing this package allows a clean separation of \"simd\" from \"archsimd\" and\n")
   566  			pf("// also allows additional useful exported declarations that would weirdly pollute archsimd.\n")
   567  			pf("\n")
   568  
   569  			var typesForArch []string
   570  			for t := range knownReceivers {
   571  				if methodsByType[combine(arch, t)] != nil {
   572  					typesForArch = append(typesForArch, t)
   573  				}
   574  			}
   575  			sort.Strings(typesForArch)
   576  
   577  			toScalar := func(s string) string {
   578  				if strings.HasPrefix(s, "Mask") {
   579  					return "int" + s[4:]
   580  				}
   581  				return strings.ToLower(s)
   582  			}
   583  
   584  			for _, t := range typesForArch {
   585  				pf("type %s archsimd.%s\n", t, t)
   586  				if xAt := strings.Index(t, "x"); xAt != -1 && !strings.HasPrefix(t, "Mask") {
   587  					elem := t[:xAt]
   588  					scalar := toScalar(elem)
   589  					pf("func Load%s(s []%s) %s {\n\treturn %s(archsimd.Load%s(s))\n}\n", t, scalar, t, t, t)
   590  					pf("func Load%sPart(s []%s) (%s, int) {\n\tv, n := archsimd.Load%sPart(s)\n\treturn %s(v), n\n}\n", t, scalar, t, t, t)
   591  					pf("func Broadcast%s(x %s) %s {\n\treturn %s(archsimd.Broadcast%s(x))\n}\n", t, scalar, t, t, t)
   592  				}
   593  			}
   594  			nl()
   595  
   596  			typeStr := func(e ast.Expr) string {
   597  				var buf strings.Builder
   598  				format.Node(&buf, token.NewFileSet(), e)
   599  				return buf.String()
   600  			}
   601  
   602  			convertArg := func(name string, e ast.Expr) string {
   603  				switch t := e.(type) {
   604  				case *ast.Ident:
   605  					if _, ok := knownReceivers[t.Name]; ok {
   606  						return fmt.Sprintf("archsimd.%s(%s)", t.Name, name)
   607  					}
   608  				case *ast.StarExpr:
   609  					if ident, ok := t.X.(*ast.Ident); ok {
   610  						if _, ok := knownReceivers[ident.Name]; ok {
   611  							return fmt.Sprintf("(*archsimd.%s)(%s)", ident.Name, name)
   612  						}
   613  					}
   614  				}
   615  				return name
   616  			}
   617  
   618  			wrapResult := func(call string, e ast.Expr) string {
   619  				switch t := e.(type) {
   620  				case *ast.Ident:
   621  					if _, ok := knownReceivers[t.Name]; ok {
   622  						return fmt.Sprintf("%s(%s)", t.Name, call)
   623  					}
   624  				case *ast.StarExpr:
   625  					if ident, ok := t.X.(*ast.Ident); ok {
   626  						if _, ok := knownReceivers[ident.Name]; ok {
   627  							return fmt.Sprintf("(*%s)(%s)", ident.Name, call)
   628  						}
   629  					}
   630  				}
   631  				return call
   632  			}
   633  
   634  			for _, elem := range elems {
   635  				intersection := intersectionByElem[elem]
   636  				for _, m := range intersection {
   637  					for _, t := range typesForArch {
   638  						if map128[elem] != t && map256[elem] != t && map512[elem] != t {
   639  							continue
   640  						}
   641  						fd := methodsByType[combine(arch, t)][m]
   642  						if fd == nil {
   643  							continue
   644  						}
   645  						pf("func (x %s) %s(", t, fd.Name.Name)
   646  						var args []string
   647  						if fd.Type.Params != nil {
   648  							paramCount := 0
   649  							for _, field := range fd.Type.Params.List {
   650  								if len(field.Names) > 0 {
   651  									for _, name := range field.Names {
   652  										if paramCount > 0 {
   653  											p(", ")
   654  										}
   655  										pf("%s %s", name.Name, typeStr(field.Type))
   656  										args = append(args, convertArg(name.Name, field.Type))
   657  										paramCount++
   658  									}
   659  								} else {
   660  									if paramCount > 0 {
   661  										p(", ")
   662  									}
   663  									paramName := fmt.Sprintf("p%d", paramCount)
   664  									pf("%s %s", paramName, typeStr(field.Type))
   665  									args = append(args, convertArg(paramName, field.Type))
   666  									paramCount++
   667  								}
   668  							}
   669  						}
   670  						p(")")
   671  
   672  						var results []ast.Expr
   673  						if fd.Type.Results != nil {
   674  							p(" ")
   675  							needsParens := len(fd.Type.Results.List) > 1 || (len(fd.Type.Results.List) == 1 && len(fd.Type.Results.List[0].Names) > 0)
   676  							if needsParens {
   677  								p("(")
   678  							}
   679  							for i, field := range fd.Type.Results.List {
   680  								if i > 0 {
   681  									p(", ")
   682  								}
   683  								results = append(results, field.Type)
   684  								p(typeStr(field.Type))
   685  							}
   686  							if needsParens {
   687  								p(")")
   688  							}
   689  						}
   690  
   691  						p(" {\n\t")
   692  						if len(results) > 0 {
   693  							p("return ")
   694  						}
   695  
   696  						callStr := fmt.Sprintf("(archsimd.%s(x)).%s(%s)", t, fd.Name.Name, strings.Join(args, ", "))
   697  						if len(results) == 1 {
   698  							p(wrapResult(callStr, results[0]))
   699  						} else {
   700  							p(callStr)
   701  						}
   702  						p("\n}\n\n")
   703  					}
   704  				}
   705  			}
   706  		}
   707  		archDir := filepath.Join(*goRoot, "src", "simd", "internal", "bridge")
   708  		os.MkdirAll(archDir, 0755)
   709  		filename := filepath.Join(archDir, "decls_"+arch+".go")
   710  		formatAndWrite(filename, doArchWrites)
   711  
   712  		doToFromWrites := func(w io.Writer) {
   713  			pf := func(f string, s ...any) { fmt.Fprintf(w, f, s...) }
   714  
   715  			pf("// Code generated by 'go run -C $GOROOT/src/simd/archsimd/_gen/midway'; DO NOT EDIT.\n\n")
   716  			pf("//go:build goexperiment.simd && %s\n\n", arch)
   717  			pf("package simd\n\n")
   718  			pf("import (\n\t\"simd/archsimd\"\n\t\"simd/internal/bridge\"\n)\n\n")
   719  
   720  			for _, elem := range elems {
   721  				var archTypes []string
   722  				if methodsByType[combine(arch, map128[elem])] != nil {
   723  					archTypes = append(archTypes, map128[elem])
   724  				}
   725  				if methodsByType[combine(arch, map256[elem])] != nil {
   726  					archTypes = append(archTypes, map256[elem])
   727  				}
   728  				if methodsByType[combine(arch, map512[elem])] != nil {
   729  					archTypes = append(archTypes, map512[elem])
   730  				}
   731  
   732  				if len(archTypes) == 0 {
   733  					continue
   734  				}
   735  
   736  				pf("func (x %ss) ToArch() any\n\n", elem)
   737  
   738  				var intfOpts []string
   739  				for _, t := range archTypes {
   740  					intfOpts = append(intfOpts, "archsimd."+t)
   741  				}
   742  				pf("type archSimd%ss interface {\n\t%s\n}\n\n", elem, strings.Join(intfOpts, " | "))
   743  
   744  				pf("func %ssFromArch[T archSimd%ss](x T) %ss {\n", elem, elem, elem)
   745  				pf("\tswitch a := any(x).(type) {\n")
   746  				pf("\t// The return expression is written this way because the code will be rewritten\n")
   747  				pf("\t// with %ss replaced by one of the arch types, and without the any-assert\n", elem)
   748  				pf("\t// hack the rewritten code would not pass type checking.\n")
   749  				pf("\t// The backend of the compiler will eat this and turn it into no code at all,\n")
   750  				pf("\t// assuming it inlines.\n")
   751  
   752  				for _, t := range archTypes {
   753  					pf("\tcase archsimd.%s:\n", t)
   754  					pf("\t\tvar t bridge.%s = bridge.%s(a)\n", t, t)
   755  					pf("\t\treturn (any(t)).(%ss)\n", elem)
   756  				}
   757  				pf("\t}\n\tpanic(\"wrong type\")\n}\n\n")
   758  			}
   759  		}
   760  		toFromFilename := filepath.Join(*goRoot, "src", "simd", "tofrom_"+arch+".go")
   761  		formatAndWrite(toFromFilename, doToFromWrites)
   762  	}
   763  
   764  	if minorProblem {
   765  		pw("The logged warnings did not prevent generation of the midway API files, but the API is flawed (lacks emulations, documentation, etc).\n")
   766  	}
   767  }
   768  
   769  // numberLines takes a slice of bytes, and returns a string where each line
   770  // is numbered, starting from 1.
   771  func numberLines(data []byte) string {
   772  	var buf bytes.Buffer
   773  	r := bytes.NewReader(data)
   774  	s := bufio.NewScanner(r)
   775  	for i := 1; s.Scan(); i++ {
   776  		fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
   777  	}
   778  	return buf.String()
   779  }
   780  
   781  func formatAndWrite(filename string, doWrites func(w io.Writer)) {
   782  	if filename == "" {
   783  		return
   784  	}
   785  	f, err := os.Create(filename)
   786  	if err != nil {
   787  		log.Fatal(err)
   788  	}
   789  	defer f.Close()
   790  
   791  	out := new(bytes.Buffer)
   792  	doWrites(out)
   793  
   794  	b, err := format.Source(out.Bytes())
   795  	if err != nil {
   796  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
   797  		fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
   798  		fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
   799  		os.Exit(1)
   800  	} else {
   801  		f.Write(b)
   802  		f.Close()
   803  	}
   804  }
   805  

View as plain text