1
2
3
4
5 package mlkem
6
7 import (
8 "bytes"
9 "crypto/internal/fips140/mlkem"
10 "crypto/internal/fips140/sha3"
11 "crypto/rand"
12 "encoding/hex"
13 "flag"
14 "testing"
15 )
16
17 type encapsulationKey interface {
18 Bytes() []byte
19 Encapsulate() ([]byte, []byte)
20 }
21
22 type decapsulationKey[E encapsulationKey] interface {
23 Bytes() []byte
24 Decapsulate([]byte) ([]byte, error)
25 EncapsulationKey() E
26 }
27
28 func TestRoundTrip(t *testing.T) {
29 t.Run("768", func(t *testing.T) {
30 testRoundTrip(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
31 })
32 t.Run("1024", func(t *testing.T) {
33 testRoundTrip(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
34 })
35 }
36
37 func testRoundTrip[E encapsulationKey, D decapsulationKey[E]](
38 t *testing.T, generateKey func() (D, error),
39 newEncapsulationKey func([]byte) (E, error),
40 newDecapsulationKey func([]byte) (D, error)) {
41 dk, err := generateKey()
42 if err != nil {
43 t.Fatal(err)
44 }
45 ek := dk.EncapsulationKey()
46 c, Ke := ek.Encapsulate()
47 Kd, err := dk.Decapsulate(c)
48 if err != nil {
49 t.Fatal(err)
50 }
51 if !bytes.Equal(Ke, Kd) {
52 t.Fail()
53 }
54
55 ek1, err := newEncapsulationKey(ek.Bytes())
56 if err != nil {
57 t.Fatal(err)
58 }
59 if !bytes.Equal(ek.Bytes(), ek1.Bytes()) {
60 t.Fail()
61 }
62 dk1, err := newDecapsulationKey(dk.Bytes())
63 if err != nil {
64 t.Fatal(err)
65 }
66 if !bytes.Equal(dk.Bytes(), dk1.Bytes()) {
67 t.Fail()
68 }
69 c1, Ke1 := ek1.Encapsulate()
70 Kd1, err := dk1.Decapsulate(c1)
71 if err != nil {
72 t.Fatal(err)
73 }
74 if !bytes.Equal(Ke1, Kd1) {
75 t.Fail()
76 }
77
78 dk2, err := generateKey()
79 if err != nil {
80 t.Fatal(err)
81 }
82 if bytes.Equal(dk.EncapsulationKey().Bytes(), dk2.EncapsulationKey().Bytes()) {
83 t.Fail()
84 }
85 if bytes.Equal(dk.Bytes(), dk2.Bytes()) {
86 t.Fail()
87 }
88
89 c2, Ke2 := dk.EncapsulationKey().Encapsulate()
90 if bytes.Equal(c, c2) {
91 t.Fail()
92 }
93 if bytes.Equal(Ke, Ke2) {
94 t.Fail()
95 }
96 }
97
98 func TestBadLengths(t *testing.T) {
99 t.Run("768", func(t *testing.T) {
100 testBadLengths(t, GenerateKey768, NewEncapsulationKey768, NewDecapsulationKey768)
101 })
102 t.Run("1024", func(t *testing.T) {
103 testBadLengths(t, GenerateKey1024, NewEncapsulationKey1024, NewDecapsulationKey1024)
104 })
105 }
106
107 func testBadLengths[E encapsulationKey, D decapsulationKey[E]](
108 t *testing.T, generateKey func() (D, error),
109 newEncapsulationKey func([]byte) (E, error),
110 newDecapsulationKey func([]byte) (D, error)) {
111 dk, err := generateKey()
112 dkBytes := dk.Bytes()
113 if err != nil {
114 t.Fatal(err)
115 }
116 ek := dk.EncapsulationKey()
117 ekBytes := dk.EncapsulationKey().Bytes()
118 c, _ := ek.Encapsulate()
119
120 for i := 0; i < len(dkBytes)-1; i++ {
121 if _, err := newDecapsulationKey(dkBytes[:i]); err == nil {
122 t.Errorf("expected error for dk length %d", i)
123 }
124 }
125 dkLong := dkBytes
126 for i := 0; i < 100; i++ {
127 dkLong = append(dkLong, 0)
128 if _, err := newDecapsulationKey(dkLong); err == nil {
129 t.Errorf("expected error for dk length %d", len(dkLong))
130 }
131 }
132
133 for i := 0; i < len(ekBytes)-1; i++ {
134 if _, err := newEncapsulationKey(ekBytes[:i]); err == nil {
135 t.Errorf("expected error for ek length %d", i)
136 }
137 }
138 ekLong := ekBytes
139 for i := 0; i < 100; i++ {
140 ekLong = append(ekLong, 0)
141 if _, err := newEncapsulationKey(ekLong); err == nil {
142 t.Errorf("expected error for ek length %d", len(ekLong))
143 }
144 }
145
146 for i := 0; i < len(c)-1; i++ {
147 if _, err := dk.Decapsulate(c[:i]); err == nil {
148 t.Errorf("expected error for c length %d", i)
149 }
150 }
151 cLong := c
152 for i := 0; i < 100; i++ {
153 cLong = append(cLong, 0)
154 if _, err := dk.Decapsulate(cLong); err == nil {
155 t.Errorf("expected error for c length %d", len(cLong))
156 }
157 }
158 }
159
160 var millionFlag = flag.Bool("million", false, "run the million vector test")
161
162
163
164 func TestAccumulated(t *testing.T) {
165 n := 10000
166 expected := "8a518cc63da366322a8e7a818c7a0d63483cb3528d34a4cf42f35d5ad73f22fc"
167 if testing.Short() {
168 n = 100
169 expected = "1114b1b6699ed191734fa339376afa7e285c9e6acf6ff0177d346696ce564415"
170 }
171 if *millionFlag {
172 n = 1000000
173 expected = "424bf8f0e8ae99b78d788a6e2e8e9cdaf9773fc0c08a6f433507cb559edfd0f0"
174 }
175
176 s := sha3.NewShake128()
177 o := sha3.NewShake128()
178 seed := make([]byte, SeedSize)
179 var msg [32]byte
180 ct1 := make([]byte, CiphertextSize768)
181
182 for i := 0; i < n; i++ {
183 s.Read(seed)
184 dk, err := NewDecapsulationKey768(seed)
185 if err != nil {
186 t.Fatal(err)
187 }
188 ek := dk.EncapsulationKey()
189 o.Write(ek.Bytes())
190
191 s.Read(msg[:])
192 ct, k := ek.key.EncapsulateInternal(&msg)
193 o.Write(ct)
194 o.Write(k)
195
196 kk, err := dk.Decapsulate(ct)
197 if err != nil {
198 t.Fatal(err)
199 }
200 if !bytes.Equal(kk, k) {
201 t.Errorf("k: got %x, expected %x", kk, k)
202 }
203
204 s.Read(ct1)
205 k1, err := dk.Decapsulate(ct1)
206 if err != nil {
207 t.Fatal(err)
208 }
209 o.Write(k1)
210 }
211
212 got := hex.EncodeToString(o.Sum(nil))
213 if got != expected {
214 t.Errorf("got %s, expected %s", got, expected)
215 }
216 }
217
218 var sink byte
219
220 func BenchmarkKeyGen(b *testing.B) {
221 var d, z [32]byte
222 rand.Read(d[:])
223 rand.Read(z[:])
224 b.ResetTimer()
225 for i := 0; i < b.N; i++ {
226 dk := mlkem.GenerateKeyInternal768(&d, &z)
227 sink ^= dk.EncapsulationKey().Bytes()[0]
228 }
229 }
230
231 func BenchmarkEncaps(b *testing.B) {
232 seed := make([]byte, SeedSize)
233 rand.Read(seed)
234 var m [32]byte
235 rand.Read(m[:])
236 dk, err := NewDecapsulationKey768(seed)
237 if err != nil {
238 b.Fatal(err)
239 }
240 ekBytes := dk.EncapsulationKey().Bytes()
241 b.ResetTimer()
242 for i := 0; i < b.N; i++ {
243 ek, err := NewEncapsulationKey768(ekBytes)
244 if err != nil {
245 b.Fatal(err)
246 }
247 c, K := ek.key.EncapsulateInternal(&m)
248 sink ^= c[0] ^ K[0]
249 }
250 }
251
252 func BenchmarkDecaps(b *testing.B) {
253 dk, err := GenerateKey768()
254 if err != nil {
255 b.Fatal(err)
256 }
257 ek := dk.EncapsulationKey()
258 c, _ := ek.Encapsulate()
259 b.ResetTimer()
260 for i := 0; i < b.N; i++ {
261 K, _ := dk.Decapsulate(c)
262 sink ^= K[0]
263 }
264 }
265
266 func BenchmarkRoundTrip(b *testing.B) {
267 dk, err := GenerateKey768()
268 if err != nil {
269 b.Fatal(err)
270 }
271 ek := dk.EncapsulationKey()
272 ekBytes := ek.Bytes()
273 c, _ := ek.Encapsulate()
274 if err != nil {
275 b.Fatal(err)
276 }
277 b.Run("Alice", func(b *testing.B) {
278 for i := 0; i < b.N; i++ {
279 dkS, err := GenerateKey768()
280 if err != nil {
281 b.Fatal(err)
282 }
283 ekS := dkS.EncapsulationKey().Bytes()
284 sink ^= ekS[0]
285
286 Ks, err := dk.Decapsulate(c)
287 if err != nil {
288 b.Fatal(err)
289 }
290 sink ^= Ks[0]
291 }
292 })
293 b.Run("Bob", func(b *testing.B) {
294 for i := 0; i < b.N; i++ {
295 ek, err := NewEncapsulationKey768(ekBytes)
296 if err != nil {
297 b.Fatal(err)
298 }
299 cS, Ks := ek.Encapsulate()
300 if err != nil {
301 b.Fatal(err)
302 }
303 sink ^= cS[0] ^ Ks[0]
304 }
305 })
306 }
307
308
309 func TestConstantSizes(t *testing.T) {
310 if SharedKeySize != mlkem.SharedKeySize {
311 t.Errorf("SharedKeySize mismatch: got %d, want %d", SharedKeySize, mlkem.SharedKeySize)
312 }
313
314 if SeedSize != mlkem.SeedSize {
315 t.Errorf("SeedSize mismatch: got %d, want %d", SeedSize, mlkem.SeedSize)
316 }
317
318 if CiphertextSize768 != mlkem.CiphertextSize768 {
319 t.Errorf("CiphertextSize768 mismatch: got %d, want %d", CiphertextSize768, mlkem.CiphertextSize768)
320 }
321
322 if EncapsulationKeySize768 != mlkem.EncapsulationKeySize768 {
323 t.Errorf("EncapsulationKeySize768 mismatch: got %d, want %d", EncapsulationKeySize768, mlkem.EncapsulationKeySize768)
324 }
325
326 if CiphertextSize1024 != mlkem.CiphertextSize1024 {
327 t.Errorf("CiphertextSize1024 mismatch: got %d, want %d", CiphertextSize1024, mlkem.CiphertextSize1024)
328 }
329
330 if EncapsulationKeySize1024 != mlkem.EncapsulationKeySize1024 {
331 t.Errorf("EncapsulationKeySize1024 mismatch: got %d, want %d", EncapsulationKeySize1024, mlkem.EncapsulationKeySize1024)
332 }
333 }
334
View as plain text