Source file src/cmd/vendor/golang.org/x/tools/go/analysis/passes/modernize/waitgroup.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modernize
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/printer"
    12  	"slices"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/analysis/analyzerutil"
    18  	typeindexanalyzer "golang.org/x/tools/internal/analysis/typeindex"
    19  	"golang.org/x/tools/internal/astutil"
    20  	"golang.org/x/tools/internal/refactor"
    21  	"golang.org/x/tools/internal/typesinternal/typeindex"
    22  	"golang.org/x/tools/internal/versions"
    23  )
    24  
    25  var WaitGroupAnalyzer = &analysis.Analyzer{
    26  	Name: "waitgroup",
    27  	Doc:  analyzerutil.MustExtractDoc(doc, "waitgroup"),
    28  	Requires: []*analysis.Analyzer{
    29  		inspect.Analyzer,
    30  		typeindexanalyzer.Analyzer,
    31  	},
    32  	Run: waitgroup,
    33  	URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/modernize#waitgroup",
    34  }
    35  
    36  // The waitgroup pass replaces old more complex code with
    37  // go1.25 added API WaitGroup.Go.
    38  //
    39  // Patterns:
    40  //
    41  //  1. wg.Add(1); go func() { defer wg.Done(); ... }()
    42  //     =>
    43  //     wg.Go(go func() { ... })
    44  //
    45  //  2. wg.Add(1); go func() { ...; wg.Done() }()
    46  //     =>
    47  //     wg.Go(go func() { ... })
    48  //
    49  // The wg.Done must occur within the first statement of the block in a
    50  // defer format or last statement of the block, and the offered fix
    51  // only removes the first/last wg.Done call. It doesn't fix existing
    52  // wrong usage of sync.WaitGroup.
    53  //
    54  // The use of WaitGroup.Go in pattern 1 implicitly introduces a
    55  // 'defer', which may change the behavior in the case of panic from
    56  // the "..." logic. In this instance, the change is safe: before and
    57  // after the transformation, an unhandled panic inevitably results in
    58  // a fatal crash. The fact that the transformed code calls wg.Done()
    59  // before the crash doesn't materially change anything. (If Done had
    60  // other effects, or blocked, or if WaitGroup.Go propagated panics
    61  // from child to parent goroutine, the argument would be different.)
    62  func waitgroup(pass *analysis.Pass) (any, error) {
    63  	var (
    64  		index             = pass.ResultOf[typeindexanalyzer.Analyzer].(*typeindex.Index)
    65  		info              = pass.TypesInfo
    66  		syncWaitGroupAdd  = index.Selection("sync", "WaitGroup", "Add")
    67  		syncWaitGroupDone = index.Selection("sync", "WaitGroup", "Done")
    68  	)
    69  	if !index.Used(syncWaitGroupDone) {
    70  		return nil, nil
    71  	}
    72  
    73  	for curAddCall := range index.Calls(syncWaitGroupAdd) {
    74  		// Extract receiver from wg.Add call.
    75  		addCall := curAddCall.Node().(*ast.CallExpr)
    76  		if !isIntLiteral(info, addCall.Args[0], 1) {
    77  			continue // not a call to wg.Add(1)
    78  		}
    79  		// Inv: the Args[0] check ensures addCall is not of
    80  		// the form sync.WaitGroup.Add(&wg, 1).
    81  		addCallRecv := ast.Unparen(addCall.Fun).(*ast.SelectorExpr).X
    82  
    83  		// Following statement must be go func() { ... } ().
    84  		curAddStmt := curAddCall.Parent()
    85  		if !is[*ast.ExprStmt](curAddStmt.Node()) {
    86  			continue // unnecessary parens?
    87  		}
    88  		curNext, ok := curAddCall.Parent().NextSibling()
    89  		if !ok {
    90  			continue // no successor
    91  		}
    92  		goStmt, ok := curNext.Node().(*ast.GoStmt)
    93  		if !ok {
    94  			continue // not a go stmt
    95  		}
    96  		lit, ok := goStmt.Call.Fun.(*ast.FuncLit)
    97  		if !ok || len(goStmt.Call.Args) != 0 {
    98  			continue // go argument is not func(){...}()
    99  		}
   100  		list := lit.Body.List
   101  		if len(list) == 0 {
   102  			continue
   103  		}
   104  
   105  		// Body must start with "defer wg.Done()" or end with "wg.Done()".
   106  		var doneStmt ast.Stmt
   107  		if deferStmt, ok := list[0].(*ast.DeferStmt); ok &&
   108  			typeutil.Callee(info, deferStmt.Call) == syncWaitGroupDone &&
   109  			astutil.EqualSyntax(ast.Unparen(deferStmt.Call.Fun).(*ast.SelectorExpr).X, addCallRecv) {
   110  			doneStmt = deferStmt // "defer wg.Done()"
   111  
   112  		} else if lastStmt, ok := list[len(list)-1].(*ast.ExprStmt); ok {
   113  			if doneCall, ok := lastStmt.X.(*ast.CallExpr); ok &&
   114  				typeutil.Callee(info, doneCall) == syncWaitGroupDone &&
   115  				astutil.EqualSyntax(ast.Unparen(doneCall.Fun).(*ast.SelectorExpr).X, addCallRecv) {
   116  				doneStmt = lastStmt // "wg.Done()"
   117  			}
   118  		}
   119  		if doneStmt == nil {
   120  			continue
   121  		}
   122  		curDoneStmt, ok := curNext.FindNode(doneStmt)
   123  		if !ok {
   124  			panic("can't find Cursor for 'done' statement")
   125  		}
   126  
   127  		file := astutil.EnclosingFile(curAddCall)
   128  		if !analyzerutil.FileUsesGoVersion(pass, file, versions.Go1_25) {
   129  			continue
   130  		}
   131  		tokFile := pass.Fset.File(file.Pos())
   132  
   133  		var addCallRecvText bytes.Buffer
   134  		err := printer.Fprint(&addCallRecvText, pass.Fset, addCallRecv)
   135  		if err != nil {
   136  			continue // error getting text for the edit
   137  		}
   138  
   139  		pass.Report(analysis.Diagnostic{
   140  			// go func() {
   141  			// ~~~~~~~~~
   142  			Pos:     goStmt.Pos(),
   143  			End:     lit.Type.End(),
   144  			Message: "Goroutine creation can be simplified using WaitGroup.Go",
   145  			SuggestedFixes: []analysis.SuggestedFix{{
   146  				Message: "Simplify by using WaitGroup.Go",
   147  				TextEdits: slices.Concat(
   148  					// delete "wg.Add(1)"
   149  					refactor.DeleteStmt(tokFile, curAddStmt),
   150  					// delete "wg.Done()" or "defer wg.Done()"
   151  					refactor.DeleteStmt(tokFile, curDoneStmt),
   152  					[]analysis.TextEdit{
   153  						// go    func()
   154  						// ------
   155  						// wg.Go(func()
   156  						{
   157  							Pos:     goStmt.Pos(),
   158  							End:     goStmt.Call.Pos(),
   159  							NewText: fmt.Appendf(nil, "%s.Go(", addCallRecvText.String()),
   160  						},
   161  						// ... }()
   162  						//      -
   163  						// ... } )
   164  						{
   165  							Pos: goStmt.Call.Lparen,
   166  							End: goStmt.Call.Rparen,
   167  						},
   168  					},
   169  				),
   170  			}},
   171  		})
   172  	}
   173  	return nil, nil
   174  }
   175  

View as plain text