1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "fmt"
10 "log"
11 "strings"
12 "text/template"
13 )
14
15 var (
16 ssaTemplates = template.Must(template.New("simdSSA").Parse(`
17 {{define "header"}}// Code generated by x/arch/internal/simdgen using 'go run . -xedPath $XED_PATH -o godefs -goroot $GOROOT go.yaml types.yaml categories.yaml'; DO NOT EDIT.
18
19 package amd64
20
21 import (
22 "cmd/compile/internal/ssa"
23 "cmd/compile/internal/ssagen"
24 "cmd/internal/obj"
25 "cmd/internal/obj/x86"
26 )
27
28 func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
29 var p *obj.Prog
30 switch v.Op {{"{"}}{{end}}
31 {{define "case"}}
32 case {{.Cases}}:
33 p = {{.Helper}}(s, v)
34 {{end}}
35 {{define "footer"}}
36 default:
37 // Unknown reg shape
38 return false
39 }
40 {{end}}
41 {{define "zeroing"}}
42 // Masked operation are always compiled with zeroing.
43 switch v.Op {
44 case {{.}}:
45 x86.ParseSuffix(p, "Z")
46 }
47 {{end}}
48 {{define "ending"}}
49 return true
50 }
51 {{end}}`))
52 )
53
54 type tplSSAData struct {
55 Cases string
56 Helper string
57 }
58
59
60
61 func writeSIMDSSA(ops []Operation) *bytes.Buffer {
62 var ZeroingMask []string
63 regInfoKeys := []string{
64 "v11",
65 "v21",
66 "v2k",
67 "v2kv",
68 "v2kk",
69 "vkv",
70 "v31",
71 "v3kv",
72 "v11Imm8",
73 "vkvImm8",
74 "v21Imm8",
75 "v2kImm8",
76 "v2kkImm8",
77 "v31ResultInArg0",
78 "v3kvResultInArg0",
79 "vfpv",
80 "vfpkv",
81 "vgpvImm8",
82 "vgpImm8",
83 "v2kvImm8",
84 "vkvload",
85 "v21load",
86 "v31loadResultInArg0",
87 "v3kvloadResultInArg0",
88 "v2kvload",
89 "v2kload",
90 "v11load",
91 "v11loadImm8",
92 "vkvloadImm8",
93 "v21loadImm8",
94 "v2kloadImm8",
95 "v2kkloadImm8",
96 "v2kvloadImm8",
97 "v31ResultInArg0Imm8",
98 "v31loadResultInArg0Imm8",
99 "v21ResultInArg0",
100 "v21ResultInArg0Imm8",
101 "v31x0AtIn2ResultInArg0",
102 "v2kvResultInArg0",
103 }
104 regInfoSet := map[string][]string{}
105 for _, key := range regInfoKeys {
106 regInfoSet[key] = []string{}
107 }
108
109 seen := map[string]struct{}{}
110 allUnseen := make(map[string][]Operation)
111 allUnseenCaseStr := make(map[string][]string)
112 classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
113 regShape, err := op.regShape(mem)
114 if err != nil {
115 return err
116 }
117 if regShape == "v01load" {
118 regShape = "vload"
119 }
120 if shapeOut == OneVregOutAtIn {
121 regShape += "ResultInArg0"
122 }
123 if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
124 regShape += "Imm8"
125 }
126 regShape, err = rewriteVecAsScalarRegInfo(op, regShape)
127 if err != nil {
128 return err
129 }
130 if _, ok := regInfoSet[regShape]; !ok {
131 allUnseen[regShape] = append(allUnseen[regShape], op)
132 allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr)
133 }
134 regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
135 if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) {
136 regShapeMerging := regShape
137 if shapeOut != OneVregOutAtIn {
138
139
140 newIn := make([]Operand, len(op.In), len(op.In)+1)
141 copy(newIn, op.In)
142 op.In = newIn
143 op.In = append(op.In, op.Out[0])
144 op.sortOperand()
145 regShapeMerging, err = op.regShape(mem)
146 regShapeMerging += "ResultInArg0"
147 }
148 if err != nil {
149 return err
150 }
151 if _, ok := regInfoSet[regShapeMerging]; !ok {
152 allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op)
153 allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging")
154 }
155 regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging")
156 }
157 return nil
158 }
159 for _, op := range ops {
160 shapeIn, shapeOut, maskType, _, gOp := op.shape()
161 asm := machineOpName(maskType, gOp)
162 if _, ok := seen[asm]; ok {
163 continue
164 }
165 seen[asm] = struct{}{}
166 caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
167 isZeroMasking := false
168 if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
169 if gOp.Zeroing == nil || *gOp.Zeroing {
170 ZeroingMask = append(ZeroingMask, caseStr)
171 isZeroMasking = true
172 }
173 }
174 if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil {
175 panic(err)
176 }
177 if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
178
179 op = rewriteLastVregToMem(op)
180
181
182
183 if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
184 if *Verbose {
185 log.Printf("Seen error: %e", err)
186 }
187 } else if isZeroMasking {
188 ZeroingMask = append(ZeroingMask, caseStr+"load")
189 }
190 }
191 }
192 if len(allUnseen) != 0 {
193 allKeys := make([]string, 0)
194 for k := range allUnseen {
195 allKeys = append(allKeys, k)
196 }
197 panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr))
198 }
199
200 buffer := new(bytes.Buffer)
201
202 if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
203 panic(fmt.Errorf("failed to execute header template: %w", err))
204 }
205
206 for _, regShape := range regInfoKeys {
207
208 cases := regInfoSet[regShape]
209 if len(cases) == 0 {
210 continue
211 }
212 data := tplSSAData{
213 Cases: strings.Join(cases, ",\n\t\t"),
214 Helper: "simd" + capitalizeFirst(regShape),
215 }
216 if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
217 panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
218 }
219 }
220
221 if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
222 panic(fmt.Errorf("failed to execute footer template: %w", err))
223 }
224
225 if len(ZeroingMask) != 0 {
226 if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
227 panic(fmt.Errorf("failed to execute footer template: %w", err))
228 }
229 }
230
231 if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
232 panic(fmt.Errorf("failed to execute footer template: %w", err))
233 }
234
235 return buffer
236 }
237
View as plain text