Source file
src/cmd/fix/main.go
1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "flag"
10 "fmt"
11 "go/ast"
12 "go/format"
13 "go/parser"
14 "go/scanner"
15 "go/token"
16 "go/version"
17 "internal/diff"
18 "io"
19 "io/fs"
20 "os"
21 "path/filepath"
22 "sort"
23 "strings"
24 )
25
26 var (
27 fset = token.NewFileSet()
28 exitCode = 0
29 )
30
31 var allowedRewrites = flag.String("r", "",
32 "restrict the rewrites to this comma-separated list")
33
34 var forceRewrites = flag.String("force", "",
35 "force these fixes to run even if the code looks updated")
36
37 var allowed, force map[string]bool
38
39 var (
40 doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
41 goVersion = flag.String("go", "", "go language version for files")
42 )
43
44
45 const debug = false
46
47 func usage() {
48 fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
49 flag.PrintDefaults()
50 fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
51 sort.Sort(byName(fixes))
52 for _, f := range fixes {
53 if f.disabled {
54 fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
55 } else {
56 fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
57 }
58 desc := strings.TrimSpace(f.desc)
59 desc = strings.ReplaceAll(desc, "\n", "\n\t")
60 fmt.Fprintf(os.Stderr, "\t%s\n", desc)
61 }
62 os.Exit(2)
63 }
64
65 func main() {
66 flag.Usage = usage
67 flag.Parse()
68
69 if !version.IsValid(*goVersion) {
70 report(fmt.Errorf("invalid -go=%s", *goVersion))
71 os.Exit(exitCode)
72 }
73
74 sort.Sort(byDate(fixes))
75
76 if *allowedRewrites != "" {
77 allowed = make(map[string]bool)
78 for _, f := range strings.Split(*allowedRewrites, ",") {
79 allowed[f] = true
80 }
81 }
82
83 if *forceRewrites != "" {
84 force = make(map[string]bool)
85 for _, f := range strings.Split(*forceRewrites, ",") {
86 force[f] = true
87 }
88 }
89
90 if flag.NArg() == 0 {
91 if err := processFile("standard input", true); err != nil {
92 report(err)
93 }
94 os.Exit(exitCode)
95 }
96
97 for i := 0; i < flag.NArg(); i++ {
98 path := flag.Arg(i)
99 switch dir, err := os.Stat(path); {
100 case err != nil:
101 report(err)
102 case dir.IsDir():
103 walkDir(path)
104 default:
105 if err := processFile(path, false); err != nil {
106 report(err)
107 }
108 }
109 }
110
111 os.Exit(exitCode)
112 }
113
114 const parserMode = parser.ParseComments
115
116 func gofmtFile(f *ast.File) ([]byte, error) {
117 var buf bytes.Buffer
118 if err := format.Node(&buf, fset, f); err != nil {
119 return nil, err
120 }
121 return buf.Bytes(), nil
122 }
123
124 func processFile(filename string, useStdin bool) error {
125 var f *os.File
126 var err error
127 var fixlog strings.Builder
128
129 if useStdin {
130 f = os.Stdin
131 } else {
132 f, err = os.Open(filename)
133 if err != nil {
134 return err
135 }
136 defer f.Close()
137 }
138
139 src, err := io.ReadAll(f)
140 if err != nil {
141 return err
142 }
143
144 file, err := parser.ParseFile(fset, filename, src, parserMode)
145 if err != nil {
146 return err
147 }
148
149
150
151 newSrc, err := gofmtFile(file)
152 if err != nil {
153 return err
154 }
155 if !bytes.Equal(newSrc, src) {
156 newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
157 if err != nil {
158 return err
159 }
160 file = newFile
161 fmt.Fprintf(&fixlog, " fmt")
162 }
163
164
165 newFile := file
166 fixed := false
167 for _, fix := range fixes {
168 if allowed != nil && !allowed[fix.name] {
169 continue
170 }
171 if fix.disabled && !force[fix.name] {
172 continue
173 }
174 if fix.f(newFile) {
175 fixed = true
176 fmt.Fprintf(&fixlog, " %s", fix.name)
177
178
179
180
181 newSrc, err := gofmtFile(newFile)
182 if err != nil {
183 return err
184 }
185 newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
186 if err != nil {
187 if debug {
188 fmt.Printf("%s", newSrc)
189 report(err)
190 os.Exit(exitCode)
191 }
192 return err
193 }
194 }
195 }
196 if !fixed {
197 return nil
198 }
199 fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
200
201
202
203
204
205
206
207 newSrc, err = gofmtFile(newFile)
208 if err != nil {
209 return err
210 }
211
212 if *doDiff {
213 os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
214 return nil
215 }
216
217 if useStdin {
218 os.Stdout.Write(newSrc)
219 return nil
220 }
221
222 return os.WriteFile(f.Name(), newSrc, 0)
223 }
224
225 func gofmt(n any) string {
226 var gofmtBuf strings.Builder
227 if err := format.Node(&gofmtBuf, fset, n); err != nil {
228 return "<" + err.Error() + ">"
229 }
230 return gofmtBuf.String()
231 }
232
233 func report(err error) {
234 scanner.PrintError(os.Stderr, err)
235 exitCode = 2
236 }
237
238 func walkDir(path string) {
239 filepath.WalkDir(path, visitFile)
240 }
241
242 func visitFile(path string, f fs.DirEntry, err error) error {
243 if err == nil && isGoFile(f) {
244 err = processFile(path, false)
245 }
246 if err != nil {
247 report(err)
248 }
249 return nil
250 }
251
252 func isGoFile(f fs.DirEntry) bool {
253
254 name := f.Name()
255 return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
256 }
257
View as plain text