1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "fmt"
10 "log"
11 "strings"
12 "text/template"
13 )
14
15 var (
16 ssaTemplates = template.Must(template.New("simdSSA").Parse(`{{define "header"}}` + generatedHeader + `
17 package amd64
18
19 import (
20 "cmd/compile/internal/ssa"
21 "cmd/compile/internal/ssagen"
22 "cmd/internal/obj"
23 "cmd/internal/obj/x86"
24 )
25
26 func ssaGenSIMDValue(s *ssagen.State, v *ssa.Value) bool {
27 var p *obj.Prog
28 switch v.Op {{"{"}}{{end}}
29 {{define "case"}}
30 case {{.Cases}}:
31 p = {{.Helper}}(s, v)
32 {{end}}
33 {{define "footer"}}
34 default:
35 // Unknown reg shape
36 return false
37 }
38 {{end}}
39 {{define "zeroing"}}
40 // Masked operation are always compiled with zeroing.
41 switch v.Op {
42 case {{.}}:
43 x86.ParseSuffix(p, "Z")
44 }
45 {{end}}
46 {{define "ending"}}
47 return true
48 }
49 {{end}}`))
50 )
51
52 type tplSSAData struct {
53 Cases string
54 Helper string
55 }
56
57
58
59 func writeSIMDSSA(ops []Operation) *bytes.Buffer {
60 var ZeroingMask []string
61 regInfoKeys := []string{
62 "v11",
63 "v21",
64 "v2k",
65 "v2kv",
66 "v2kk",
67 "vkv",
68 "v31",
69 "v3kv",
70 "v11Imm8",
71 "vkvImm8",
72 "v21Imm8",
73 "v2kImm8",
74 "v2kkImm8",
75 "v31ResultInArg0",
76 "v3kvResultInArg0",
77 "vfpv",
78 "vfpkv",
79 "vgpvImm8",
80 "vgpImm8",
81 "v2kvImm8",
82 "vkvload",
83 "v21load",
84 "v31loadResultInArg0",
85 "v3kvloadResultInArg0",
86 "v2kvload",
87 "v2kload",
88 "v11load",
89 "v11loadImm8",
90 "vkvloadImm8",
91 "v21loadImm8",
92 "v2kloadImm8",
93 "v2kkloadImm8",
94 "v2kvloadImm8",
95 "v31ResultInArg0Imm8",
96 "v31loadResultInArg0Imm8",
97 "v21ResultInArg0",
98 "v21ResultInArg0Imm8",
99 "v31x0AtIn2ResultInArg0",
100 "v2kvResultInArg0",
101 }
102 regInfoSet := map[string][]string{}
103 for _, key := range regInfoKeys {
104 regInfoSet[key] = []string{}
105 }
106
107 seen := map[string]struct{}{}
108 allUnseen := make(map[string][]Operation)
109 allUnseenCaseStr := make(map[string][]string)
110 classifyOp := func(op Operation, maskType maskShape, shapeIn inShape, shapeOut outShape, caseStr string, mem memShape) error {
111 regShape, err := op.regShape(mem)
112 if err != nil {
113 return err
114 }
115 if regShape == "v01load" {
116 regShape = "vload"
117 }
118 if shapeOut == OneVregOutAtIn {
119 regShape += "ResultInArg0"
120 }
121 if shapeIn == OneImmIn || shapeIn == OneKmaskImmIn {
122 regShape += "Imm8"
123 }
124 regShape, err = rewriteVecAsScalarRegInfo(op, regShape)
125 if err != nil {
126 return err
127 }
128 if _, ok := regInfoSet[regShape]; !ok {
129 allUnseen[regShape] = append(allUnseen[regShape], op)
130 allUnseenCaseStr[regShape] = append(allUnseenCaseStr[regShape], caseStr)
131 }
132 regInfoSet[regShape] = append(regInfoSet[regShape], caseStr)
133 if mem == NoMem && op.hasMaskedMerging(maskType, shapeOut) {
134 regShapeMerging := regShape
135 if shapeOut != OneVregOutAtIn {
136
137
138 newIn := make([]Operand, len(op.In), len(op.In)+1)
139 copy(newIn, op.In)
140 op.In = newIn
141 op.In = append(op.In, op.Out[0])
142 op.sortOperand()
143 regShapeMerging, err = op.regShape(mem)
144 regShapeMerging += "ResultInArg0"
145 }
146 if err != nil {
147 return err
148 }
149 if _, ok := regInfoSet[regShapeMerging]; !ok {
150 allUnseen[regShapeMerging] = append(allUnseen[regShapeMerging], op)
151 allUnseenCaseStr[regShapeMerging] = append(allUnseenCaseStr[regShapeMerging], caseStr+"Merging")
152 }
153 regInfoSet[regShapeMerging] = append(regInfoSet[regShapeMerging], caseStr+"Merging")
154 }
155 return nil
156 }
157 for _, op := range ops {
158 shapeIn, shapeOut, maskType, _, gOp := op.shape()
159 asm := machineOpName(maskType, gOp)
160 if _, ok := seen[asm]; ok {
161 continue
162 }
163 seen[asm] = struct{}{}
164 caseStr := fmt.Sprintf("ssa.OpAMD64%s", asm)
165 isZeroMasking := false
166 if shapeIn == OneKmaskIn || shapeIn == OneKmaskImmIn {
167 if gOp.Zeroing == nil || *gOp.Zeroing {
168 ZeroingMask = append(ZeroingMask, caseStr)
169 isZeroMasking = true
170 }
171 }
172 if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr, NoMem); err != nil {
173 panic(err)
174 }
175 if op.MemFeatures != nil && *op.MemFeatures == "vbcst" {
176
177 op = rewriteLastVregToMem(op)
178
179
180
181 if err := classifyOp(op, maskType, shapeIn, shapeOut, caseStr+"load", VregMemIn); err != nil {
182 if *Verbose {
183 log.Printf("Seen error: %e", err)
184 }
185 } else if isZeroMasking {
186 ZeroingMask = append(ZeroingMask, caseStr+"load")
187 }
188 }
189 }
190 if len(allUnseen) != 0 {
191 allKeys := make([]string, 0)
192 for k := range allUnseen {
193 allKeys = append(allKeys, k)
194 }
195 panic(fmt.Errorf("unsupported register constraint for prog, please update gen_simdssa.go and amd64/ssa.go: %+v\nAll keys: %v\n, cases: %v\n", allUnseen, allKeys, allUnseenCaseStr))
196 }
197
198 buffer := new(bytes.Buffer)
199
200 if err := ssaTemplates.ExecuteTemplate(buffer, "header", nil); err != nil {
201 panic(fmt.Errorf("failed to execute header template: %w", err))
202 }
203
204 for _, regShape := range regInfoKeys {
205
206 cases := regInfoSet[regShape]
207 if len(cases) == 0 {
208 continue
209 }
210 data := tplSSAData{
211 Cases: strings.Join(cases, ",\n\t\t"),
212 Helper: "simd" + capitalizeFirst(regShape),
213 }
214 if err := ssaTemplates.ExecuteTemplate(buffer, "case", data); err != nil {
215 panic(fmt.Errorf("failed to execute case template for %s: %w", regShape, err))
216 }
217 }
218
219 if err := ssaTemplates.ExecuteTemplate(buffer, "footer", nil); err != nil {
220 panic(fmt.Errorf("failed to execute footer template: %w", err))
221 }
222
223 if len(ZeroingMask) != 0 {
224 if err := ssaTemplates.ExecuteTemplate(buffer, "zeroing", strings.Join(ZeroingMask, ",\n\t\t")); err != nil {
225 panic(fmt.Errorf("failed to execute footer template: %w", err))
226 }
227 }
228
229 if err := ssaTemplates.ExecuteTemplate(buffer, "ending", nil); err != nil {
230 panic(fmt.Errorf("failed to execute footer template: %w", err))
231 }
232
233 return buffer
234 }
235
View as plain text