1
2
3
4
5
6
7 package json
8
9 import (
10 "errors"
11 "fmt"
12 "reflect"
13 "sync"
14
15 "encoding/json/internal"
16 "encoding/json/internal/jsonflags"
17 "encoding/json/internal/jsonopts"
18 "encoding/json/jsontext"
19 )
20
21
22
23
24
25
26
27
28 var SkipFunc = errors.New("json: skip function")
29
30 var errSkipMutation = errors.New("must not read or write any tokens when skipping")
31 var errNonSingularValue = errors.New("must read or write exactly one value")
32
33
34
35
36
37
38 type Marshalers = typedMarshalers
39
40
41
42
43
44
45
46
47
48
49
50
51 func JoinMarshalers(ms ...*Marshalers) *Marshalers {
52 return newMarshalers(ms...)
53 }
54
55
56
57
58
59
60 type Unmarshalers = typedUnmarshalers
61
62
63
64
65
66
67
68
69
70
71
72
73 func JoinUnmarshalers(us ...*Unmarshalers) *Unmarshalers {
74 return newUnmarshalers(us...)
75 }
76
77 type typedMarshalers = typedArshalers[jsontext.Encoder]
78 type typedUnmarshalers = typedArshalers[jsontext.Decoder]
79 type typedArshalers[Coder any] struct {
80 nonComparable
81
82 fncVals []typedArshaler[Coder]
83 fncCache sync.Map
84
85
86
87
88
89
90
91
92
93
94 fromAny bool
95 }
96 type typedMarshaler = typedArshaler[jsontext.Encoder]
97 type typedUnmarshaler = typedArshaler[jsontext.Decoder]
98 type typedArshaler[Coder any] struct {
99 typ reflect.Type
100 fnc func(*Coder, addressableValue, *jsonopts.Struct) error
101 maySkip bool
102 }
103
104 func newMarshalers(ms ...*Marshalers) *Marshalers { return newTypedArshalers(ms...) }
105 func newUnmarshalers(us ...*Unmarshalers) *Unmarshalers { return newTypedArshalers(us...) }
106 func newTypedArshalers[Coder any](as ...*typedArshalers[Coder]) *typedArshalers[Coder] {
107 var a typedArshalers[Coder]
108 for _, a2 := range as {
109 if a2 != nil {
110 a.fncVals = append(a.fncVals, a2.fncVals...)
111 a.fromAny = a.fromAny || a2.fromAny
112 }
113 }
114 if len(a.fncVals) == 0 {
115 return nil
116 }
117 return &a
118 }
119
120 func (a *typedArshalers[Coder]) lookup(fnc func(*Coder, addressableValue, *jsonopts.Struct) error, t reflect.Type) (func(*Coder, addressableValue, *jsonopts.Struct) error, bool) {
121 if a == nil {
122 return fnc, false
123 }
124 if v, ok := a.fncCache.Load(t); ok {
125 if v == nil {
126 return fnc, false
127 }
128 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
129 }
130
131
132
133 var fncs []func(*Coder, addressableValue, *jsonopts.Struct) error
134 for _, fncVal := range a.fncVals {
135 if !castableTo(t, fncVal.typ) {
136 continue
137 }
138 fncs = append(fncs, fncVal.fnc)
139 if !fncVal.maySkip {
140 break
141 }
142 }
143
144 if len(fncs) == 0 {
145 a.fncCache.Store(t, nil)
146 return fnc, false
147 }
148
149
150 fncDefault := fnc
151 fnc = func(c *Coder, v addressableValue, o *jsonopts.Struct) error {
152 for _, fnc := range fncs {
153 if err := fnc(c, v, o); err != SkipFunc {
154 return err
155 }
156 }
157 return fncDefault(c, v, o)
158 }
159
160
161 v, _ := a.fncCache.LoadOrStore(t, fnc)
162 return v.(func(*Coder, addressableValue, *jsonopts.Struct) error), true
163 }
164
165
166
167
168
169
170
171
172
173
174 func MarshalFunc[T any](fn func(T) ([]byte, error)) *Marshalers {
175 t := reflect.TypeFor[T]()
176 assertCastableTo(t, true)
177 typFnc := typedMarshaler{
178 typ: t,
179 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
180 v, _ := reflect.TypeAssert[T](va.castTo(t))
181 val, err := fn(v)
182 if err != nil {
183 err = wrapSkipFunc(err, "marshal function of type func(T) ([]byte, error)")
184 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
185 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
186 }
187 err = newMarshalErrorBefore(enc, t, err)
188 return collapseSemanticErrors(err)
189 }
190 if err := enc.WriteValue(val); err != nil {
191 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
192 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalFunc")
193 }
194 if isSyntacticError(err) {
195 err = newMarshalErrorBefore(enc, t, err)
196 }
197 return err
198 }
199 return nil
200 },
201 }
202 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
203 }
204
205
206
207
208
209
210
211
212
213
214
215
216
217 func MarshalToFunc[T any](fn func(*jsontext.Encoder, T) error) *Marshalers {
218 t := reflect.TypeFor[T]()
219 assertCastableTo(t, true)
220 typFnc := typedMarshaler{
221 typ: t,
222 fnc: func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
223 xe := export.Encoder(enc)
224 prevDepth, prevLength := xe.Tokens.DepthLength()
225 xe.Flags.Set(jsonflags.WithinArshalCall | 1)
226 v, _ := reflect.TypeAssert[T](va.castTo(t))
227 err := fn(enc, v)
228 xe.Flags.Set(jsonflags.WithinArshalCall | 0)
229 currDepth, currLength := xe.Tokens.DepthLength()
230 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
231 err = errNonSingularValue
232 }
233 if err != nil {
234 if err == SkipFunc {
235 if prevDepth == currDepth && prevLength == currLength {
236 return SkipFunc
237 }
238 err = errSkipMutation
239 }
240 if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
241 return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalToFunc")
242 }
243 if !export.IsIOError(err) {
244 err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
245 }
246 return err
247 }
248 return nil
249 },
250 maySkip: true,
251 }
252 return &Marshalers{fncVals: []typedMarshaler{typFnc}, fromAny: castableToFromAny(t)}
253 }
254
255
256
257
258
259
260
261
262
263
264 func UnmarshalFunc[T any](fn func([]byte, T) error) *Unmarshalers {
265 t := reflect.TypeFor[T]()
266 assertCastableTo(t, false)
267 typFnc := typedUnmarshaler{
268 typ: t,
269 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
270 val, err := dec.ReadValue()
271 if err != nil {
272 return err
273 }
274 v, _ := reflect.TypeAssert[T](va.castTo(t))
275 err = fn(val, v)
276 if err != nil {
277 err = wrapSkipFunc(err, "unmarshal function of type func([]byte, T) error")
278 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
279 return err
280 }
281 err = newUnmarshalErrorAfter(dec, t, err)
282 return collapseSemanticErrors(err)
283 }
284 return nil
285 },
286 }
287 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
288 }
289
290
291
292
293
294
295
296
297
298
299
300
301 func UnmarshalFromFunc[T any](fn func(*jsontext.Decoder, T) error) *Unmarshalers {
302 t := reflect.TypeFor[T]()
303 assertCastableTo(t, false)
304 typFnc := typedUnmarshaler{
305 typ: t,
306 fnc: func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
307 xd := export.Decoder(dec)
308 prevDepth, prevLength := xd.Tokens.DepthLength()
309 xd.Flags.Set(jsonflags.WithinArshalCall | 1)
310 v, _ := reflect.TypeAssert[T](va.castTo(t))
311 err := fn(dec, v)
312 xd.Flags.Set(jsonflags.WithinArshalCall | 0)
313 currDepth, currLength := xd.Tokens.DepthLength()
314 if err == nil && (prevDepth != currDepth || prevLength+1 != currLength) {
315 err = errNonSingularValue
316 }
317 if err != nil {
318 if err == SkipFunc {
319 if prevDepth == currDepth && prevLength == currLength {
320 return SkipFunc
321 }
322 err = errSkipMutation
323 }
324 if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
325 if err2 := xd.SkipUntil(prevDepth, prevLength+1); err2 != nil {
326 return err2
327 }
328 return err
329 }
330 if !isSyntacticError(err) && !export.IsIOError(err) {
331 err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
332 }
333 return err
334 }
335 return nil
336 },
337 maySkip: true,
338 }
339 return &Unmarshalers{fncVals: []typedUnmarshaler{typFnc}, fromAny: castableToFromAny(t)}
340 }
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356 func assertCastableTo(to reflect.Type, marshal bool) {
357 switch to.Kind() {
358 case reflect.Interface:
359 return
360 case reflect.Pointer:
361
362
363 if to.Name() == "" {
364 return
365 }
366 default:
367
368
369
370 if marshal {
371 return
372 }
373 }
374 if marshal {
375 panic(fmt.Sprintf("input type %v must be an interface type, an unnamed pointer type, or a non-pointer type", to))
376 } else {
377 panic(fmt.Sprintf("input type %v must be an interface type or an unnamed pointer type", to))
378 }
379 }
380
381
382
383
384
385 func castableTo(from, to reflect.Type) bool {
386 switch to.Kind() {
387 case reflect.Interface:
388
389
390
391
392 return reflect.PointerTo(from).Implements(to)
393 case reflect.Pointer:
394
395
396 return reflect.PointerTo(from) == to
397 default:
398
399
400 return from == to
401 }
402 }
403
404
405
406
407
408
409 func (va addressableValue) castTo(to reflect.Type) reflect.Value {
410 switch to.Kind() {
411 case reflect.Interface:
412 return va.Addr().Convert(to)
413 case reflect.Pointer:
414 return va.Addr()
415 default:
416 return va.Value
417 }
418 }
419
420
421
422 func castableToFromAny(to reflect.Type) bool {
423 for _, from := range []reflect.Type{anyType, boolType, stringType, float64Type, mapStringAnyType, sliceAnyType} {
424 if castableTo(from, to) {
425 return true
426 }
427 }
428 return false
429 }
430
431 func wrapSkipFunc(err error, what string) error {
432 if err == SkipFunc {
433 return errors.New(what + " cannot be skipped")
434 }
435 return err
436 }
437
View as plain text