Source file src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"sync"
    10  
    11  	. "github.com/mmcloughlin/avo/build"
    12  	. "github.com/mmcloughlin/avo/operand"
    13  	. "github.com/mmcloughlin/avo/reg"
    14  )
    15  
    16  //go:generate go run . -out ../../ctr_amd64.s
    17  
    18  func main() {
    19  	Package("crypto/internal/fips140/aes")
    20  	ConstraintExpr("!purego")
    21  
    22  	ctrBlocks(1)
    23  	ctrBlocks(2)
    24  	ctrBlocks(4)
    25  	ctrBlocks(8)
    26  
    27  	Generate()
    28  }
    29  
    30  func ctrBlocks(numBlocks int) {
    31  	Implement(fmt.Sprintf("ctrBlocks%dAsm", numBlocks))
    32  
    33  	rounds := Load(Param("nr"), GP64())
    34  	xk := Load(Param("xk"), GP64())
    35  	dst := Load(Param("dst"), GP64())
    36  	src := Load(Param("src"), GP64())
    37  	ivlo := Load(Param("ivlo"), GP64())
    38  	ivhi := Load(Param("ivhi"), GP64())
    39  
    40  	bswap := XMM()
    41  	MOVOU(bswapMask(), bswap)
    42  
    43  	blocks := make([]VecVirtual, numBlocks)
    44  
    45  	// For the 8-block case we optimize counter generation. We build the first
    46  	// counter as usual, then check whether the remaining seven increments will
    47  	// overflow. When they do not (the common case) we keep the work entirely in
    48  	// XMM registers to avoid expensive general-purpose -> XMM moves. Otherwise
    49  	// we fall back to the traditional scalar path.
    50  	if numBlocks == 8 {
    51  		for i := range blocks {
    52  			blocks[i] = XMM()
    53  		}
    54  
    55  		base := XMM()
    56  		tmp := GP64()
    57  		addVec := XMM()
    58  
    59  		MOVQ(ivlo, blocks[0])
    60  		PINSRQ(Imm(1), ivhi, blocks[0])
    61  		MOVAPS(blocks[0], base)
    62  		PSHUFB(bswap, blocks[0])
    63  
    64  		// Check whether any of these eight counters will overflow.
    65  		MOVQ(ivlo, tmp)
    66  		ADDQ(Imm(uint64(numBlocks-1)), tmp)
    67  		slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks)
    68  		doneLabel := fmt.Sprintf("ctr%d_done", numBlocks)
    69  		JC(LabelRef(slowLabel))
    70  
    71  		// Fast branch: create an XMM increment vector containing the value 1.
    72  		// Adding it to the base counter yields each subsequent counter.
    73  		XORQ(tmp, tmp)
    74  		INCQ(tmp)
    75  		PXOR(addVec, addVec)
    76  		PINSRQ(Imm(0), tmp, addVec)
    77  
    78  		for i := 1; i < numBlocks; i++ {
    79  			PADDQ(addVec, base)
    80  			MOVAPS(base, blocks[i])
    81  		}
    82  		JMP(LabelRef(doneLabel))
    83  
    84  		Label(slowLabel)
    85  		ADDQ(Imm(1), ivlo)
    86  		ADCQ(Imm(0), ivhi)
    87  		for i := 1; i < numBlocks; i++ {
    88  			MOVQ(ivlo, blocks[i])
    89  			PINSRQ(Imm(1), ivhi, blocks[i])
    90  			if i < numBlocks-1 {
    91  				ADDQ(Imm(1), ivlo)
    92  				ADCQ(Imm(0), ivhi)
    93  			}
    94  		}
    95  
    96  		Label(doneLabel)
    97  
    98  		// Convert little-endian counters to big-endian after the branch since
    99  		// both paths share the same shuffle sequence.
   100  		for i := 1; i < numBlocks; i++ {
   101  			PSHUFB(bswap, blocks[i])
   102  		}
   103  	} else {
   104  		// Lay out counter block plaintext.
   105  		for i := 0; i < numBlocks; i++ {
   106  			x := XMM()
   107  			blocks[i] = x
   108  
   109  			MOVQ(ivlo, x)
   110  			PINSRQ(Imm(1), ivhi, x)
   111  			PSHUFB(bswap, x)
   112  			if i < numBlocks-1 {
   113  				ADDQ(Imm(1), ivlo)
   114  				ADCQ(Imm(0), ivhi)
   115  			}
   116  		}
   117  	}
   118  
   119  	// Initial key add.
   120  	aesRoundStart(blocks, Mem{Base: xk})
   121  	ADDQ(Imm(16), xk)
   122  
   123  	// Branch based on the number of rounds.
   124  	SUBQ(Imm(12), rounds)
   125  	JE(LabelRef("enc192"))
   126  	JB(LabelRef("enc128"))
   127  
   128  	// Two extra rounds for 256-bit keys.
   129  	aesRound(blocks, Mem{Base: xk})
   130  	aesRound(blocks, Mem{Base: xk}.Offset(16))
   131  	ADDQ(Imm(32), xk)
   132  
   133  	// Two extra rounds for 192-bit keys.
   134  	Label("enc192")
   135  	aesRound(blocks, Mem{Base: xk})
   136  	aesRound(blocks, Mem{Base: xk}.Offset(16))
   137  	ADDQ(Imm(32), xk)
   138  
   139  	// 10 rounds for 128-bit keys (with special handling for the final round).
   140  	Label("enc128")
   141  	for i := 0; i < 9; i++ {
   142  		aesRound(blocks, Mem{Base: xk}.Offset(16*i))
   143  	}
   144  	aesRoundLast(blocks, Mem{Base: xk}.Offset(16*9))
   145  
   146  	// XOR state with src and write back to dst.
   147  	for i, b := range blocks {
   148  		x := XMM()
   149  
   150  		MOVUPS(Mem{Base: src}.Offset(16*i), x)
   151  		PXOR(b, x)
   152  		MOVUPS(x, Mem{Base: dst}.Offset(16*i))
   153  	}
   154  
   155  	RET()
   156  }
   157  
   158  func aesRoundStart(blocks []VecVirtual, k Mem) {
   159  	x := XMM()
   160  	MOVUPS(k, x)
   161  	for _, b := range blocks {
   162  		PXOR(x, b)
   163  	}
   164  }
   165  
   166  func aesRound(blocks []VecVirtual, k Mem) {
   167  	x := XMM()
   168  	MOVUPS(k, x)
   169  	for _, b := range blocks {
   170  		AESENC(x, b)
   171  	}
   172  }
   173  
   174  func aesRoundLast(blocks []VecVirtual, k Mem) {
   175  	x := XMM()
   176  	MOVUPS(k, x)
   177  	for _, b := range blocks {
   178  		AESENCLAST(x, b)
   179  	}
   180  }
   181  
   182  var bswapMask = sync.OnceValue(func() Mem {
   183  	bswapMask := GLOBL("bswapMask", NOPTR|RODATA)
   184  	DATA(0x00, U64(0x08090a0b0c0d0e0f))
   185  	DATA(0x08, U64(0x0001020304050607))
   186  	return bswapMask
   187  })
   188  

View as plain text