Source file
src/math/big/natmul.go
1
2
3
4
5
6
7 package big
8
9
10
11
12 var karatsubaThreshold = 40
13
14
15
16 func (z nat) mul(stk *stack, x, y nat) nat {
17 m := len(x)
18 n := len(y)
19
20 switch {
21 case m < n:
22 return z.mul(stk, y, x)
23 case m == 0 || n == 0:
24 return z[:0]
25 case n == 1:
26 return z.mulAddWW(x, y[0], 0)
27 }
28
29
30
31 if alias(z, x) || alias(z, y) {
32 z = nil
33 }
34 z = z.make(m + n)
35
36
37 if n < karatsubaThreshold {
38 basicMul(z, x, y)
39 return z.norm()
40 }
41
42 if stk == nil {
43 stk = getStack()
44 defer stk.free()
45 }
46
47
48
49
50 karatsuba(stk, z[:2*n], x[:n], y)
51
52 if n < m {
53 clear(z[2*n:])
54 defer stk.restore(stk.save())
55 t := stk.nat(2 * n)
56 for i := n; i < m; i += n {
57 t = t.mul(stk, x[i:min(i+n, len(x))], y)
58 addTo(z[i:], t)
59 }
60 }
61
62 return z.norm()
63 }
64
65
66
67
68 var basicSqrThreshold = 12
69 var karatsubaSqrThreshold = 80
70
71
72
73 func (z nat) sqr(stk *stack, x nat) nat {
74 n := len(x)
75 switch {
76 case n == 0:
77 return z[:0]
78 case n == 1:
79 d := x[0]
80 z = z.make(2)
81 z[1], z[0] = mulWW(d, d)
82 return z.norm()
83 }
84
85 if alias(z, x) {
86 z = nil
87 }
88 z = z.make(2 * n)
89
90 if n < basicSqrThreshold && n < karatsubaSqrThreshold {
91 basicMul(z, x, x)
92 return z.norm()
93 }
94
95 if stk == nil {
96 stk = getStack()
97 defer stk.free()
98 }
99
100 if n < karatsubaSqrThreshold {
101 basicSqr(stk, z, x)
102 return z.norm()
103 }
104
105 karatsubaSqr(stk, z, x)
106 return z.norm()
107 }
108
109
110
111
112
113 func basicSqr(stk *stack, z, x nat) {
114 n := len(x)
115 if n < basicSqrThreshold {
116 basicMul(z, x, x)
117 return
118 }
119
120 defer stk.restore(stk.save())
121 t := stk.nat(2 * n)
122 clear(t)
123 z[1], z[0] = mulWW(x[0], x[0])
124 for i := 1; i < n; i++ {
125 d := x[i]
126
127 z[2*i+1], z[2*i] = mulWW(d, d)
128
129 t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
130 }
131 t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1)
132 addVV(z, z, t)
133 }
134
135
136 func (z nat) mulAddWW(x nat, y, r Word) nat {
137 m := len(x)
138 if m == 0 || y == 0 {
139 return z.setWord(r)
140 }
141
142
143 z = z.make(m + 1)
144 z[m] = mulAddVWW(z[0:m], x, y, r)
145
146 return z.norm()
147 }
148
149
150
151 func basicMul(z, x, y nat) {
152 clear(z[0 : len(x)+len(y)])
153 for i, d := range y {
154 if d != 0 {
155 z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
156 }
157 }
158 }
159
160
161
162
163
164 func karatsuba(stk *stack, z, x, y nat) {
165 n := len(y)
166 if len(x) != n || len(z) != 2*n {
167 panic("bad karatsuba length")
168 }
169
170
171 if n < karatsubaThreshold || n < 2 {
172 basicMul(z, x, y)
173 return
174 }
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196 n2 := (n + 1) / 2
197 x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
198 y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()}
199 z0 := &Int{abs: z[0 : 2*n2]}
200 z2 := &Int{abs: z[2*n2:]}
201
202
203 defer stk.restore(stk.save())
204 z1 := &Int{abs: stk.nat(2*n2 + 1)}
205 tx := &Int{abs: z[0:n2]}
206 ty := &Int{abs: z[n2 : 2*n2]}
207
208 tx.Sub(x0, x1)
209 ty.Sub(y1, y0)
210 z1.mul(stk, tx, ty)
211
212 clear(z)
213 z0.mul(stk, x0, y0)
214 z2.mul(stk, x1, y1)
215 z1.Add(z1, z0)
216 z1.Add(z1, z2)
217 addTo(z[n2:], z1.abs)
218
219
220 const debug = false
221 if debug {
222 zz := make(nat, len(z))
223 basicMul(zz, x, y)
224 if z.cmp(zz) != 0 {
225
226 z0 = new(Int)
227 z0.mul(stk, x0, y0)
228 tx = new(Int).Sub(x1, x0)
229 ty = new(Int).Sub(y0, y1)
230 z2 = new(Int)
231 z2.mul(stk, x1, y1)
232 print("karatsuba wrong\n")
233 trace("x ", &Int{abs: x})
234 trace("y ", &Int{abs: y})
235 trace("z ", &Int{abs: z})
236 trace("zz", &Int{abs: zz})
237 trace("x0", x0)
238 trace("x1", x1)
239 trace("y0", y0)
240 trace("y1", y1)
241 trace("tx", tx)
242 trace("ty", ty)
243 trace("z0", z0)
244 trace("z1", z1)
245 trace("z2", z2)
246 panic("karatsuba")
247 }
248 }
249
250 }
251
252
253
254
255
256
257 func karatsubaSqr(stk *stack, z, x nat) {
258 n := len(x)
259 if len(z) != 2*n {
260 panic("bad karatsubaSqr length")
261 }
262
263 if n < karatsubaSqrThreshold || n < 2 {
264 basicSqr(stk, z, x)
265 return
266 }
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285 n2 := (n + 1) / 2
286 x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
287 z0 := &Int{abs: z[0 : 2*n2]}
288 z2 := &Int{abs: z[2*n2:]}
289
290
291 defer stk.restore(stk.save())
292 z1 := &Int{abs: stk.nat(2*n2 + 1)}
293 tx := &Int{abs: z[0:n2]}
294
295 tx.Sub(x0, x1)
296 z1.abs = z1.abs.sqr(stk, tx.abs)
297 z1.neg = true
298
299 clear(z)
300 z0.abs = z0.abs.sqr(stk, x0.abs)
301 z2.abs = z2.abs.sqr(stk, x1.abs)
302 z1.Add(z1, z0)
303 z1.Add(z1, z2)
304 addTo(z[n2:], z1.abs)
305
306
307 const debug = false
308 if debug {
309 zz := make(nat, len(z))
310 basicSqr(stk, zz, x)
311 if z.cmp(zz) != 0 {
312
313 tx = new(Int).Sub(x0, x1)
314 z0 = new(Int).Mul(x0, x0)
315 z2 = new(Int).Mul(x1, x1)
316 z1 = new(Int).Mul(tx, tx)
317 z1.Neg(z1)
318 z1.Add(z1, z0)
319 z1.Add(z1, z2)
320 print("karatsubaSqr wrong\n")
321 trace("x ", &Int{abs: x})
322 trace("z ", &Int{abs: z})
323 trace("zz", &Int{abs: zz})
324 trace("x0", x0)
325 trace("x1", x1)
326 trace("z0", z0)
327 trace("z1", z1)
328 trace("z2", z2)
329 panic("karatsubaSqr")
330 }
331 }
332 }
333
334
335 func ifmt(x *Int) string {
336 neg, s, t := "", x.Text(16), ""
337 if s == "" {
338 s = "0x0"
339 }
340 if s[0] == '-' {
341 neg, s = "-", s[1:]
342 }
343
344
345 const D = _W / 4
346 for len(s) > D {
347 s, t = s[:len(s)-D], s[len(s)-D:]+"_"+t
348 }
349 return neg + s + t
350 }
351
352
353 func trace(name string, x *Int) {
354 print(name, "=", ifmt(x), "\n")
355 }
356
View as plain text