1
2
3
4
5 package modernize
6
7 import (
8 "fmt"
9 "go/ast"
10 "go/token"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/ast/edge"
15 "golang.org/x/tools/go/ast/inspector"
16 "golang.org/x/tools/go/types/typeutil"
17 "golang.org/x/tools/internal/analysis/analyzerutil"
18 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
19 "golang.org/x/tools/internal/astutil"
20 "golang.org/x/tools/internal/goplsexport"
21 "golang.org/x/tools/internal/refactor"
22 "golang.org/x/tools/internal/stdlib"
23 "golang.org/x/tools/internal/typesinternal/typeindex"
24 )
25
26 var stditeratorsAnalyzer = &analysis.Analyzer{
27 Name: "stditerators",
28 Doc: analyzerutil.MustExtractDoc(doc, "stditerators"),
29 Requires: []*analysis.Analyzer{
30 typeindexanalyzer.Analyzer,
31 },
32 Run: stditerators,
33 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#stditerators",
34 }
35
36 func init() {
37
38 goplsexport.StdIteratorsModernizer = stditeratorsAnalyzer
39 }
40
41
42
43
44 var stditeratorsTable = [...]struct {
45 pkgpath, typename, lenmethod, atmethod, itermethod, elemname string
46
47 seqn int
48 }{
49
50
51
52 {"go/types", "Interface", "NumEmbeddeds", "EmbeddedType", "EmbeddedTypes", "etyp", 1},
53 {"go/types", "Interface", "NumExplicitMethods", "ExplicitMethod", "ExplicitMethods", "method", 1},
54 {"go/types", "Interface", "NumMethods", "Method", "Methods", "method", 1},
55 {"go/types", "MethodSet", "Len", "At", "Methods", "method", 1},
56 {"go/types", "Named", "NumMethods", "Method", "Methods", "method", 1},
57 {"go/types", "Scope", "NumChildren", "Child", "Children", "child", 1},
58 {"go/types", "Struct", "NumFields", "Field", "Fields", "field", 1},
59 {"go/types", "Tuple", "Len", "At", "Variables", "v", 1},
60 {"go/types", "TypeList", "Len", "At", "Types", "t", 1},
61 {"go/types", "TypeParamList", "Len", "At", "TypeParams", "tparam", 1},
62 {"go/types", "Union", "Len", "Term", "Terms", "term", 1},
63 {"reflect", "Type", "NumField", "Field", "Fields", "field", 1},
64 {"reflect", "Type", "NumMethod", "Method", "Methods", "method", 1},
65 {"reflect", "Type", "NumIn", "In", "Ins", "in", 1},
66 {"reflect", "Type", "NumOut", "Out", "Outs", "out", 1},
67 {"reflect", "Value", "NumField", "Field", "Fields", "field", 2},
68 {"reflect", "Value", "NumMethod", "Method", "Methods", "method", 2},
69 }
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108 func stditerators(pass *analysis.Pass) (any, error) {
109 var (
110 index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
111 info = pass.TypesInfo
112 )
113
114 for _, row := range stditeratorsTable {
115
116
117 if within(pass, row.pkgpath) {
118 continue
119 }
120
121 var (
122 lenMethod = index.Selection(row.pkgpath, row.typename, row.lenmethod)
123 atMethod = index.Selection(row.pkgpath, row.typename, row.atmethod)
124 )
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145 chooseName := func(curBody inspector.Cursor, x ast.Expr, i *types.Var) (string, *types.Var) {
146
147
148
149 isVarAssign := func(stmt ast.Stmt) *types.Var {
150 if assign, ok := stmt.(*ast.AssignStmt); ok &&
151 assign.Tok == token.DEFINE &&
152 len(assign.Lhs) == 1 &&
153 len(assign.Rhs) == 1 &&
154 is[*ast.Ident](assign.Lhs[0]) {
155
156 if call, ok := assign.Rhs[0].(*ast.CallExpr); ok &&
157 typeutil.Callee(info, call) == atMethod &&
158 astutil.EqualSyntax(ast.Unparen(call.Fun).(*ast.SelectorExpr).X, x) &&
159 is[*ast.Ident](call.Args[0]) &&
160 info.Uses[call.Args[0].(*ast.Ident)] == i {
161
162 id := assign.Lhs[0].(*ast.Ident)
163 return info.Defs[id].(*types.Var)
164 }
165 }
166 return nil
167 }
168
169 body := curBody.Node().(*ast.BlockStmt)
170 if len(body.List) > 0 {
171
172 if v := isVarAssign(body.List[0]); v != nil {
173 return v.Name(), v
174 }
175
176
177 if ifstmt, ok := body.List[0].(*ast.IfStmt); ok && ifstmt.Init != nil {
178 if v := isVarAssign(ifstmt.Init); v != nil {
179 return v.Name(), v
180 }
181 }
182 }
183
184 loop := curBody.Parent().Node()
185
186
187
188
189
190
191 name := row.elemname
192 if v := lookup(info, curBody, name); v != nil {
193
194 for curUse := range index.Uses(v) {
195 if curBody.Contains(curUse) {
196 name = refactor.FreshName(info.Scopes[loop], loop.Pos(), name)
197 break
198 }
199 }
200 }
201 return name, nil
202 }
203
204
205 nextCall:
206 for curLenCall := range index.Calls(lenMethod) {
207 lenSel, ok := ast.Unparen(curLenCall.Node().(*ast.CallExpr).Fun).(*ast.SelectorExpr)
208 if !ok {
209 continue
210 }
211
212
213 var (
214 rng analysis.Range
215 curBody inspector.Cursor
216 indexVar *types.Var
217 elemVar *types.Var
218 elem string
219 edits []analysis.TextEdit
220 )
221
222
223 switch first(curLenCall.ParentEdge()) {
224 case edge.BinaryExpr_Y:
225
226 var (
227 curCmp = curLenCall.Parent()
228 cmp = curCmp.Node().(*ast.BinaryExpr)
229 )
230 if cmp.Op != token.LSS ||
231 !astutil.IsChildOf(curCmp, edge.ForStmt_Cond) {
232 continue
233 }
234 if id, ok := cmp.X.(*ast.Ident); ok {
235
236 var (
237 v = info.Uses[id].(*types.Var)
238 curFor = curCmp.Parent()
239 loop = curFor.Node().(*ast.ForStmt)
240 )
241 if v != isIncrementLoop(info, loop) {
242 continue
243 }
244
245
246 rng = astutil.RangeOf(loop.For, loop.Post.End())
247 indexVar = v
248 curBody = curFor.ChildAt(edge.ForStmt_Body, -1)
249 elem, elemVar = chooseName(curBody, lenSel.X, indexVar)
250 elemPrefix := cond(row.seqn == 2, "_, ", "")
251
252
253
254
255
256 edits = []analysis.TextEdit{
257 {
258 Pos: v.Pos(),
259 End: v.Pos() + token.Pos(len(v.Name())),
260 NewText: []byte(elemPrefix + elem),
261 },
262 {
263 Pos: loop.Init.(*ast.AssignStmt).Rhs[0].Pos(),
264 End: cmp.Y.Pos(),
265 NewText: []byte("range "),
266 },
267 {
268 Pos: lenSel.Sel.Pos(),
269 End: lenSel.Sel.End(),
270 NewText: []byte(row.itermethod),
271 },
272 {
273 Pos: curLenCall.Node().End(),
274 End: loop.Post.End(),
275 },
276 }
277 }
278
279 case edge.RangeStmt_X:
280
281 var (
282 curRange = curLenCall.Parent()
283 loop = curRange.Node().(*ast.RangeStmt)
284 )
285 if id, ok := loop.Key.(*ast.Ident); ok &&
286 loop.Value == nil &&
287 loop.Tok == token.DEFINE {
288
289
290
291 rng = astutil.RangeOf(loop.Range, loop.X.End())
292 indexVar = info.Defs[id].(*types.Var)
293 curBody = curRange.ChildAt(edge.RangeStmt_Body, -1)
294 elem, elemVar = chooseName(curBody, lenSel.X, indexVar)
295 elemPrefix := cond(row.seqn == 2, "_, ", "")
296
297
298
299
300 edits = []analysis.TextEdit{
301 {
302 Pos: loop.Key.Pos(),
303 End: loop.Key.End(),
304 NewText: []byte(elemPrefix + elem),
305 },
306 {
307 Pos: lenSel.Sel.Pos(),
308 End: lenSel.Sel.End(),
309 NewText: []byte(row.itermethod),
310 },
311 }
312 }
313 }
314
315 if indexVar == nil {
316 continue
317 }
318
319
320
321
322
323
324
325 for curUse := range index.Uses(indexVar) {
326 if !curBody.Contains(curUse) {
327 continue
328 }
329 if ek, argidx := curUse.ParentEdge(); ek != edge.CallExpr_Args || argidx != 0 {
330 continue nextCall
331 }
332 curAtCall := curUse.Parent()
333 atCall := curAtCall.Node().(*ast.CallExpr)
334 if typeutil.Callee(info, atCall) != atMethod {
335 continue nextCall
336 }
337 atSel := ast.Unparen(atCall.Fun).(*ast.SelectorExpr)
338
339
340 if !astutil.EqualSyntax(lenSel.X, atSel.X) {
341 continue nextCall
342 }
343
344
345
346
347
348
349 if obj := lookup(info, curAtCall, elem); obj != nil && obj != elemVar && obj.Pos() > indexVar.Pos() {
350
351
352 continue nextCall
353 }
354
355
356
357
358 edits = append(edits, analysis.TextEdit{
359 Pos: atCall.Pos(),
360 End: atCall.End(),
361 NewText: []byte(elem),
362 })
363 }
364
365
366
367
368
369 if v, err := methodGoVersion(row.pkgpath, row.typename, row.itermethod); err != nil {
370 panic(err)
371 } else if !analyzerutil.FileUsesGoVersion(pass, astutil.EnclosingFile(curLenCall), v.String()) {
372 continue nextCall
373 }
374
375 pass.Report(analysis.Diagnostic{
376 Pos: rng.Pos(),
377 End: rng.End(),
378 Message: fmt.Sprintf("%s/%s loop can simplified using %s.%s iteration",
379 row.lenmethod, row.atmethod, row.typename, row.itermethod),
380 SuggestedFixes: []analysis.SuggestedFix{{
381 Message: fmt.Sprintf(
382 "Replace %s/%s loop with %s.%s iteration",
383 row.lenmethod, row.atmethod, row.typename, row.itermethod),
384 TextEdits: edits,
385 }},
386 })
387 }
388 }
389 return nil, nil
390 }
391
392
393
394
395
396 func methodGoVersion(pkgpath, recvtype, method string) (stdlib.Version, error) {
397
398
399
400 for _, sym := range stdlib.PackageSymbols[pkgpath] {
401 if sym.Kind == stdlib.Method {
402 _, recv, name := sym.SplitMethod()
403 if recv == recvtype && name == method {
404 return sym.Version, nil
405 }
406 }
407 }
408 return 0, fmt.Errorf("methodGoVersion: %s.%s.%s missing from stdlib manifest", pkgpath, recvtype, method)
409 }
410
View as plain text