Source file src/encoding/json/v2/arshal_funcs.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sync"
    15  
    16  	"encoding/json/internal"
    17  	"encoding/json/internal/jsonflags"
    18  	"encoding/json/internal/jsonopts"
    19  	"encoding/json/jsontext"
    20  )
    21  
    22  var errUnsupportedMutation = errors.New("unsupported calls must not read or write any tokens")
    23  var errNonSingularValue = errors.New("must read or write exactly one value")
    24  
    25  // Marshalers is a list of functions that may override the marshal behavior
    26  // of specific types. Populate [WithMarshalers] to use it with
    27  // [Marshal], [MarshalWrite], or [MarshalEncode].
    28  // A nil *Marshalers is equivalent to an empty list.
    29  // There are no exported fields or methods on Marshalers.
    30  type Marshalers = typedMarshalers
    31  
    32  // JoinMarshalers constructs a flattened list of marshal functions.
    33  // If multiple functions in the list are applicable for a value of a given type,
    34  // then those earlier in the list take precedence over those that come later.
    35  // If a function returns [errors.ErrUnsupported],
    36  // then the next applicable function is called,
    37  // otherwise the default marshaling behavior is used.
    38  //
    39  // For example:
    40  //
    41  //	m1 := JoinMarshalers(f1, f2)
    42  //	m2 := JoinMarshalers(f0, m1, f3)     // equivalent to m3
    43  //	m3 := JoinMarshalers(f0, f1, f2, f3) // equivalent to m2
    44  func JoinMarshalers(ms ...*Marshalers) *Marshalers {
    45  	return newMarshalers(ms...)
    46  }
    47  
    48  // Unmarshalers is a list of functions that may override the unmarshal behavior
    49  // of specific types. Populate [WithUnmarshalers] to use it with
    50  // [Unmarshal], [UnmarshalRead], or [UnmarshalDecode].
    51  // A nil *Unmarshalers is equivalent to an empty list.
    52  // There are no exported fields or methods on Unmarshalers.
    53  type Unmarshalers = typedUnmarshalers
    54  
    55  // JoinUnmarshalers constructs a flattened list of unmarshal functions.
    56  // If multiple functions in the list are applicable for a value of a given type,
    57  // then those earlier in the list take precedence over those that come later.
    58  // If a function returns [errors.ErrUnsupported],
    59  // then the next applicable function is called,
    60  // otherwise the default unmarshaling behavior is used.
    61  //
    62  // For example:
    63  //
    64  //	u1 := JoinUnmarshalers(f1, f2)
    65  //	u2 := JoinUnmarshalers(f0, u1, f3)     // equivalent to u3
    66  //	u3 := JoinUnmarshalers(f0, f1, f2, f3) // equivalent to u2
    67  func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
    68  	return newUnmarshalers(us...)
    69  }
    70  
    71  type typedMarshalers = typedArshalers[jsontext.Encoder]
    72  type typedUnmarshalers = typedArshalers[jsontext.Decoder]
    73  type typedArshalers[Coder any] struct {
    74  	nonComparable
    75  
    76  	fncVals  []typedArshaler[Coder]
    77  	fncCache sync.Map // map[reflect.Type]arshaler
    78  
    79  	// fromAny reports whether any of Go types used to represent arbitrary JSON
    80  	// (i.e., any, bool, string, float64, map[string]any, or []any) matches
    81  	// any of the provided type-specific arshalers.
    82  	//
    83  	// This bit of information is needed in arshal_default.go to determine
    84  	// whether to use the specialized logic in arshal_any.go to handle
    85  	// the any interface type. The logic in arshal_any.go does not support
    86  	// type-specific arshal functions, so we must avoid using that logic
    87  	// if this is true.
    88  	fromAny bool
    89  }
    90  type typedMarshaler = typedArshaler[jsontext.Encoder]
    91  type typedUnmarshaler = typedArshaler[jsontext.Decoder]
    92  type typedArshaler[Coder any] struct {
    93  	typ     reflect.Type
    94  	fnc     func(*Coder, addressableValue, *jsonopts.Struct) error
    95  	maySkip bool
    96  }
    97  
    98  func newMarshalers(ms ...*Marshalers) *Marshalers       { return newTypedArshalers(ms...) }
    99  func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
   100  func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
   101  	var a typedArshalers[Coder]
   102  	for _, a2 := range as {
   103  		if a2 != nil {
   104  			a.fncVals = append(a.fncVals, a2.fncVals...)
   105  			a.fromAny = a.fromAny || a2.fromAny
   106  		}
   107  	}
   108  	if len(a.fncVals) == 0 {
   109  		return nil
   110  	}
   111  	return &a
   112  }
   113  
   114  func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
   115  	if a == nil {
   116  		return fnc, false
   117  	}
   118  	if v, ok := a.fncCache.Load(t); ok {
   119  		if v == nil {
   120  			return fnc, false
   121  		}
   122  		return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   123  	}
   124  
   125  	// Collect a list of arshalers that can be called for this type.
   126  	// This list may be longer than 1 since some arshalers can be skipped.
   127  	var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
   128  	for _, fncVal := range a.fncVals {
   129  		if !castableTo(t, fncVal.typ) {
   130  			continue
   131  		}
   132  		fncs = append(fncs, fncVal.fnc)
   133  		if !fncVal.maySkip {
   134  			break // subsequent arshalers will never be called
   135  		}
   136  	}
   137  
   138  	if len(fncs) == 0 {
   139  		a.fncCache.Store(t, nil) // nil to indicate that no funcs found
   140  		return fnc, false
   141  	}
   142  
   143  	// Construct an arshaler that may call every applicable arshaler.
   144  	fncDefault := fnc
   145  	fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
   146  		for _, fnc := range fncs {
   147  			if err := fnc(c, v, o); !errors.Is(err, errors.ErrUnsupported) {
   148  				return err // may be nil or non-nil
   149  			}
   150  		}
   151  		return fncDefault(c, v, o)
   152  	}
   153  
   154  	// Use the first stored so duplicate work can be garbage collected.
   155  	v, _ := a.fncCache.LoadOrStore(t, fnc)
   156  	return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   157  }
   158  
   159  // MarshalFunc constructs a type-specific marshaler that
   160  // specifies how to marshal values of type T.
   161  // T can be any type except a named pointer.
   162  // The function is always provided with a non-nil pointer value
   163  // if T is an interface or pointer type.
   164  //
   165  // The function must marshal exactly one JSON value.
   166  // The value of T must not be retained outside the function call.
   167  // It may not return [errors.ErrUnsupported].
   168  func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
   169  	t := reflect.TypeFor[T]()
   170  	assertCastableTo(t, true)
   171  	typFnc := typedMarshaler{
   172  		typ: t,
   173  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   174  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   175  			val, err := fn(v)
   176  			if err != nil {
   177  				err = wrapErrUnsupported(err, "marshal function of type func(T) ([]byte, error)")
   178  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   179  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   180  				}
   181  				err = newMarshalErrorBefore(enc, t, err)
   182  				return collapseSemanticErrors(err)
   183  			}
   184  			if err := enc.WriteValue(val); err != nil {
   185  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   186  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   187  				}
   188  				if isSyntacticError(err) {
   189  					err = newMarshalErrorBefore(enc, t, err)
   190  				}
   191  				return err
   192  			}
   193  			return nil
   194  		},
   195  	}
   196  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   197  }
   198  
   199  // MarshalToFunc constructs a type-specific marshaler that
   200  // specifies how to marshal values of type T.
   201  // T can be any type except a named pointer.
   202  // The function is always provided with a non-nil pointer value
   203  // if T is an interface or pointer type.
   204  //
   205  // The function must marshal exactly one JSON value by calling write methods
   206  // on the provided encoder. It may return [errors.ErrUnsupported] such that marshaling can
   207  // move on to the next marshal function. However, no mutable method calls may
   208  // be called on the encoder if [errors.ErrUnsupported] is returned.
   209  // The pointer to [jsontext.Encoder] and the value of T
   210  // must not be retained outside the function call.
   211  func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
   212  	t := reflect.TypeFor[T]()
   213  	assertCastableTo(t, true)
   214  	typFnc := typedMarshaler{
   215  		typ: t,
   216  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   217  			xe := export.Encoder(enc)
   218  			prevDepth, prevLength := xe.Tokens.DepthLength()
   219  			xe.Flags.Set(jsonflags.WithinArshalCall | 1)
   220  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   221  			err := fn(enc, v)
   222  			xe.Flags.Set(jsonflags.WithinArshalCall | 0)
   223  			currDepth, currLength := xe.Tokens.DepthLength()
   224  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   225  				err = errNonSingularValue
   226  			}
   227  			if err != nil {
   228  				if errors.Is(err, errors.ErrUnsupported) {
   229  					if prevDepth == currDepth && prevLength == currLength {
   230  						return err // forward [errors.ErrUnsupported]
   231  					}
   232  					err = errUnsupportedMutation
   233  				}
   234  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   235  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc") // unlike unmarshal, always wrapped
   236  				}
   237  				if !export.IsIOError(err) {
   238  					err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
   239  				}
   240  				return err
   241  			}
   242  			return nil
   243  		},
   244  		maySkip: true,
   245  	}
   246  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   247  }
   248  
   249  // UnmarshalFunc constructs a type-specific unmarshaler that
   250  // specifies how to unmarshal values of type T.
   251  // T must be an unnamed pointer or an interface type.
   252  // The function is always provided with a non-nil pointer value.
   253  //
   254  // The function must unmarshal exactly one JSON value.
   255  // The input []byte must not be mutated.
   256  // The input []byte and value T must not be retained outside the function call.
   257  // It may not return [errors.ErrUnsupported].
   258  func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
   259  	t := reflect.TypeFor[T]()
   260  	assertCastableTo(t, false)
   261  	typFnc := typedUnmarshaler{
   262  		typ: t,
   263  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   264  			val, err := dec.ReadValue()
   265  			if err != nil {
   266  				return err // must be a syntactic or I/O error
   267  			}
   268  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   269  			err = fn(val, v)
   270  			if err != nil {
   271  				err = wrapErrUnsupported(err, "unmarshal function of type func([]byte, T) error")
   272  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   273  					return err // unlike marshal, never wrapped
   274  				}
   275  				err = newUnmarshalErrorAfter(dec, t, err)
   276  				return collapseSemanticErrors(err)
   277  			}
   278  			return nil
   279  		},
   280  	}
   281  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   282  }
   283  
   284  // UnmarshalFromFunc constructs a type-specific unmarshaler that
   285  // specifies how to unmarshal values of type T.
   286  // T must be an unnamed pointer or an interface type.
   287  // The function is always provided with a non-nil pointer value.
   288  //
   289  // The function must unmarshal exactly one JSON value by calling read methods
   290  // on the provided decoder. It may return [errors.ErrUnsupported] such that unmarshaling can
   291  // move on to the next unmarshal function. However, no mutable method calls may
   292  // be called on the decoder if [errors.ErrUnsupported] is returned.
   293  // The pointer to [jsontext.Decoder] and the value of T
   294  // must not be retained outside the function call.
   295  func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
   296  	t := reflect.TypeFor[T]()
   297  	assertCastableTo(t, false)
   298  	typFnc := typedUnmarshaler{
   299  		typ: t,
   300  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   301  			xd := export.Decoder(dec)
   302  			prevDepth, prevLength := xd.Tokens.DepthLength()
   303  			if prevDepth == 1 && xd.AtEOF() {
   304  				return io.EOF // check EOF early to avoid fn reporting an EOF
   305  			}
   306  			xd.Flags.Set(jsonflags.WithinArshalCall | 1)
   307  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   308  			err := fn(dec, v)
   309  			xd.Flags.Set(jsonflags.WithinArshalCall | 0)
   310  			currDepth, currLength := xd.Tokens.DepthLength()
   311  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   312  				err = errNonSingularValue
   313  			}
   314  			if err != nil {
   315  				if errors.Is(err, errors.ErrUnsupported) {
   316  					if prevDepth == currDepth && prevLength == currLength {
   317  						return err // forward [errors.ErrUnsupported]
   318  					}
   319  					err = errUnsupportedMutation
   320  				}
   321  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   322  					if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
   323  						return err2
   324  					}
   325  					return err // unlike marshal, never wrapped
   326  				}
   327  				if !isSyntacticError(err) && !export.IsIOError(err) {
   328  					err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
   329  				}
   330  				return err
   331  			}
   332  			return nil
   333  		},
   334  		maySkip: true,
   335  	}
   336  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   337  }
   338  
   339  // assertCastableTo asserts that "to" is a valid type to be casted to.
   340  // These are the Go types that type-specific arshalers may operate upon.
   341  //
   342  // Let AllTypes be the universal set of all possible Go types.
   343  // This function generally asserts that:
   344  //
   345  //	len([from for from in AllTypes if castableTo(from, to)]) > 0
   346  //
   347  // otherwise it panics.
   348  //
   349  // As a special-case if marshal is false, then we forbid any non-pointer or
   350  // non-interface type since it is almost always a bug trying to unmarshal
   351  // into something where the end-user caller did not pass in an addressable value
   352  // since they will not observe the mutations.
   353  func assertCastableTo(to reflect.Type, marshal bool) {
   354  	switch to.Kind() {
   355  	case reflect.Interface:
   356  		return
   357  	case reflect.Pointer:
   358  		// Only allow unnamed pointers to be consistent with the fact that
   359  		// taking the address of a value produces an unnamed pointer type.
   360  		if to.Name() == "" {
   361  			return
   362  		}
   363  	default:
   364  		// Technically, non-pointer types are permissible for unmarshal.
   365  		// However, they are often a bug since the receiver would be immutable.
   366  		// Thus, only allow them for marshaling.
   367  		if marshal {
   368  			return
   369  		}
   370  	}
   371  	if marshal {
   372  		panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
   373  	} else {
   374  		panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
   375  	}
   376  }
   377  
   378  // castableTo checks whether values of type "from" can be casted to type "to".
   379  // Nil pointer or interface "from" values are never considered castable.
   380  //
   381  // This function must be kept in sync with addressableValue.castTo.
   382  func castableTo(from, to reflect.Type) bool {
   383  	switch to.Kind() {
   384  	case reflect.Interface:
   385  		// TODO: This breaks when ordinary interfaces can have type sets
   386  		// since interfaces now exist where only the value form of a type (T)
   387  		// implements the interface, but not the pointer variant (*T).
   388  		// See https://go.dev/issue/45346.
   389  		return reflect.PointerTo(from).Implements(to)
   390  	case reflect.Pointer:
   391  		// Common case for unmarshaling.
   392  		// From must be a concrete or interface type.
   393  		return reflect.PointerTo(from) == to
   394  	default:
   395  		// Common case for marshaling.
   396  		// From must be a concrete type.
   397  		return from == to
   398  	}
   399  }
   400  
   401  // castTo casts va to the specified type.
   402  // If the type is an interface, then the underlying type will always
   403  // be a non-nil pointer to a concrete type.
   404  //
   405  // Requirement: castableTo(va.Type(), to) must hold.
   406  func (va addressableValue) castTo(to reflect.Type) reflect.Value {
   407  	switch to.Kind() {
   408  	case reflect.Interface:
   409  		return va.Addr().Convert(to)
   410  	case reflect.Pointer:
   411  		return va.Addr()
   412  	default:
   413  		return va.Value
   414  	}
   415  }
   416  
   417  // castableToFromAny reports whether "to" can be casted to from any
   418  // of the dynamic types used to represent arbitrary JSON.
   419  func castableToFromAny(to reflect.Type) bool {
   420  	for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
   421  		if castableTo(from, to) {
   422  			return true
   423  		}
   424  	}
   425  	return false
   426  }
   427  

View as plain text