Source file src/math/big/natmul.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Multiplication.
     6  
     7  package big
     8  
     9  // Operands that are shorter than karatsubaThreshold are multiplied using
    10  // "grade school" multiplication; for longer operands the Karatsuba algorithm
    11  // is used.
    12  var karatsubaThreshold = 40 // see calibrate_test.go
    13  
    14  // mul sets z = x*y, using stk for temporary storage.
    15  // The caller may pass stk == nil to request that mul obtain and release one itself.
    16  func (z nat) mul(stk *stack, x, y nat) nat {
    17  	m := len(x)
    18  	n := len(y)
    19  
    20  	switch {
    21  	case m < n:
    22  		return z.mul(stk, y, x)
    23  	case m == 0 || n == 0:
    24  		return z[:0]
    25  	case n == 1:
    26  		return z.mulAddWW(x, y[0], 0)
    27  	}
    28  	// m >= n > 1
    29  
    30  	// determine if z can be reused
    31  	if alias(z, x) || alias(z, y) {
    32  		z = nil // z is an alias for x or y - cannot reuse
    33  	}
    34  	z = z.make(m + n)
    35  
    36  	// use basic multiplication if the numbers are small
    37  	if n < karatsubaThreshold {
    38  		basicMul(z, x, y)
    39  		return z.norm()
    40  	}
    41  
    42  	if stk == nil {
    43  		stk = getStack()
    44  		defer stk.free()
    45  	}
    46  
    47  	// Let x = x1:x0 where x0 is the same length as y.
    48  	// Compute z = x0*y and then add in x1*y in sections
    49  	// if needed.
    50  	karatsuba(stk, z[:2*n], x[:n], y)
    51  
    52  	if n < m {
    53  		clear(z[2*n:])
    54  		defer stk.restore(stk.save())
    55  		t := stk.nat(2 * n)
    56  		for i := n; i < m; i += n {
    57  			t = t.mul(stk, x[i:min(i+n, len(x))], y)
    58  			addTo(z[i:], t)
    59  		}
    60  	}
    61  
    62  	return z.norm()
    63  }
    64  
    65  // Operands that are shorter than basicSqrThreshold are squared using
    66  // "grade school" multiplication; for operands longer than karatsubaSqrThreshold
    67  // we use the Karatsuba algorithm optimized for x == y.
    68  var basicSqrThreshold = 12     // see calibrate_test.go
    69  var karatsubaSqrThreshold = 80 // see calibrate_test.go
    70  
    71  // sqr sets z = x*x, using stk for temporary storage.
    72  // The caller may pass stk == nil to request that sqr obtain and release one itself.
    73  func (z nat) sqr(stk *stack, x nat) nat {
    74  	n := len(x)
    75  	switch {
    76  	case n == 0:
    77  		return z[:0]
    78  	case n == 1:
    79  		d := x[0]
    80  		z = z.make(2)
    81  		z[1], z[0] = mulWW(d, d)
    82  		return z.norm()
    83  	}
    84  
    85  	if alias(z, x) {
    86  		z = nil // z is an alias for x - cannot reuse
    87  	}
    88  	z = z.make(2 * n)
    89  
    90  	if n < basicSqrThreshold && n < karatsubaSqrThreshold {
    91  		basicMul(z, x, x)
    92  		return z.norm()
    93  	}
    94  
    95  	if stk == nil {
    96  		stk = getStack()
    97  		defer stk.free()
    98  	}
    99  
   100  	if n < karatsubaSqrThreshold {
   101  		basicSqr(stk, z, x)
   102  		return z.norm()
   103  	}
   104  
   105  	karatsubaSqr(stk, z, x)
   106  	return z.norm()
   107  }
   108  
   109  // basicSqr sets z = x*x and is asymptotically faster than basicMul
   110  // by about a factor of 2, but slower for small arguments due to overhead.
   111  // Requirements: len(x) > 0, len(z) == 2*len(x)
   112  // The (non-normalized) result is placed in z.
   113  func basicSqr(stk *stack, z, x nat) {
   114  	n := len(x)
   115  	if n < basicSqrThreshold {
   116  		basicMul(z, x, x)
   117  		return
   118  	}
   119  
   120  	defer stk.restore(stk.save())
   121  	t := stk.nat(2 * n)
   122  	clear(t)
   123  	z[1], z[0] = mulWW(x[0], x[0]) // the initial square
   124  	for i := 1; i < n; i++ {
   125  		d := x[i]
   126  		// z collects the squares x[i] * x[i]
   127  		z[2*i+1], z[2*i] = mulWW(d, d)
   128  		// t collects the products x[i] * x[j] where j < i
   129  		t[2*i] = addMulVVW(t[i:2*i], x[0:i], d)
   130  	}
   131  	t[2*n-1] = shlVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
   132  	addVV(z, z, t)                              // combine the result
   133  }
   134  
   135  // mulAddWW returns z = x*y + r.
   136  func (z nat) mulAddWW(x nat, y, r Word) nat {
   137  	m := len(x)
   138  	if m == 0 || y == 0 {
   139  		return z.setWord(r) // result is r
   140  	}
   141  	// m > 0
   142  
   143  	z = z.make(m + 1)
   144  	z[m] = mulAddVWW(z[0:m], x, y, r)
   145  
   146  	return z.norm()
   147  }
   148  
   149  // basicMul multiplies x and y and leaves the result in z.
   150  // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
   151  func basicMul(z, x, y nat) {
   152  	clear(z[0 : len(x)+len(y)]) // initialize z
   153  	for i, d := range y {
   154  		if d != 0 {
   155  			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
   156  		}
   157  	}
   158  }
   159  
   160  // karatsuba multiplies x and y,
   161  // writing the (non-normalized) result to z.
   162  // x and y must have the same length n,
   163  // and z must have length twice that.
   164  func karatsuba(stk *stack, z, x, y nat) {
   165  	n := len(y)
   166  	if len(x) != n || len(z) != 2*n {
   167  		panic("bad karatsuba length")
   168  	}
   169  
   170  	// Fall back to basic algorithm if small enough.
   171  	if n < karatsubaThreshold || n < 2 {
   172  		basicMul(z, x, y)
   173  		return
   174  	}
   175  
   176  	// Let the notation x1:x0 denote the nat (x1<<N)+x0 for some N,
   177  	// and similarly z2:z1:z0 = (z2<<2N)+(z1<<N)+z0.
   178  	//
   179  	// (Note that z0, z1, z2 might be ≥ 2**N, in which case the high
   180  	// bits of, say, z0 are being added to the low bits of z1 in this notation.)
   181  	//
   182  	// Karatsuba multiplication is based on the observation that
   183  	//
   184  	//	x1:x0 * y1:y0 = x1*y1:(x0*y1+y0*x1):x0*y0
   185  	//	              = x1*y1:((x0-x1)*(y1-y0)+x1*y1+x0*y0):x0*y0
   186  	//
   187  	// The second form uses only three half-width multiplications
   188  	// instead of the four that the straightforward first form does.
   189  	//
   190  	// We call the three pieces z0, z1, z2:
   191  	//
   192  	//	z0 = x0*y0
   193  	//	z2 = x1*y1
   194  	//	z1 = (x0-x1)*(y1-y0) + z0 + z2
   195  
   196  	n2 := (n + 1) / 2
   197  	x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
   198  	y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()}
   199  	z0 := &Int{abs: z[0 : 2*n2]}
   200  	z2 := &Int{abs: z[2*n2:]}
   201  
   202  	// Allocate temporary storage for z1; repurpose z0 to hold tx and ty.
   203  	defer stk.restore(stk.save())
   204  	z1 := &Int{abs: stk.nat(2*n2 + 1)}
   205  	tx := &Int{abs: z[0:n2]}
   206  	ty := &Int{abs: z[n2 : 2*n2]}
   207  
   208  	tx.Sub(x0, x1)
   209  	ty.Sub(y1, y0)
   210  	z1.mul(stk, tx, ty)
   211  
   212  	clear(z)
   213  	z0.mul(stk, x0, y0)
   214  	z2.mul(stk, x1, y1)
   215  	z1.Add(z1, z0)
   216  	z1.Add(z1, z2)
   217  	addTo(z[n2:], z1.abs)
   218  
   219  	// Debug mode: double-check answer and print trace on failure.
   220  	const debug = false
   221  	if debug {
   222  		zz := make(nat, len(z))
   223  		basicMul(zz, x, y)
   224  		if z.cmp(zz) != 0 {
   225  			// All the temps were aliased to z and gone. Recompute.
   226  			z0 = new(Int)
   227  			z0.mul(stk, x0, y0)
   228  			tx = new(Int).Sub(x1, x0)
   229  			ty = new(Int).Sub(y0, y1)
   230  			z2 = new(Int)
   231  			z2.mul(stk, x1, y1)
   232  			print("karatsuba wrong\n")
   233  			trace("x ", &Int{abs: x})
   234  			trace("y ", &Int{abs: y})
   235  			trace("z ", &Int{abs: z})
   236  			trace("zz", &Int{abs: zz})
   237  			trace("x0", x0)
   238  			trace("x1", x1)
   239  			trace("y0", y0)
   240  			trace("y1", y1)
   241  			trace("tx", tx)
   242  			trace("ty", ty)
   243  			trace("z0", z0)
   244  			trace("z1", z1)
   245  			trace("z2", z2)
   246  			panic("karatsuba")
   247  		}
   248  	}
   249  
   250  }
   251  
   252  // karatsubaSqr squares x,
   253  // writing the (non-normalized) result to z.
   254  // z must have length 2*len(x).
   255  // It is analogous to [karatsuba] but can run faster
   256  // knowing both multiplicands are the same value.
   257  func karatsubaSqr(stk *stack, z, x nat) {
   258  	n := len(x)
   259  	if len(z) != 2*n {
   260  		panic("bad karatsubaSqr length")
   261  	}
   262  
   263  	if n < karatsubaSqrThreshold || n < 2 {
   264  		basicSqr(stk, z, x)
   265  		return
   266  	}
   267  
   268  	// Recall that for karatsuba we want to compute:
   269  	//
   270  	//	x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0
   271  	//                = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0
   272  	//	              = z2:z1:z0
   273  	// where:
   274  	//
   275  	//	z0 = x0y0
   276  	//	z2 = x1y1
   277  	//	z1 = (x0-x1)*(y1-y0) + z0 + z2
   278  	//
   279  	// When x = y, these simplify to:
   280  	//
   281  	//	z0 = x0²
   282  	//	z2 = x1²
   283  	//	z1 = z0 + z2 - (x0-x1)²
   284  
   285  	n2 := (n + 1) / 2
   286  	x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
   287  	z0 := &Int{abs: z[0 : 2*n2]}
   288  	z2 := &Int{abs: z[2*n2:]}
   289  
   290  	// Allocate temporary storage for z1; repurpose z0 to hold tx.
   291  	defer stk.restore(stk.save())
   292  	z1 := &Int{abs: stk.nat(2*n2 + 1)}
   293  	tx := &Int{abs: z[0:n2]}
   294  
   295  	tx.Sub(x0, x1)
   296  	z1.abs = z1.abs.sqr(stk, tx.abs)
   297  	z1.neg = true
   298  
   299  	clear(z)
   300  	z0.abs = z0.abs.sqr(stk, x0.abs)
   301  	z2.abs = z2.abs.sqr(stk, x1.abs)
   302  	z1.Add(z1, z0)
   303  	z1.Add(z1, z2)
   304  	addTo(z[n2:], z1.abs)
   305  
   306  	// Debug mode: double-check answer and print trace on failure.
   307  	const debug = false
   308  	if debug {
   309  		zz := make(nat, len(z))
   310  		basicSqr(stk, zz, x)
   311  		if z.cmp(zz) != 0 {
   312  			// All the temps were aliased to z and gone. Recompute.
   313  			tx = new(Int).Sub(x0, x1)
   314  			z0 = new(Int).Mul(x0, x0)
   315  			z2 = new(Int).Mul(x1, x1)
   316  			z1 = new(Int).Mul(tx, tx)
   317  			z1.Neg(z1)
   318  			z1.Add(z1, z0)
   319  			z1.Add(z1, z2)
   320  			print("karatsubaSqr wrong\n")
   321  			trace("x ", &Int{abs: x})
   322  			trace("z ", &Int{abs: z})
   323  			trace("zz", &Int{abs: zz})
   324  			trace("x0", x0)
   325  			trace("x1", x1)
   326  			trace("z0", z0)
   327  			trace("z1", z1)
   328  			trace("z2", z2)
   329  			panic("karatsubaSqr")
   330  		}
   331  	}
   332  }
   333  
   334  // ifmt returns the debug formatting of the Int x: 0xHEX.
   335  func ifmt(x *Int) string {
   336  	neg, s, t := "", x.Text(16), ""
   337  	if s == "" { // happens for denormalized zero
   338  		s = "0x0"
   339  	}
   340  	if s[0] == '-' {
   341  		neg, s = "-", s[1:]
   342  	}
   343  
   344  	// Add _ between words.
   345  	const D = _W / 4 // digits per chunk
   346  	for len(s) > D {
   347  		s, t = s[:len(s)-D], s[len(s)-D:]+"_"+t
   348  	}
   349  	return neg + s + t
   350  }
   351  
   352  // trace prints a single debug value.
   353  func trace(name string, x *Int) {
   354  	print(name, "=", ifmt(x), "\n")
   355  }
   356  

View as plain text