Source file src/encoding/json/v2/arshal_funcs.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sync"
    15  
    16  	"encoding/json/internal"
    17  	"encoding/json/internal/jsonflags"
    18  	"encoding/json/internal/jsonopts"
    19  	"encoding/json/jsontext"
    20  )
    21  
    22  // SkipFunc may be returned by [MarshalToFunc] and [UnmarshalFromFunc] functions.
    23  //
    24  // Any function that returns SkipFunc must not cause observable side effects
    25  // on the provided [jsontext.Encoder] or [jsontext.Decoder].
    26  // For example, it is permissible to call [jsontext.Decoder.PeekKind],
    27  // but not permissible to call [jsontext.Decoder.ReadToken] or
    28  // [jsontext.Encoder.WriteToken] since such methods mutate the state.
    29  var SkipFunc = errors.New("json: skip function")
    30  
    31  var errSkipMutation = errors.New("must not read or write any tokens when skipping")
    32  var errNonSingularValue = errors.New("must read or write exactly one value")
    33  
    34  // Marshalers is a list of functions that may override the marshal behavior
    35  // of specific types. Populate [WithMarshalers] to use it with
    36  // [Marshal], [MarshalWrite], or [MarshalEncode].
    37  // A nil *Marshalers is equivalent to an empty list.
    38  // There are no exported fields or methods on Marshalers.
    39  type Marshalers = typedMarshalers
    40  
    41  // JoinMarshalers constructs a flattened list of marshal functions.
    42  // If multiple functions in the list are applicable for a value of a given type,
    43  // then those earlier in the list take precedence over those that come later.
    44  // If a function returns [SkipFunc], then the next applicable function is called,
    45  // otherwise the default marshaling behavior is used.
    46  //
    47  // For example:
    48  //
    49  //	m1 := JoinMarshalers(f1, f2)
    50  //	m2 := JoinMarshalers(f0, m1, f3)     // equivalent to m3
    51  //	m3 := JoinMarshalers(f0, f1, f2, f3) // equivalent to m2
    52  func JoinMarshalers(ms ...*Marshalers) *Marshalers {
    53  	return newMarshalers(ms...)
    54  }
    55  
    56  // Unmarshalers is a list of functions that may override the unmarshal behavior
    57  // of specific types. Populate [WithUnmarshalers] to use it with
    58  // [Unmarshal], [UnmarshalRead], or [UnmarshalDecode].
    59  // A nil *Unmarshalers is equivalent to an empty list.
    60  // There are no exported fields or methods on Unmarshalers.
    61  type Unmarshalers = typedUnmarshalers
    62  
    63  // JoinUnmarshalers constructs a flattened list of unmarshal functions.
    64  // If multiple functions in the list are applicable for a value of a given type,
    65  // then those earlier in the list take precedence over those that come later.
    66  // If a function returns [SkipFunc], then the next applicable function is called,
    67  // otherwise the default unmarshaling behavior is used.
    68  //
    69  // For example:
    70  //
    71  //	u1 := JoinUnmarshalers(f1, f2)
    72  //	u2 := JoinUnmarshalers(f0, u1, f3)     // equivalent to u3
    73  //	u3 := JoinUnmarshalers(f0, f1, f2, f3) // equivalent to u2
    74  func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
    75  	return newUnmarshalers(us...)
    76  }
    77  
    78  type typedMarshalers = typedArshalers[jsontext.Encoder]
    79  type typedUnmarshalers = typedArshalers[jsontext.Decoder]
    80  type typedArshalers[Coder any] struct {
    81  	nonComparable
    82  
    83  	fncVals  []typedArshaler[Coder]
    84  	fncCache sync.Map // map[reflect.Type]arshaler
    85  
    86  	// fromAny reports whether any of Go types used to represent arbitrary JSON
    87  	// (i.e., any, bool, string, float64, map[string]any, or []any) matches
    88  	// any of the provided type-specific arshalers.
    89  	//
    90  	// This bit of information is needed in arshal_default.go to determine
    91  	// whether to use the specialized logic in arshal_any.go to handle
    92  	// the any interface type. The logic in arshal_any.go does not support
    93  	// type-specific arshal functions, so we must avoid using that logic
    94  	// if this is true.
    95  	fromAny bool
    96  }
    97  type typedMarshaler = typedArshaler[jsontext.Encoder]
    98  type typedUnmarshaler = typedArshaler[jsontext.Decoder]
    99  type typedArshaler[Coder any] struct {
   100  	typ     reflect.Type
   101  	fnc     func(*Coder, addressableValue, *jsonopts.Struct) error
   102  	maySkip bool
   103  }
   104  
   105  func newMarshalers(ms ...*Marshalers) *Marshalers       { return newTypedArshalers(ms...) }
   106  func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
   107  func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
   108  	var a typedArshalers[Coder]
   109  	for _, a2 := range as {
   110  		if a2 != nil {
   111  			a.fncVals = append(a.fncVals, a2.fncVals...)
   112  			a.fromAny = a.fromAny || a2.fromAny
   113  		}
   114  	}
   115  	if len(a.fncVals) == 0 {
   116  		return nil
   117  	}
   118  	return &a
   119  }
   120  
   121  func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
   122  	if a == nil {
   123  		return fnc, false
   124  	}
   125  	if v, ok := a.fncCache.Load(t); ok {
   126  		if v == nil {
   127  			return fnc, false
   128  		}
   129  		return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   130  	}
   131  
   132  	// Collect a list of arshalers that can be called for this type.
   133  	// This list may be longer than 1 since some arshalers can be skipped.
   134  	var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
   135  	for _, fncVal := range a.fncVals {
   136  		if !castableTo(t, fncVal.typ) {
   137  			continue
   138  		}
   139  		fncs = append(fncs, fncVal.fnc)
   140  		if !fncVal.maySkip {
   141  			break // subsequent arshalers will never be called
   142  		}
   143  	}
   144  
   145  	if len(fncs) == 0 {
   146  		a.fncCache.Store(t, nil) // nil to indicate that no funcs found
   147  		return fnc, false
   148  	}
   149  
   150  	// Construct an arshaler that may call every applicable arshaler.
   151  	fncDefault := fnc
   152  	fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
   153  		for _, fnc := range fncs {
   154  			if err := fnc(c, v, o); err != SkipFunc {
   155  				return err // may be nil or non-nil
   156  			}
   157  		}
   158  		return fncDefault(c, v, o)
   159  	}
   160  
   161  	// Use the first stored so duplicate work can be garbage collected.
   162  	v, _ := a.fncCache.LoadOrStore(t, fnc)
   163  	return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   164  }
   165  
   166  // MarshalFunc constructs a type-specific marshaler that
   167  // specifies how to marshal values of type T.
   168  // T can be any type except a named pointer.
   169  // The function is always provided with a non-nil pointer value
   170  // if T is an interface or pointer type.
   171  //
   172  // The function must marshal exactly one JSON value.
   173  // The value of T must not be retained outside the function call.
   174  // It may not return [SkipFunc].
   175  func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
   176  	t := reflect.TypeFor[T]()
   177  	assertCastableTo(t, true)
   178  	typFnc := typedMarshaler{
   179  		typ: t,
   180  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   181  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   182  			val, err := fn(v)
   183  			if err != nil {
   184  				err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
   185  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   186  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   187  				}
   188  				err = newMarshalErrorBefore(enc, t, err)
   189  				return collapseSemanticErrors(err)
   190  			}
   191  			if err := enc.WriteValue(val); err != nil {
   192  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   193  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   194  				}
   195  				if isSyntacticError(err) {
   196  					err = newMarshalErrorBefore(enc, t, err)
   197  				}
   198  				return err
   199  			}
   200  			return nil
   201  		},
   202  	}
   203  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   204  }
   205  
   206  // MarshalToFunc constructs a type-specific marshaler that
   207  // specifies how to marshal values of type T.
   208  // T can be any type except a named pointer.
   209  // The function is always provided with a non-nil pointer value
   210  // if T is an interface or pointer type.
   211  //
   212  // The function must marshal exactly one JSON value by calling write methods
   213  // on the provided encoder. It may return [SkipFunc] such that marshaling can
   214  // move on to the next marshal function. However, no mutable method calls may
   215  // be called on the encoder if [SkipFunc] is returned.
   216  // The pointer to [jsontext.Encoder] and the value of T
   217  // must not be retained outside the function call.
   218  func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
   219  	t := reflect.TypeFor[T]()
   220  	assertCastableTo(t, true)
   221  	typFnc := typedMarshaler{
   222  		typ: t,
   223  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   224  			xe := export.Encoder(enc)
   225  			prevDepth, prevLength := xe.Tokens.DepthLength()
   226  			xe.Flags.Set(jsonflags.WithinArshalCall | 1)
   227  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   228  			err := fn(enc, v)
   229  			xe.Flags.Set(jsonflags.WithinArshalCall | 0)
   230  			currDepth, currLength := xe.Tokens.DepthLength()
   231  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   232  				err = errNonSingularValue
   233  			}
   234  			if err != nil {
   235  				if err == SkipFunc {
   236  					if prevDepth == currDepth && prevLength == currLength {
   237  						return SkipFunc
   238  					}
   239  					err = errSkipMutation
   240  				}
   241  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   242  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc") // unlike unmarshal, always wrapped
   243  				}
   244  				if !export.IsIOError(err) {
   245  					err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
   246  				}
   247  				return err
   248  			}
   249  			return nil
   250  		},
   251  		maySkip: true,
   252  	}
   253  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   254  }
   255  
   256  // UnmarshalFunc constructs a type-specific unmarshaler that
   257  // specifies how to unmarshal values of type T.
   258  // T must be an unnamed pointer or an interface type.
   259  // The function is always provided with a non-nil pointer value.
   260  //
   261  // The function must unmarshal exactly one JSON value.
   262  // The input []byte must not be mutated.
   263  // The input []byte and value T must not be retained outside the function call.
   264  // It may not return [SkipFunc].
   265  func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
   266  	t := reflect.TypeFor[T]()
   267  	assertCastableTo(t, false)
   268  	typFnc := typedUnmarshaler{
   269  		typ: t,
   270  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   271  			val, err := dec.ReadValue()
   272  			if err != nil {
   273  				return err // must be a syntactic or I/O error
   274  			}
   275  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   276  			err = fn(val, v)
   277  			if err != nil {
   278  				err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
   279  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   280  					return err // unlike marshal, never wrapped
   281  				}
   282  				err = newUnmarshalErrorAfter(dec, t, err)
   283  				return collapseSemanticErrors(err)
   284  			}
   285  			return nil
   286  		},
   287  	}
   288  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   289  }
   290  
   291  // UnmarshalFromFunc constructs a type-specific unmarshaler that
   292  // specifies how to unmarshal values of type T.
   293  // T must be an unnamed pointer or an interface type.
   294  // The function is always provided with a non-nil pointer value.
   295  //
   296  // The function must unmarshal exactly one JSON value by calling read methods
   297  // on the provided decoder. It may return [SkipFunc] such that unmarshaling can
   298  // move on to the next unmarshal function. However, no mutable method calls may
   299  // be called on the decoder if [SkipFunc] is returned.
   300  // The pointer to [jsontext.Decoder] and the value of T
   301  // must not be retained outside the function call.
   302  func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
   303  	t := reflect.TypeFor[T]()
   304  	assertCastableTo(t, false)
   305  	typFnc := typedUnmarshaler{
   306  		typ: t,
   307  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   308  			xd := export.Decoder(dec)
   309  			prevDepth, prevLength := xd.Tokens.DepthLength()
   310  			if prevDepth == 1 && xd.AtEOF() {
   311  				return io.EOF // check EOF early to avoid fn reporting an EOF
   312  			}
   313  			xd.Flags.Set(jsonflags.WithinArshalCall | 1)
   314  			v, _ := reflect.TypeAssert[T](va.castTo(t))
   315  			err := fn(dec, v)
   316  			xd.Flags.Set(jsonflags.WithinArshalCall | 0)
   317  			currDepth, currLength := xd.Tokens.DepthLength()
   318  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   319  				err = errNonSingularValue
   320  			}
   321  			if err != nil {
   322  				if err == SkipFunc {
   323  					if prevDepth == currDepth && prevLength == currLength {
   324  						return SkipFunc
   325  					}
   326  					err = errSkipMutation
   327  				}
   328  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   329  					if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
   330  						return err2
   331  					}
   332  					return err // unlike marshal, never wrapped
   333  				}
   334  				if !isSyntacticError(err) && !export.IsIOError(err) {
   335  					err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
   336  				}
   337  				return err
   338  			}
   339  			return nil
   340  		},
   341  		maySkip: true,
   342  	}
   343  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   344  }
   345  
   346  // assertCastableTo asserts that "to" is a valid type to be casted to.
   347  // These are the Go types that type-specific arshalers may operate upon.
   348  //
   349  // Let AllTypes be the universal set of all possible Go types.
   350  // This function generally asserts that:
   351  //
   352  //	len([from for from in AllTypes if castableTo(from, to)]) > 0
   353  //
   354  // otherwise it panics.
   355  //
   356  // As a special-case if marshal is false, then we forbid any non-pointer or
   357  // non-interface type since it is almost always a bug trying to unmarshal
   358  // into something where the end-user caller did not pass in an addressable value
   359  // since they will not observe the mutations.
   360  func assertCastableTo(to reflect.Type, marshal bool) {
   361  	switch to.Kind() {
   362  	case reflect.Interface:
   363  		return
   364  	case reflect.Pointer:
   365  		// Only allow unnamed pointers to be consistent with the fact that
   366  		// taking the address of a value produces an unnamed pointer type.
   367  		if to.Name() == "" {
   368  			return
   369  		}
   370  	default:
   371  		// Technically, non-pointer types are permissible for unmarshal.
   372  		// However, they are often a bug since the receiver would be immutable.
   373  		// Thus, only allow them for marshaling.
   374  		if marshal {
   375  			return
   376  		}
   377  	}
   378  	if marshal {
   379  		panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
   380  	} else {
   381  		panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
   382  	}
   383  }
   384  
   385  // castableTo checks whether values of type "from" can be casted to type "to".
   386  // Nil pointer or interface "from" values are never considered castable.
   387  //
   388  // This function must be kept in sync with addressableValue.castTo.
   389  func castableTo(from, to reflect.Type) bool {
   390  	switch to.Kind() {
   391  	case reflect.Interface:
   392  		// TODO: This breaks when ordinary interfaces can have type sets
   393  		// since interfaces now exist where only the value form of a type (T)
   394  		// implements the interface, but not the pointer variant (*T).
   395  		// See https://go.dev/issue/45346.
   396  		return reflect.PointerTo(from).Implements(to)
   397  	case reflect.Pointer:
   398  		// Common case for unmarshaling.
   399  		// From must be a concrete or interface type.
   400  		return reflect.PointerTo(from) == to
   401  	default:
   402  		// Common case for marshaling.
   403  		// From must be a concrete type.
   404  		return from == to
   405  	}
   406  }
   407  
   408  // castTo casts va to the specified type.
   409  // If the type is an interface, then the underlying type will always
   410  // be a non-nil pointer to a concrete type.
   411  //
   412  // Requirement: castableTo(va.Type(), to) must hold.
   413  func (va addressableValue) castTo(to reflect.Type) reflect.Value {
   414  	switch to.Kind() {
   415  	case reflect.Interface:
   416  		return va.Addr().Convert(to)
   417  	case reflect.Pointer:
   418  		return va.Addr()
   419  	default:
   420  		return va.Value
   421  	}
   422  }
   423  
   424  // castableToFromAny reports whether "to" can be casted to from any
   425  // of the dynamic types used to represent arbitrary JSON.
   426  func castableToFromAny(to reflect.Type) bool {
   427  	for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
   428  		if castableTo(from, to) {
   429  			return true
   430  		}
   431  	}
   432  	return false
   433  }
   434  
   435  func wrapSkipFunc(err error, what string) error {
   436  	if err == SkipFunc {
   437  		return errors.New(what + " cannot be skipped")
   438  	}
   439  	return err
   440  }
   441  

View as plain text