// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package main import ( "bytes" "fmt" "log" "strings" "text/template" ) var ( ssaTemplates = template.Must(template.New("simdSSA").Parse(` {{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT. package amd64 import ( "cmd/compile/internal/ssa" "cmd/compile/internal/ssagen" "cmd/internal/obj" "cmd/internal/obj/x86" ) func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool { var p *obj.Prog switch v.Op {{"{"}}{{end}} {{define "case"}} case {{.Cases}}: p = {{.Helper}}(s, v) {{end}} {{define "footer"}} default: // Unknown reg shape return false } {{end}} {{define "zeroing"}} // Masked operation are always compiled with zeroing. switch v.Op { case {{.}}: x86.ParseSuffix(p, "Z") } {{end}} {{define "ending"}} return true } {{end}}`)) ) type tplSSAData struct { Cases string Helper string } // writeSIMDSSA generates the ssa to prog lowering codes and writes it to simdssa.go // within the specified directory. func writeSIMDSSA(ops []Operation) *bytes.Buffer { var ZeroingMask []string regInfoKeys := []string{ "v11", "v21", "v2k", "v2kv", "v2kk", "vkv", "v31", "v3kv", "v11Imm8", "vkvImm8", "v21Imm8", "v2kImm8", "v2kkImm8", "v31ResultInArg0", "v3kvResultInArg0", "vfpv", "vfpkv", "vgpvImm8", "vgpImm8", "v2kvImm8", "vkvload", "v21load", "v31loadResultInArg0", "v3kvloadResultInArg0", "v2kvload", "v2kload", "v11load", "v11loadImm8", "vkvloadImm8", "v21loadImm8", "v2kloadImm8", "v2kkloadImm8", "v2kvloadImm8", "v31ResultInArg0Imm8", "v31loadResultInArg0Imm8", "v21ResultInArg0", "v21ResultInArg0Imm8", "v31x0AtIn2ResultInArg0", "v2kvResultInArg0", } regInfoSet := map[string][]string{} for _, key := range regInfoKeys { regInfoSet[key] = []string{} } seen := map[string]struct{}{} allUnseen := make(map[string][]Operation) allUnseenCaseStr := make(map[string][]string) classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error { regShape, err := op.regShape(mem) if err != nil { return err } if regShape == "v01load" { regShape = "vload" } if shapeOut == OneVregOutAtIn { regShape += "ResultInArg0" } if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn { regShape += "Imm8" } regShape, err = rewriteVecAsScalarRegInfo(op, regShape) if err != nil { return err } if _, ok := regInfoSet[regShape]; !ok { allUnseen[regShape] = append(allUnseen[regShape], op) allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr) } regInfoSet[regShape] = append(regInfoSet[regShape], caseStr) if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) { regShapeMerging := regShape if shapeOut != OneVregOutAtIn { // We have to copy the slice here becasue the sort will be visible from other // aliases when no reslicing is happening. newIn := make([]Operand, len(op.In), len(op.In)+1) copy(newIn, op.In) op.In = newIn op.In = append(op.In, op.Out[0]) op.sortOperand() regShapeMerging, err = op.regShape(mem) regShapeMerging += "ResultInArg0" } if err != nil { return err } if _, ok := regInfoSet[regShapeMerging]; !ok { allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op) allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging") } regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging") } return nil } for _, op := range ops { shapeIn, shapeOut, maskType, _, gOp := op.shape() asm := machineOpName(maskType, gOp) if _, ok := seen[asm]; ok { continue } seen[asm] = struct{}{} caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm) isZeroMasking := false if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn { if gOp.Zeroing == nil || *gOp.Zeroing { ZeroingMask = append(ZeroingMask, caseStr) isZeroMasking = true } } if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil { panic(err) } if op.MemFeatures != nil && *op.MemFeatures == "vbcst" { // Make a full vec memory variant op = rewriteLastVregToMem(op) // Ignore the error // an error could be triggered by [checkVecAsScalar]. // TODO: make [checkVecAsScalar] aware of mem ops. if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil { if *Verbose { log.Printf("Seen error: %e", err) } } else if isZeroMasking { ZeroingMask = append(ZeroingMask, caseStr+"load") } } } if len(allUnseen) != 0 { allKeys := make([]string, 0) for k := range allUnseen { allKeys = append(allKeys, k) } panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr)) } buffer := new(bytes.Buffer) if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil { panic(fmt.Errorf("failed to execute header template: %w", err)) } for _, regShape := range regInfoKeys { // Stable traversal of regInfoSet cases := regInfoSet[regShape] if len(cases) == 0 { continue } data := tplSSAData{ Cases: strings.Join(cases, ",\n\t\t"), Helper: "simd" + capitalizeFirst(regShape), } if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil { panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err)) } } if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil { panic(fmt.Errorf("failed to execute footer template: %w", err)) } if len(ZeroingMask) != 0 { if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil { panic(fmt.Errorf("failed to execute footer template: %w", err)) } } if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil { panic(fmt.Errorf("failed to execute footer template: %w", err)) } return buffer }