Source file src/cmd/compile/internal/ssa/known_bits.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  func (kb *knownBitsState) fold(v *Value) (value, known int64) {
     8  	if kb.seenValues.Test(uint32(v.ID)) {
     9  		return kb.entries[v.ID].value, kb.entries[v.ID].known
    10  	}
    11  	defer func() {
    12  		// maintain the invariants:
    13  		// 3. booleans are stored as 1 byte values who are either 0 or 1.
    14  		if v.Type.IsBoolean() {
    15  			value &= 1
    16  			known |= ^1
    17  		}
    18  
    19  		// 2. all values are sign-extended to int64 (inspired by RISC-V's xlen=64)
    20  		switch v.Type.Size() {
    21  		case 1:
    22  			value = int64(int8(value))
    23  			known = int64(int8(known))
    24  		case 2:
    25  			value = int64(int16(value))
    26  			known = int64(int16(known))
    27  		case 4:
    28  			value = int64(int32(value))
    29  			known = int64(int32(known))
    30  		case 8:
    31  		default:
    32  			panic("unreachable; unknown integer size")
    33  		}
    34  
    35  		// 1. unknown bits are always set to 0 inside value
    36  		value &= known
    37  
    38  		if v.Block.Func.pass.debug > 1 {
    39  			v.Block.Func.Warnl(v.Pos, "known bits state %v: k:%d v:%d", v, known, value)
    40  		}
    41  		kb.entries[v.ID].known = known
    42  		kb.entries[v.ID].value = value
    43  	}()
    44  	kb.seenValues.Set(uint32(v.ID)) // set seen early to give up on loops
    45  
    46  	switch v.Op {
    47  	// TODO: rotates, ...
    48  	case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool:
    49  		return v.AuxInt, -1
    50  	case OpAnd64, OpAnd32, OpAnd16, OpAnd8, OpAndB:
    51  		x, xk := kb.fold(v.Args[0])
    52  		y, yk := kb.fold(v.Args[1])
    53  		onesInBoth := x & y
    54  		zerosInX := ^x & xk
    55  		zerosInY := ^y & yk
    56  		return x & y, onesInBoth | zerosInX | zerosInY
    57  	case OpOr64, OpOr32, OpOr16, OpOr8, OpOrB:
    58  		x, xk := kb.fold(v.Args[0])
    59  		y, yk := kb.fold(v.Args[1])
    60  		zerosInBoth := ^x & ^y & (xk & yk)
    61  		onesInX := x
    62  		onesInY := y
    63  		return x | y, onesInX | onesInY | zerosInBoth
    64  	case OpXor64, OpXor32, OpXor16, OpXor8:
    65  		x, xk := kb.fold(v.Args[0])
    66  		y, yk := kb.fold(v.Args[1])
    67  		return x ^ y, xk & yk
    68  	case OpCom64, OpCom32, OpCom16, OpCom8, OpNot:
    69  		x, xk := kb.fold(v.Args[0])
    70  		return ^x, xk
    71  	case OpPhi:
    72  		set := false
    73  		for i, arg := range v.Args {
    74  			if !kb.isLiveInEdge(v.Block, uint(i)) {
    75  				continue
    76  			}
    77  			a, k := kb.fold(arg)
    78  			if !set {
    79  				value, known = a, k
    80  				set = true
    81  			} else {
    82  				known &^= value ^ a
    83  				known &= k
    84  			}
    85  			if known == 0 {
    86  				break
    87  			}
    88  		}
    89  		return value, known
    90  	case OpCopy, OpCvtBoolToUint8,
    91  		OpSignExt8to16, OpSignExt8to32, OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
    92  		// The defer block handles maintaining the sign-extension invariant using v.Type.Size()
    93  		// thus we can just pass Truncs as-is.
    94  		OpTrunc64to32, OpTrunc64to16, OpTrunc64to8, OpTrunc32to16, OpTrunc32to8, OpTrunc16to8:
    95  		return kb.fold(v.Args[0])
    96  	case OpEq64, OpEq32, OpEq16, OpEq8, OpEqB:
    97  		x, xk := kb.fold(v.Args[0])
    98  		y, yk := kb.fold(v.Args[1])
    99  		differentBits := x ^ y
   100  		if differentBits&xk&yk != 0 {
   101  			return 0, -1
   102  		}
   103  		if xk == -1 && yk == -1 {
   104  			return boolToAuxInt(x == y), -1
   105  		}
   106  		return 0, -1 << 1
   107  	case OpNeq64, OpNeq32, OpNeq16, OpNeq8, OpNeqB:
   108  		x, xk := kb.fold(v.Args[0])
   109  		y, yk := kb.fold(v.Args[1])
   110  		differentBits := x ^ y
   111  		if differentBits&xk&yk != 0 {
   112  			return 1, -1
   113  		}
   114  		if xk == -1 && yk == -1 {
   115  			return boolToAuxInt(x != y), -1
   116  		}
   117  		return 0, -1 << 1
   118  	case OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64:
   119  		x, k := kb.fold(v.Args[0])
   120  		srcSize := v.Args[0].Type.Size() * 8
   121  		mask := int64(1<<srcSize - 1)
   122  		return x & mask, k | ^mask
   123  	case OpLsh8x8, OpLsh16x8, OpLsh32x8, OpLsh64x8,
   124  		OpLsh8x16, OpLsh16x16, OpLsh32x16, OpLsh64x16,
   125  		OpLsh8x32, OpLsh16x32, OpLsh32x32, OpLsh64x32,
   126  		OpLsh8x64, OpLsh16x64, OpLsh32x64, OpLsh64x64:
   127  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   128  			return x << shift, xk<<shift | (1<<shift - 1)
   129  		})
   130  	case OpRsh8Ux8, OpRsh16Ux8, OpRsh32Ux8, OpRsh64Ux8,
   131  		OpRsh8Ux16, OpRsh16Ux16, OpRsh32Ux16, OpRsh64Ux16,
   132  		OpRsh8Ux32, OpRsh16Ux32, OpRsh32Ux32, OpRsh64Ux32,
   133  		OpRsh8Ux64, OpRsh16Ux64, OpRsh32Ux64, OpRsh64Ux64:
   134  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   135  			x &= (1<<xSize - 1)
   136  			xk |= -1 << xSize
   137  			return int64(uint64(x) >> shift), int64(uint64(xk)>>shift | (^uint64(0) << (64 - shift)))
   138  		})
   139  	case OpRsh8x8, OpRsh16x8, OpRsh32x8, OpRsh64x8,
   140  		OpRsh8x16, OpRsh16x16, OpRsh32x16, OpRsh64x16,
   141  		OpRsh8x32, OpRsh16x32, OpRsh32x32, OpRsh64x32,
   142  		OpRsh8x64, OpRsh16x64, OpRsh32x64, OpRsh64x64:
   143  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   144  			return x >> shift, xk >> shift
   145  		})
   146  	default:
   147  		return 0, 0
   148  	}
   149  }
   150  
   151  // knownBits does constant folding across bitfields
   152  func knownBits(f *Func) {
   153  	kb := &knownBitsState{
   154  		entries:         f.Cache.allocKnownBitsEntriesSlice(f.NumValues()),
   155  		seenValues:      f.Cache.allocBitset(f.NumValues()),
   156  		reachableBlocks: f.Cache.allocBitset(f.NumBlocks()),
   157  	}
   158  	defer f.Cache.freeKnownBitsEntriesSlice(kb.entries)
   159  	defer f.Cache.freeBitset(kb.seenValues)
   160  	defer f.Cache.freeBitset(kb.reachableBlocks)
   161  	clear(kb.seenValues)
   162  	clear(kb.entries)
   163  	clear(kb.reachableBlocks)
   164  
   165  	blocks := f.postorder()
   166  	for _, b := range blocks {
   167  		kb.reachableBlocks.Set(uint32(b.ID))
   168  	}
   169  
   170  	for _, b := range blocks {
   171  		for _, v := range b.Values {
   172  			if v.Uses == 0 || !(v.Type.IsInteger() || v.Type.IsBoolean()) {
   173  				continue
   174  			}
   175  			switch v.Op {
   176  			case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool:
   177  				continue
   178  			}
   179  			val, k := kb.fold(v)
   180  			if k != -1 {
   181  				continue
   182  			}
   183  			if f.pass.debug > 0 {
   184  				var pval any = val
   185  				if v.Type.IsBoolean() {
   186  					pval = val != 0
   187  				}
   188  				f.Warnl(v.Pos, "known value of %v (%v): %v", v, v.Op, pval)
   189  			}
   190  			var c *Value
   191  			switch v.Type.Size() {
   192  			case 1:
   193  				if v.Type.IsBoolean() {
   194  					c = f.ConstBool(v.Type, val != 0)
   195  					break
   196  				}
   197  				c = f.ConstInt8(v.Type, int8(val))
   198  			case 2:
   199  				c = f.ConstInt16(v.Type, int16(val))
   200  			case 4:
   201  				c = f.ConstInt32(v.Type, int32(val))
   202  			case 8:
   203  				c = f.ConstInt64(v.Type, val)
   204  			default:
   205  				panic("unreachable; unknown integer size")
   206  			}
   207  			v.copyOf(c)
   208  		}
   209  	}
   210  }
   211  
   212  type knownBitsState struct {
   213  	entries         []knownBitsEntry // indexed by Value.ID
   214  	seenValues      bitset           // indexed by Value.ID (at the bit level)
   215  	reachableBlocks bitset           // indexed by Block.ID (at the bit level)
   216  }
   217  
   218  type knownBitsEntry struct {
   219  	// Two invariants:
   220  	// 1. unknown bits are always set to 0 inside value
   221  	// 2. all values are sign-extended to int64 (inspired by RISC-V's xlen=64)
   222  	//    This means let's say you know an 8 bits value is 0b10??????,
   223  	//    known = int64(int8(0b11000000))
   224  	//    value = int64(int8(0b10000000))
   225  	// 3. booleans are stored as 1 byte values who are either 0 or 1.
   226  	known, value int64
   227  }
   228  
   229  func (kb *knownBitsState) isLiveInEdge(b *Block, index uint) bool {
   230  	inEdge := b.Preds[index]
   231  	return kb.isLiveOutEdge(inEdge.b, uint(inEdge.i))
   232  }
   233  
   234  func (kb *knownBitsState) isLiveOutEdge(b *Block, index uint) bool {
   235  	if !kb.reachableBlocks.Test(uint32(b.ID)) {
   236  		return false
   237  	}
   238  
   239  	switch b.Kind {
   240  	case BlockFirst:
   241  		return index == 0
   242  	case BlockPlain, BlockIf, BlockDefer, BlockRet, BlockRetJmp, BlockExit, BlockJumpTable:
   243  		return true
   244  	default:
   245  		panic("unreachable; unknown block kind")
   246  	}
   247  }
   248  
   249  // computeKnownBitsForShift computes the known bits for a shift operation.
   250  // Considering the following piece of code x = x << uint8(i)
   251  // The algorithm is based on two observations:
   252  //
   253  //  1. computing a shift of a lattice by a constant (i) is easy:
   254  //     value, known = x<<i, xk<<i|(1<<i-1)
   255  //     each point in the lattice is shifted by the constant, all new shifted in bits are known zeros.
   256  //
   257  //  2. x = uint8(x) << i is equivalent to
   258  //
   259  //     switch i {
   260  //     case 0:  x0 = x << 0
   261  //     case 1:  x1 = x << 1
   262  //     case 2:  x2 = x << 2
   263  //     case 3:  x3 = x << 3
   264  //     case 4:  x4 = x << 4
   265  //     case 5:  x5 = x << 5
   266  //     case 6:  x6 = x << 6
   267  //     case 7:  x7 = x << 7
   268  //     default: xd = x << 8
   269  //     }
   270  //     x = phi(x0, x1, x2, x3, x4, x5, x6, x7, xd)
   271  //
   272  // The algorithm below then models the phi in the equivalence above using same intersection algorithm phi uses.
   273  // We also leverage known bits of the shift amount to remove "branches" in the switch that are proved to be impossible.
   274  func (kb *knownBitsState) computeKnownBitsForShift(v *Value, doShiftByAConst func(x, xk, xSize, shift int64) (value, known int64)) (value, known int64) {
   275  	xSize := v.Args[0].Type.Size() * 8
   276  	x, xk := kb.fold(v.Args[0])
   277  	y, yk := kb.fold(v.Args[1])
   278  	if uint64(y) >= uint64(xSize) {
   279  		return doShiftByAConst(x, xk, xSize, 64)
   280  	}
   281  
   282  	set := false
   283  	if v.AuxInt == 0 && uint64(^yk) >= uint64(xSize) {
   284  		// this implement the default case of the equivalent switch above.
   285  		// if the shift isn't bounded and there are unknown bits above the shift size we might completely stomp all bits.
   286  
   287  		value, known = doShiftByAConst(x, xk, xSize, 64)
   288  		set = true
   289  	}
   290  	yk &= xSize - 1
   291  
   292  	for i := range xSize {
   293  		if i&yk != y {
   294  			continue
   295  		}
   296  		a, k := doShiftByAConst(x, xk, xSize, int64(i))
   297  		if !set {
   298  			value, known = a, k
   299  			set = true
   300  		} else {
   301  			known &^= value ^ a
   302  			known &= k
   303  		}
   304  		if known == 0 {
   305  			break
   306  		}
   307  	}
   308  
   309  	return value & known, known
   310  }
   311  

View as plain text