Source file src/encoding/json/v2/arshal_funcs.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"reflect"
    13  	"sync"
    14  
    15  	"encoding/json/internal"
    16  	"encoding/json/internal/jsonflags"
    17  	"encoding/json/internal/jsonopts"
    18  	"encoding/json/jsontext"
    19  )
    20  
    21  // SkipFunc may be returned by [MarshalToFunc] and [UnmarshalFromFunc] functions.
    22  //
    23  // Any function that returns SkipFunc must not cause observable side effects
    24  // on the provided [jsontext.Encoder] or [jsontext.Decoder].
    25  // For example, it is permissible to call [jsontext.Decoder.PeekKind],
    26  // but not permissible to call [jsontext.Decoder.ReadToken] or
    27  // [jsontext.Encoder.WriteToken] since such methods mutate the state.
    28  var SkipFunc = errors.New("json: skip function")
    29  
    30  var errSkipMutation = errors.New("must not read or write any tokens when skipping")
    31  var errNonSingularValue = errors.New("must read or write exactly one value")
    32  
    33  // Marshalers is a list of functions that may override the marshal behavior
    34  // of specific types. Populate [WithMarshalers] to use it with
    35  // [Marshal], [MarshalWrite], or [MarshalEncode].
    36  // A nil *Marshalers is equivalent to an empty list.
    37  // There are no exported fields or methods on Marshalers.
    38  type Marshalers = typedMarshalers
    39  
    40  // JoinMarshalers constructs a flattened list of marshal functions.
    41  // If multiple functions in the list are applicable for a value of a given type,
    42  // then those earlier in the list take precedence over those that come later.
    43  // If a function returns [SkipFunc], then the next applicable function is called,
    44  // otherwise the default marshaling behavior is used.
    45  //
    46  // For example:
    47  //
    48  //	m1 := JoinMarshalers(f1, f2)
    49  //	m2 := JoinMarshalers(f0, m1, f3)     // equivalent to m3
    50  //	m3 := JoinMarshalers(f0, f1, f2, f3) // equivalent to m2
    51  func JoinMarshalers(ms ...*Marshalers) *Marshalers {
    52  	return newMarshalers(ms...)
    53  }
    54  
    55  // Unmarshalers is a list of functions that may override the unmarshal behavior
    56  // of specific types. Populate [WithUnmarshalers] to use it with
    57  // [Unmarshal], [UnmarshalRead], or [UnmarshalDecode].
    58  // A nil *Unmarshalers is equivalent to an empty list.
    59  // There are no exported fields or methods on Unmarshalers.
    60  type Unmarshalers = typedUnmarshalers
    61  
    62  // JoinUnmarshalers constructs a flattened list of unmarshal functions.
    63  // If multiple functions in the list are applicable for a value of a given type,
    64  // then those earlier in the list take precedence over those that come later.
    65  // If a function returns [SkipFunc], then the next applicable function is called,
    66  // otherwise the default unmarshaling behavior is used.
    67  //
    68  // For example:
    69  //
    70  //	u1 := JoinUnmarshalers(f1, f2)
    71  //	u2 := JoinUnmarshalers(f0, u1, f3)     // equivalent to u3
    72  //	u3 := JoinUnmarshalers(f0, f1, f2, f3) // equivalent to u2
    73  func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
    74  	return newUnmarshalers(us...)
    75  }
    76  
    77  type typedMarshalers = typedArshalers[jsontext.Encoder]
    78  type typedUnmarshalers = typedArshalers[jsontext.Decoder]
    79  type typedArshalers[Coder any] struct {
    80  	nonComparable
    81  
    82  	fncVals  []typedArshaler[Coder]
    83  	fncCache sync.Map // map[reflect.Type]arshaler
    84  
    85  	// fromAny reports whether any of Go types used to represent arbitrary JSON
    86  	// (i.e., any, bool, string, float64, map[string]any, or []any) matches
    87  	// any of the provided type-specific arshalers.
    88  	//
    89  	// This bit of information is needed in arshal_default.go to determine
    90  	// whether to use the specialized logic in arshal_any.go to handle
    91  	// the any interface type. The logic in arshal_any.go does not support
    92  	// type-specific arshal functions, so we must avoid using that logic
    93  	// if this is true.
    94  	fromAny bool
    95  }
    96  type typedMarshaler = typedArshaler[jsontext.Encoder]
    97  type typedUnmarshaler = typedArshaler[jsontext.Decoder]
    98  type typedArshaler[Coder any] struct {
    99  	typ     reflect.Type
   100  	fnc     func(*Coder, addressableValue, *jsonopts.Struct) error
   101  	maySkip bool
   102  }
   103  
   104  func newMarshalers(ms ...*Marshalers) *Marshalers       { return newTypedArshalers(ms...) }
   105  func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
   106  func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
   107  	var a typedArshalers[Coder]
   108  	for _, a2 := range as {
   109  		if a2 != nil {
   110  			a.fncVals = append(a.fncVals, a2.fncVals...)
   111  			a.fromAny = a.fromAny || a2.fromAny
   112  		}
   113  	}
   114  	if len(a.fncVals) == 0 {
   115  		return nil
   116  	}
   117  	return &a
   118  }
   119  
   120  func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
   121  	if a == nil {
   122  		return fnc, false
   123  	}
   124  	if v, ok := a.fncCache.Load(t); ok {
   125  		if v == nil {
   126  			return fnc, false
   127  		}
   128  		return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   129  	}
   130  
   131  	// Collect a list of arshalers that can be called for this type.
   132  	// This list may be longer than 1 since some arshalers can be skipped.
   133  	var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
   134  	for _, fncVal := range a.fncVals {
   135  		if !castableTo(t, fncVal.typ) {
   136  			continue
   137  		}
   138  		fncs = append(fncs, fncVal.fnc)
   139  		if !fncVal.maySkip {
   140  			break // subsequent arshalers will never be called
   141  		}
   142  	}
   143  
   144  	if len(fncs) == 0 {
   145  		a.fncCache.Store(t, nil) // nil to indicate that no funcs found
   146  		return fnc, false
   147  	}
   148  
   149  	// Construct an arshaler that may call every applicable arshaler.
   150  	fncDefault := fnc
   151  	fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
   152  		for _, fnc := range fncs {
   153  			if err := fnc(c, v, o); err != SkipFunc {
   154  				return err // may be nil or non-nil
   155  			}
   156  		}
   157  		return fncDefault(c, v, o)
   158  	}
   159  
   160  	// Use the first stored so duplicate work can be garbage collected.
   161  	v, _ := a.fncCache.LoadOrStore(t, fnc)
   162  	return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
   163  }
   164  
   165  // MarshalFunc constructs a type-specific marshaler that
   166  // specifies how to marshal values of type T.
   167  // T can be any type except a named pointer.
   168  // The function is always provided with a non-nil pointer value
   169  // if T is an interface or pointer type.
   170  //
   171  // The function must marshal exactly one JSON value.
   172  // The value of T must not be retained outside the function call.
   173  // It may not return [SkipFunc].
   174  func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
   175  	t := reflect.TypeFor[T]()
   176  	assertCastableTo(t, true)
   177  	typFnc := typedMarshaler{
   178  		typ: t,
   179  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   180  			val, err := fn(va.castTo(t).Interface().(T))
   181  			if err != nil {
   182  				err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
   183  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   184  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   185  				}
   186  				err = newMarshalErrorBefore(enc, t, err)
   187  				return collapseSemanticErrors(err)
   188  			}
   189  			if err := enc.WriteValue(val); err != nil {
   190  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   191  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc") // unlike unmarshal, always wrapped
   192  				}
   193  				if isSyntacticError(err) {
   194  					err = newMarshalErrorBefore(enc, t, err)
   195  				}
   196  				return err
   197  			}
   198  			return nil
   199  		},
   200  	}
   201  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   202  }
   203  
   204  // MarshalToFunc constructs a type-specific marshaler that
   205  // specifies how to marshal values of type T.
   206  // T can be any type except a named pointer.
   207  // The function is always provided with a non-nil pointer value
   208  // if T is an interface or pointer type.
   209  //
   210  // The function must marshal exactly one JSON value by calling write methods
   211  // on the provided encoder. It may return [SkipFunc] such that marshaling can
   212  // move on to the next marshal function. However, no mutable method calls may
   213  // be called on the encoder if [SkipFunc] is returned.
   214  // The pointer to [jsontext.Encoder] and the value of T
   215  // must not be retained outside the function call.
   216  func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
   217  	t := reflect.TypeFor[T]()
   218  	assertCastableTo(t, true)
   219  	typFnc := typedMarshaler{
   220  		typ: t,
   221  		fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
   222  			xe := export.Encoder(enc)
   223  			prevDepth, prevLength := xe.Tokens.DepthLength()
   224  			xe.Flags.Set(jsonflags.WithinArshalCall | 1)
   225  			err := fn(enc, va.castTo(t).Interface().(T))
   226  			xe.Flags.Set(jsonflags.WithinArshalCall | 0)
   227  			currDepth, currLength := xe.Tokens.DepthLength()
   228  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   229  				err = errNonSingularValue
   230  			}
   231  			if err != nil {
   232  				if err == SkipFunc {
   233  					if prevDepth == currDepth && prevLength == currLength {
   234  						return SkipFunc
   235  					}
   236  					err = errSkipMutation
   237  				}
   238  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   239  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc") // unlike unmarshal, always wrapped
   240  				}
   241  				if !export.IsIOError(err) {
   242  					err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
   243  				}
   244  				return err
   245  			}
   246  			return nil
   247  		},
   248  		maySkip: true,
   249  	}
   250  	return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
   251  }
   252  
   253  // UnmarshalFunc constructs a type-specific unmarshaler that
   254  // specifies how to unmarshal values of type T.
   255  // T must be an unnamed pointer or an interface type.
   256  // The function is always provided with a non-nil pointer value.
   257  //
   258  // The function must unmarshal exactly one JSON value.
   259  // The input []byte must not be mutated.
   260  // The input []byte and value T must not be retained outside the function call.
   261  // It may not return [SkipFunc].
   262  func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
   263  	t := reflect.TypeFor[T]()
   264  	assertCastableTo(t, false)
   265  	typFnc := typedUnmarshaler{
   266  		typ: t,
   267  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   268  			val, err := dec.ReadValue()
   269  			if err != nil {
   270  				return err // must be a syntactic or I/O error
   271  			}
   272  			err = fn(val, va.castTo(t).Interface().(T))
   273  			if err != nil {
   274  				err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
   275  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   276  					return err // unlike marshal, never wrapped
   277  				}
   278  				err = newUnmarshalErrorAfter(dec, t, err)
   279  				return collapseSemanticErrors(err)
   280  			}
   281  			return nil
   282  		},
   283  	}
   284  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   285  }
   286  
   287  // UnmarshalFromFunc constructs a type-specific unmarshaler that
   288  // specifies how to unmarshal values of type T.
   289  // T must be an unnamed pointer or an interface type.
   290  // The function is always provided with a non-nil pointer value.
   291  //
   292  // The function must unmarshal exactly one JSON value by calling read methods
   293  // on the provided decoder. It may return [SkipFunc] such that unmarshaling can
   294  // move on to the next unmarshal function. However, no mutable method calls may
   295  // be called on the decoder if [SkipFunc] is returned.
   296  // The pointer to [jsontext.Decoder] and the value of T
   297  // must not be retained outside the function call.
   298  func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
   299  	t := reflect.TypeFor[T]()
   300  	assertCastableTo(t, false)
   301  	typFnc := typedUnmarshaler{
   302  		typ: t,
   303  		fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
   304  			xd := export.Decoder(dec)
   305  			prevDepth, prevLength := xd.Tokens.DepthLength()
   306  			xd.Flags.Set(jsonflags.WithinArshalCall | 1)
   307  			err := fn(dec, va.castTo(t).Interface().(T))
   308  			xd.Flags.Set(jsonflags.WithinArshalCall | 0)
   309  			currDepth, currLength := xd.Tokens.DepthLength()
   310  			if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
   311  				err = errNonSingularValue
   312  			}
   313  			if err != nil {
   314  				if err == SkipFunc {
   315  					if prevDepth == currDepth && prevLength == currLength {
   316  						return SkipFunc
   317  					}
   318  					err = errSkipMutation
   319  				}
   320  				if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   321  					if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
   322  						return err2
   323  					}
   324  					return err // unlike marshal, never wrapped
   325  				}
   326  				if !isSyntacticError(err) && !export.IsIOError(err) {
   327  					err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
   328  				}
   329  				return err
   330  			}
   331  			return nil
   332  		},
   333  		maySkip: true,
   334  	}
   335  	return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
   336  }
   337  
   338  // assertCastableTo asserts that "to" is a valid type to be casted to.
   339  // These are the Go types that type-specific arshalers may operate upon.
   340  //
   341  // Let AllTypes be the universal set of all possible Go types.
   342  // This function generally asserts that:
   343  //
   344  //	len([from for from in AllTypes if castableTo(from, to)]) > 0
   345  //
   346  // otherwise it panics.
   347  //
   348  // As a special-case if marshal is false, then we forbid any non-pointer or
   349  // non-interface type since it is almost always a bug trying to unmarshal
   350  // into something where the end-user caller did not pass in an addressable value
   351  // since they will not observe the mutations.
   352  func assertCastableTo(to reflect.Type, marshal bool) {
   353  	switch to.Kind() {
   354  	case reflect.Interface:
   355  		return
   356  	case reflect.Pointer:
   357  		// Only allow unnamed pointers to be consistent with the fact that
   358  		// taking the address of a value produces an unnamed pointer type.
   359  		if to.Name() == "" {
   360  			return
   361  		}
   362  	default:
   363  		// Technically, non-pointer types are permissible for unmarshal.
   364  		// However, they are often a bug since the receiver would be immutable.
   365  		// Thus, only allow them for marshaling.
   366  		if marshal {
   367  			return
   368  		}
   369  	}
   370  	if marshal {
   371  		panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
   372  	} else {
   373  		panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
   374  	}
   375  }
   376  
   377  // castableTo checks whether values of type "from" can be casted to type "to".
   378  // Nil pointer or interface "from" values are never considered castable.
   379  //
   380  // This function must be kept in sync with addressableValue.castTo.
   381  func castableTo(from, to reflect.Type) bool {
   382  	switch to.Kind() {
   383  	case reflect.Interface:
   384  		// TODO: This breaks when ordinary interfaces can have type sets
   385  		// since interfaces now exist where only the value form of a type (T)
   386  		// implements the interface, but not the pointer variant (*T).
   387  		// See https://go.dev/issue/45346.
   388  		return reflect.PointerTo(from).Implements(to)
   389  	case reflect.Pointer:
   390  		// Common case for unmarshaling.
   391  		// From must be a concrete or interface type.
   392  		return reflect.PointerTo(from) == to
   393  	default:
   394  		// Common case for marshaling.
   395  		// From must be a concrete type.
   396  		return from == to
   397  	}
   398  }
   399  
   400  // castTo casts va to the specified type.
   401  // If the type is an interface, then the underlying type will always
   402  // be a non-nil pointer to a concrete type.
   403  //
   404  // Requirement: castableTo(va.Type(), to) must hold.
   405  func (va addressableValue) castTo(to reflect.Type) reflect.Value {
   406  	switch to.Kind() {
   407  	case reflect.Interface:
   408  		return va.Addr().Convert(to)
   409  	case reflect.Pointer:
   410  		return va.Addr()
   411  	default:
   412  		return va.Value
   413  	}
   414  }
   415  
   416  // castableToFromAny reports whether "to" can be casted to from any
   417  // of the dynamic types used to represent arbitrary JSON.
   418  func castableToFromAny(to reflect.Type) bool {
   419  	for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
   420  		if castableTo(from, to) {
   421  			return true
   422  		}
   423  	}
   424  	return false
   425  }
   426  
   427  func wrapSkipFunc(err error, what string) error {
   428  	if err == SkipFunc {
   429  		return errors.New(what + " cannot be skipped")
   430  	}
   431  	return err
   432  }
   433  

View as plain text