Source file src/encoding/json/v2/arshal_funcs.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"reflect"
    13  	"sync"
    14  
    15  	"encoding/json/internal"
    16  	"encoding/json/internal/jsonflags"
    17  	"encoding/json/internal/jsonopts"
    18  	"encoding/json/jsontext"
    19  )
    20  
    21  // SkipFunc may be returned by [MarshalToFunc] and [UnmarshalFromFunc] functions.
    22  //
    23  // Any function that returns SkipFunc must not cause observable side effects
    24  // on the provided [jsontext.Encoder] or [jsontext.Decoder].
    25  // For example, it is permissible to call [jsontext.Decoder.PeekKind],
    26  // but not permissible to call [jsontext.Decoder.ReadToken] or
    27  // [jsontext.Encoder.WriteToken] since such methods mutate the state.
    28  var SkipFunc = errors.New("json: skip function")
    29  
    30  var errSkipMutation = errors.New("must not read or write any tokens when skipping")
    31  var errNonSingularValue = errors.New("must read or write exactly one value")
    32  
    33  // Marshalers is a list of functions that may override the marshal behavior
    34  // of specific types. Populate [WithMarshalers] to use it with
    35  // [Marshal], [MarshalWrite], or [MarshalEncode].
    36  // A nil *Marshalers is equivalent to an empty list.
    37  // There are no exported fields or methods on Marshalers.
    38  type Marshalers = typedMarshalers
    39  
    40  // JoinMarshalers constructs a flattened list of marshal functions.
    41  // If multiple functions in the list are applicable for a value of a given type,
    42  // then those earlier in the list take precedence over those that come later.
    43  // If a function returns [SkipFunc], then the next applicable function is called,
    44  // otherwise the default marshaling behavior is used.
    45  //
    46  // For example:
    47  //
    48  //	m1 := JoinMarshalers(f1, f2)
    49  //	m2 := JoinMarshalers(f0, m1, f3)     // equivalent to m3
    50  //	m3 := JoinMarshalers(f0, f1, f2, f3) // equivalent to m2
    51  func JoinMarshalers(ms ...*Marshalers) *Marshalers {
    52  	return newMarshalers(ms...)
    53  }
    54  
    55  // Unmarshalers is a list of functions that may override the unmarshal behavior
    56  // of specific types. Populate [WithUnmarshalers] to use it with
    57  // [Unmarshal], [UnmarshalRead], or [UnmarshalDecode].
    58  // A nil *Unmarshalers is equivalent to an empty list.
    59  // There are no exported fields or methods on Unmarshalers.
    60  type Unmarshalers = typedUnmarshalers
    61  
    62  // JoinUnmarshalers constructs a flattened list of unmarshal functions.
    63  // If multiple functions in the list are applicable for a value of a given type,
    64  // then those earlier in the list take precedence over those that come later.
    65  // If a function returns [SkipFunc], then the next applicable function is called,
    66  // otherwise the default unmarshaling behavior is used.
    67  //
    68  // For example:
    69  //
    70  //	u1 := JoinUnmarshalers(f1, f2)
    71  //	u2 := JoinUnmarshalers(f0, u1, f3)     // equivalent to u3
    72  //	u3 := JoinUnmarshalers(f0, f1, f2, f3) // equivalent to u2
    73  func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
    74  	return newUnmarshalers(us...)
    75  }
    76  
    77  type typedMarshalers = typedArshalers[jsontext.Encoder]
    78  type typedUnmarshalers = typedArshalers[jsontext.Decoder]
    79  type typedArshalers[Coder any] struct {
    80  	nonComparable
    81  
    82  	fncVals  []typedArshaler[Coder]
    83  	fncCache sync.Map // map[reflect.Type]arshaler
    84  
    85  	// fromAny reports whether any of Go types used to represent arbitrary JSON
    86  	// (i.e., any, bool, string, float64, map[string]any, or []any) matches
    87  	// any of the provided type-specific arshalers.
    88  	//
    89  	// This bit of information is needed in arshal_default.go to determine
    90  	// whether to use the specialized logic in arshal_any.go to handle
    91  	// the any interface type. The logic in arshal_any.go does not support
    92  	// type-specific arshal functions, so we must avoid using that logic
    93  	// if this is true.
    94  	fromAny bool
    95  }
    96  type typedMarshaler = typedArshaler[jsontext.Encoder]
    97  type typedUnmarshaler = typedArshaler[jsontext.Decoder]
    98  type typedArshaler[Coder any] struct {
    99  	typ     reflect.Type
   100  	fnc     func(*Coder, addressableValue, *jsonopts.Struct) error
   101  	maySkip bool
   102  }
   103  
   104  func newMarshalers(ms ...*Marshalers) *Marshalers       { return newTypedArshalers(ms...) }
   105  func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
   106  func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
   107  	var a typedArshalers[Coder]
   108  	for _, a2 := range as {
   109  		if a2 != nil {
   110  			a.fncVals = append(a.fncVals, a2.fncVals...)
   111  			a.fromAny = a.fromAny || a2.fromAny
   112  		}
   113  	}
   114  	if len(a.fncVals) == 0 {
   115  		return nil
   116  	}
   117  	return &a
   118  }
   119  
   120  func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
   121  	if a == nil {
   122  		return fnc, false
   123  	}
   124  	if v, ok := a.fncCache.Load(t); ok {
   125  		if v == nil {
   126  			return fnc, false
   127  		}
   128  		return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   129  	}
   130  
   131  	// Collect a list of arshalers that can be called for this type.
   132  	// This list may be longer than 1 since some arshalers can be skipped.
   133  	var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
   134  	for _, fncVal := range a.fncVals {
   135  		if !castableTo(t, fncVal.typ) {
   136  			continue
   137  		}
   138  		fncs = append(fncs, fncVal.fnc)
   139  		if !fncVal.maySkip {
   140  			break // subsequent arshalers will never be called
   141  		}
   142  	}
   143  
   144  	if len(fncs) == 0 {
   145  		a.fncCache.Store(t, nil) // nil to indicate that no funcs found
   146  		return fnc, false
   147  	}
   148  
   149  	// Construct an arshaler that may call every applicable arshaler.
   150  	fncDefault := fnc
   151  	fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
   152  		for _, fnc := range fncs {
   153  			if err := fnc(c, v, o); err != SkipFunc {
   154  				return err // may be nil or non-nil
   155  			}
   156  		}
   157  		return fncDefault(c, v, o)
   158  	}
   159  
   160  	// Use the first stored so duplicate work can be garbage collected.
   161  	v, _ := a.fncCache.LoadOrStore(t, fnc)
   162  	return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   163  }
   164  
   165  // MarshalFunc constructs a type-specific marshaler that
   166  // specifies how to marshal values of type T.
   167  // T can be any type except a named pointer.
   168  // The function is always provided with a non-nil pointer value
   169  // if T is an interface or pointer type.
   170  //
   171  // The function must marshal exactly one JSON value.
   172  // The value of T must not be retained outside the function call.
   173  // It may not return [SkipFunc].
   174  func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
   175  	t := reflect.TypeFor[T]()
   176  	assertCastableTo(t, true)
   177  	typFnc := typedMarshaler{
   178  		typ: t,
   179  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   180  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   181  			val, err := fn(v)
   182  			if err != nil {
   183  				err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
   184  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   185  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   186  				}
   187  				err = newMarshalErrorBefore(enc, t, err)
   188  				return collapseSemanticErrors(err)
   189  			}
   190  			if err := enc.WriteValue(val); err != nil {
   191  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   192  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   193  				}
   194  				if isSyntacticError(err) {
   195  					err = newMarshalErrorBefore(enc, t, err)
   196  				}
   197  				return err
   198  			}
   199  			return nil
   200  		},
   201  	}
   202  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   203  }
   204  
   205  // MarshalToFunc constructs a type-specific marshaler that
   206  // specifies how to marshal values of type T.
   207  // T can be any type except a named pointer.
   208  // The function is always provided with a non-nil pointer value
   209  // if T is an interface or pointer type.
   210  //
   211  // The function must marshal exactly one JSON value by calling write methods
   212  // on the provided encoder. It may return [SkipFunc] such that marshaling can
   213  // move on to the next marshal function. However, no mutable method calls may
   214  // be called on the encoder if [SkipFunc] is returned.
   215  // The pointer to [jsontext.Encoder] and the value of T
   216  // must not be retained outside the function call.
   217  func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
   218  	t := reflect.TypeFor[T]()
   219  	assertCastableTo(t, true)
   220  	typFnc := typedMarshaler{
   221  		typ: t,
   222  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   223  			xe := export.Encoder(enc)
   224  			prevDepth, prevLength := xe.Tokens.DepthLength()
   225  			xe.Flags.Set(jsonflags.WithinArshalCall | 1)
   226  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   227  			err := fn(enc, v)
   228  			xe.Flags.Set(jsonflags.WithinArshalCall | 0)
   229  			currDepth, currLength := xe.Tokens.DepthLength()
   230  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   231  				err = errNonSingularValue
   232  			}
   233  			if err != nil {
   234  				if err == SkipFunc {
   235  					if prevDepth == currDepth && prevLength == currLength {
   236  						return SkipFunc
   237  					}
   238  					err = errSkipMutation
   239  				}
   240  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   241  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc") // unlike unmarshal, always wrapped
   242  				}
   243  				if !export.IsIOError(err) {
   244  					err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
   245  				}
   246  				return err
   247  			}
   248  			return nil
   249  		},
   250  		maySkip: true,
   251  	}
   252  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   253  }
   254  
   255  // UnmarshalFunc constructs a type-specific unmarshaler that
   256  // specifies how to unmarshal values of type T.
   257  // T must be an unnamed pointer or an interface type.
   258  // The function is always provided with a non-nil pointer value.
   259  //
   260  // The function must unmarshal exactly one JSON value.
   261  // The input []byte must not be mutated.
   262  // The input []byte and value T must not be retained outside the function call.
   263  // It may not return [SkipFunc].
   264  func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
   265  	t := reflect.TypeFor[T]()
   266  	assertCastableTo(t, false)
   267  	typFnc := typedUnmarshaler{
   268  		typ: t,
   269  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   270  			val, err := dec.ReadValue()
   271  			if err != nil {
   272  				return err // must be a syntactic or I/O error
   273  			}
   274  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   275  			err = fn(val, v)
   276  			if err != nil {
   277  				err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
   278  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   279  					return err // unlike marshal, never wrapped
   280  				}
   281  				err = newUnmarshalErrorAfter(dec, t, err)
   282  				return collapseSemanticErrors(err)
   283  			}
   284  			return nil
   285  		},
   286  	}
   287  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   288  }
   289  
   290  // UnmarshalFromFunc constructs a type-specific unmarshaler that
   291  // specifies how to unmarshal values of type T.
   292  // T must be an unnamed pointer or an interface type.
   293  // The function is always provided with a non-nil pointer value.
   294  //
   295  // The function must unmarshal exactly one JSON value by calling read methods
   296  // on the provided decoder. It may return [SkipFunc] such that unmarshaling can
   297  // move on to the next unmarshal function. However, no mutable method calls may
   298  // be called on the decoder if [SkipFunc] is returned.
   299  // The pointer to [jsontext.Decoder] and the value of T
   300  // must not be retained outside the function call.
   301  func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
   302  	t := reflect.TypeFor[T]()
   303  	assertCastableTo(t, false)
   304  	typFnc := typedUnmarshaler{
   305  		typ: t,
   306  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   307  			xd := export.Decoder(dec)
   308  			prevDepth, prevLength := xd.Tokens.DepthLength()
   309  			xd.Flags.Set(jsonflags.WithinArshalCall | 1)
   310  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   311  			err := fn(dec, v)
   312  			xd.Flags.Set(jsonflags.WithinArshalCall | 0)
   313  			currDepth, currLength := xd.Tokens.DepthLength()
   314  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   315  				err = errNonSingularValue
   316  			}
   317  			if err != nil {
   318  				if err == SkipFunc {
   319  					if prevDepth == currDepth && prevLength == currLength {
   320  						return SkipFunc
   321  					}
   322  					err = errSkipMutation
   323  				}
   324  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   325  					if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
   326  						return err2
   327  					}
   328  					return err // unlike marshal, never wrapped
   329  				}
   330  				if !isSyntacticError(err) && !export.IsIOError(err) {
   331  					err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
   332  				}
   333  				return err
   334  			}
   335  			return nil
   336  		},
   337  		maySkip: true,
   338  	}
   339  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   340  }
   341  
   342  // assertCastableTo asserts that "to" is a valid type to be casted to.
   343  // These are the Go types that type-specific arshalers may operate upon.
   344  //
   345  // Let AllTypes be the universal set of all possible Go types.
   346  // This function generally asserts that:
   347  //
   348  //	len([from for from in AllTypes if castableTo(from, to)]) > 0
   349  //
   350  // otherwise it panics.
   351  //
   352  // As a special-case if marshal is false, then we forbid any non-pointer or
   353  // non-interface type since it is almost always a bug trying to unmarshal
   354  // into something where the end-user caller did not pass in an addressable value
   355  // since they will not observe the mutations.
   356  func assertCastableTo(to reflect.Type, marshal bool) {
   357  	switch to.Kind() {
   358  	case reflect.Interface:
   359  		return
   360  	case reflect.Pointer:
   361  		// Only allow unnamed pointers to be consistent with the fact that
   362  		// taking the address of a value produces an unnamed pointer type.
   363  		if to.Name() == "" {
   364  			return
   365  		}
   366  	default:
   367  		// Technically, non-pointer types are permissible for unmarshal.
   368  		// However, they are often a bug since the receiver would be immutable.
   369  		// Thus, only allow them for marshaling.
   370  		if marshal {
   371  			return
   372  		}
   373  	}
   374  	if marshal {
   375  		panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
   376  	} else {
   377  		panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
   378  	}
   379  }
   380  
   381  // castableTo checks whether values of type "from" can be casted to type "to".
   382  // Nil pointer or interface "from" values are never considered castable.
   383  //
   384  // This function must be kept in sync with addressableValue.castTo.
   385  func castableTo(from, to reflect.Type) bool {
   386  	switch to.Kind() {
   387  	case reflect.Interface:
   388  		// TODO: This breaks when ordinary interfaces can have type sets
   389  		// since interfaces now exist where only the value form of a type (T)
   390  		// implements the interface, but not the pointer variant (*T).
   391  		// See https://go.dev/issue/45346.
   392  		return reflect.PointerTo(from).Implements(to)
   393  	case reflect.Pointer:
   394  		// Common case for unmarshaling.
   395  		// From must be a concrete or interface type.
   396  		return reflect.PointerTo(from) == to
   397  	default:
   398  		// Common case for marshaling.
   399  		// From must be a concrete type.
   400  		return from == to
   401  	}
   402  }
   403  
   404  // castTo casts va to the specified type.
   405  // If the type is an interface, then the underlying type will always
   406  // be a non-nil pointer to a concrete type.
   407  //
   408  // Requirement: castableTo(va.Type(), to) must hold.
   409  func (va addressableValue) castTo(to reflect.Type) reflect.Value {
   410  	switch to.Kind() {
   411  	case reflect.Interface:
   412  		return va.Addr().Convert(to)
   413  	case reflect.Pointer:
   414  		return va.Addr()
   415  	default:
   416  		return va.Value
   417  	}
   418  }
   419  
   420  // castableToFromAny reports whether "to" can be casted to from any
   421  // of the dynamic types used to represent arbitrary JSON.
   422  func castableToFromAny(to reflect.Type) bool {
   423  	for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
   424  		if castableTo(from, to) {
   425  			return true
   426  		}
   427  	}
   428  	return false
   429  }
   430  
   431  func wrapSkipFunc(err error, what string) error {
   432  	if err == SkipFunc {
   433  		return errors.New(what + " cannot be skipped")
   434  	}
   435  	return err
   436  }
   437  

View as plain text