Source file src/cmd/fix/main.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/scanner"
    15  	"go/token"
    16  	"go/version"
    17  	"internal/diff"
    18  	"io"
    19  	"io/fs"
    20  	"os"
    21  	"path/filepath"
    22  	"sort"
    23  	"strings"
    24  )
    25  
    26  var (
    27  	fset     = token.NewFileSet()
    28  	exitCode = 0
    29  )
    30  
    31  var allowedRewrites = flag.String("r", "",
    32  	"restrict the rewrites to this comma-separated list")
    33  
    34  var forceRewrites = flag.String("force", "",
    35  	"force these fixes to run even if the code looks updated")
    36  
    37  var allowed, force map[string]bool
    38  
    39  var (
    40  	doDiff    = flag.Bool("diff", false, "display diffs instead of rewriting files")
    41  	goVersion = flag.String("go", "", "go language version for files")
    42  )
    43  
    44  // enable for debugging fix failures
    45  const debug = false // display incorrectly reformatted source and exit
    46  
    47  func usage() {
    48  	fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
    49  	flag.PrintDefaults()
    50  	fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
    51  	sort.Sort(byName(fixes))
    52  	for _, f := range fixes {
    53  		if f.disabled {
    54  			fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name)
    55  		} else {
    56  			fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
    57  		}
    58  		desc := strings.TrimSpace(f.desc)
    59  		desc = strings.ReplaceAll(desc, "\n", "\n\t")
    60  		fmt.Fprintf(os.Stderr, "\t%s\n", desc)
    61  	}
    62  	os.Exit(2)
    63  }
    64  
    65  func main() {
    66  	flag.Usage = usage
    67  	flag.Parse()
    68  
    69  	if !version.IsValid(*goVersion) {
    70  		report(fmt.Errorf("invalid -go=%s", *goVersion))
    71  		os.Exit(exitCode)
    72  	}
    73  
    74  	sort.Sort(byDate(fixes))
    75  
    76  	if *allowedRewrites != "" {
    77  		allowed = make(map[string]bool)
    78  		for _, f := range strings.Split(*allowedRewrites, ",") {
    79  			allowed[f] = true
    80  		}
    81  	}
    82  
    83  	if *forceRewrites != "" {
    84  		force = make(map[string]bool)
    85  		for _, f := range strings.Split(*forceRewrites, ",") {
    86  			force[f] = true
    87  		}
    88  	}
    89  
    90  	if flag.NArg() == 0 {
    91  		if err := processFile("standard input", true); err != nil {
    92  			report(err)
    93  		}
    94  		os.Exit(exitCode)
    95  	}
    96  
    97  	for i := 0; i < flag.NArg(); i++ {
    98  		path := flag.Arg(i)
    99  		switch dir, err := os.Stat(path); {
   100  		case err != nil:
   101  			report(err)
   102  		case dir.IsDir():
   103  			walkDir(path)
   104  		default:
   105  			if err := processFile(path, false); err != nil {
   106  				report(err)
   107  			}
   108  		}
   109  	}
   110  
   111  	os.Exit(exitCode)
   112  }
   113  
   114  const parserMode = parser.ParseComments
   115  
   116  func gofmtFile(f *ast.File) ([]byte, error) {
   117  	var buf bytes.Buffer
   118  	if err := format.Node(&buf, fset, f); err != nil {
   119  		return nil, err
   120  	}
   121  	return buf.Bytes(), nil
   122  }
   123  
   124  func processFile(filename string, useStdin bool) error {
   125  	var f *os.File
   126  	var err error
   127  	var fixlog strings.Builder
   128  
   129  	if useStdin {
   130  		f = os.Stdin
   131  	} else {
   132  		f, err = os.Open(filename)
   133  		if err != nil {
   134  			return err
   135  		}
   136  		defer f.Close()
   137  	}
   138  
   139  	src, err := io.ReadAll(f)
   140  	if err != nil {
   141  		return err
   142  	}
   143  
   144  	file, err := parser.ParseFile(fset, filename, src, parserMode)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	// Make sure file is in canonical format.
   150  	// This "fmt" pseudo-fix cannot be disabled.
   151  	newSrc, err := gofmtFile(file)
   152  	if err != nil {
   153  		return err
   154  	}
   155  	if !bytes.Equal(newSrc, src) {
   156  		newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode)
   157  		if err != nil {
   158  			return err
   159  		}
   160  		file = newFile
   161  		fmt.Fprintf(&fixlog, " fmt")
   162  	}
   163  
   164  	// Apply all fixes to file.
   165  	newFile := file
   166  	fixed := false
   167  	for _, fix := range fixes {
   168  		if allowed != nil && !allowed[fix.name] {
   169  			continue
   170  		}
   171  		if fix.disabled && !force[fix.name] {
   172  			continue
   173  		}
   174  		if fix.f(newFile) {
   175  			fixed = true
   176  			fmt.Fprintf(&fixlog, " %s", fix.name)
   177  
   178  			// AST changed.
   179  			// Print and parse, to update any missing scoping
   180  			// or position information for subsequent fixers.
   181  			newSrc, err := gofmtFile(newFile)
   182  			if err != nil {
   183  				return err
   184  			}
   185  			newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
   186  			if err != nil {
   187  				if debug {
   188  					fmt.Printf("%s", newSrc)
   189  					report(err)
   190  					os.Exit(exitCode)
   191  				}
   192  				return err
   193  			}
   194  		}
   195  	}
   196  	if !fixed {
   197  		return nil
   198  	}
   199  	fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
   200  
   201  	// Print AST.  We did that after each fix, so this appears
   202  	// redundant, but it is necessary to generate gofmt-compatible
   203  	// source code in a few cases. The official gofmt style is the
   204  	// output of the printer run on a standard AST generated by the parser,
   205  	// but the source we generated inside the loop above is the
   206  	// output of the printer run on a mangled AST generated by a fixer.
   207  	newSrc, err = gofmtFile(newFile)
   208  	if err != nil {
   209  		return err
   210  	}
   211  
   212  	if *doDiff {
   213  		os.Stdout.Write(diff.Diff(filename, src, "fixed/"+filename, newSrc))
   214  		return nil
   215  	}
   216  
   217  	if useStdin {
   218  		os.Stdout.Write(newSrc)
   219  		return nil
   220  	}
   221  
   222  	return os.WriteFile(f.Name(), newSrc, 0)
   223  }
   224  
   225  func gofmt(n any) string {
   226  	var gofmtBuf strings.Builder
   227  	if err := format.Node(&gofmtBuf, fset, n); err != nil {
   228  		return "<" + err.Error() + ">"
   229  	}
   230  	return gofmtBuf.String()
   231  }
   232  
   233  func report(err error) {
   234  	scanner.PrintError(os.Stderr, err)
   235  	exitCode = 2
   236  }
   237  
   238  func walkDir(path string) {
   239  	filepath.WalkDir(path, visitFile)
   240  }
   241  
   242  func visitFile(path string, f fs.DirEntry, err error) error {
   243  	if err == nil && isGoFile(f) {
   244  		err = processFile(path, false)
   245  	}
   246  	if err != nil {
   247  		report(err)
   248  	}
   249  	return nil
   250  }
   251  
   252  func isGoFile(f fs.DirEntry) bool {
   253  	// ignore non-Go files
   254  	name := f.Name()
   255  	return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
   256  }
   257  

View as plain text