1
2
3
4
5
6
7 package json
8
9 import (
10 "errors"
11 "fmt"
12 "io"
13 "reflect"
14 "sync"
15
16 "encoding/json/internal"
17 "encoding/json/internal/jsonflags"
18 "encoding/json/internal/jsonopts"
19 "encoding/json/jsontext"
20 )
21
22 var errUnsupportedMutation = errors.New("unsupported calls must not read or write any tokens")
23 var errNonSingularValue = errors.New("must read or write exactly one value")
24
25
26
27
28
29
30 type Marshalers = typedMarshalers
31
32
33
34
35
36
37
38
39
40
41
42
43
44 func JoinMarshalers(ms ...*Marshalers) *Marshalers {
45 return newMarshalers(ms...)
46 }
47
48
49
50
51
52
53 type Unmarshalers = typedUnmarshalers
54
55
56
57
58
59
60
61
62
63
64
65
66
67 func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
68 return newUnmarshalers(us...)
69 }
70
71 type typedMarshalers = typedArshalers[jsontext.Encoder]
72 type typedUnmarshalers = typedArshalers[jsontext.Decoder]
73 type typedArshalers[Coder any] struct {
74 nonComparable
75
76 fncVals []typedArshaler[Coder]
77 fncCache sync.Map
78
79
80
81
82
83
84
85
86
87
88 fromAny bool
89 }
90 type typedMarshaler = typedArshaler[jsontext.Encoder]
91 type typedUnmarshaler = typedArshaler[jsontext.Decoder]
92 type typedArshaler[Coder any] struct {
93 typ reflect.Type
94 fnc func(*Coder, addressableValue, *jsonopts.Struct) error
95 maySkip bool
96 }
97
98 func newMarshalers(ms ...*Marshalers) *Marshalers { return newTypedArshalers(ms...) }
99 func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
100 func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
101 var a typedArshalers[Coder]
102 for _, a2 := range as {
103 if a2 != nil {
104 a.fncVals = append(a.fncVals, a2.fncVals...)
105 a.fromAny = a.fromAny || a2.fromAny
106 }
107 }
108 if len(a.fncVals) == 0 {
109 return nil
110 }
111 return &a
112 }
113
114 func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
115 if a == nil {
116 return fnc, false
117 }
118 if v, ok := a.fncCache.Load(t); ok {
119 if v == nil {
120 return fnc, false
121 }
122 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
123 }
124
125
126
127 var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
128 for _, fncVal := range a.fncVals {
129 if !castableTo(t, fncVal.typ) {
130 continue
131 }
132 fncs = append(fncs, fncVal.fnc)
133 if !fncVal.maySkip {
134 break
135 }
136 }
137
138 if len(fncs) == 0 {
139 a.fncCache.Store(t, nil)
140 return fnc, false
141 }
142
143
144 fncDefault := fnc
145 fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
146 for _, fnc := range fncs {
147 if err := fnc(c, v, o); !errors.Is(err, errors.ErrUnsupported) {
148 return err
149 }
150 }
151 return fncDefault(c, v, o)
152 }
153
154
155 v, _ := a.fncCache.LoadOrStore(t, fnc)
156 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
157 }
158
159
160
161
162
163
164
165
166
167
168 func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
169 t := reflect.TypeFor[T]()
170 assertCastableTo(t, true)
171 typFnc := typedMarshaler{
172 typ: t,
173 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
174 v, _ := reflect.TypeAssert[T](va.castTo(t))
175 val, err := fn(v)
176 if err != nil {
177 err = wrapErrUnsupported(err, "marshal function of type func(T) ([]byte, error)")
178 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
179 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
180 }
181 err = newMarshalErrorBefore(enc, t, err)
182 return collapseSemanticErrors(err)
183 }
184 if err := enc.WriteValue(val); err != nil {
185 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
186 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
187 }
188 if isSyntacticError(err) {
189 err = newMarshalErrorBefore(enc, t, err)
190 }
191 return err
192 }
193 return nil
194 },
195 }
196 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
197 }
198
199
200
201
202
203
204
205
206
207
208
209
210
211 func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
212 t := reflect.TypeFor[T]()
213 assertCastableTo(t, true)
214 typFnc := typedMarshaler{
215 typ: t,
216 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
217 xe := export.Encoder(enc)
218 prevDepth, prevLength := xe.Tokens.DepthLength()
219 xe.Flags.Set(jsonflags.WithinArshalCall | 1)
220 v, _ := reflect.TypeAssert[T](va.castTo(t))
221 err := fn(enc, v)
222 xe.Flags.Set(jsonflags.WithinArshalCall | 0)
223 currDepth, currLength := xe.Tokens.DepthLength()
224 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
225 err = errNonSingularValue
226 }
227 if err != nil {
228 if errors.Is(err, errors.ErrUnsupported) {
229 if prevDepth == currDepth && prevLength == currLength {
230 return err
231 }
232 err = errUnsupportedMutation
233 }
234 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
235 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc")
236 }
237 if !export.IsIOError(err) {
238 err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
239 }
240 return err
241 }
242 return nil
243 },
244 maySkip: true,
245 }
246 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
247 }
248
249
250
251
252
253
254
255
256
257
258 func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
259 t := reflect.TypeFor[T]()
260 assertCastableTo(t, false)
261 typFnc := typedUnmarshaler{
262 typ: t,
263 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
264 val, err := dec.ReadValue()
265 if err != nil {
266 return err
267 }
268 v, _ := reflect.TypeAssert[T](va.castTo(t))
269 err = fn(val, v)
270 if err != nil {
271 err = wrapErrUnsupported(err, "unmarshal function of type func([]byte, T) error")
272 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
273 return err
274 }
275 err = newUnmarshalErrorAfter(dec, t, err)
276 return collapseSemanticErrors(err)
277 }
278 return nil
279 },
280 }
281 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
282 }
283
284
285
286
287
288
289
290
291
292
293
294
295 func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
296 t := reflect.TypeFor[T]()
297 assertCastableTo(t, false)
298 typFnc := typedUnmarshaler{
299 typ: t,
300 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
301 xd := export.Decoder(dec)
302 prevDepth, prevLength := xd.Tokens.DepthLength()
303 if prevDepth == 1 && xd.AtEOF() {
304 return io.EOF
305 }
306 xd.Flags.Set(jsonflags.WithinArshalCall | 1)
307 v, _ := reflect.TypeAssert[T](va.castTo(t))
308 err := fn(dec, v)
309 xd.Flags.Set(jsonflags.WithinArshalCall | 0)
310 currDepth, currLength := xd.Tokens.DepthLength()
311 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
312 err = errNonSingularValue
313 }
314 if err != nil {
315 if errors.Is(err, errors.ErrUnsupported) {
316 if prevDepth == currDepth && prevLength == currLength {
317 return err
318 }
319 err = errUnsupportedMutation
320 }
321 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
322 if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
323 return err2
324 }
325 return err
326 }
327 if !isSyntacticError(err) && !export.IsIOError(err) {
328 err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
329 }
330 return err
331 }
332 return nil
333 },
334 maySkip: true,
335 }
336 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
337 }
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353 func assertCastableTo(to reflect.Type, marshal bool) {
354 switch to.Kind() {
355 case reflect.Interface:
356 return
357 case reflect.Pointer:
358
359
360 if to.Name() == "" {
361 return
362 }
363 default:
364
365
366
367 if marshal {
368 return
369 }
370 }
371 if marshal {
372 panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
373 } else {
374 panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
375 }
376 }
377
378
379
380
381
382 func castableTo(from, to reflect.Type) bool {
383 switch to.Kind() {
384 case reflect.Interface:
385
386
387
388
389 return reflect.PointerTo(from).Implements(to)
390 case reflect.Pointer:
391
392
393 return reflect.PointerTo(from) == to
394 default:
395
396
397 return from == to
398 }
399 }
400
401
402
403
404
405
406 func (va addressableValue) castTo(to reflect.Type) reflect.Value {
407 switch to.Kind() {
408 case reflect.Interface:
409 return va.Addr().Convert(to)
410 case reflect.Pointer:
411 return va.Addr()
412 default:
413 return va.Value
414 }
415 }
416
417
418
419 func castableToFromAny(to reflect.Type) bool {
420 for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
421 if castableTo(from, to) {
422 return true
423 }
424 }
425 return false
426 }
427
View as plain text