Source file src/simd/_gen/simdgen/gen_simdssa.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"log"
    11  	"strings"
    12  	"text/template"
    13  )
    14  
    15  var (
    16  	ssaTemplates = template.Must(template.New("simdSSA").Parse(`
    17  {{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
    18  
    19  package amd64
    20  
    21  import (
    22  	"cmd/compile/internal/ssa"
    23  	"cmd/compile/internal/ssagen"
    24  	"cmd/internal/obj"
    25  	"cmd/internal/obj/x86"
    26  )
    27  
    28  func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
    29  	var p *obj.Prog
    30  	switch v.Op {{"{"}}{{end}}
    31  {{define "case"}}
    32  	case {{.Cases}}:
    33  		p = {{.Helper}}(s, v)
    34  {{end}}
    35  {{define "footer"}}
    36  	default:
    37  		// Unknown reg shape
    38  		return false
    39  	}
    40  {{end}}
    41  {{define "zeroing"}}
    42  	// Masked operation are always compiled with zeroing.
    43  	switch v.Op {
    44  	case {{.}}:
    45  		x86.ParseSuffix(p, "Z")
    46  	}
    47  {{end}}
    48  {{define "ending"}}
    49  	return true
    50  }
    51  {{end}}`))
    52  )
    53  
    54  type tplSSAData struct {
    55  	Cases  string
    56  	Helper string
    57  }
    58  
    59  // writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go
    60  // within the specified directory.
    61  func writeSIMDSSA(ops []Operation) *bytes.Buffer {
    62  	var ZeroingMask []string
    63  	regInfoKeys := []string{
    64  		"v11",
    65  		"v21",
    66  		"v2k",
    67  		"v2kv",
    68  		"v2kk",
    69  		"vkv",
    70  		"v31",
    71  		"v3kv",
    72  		"v11Imm8",
    73  		"vkvImm8",
    74  		"v21Imm8",
    75  		"v2kImm8",
    76  		"v2kkImm8",
    77  		"v31ResultInArg0",
    78  		"v3kvResultInArg0",
    79  		"vfpv",
    80  		"vfpkv",
    81  		"vgpvImm8",
    82  		"vgpImm8",
    83  		"v2kvImm8",
    84  		"vkvload",
    85  		"v21load",
    86  		"v31loadResultInArg0",
    87  		"v3kvloadResultInArg0",
    88  		"v2kvload",
    89  		"v2kload",
    90  		"v11load",
    91  		"v11loadImm8",
    92  		"vkvloadImm8",
    93  		"v21loadImm8",
    94  		"v2kloadImm8",
    95  		"v2kkloadImm8",
    96  		"v2kvloadImm8",
    97  		"v31ResultInArg0Imm8",
    98  		"v31loadResultInArg0Imm8",
    99  		"v21ResultInArg0",
   100  		"v21ResultInArg0Imm8",
   101  		"v31x0AtIn2ResultInArg0",
   102  		"v2kvResultInArg0",
   103  	}
   104  	regInfoSet := map[string][]string{}
   105  	for _, key := range regInfoKeys {
   106  		regInfoSet[key] = []string{}
   107  	}
   108  
   109  	seen := map[string]struct{}{}
   110  	allUnseen := make(map[string][]Operation)
   111  	allUnseenCaseStr := make(map[string][]string)
   112  	classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
   113  		regShape, err := op.regShape(mem)
   114  		if err != nil {
   115  			return err
   116  		}
   117  		if regShape == "v01load" {
   118  			regShape = "vload"
   119  		}
   120  		if shapeOut == OneVregOutAtIn {
   121  			regShape += "ResultInArg0"
   122  		}
   123  		if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
   124  			regShape += "Imm8"
   125  		}
   126  		regShape, err = rewriteVecAsScalarRegInfo(op, regShape)
   127  		if err != nil {
   128  			return err
   129  		}
   130  		if _, ok := regInfoSet[regShape]; !ok {
   131  			allUnseen[regShape] = append(allUnseen[regShape], op)
   132  			allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr)
   133  		}
   134  		regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
   135  		if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) {
   136  			regShapeMerging := regShape
   137  			if shapeOut != OneVregOutAtIn {
   138  				// We have to copy the slice here becasue the sort will be visible from other
   139  				// aliases when no reslicing is happening.
   140  				newIn := make([]Operand, len(op.In), len(op.In)+1)
   141  				copy(newIn, op.In)
   142  				op.In = newIn
   143  				op.In = append(op.In, op.Out[0])
   144  				op.sortOperand()
   145  				regShapeMerging, err = op.regShape(mem)
   146  				regShapeMerging += "ResultInArg0"
   147  			}
   148  			if err != nil {
   149  				return err
   150  			}
   151  			if _, ok := regInfoSet[regShapeMerging]; !ok {
   152  				allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op)
   153  				allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging")
   154  			}
   155  			regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging")
   156  		}
   157  		return nil
   158  	}
   159  	for _, op := range ops {
   160  		shapeIn, shapeOut, maskType, _, gOp := op.shape()
   161  		asm := machineOpName(maskType, gOp)
   162  		if _, ok := seen[asm]; ok {
   163  			continue
   164  		}
   165  		seen[asm] = struct{}{}
   166  		caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
   167  		isZeroMasking := false
   168  		if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
   169  			if gOp.Zeroing == nil || *gOp.Zeroing {
   170  				ZeroingMask = append(ZeroingMask, caseStr)
   171  				isZeroMasking = true
   172  			}
   173  		}
   174  		if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil {
   175  			panic(err)
   176  		}
   177  		if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
   178  			// Make a full vec memory variant
   179  			op = rewriteLastVregToMem(op)
   180  			// Ignore the error
   181  			// an error could be triggered by [checkVecAsScalar].
   182  			// TODO: make [checkVecAsScalar] aware of mem ops.
   183  			if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
   184  				if *Verbose {
   185  					log.Printf("Seen error: %e", err)
   186  				}
   187  			} else if isZeroMasking {
   188  				ZeroingMask = append(ZeroingMask, caseStr+"load")
   189  			}
   190  		}
   191  	}
   192  	if len(allUnseen) != 0 {
   193  		allKeys := make([]string, 0)
   194  		for k := range allUnseen {
   195  			allKeys = append(allKeys, k)
   196  		}
   197  		panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr))
   198  	}
   199  
   200  	buffer := new(bytes.Buffer)
   201  
   202  	if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
   203  		panic(fmt.Errorf("failed to execute header template: %w", err))
   204  	}
   205  
   206  	for _, regShape := range regInfoKeys {
   207  		// Stable traversal of regInfoSet
   208  		cases := regInfoSet[regShape]
   209  		if len(cases) == 0 {
   210  			continue
   211  		}
   212  		data := tplSSAData{
   213  			Cases:  strings.Join(cases, ",\n\t\t"),
   214  			Helper: "simd" + capitalizeFirst(regShape),
   215  		}
   216  		if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
   217  			panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
   218  		}
   219  	}
   220  
   221  	if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
   222  		panic(fmt.Errorf("failed to execute footer template: %w", err))
   223  	}
   224  
   225  	if len(ZeroingMask) != 0 {
   226  		if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
   227  			panic(fmt.Errorf("failed to execute footer template: %w", err))
   228  		}
   229  	}
   230  
   231  	if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
   232  		panic(fmt.Errorf("failed to execute footer template: %w", err))
   233  	}
   234  
   235  	return buffer
   236  }
   237  

View as plain text