Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/stditerators.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/token"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/ast/edge"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/analysis/analyzerutil"
    18  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    19  	"golang.org/x/tools/internal/astutil"
    20  	"golang.org/x/tools/internal/goplsexport"
    21  	"golang.org/x/tools/internal/refactor"
    22  	"golang.org/x/tools/internal/stdlib"
    23  	"golang.org/x/tools/internal/typesinternal/typeindex"
    24  )
    25  
    26  var stditeratorsAnalyzer = &analysis.Analyzer{
    27  	Name: "stditerators",
    28  	Doc:  analyzerutil.MustExtractDoc(doc, "stditerators"),
    29  	Requires: []*analysis.Analyzer{
    30  		typeindexanalyzer.Analyzer,
    31  	},
    32  	Run: stditerators,
    33  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#stditerators",
    34  }
    35  
    36  func init() {
    37  	// Export to gopls until this is a published modernizer.
    38  	goplsexport.StdIteratorsModernizer = stditeratorsAnalyzer
    39  }
    40  
    41  // stditeratorsTable records std types that have legacy T.{Len,At}
    42  // iteration methods as well as a newer T.All method that returns an
    43  // iter.Seq.
    44  var stditeratorsTable = [...]struct {
    45  	pkgpath, typename, lenmethod, atmethod, itermethod, elemname string
    46  
    47  	seqn int // 1 or 2 => "for x" or "for _, x"
    48  }{
    49  	// Example: in go/types, (*Tuple).Variables returns an
    50  	// iterator that replaces a loop over (*Tuple).{Len,At}.
    51  	// The loop variable is named "v".
    52  	{"go/types", "Interface", "NumEmbeddeds", "EmbeddedType", "EmbeddedTypes", "etyp", 1},
    53  	{"go/types", "Interface", "NumExplicitMethods", "ExplicitMethod", "ExplicitMethods", "method", 1},
    54  	{"go/types", "Interface", "NumMethods", "Method", "Methods", "method", 1},
    55  	{"go/types", "MethodSet", "Len", "At", "Methods", "method", 1},
    56  	{"go/types", "Named", "NumMethods", "Method", "Methods", "method", 1},
    57  	{"go/types", "Scope", "NumChildren", "Child", "Children", "child", 1},
    58  	{"go/types", "Struct", "NumFields", "Field", "Fields", "field", 1},
    59  	{"go/types", "Tuple", "Len", "At", "Variables", "v", 1},
    60  	{"go/types", "TypeList", "Len", "At", "Types", "t", 1},
    61  	{"go/types", "TypeParamList", "Len", "At", "TypeParams", "tparam", 1},
    62  	{"go/types", "Union", "Len", "Term", "Terms", "term", 1},
    63  	{"reflect", "Type", "NumField", "Field", "Fields", "field", 1},
    64  	{"reflect", "Type", "NumMethod", "Method", "Methods", "method", 1},
    65  	{"reflect", "Type", "NumIn", "In", "Ins", "in", 1},
    66  	{"reflect", "Type", "NumOut", "Out", "Outs", "out", 1},
    67  	{"reflect", "Value", "NumField", "Field", "Fields", "field", 2},
    68  	{"reflect", "Value", "NumMethod", "Method", "Methods", "method", 2},
    69  }
    70  
    71  // stditerators suggests fixes to replace loops using Len/At-style
    72  // iterator APIs by a range loop over an iterator. The set of
    73  // participating types and methods is defined by [iteratorsTable].
    74  //
    75  // Pattern:
    76  //
    77  //	for i := 0; i < x.Len(); i++ {
    78  //		use(x.At(i))
    79  //	}
    80  //
    81  // =>
    82  //
    83  //	for elem := range x.All() {
    84  //		use(elem)
    85  //	}
    86  //
    87  // Variant:
    88  //
    89  //	for i := range x.Len() { ... }
    90  //
    91  // Note: Iterators have a dynamic cost. How do we know that
    92  // the user hasn't intentionally chosen not to use an
    93  // iterator for that reason? We don't want to go fix to
    94  // undo optimizations. Do we need a suppression mechanism?
    95  //
    96  // TODO(adonovan): recognize the more complex patterns that
    97  // could make full use of both components of an iter.Seq2, e.g.
    98  //
    99  //	for i := 0; i < v.NumField(); i++ {
   100  //		use(v.Field(i), v.Type().Field(i))
   101  //	}
   102  //
   103  // =>
   104  //
   105  //	for structField, field := range v.Fields() {
   106  //		use(structField, field)
   107  //	}
   108  func stditerators(pass *analysis.Pass) (any, error) {
   109  	var (
   110  		index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
   111  		info  = pass.TypesInfo
   112  	)
   113  
   114  	for _, row := range stditeratorsTable {
   115  		// Don't offer fixes within the package
   116  		// that defines the iterator in question.
   117  		if within(pass, row.pkgpath) {
   118  			continue
   119  		}
   120  
   121  		var (
   122  			lenMethod = index.Selection(row.pkgpath, row.typename, row.lenmethod)
   123  			atMethod  = index.Selection(row.pkgpath, row.typename, row.atmethod)
   124  		)
   125  
   126  		// chooseName returns an appropriate fresh name
   127  		// for the index variable of the iterator loop
   128  		// whose body is specified.
   129  		//
   130  		// If the loop body starts with
   131  		//
   132  		//     for ... { e := x.At(i); use(e) }
   133  		//
   134  		// or
   135  		//
   136  		//     for ... { if e := x.At(i); cond { use(e) } }
   137  		//
   138  		// then chooseName prefers the name e and additionally
   139  		// returns the var's symbol. We'll transform this to:
   140  		//
   141  		//     for e := range x.Len() { e := e; use(e) }
   142  		//
   143  		// which leaves a redundant assignment that a
   144  		// subsequent 'forvar' pass will eliminate.
   145  		chooseName := func(curBody inspector.Cursor, x ast.Expr, i *types.Var) (string, *types.Var) {
   146  
   147  			// isVarAssign reports whether stmt has the form v := x.At(i)
   148  			// and returns the variable if so.
   149  			isVarAssign := func(stmt ast.Stmt) *types.Var {
   150  				if assign, ok := stmt.(*ast.AssignStmt); ok &&
   151  					assign.Tok == token.DEFINE &&
   152  					len(assign.Lhs) == 1 &&
   153  					len(assign.Rhs) == 1 &&
   154  					is[*ast.Ident](assign.Lhs[0]) {
   155  					// call to x.At(i)?
   156  					if call, ok := assign.Rhs[0].(*ast.CallExpr); ok &&
   157  						typeutil.Callee(info, call) == atMethod &&
   158  						astutil.EqualSyntax(ast.Unparen(call.Fun).(*ast.SelectorExpr).X, x) &&
   159  						is[*ast.Ident](call.Args[0]) &&
   160  						info.Uses[call.Args[0].(*ast.Ident)] == i {
   161  						// Have: elem := x.At(i)
   162  						id := assign.Lhs[0].(*ast.Ident)
   163  						return info.Defs[id].(*types.Var)
   164  					}
   165  				}
   166  				return nil
   167  			}
   168  
   169  			body := curBody.Node().(*ast.BlockStmt)
   170  			if len(body.List) > 0 {
   171  				// Is body { elem := x.At(i); ... } ?
   172  				if v := isVarAssign(body.List[0]); v != nil {
   173  					return v.Name(), v
   174  				}
   175  
   176  				// Or { if elem := x.At(i); cond { ... } } ?
   177  				if ifstmt, ok := body.List[0].(*ast.IfStmt); ok && ifstmt.Init != nil {
   178  					if v := isVarAssign(ifstmt.Init); v != nil {
   179  						return v.Name(), v
   180  					}
   181  				}
   182  			}
   183  
   184  			loop := curBody.Parent().Node()
   185  
   186  			// Choose a fresh name only if
   187  			// (a) the preferred name is already declared here, and
   188  			// (b) there are references to it from the loop body.
   189  			// TODO(adonovan): this pattern also appears in errorsastype,
   190  			// and is wanted elsewhere; factor.
   191  			name := row.elemname
   192  			if v := lookup(info, curBody, name); v != nil {
   193  				// is it free in body?
   194  				for curUse := range index.Uses(v) {
   195  					if curBody.Contains(curUse) {
   196  						name = refactor.FreshName(info.Scopes[loop], loop.Pos(), name)
   197  						break
   198  					}
   199  				}
   200  			}
   201  			return name, nil
   202  		}
   203  
   204  		// Process each call of x.Len().
   205  	nextCall:
   206  		for curLenCall := range index.Calls(lenMethod) {
   207  			lenSel, ok := ast.Unparen(curLenCall.Node().(*ast.CallExpr).Fun).(*ast.SelectorExpr)
   208  			if !ok {
   209  				continue
   210  			}
   211  			// lenSel is "x.Len"
   212  
   213  			var (
   214  				rng      analysis.Range   // where to report diagnostic
   215  				curBody  inspector.Cursor // loop body
   216  				indexVar *types.Var       // old loop index var
   217  				elemVar  *types.Var       // existing "elem := x.At(i)" var, if present
   218  				elem     string           // name for new loop var
   219  				edits    []analysis.TextEdit
   220  			)
   221  
   222  			// Analyze enclosing loop.
   223  			switch first(curLenCall.ParentEdge()) {
   224  			case edge.BinaryExpr_Y:
   225  				// pattern 1: for i := 0; i < x.Len(); i++ { ... }
   226  				var (
   227  					curCmp = curLenCall.Parent()
   228  					cmp    = curCmp.Node().(*ast.BinaryExpr)
   229  				)
   230  				if cmp.Op != token.LSS ||
   231  					!astutil.IsChildOf(curCmp, edge.ForStmt_Cond) {
   232  					continue
   233  				}
   234  				if id, ok := cmp.X.(*ast.Ident); ok {
   235  					// Have: for _; i < x.Len(); _ { ... }
   236  					var (
   237  						v      = info.Uses[id].(*types.Var)
   238  						curFor = curCmp.Parent()
   239  						loop   = curFor.Node().(*ast.ForStmt)
   240  					)
   241  					if v != isIncrementLoop(info, loop) {
   242  						continue
   243  					}
   244  					// Have: for i := 0; i < x.Len(); i++ { ... }.
   245  					//       ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   246  					rng = astutil.RangeOf(loop.For, loop.Post.End())
   247  					indexVar = v
   248  					curBody = curFor.ChildAt(edge.ForStmt_Body, -1)
   249  					elem, elemVar = chooseName(curBody, lenSel.X, indexVar)
   250  					elemPrefix := cond(row.seqn == 2, "_, ", "")
   251  
   252  					//	for i       := 0; i < x.Len(); i++ {
   253  					//          ----       -------  ---  -----
   254  					//	for elem    := range  x.All()      {
   255  					// or   for _, elem := ...
   256  					edits = []analysis.TextEdit{
   257  						{
   258  							Pos:     v.Pos(),
   259  							End:     v.Pos() + token.Pos(len(v.Name())),
   260  							NewText: []byte(elemPrefix + elem),
   261  						},
   262  						{
   263  							Pos:     loop.Init.(*ast.AssignStmt).Rhs[0].Pos(),
   264  							End:     cmp.Y.Pos(),
   265  							NewText: []byte("range "),
   266  						},
   267  						{
   268  							Pos:     lenSel.Sel.Pos(),
   269  							End:     lenSel.Sel.End(),
   270  							NewText: []byte(row.itermethod),
   271  						},
   272  						{
   273  							Pos: curLenCall.Node().End(),
   274  							End: loop.Post.End(),
   275  						},
   276  					}
   277  				}
   278  
   279  			case edge.RangeStmt_X:
   280  				// pattern 2: for i := range x.Len() { ... }
   281  				var (
   282  					curRange = curLenCall.Parent()
   283  					loop     = curRange.Node().(*ast.RangeStmt)
   284  				)
   285  				if id, ok := loop.Key.(*ast.Ident); ok &&
   286  					loop.Value == nil &&
   287  					loop.Tok == token.DEFINE {
   288  					// Have: for i := range x.Len() { ... }
   289  					//                ~~~~~~~~~~~~~
   290  
   291  					rng = astutil.RangeOf(loop.Range, loop.X.End())
   292  					indexVar = info.Defs[id].(*types.Var)
   293  					curBody = curRange.ChildAt(edge.RangeStmt_Body, -1)
   294  					elem, elemVar = chooseName(curBody, lenSel.X, indexVar)
   295  					elemPrefix := cond(row.seqn == 2, "_, ", "")
   296  
   297  					//	for i    := range x.Len() {
   298  					//          ----            ---
   299  					//	for elem := range x.All() {
   300  					edits = []analysis.TextEdit{
   301  						{
   302  							Pos:     loop.Key.Pos(),
   303  							End:     loop.Key.End(),
   304  							NewText: []byte(elemPrefix + elem),
   305  						},
   306  						{
   307  							Pos:     lenSel.Sel.Pos(),
   308  							End:     lenSel.Sel.End(),
   309  							NewText: []byte(row.itermethod),
   310  						},
   311  					}
   312  				}
   313  			}
   314  
   315  			if indexVar == nil {
   316  				continue // no loop of the required form
   317  			}
   318  
   319  			// TODO(adonovan): what about possible
   320  			// modifications of x within the loop?
   321  			// Aliasing seems to make a conservative
   322  			// treatment impossible.
   323  
   324  			// Check that all uses of var i within loop body are x.At(i).
   325  			for curUse := range index.Uses(indexVar) {
   326  				if !curBody.Contains(curUse) {
   327  					continue
   328  				}
   329  				if ek, argidx := curUse.ParentEdge(); ek != edge.CallExpr_Args || argidx != 0 {
   330  					continue nextCall // use is not arg of call
   331  				}
   332  				curAtCall := curUse.Parent()
   333  				atCall := curAtCall.Node().(*ast.CallExpr)
   334  				if typeutil.Callee(info, atCall) != atMethod {
   335  					continue nextCall // use is not arg of call to T.At
   336  				}
   337  				atSel := ast.Unparen(atCall.Fun).(*ast.SelectorExpr)
   338  
   339  				// Check receivers of Len, At calls match (syntactically).
   340  				if !astutil.EqualSyntax(lenSel.X, atSel.X) {
   341  					continue nextCall
   342  				}
   343  
   344  				// At each point of use, check that
   345  				// the fresh variable is not shadowed
   346  				// by an intervening local declaration
   347  				// (or by the idiomatic elemVar optionally
   348  				// found by chooseName).
   349  				if obj := lookup(info, curAtCall, elem); obj != nil && obj != elemVar && obj.Pos() > indexVar.Pos() {
   350  					// (Ideally, instead of giving up, we would
   351  					// embellish the name and try again.)
   352  					continue nextCall
   353  				}
   354  
   355  				// use(x.At(i))
   356  				//     -------
   357  				// use(elem   )
   358  				edits = append(edits, analysis.TextEdit{
   359  					Pos:     atCall.Pos(),
   360  					End:     atCall.End(),
   361  					NewText: []byte(elem),
   362  				})
   363  			}
   364  
   365  			// Check file Go version is new enough for the iterator method.
   366  			// (In the long run, version filters are not highly selective,
   367  			// so there's no need to do them first, especially as this check
   368  			// may be somewhat expensive.)
   369  			if v, err := methodGoVersion(row.pkgpath, row.typename, row.itermethod); err != nil {
   370  				panic(err)
   371  			} else if !analyzerutil.FileUsesGoVersion(pass, astutil.EnclosingFile(curLenCall), v.String()) {
   372  				continue nextCall
   373  			}
   374  
   375  			pass.Report(analysis.Diagnostic{
   376  				Pos: rng.Pos(),
   377  				End: rng.End(),
   378  				Message: fmt.Sprintf("%s/%s loop can simplified using %s.%s iteration",
   379  					row.lenmethod, row.atmethod, row.typename, row.itermethod),
   380  				SuggestedFixes: []analysis.SuggestedFix{{
   381  					Message: fmt.Sprintf(
   382  						"Replace %s/%s loop with %s.%s iteration",
   383  						row.lenmethod, row.atmethod, row.typename, row.itermethod),
   384  					TextEdits: edits,
   385  				}},
   386  			})
   387  		}
   388  	}
   389  	return nil, nil
   390  }
   391  
   392  // -- helpers --
   393  
   394  // methodGoVersion reports the version at which the method
   395  // (pkgpath.recvtype).method appeared in the standard library.
   396  func methodGoVersion(pkgpath, recvtype, method string) (stdlib.Version, error) {
   397  	// TODO(adonovan): opt: this might be inefficient for large packages
   398  	// like go/types. If so, memoize using a map (and kill two birds with
   399  	// one stone by also memoizing the 'within' check above).
   400  	for _, sym := range stdlib.PackageSymbols[pkgpath] {
   401  		if sym.Kind == stdlib.Method {
   402  			_, recv, name := sym.SplitMethod()
   403  			if recv == recvtype && name == method {
   404  				return sym.Version, nil
   405  			}
   406  		}
   407  	}
   408  	return 0, fmt.Errorf("methodGoVersion: %s.%s.%s missing from stdlib manifest", pkgpath, recvtype, method)
   409  }
   410  

View as plain text