1
2
3
4
5 package main
6
7 import (
8 "fmt"
9 "sync"
10
11 . "github.com/mmcloughlin/avo/build"
12 . "github.com/mmcloughlin/avo/operand"
13 . "github.com/mmcloughlin/avo/reg"
14 )
15
16
17
18 func main() {
19 Package("crypto/internal/fips140/aes")
20 ConstraintExpr("!purego")
21
22 ctrBlocks(1)
23 ctrBlocks(2)
24 ctrBlocks(4)
25 ctrBlocks(8)
26
27 Generate()
28 }
29
30 func ctrBlocks(numBlocks int) {
31 Implement(fmt.Sprintf("ctrBlocks%dAsm", numBlocks))
32
33 rounds := Load(Param("nr"), GP64())
34 xk := Load(Param("xk"), GP64())
35 dst := Load(Param("dst"), GP64())
36 src := Load(Param("src"), GP64())
37 ivlo := Load(Param("ivlo"), GP64())
38 ivhi := Load(Param("ivhi"), GP64())
39
40 bswap := XMM()
41 MOVOU(bswapMask(), bswap)
42
43 blocks := make([]VecVirtual, numBlocks)
44
45
46
47
48
49
50 if numBlocks == 8 {
51 for i := range blocks {
52 blocks[i] = XMM()
53 }
54
55 base := XMM()
56 tmp := GP64()
57 addVec := XMM()
58
59 MOVQ(ivlo, blocks[0])
60 PINSRQ(Imm(1), ivhi, blocks[0])
61 MOVAPS(blocks[0], base)
62 PSHUFB(bswap, blocks[0])
63
64
65 MOVQ(ivlo, tmp)
66 ADDQ(Imm(uint64(numBlocks-1)), tmp)
67 slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks)
68 doneLabel := fmt.Sprintf("ctr%d_done", numBlocks)
69 JC(LabelRef(slowLabel))
70
71
72
73 XORQ(tmp, tmp)
74 INCQ(tmp)
75 PXOR(addVec, addVec)
76 PINSRQ(Imm(0), tmp, addVec)
77
78 for i := 1; i < numBlocks; i++ {
79 PADDQ(addVec, base)
80 MOVAPS(base, blocks[i])
81 }
82 JMP(LabelRef(doneLabel))
83
84 Label(slowLabel)
85 ADDQ(Imm(1), ivlo)
86 ADCQ(Imm(0), ivhi)
87 for i := 1; i < numBlocks; i++ {
88 MOVQ(ivlo, blocks[i])
89 PINSRQ(Imm(1), ivhi, blocks[i])
90 if i < numBlocks-1 {
91 ADDQ(Imm(1), ivlo)
92 ADCQ(Imm(0), ivhi)
93 }
94 }
95
96 Label(doneLabel)
97
98
99
100 for i := 1; i < numBlocks; i++ {
101 PSHUFB(bswap, blocks[i])
102 }
103 } else {
104
105 for i := 0; i < numBlocks; i++ {
106 x := XMM()
107 blocks[i] = x
108
109 MOVQ(ivlo, x)
110 PINSRQ(Imm(1), ivhi, x)
111 PSHUFB(bswap, x)
112 if i < numBlocks-1 {
113 ADDQ(Imm(1), ivlo)
114 ADCQ(Imm(0), ivhi)
115 }
116 }
117 }
118
119
120 aesRoundStart(blocks, Mem{Base: xk})
121 ADDQ(Imm(16), xk)
122
123
124 SUBQ(Imm(12), rounds)
125 JE(LabelRef("enc192"))
126 JB(LabelRef("enc128"))
127
128
129 aesRound(blocks, Mem{Base: xk})
130 aesRound(blocks, Mem{Base: xk}.Offset(16))
131 ADDQ(Imm(32), xk)
132
133
134 Label("enc192")
135 aesRound(blocks, Mem{Base: xk})
136 aesRound(blocks, Mem{Base: xk}.Offset(16))
137 ADDQ(Imm(32), xk)
138
139
140 Label("enc128")
141 for i := 0; i < 9; i++ {
142 aesRound(blocks, Mem{Base: xk}.Offset(16*i))
143 }
144 aesRoundLast(blocks, Mem{Base: xk}.Offset(16*9))
145
146
147 for i, b := range blocks {
148 x := XMM()
149
150 MOVUPS(Mem{Base: src}.Offset(16*i), x)
151 PXOR(b, x)
152 MOVUPS(x, Mem{Base: dst}.Offset(16*i))
153 }
154
155 RET()
156 }
157
158 func aesRoundStart(blocks []VecVirtual, k Mem) {
159 x := XMM()
160 MOVUPS(k, x)
161 for _, b := range blocks {
162 PXOR(x, b)
163 }
164 }
165
166 func aesRound(blocks []VecVirtual, k Mem) {
167 x := XMM()
168 MOVUPS(k, x)
169 for _, b := range blocks {
170 AESENC(x, b)
171 }
172 }
173
174 func aesRoundLast(blocks []VecVirtual, k Mem) {
175 x := XMM()
176 MOVUPS(k, x)
177 for _, b := range blocks {
178 AESENCLAST(x, b)
179 }
180 }
181
182 var bswapMask = sync.OnceValue(func() Mem {
183 bswapMask := GLOBL("bswapMask", NOPTR|RODATA)
184 DATA(0x00, U64(0x08090a0b0c0d0e0f))
185 DATA(0x08, U64(0x0001020304050607))
186 return bswapMask
187 })
188
View as plain text