Source file src/simd/_gen/simdgen/gen_simdMachineOps.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"log"
    11  	"sort"
    12  	"strings"
    13  )
    14  
    15  const simdMachineOpsTmpl = `
    16  package main
    17  
    18  func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw,
    19  	wkwload, v21load, v31load, v11load, w21load, w31load, w2kload, w2kwload, w11load, w3kwload, w2kkload, v31x0AtIn2 regInfo) []opData {
    20  	return []opData{
    21  {{- range .OpsData }}
    22  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
    23  {{- end }}
    24  {{- range .OpsDataImm }}
    25  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
    26  {{- end }}
    27  {{- range .OpsDataLoad}}
    28  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
    29  {{- end}}
    30  {{- range .OpsDataImmLoad}}
    31  		{name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
    32  {{- end}}
    33  {{- range .OpsDataMerging }}
    34  		{name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true},
    35  {{- end }}
    36  {{- range .OpsDataImmMerging }}
    37  		{name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true},
    38  {{- end }}
    39  	}
    40  }
    41  `
    42  
    43  // writeSIMDMachineOps generates the machine ops and writes it to simdAMD64ops.go
    44  // within the specified directory.
    45  func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
    46  	t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
    47  	buffer := new(bytes.Buffer)
    48  	buffer.WriteString(generatedHeader)
    49  
    50  	type opData struct {
    51  		OpName       string
    52  		Asm          string
    53  		OpInLen      int
    54  		RegInfo      string
    55  		Comm         bool
    56  		Type         string
    57  		ResultInArg0 bool
    58  	}
    59  	type machineOpsData struct {
    60  		OpsData           []opData
    61  		OpsDataImm        []opData
    62  		OpsDataLoad       []opData
    63  		OpsDataImmLoad    []opData
    64  		OpsDataMerging    []opData
    65  		OpsDataImmMerging []opData
    66  	}
    67  
    68  	regInfoSet := map[string]bool{
    69  		"v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
    70  		"w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true,
    71  		"wkwload": true, "v21load": true, "v31load": true, "v11load": true, "w21load": true, "w31load": true, "w2kload": true, "w2kwload": true, "w11load": true,
    72  		"w3kwload": true, "w2kkload": true, "v31x0AtIn2": true}
    73  	opsData := make([]opData, 0)
    74  	opsDataImm := make([]opData, 0)
    75  	opsDataLoad := make([]opData, 0)
    76  	opsDataImmLoad := make([]opData, 0)
    77  	opsDataMerging := make([]opData, 0)
    78  	opsDataImmMerging := make([]opData, 0)
    79  
    80  	// Determine the "best" version of an instruction to use
    81  	best := make(map[string]Operation)
    82  	var mOpOrder []string
    83  	countOverrides := func(s []Operand) int {
    84  		a := 0
    85  		for _, o := range s {
    86  			if o.OverwriteBase != nil {
    87  				a++
    88  			}
    89  		}
    90  		return a
    91  	}
    92  	for _, op := range ops {
    93  		_, _, maskType, _, gOp := op.shape()
    94  		asm := machineOpName(maskType, gOp)
    95  		other, ok := best[asm]
    96  		if !ok {
    97  			best[asm] = op
    98  			mOpOrder = append(mOpOrder, asm)
    99  			continue
   100  		}
   101  		// see if "op" is better than "other"
   102  		if countOverrides(op.In)+countOverrides(op.Out) < countOverrides(other.In)+countOverrides(other.Out) {
   103  			best[asm] = op
   104  		}
   105  	}
   106  
   107  	regInfoErrs := make([]error, 0)
   108  	regInfoMissing := make(map[string]bool, 0)
   109  	for _, asm := range mOpOrder {
   110  		op := best[asm]
   111  		shapeIn, shapeOut, maskType, _, gOp := op.shape()
   112  
   113  		// TODO: all our masked operations are now zeroing, we need to generate machine ops with merging masks, maybe copy
   114  		// one here with a name suffix "Merging". The rewrite rules will need them.
   115  		makeRegInfo := func(op Operation, mem memShape) (string, error) {
   116  			regInfo, err := op.regShape(mem)
   117  			if err != nil {
   118  				panic(err)
   119  			}
   120  			regInfo, err = rewriteVecAsScalarRegInfo(op, regInfo)
   121  			if err != nil {
   122  				if mem == NoMem || mem == InvalidMem {
   123  					panic(err)
   124  				}
   125  				return "", err
   126  			}
   127  			if regInfo == "v01load" {
   128  				regInfo = "vload"
   129  			}
   130  			// Makes AVX512 operations use upper registers
   131  			if strings.Contains(op.CPUFeature, "AVX512") {
   132  				regInfo = strings.ReplaceAll(regInfo, "v", "w")
   133  			}
   134  			if _, ok := regInfoSet[regInfo]; !ok {
   135  				regInfoErrs = append(regInfoErrs, fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s.  Op is %s", regInfo, op))
   136  				regInfoMissing[regInfo] = true
   137  			}
   138  			return regInfo, nil
   139  		}
   140  		regInfo, err := makeRegInfo(op, NoMem)
   141  		if err != nil {
   142  			panic(err)
   143  		}
   144  		var outType string
   145  		if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
   146  			// If class overwrite is happening, that's not really a mask but a vreg.
   147  			outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
   148  		} else if shapeOut == OneGregOut {
   149  			outType = gOp.GoType() // this is a straight Go type, not a VecNNN type
   150  		} else if shapeOut == OneKmaskOut {
   151  			outType = "Mask"
   152  		} else {
   153  			panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
   154  		}
   155  		resultInArg0 := false
   156  		if shapeOut == OneVregOutAtIn {
   157  			resultInArg0 = true
   158  		}
   159  		var memOpData *opData
   160  		regInfoMerging := regInfo
   161  		hasMerging := false
   162  		if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
   163  			// Right now we only have vbcst case
   164  			// Make a full vec memory variant.
   165  			opMem := rewriteLastVregToMem(op)
   166  			regInfo, err := makeRegInfo(opMem, VregMemIn)
   167  			if err != nil {
   168  				// Just skip it if it's non nill.
   169  				// an error could be triggered by [checkVecAsScalar].
   170  				// TODO: make [checkVecAsScalar] aware of mem ops.
   171  				if *Verbose {
   172  					log.Printf("Seen error: %e", err)
   173  				}
   174  			} else {
   175  				memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0}
   176  			}
   177  		}
   178  		hasMerging = gOp.hasMaskedMerging(maskType, shapeOut)
   179  		if hasMerging && !resultInArg0 {
   180  			// We have to copy the slice here becasue the sort will be visible from other
   181  			// aliases when no reslicing is happening.
   182  			newIn := make([]Operand, len(op.In), len(op.In)+1)
   183  			copy(newIn, op.In)
   184  			op.In = newIn
   185  			op.In = append(op.In, op.Out[0])
   186  			op.sortOperand()
   187  			regInfoMerging, err = makeRegInfo(op, NoMem)
   188  			if err != nil {
   189  				panic(err)
   190  			}
   191  		}
   192  
   193  		if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
   194  			opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
   195  			if memOpData != nil {
   196  				if *op.MemFeatures != "vbcst" {
   197  					panic("simdgen only knows vbcst for mem ops for now")
   198  				}
   199  				opsDataImmLoad = append(opsDataImmLoad, *memOpData)
   200  			}
   201  			if hasMerging {
   202  				mergingLen := len(gOp.In)
   203  				if !resultInArg0 {
   204  					mergingLen++
   205  				}
   206  				opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
   207  			}
   208  		} else {
   209  			opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
   210  			if memOpData != nil {
   211  				if *op.MemFeatures != "vbcst" {
   212  					panic("simdgen only knows vbcst for mem ops for now")
   213  				}
   214  				opsDataLoad = append(opsDataLoad, *memOpData)
   215  			}
   216  			if hasMerging {
   217  				mergingLen := len(gOp.In)
   218  				if !resultInArg0 {
   219  					mergingLen++
   220  				}
   221  				opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
   222  			}
   223  		}
   224  	}
   225  	if len(regInfoErrs) != 0 {
   226  		for _, e := range regInfoErrs {
   227  			log.Printf("Errors: %e\n", e)
   228  		}
   229  		panic(fmt.Errorf("these regInfo unseen: %v", regInfoMissing))
   230  	}
   231  	sort.Slice(opsData, func(i, j int) bool {
   232  		return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
   233  	})
   234  	sort.Slice(opsDataImm, func(i, j int) bool {
   235  		return compareNatural(opsDataImm[i].OpName, opsDataImm[j].OpName) < 0
   236  	})
   237  	sort.Slice(opsDataLoad, func(i, j int) bool {
   238  		return compareNatural(opsDataLoad[i].OpName, opsDataLoad[j].OpName) < 0
   239  	})
   240  	sort.Slice(opsDataImmLoad, func(i, j int) bool {
   241  		return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0
   242  	})
   243  	sort.Slice(opsDataMerging, func(i, j int) bool {
   244  		return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0
   245  	})
   246  	sort.Slice(opsDataImmMerging, func(i, j int) bool {
   247  		return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0
   248  	})
   249  	err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad,
   250  		opsDataMerging, opsDataImmMerging})
   251  	if err != nil {
   252  		panic(fmt.Errorf("failed to execute template: %w", err))
   253  	}
   254  
   255  	return buffer
   256  }
   257  

View as plain text