Source file src/cmd/fix/main_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"internal/diff"
    12  	"internal/testenv"
    13  	"strings"
    14  	"testing"
    15  )
    16  
    17  type testCase struct {
    18  	Name    string
    19  	Fn      func(*ast.File) bool
    20  	Version string
    21  	In      string
    22  	Out     string
    23  }
    24  
    25  var testCases []testCase
    26  
    27  func addTestCases(t []testCase, fn func(*ast.File) bool) {
    28  	// Fill in fn to avoid repetition in definitions.
    29  	if fn != nil {
    30  		for i := range t {
    31  			if t[i].Fn == nil {
    32  				t[i].Fn = fn
    33  			}
    34  		}
    35  	}
    36  	testCases = append(testCases, t...)
    37  }
    38  
    39  func fnop(*ast.File) bool { return false }
    40  
    41  func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
    42  	file, err := parser.ParseFile(fset, desc, in, parserMode)
    43  	if err != nil {
    44  		t.Errorf("parsing: %v", err)
    45  		return
    46  	}
    47  
    48  	outb, err := gofmtFile(file)
    49  	if err != nil {
    50  		t.Errorf("printing: %v", err)
    51  		return
    52  	}
    53  	if s := string(outb); in != s && mustBeGofmt {
    54  		t.Errorf("not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
    55  			desc, in, desc, s)
    56  		tdiff(t, "want", in, "have", s)
    57  		return
    58  	}
    59  
    60  	if fn == nil {
    61  		for _, fix := range fixes {
    62  			if fix.f(file) {
    63  				fixed = true
    64  			}
    65  		}
    66  	} else {
    67  		fixed = fn(file)
    68  	}
    69  
    70  	outb, err = gofmtFile(file)
    71  	if err != nil {
    72  		t.Errorf("printing: %v", err)
    73  		return
    74  	}
    75  
    76  	return string(outb), fixed, true
    77  }
    78  
    79  func TestRewrite(t *testing.T) {
    80  	// If cgo is enabled, enforce that cgo commands invoked by cmd/fix
    81  	// do not fail during testing.
    82  	if testenv.HasCGO() {
    83  		testenv.MustHaveGoBuild(t) // Really just 'go tool cgo', but close enough.
    84  
    85  		// The reportCgoError hook is global, so we can't set it per-test
    86  		// if we want to be able to run those tests in parallel.
    87  		// Instead, simply set it to panic on error: the goroutine dump
    88  		// from the panic should help us determine which test failed.
    89  		prevReportCgoError := reportCgoError
    90  		reportCgoError = func(err error) {
    91  			panic(fmt.Sprintf("unexpected cgo error: %v", err))
    92  		}
    93  		t.Cleanup(func() { reportCgoError = prevReportCgoError })
    94  	}
    95  
    96  	for _, tt := range testCases {
    97  		tt := tt
    98  		t.Run(tt.Name, func(t *testing.T) {
    99  			if tt.Version == "" {
   100  				if testing.Verbose() {
   101  					// Don't run in parallel: cmd/fix sometimes writes directly to stderr,
   102  					// and since -v prints which test is currently running we want that
   103  					// information to accurately correlate with the stderr output.
   104  				} else {
   105  					t.Parallel()
   106  				}
   107  			} else {
   108  				old := *goVersion
   109  				*goVersion = tt.Version
   110  				defer func() {
   111  					*goVersion = old
   112  				}()
   113  			}
   114  
   115  			// Apply fix: should get tt.Out.
   116  			out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
   117  			if !ok {
   118  				return
   119  			}
   120  
   121  			// reformat to get printing right
   122  			out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
   123  			if !ok {
   124  				return
   125  			}
   126  
   127  			if tt.Out == "" {
   128  				tt.Out = tt.In
   129  			}
   130  			if out != tt.Out {
   131  				t.Errorf("incorrect output.\n")
   132  				if !strings.HasPrefix(tt.Name, "testdata/") {
   133  					t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
   134  				}
   135  				tdiff(t, "have", out, "want", tt.Out)
   136  				return
   137  			}
   138  
   139  			if changed := out != tt.In; changed != fixed {
   140  				t.Errorf("changed=%v != fixed=%v", changed, fixed)
   141  				return
   142  			}
   143  
   144  			// Should not change if run again.
   145  			out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
   146  			if !ok {
   147  				return
   148  			}
   149  
   150  			if fixed2 {
   151  				t.Errorf("applied fixes during second round")
   152  				return
   153  			}
   154  
   155  			if out2 != out {
   156  				t.Errorf("changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
   157  					out, out2)
   158  				tdiff(t, "first", out, "second", out2)
   159  			}
   160  		})
   161  	}
   162  }
   163  
   164  func tdiff(t *testing.T, aname, a, bname, b string) {
   165  	t.Errorf("%s", diff.Diff(aname, []byte(a), bname, []byte(b)))
   166  }
   167  

View as plain text