Source file src/cmd/compile/internal/compare/compare.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package compare contains code for generating comparison
     6  // routines for structs, strings and interfaces.
     7  package compare
     8  
     9  import (
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/ir"
    12  	"cmd/compile/internal/typecheck"
    13  	"cmd/compile/internal/types"
    14  	"fmt"
    15  	"math/bits"
    16  	"sort"
    17  )
    18  
    19  // IsRegularMemory reports whether t can be compared/hashed as regular memory.
    20  func IsRegularMemory(t *types.Type) bool {
    21  	return types.AlgType(t) == types.AMEM
    22  }
    23  
    24  // Memrun finds runs of struct fields for which memory-only algs are appropriate.
    25  // t is the parent struct type, and start is the field index at which to start the run.
    26  // size is the length in bytes of the memory included in the run.
    27  // next is the index just after the end of the memory run.
    28  func Memrun(t *types.Type, start int) (size int64, next int) {
    29  	next = start
    30  	for {
    31  		next++
    32  		if next == t.NumFields() {
    33  			break
    34  		}
    35  		// Stop run after a padded field.
    36  		if types.IsPaddedField(t, next-1) {
    37  			break
    38  		}
    39  		// Also, stop before a blank or non-memory field.
    40  		if f := t.Field(next); f.Sym.IsBlank() || !IsRegularMemory(f.Type) {
    41  			break
    42  		}
    43  		// For issue 46283, don't combine fields if the resulting load would
    44  		// require a larger alignment than the component fields.
    45  		if base.Ctxt.Arch.Alignment > 1 {
    46  			align := t.Alignment()
    47  			if off := t.Field(start).Offset; off&(align-1) != 0 {
    48  				// Offset is less aligned than the containing type.
    49  				// Use offset to determine alignment.
    50  				align = 1 << uint(bits.TrailingZeros64(uint64(off)))
    51  			}
    52  			size := t.Field(next).End() - t.Field(start).Offset
    53  			if size > align {
    54  				break
    55  			}
    56  		}
    57  	}
    58  	return t.Field(next-1).End() - t.Field(start).Offset, next
    59  }
    60  
    61  // EqCanPanic reports whether == on type t could panic (has an interface somewhere).
    62  // t must be comparable.
    63  func EqCanPanic(t *types.Type) bool {
    64  	switch t.Kind() {
    65  	default:
    66  		return false
    67  	case types.TINTER:
    68  		return true
    69  	case types.TARRAY:
    70  		return EqCanPanic(t.Elem())
    71  	case types.TSTRUCT:
    72  		for _, f := range t.Fields() {
    73  			if !f.Sym.IsBlank() && EqCanPanic(f.Type) {
    74  				return true
    75  			}
    76  		}
    77  		return false
    78  	}
    79  }
    80  
    81  // EqStructCost returns the cost of an equality comparison of two structs.
    82  //
    83  // The cost is determined using an algorithm which takes into consideration
    84  // the size of the registers in the current architecture and the size of the
    85  // memory-only fields in the struct.
    86  func EqStructCost(t *types.Type) int64 {
    87  	cost := int64(0)
    88  
    89  	for i, fields := 0, t.Fields(); i < len(fields); {
    90  		f := fields[i]
    91  
    92  		// Skip blank-named fields.
    93  		if f.Sym.IsBlank() {
    94  			i++
    95  			continue
    96  		}
    97  
    98  		n, _, next := eqStructFieldCost(t, i)
    99  
   100  		cost += n
   101  		i = next
   102  	}
   103  
   104  	return cost
   105  }
   106  
   107  // eqStructFieldCost returns the cost of an equality comparison of two struct fields.
   108  // t is the parent struct type, and i is the index of the field in the parent struct type.
   109  // eqStructFieldCost may compute the cost of several adjacent fields at once. It returns
   110  // the cost, the size of the set of fields it computed the cost for (in bytes), and the
   111  // index of the first field not part of the set of fields for which the cost
   112  // has already been calculated.
   113  func eqStructFieldCost(t *types.Type, i int) (int64, int64, int) {
   114  	var (
   115  		cost    = int64(0)
   116  		regSize = int64(types.RegSize)
   117  
   118  		size int64
   119  		next int
   120  	)
   121  
   122  	if base.Ctxt.Arch.CanMergeLoads {
   123  		// If we can merge adjacent loads then we can calculate the cost of the
   124  		// comparison using the size of the memory run and the size of the registers.
   125  		size, next = Memrun(t, i)
   126  		cost = size / regSize
   127  		if size%regSize != 0 {
   128  			cost++
   129  		}
   130  		return cost, size, next
   131  	}
   132  
   133  	// If we cannot merge adjacent loads then we have to use the size of the
   134  	// field and take into account the type to determine how many loads and compares
   135  	// are needed.
   136  	ft := t.Field(i).Type
   137  	size = ft.Size()
   138  	next = i + 1
   139  
   140  	return calculateCostForType(ft), size, next
   141  }
   142  
   143  func calculateCostForType(t *types.Type) int64 {
   144  	var cost int64
   145  	switch t.Kind() {
   146  	case types.TSTRUCT:
   147  		return EqStructCost(t)
   148  	case types.TSLICE:
   149  		// Slices are not comparable.
   150  		base.Fatalf("calculateCostForType: unexpected slice type")
   151  	case types.TARRAY:
   152  		elemCost := calculateCostForType(t.Elem())
   153  		cost = t.NumElem() * elemCost
   154  	case types.TSTRING, types.TINTER, types.TCOMPLEX64, types.TCOMPLEX128:
   155  		cost = 2
   156  	case types.TINT64, types.TUINT64:
   157  		cost = 8 / int64(types.RegSize)
   158  	default:
   159  		cost = 1
   160  	}
   161  	return cost
   162  }
   163  
   164  // EqStruct compares two structs np and nq for equality.
   165  // It works by building a list of boolean conditions to satisfy.
   166  // Conditions must be evaluated in the returned order and
   167  // properly short-circuited by the caller.
   168  // The first return value is the flattened list of conditions,
   169  // the second value is a boolean indicating whether any of the
   170  // comparisons could panic.
   171  func EqStruct(t *types.Type, np, nq ir.Node) ([]ir.Node, bool) {
   172  	// The conditions are a list-of-lists. Conditions are reorderable
   173  	// within each inner list. The outer lists must be evaluated in order.
   174  	var conds [][]ir.Node
   175  	conds = append(conds, []ir.Node{})
   176  	and := func(n ir.Node) {
   177  		i := len(conds) - 1
   178  		conds[i] = append(conds[i], n)
   179  	}
   180  
   181  	// Walk the struct using memequal for runs of AMEM
   182  	// and calling specific equality tests for the others.
   183  	for i, fields := 0, t.Fields(); i < len(fields); {
   184  		f := fields[i]
   185  
   186  		// Skip blank-named fields.
   187  		if f.Sym.IsBlank() {
   188  			i++
   189  			continue
   190  		}
   191  
   192  		typeCanPanic := EqCanPanic(f.Type)
   193  
   194  		// Compare non-memory fields with field equality.
   195  		if !IsRegularMemory(f.Type) {
   196  			if typeCanPanic {
   197  				// Enforce ordering by starting a new set of reorderable conditions.
   198  				conds = append(conds, []ir.Node{})
   199  			}
   200  			switch {
   201  			case f.Type.IsString():
   202  				p := typecheck.DotField(base.Pos, typecheck.Expr(np), i)
   203  				q := typecheck.DotField(base.Pos, typecheck.Expr(nq), i)
   204  				eqlen, eqmem := EqString(p, q)
   205  				and(eqlen)
   206  				and(eqmem)
   207  			default:
   208  				and(eqfield(np, nq, i))
   209  			}
   210  			if typeCanPanic {
   211  				// Also enforce ordering after something that can panic.
   212  				conds = append(conds, []ir.Node{})
   213  			}
   214  			i++
   215  			continue
   216  		}
   217  
   218  		cost, size, next := eqStructFieldCost(t, i)
   219  		if cost <= 4 {
   220  			// Cost of 4 or less: use plain field equality.
   221  			for j := i; j < next; j++ {
   222  				and(eqfield(np, nq, j))
   223  			}
   224  		} else {
   225  			// Higher cost: use memequal.
   226  			cc := eqmem(np, nq, i, size)
   227  			and(cc)
   228  		}
   229  		i = next
   230  	}
   231  
   232  	// Sort conditions to put runtime calls last.
   233  	// Preserve the rest of the ordering.
   234  	var flatConds []ir.Node
   235  	for _, c := range conds {
   236  		isCall := func(n ir.Node) bool {
   237  			return n.Op() == ir.OCALL || n.Op() == ir.OCALLFUNC
   238  		}
   239  		sort.SliceStable(c, func(i, j int) bool {
   240  			return !isCall(c[i]) && isCall(c[j])
   241  		})
   242  		flatConds = append(flatConds, c...)
   243  	}
   244  	return flatConds, len(conds) > 1
   245  }
   246  
   247  // EqString returns the nodes
   248  //
   249  //	len(s) == len(t)
   250  //
   251  // and
   252  //
   253  //	memequal(s.ptr, t.ptr, len(s))
   254  //
   255  // which can be used to construct string equality comparison.
   256  // eqlen must be evaluated before eqmem, and shortcircuiting is required.
   257  func EqString(s, t ir.Node) (eqlen *ir.BinaryExpr, eqmem *ir.CallExpr) {
   258  	s = typecheck.Conv(s, types.Types[types.TSTRING])
   259  	t = typecheck.Conv(t, types.Types[types.TSTRING])
   260  	sptr := ir.NewUnaryExpr(base.Pos, ir.OSPTR, s)
   261  	tptr := ir.NewUnaryExpr(base.Pos, ir.OSPTR, t)
   262  	slen := typecheck.Conv(ir.NewUnaryExpr(base.Pos, ir.OLEN, s), types.Types[types.TUINTPTR])
   263  	tlen := typecheck.Conv(ir.NewUnaryExpr(base.Pos, ir.OLEN, t), types.Types[types.TUINTPTR])
   264  
   265  	// Pick the 3rd arg to memequal. Both slen and tlen are fine to use, because we short
   266  	// circuit the memequal call if they aren't the same. But if one is a constant some
   267  	// memequal optimizations are easier to apply.
   268  	probablyConstant := func(n ir.Node) bool {
   269  		if n.Op() == ir.OCONVNOP {
   270  			n = n.(*ir.ConvExpr).X
   271  		}
   272  		if n.Op() == ir.OLITERAL {
   273  			return true
   274  		}
   275  		if n.Op() != ir.ONAME {
   276  			return false
   277  		}
   278  		name := n.(*ir.Name)
   279  		if name.Class != ir.PAUTO {
   280  			return false
   281  		}
   282  		if def := name.Defn; def == nil {
   283  			// n starts out as the empty string
   284  			return true
   285  		} else if def.Op() == ir.OAS && (def.(*ir.AssignStmt).Y == nil || def.(*ir.AssignStmt).Y.Op() == ir.OLITERAL) {
   286  			// n starts out as a constant string
   287  			return true
   288  		}
   289  		return false
   290  	}
   291  	cmplen := slen
   292  	if probablyConstant(t) && !probablyConstant(s) {
   293  		cmplen = tlen
   294  	}
   295  
   296  	fn := typecheck.LookupRuntime("memequal", types.Types[types.TUINT8], types.Types[types.TUINT8])
   297  	call := typecheck.Call(base.Pos, fn, []ir.Node{sptr, tptr, ir.Copy(cmplen)}, false).(*ir.CallExpr)
   298  
   299  	cmp := ir.NewBinaryExpr(base.Pos, ir.OEQ, slen, tlen)
   300  	cmp = typecheck.Expr(cmp).(*ir.BinaryExpr)
   301  	cmp.SetType(types.Types[types.TBOOL])
   302  	return cmp, call
   303  }
   304  
   305  // EqInterface returns the nodes
   306  //
   307  //	s.tab == t.tab (or s.typ == t.typ, as appropriate)
   308  //
   309  // and
   310  //
   311  //	ifaceeq(s.tab, s.data, t.data) (or efaceeq(s.typ, s.data, t.data), as appropriate)
   312  //
   313  // which can be used to construct interface equality comparison.
   314  // eqtab must be evaluated before eqdata, and shortcircuiting is required.
   315  func EqInterface(s, t ir.Node) (eqtab *ir.BinaryExpr, eqdata *ir.CallExpr) {
   316  	if !types.Identical(s.Type(), t.Type()) {
   317  		base.Fatalf("EqInterface %v %v", s.Type(), t.Type())
   318  	}
   319  	// func ifaceeq(tab *uintptr, x, y unsafe.Pointer) (ret bool)
   320  	// func efaceeq(typ *uintptr, x, y unsafe.Pointer) (ret bool)
   321  	var fn ir.Node
   322  	if s.Type().IsEmptyInterface() {
   323  		fn = typecheck.LookupRuntime("efaceeq")
   324  	} else {
   325  		fn = typecheck.LookupRuntime("ifaceeq")
   326  	}
   327  
   328  	stab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s)
   329  	ttab := ir.NewUnaryExpr(base.Pos, ir.OITAB, t)
   330  	sdata := ir.NewUnaryExpr(base.Pos, ir.OIDATA, s)
   331  	tdata := ir.NewUnaryExpr(base.Pos, ir.OIDATA, t)
   332  	sdata.SetType(types.Types[types.TUNSAFEPTR])
   333  	tdata.SetType(types.Types[types.TUNSAFEPTR])
   334  	sdata.SetTypecheck(1)
   335  	tdata.SetTypecheck(1)
   336  
   337  	call := typecheck.Call(base.Pos, fn, []ir.Node{stab, sdata, tdata}, false).(*ir.CallExpr)
   338  
   339  	cmp := ir.NewBinaryExpr(base.Pos, ir.OEQ, stab, ttab)
   340  	cmp = typecheck.Expr(cmp).(*ir.BinaryExpr)
   341  	cmp.SetType(types.Types[types.TBOOL])
   342  	return cmp, call
   343  }
   344  
   345  // eqfield returns the node
   346  //
   347  //	p.field == q.field
   348  func eqfield(p, q ir.Node, field int) ir.Node {
   349  	nx := typecheck.DotField(base.Pos, typecheck.Expr(p), field)
   350  	ny := typecheck.DotField(base.Pos, typecheck.Expr(q), field)
   351  	return typecheck.Expr(ir.NewBinaryExpr(base.Pos, ir.OEQ, nx, ny))
   352  }
   353  
   354  // eqmem returns the node
   355  //
   356  //	memequal(&p.field, &q.field, size)
   357  func eqmem(p, q ir.Node, field int, size int64) ir.Node {
   358  	nx := typecheck.Expr(typecheck.NodAddr(typecheck.DotField(base.Pos, p, field)))
   359  	ny := typecheck.Expr(typecheck.NodAddr(typecheck.DotField(base.Pos, q, field)))
   360  
   361  	fn, needsize := eqmemfunc(size, nx.Type().Elem())
   362  	call := ir.NewCallExpr(base.Pos, ir.OCALL, fn, nil)
   363  	call.Args.Append(nx)
   364  	call.Args.Append(ny)
   365  	if needsize {
   366  		call.Args.Append(ir.NewInt(base.Pos, size))
   367  	}
   368  
   369  	return call
   370  }
   371  
   372  func eqmemfunc(size int64, t *types.Type) (fn *ir.Name, needsize bool) {
   373  	if !base.Ctxt.Arch.CanMergeLoads && t.Alignment() < int64(base.Ctxt.Arch.Alignment) && t.Alignment() < t.Size() {
   374  		// We can't use larger comparisons if the value might not be aligned
   375  		// enough for the larger comparison. See issues 46283 and 67160.
   376  		size = 0
   377  	}
   378  	switch size {
   379  	case 1, 2, 4, 8, 16:
   380  		buf := fmt.Sprintf("memequal%d", int(size)*8)
   381  		return typecheck.LookupRuntime(buf, t, t), false
   382  	}
   383  
   384  	return typecheck.LookupRuntime("memequal", t, t), true
   385  }
   386  

View as plain text