1
2
3
4
5
6
7 package json
8
9 import (
10 "errors"
11 "fmt"
12 "io"
13 "reflect"
14 "sync"
15
16 "encoding/json/internal"
17 "encoding/json/internal/jsonflags"
18 "encoding/json/internal/jsonopts"
19 "encoding/json/jsontext"
20 )
21
22
23
24
25
26
27
28
29 var SkipFunc = errors.New("json: skip function")
30
31 var errSkipMutation = errors.New("must not read or write any tokens when skipping")
32 var errNonSingularValue = errors.New("must read or write exactly one value")
33
34
35
36
37
38
39 type Marshalers = typedMarshalers
40
41
42
43
44
45
46
47
48
49
50
51
52 func JoinMarshalers(ms ...*Marshalers) *Marshalers {
53 return newMarshalers(ms...)
54 }
55
56
57
58
59
60
61 type Unmarshalers = typedUnmarshalers
62
63
64
65
66
67
68
69
70
71
72
73
74 func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
75 return newUnmarshalers(us...)
76 }
77
78 type typedMarshalers = typedArshalers[jsontext.Encoder]
79 type typedUnmarshalers = typedArshalers[jsontext.Decoder]
80 type typedArshalers[Coder any] struct {
81 nonComparable
82
83 fncVals []typedArshaler[Coder]
84 fncCache sync.Map
85
86
87
88
89
90
91
92
93
94
95 fromAny bool
96 }
97 type typedMarshaler = typedArshaler[jsontext.Encoder]
98 type typedUnmarshaler = typedArshaler[jsontext.Decoder]
99 type typedArshaler[Coder any] struct {
100 typ reflect.Type
101 fnc func(*Coder, addressableValue, *jsonopts.Struct) error
102 maySkip bool
103 }
104
105 func newMarshalers(ms ...*Marshalers) *Marshalers { return newTypedArshalers(ms...) }
106 func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
107 func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
108 var a typedArshalers[Coder]
109 for _, a2 := range as {
110 if a2 != nil {
111 a.fncVals = append(a.fncVals, a2.fncVals...)
112 a.fromAny = a.fromAny || a2.fromAny
113 }
114 }
115 if len(a.fncVals) == 0 {
116 return nil
117 }
118 return &a
119 }
120
121 func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
122 if a == nil {
123 return fnc, false
124 }
125 if v, ok := a.fncCache.Load(t); ok {
126 if v == nil {
127 return fnc, false
128 }
129 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
130 }
131
132
133
134 var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
135 for _, fncVal := range a.fncVals {
136 if !castableTo(t, fncVal.typ) {
137 continue
138 }
139 fncs = append(fncs, fncVal.fnc)
140 if !fncVal.maySkip {
141 break
142 }
143 }
144
145 if len(fncs) == 0 {
146 a.fncCache.Store(t, nil)
147 return fnc, false
148 }
149
150
151 fncDefault := fnc
152 fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
153 for _, fnc := range fncs {
154 if err := fnc(c, v, o); err != SkipFunc {
155 return err
156 }
157 }
158 return fncDefault(c, v, o)
159 }
160
161
162 v, _ := a.fncCache.LoadOrStore(t, fnc)
163 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
164 }
165
166
167
168
169
170
171
172
173
174
175 func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
176 t := reflect.TypeFor[T]()
177 assertCastableTo(t, true)
178 typFnc := typedMarshaler{
179 typ: t,
180 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
181 v, _ := reflect.TypeAssert[T](va.castTo(t))
182 val, err := fn(v)
183 if err != nil {
184 err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
185 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
186 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
187 }
188 err = newMarshalErrorBefore(enc, t, err)
189 return collapseSemanticErrors(err)
190 }
191 if err := enc.WriteValue(val); err != nil {
192 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
193 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
194 }
195 if isSyntacticError(err) {
196 err = newMarshalErrorBefore(enc, t, err)
197 }
198 return err
199 }
200 return nil
201 },
202 }
203 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
204 }
205
206
207
208
209
210
211
212
213
214
215
216
217
218 func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
219 t := reflect.TypeFor[T]()
220 assertCastableTo(t, true)
221 typFnc := typedMarshaler{
222 typ: t,
223 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
224 xe := export.Encoder(enc)
225 prevDepth, prevLength := xe.Tokens.DepthLength()
226 xe.Flags.Set(jsonflags.WithinArshalCall | 1)
227 v, _ := reflect.TypeAssert[T](va.castTo(t))
228 err := fn(enc, v)
229 xe.Flags.Set(jsonflags.WithinArshalCall | 0)
230 currDepth, currLength := xe.Tokens.DepthLength()
231 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
232 err = errNonSingularValue
233 }
234 if err != nil {
235 if err == SkipFunc {
236 if prevDepth == currDepth && prevLength == currLength {
237 return SkipFunc
238 }
239 err = errSkipMutation
240 }
241 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
242 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc")
243 }
244 if !export.IsIOError(err) {
245 err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
246 }
247 return err
248 }
249 return nil
250 },
251 maySkip: true,
252 }
253 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
254 }
255
256
257
258
259
260
261
262
263
264
265 func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
266 t := reflect.TypeFor[T]()
267 assertCastableTo(t, false)
268 typFnc := typedUnmarshaler{
269 typ: t,
270 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
271 val, err := dec.ReadValue()
272 if err != nil {
273 return err
274 }
275 v, _ := reflect.TypeAssert[T](va.castTo(t))
276 err = fn(val, v)
277 if err != nil {
278 err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
279 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
280 return err
281 }
282 err = newUnmarshalErrorAfter(dec, t, err)
283 return collapseSemanticErrors(err)
284 }
285 return nil
286 },
287 }
288 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
289 }
290
291
292
293
294
295
296
297
298
299
300
301
302 func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
303 t := reflect.TypeFor[T]()
304 assertCastableTo(t, false)
305 typFnc := typedUnmarshaler{
306 typ: t,
307 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
308 xd := export.Decoder(dec)
309 prevDepth, prevLength := xd.Tokens.DepthLength()
310 if prevDepth == 1 && xd.AtEOF() {
311 return io.EOF
312 }
313 xd.Flags.Set(jsonflags.WithinArshalCall | 1)
314 v, _ := reflect.TypeAssert[T](va.castTo(t))
315 err := fn(dec, v)
316 xd.Flags.Set(jsonflags.WithinArshalCall | 0)
317 currDepth, currLength := xd.Tokens.DepthLength()
318 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
319 err = errNonSingularValue
320 }
321 if err != nil {
322 if err == SkipFunc {
323 if prevDepth == currDepth && prevLength == currLength {
324 return SkipFunc
325 }
326 err = errSkipMutation
327 }
328 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
329 if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
330 return err2
331 }
332 return err
333 }
334 if !isSyntacticError(err) && !export.IsIOError(err) {
335 err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
336 }
337 return err
338 }
339 return nil
340 },
341 maySkip: true,
342 }
343 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
344 }
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360 func assertCastableTo(to reflect.Type, marshal bool) {
361 switch to.Kind() {
362 case reflect.Interface:
363 return
364 case reflect.Pointer:
365
366
367 if to.Name() == "" {
368 return
369 }
370 default:
371
372
373
374 if marshal {
375 return
376 }
377 }
378 if marshal {
379 panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
380 } else {
381 panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
382 }
383 }
384
385
386
387
388
389 func castableTo(from, to reflect.Type) bool {
390 switch to.Kind() {
391 case reflect.Interface:
392
393
394
395
396 return reflect.PointerTo(from).Implements(to)
397 case reflect.Pointer:
398
399
400 return reflect.PointerTo(from) == to
401 default:
402
403
404 return from == to
405 }
406 }
407
408
409
410
411
412
413 func (va addressableValue) castTo(to reflect.Type) reflect.Value {
414 switch to.Kind() {
415 case reflect.Interface:
416 return va.Addr().Convert(to)
417 case reflect.Pointer:
418 return va.Addr()
419 default:
420 return va.Value
421 }
422 }
423
424
425
426 func castableToFromAny(to reflect.Type) bool {
427 for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
428 if castableTo(from, to) {
429 return true
430 }
431 }
432 return false
433 }
434
435 func wrapSkipFunc(err error, what string) error {
436 if err == SkipFunc {
437 return errors.New(what + " cannot be skipped")
438 }
439 return err
440 }
441
View as plain text