1
2
3
4
5 package cryptotest
6
7 import (
8 "bytes"
9 "crypto/cipher"
10 "crypto/subtle"
11 "fmt"
12 "strings"
13 "testing"
14 )
15
16
17 var (
18 bufLens = []int{0, 1, 3, 4, 8, 10, 15, 16, 20, 32, 50, 4096, 5000}
19 bufCap = 10000
20 )
21
22
23
24
25
26 type MakeStream func() cipher.Stream
27
28
29
30 func TestStream(t *testing.T, ms MakeStream) {
31
32 t.Run("XORSemantics", func(t *testing.T) {
33 if strings.Contains(t.Name(), "TestCFBStream") {
34
35
36 t.Skip("CFB implements cipher.Stream but does not follow XOR semantics")
37 }
38
39
40 t.Run("Roundtrip", func(t *testing.T) {
41
42 for _, length := range bufLens {
43 t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
44 rng := newRandReader(t)
45
46 plaintext := make([]byte, length)
47 rng.Read(plaintext)
48
49 ciphertext := make([]byte, length)
50 decrypted := make([]byte, length)
51
52 ms().XORKeyStream(ciphertext, plaintext)
53 ms().XORKeyStream(decrypted, ciphertext)
54 if !bytes.Equal(decrypted, plaintext) {
55 t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %s, want %s", truncateHex(decrypted), truncateHex(plaintext))
56 }
57 })
58 }
59 })
60
61
62
63 t.Run("DirectXOR", func(t *testing.T) {
64
65 for _, length := range bufLens {
66 t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
67 rng := newRandReader(t)
68
69 plaintext := make([]byte, length)
70 rng.Read(plaintext)
71
72
73 stream, directXOR := make([]byte, length), make([]byte, length)
74 ms().XORKeyStream(stream, stream)
75
76 subtle.XORBytes(directXOR, stream, plaintext)
77
78
79 ciphertext := make([]byte, length)
80 ms().XORKeyStream(ciphertext, plaintext)
81 if !bytes.Equal(ciphertext, directXOR) {
82 t.Errorf("xor semantics were not preserved; got %s, want %s", truncateHex(ciphertext), truncateHex(directXOR))
83 }
84 })
85 }
86 })
87 })
88
89 t.Run("EmptyInput", func(t *testing.T) {
90 rng := newRandReader(t)
91
92 src, dst := make([]byte, 100), make([]byte, 100)
93 rng.Read(dst)
94 before := bytes.Clone(dst)
95
96 ms().XORKeyStream(dst, src[:0])
97 if !bytes.Equal(dst, before) {
98 t.Errorf("XORKeyStream modified dst on empty input; got %s, want %s", truncateHex(dst), truncateHex(before))
99 }
100 })
101
102 t.Run("AlterInput", func(t *testing.T) {
103 rng := newRandReader(t)
104 src, dst, before := make([]byte, bufCap), make([]byte, bufCap), make([]byte, bufCap)
105 rng.Read(src)
106
107 for _, length := range bufLens {
108
109 t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
110 copy(before, src)
111
112 ms().XORKeyStream(dst[:length], src[:length])
113 if !bytes.Equal(src, before) {
114 t.Errorf("XORKeyStream modified src; got %s, want %s", truncateHex(src), truncateHex(before))
115 }
116 })
117 }
118 })
119
120 t.Run("Aliasing", func(t *testing.T) {
121 rng := newRandReader(t)
122
123 buff, expectedOutput := make([]byte, bufCap), make([]byte, bufCap)
124
125 for _, length := range bufLens {
126
127 rng.Read(buff)
128 ms().XORKeyStream(expectedOutput[:length], buff[:length])
129
130
131
132 ms().XORKeyStream(buff[:length], buff[:length])
133 if !bytes.Equal(buff[:length], expectedOutput[:length]) {
134 t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff[:length], expectedOutput[:length])
135 }
136 }
137 })
138
139 t.Run("OutOfBoundsWrite", func(t *testing.T) {
140 rng := newRandReader(t)
141
142 plaintext := make([]byte, bufCap)
143 rng.Read(plaintext)
144 ciphertext := make([]byte, bufCap)
145
146 for _, length := range bufLens {
147 copy(ciphertext, plaintext)
148
149 t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
150 mustPanic(t, "output smaller than input", func() { ms().XORKeyStream(ciphertext[:length], plaintext) })
151
152 if !bytes.Equal(ciphertext[length:], plaintext[length:]) {
153 t.Errorf("XORKeyStream did out of bounds write; got %s, want %s", truncateHex(ciphertext[length:]), truncateHex(plaintext[length:]))
154 }
155 })
156 }
157 })
158
159 t.Run("BufferOverlap", func(t *testing.T) {
160 rng := newRandReader(t)
161
162 buff := make([]byte, bufCap)
163 rng.Read(buff)
164
165 for _, length := range bufLens {
166 if length == 0 || length == 1 {
167 continue
168 }
169
170 t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
171
172 src := buff[:length]
173 dst := buff[1 : length+1]
174 mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
175
176
177 src = buff[:length]
178 dst = buff[length-1 : 2*length-1]
179 mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
180
181
182 src = buff[length-1 : 2*length-1]
183 dst = buff[:length]
184 mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
185 })
186 }
187 })
188
189 t.Run("KeepState", func(t *testing.T) {
190 rng := newRandReader(t)
191
192 plaintext := make([]byte, bufCap)
193 rng.Read(plaintext)
194 ciphertext := make([]byte, bufCap)
195
196
197 ms().XORKeyStream(ciphertext, plaintext)
198
199 for _, step := range bufLens {
200 if step == 0 {
201 continue
202 }
203 stepMsg := fmt.Sprintf("step %d: ", step)
204
205 dst := make([]byte, bufCap)
206
207
208 stream := ms()
209 i := 0
210 for i+step < len(plaintext) {
211 stream.XORKeyStream(dst[i:], plaintext[i:i+step])
212 i += step
213 }
214 stream.XORKeyStream(dst[i:], plaintext[i:])
215
216 if !bytes.Equal(dst, ciphertext) {
217 t.Errorf(stepMsg+"successive XORKeyStream calls returned a different result than a single one; got %s, want %s", truncateHex(dst), truncateHex(ciphertext))
218 }
219 }
220 })
221 }
222
223
224
225
226
227 func TestStreamFromBlock(t *testing.T, block cipher.Block, blockMode func(b cipher.Block, iv []byte) cipher.Stream) {
228
229 t.Run("WrongIVLen", func(t *testing.T) {
230 t.Skip("see Issue 68377")
231
232 rng := newRandReader(t)
233 iv := make([]byte, block.BlockSize()+1)
234 rng.Read(iv)
235 mustPanic(t, "IV length must equal block size", func() { blockMode(block, iv) })
236 })
237
238 t.Run("BlockModeStream", func(t *testing.T) {
239 rng := newRandReader(t)
240 iv := make([]byte, block.BlockSize())
241 rng.Read(iv)
242
243 TestStream(t, func() cipher.Stream { return blockMode(block, iv) })
244 })
245 }
246
247 func truncateHex(b []byte) string {
248 numVals := 50
249
250 if len(b) <= numVals {
251 return fmt.Sprintf("%x", b)
252 }
253 return fmt.Sprintf("%x...", b[:numVals])
254 }
255
View as plain text