1
2
3
4
5 package fipstest
6
7 import (
8 "bytes"
9 "crypto/elliptic"
10 "crypto/internal/cryptotest"
11 "crypto/internal/fips140/nistec"
12 "fmt"
13 "math/big"
14 "math/rand"
15 "testing"
16 )
17
18 func TestNISTECAllocations(t *testing.T) {
19 cryptotest.SkipTestAllocations(t)
20 t.Run("P224", func(t *testing.T) {
21 if allocs := testing.AllocsPerRun(10, func() {
22 p := nistec.NewP224Point().SetGenerator()
23 scalar := make([]byte, 28)
24 rand.Read(scalar)
25 p.ScalarBaseMult(scalar)
26 p.ScalarMult(p, scalar)
27 out := p.Bytes()
28 if _, err := nistec.NewP224Point().SetBytes(out); err != nil {
29 t.Fatal(err)
30 }
31 out = p.BytesCompressed()
32 if _, err := p.SetBytes(out); err != nil {
33 t.Fatal(err)
34 }
35 }); allocs > 0 {
36 t.Errorf("expected zero allocations, got %0.1f", allocs)
37 }
38 })
39 t.Run("P256", func(t *testing.T) {
40 if allocs := testing.AllocsPerRun(10, func() {
41 p := nistec.NewP256Point().SetGenerator()
42 scalar := make([]byte, 32)
43 rand.Read(scalar)
44 p.ScalarBaseMult(scalar)
45 p.ScalarMult(p, scalar)
46 out := p.Bytes()
47 if _, err := nistec.NewP256Point().SetBytes(out); err != nil {
48 t.Fatal(err)
49 }
50 out = p.BytesCompressed()
51 if _, err := p.SetBytes(out); err != nil {
52 t.Fatal(err)
53 }
54 }); allocs > 0 {
55 t.Errorf("expected zero allocations, got %0.1f", allocs)
56 }
57 })
58 t.Run("P384", func(t *testing.T) {
59 if allocs := testing.AllocsPerRun(10, func() {
60 p := nistec.NewP384Point().SetGenerator()
61 scalar := make([]byte, 48)
62 rand.Read(scalar)
63 p.ScalarBaseMult(scalar)
64 p.ScalarMult(p, scalar)
65 out := p.Bytes()
66 if _, err := nistec.NewP384Point().SetBytes(out); err != nil {
67 t.Fatal(err)
68 }
69 out = p.BytesCompressed()
70 if _, err := p.SetBytes(out); err != nil {
71 t.Fatal(err)
72 }
73 }); allocs > 0 {
74 t.Errorf("expected zero allocations, got %0.1f", allocs)
75 }
76 })
77 t.Run("P521", func(t *testing.T) {
78 if allocs := testing.AllocsPerRun(10, func() {
79 p := nistec.NewP521Point().SetGenerator()
80 scalar := make([]byte, 66)
81 rand.Read(scalar)
82 p.ScalarBaseMult(scalar)
83 p.ScalarMult(p, scalar)
84 out := p.Bytes()
85 if _, err := nistec.NewP521Point().SetBytes(out); err != nil {
86 t.Fatal(err)
87 }
88 out = p.BytesCompressed()
89 if _, err := p.SetBytes(out); err != nil {
90 t.Fatal(err)
91 }
92 }); allocs > 0 {
93 t.Errorf("expected zero allocations, got %0.1f", allocs)
94 }
95 })
96 }
97
98 type nistPoint[T any] interface {
99 Bytes() []byte
100 SetGenerator() T
101 SetBytes([]byte) (T, error)
102 Add(T, T) T
103 Double(T) T
104 ScalarMult(T, []byte) (T, error)
105 ScalarBaseMult([]byte) (T, error)
106 }
107
108 func TestEquivalents(t *testing.T) {
109 t.Run("P224", func(t *testing.T) {
110 testEquivalents(t, nistec.NewP224Point, elliptic.P224())
111 })
112 t.Run("P256", func(t *testing.T) {
113 testEquivalents(t, nistec.NewP256Point, elliptic.P256())
114 })
115 t.Run("P384", func(t *testing.T) {
116 testEquivalents(t, nistec.NewP384Point, elliptic.P384())
117 })
118 t.Run("P521", func(t *testing.T) {
119 testEquivalents(t, nistec.NewP521Point, elliptic.P521())
120 })
121 }
122
123 func testEquivalents[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
124 p := newPoint().SetGenerator()
125
126 elementSize := (c.Params().BitSize + 7) / 8
127 two := make([]byte, elementSize)
128 two[len(two)-1] = 2
129 nPlusTwo := make([]byte, elementSize)
130 new(big.Int).Add(c.Params().N, big.NewInt(2)).FillBytes(nPlusTwo)
131
132 p1 := newPoint().Double(p)
133 p2 := newPoint().Add(p, p)
134 p3, err := newPoint().ScalarMult(p, two)
135 fatalIfErr(t, err)
136 p4, err := newPoint().ScalarBaseMult(two)
137 fatalIfErr(t, err)
138 p5, err := newPoint().ScalarMult(p, nPlusTwo)
139 fatalIfErr(t, err)
140 p6, err := newPoint().ScalarBaseMult(nPlusTwo)
141 fatalIfErr(t, err)
142
143 if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
144 t.Error("P+P != 2*P")
145 }
146 if !bytes.Equal(p1.Bytes(), p3.Bytes()) {
147 t.Error("P+P != [2]P")
148 }
149 if !bytes.Equal(p1.Bytes(), p4.Bytes()) {
150 t.Error("G+G != [2]G")
151 }
152 if !bytes.Equal(p1.Bytes(), p5.Bytes()) {
153 t.Error("P+P != [N+2]P")
154 }
155 if !bytes.Equal(p1.Bytes(), p6.Bytes()) {
156 t.Error("G+G != [N+2]G")
157 }
158 }
159
160 func TestScalarMult(t *testing.T) {
161 t.Run("P224", func(t *testing.T) {
162 testScalarMult(t, nistec.NewP224Point, elliptic.P224())
163 })
164 t.Run("P256", func(t *testing.T) {
165 testScalarMult(t, nistec.NewP256Point, elliptic.P256())
166 })
167 t.Run("P384", func(t *testing.T) {
168 testScalarMult(t, nistec.NewP384Point, elliptic.P384())
169 })
170 t.Run("P521", func(t *testing.T) {
171 testScalarMult(t, nistec.NewP521Point, elliptic.P521())
172 })
173 }
174
175 func testScalarMult[P nistPoint[P]](t *testing.T, newPoint func() P, c elliptic.Curve) {
176 G := newPoint().SetGenerator()
177 checkScalar := func(t *testing.T, scalar []byte) {
178 p1, err := newPoint().ScalarBaseMult(scalar)
179 fatalIfErr(t, err)
180 p2, err := newPoint().ScalarMult(G, scalar)
181 fatalIfErr(t, err)
182 if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
183 t.Error("[k]G != ScalarBaseMult(k)")
184 }
185
186 expectInfinity := new(big.Int).Mod(new(big.Int).SetBytes(scalar), c.Params().N).Sign() == 0
187 if expectInfinity {
188 if !bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
189 t.Error("ScalarBaseMult(k) != ∞")
190 }
191 if !bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
192 t.Error("[k]G != ∞")
193 }
194 } else {
195 if bytes.Equal(p1.Bytes(), newPoint().Bytes()) {
196 t.Error("ScalarBaseMult(k) == ∞")
197 }
198 if bytes.Equal(p2.Bytes(), newPoint().Bytes()) {
199 t.Error("[k]G == ∞")
200 }
201 }
202
203 d := new(big.Int).SetBytes(scalar)
204 d.Sub(c.Params().N, d)
205 d.Mod(d, c.Params().N)
206 g1, err := newPoint().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
207 fatalIfErr(t, err)
208 g1.Add(g1, p1)
209 if !bytes.Equal(g1.Bytes(), newPoint().Bytes()) {
210 t.Error("[N - k]G + [k]G != ∞")
211 }
212 }
213
214 byteLen := len(c.Params().N.Bytes())
215 bitLen := c.Params().N.BitLen()
216 t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
217 t.Run("1", func(t *testing.T) {
218 checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
219 })
220 t.Run("N-1", func(t *testing.T) {
221 checkScalar(t, new(big.Int).Sub(c.Params().N, big.NewInt(1)).Bytes())
222 })
223 t.Run("N", func(t *testing.T) { checkScalar(t, c.Params().N.Bytes()) })
224 t.Run("N+1", func(t *testing.T) {
225 checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(1)).Bytes())
226 })
227 t.Run("all1s", func(t *testing.T) {
228 s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
229 s.Sub(s, big.NewInt(1))
230 checkScalar(t, s.Bytes())
231 })
232 if testing.Short() {
233 return
234 }
235 for i := 0; i < bitLen; i++ {
236 t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
237 s := new(big.Int).Lsh(big.NewInt(1), uint(i))
238 checkScalar(t, s.FillBytes(make([]byte, byteLen)))
239 })
240 }
241 for i := 0; i <= 64; i++ {
242 t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
243 checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen)))
244 })
245 }
246
247
248 for i := int64(-64); i <= 64; i++ {
249 t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) {
250 checkScalar(t, new(big.Int).Add(c.Params().N, big.NewInt(i)).Bytes())
251 })
252 }
253 }
254
255 func fatalIfErr(t *testing.T, err error) {
256 t.Helper()
257 if err != nil {
258 t.Fatal(err)
259 }
260 }
261
View as plain text