Source file src/cmd/gofmt/gofmt_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"flag"
    10  	"internal/diff"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"testing"
    15  	"text/scanner"
    16  )
    17  
    18  var update = flag.Bool("update", false, "update .golden files")
    19  
    20  // gofmtFlags looks for a comment of the form
    21  //
    22  //	//gofmt flags
    23  //
    24  // within the first maxLines lines of the given file,
    25  // and returns the flags string, if any. Otherwise it
    26  // returns the empty string.
    27  func gofmtFlags(filename string, maxLines int) string {
    28  	f, err := os.Open(filename)
    29  	if err != nil {
    30  		return "" // ignore errors - they will be found later
    31  	}
    32  	defer f.Close()
    33  
    34  	// initialize scanner
    35  	var s scanner.Scanner
    36  	s.Init(f)
    37  	s.Error = func(*scanner.Scanner, string) {}       // ignore errors
    38  	s.Mode = scanner.GoTokens &^ scanner.SkipComments // want comments
    39  
    40  	// look for //gofmt comment
    41  	for s.Line <= maxLines {
    42  		switch s.Scan() {
    43  		case scanner.Comment:
    44  			const prefix = "//gofmt "
    45  			if t := s.TokenText(); strings.HasPrefix(t, prefix) {
    46  				return strings.TrimSpace(t[len(prefix):])
    47  			}
    48  		case scanner.EOF:
    49  			return ""
    50  		}
    51  	}
    52  
    53  	return ""
    54  }
    55  
    56  func runTest(t *testing.T, in, out string) {
    57  	// process flags
    58  	*simplifyAST = false
    59  	*rewriteRule = ""
    60  	info, err := os.Lstat(in)
    61  	if err != nil {
    62  		t.Error(err)
    63  		return
    64  	}
    65  	for _, flag := range strings.Split(gofmtFlags(in, 20), " ") {
    66  		elts := strings.SplitN(flag, "=", 2)
    67  		name := elts[0]
    68  		value := ""
    69  		if len(elts) == 2 {
    70  			value = elts[1]
    71  		}
    72  		switch name {
    73  		case "":
    74  			// no flags
    75  		case "-r":
    76  			*rewriteRule = value
    77  		case "-s":
    78  			*simplifyAST = true
    79  		case "-stdin":
    80  			// fake flag - pretend input is from stdin
    81  			info = nil
    82  		default:
    83  			t.Errorf("unrecognized flag name: %s", name)
    84  		}
    85  	}
    86  
    87  	initParserMode()
    88  	initRewrite()
    89  
    90  	const maxWeight = 2 << 20
    91  	var buf, errBuf bytes.Buffer
    92  	s := newSequencer(maxWeight, &buf, &errBuf)
    93  	s.Add(fileWeight(in, info), func(r *reporter) error {
    94  		return processFile(in, info, nil, r)
    95  	})
    96  	if errBuf.Len() > 0 {
    97  		t.Logf("%q", errBuf.Bytes())
    98  	}
    99  	if s.GetExitCode() != 0 {
   100  		t.Fail()
   101  	}
   102  
   103  	expected, err := os.ReadFile(out)
   104  	if err != nil {
   105  		t.Error(err)
   106  		return
   107  	}
   108  
   109  	if got := buf.Bytes(); !bytes.Equal(got, expected) {
   110  		if *update {
   111  			if in != out {
   112  				if err := os.WriteFile(out, got, 0666); err != nil {
   113  					t.Error(err)
   114  				}
   115  				return
   116  			}
   117  			// in == out: don't accidentally destroy input
   118  			t.Errorf("WARNING: -update did not rewrite input file %s", in)
   119  		}
   120  
   121  		t.Errorf("(gofmt %s) != %s (see %s.gofmt)\n%s", in, out, in,
   122  			diff.Diff("expected", expected, "got", got))
   123  		if err := os.WriteFile(in+".gofmt", got, 0666); err != nil {
   124  			t.Error(err)
   125  		}
   126  	}
   127  }
   128  
   129  // TestRewrite processes testdata/*.input files and compares them to the
   130  // corresponding testdata/*.golden files. The gofmt flags used to process
   131  // a file must be provided via a comment of the form
   132  //
   133  //	//gofmt flags
   134  //
   135  // in the processed file within the first 20 lines, if any.
   136  func TestRewrite(t *testing.T) {
   137  	// determine input files
   138  	match, err := filepath.Glob("testdata/*.input")
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  
   143  	// add larger examples
   144  	match = append(match, "gofmt.go", "gofmt_test.go")
   145  
   146  	for _, in := range match {
   147  		name := filepath.Base(in)
   148  		t.Run(name, func(t *testing.T) {
   149  			out := in // for files where input and output are identical
   150  			if strings.HasSuffix(in, ".input") {
   151  				out = in[:len(in)-len(".input")] + ".golden"
   152  			}
   153  			runTest(t, in, out)
   154  			if in != out && !t.Failed() {
   155  				// Check idempotence.
   156  				runTest(t, out, out)
   157  			}
   158  		})
   159  	}
   160  }
   161  
   162  // Test case for issue 3961.
   163  func TestCRLF(t *testing.T) {
   164  	const input = "testdata/crlf.input"   // must contain CR/LF's
   165  	const golden = "testdata/crlf.golden" // must not contain any CR's
   166  
   167  	data, err := os.ReadFile(input)
   168  	if err != nil {
   169  		t.Error(err)
   170  	}
   171  	if !bytes.Contains(data, []byte("\r\n")) {
   172  		t.Errorf("%s contains no CR/LF's", input)
   173  	}
   174  
   175  	data, err = os.ReadFile(golden)
   176  	if err != nil {
   177  		t.Error(err)
   178  	}
   179  	if bytes.Contains(data, []byte("\r")) {
   180  		t.Errorf("%s contains CR's", golden)
   181  	}
   182  }
   183  
   184  func TestBackupFile(t *testing.T) {
   185  	dir, err := os.MkdirTemp("", "gofmt_test")
   186  	if err != nil {
   187  		t.Fatal(err)
   188  	}
   189  	defer os.RemoveAll(dir)
   190  	name, err := backupFile(filepath.Join(dir, "foo.go"), []byte("  package main"), 0644)
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  	t.Logf("Created: %s", name)
   195  }
   196  

View as plain text