1
2
3
4
5 package modernize
6
7
8
9 import (
10 "go/ast"
11 "go/types"
12
13 "golang.org/x/tools/go/analysis"
14 "golang.org/x/tools/go/analysis/passes/inspect"
15 "golang.org/x/tools/go/ast/edge"
16 "golang.org/x/tools/go/types/typeutil"
17 "golang.org/x/tools/internal/analysis/analyzerutil"
18 typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
19 "golang.org/x/tools/internal/astutil"
20 "golang.org/x/tools/internal/refactor"
21 "golang.org/x/tools/internal/typesinternal"
22 "golang.org/x/tools/internal/typesinternal/typeindex"
23 "golang.org/x/tools/internal/versions"
24 )
25
26 var ReflectTypeForAnalyzer = &analysis.Analyzer{
27 Name: "reflecttypefor",
28 Doc: analyzerutil.MustExtractDoc(doc, "reflecttypefor"),
29 Requires: []*analysis.Analyzer{
30 inspect.Analyzer,
31 typeindexanalyzer.Analyzer,
32 },
33 Run: reflecttypefor,
34 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#reflecttypefor",
35 }
36
37 func reflecttypefor(pass *analysis.Pass) (any, error) {
38 var (
39 index = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
40 info = pass.TypesInfo
41
42 reflectTypeOf = index.Object("reflect", "TypeOf")
43 )
44
45 for curCall := range index.Calls(reflectTypeOf) {
46 call := curCall.Node().(*ast.CallExpr)
47
48
49 expr := call.Args[0]
50 if !typesinternal.NoEffects(info, expr) {
51 continue
52 }
53
54 t := info.TypeOf(expr)
55 var edits []analysis.TextEdit
56
57
58
59 if astutil.IsChildOf(curCall, edge.SelectorExpr_X) {
60 curSel := unparenEnclosing(curCall).Parent()
61 if astutil.IsChildOf(curSel, edge.CallExpr_Fun) {
62 call2 := unparenEnclosing(curSel).Parent().Node().(*ast.CallExpr)
63 obj := typeutil.Callee(info, call2)
64 if typesinternal.IsMethodNamed(obj, "reflect", "Type", "Elem") {
65 if ptr, ok := t.(*types.Pointer); ok {
66
67 t = ptr.Elem()
68
69
70
71 edits = []analysis.TextEdit{{
72 Pos: call.End(),
73 End: call2.End(),
74 }}
75 }
76 }
77 }
78 }
79
80
81
82
83 if types.IsInterface(t) && edits == nil {
84 continue
85 }
86
87 file := astutil.EnclosingFile(curCall)
88 if !analyzerutil.FileUsesGoVersion(pass, file, versions.Go1_22) {
89 continue
90 }
91 tokFile := pass.Fset.File(file.Pos())
92
93
94
95
96
97 qual := typesinternal.FileQualifier(file, pass.Pkg)
98 tstr := types.TypeString(t, qual)
99
100 sel, ok := call.Fun.(*ast.SelectorExpr)
101 if !ok {
102 continue
103 }
104
105
106
107
108 if isComplicatedType(t) {
109 continue
110 }
111
112
113
114
115
116 oldLen := int(expr.End() - expr.Pos())
117 newLen := len(tstr)
118 if newLen >= 16 && newLen > 3*oldLen {
119 continue
120 }
121
122
123
124
125
126
127 curArg0 := curCall.ChildAt(edge.CallExpr_Args, 0)
128 edits = append(edits, refactor.DeleteUnusedVars(index, info, tokFile, curArg0)...)
129
130 pass.Report(analysis.Diagnostic{
131 Pos: call.Fun.Pos(),
132 End: call.Fun.End(),
133 Message: "reflect.TypeOf call can be simplified using TypeFor",
134 SuggestedFixes: []analysis.SuggestedFix{{
135
136
137
138 Message: "Replace TypeOf by TypeFor",
139 TextEdits: append([]analysis.TextEdit{
140 {
141 Pos: sel.Sel.Pos(),
142 End: sel.Sel.End(),
143 NewText: []byte("TypeFor[" + tstr + "]"),
144 },
145
146 {
147 Pos: call.Lparen + 1,
148 End: call.Rparen,
149 },
150 }, edits...),
151 }},
152 })
153 }
154
155 return nil, nil
156 }
157
158
159
160 func isComplicatedType(t types.Type) bool {
161 var check func(typ types.Type) bool
162 check = func(typ types.Type) bool {
163 switch t := typ.(type) {
164 case typesinternal.NamedOrAlias:
165 for ta := range t.TypeArgs().Types() {
166 if check(ta) {
167 return true
168 }
169 }
170 return false
171 case *types.Struct, *types.Interface, *types.Signature:
172
173
174 return true
175 case *types.Pointer:
176 return check(t.Elem())
177 case *types.Slice:
178 return check(t.Elem())
179 case *types.Array:
180 return check(t.Elem())
181 case *types.Chan:
182 return check(t.Elem())
183 case *types.Map:
184 return check(t.Key()) || check(t.Elem())
185 case *types.Basic:
186 return false
187 case *types.TypeParam:
188 return false
189 default:
190
191 return true
192 }
193 }
194
195 return check(t)
196 }
197
View as plain text