Source file src/errors/wrap_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package errors_test
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io/fs"
    11  	"os"
    12  	"reflect"
    13  	"testing"
    14  )
    15  
    16  func TestIs(t *testing.T) {
    17  	err1 := errors.New("1")
    18  	erra := wrapped{"wrap 2", err1}
    19  	errb := wrapped{"wrap 3", erra}
    20  
    21  	err3 := errors.New("3")
    22  
    23  	poser := &poser{"either 1 or 3", func(err error) bool {
    24  		return err == err1 || err == err3
    25  	}}
    26  
    27  	testCases := []struct {
    28  		err    error
    29  		target error
    30  		match  bool
    31  	}{
    32  		{nil, nil, true},
    33  		{nil, err1, false},
    34  		{err1, nil, false},
    35  		{err1, err1, true},
    36  		{erra, err1, true},
    37  		{errb, err1, true},
    38  		{err1, err3, false},
    39  		{erra, err3, false},
    40  		{errb, err3, false},
    41  		{poser, err1, true},
    42  		{poser, err3, true},
    43  		{poser, erra, false},
    44  		{poser, errb, false},
    45  		{errorUncomparable{}, errorUncomparable{}, true},
    46  		{errorUncomparable{}, &errorUncomparable{}, false},
    47  		{&errorUncomparable{}, errorUncomparable{}, true},
    48  		{&errorUncomparable{}, &errorUncomparable{}, false},
    49  		{errorUncomparable{}, err1, false},
    50  		{&errorUncomparable{}, err1, false},
    51  		{multiErr{}, err1, false},
    52  		{multiErr{err1, err3}, err1, true},
    53  		{multiErr{err3, err1}, err1, true},
    54  		{multiErr{err1, err3}, errors.New("x"), false},
    55  		{multiErr{err3, errb}, errb, true},
    56  		{multiErr{err3, errb}, erra, true},
    57  		{multiErr{err3, errb}, err1, true},
    58  		{multiErr{errb, err3}, err1, true},
    59  		{multiErr{poser}, err1, true},
    60  		{multiErr{poser}, err3, true},
    61  		{multiErr{nil}, nil, false},
    62  	}
    63  	for _, tc := range testCases {
    64  		t.Run("", func(t *testing.T) {
    65  			if got := errors.Is(tc.err, tc.target); got != tc.match {
    66  				t.Errorf("Is(%v, %v) = %v, want %v", tc.err, tc.target, got, tc.match)
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  type poser struct {
    73  	msg string
    74  	f   func(error) bool
    75  }
    76  
    77  var poserPathErr = &fs.PathError{Op: "poser"}
    78  
    79  func (p *poser) Error() string     { return p.msg }
    80  func (p *poser) Is(err error) bool { return p.f(err) }
    81  func (p *poser) As(err any) bool {
    82  	switch x := err.(type) {
    83  	case **poser:
    84  		*x = p
    85  	case *errorT:
    86  		*x = errorT{"poser"}
    87  	case **fs.PathError:
    88  		*x = poserPathErr
    89  	default:
    90  		return false
    91  	}
    92  	return true
    93  }
    94  
    95  func TestAs(t *testing.T) {
    96  	var errT errorT
    97  	var errP *fs.PathError
    98  	var timeout interface{ Timeout() bool }
    99  	var p *poser
   100  	_, errF := os.Open("non-existing")
   101  	poserErr := &poser{"oh no", nil}
   102  
   103  	testCases := []struct {
   104  		err    error
   105  		target any
   106  		match  bool
   107  		want   any // value of target on match
   108  	}{{
   109  		nil,
   110  		&errP,
   111  		false,
   112  		nil,
   113  	}, {
   114  		wrapped{"pitied the fool", errorT{"T"}},
   115  		&errT,
   116  		true,
   117  		errorT{"T"},
   118  	}, {
   119  		errF,
   120  		&errP,
   121  		true,
   122  		errF,
   123  	}, {
   124  		errorT{},
   125  		&errP,
   126  		false,
   127  		nil,
   128  	}, {
   129  		wrapped{"wrapped", nil},
   130  		&errT,
   131  		false,
   132  		nil,
   133  	}, {
   134  		&poser{"error", nil},
   135  		&errT,
   136  		true,
   137  		errorT{"poser"},
   138  	}, {
   139  		&poser{"path", nil},
   140  		&errP,
   141  		true,
   142  		poserPathErr,
   143  	}, {
   144  		poserErr,
   145  		&p,
   146  		true,
   147  		poserErr,
   148  	}, {
   149  		errors.New("err"),
   150  		&timeout,
   151  		false,
   152  		nil,
   153  	}, {
   154  		errF,
   155  		&timeout,
   156  		true,
   157  		errF,
   158  	}, {
   159  		wrapped{"path error", errF},
   160  		&timeout,
   161  		true,
   162  		errF,
   163  	}, {
   164  		multiErr{},
   165  		&errT,
   166  		false,
   167  		nil,
   168  	}, {
   169  		multiErr{errors.New("a"), errorT{"T"}},
   170  		&errT,
   171  		true,
   172  		errorT{"T"},
   173  	}, {
   174  		multiErr{errorT{"T"}, errors.New("a")},
   175  		&errT,
   176  		true,
   177  		errorT{"T"},
   178  	}, {
   179  		multiErr{errorT{"a"}, errorT{"b"}},
   180  		&errT,
   181  		true,
   182  		errorT{"a"},
   183  	}, {
   184  		multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}},
   185  		&errT,
   186  		true,
   187  		errorT{"a"},
   188  	}, {
   189  		multiErr{wrapped{"path error", errF}},
   190  		&timeout,
   191  		true,
   192  		errF,
   193  	}, {
   194  		multiErr{nil},
   195  		&errT,
   196  		false,
   197  		nil,
   198  	}}
   199  	for i, tc := range testCases {
   200  		name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target)
   201  		// Clear the target pointer, in case it was set in a previous test.
   202  		rtarget := reflect.ValueOf(tc.target)
   203  		rtarget.Elem().Set(reflect.Zero(reflect.TypeOf(tc.target).Elem()))
   204  		t.Run(name, func(t *testing.T) {
   205  			match := errors.As(tc.err, tc.target)
   206  			if match != tc.match {
   207  				t.Fatalf("match: got %v; want %v", match, tc.match)
   208  			}
   209  			if !match {
   210  				return
   211  			}
   212  			if got := rtarget.Elem().Interface(); got != tc.want {
   213  				t.Fatalf("got %#v, want %#v", got, tc.want)
   214  			}
   215  		})
   216  	}
   217  }
   218  
   219  func TestAsValidation(t *testing.T) {
   220  	var s string
   221  	testCases := []any{
   222  		nil,
   223  		(*int)(nil),
   224  		"error",
   225  		&s,
   226  	}
   227  	err := errors.New("error")
   228  	for _, tc := range testCases {
   229  		t.Run(fmt.Sprintf("%T(%v)", tc, tc), func(t *testing.T) {
   230  			defer func() {
   231  				recover()
   232  			}()
   233  			if errors.As(err, tc) {
   234  				t.Errorf("As(err, %T(%v)) = true, want false", tc, tc)
   235  				return
   236  			}
   237  			t.Errorf("As(err, %T(%v)) did not panic", tc, tc)
   238  		})
   239  	}
   240  }
   241  
   242  func TestAsType(t *testing.T) {
   243  	var errT errorT
   244  	var errP *fs.PathError
   245  	type timeout interface {
   246  		Timeout() bool
   247  		error
   248  	}
   249  	_, errF := os.Open("non-existing")
   250  	poserErr := &poser{"oh no", nil}
   251  
   252  	testAsType(t,
   253  		nil,
   254  		errP,
   255  		false,
   256  	)
   257  	testAsType(t,
   258  		wrapped{"pitied the fool", errorT{"T"}},
   259  		errorT{"T"},
   260  		true,
   261  	)
   262  	testAsType(t,
   263  		errF,
   264  		errF,
   265  		true,
   266  	)
   267  	testAsType(t,
   268  		errT,
   269  		errP,
   270  		false,
   271  	)
   272  	testAsType(t,
   273  		wrapped{"wrapped", nil},
   274  		errT,
   275  		false,
   276  	)
   277  	testAsType(t,
   278  		&poser{"error", nil},
   279  		errorT{"poser"},
   280  		true,
   281  	)
   282  	testAsType(t,
   283  		&poser{"path", nil},
   284  		poserPathErr,
   285  		true,
   286  	)
   287  	testAsType(t,
   288  		poserErr,
   289  		poserErr,
   290  		true,
   291  	)
   292  	testAsType(t,
   293  		errors.New("err"),
   294  		timeout(nil),
   295  		false,
   296  	)
   297  	testAsType(t,
   298  		errF,
   299  		errF.(timeout),
   300  		true)
   301  	testAsType(t,
   302  		wrapped{"path error", errF},
   303  		errF.(timeout),
   304  		true,
   305  	)
   306  	testAsType(t,
   307  		multiErr{},
   308  		errT,
   309  		false,
   310  	)
   311  	testAsType(t,
   312  		multiErr{errors.New("a"), errorT{"T"}},
   313  		errorT{"T"},
   314  		true,
   315  	)
   316  	testAsType(t,
   317  		multiErr{errorT{"T"}, errors.New("a")},
   318  		errorT{"T"},
   319  		true,
   320  	)
   321  	testAsType(t,
   322  		multiErr{errorT{"a"}, errorT{"b"}},
   323  		errorT{"a"},
   324  		true,
   325  	)
   326  	testAsType(t,
   327  		multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}},
   328  		errorT{"a"},
   329  		true,
   330  	)
   331  	testAsType(t,
   332  		multiErr{wrapped{"path error", errF}},
   333  		errF.(timeout),
   334  		true,
   335  	)
   336  	testAsType(t,
   337  		multiErr{nil},
   338  		errT,
   339  		false,
   340  	)
   341  }
   342  
   343  type compError interface {
   344  	comparable
   345  	error
   346  }
   347  
   348  func testAsType[E compError](t *testing.T, err error, want E, wantOK bool) {
   349  	t.Helper()
   350  	name := fmt.Sprintf("AsType[%T](Errorf(..., %v))", want, err)
   351  	t.Run(name, func(t *testing.T) {
   352  		got, gotOK := errors.AsType[E](err)
   353  		if gotOK != wantOK || got != want {
   354  			t.Fatalf("got %v, %t; want %v, %t", got, gotOK, want, wantOK)
   355  		}
   356  	})
   357  }
   358  
   359  func BenchmarkIs(b *testing.B) {
   360  	err1 := errors.New("1")
   361  	err2 := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}}
   362  
   363  	for i := 0; i < b.N; i++ {
   364  		if !errors.Is(err2, err1) {
   365  			b.Fatal("Is failed")
   366  		}
   367  	}
   368  }
   369  
   370  func BenchmarkAs(b *testing.B) {
   371  	err := multiErr{multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}}
   372  	for i := 0; i < b.N; i++ {
   373  		var target errorT
   374  		if !errors.As(err, &target) {
   375  			b.Fatal("As failed")
   376  		}
   377  	}
   378  }
   379  
   380  func BenchmarkAsType(b *testing.B) {
   381  	err := multiErr{multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}}
   382  	for range b.N {
   383  		if _, ok := errors.AsType[errorT](err); !ok {
   384  			b.Fatal("AsType failed")
   385  		}
   386  	}
   387  }
   388  
   389  func TestUnwrap(t *testing.T) {
   390  	err1 := errors.New("1")
   391  	erra := wrapped{"wrap 2", err1}
   392  
   393  	testCases := []struct {
   394  		err  error
   395  		want error
   396  	}{
   397  		{nil, nil},
   398  		{wrapped{"wrapped", nil}, nil},
   399  		{err1, nil},
   400  		{erra, err1},
   401  		{wrapped{"wrap 3", erra}, erra},
   402  	}
   403  	for _, tc := range testCases {
   404  		if got := errors.Unwrap(tc.err); got != tc.want {
   405  			t.Errorf("Unwrap(%v) = %v, want %v", tc.err, got, tc.want)
   406  		}
   407  	}
   408  }
   409  
   410  type errorT struct{ s string }
   411  
   412  func (e errorT) Error() string { return fmt.Sprintf("errorT(%s)", e.s) }
   413  
   414  type wrapped struct {
   415  	msg string
   416  	err error
   417  }
   418  
   419  func (e wrapped) Error() string { return e.msg }
   420  func (e wrapped) Unwrap() error { return e.err }
   421  
   422  type multiErr []error
   423  
   424  func (m multiErr) Error() string   { return "multiError" }
   425  func (m multiErr) Unwrap() []error { return []error(m) }
   426  
   427  type errorUncomparable struct {
   428  	f []string
   429  }
   430  
   431  func (errorUncomparable) Error() string {
   432  	return "uncomparable error"
   433  }
   434  
   435  func (errorUncomparable) Is(target error) bool {
   436  	_, ok := target.(errorUncomparable)
   437  	return ok
   438  }
   439  

View as plain text