Source file
src/simd/_gen/simdgen/gen_simdMachineOps.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "fmt"
10 "log"
11 "sort"
12 "strings"
13 )
14
15 const simdMachineOpsTmpl = `
16 package main
17
18 func simdAMD64Ops(v11, v21, v2k, vkv, v2kv, v2kk, v31, v3kv, vgpv, vgp, vfpv, vfpkv, w11, w21, w2k, wkw, w2kw, w2kk, w31, w3kw, wgpw, wgp, wfpw, wfpkw,
19 wkwload, v21load, v31load, v11load, w21load, w31load, w2kload, w2kwload, w11load, w3kwload, w2kkload, v31x0AtIn2 regInfo) []opData {
20 return []opData{
21 {{- range .OpsData }}
22 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
23 {{- end }}
24 {{- range .OpsDataImm }}
25 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: {{.Comm}}, typ: "{{.Type}}", resultInArg0: {{.ResultInArg0}}},
26 {{- end }}
27 {{- range .OpsDataLoad}}
28 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
29 {{- end}}
30 {{- range .OpsDataImmLoad}}
31 {name: "{{.OpName}}", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: {{.Comm}}, typ: "{{.Type}}", aux: "SymValAndOff", symEffect: "Read", resultInArg0: {{.ResultInArg0}}},
32 {{- end}}
33 {{- range .OpsDataMerging }}
34 {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", commutative: false, typ: "{{.Type}}", resultInArg0: true},
35 {{- end }}
36 {{- range .OpsDataImmMerging }}
37 {name: "{{.OpName}}Merging", argLength: {{.OpInLen}}, reg: {{.RegInfo}}, asm: "{{.Asm}}", aux: "UInt8", commutative: false, typ: "{{.Type}}", resultInArg0: true},
38 {{- end }}
39 }
40 }
41 `
42
43
44
45 func writeSIMDMachineOps(ops []Operation) *bytes.Buffer {
46 t := templateOf(simdMachineOpsTmpl, "simdAMD64Ops")
47 buffer := new(bytes.Buffer)
48 buffer.WriteString(generatedHeader)
49
50 type opData struct {
51 OpName string
52 Asm string
53 OpInLen int
54 RegInfo string
55 Comm bool
56 Type string
57 ResultInArg0 bool
58 }
59 type machineOpsData struct {
60 OpsData []opData
61 OpsDataImm []opData
62 OpsDataLoad []opData
63 OpsDataImmLoad []opData
64 OpsDataMerging []opData
65 OpsDataImmMerging []opData
66 }
67
68 regInfoSet := map[string]bool{
69 "v11": true, "v21": true, "v2k": true, "v2kv": true, "v2kk": true, "vkv": true, "v31": true, "v3kv": true, "vgpv": true, "vgp": true, "vfpv": true, "vfpkv": true,
70 "w11": true, "w21": true, "w2k": true, "w2kw": true, "w2kk": true, "wkw": true, "w31": true, "w3kw": true, "wgpw": true, "wgp": true, "wfpw": true, "wfpkw": true,
71 "wkwload": true, "v21load": true, "v31load": true, "v11load": true, "w21load": true, "w31load": true, "w2kload": true, "w2kwload": true, "w11load": true,
72 "w3kwload": true, "w2kkload": true, "v31x0AtIn2": true}
73 opsData := make([]opData, 0)
74 opsDataImm := make([]opData, 0)
75 opsDataLoad := make([]opData, 0)
76 opsDataImmLoad := make([]opData, 0)
77 opsDataMerging := make([]opData, 0)
78 opsDataImmMerging := make([]opData, 0)
79
80
81 best := make(map[string]Operation)
82 var mOpOrder []string
83 countOverrides := func(s []Operand) int {
84 a := 0
85 for _, o := range s {
86 if o.OverwriteBase != nil {
87 a++
88 }
89 }
90 return a
91 }
92 for _, op := range ops {
93 _, _, maskType, _, gOp := op.shape()
94 asm := machineOpName(maskType, gOp)
95 other, ok := best[asm]
96 if !ok {
97 best[asm] = op
98 mOpOrder = append(mOpOrder, asm)
99 continue
100 }
101
102 if countOverrides(op.In)+countOverrides(op.Out) < countOverrides(other.In)+countOverrides(other.Out) {
103 best[asm] = op
104 }
105 }
106
107 regInfoErrs := make([]error, 0)
108 regInfoMissing := make(map[string]bool, 0)
109 for _, asm := range mOpOrder {
110 op := best[asm]
111 shapeIn, shapeOut, maskType, _, gOp := op.shape()
112
113
114
115 makeRegInfo := func(op Operation, mem memShape) (string, error) {
116 regInfo, err := op.regShape(mem)
117 if err != nil {
118 panic(err)
119 }
120 regInfo, err = rewriteVecAsScalarRegInfo(op, regInfo)
121 if err != nil {
122 if mem == NoMem || mem == InvalidMem {
123 panic(err)
124 }
125 return "", err
126 }
127 if regInfo == "v01load" {
128 regInfo = "vload"
129 }
130
131 if strings.Contains(op.CPUFeature, "AVX512") {
132 regInfo = strings.ReplaceAll(regInfo, "v", "w")
133 }
134 if _, ok := regInfoSet[regInfo]; !ok {
135 regInfoErrs = append(regInfoErrs, fmt.Errorf("unsupported register constraint, please update the template and AMD64Ops.go: %s. Op is %s", regInfo, op))
136 regInfoMissing[regInfo] = true
137 }
138 return regInfo, nil
139 }
140 regInfo, err := makeRegInfo(op, NoMem)
141 if err != nil {
142 panic(err)
143 }
144 var outType string
145 if shapeOut == OneVregOut || shapeOut == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
146
147 outType = fmt.Sprintf("Vec%d", *gOp.Out[0].Bits)
148 } else if shapeOut == OneGregOut {
149 outType = gOp.GoType()
150 } else if shapeOut == OneKmaskOut {
151 outType = "Mask"
152 } else {
153 panic(fmt.Errorf("simdgen does not recognize this output shape: %d", shapeOut))
154 }
155 resultInArg0 := false
156 if shapeOut == OneVregOutAtIn {
157 resultInArg0 = true
158 }
159 var memOpData *opData
160 regInfoMerging := regInfo
161 hasMerging := false
162 if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
163
164
165 opMem := rewriteLastVregToMem(op)
166 regInfo, err := makeRegInfo(opMem, VregMemIn)
167 if err != nil {
168
169
170
171 if *Verbose {
172 log.Printf("Seen error: %e", err)
173 }
174 } else {
175 memOpData = &opData{asm + "load", gOp.Asm, len(gOp.In) + 1, regInfo, false, outType, resultInArg0}
176 }
177 }
178 hasMerging = gOp.hasMaskedMerging(maskType, shapeOut)
179 if hasMerging && !resultInArg0 {
180
181
182 newIn := make([]Operand, len(op.In), len(op.In)+1)
183 copy(newIn, op.In)
184 op.In = newIn
185 op.In = append(op.In, op.Out[0])
186 op.sortOperand()
187 regInfoMerging, err = makeRegInfo(op, NoMem)
188 if err != nil {
189 panic(err)
190 }
191 }
192
193 if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
194 opsDataImm = append(opsDataImm, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
195 if memOpData != nil {
196 if *op.MemFeatures != "vbcst" {
197 panic("simdgen only knows vbcst for mem ops for now")
198 }
199 opsDataImmLoad = append(opsDataImmLoad, *memOpData)
200 }
201 if hasMerging {
202 mergingLen := len(gOp.In)
203 if !resultInArg0 {
204 mergingLen++
205 }
206 opsDataImmMerging = append(opsDataImmMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
207 }
208 } else {
209 opsData = append(opsData, opData{asm, gOp.Asm, len(gOp.In), regInfo, gOp.Commutative, outType, resultInArg0})
210 if memOpData != nil {
211 if *op.MemFeatures != "vbcst" {
212 panic("simdgen only knows vbcst for mem ops for now")
213 }
214 opsDataLoad = append(opsDataLoad, *memOpData)
215 }
216 if hasMerging {
217 mergingLen := len(gOp.In)
218 if !resultInArg0 {
219 mergingLen++
220 }
221 opsDataMerging = append(opsDataMerging, opData{asm, gOp.Asm, mergingLen, regInfoMerging, gOp.Commutative, outType, resultInArg0})
222 }
223 }
224 }
225 if len(regInfoErrs) != 0 {
226 for _, e := range regInfoErrs {
227 log.Printf("Errors: %e\n", e)
228 }
229 panic(fmt.Errorf("these regInfo unseen: %v", regInfoMissing))
230 }
231 sort.Slice(opsData, func(i, j int) bool {
232 return compareNatural(opsData[i].OpName, opsData[j].OpName) < 0
233 })
234 sort.Slice(opsDataImm, func(i, j int) bool {
235 return compareNatural(opsDataImm[i].OpName, opsDataImm[j].OpName) < 0
236 })
237 sort.Slice(opsDataLoad, func(i, j int) bool {
238 return compareNatural(opsDataLoad[i].OpName, opsDataLoad[j].OpName) < 0
239 })
240 sort.Slice(opsDataImmLoad, func(i, j int) bool {
241 return compareNatural(opsDataImmLoad[i].OpName, opsDataImmLoad[j].OpName) < 0
242 })
243 sort.Slice(opsDataMerging, func(i, j int) bool {
244 return compareNatural(opsDataMerging[i].OpName, opsDataMerging[j].OpName) < 0
245 })
246 sort.Slice(opsDataImmMerging, func(i, j int) bool {
247 return compareNatural(opsDataImmMerging[i].OpName, opsDataImmMerging[j].OpName) < 0
248 })
249 err := t.Execute(buffer, machineOpsData{opsData, opsDataImm, opsDataLoad, opsDataImmLoad,
250 opsDataMerging, opsDataImmMerging})
251 if err != nil {
252 panic(fmt.Errorf("failed to execute template: %w", err))
253 }
254
255 return buffer
256 }
257
View as plain text