Source file src/simd/archsimd/_gen/simdgen/gen_simdssa.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"log"
    11  	"strings"
    12  	"text/template"
    13  )
    14  
    15  var (
    16  	ssaTemplates = template.Must(template.New("simdSSA").Parse(`{{define "header"}}` + generatedHeader + `
    17  package amd64
    18  
    19  import (
    20  	"cmd/compile/internal/ssa"
    21  	"cmd/compile/internal/ssagen"
    22  	"cmd/internal/obj"
    23  	"cmd/internal/obj/x86"
    24  )
    25  
    26  func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
    27  	var p *obj.Prog
    28  	switch v.Op {{"{"}}{{end}}
    29  {{define "case"}}
    30  	case {{.Cases}}:
    31  		p = {{.Helper}}(s, v)
    32  {{end}}
    33  {{define "footer"}}
    34  	default:
    35  		// Unknown reg shape
    36  		return false
    37  	}
    38  {{end}}
    39  {{define "zeroing"}}
    40  	// Masked operation are always compiled with zeroing.
    41  	switch v.Op {
    42  	case {{.}}:
    43  		x86.ParseSuffix(p, "Z")
    44  	}
    45  {{end}}
    46  {{define "ending"}}
    47  	return true
    48  }
    49  {{end}}`))
    50  )
    51  
    52  type tplSSAData struct {
    53  	Cases  string
    54  	Helper string
    55  }
    56  
    57  // writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go
    58  // within the specified directory.
    59  func writeSIMDSSA(ops []Operation) *bytes.Buffer {
    60  	var ZeroingMask []string
    61  	regInfoKeys := []string{
    62  		"v11",
    63  		"v21",
    64  		"v2k",
    65  		"v2kv",
    66  		"v2kk",
    67  		"vkv",
    68  		"v31",
    69  		"v3kv",
    70  		"v11Imm8",
    71  		"vkvImm8",
    72  		"v21Imm8",
    73  		"v2kImm8",
    74  		"v2kkImm8",
    75  		"v31ResultInArg0",
    76  		"v3kvResultInArg0",
    77  		"vfpv",
    78  		"vfpkv",
    79  		"vgpvImm8",
    80  		"vgpImm8",
    81  		"v2kvImm8",
    82  		"vkvload",
    83  		"v21load",
    84  		"v31loadResultInArg0",
    85  		"v3kvloadResultInArg0",
    86  		"v2kvload",
    87  		"v2kload",
    88  		"v11load",
    89  		"v11loadImm8",
    90  		"vkvloadImm8",
    91  		"v21loadImm8",
    92  		"v2kloadImm8",
    93  		"v2kkloadImm8",
    94  		"v2kvloadImm8",
    95  		"v31ResultInArg0Imm8",
    96  		"v31loadResultInArg0Imm8",
    97  		"v21ResultInArg0",
    98  		"v21ResultInArg0Imm8",
    99  		"v31x0AtIn2ResultInArg0",
   100  		"v2kvResultInArg0",
   101  	}
   102  	regInfoSet := map[string][]string{}
   103  	for _, key := range regInfoKeys {
   104  		regInfoSet[key] = []string{}
   105  	}
   106  
   107  	seen := map[string]struct{}{}
   108  	allUnseen := make(map[string][]Operation)
   109  	allUnseenCaseStr := make(map[string][]string)
   110  	classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
   111  		regShape, err := op.regShape(mem)
   112  		if err != nil {
   113  			return err
   114  		}
   115  		if regShape == "v01load" {
   116  			regShape = "vload"
   117  		}
   118  		if shapeOut == OneVregOutAtIn {
   119  			regShape += "ResultInArg0"
   120  		}
   121  		if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
   122  			regShape += "Imm8"
   123  		}
   124  		regShape, err = rewriteVecAsScalarRegInfo(op, regShape)
   125  		if err != nil {
   126  			return err
   127  		}
   128  		if _, ok := regInfoSet[regShape]; !ok {
   129  			allUnseen[regShape] = append(allUnseen[regShape], op)
   130  			allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr)
   131  		}
   132  		regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
   133  		if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) {
   134  			regShapeMerging := regShape
   135  			if shapeOut != OneVregOutAtIn {
   136  				// We have to copy the slice here becasue the sort will be visible from other
   137  				// aliases when no reslicing is happening.
   138  				newIn := make([]Operand, len(op.In), len(op.In)+1)
   139  				copy(newIn, op.In)
   140  				op.In = newIn
   141  				op.In = append(op.In, op.Out[0])
   142  				op.sortOperand()
   143  				regShapeMerging, err = op.regShape(mem)
   144  				regShapeMerging += "ResultInArg0"
   145  			}
   146  			if err != nil {
   147  				return err
   148  			}
   149  			if _, ok := regInfoSet[regShapeMerging]; !ok {
   150  				allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op)
   151  				allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging")
   152  			}
   153  			regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging")
   154  		}
   155  		return nil
   156  	}
   157  	for _, op := range ops {
   158  		shapeIn, shapeOut, maskType, _, gOp := op.shape()
   159  		asm := machineOpName(maskType, gOp)
   160  		if _, ok := seen[asm]; ok {
   161  			continue
   162  		}
   163  		seen[asm] = struct{}{}
   164  		caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
   165  		isZeroMasking := false
   166  		if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
   167  			if gOp.Zeroing == nil || *gOp.Zeroing {
   168  				ZeroingMask = append(ZeroingMask, caseStr)
   169  				isZeroMasking = true
   170  			}
   171  		}
   172  		if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil {
   173  			panic(err)
   174  		}
   175  		if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
   176  			// Make a full vec memory variant
   177  			op = rewriteLastVregToMem(op)
   178  			// Ignore the error
   179  			// an error could be triggered by [checkVecAsScalar].
   180  			// TODO: make [checkVecAsScalar] aware of mem ops.
   181  			if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
   182  				if *Verbose {
   183  					log.Printf("Seen error: %e", err)
   184  				}
   185  			} else if isZeroMasking {
   186  				ZeroingMask = append(ZeroingMask, caseStr+"load")
   187  			}
   188  		}
   189  	}
   190  	if len(allUnseen) != 0 {
   191  		allKeys := make([]string, 0)
   192  		for k := range allUnseen {
   193  			allKeys = append(allKeys, k)
   194  		}
   195  		panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr))
   196  	}
   197  
   198  	buffer := new(bytes.Buffer)
   199  
   200  	if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
   201  		panic(fmt.Errorf("failed to execute header template: %w", err))
   202  	}
   203  
   204  	for _, regShape := range regInfoKeys {
   205  		// Stable traversal of regInfoSet
   206  		cases := regInfoSet[regShape]
   207  		if len(cases) == 0 {
   208  			continue
   209  		}
   210  		data := tplSSAData{
   211  			Cases:  strings.Join(cases, ",\n\t\t"),
   212  			Helper: "simd" + capitalizeFirst(regShape),
   213  		}
   214  		if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
   215  			panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
   216  		}
   217  	}
   218  
   219  	if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
   220  		panic(fmt.Errorf("failed to execute footer template: %w", err))
   221  	}
   222  
   223  	if len(ZeroingMask) != 0 {
   224  		if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
   225  			panic(fmt.Errorf("failed to execute footer template: %w", err))
   226  		}
   227  	}
   228  
   229  	if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
   230  		panic(fmt.Errorf("failed to execute footer template: %w", err))
   231  	}
   232  
   233  	return buffer
   234  }
   235  

View as plain text