1
2
3
4
5
6
7 package json
8
9 import (
10 "errors"
11 "fmt"
12 "reflect"
13 "sync"
14
15 "encoding/json/internal"
16 "encoding/json/internal/jsonflags"
17 "encoding/json/internal/jsonopts"
18 "encoding/json/jsontext"
19 )
20
21
22
23
24
25
26
27
28 var SkipFunc = errors.New("json: skip function")
29
30 var errSkipMutation = errors.New("must not read or write any tokens when skipping")
31 var errNonSingularValue = errors.New("must read or write exactly one value")
32
33
34
35
36
37
38 type Marshalers = typedMarshalers
39
40
41
42
43
44
45
46
47
48
49
50
51 func JoinMarshalers(ms ...*Marshalers) *Marshalers {
52 return newMarshalers(ms...)
53 }
54
55
56
57
58
59
60 type Unmarshalers = typedUnmarshalers
61
62
63
64
65
66
67
68
69
70
71
72
73 func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
74 return newUnmarshalers(us...)
75 }
76
77 type typedMarshalers = typedArshalers[jsontext.Encoder]
78 type typedUnmarshalers = typedArshalers[jsontext.Decoder]
79 type typedArshalers[Coder any] struct {
80 nonComparable
81
82 fncVals []typedArshaler[Coder]
83 fncCache sync.Map
84
85
86
87
88
89
90
91
92
93
94 fromAny bool
95 }
96 type typedMarshaler = typedArshaler[jsontext.Encoder]
97 type typedUnmarshaler = typedArshaler[jsontext.Decoder]
98 type typedArshaler[Coder any] struct {
99 typ reflect.Type
100 fnc func(*Coder, addressableValue, *jsonopts.Struct) error
101 maySkip bool
102 }
103
104 func newMarshalers(ms ...*Marshalers) *Marshalers { return newTypedArshalers(ms...) }
105 func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
106 func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
107 var a typedArshalers[Coder]
108 for _, a2 := range as {
109 if a2 != nil {
110 a.fncVals = append(a.fncVals, a2.fncVals...)
111 a.fromAny = a.fromAny || a2.fromAny
112 }
113 }
114 if len(a.fncVals) == 0 {
115 return nil
116 }
117 return &a
118 }
119
120 func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
121 if a == nil {
122 return fnc, false
123 }
124 if v, ok := a.fncCache.Load(t); ok {
125 if v == nil {
126 return fnc, false
127 }
128 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
129 }
130
131
132
133 var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
134 for _, fncVal := range a.fncVals {
135 if !castableTo(t, fncVal.typ) {
136 continue
137 }
138 fncs = append(fncs, fncVal.fnc)
139 if !fncVal.maySkip {
140 break
141 }
142 }
143
144 if len(fncs) == 0 {
145 a.fncCache.Store(t, nil)
146 return fnc, false
147 }
148
149
150 fncDefault := fnc
151 fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
152 for _, fnc := range fncs {
153 if err := fnc(c, v, o); err != SkipFunc {
154 return err
155 }
156 }
157 return fncDefault(c, v, o)
158 }
159
160
161 v, _ := a.fncCache.LoadOrStore(t, fnc)
162 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
163 }
164
165
166
167
168
169
170
171
172
173
174 func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
175 t := reflect.TypeFor[T]()
176 assertCastableTo(t, true)
177 typFnc := typedMarshaler{
178 typ: t,
179 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
180 val, err := fn(va.castTo(t).Interface().(T))
181 if err != nil {
182 err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
183 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
184 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
185 }
186 err = newMarshalErrorBefore(enc, t, err)
187 return collapseSemanticErrors(err)
188 }
189 if err := enc.WriteValue(val); err != nil {
190 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
191 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
192 }
193 if isSyntacticError(err) {
194 err = newMarshalErrorBefore(enc, t, err)
195 }
196 return err
197 }
198 return nil
199 },
200 }
201 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
202 }
203
204
205
206
207
208
209
210
211
212
213
214
215
216 func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
217 t := reflect.TypeFor[T]()
218 assertCastableTo(t, true)
219 typFnc := typedMarshaler{
220 typ: t,
221 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
222 xe := export.Encoder(enc)
223 prevDepth, prevLength := xe.Tokens.DepthLength()
224 xe.Flags.Set(jsonflags.WithinArshalCall | 1)
225 err := fn(enc, va.castTo(t).Interface().(T))
226 xe.Flags.Set(jsonflags.WithinArshalCall | 0)
227 currDepth, currLength := xe.Tokens.DepthLength()
228 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
229 err = errNonSingularValue
230 }
231 if err != nil {
232 if err == SkipFunc {
233 if prevDepth == currDepth && prevLength == currLength {
234 return SkipFunc
235 }
236 err = errSkipMutation
237 }
238 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
239 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc")
240 }
241 if !export.IsIOError(err) {
242 err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
243 }
244 return err
245 }
246 return nil
247 },
248 maySkip: true,
249 }
250 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
251 }
252
253
254
255
256
257
258
259
260
261
262 func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
263 t := reflect.TypeFor[T]()
264 assertCastableTo(t, false)
265 typFnc := typedUnmarshaler{
266 typ: t,
267 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
268 val, err := dec.ReadValue()
269 if err != nil {
270 return err
271 }
272 err = fn(val, va.castTo(t).Interface().(T))
273 if err != nil {
274 err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
275 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
276 return err
277 }
278 err = newUnmarshalErrorAfter(dec, t, err)
279 return collapseSemanticErrors(err)
280 }
281 return nil
282 },
283 }
284 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
285 }
286
287
288
289
290
291
292
293
294
295
296
297
298 func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
299 t := reflect.TypeFor[T]()
300 assertCastableTo(t, false)
301 typFnc := typedUnmarshaler{
302 typ: t,
303 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
304 xd := export.Decoder(dec)
305 prevDepth, prevLength := xd.Tokens.DepthLength()
306 xd.Flags.Set(jsonflags.WithinArshalCall | 1)
307 err := fn(dec, va.castTo(t).Interface().(T))
308 xd.Flags.Set(jsonflags.WithinArshalCall | 0)
309 currDepth, currLength := xd.Tokens.DepthLength()
310 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
311 err = errNonSingularValue
312 }
313 if err != nil {
314 if err == SkipFunc {
315 if prevDepth == currDepth && prevLength == currLength {
316 return SkipFunc
317 }
318 err = errSkipMutation
319 }
320 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
321 if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
322 return err2
323 }
324 return err
325 }
326 if !isSyntacticError(err) && !export.IsIOError(err) {
327 err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
328 }
329 return err
330 }
331 return nil
332 },
333 maySkip: true,
334 }
335 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
336 }
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352 func assertCastableTo(to reflect.Type, marshal bool) {
353 switch to.Kind() {
354 case reflect.Interface:
355 return
356 case reflect.Pointer:
357
358
359 if to.Name() == "" {
360 return
361 }
362 default:
363
364
365
366 if marshal {
367 return
368 }
369 }
370 if marshal {
371 panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
372 } else {
373 panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
374 }
375 }
376
377
378
379
380
381 func castableTo(from, to reflect.Type) bool {
382 switch to.Kind() {
383 case reflect.Interface:
384
385
386
387
388 return reflect.PointerTo(from).Implements(to)
389 case reflect.Pointer:
390
391
392 return reflect.PointerTo(from) == to
393 default:
394
395
396 return from == to
397 }
398 }
399
400
401
402
403
404
405 func (va addressableValue) castTo(to reflect.Type) reflect.Value {
406 switch to.Kind() {
407 case reflect.Interface:
408 return va.Addr().Convert(to)
409 case reflect.Pointer:
410 return va.Addr()
411 default:
412 return va.Value
413 }
414 }
415
416
417
418 func castableToFromAny(to reflect.Type) bool {
419 for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
420 if castableTo(from, to) {
421 return true
422 }
423 }
424 return false
425 }
426
427 func wrapSkipFunc(err error, what string) error {
428 if err == SkipFunc {
429 return errors.New(what + " cannot be skipped")
430 }
431 return err
432 }
433
View as plain text