Source file src/encoding/json/v2/arshal_time.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"bytes"
    11  	"cmp"
    12  	"errors"
    13  	"fmt"
    14  	"math"
    15  	"math/bits"
    16  	"reflect"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"encoding/json/internal"
    22  	"encoding/json/internal/jsonflags"
    23  	"encoding/json/internal/jsonopts"
    24  	"encoding/json/internal/jsonwire"
    25  	"encoding/json/jsontext"
    26  )
    27  
    28  var (
    29  	timeDurationType = reflect.TypeFor[time.Duration]()
    30  	timeTimeType     = reflect.TypeFor[time.Time]()
    31  )
    32  
    33  func makeTimeArshaler(fncs *arshaler, t reflect.Type) *arshaler {
    34  	// Ideally, time types would implement MarshalerTo and UnmarshalerFrom,
    35  	// but that would incur a dependency on package json from package time.
    36  	// Given how widely used time is, it is more acceptable that we incur a
    37  	// dependency on time from json.
    38  	//
    39  	// Injecting the arshaling functionality like this will not be identical
    40  	// to actually declaring methods on the time types since embedding of the
    41  	// time types will not be able to forward this functionality.
    42  	switch t {
    43  	case timeDurationType:
    44  		fncs.nonDefault = true
    45  		marshalNano := fncs.marshal
    46  		fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
    47  			xe := export.Encoder(enc)
    48  			var m durationArshaler
    49  			if mo.Flags.Has(jsonflags.FormatTag) {
    50  				if !m.initFormat(mo.Format) {
    51  					return newInvalidFormatError(enc, t)
    52  				}
    53  			} else if mo.Flags.Get(jsonflags.FormatDurationAsNano) {
    54  				return marshalNano(enc, va, mo)
    55  			} else {
    56  				// TODO(https://go.dev/issue/71631): Decide on default duration representation.
    57  				var workaround string
    58  				if mo.Flags.Get(jsonflags.FormatTagSupported) {
    59  					workaround = "; specify an explicit format"
    60  				}
    61  				return newMarshalErrorBefore(enc, t, errors.New("no default representation"+workaround))
    62  			}
    63  			if mo.Flags.Get(jsonflags.StringTag) && !m.isNumeric() && !mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
    64  				return newMarshalErrorBefore(enc, t, errInvalidStringTag)
    65  			}
    66  
    67  			m.td, _ = reflect.TypeAssert[time.Duration](va.Value)
    68  			k := stringOrNumberKind(!m.isNumeric() || xe.Tokens.Last.NeedObjectName() || mo.Flags.Get(jsonflags.StringifyNumbers|jsonflags.StringTag))
    69  			if err := xe.AppendRaw(k, true, m.appendMarshal); err != nil {
    70  				if !isSyntacticError(err) && !export.IsIOError(err) {
    71  					err = newMarshalErrorBefore(enc, t, err)
    72  				}
    73  				return err
    74  			}
    75  			return nil
    76  		}
    77  		unmarshalNano := fncs.unmarshal
    78  		fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
    79  			xd := export.Decoder(dec)
    80  			var u durationArshaler
    81  			if uo.Flags.Has(jsonflags.FormatTag) {
    82  				if !u.initFormat(uo.Format) {
    83  					return newInvalidFormatError(dec, t)
    84  				}
    85  			} else if uo.Flags.Get(jsonflags.FormatDurationAsNano) {
    86  				return unmarshalNano(dec, va, uo)
    87  			} else {
    88  				// TODO(https://go.dev/issue/71631): Decide on default duration representation.
    89  				var workaround string
    90  				if uo.Flags.Get(jsonflags.FormatTagSupported) {
    91  					workaround = "; specify an explicit format"
    92  				}
    93  				return newUnmarshalErrorBeforeWithSkipping(dec, t, errors.New("no default representation"+workaround))
    94  			}
    95  			if uo.Flags.Get(jsonflags.StringTag) && !u.isNumeric() && !uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
    96  				return newUnmarshalErrorBeforeWithSkipping(dec, t, errInvalidStringTag)
    97  			}
    98  
    99  			stringify := !u.isNumeric() || xd.Tokens.Last.NeedObjectName() || uo.Flags.Get(jsonflags.StringifyNumbers|jsonflags.StringTag)
   100  			var flags jsonwire.ValueFlags
   101  			td, _ := reflect.TypeAssert[*time.Duration](va.Addr())
   102  			val, err := xd.ReadValue(&flags)
   103  			if err != nil {
   104  				return err
   105  			}
   106  			switch k := val.Kind(); k {
   107  			case 'n':
   108  				if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
   109  					*td = time.Duration(0)
   110  				}
   111  				return nil
   112  			case '"':
   113  				if !stringify {
   114  					break
   115  				}
   116  				val = jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
   117  				if err := u.unmarshal(val); err != nil {
   118  					return newUnmarshalErrorAfter(dec, t, err)
   119  				}
   120  				*td = u.td
   121  				return nil
   122  			case '0':
   123  				if stringify {
   124  					break
   125  				}
   126  				if err := u.unmarshal(val); err != nil {
   127  					return newUnmarshalErrorAfter(dec, t, err)
   128  				}
   129  				*td = u.td
   130  				return nil
   131  			}
   132  			return newUnmarshalErrorAfter(dec, t, nil)
   133  		}
   134  	case timeTimeType:
   135  		fncs.nonDefault = true
   136  		fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) (err error) {
   137  			xe := export.Encoder(enc)
   138  			var m timeArshaler
   139  			if mo.Flags.Has(jsonflags.FormatTag) {
   140  				if !m.initFormat(mo.Format) {
   141  					return newInvalidFormatError(enc, t)
   142  				}
   143  			}
   144  			if mo.Flags.Get(jsonflags.StringTag) && !m.isNumeric() && !mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   145  				return newMarshalErrorBefore(enc, t, errInvalidStringTag)
   146  			}
   147  
   148  			m.tt, _ = reflect.TypeAssert[time.Time](va.Value)
   149  			k := stringOrNumberKind(!m.isNumeric() || xe.Tokens.Last.NeedObjectName() || mo.Flags.Get(jsonflags.StringifyNumbers|jsonflags.StringTag))
   150  			if err := xe.AppendRaw(k, !m.hasCustomFormat(), m.appendMarshal); err != nil {
   151  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   152  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSON") // unlike unmarshal, always wrapped
   153  				}
   154  				if !isSyntacticError(err) && !export.IsIOError(err) {
   155  					err = newMarshalErrorBefore(enc, t, err)
   156  				}
   157  				return err
   158  			}
   159  			return nil
   160  		}
   161  		fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) (err error) {
   162  			xd := export.Decoder(dec)
   163  			var u timeArshaler
   164  			if uo.Flags.Has(jsonflags.FormatTag) {
   165  				if !u.initFormat(uo.Format) {
   166  					return newInvalidFormatError(dec, t)
   167  				}
   168  			} else if uo.Flags.Get(jsonflags.ParseTimeWithLooseRFC3339) {
   169  				u.looseRFC3339 = true
   170  			}
   171  			if uo.Flags.Get(jsonflags.StringTag) && !u.isNumeric() && !uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   172  				return newUnmarshalErrorBeforeWithSkipping(dec, t, errInvalidStringTag)
   173  			}
   174  
   175  			stringify := !u.isNumeric() || xd.Tokens.Last.NeedObjectName() || uo.Flags.Get(jsonflags.StringifyNumbers|jsonflags.StringTag)
   176  			var flags jsonwire.ValueFlags
   177  			tt, _ := reflect.TypeAssert[*time.Time](va.Addr())
   178  			val, err := xd.ReadValue(&flags)
   179  			if err != nil {
   180  				return err
   181  			}
   182  			switch k := val.Kind(); k {
   183  			case 'n':
   184  				if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
   185  					*tt = time.Time{}
   186  				}
   187  				return nil
   188  			case '"':
   189  				if !stringify {
   190  					break
   191  				}
   192  				val = jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
   193  				if err := u.unmarshal(val); err != nil {
   194  					if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   195  						return err // unlike marshal, never wrapped
   196  					}
   197  					return newUnmarshalErrorAfter(dec, t, err)
   198  				}
   199  				*tt = u.tt
   200  				return nil
   201  			case '0':
   202  				if stringify {
   203  					break
   204  				}
   205  				if err := u.unmarshal(val); err != nil {
   206  					if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   207  						return err // unlike marshal, never wrapped
   208  					}
   209  					return newUnmarshalErrorAfter(dec, t, err)
   210  				}
   211  				*tt = u.tt
   212  				return nil
   213  			}
   214  			return newUnmarshalErrorAfter(dec, t, nil)
   215  		}
   216  	}
   217  	return fncs
   218  }
   219  
   220  type durationArshaler struct {
   221  	td time.Duration
   222  
   223  	// base records the representation where:
   224  	//   - 0 uses time.Duration.String
   225  	//   - 1e0, 1e3, 1e6, or 1e9 use a decimal encoding of the duration as
   226  	//     nanoseconds, microseconds, milliseconds, or seconds.
   227  	//   - 8601 uses ISO 8601
   228  	base uint64
   229  }
   230  
   231  func (a *durationArshaler) initFormat(format string) (ok bool) {
   232  	switch format {
   233  	case "units":
   234  		a.base = 0
   235  	case "sec":
   236  		a.base = 1e9
   237  	case "milli":
   238  		a.base = 1e6
   239  	case "micro":
   240  		a.base = 1e3
   241  	case "nano":
   242  		a.base = 1e0
   243  	case "iso8601":
   244  		a.base = 8601
   245  	default:
   246  		return false
   247  	}
   248  	return true
   249  }
   250  
   251  func (a *durationArshaler) isNumeric() bool {
   252  	return a.base != 0 && a.base != 8601
   253  }
   254  
   255  func (a *durationArshaler) appendMarshal(b []byte) ([]byte, error) {
   256  	switch a.base {
   257  	case 0:
   258  		return append(b, a.td.String()...), nil
   259  	case 8601:
   260  		return appendDurationISO8601(b, a.td), nil
   261  	default:
   262  		return appendDurationBase10(b, a.td, a.base), nil
   263  	}
   264  }
   265  
   266  func (a *durationArshaler) unmarshal(b []byte) (err error) {
   267  	switch a.base {
   268  	case 0:
   269  		a.td, err = time.ParseDuration(string(b))
   270  	case 8601:
   271  		a.td, err = parseDurationISO8601(b)
   272  	default:
   273  		a.td, err = parseDurationBase10(b, a.base)
   274  	}
   275  	return err
   276  }
   277  
   278  type timeArshaler struct {
   279  	tt time.Time
   280  
   281  	// base records the representation where:
   282  	//   - 0 uses RFC 3339 encoding of the timestamp
   283  	//   - 1e0, 1e3, 1e6, or 1e9 use a decimal encoding of the timestamp as
   284  	//     seconds, milliseconds, microseconds, or nanoseconds since Unix epoch.
   285  	//   - math.MaxUint uses time.Time.Format to encode the timestamp
   286  	base   uint64
   287  	format string // time format passed to time.Parse
   288  
   289  	looseRFC3339 bool
   290  }
   291  
   292  func (a *timeArshaler) initFormat(format string) bool {
   293  	// We assume that an exported constant in the time package will
   294  	// always start with an uppercase ASCII letter.
   295  	if len(format) == 0 {
   296  		return false
   297  	}
   298  	a.base = math.MaxUint // implies custom format
   299  	if c := format[0]; !('a' <= c && c <= 'z') && !('A' <= c && c <= 'Z') {
   300  		a.format = format
   301  		return true
   302  	}
   303  	switch format {
   304  	case "ANSIC":
   305  		a.format = time.ANSIC
   306  	case "UnixDate":
   307  		a.format = time.UnixDate
   308  	case "RubyDate":
   309  		a.format = time.RubyDate
   310  	case "RFC822":
   311  		a.format = time.RFC822
   312  	case "RFC822Z":
   313  		a.format = time.RFC822Z
   314  	case "RFC850":
   315  		a.format = time.RFC850
   316  	case "RFC1123":
   317  		a.format = time.RFC1123
   318  	case "RFC1123Z":
   319  		a.format = time.RFC1123Z
   320  	case "RFC3339":
   321  		a.base = 0
   322  		a.format = time.RFC3339
   323  	case "RFC3339Nano":
   324  		a.base = 0
   325  		a.format = time.RFC3339Nano
   326  	case "Kitchen":
   327  		a.format = time.Kitchen
   328  	case "Stamp":
   329  		a.format = time.Stamp
   330  	case "StampMilli":
   331  		a.format = time.StampMilli
   332  	case "StampMicro":
   333  		a.format = time.StampMicro
   334  	case "StampNano":
   335  		a.format = time.StampNano
   336  	case "DateTime":
   337  		a.format = time.DateTime
   338  	case "DateOnly":
   339  		a.format = time.DateOnly
   340  	case "TimeOnly":
   341  		a.format = time.TimeOnly
   342  	case "unix":
   343  		a.base = 1e0
   344  	case "unixmilli":
   345  		a.base = 1e3
   346  	case "unixmicro":
   347  		a.base = 1e6
   348  	case "unixnano":
   349  		a.base = 1e9
   350  	default:
   351  		// Reject any Go identifier in case new constants are supported.
   352  		if strings.TrimFunc(format, isLetterOrDigit) == "" {
   353  			return false
   354  		}
   355  		a.format = format
   356  	}
   357  	return true
   358  }
   359  
   360  func (a *timeArshaler) isNumeric() bool {
   361  	return int(a.base) > 0
   362  }
   363  
   364  func (a *timeArshaler) hasCustomFormat() bool {
   365  	return a.base == math.MaxUint
   366  }
   367  
   368  func (a *timeArshaler) appendMarshal(b []byte) ([]byte, error) {
   369  	switch a.base {
   370  	case 0:
   371  		format := cmp.Or(a.format, time.RFC3339Nano)
   372  		n0 := len(b)
   373  		b = a.tt.AppendFormat(b, format)
   374  		// Not all Go timestamps can be represented as valid RFC 3339.
   375  		// Explicitly check for these edge cases.
   376  		// See https://go.dev/issue/4556 and https://go.dev/issue/54580.
   377  		switch b := b[n0:]; {
   378  		case b[len("9999")] != '-': // year must be exactly 4 digits wide
   379  			return b, errors.New("year outside of range [0,9999]")
   380  		case b[len(b)-1] != 'Z':
   381  			c := b[len(b)-len("Z07:00")]
   382  			if ('0' <= c && c <= '9') || parseDec2(b[len(b)-len("07:00"):]) >= 24 {
   383  				return b, errors.New("timezone hour outside of range [0,23]")
   384  			}
   385  		}
   386  		return b, nil
   387  	case math.MaxUint:
   388  		return a.tt.AppendFormat(b, a.format), nil
   389  	default:
   390  		return appendTimeUnix(b, a.tt, a.base), nil
   391  	}
   392  }
   393  
   394  func (a *timeArshaler) unmarshal(b []byte) (err error) {
   395  	switch a.base {
   396  	case 0:
   397  		// Use time.Time.UnmarshalText to avoid possible string allocation.
   398  		if err := a.tt.UnmarshalText(b); err != nil {
   399  			return err
   400  		}
   401  		// TODO(https://go.dev/issue/57912):
   402  		// RFC 3339 specifies the grammar for a valid timestamp.
   403  		// However, the parsing functionality in "time" is too loose and
   404  		// incorrectly accepts invalid timestamps as valid.
   405  		// Remove these manual checks when "time" checks it for us.
   406  		newParseError := func(layout, value, layoutElem, valueElem, message string) error {
   407  			return &time.ParseError{Layout: layout, Value: value, LayoutElem: layoutElem, ValueElem: valueElem, Message: message}
   408  		}
   409  		switch {
   410  		case a.looseRFC3339:
   411  			return nil
   412  		case b[len("2006-01-02T")+1] == ':': // hour must be two digits
   413  			return newParseError(time.RFC3339, string(b), "15", string(b[len("2006-01-02T"):][:1]), "")
   414  		case b[len("2006-01-02T15:04:05")] == ',': // sub-second separator must be a period
   415  			return newParseError(time.RFC3339, string(b), ".", ",", "")
   416  		case b[len(b)-1] != 'Z':
   417  			switch {
   418  			case parseDec2(b[len(b)-len("07:00"):]) >= 24: // timezone hour must be in range
   419  				return newParseError(time.RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone hour out of range")
   420  			case parseDec2(b[len(b)-len("00"):]) >= 60: // timezone minute must be in range
   421  				return newParseError(time.RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone minute out of range")
   422  			}
   423  		}
   424  		return nil
   425  	case math.MaxUint:
   426  		a.tt, err = time.Parse(a.format, string(b))
   427  		return err
   428  	default:
   429  		a.tt, err = parseTimeUnix(b, a.base)
   430  		return err
   431  	}
   432  }
   433  
   434  // appendDurationBase10 appends d formatted as a decimal fractional number,
   435  // where pow10 is a power-of-10 used to scale down the number.
   436  func appendDurationBase10(b []byte, d time.Duration, pow10 uint64) []byte {
   437  	b, n := mayAppendDurationSign(b, d)            // append sign
   438  	whole, frac := bits.Div64(0, n, uint64(pow10)) // compute whole and frac fields
   439  	b = strconv.AppendUint(b, whole, 10)           // append whole field
   440  	return appendFracBase10(b, frac, pow10)        // append frac field
   441  }
   442  
   443  // parseDurationBase10 parses d from a decimal fractional number,
   444  // where pow10 is a power-of-10 used to scale up the number.
   445  func parseDurationBase10(b []byte, pow10 uint64) (time.Duration, error) {
   446  	suffix, neg := consumeSign(b, false)                     // consume sign
   447  	wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
   448  	whole, okWhole := jsonwire.ParseUint(wholeBytes)         // parse whole field; may overflow
   449  	frac, okFrac := parseFracBase10(fracBytes, pow10)        // parse frac field
   450  	hi, lo := bits.Mul64(whole, uint64(pow10))               // overflow if hi > 0
   451  	sum, co := bits.Add64(lo, uint64(frac), 0)               // overflow if co > 0
   452  	switch d := mayApplyDurationSign(sum, neg); {            // overflow if neg != (d < 0)
   453  	case (!okWhole && whole != math.MaxUint64) || !okFrac:
   454  		return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrSyntax)
   455  	case !okWhole || hi > 0 || co > 0 || neg != (d < 0):
   456  		return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrRange)
   457  	default:
   458  		return d, nil
   459  	}
   460  }
   461  
   462  // appendDurationISO8601 appends an ISO 8601 duration with a restricted grammar,
   463  // where leading and trailing zeroes and zero-value designators are omitted.
   464  // It only uses hour, minute, and second designators since ISO 8601 defines
   465  // those as being "accurate", while year, month, week, and day are "nominal".
   466  func appendDurationISO8601(b []byte, d time.Duration) []byte {
   467  	if d == 0 {
   468  		return append(b, "PT0S"...)
   469  	}
   470  	b, n := mayAppendDurationSign(b, d)
   471  	b = append(b, "PT"...)
   472  	n, nsec := bits.Div64(0, n, 1e9)  // compute nsec field
   473  	n, sec := bits.Div64(0, n, 60)    // compute sec field
   474  	hour, min := bits.Div64(0, n, 60) // compute hour and min fields
   475  	if hour > 0 {
   476  		b = append(strconv.AppendUint(b, hour, 10), 'H')
   477  	}
   478  	if min > 0 {
   479  		b = append(strconv.AppendUint(b, min, 10), 'M')
   480  	}
   481  	if sec > 0 || nsec > 0 {
   482  		b = append(appendFracBase10(strconv.AppendUint(b, sec, 10), nsec, 1e9), 'S')
   483  	}
   484  	return b
   485  }
   486  
   487  // daysPerYear is the exact average number of days in a year according to
   488  // the Gregorian calendar, which has an extra day each year that is
   489  // a multiple of 4, unless it is evenly divisible by 100 but not by 400.
   490  // This does not take into account leap seconds, which are not deterministic.
   491  const daysPerYear = 365.2425
   492  
   493  var errInaccurateDateUnits = errors.New("inaccurate year, month, week, or day units")
   494  
   495  // parseDurationISO8601 parses a duration according to ISO 8601-1:2019,
   496  // section 5.5.2.2 and 5.5.2.3 with the following restrictions or extensions:
   497  //
   498  //   - A leading minus sign is permitted for negative duration according
   499  //     to ISO 8601-2:2019, section 4.4.1.9. We do not permit negative values
   500  //     for each "time scale component", which is permitted by section 4.4.1.1,
   501  //     but rarely supported by parsers.
   502  //
   503  //   - A leading plus sign is permitted (and ignored).
   504  //     This is not required by ISO 8601, but not forbidden either.
   505  //     There is some precedent for this as it is supported by the principle of
   506  //     duration arithmetic as specified in ISO 8601-2-2019, section 14.1.
   507  //     Of note, the JavaScript grammar for ISO 8601 permits a leading plus sign.
   508  //
   509  //   - A fractional value is only permitted for accurate units
   510  //     (i.e., hour, minute, and seconds) in the last time component,
   511  //     which is permissible by ISO 8601-1:2019, section 5.5.2.3.
   512  //
   513  //   - Both periods ('.') and commas (',') are supported as the separator
   514  //     between the integer part and fraction part of a number,
   515  //     as specified in ISO 8601-1:2019, section 3.2.6.
   516  //     While ISO 8601 recommends comma as the default separator,
   517  //     most formatters use a period.
   518  //
   519  //   - Leading zeros are ignored. This is not required by ISO 8601,
   520  //     but also not forbidden by the standard. Many parsers support this.
   521  //
   522  //   - Lowercase designators are supported. This is not required by ISO 8601,
   523  //     but also not forbidden by the standard. Many parsers support this.
   524  //
   525  // If the nominal units of year, month, week, or day are present,
   526  // this produces a best-effort value and also reports [errInaccurateDateUnits].
   527  //
   528  // The accepted grammar is identical to JavaScript's Duration:
   529  //
   530  //	https://tc39.es/proposal-temporal/#prod-Duration
   531  //
   532  // We follow JavaScript's grammar as JSON itself is derived from JavaScript.
   533  // The Temporal.Duration.toJSON method is guaranteed to produce an output
   534  // that can be parsed by this function so long as arithmetic in JavaScript
   535  // does not use a largestUnit value higher than "hours" (which is the default).
   536  // Even if it does, this will do a best-effort parsing with inaccurate units,
   537  // but report [errInaccurateDateUnits].
   538  func parseDurationISO8601(b []byte) (time.Duration, error) {
   539  	var invalid, overflow, inaccurate, sawFrac bool
   540  	var sumNanos, n, co uint64
   541  
   542  	// cutBytes is like [bytes.Cut], but uses either c0 or c1 as the separator.
   543  	cutBytes := func(b []byte, c0, c1 byte) (prefix, suffix []byte, ok bool) {
   544  		for i, c := range b {
   545  			if c == c0 || c == c1 {
   546  				return b[:i], b[i+1:], true
   547  			}
   548  		}
   549  		return b, nil, false
   550  	}
   551  
   552  	// mayParseUnit attempts to parse another date or time number
   553  	// identified by the desHi and desLo unit characters.
   554  	// If the part is absent for current unit, it returns b as is.
   555  	mayParseUnit := func(b []byte, desHi, desLo byte, unit time.Duration) []byte {
   556  		number, suffix, ok := cutBytes(b, desHi, desLo)
   557  		if !ok || sawFrac {
   558  			return b // designator is not present or already saw fraction, which can only be in the last component
   559  		}
   560  
   561  		// Parse the number.
   562  		// A fraction is allowed for the accurate units in the last part.
   563  		whole, frac, ok := cutBytes(number, '.', ',')
   564  		if ok {
   565  			sawFrac = true
   566  			invalid = invalid || len(frac) == len("") || unit > time.Hour
   567  			if unit == time.Second {
   568  				n, ok = parsePaddedBase10(frac, uint64(time.Second))
   569  				invalid = invalid || !ok
   570  			} else {
   571  				f, err := strconv.ParseFloat("0."+string(frac), 64)
   572  				invalid = invalid || err != nil || len(bytes.Trim(frac[len("."):], "0123456789")) > 0
   573  				n = uint64(math.Round(f * float64(unit))) // never overflows since f is within [0..1]
   574  			}
   575  			sumNanos, co = bits.Add64(sumNanos, n, 0) // overflow if co > 0
   576  			overflow = overflow || co > 0
   577  		}
   578  		for len(whole) > 1 && whole[0] == '0' {
   579  			whole = whole[len("0"):] // trim leading zeros
   580  		}
   581  		n, ok := jsonwire.ParseUint(whole)         // overflow if !ok && MaxUint64
   582  		hi, lo := bits.Mul64(n, uint64(unit))      // overflow if hi > 0
   583  		sumNanos, co = bits.Add64(sumNanos, lo, 0) // overflow if co > 0
   584  		invalid = invalid || (!ok && n != math.MaxUint64)
   585  		overflow = overflow || (!ok && n == math.MaxUint64) || hi > 0 || co > 0
   586  		inaccurate = inaccurate || unit > time.Hour
   587  		return suffix
   588  	}
   589  
   590  	suffix, neg := consumeSign(b, true)
   591  	prefix, suffix, okP := cutBytes(suffix, 'P', 'p')
   592  	durDate, durTime, okT := cutBytes(suffix, 'T', 't')
   593  	invalid = invalid || len(prefix) > 0 || !okP || (okT && len(durTime) == 0) || len(durDate)+len(durTime) == 0
   594  	if len(durDate) > 0 { // nominal portion of the duration
   595  		durDate = mayParseUnit(durDate, 'Y', 'y', time.Duration(daysPerYear*24*60*60*1e9))
   596  		durDate = mayParseUnit(durDate, 'M', 'm', time.Duration(daysPerYear/12*24*60*60*1e9))
   597  		durDate = mayParseUnit(durDate, 'W', 'w', time.Duration(7*24*60*60*1e9))
   598  		durDate = mayParseUnit(durDate, 'D', 'd', time.Duration(24*60*60*1e9))
   599  		invalid = invalid || len(durDate) > 0 // unknown elements
   600  	}
   601  	if len(durTime) > 0 { // accurate portion of the duration
   602  		durTime = mayParseUnit(durTime, 'H', 'h', time.Duration(60*60*1e9))
   603  		durTime = mayParseUnit(durTime, 'M', 'm', time.Duration(60*1e9))
   604  		durTime = mayParseUnit(durTime, 'S', 's', time.Duration(1e9))
   605  		invalid = invalid || len(durTime) > 0 // unknown elements
   606  	}
   607  	d := mayApplyDurationSign(sumNanos, neg)
   608  	overflow = overflow || (neg != (d < 0) && d != 0) // overflows signed duration
   609  
   610  	switch {
   611  	case invalid:
   612  		return 0, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, strconv.ErrSyntax)
   613  	case overflow:
   614  		return 0, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, strconv.ErrRange)
   615  	case inaccurate:
   616  		return d, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, errInaccurateDateUnits)
   617  	default:
   618  		return d, nil
   619  	}
   620  }
   621  
   622  // mayAppendDurationSign appends a negative sign if n is negative.
   623  func mayAppendDurationSign(b []byte, d time.Duration) ([]byte, uint64) {
   624  	if d < 0 {
   625  		b = append(b, '-')
   626  		d *= -1
   627  	}
   628  	return b, uint64(d)
   629  }
   630  
   631  // mayApplyDurationSign inverts n if neg is specified.
   632  func mayApplyDurationSign(n uint64, neg bool) time.Duration {
   633  	if neg {
   634  		return -1 * time.Duration(n)
   635  	} else {
   636  		return +1 * time.Duration(n)
   637  	}
   638  }
   639  
   640  // appendTimeUnix appends t formatted as a decimal fractional number,
   641  // where pow10 is a power-of-10 used to scale up the number.
   642  func appendTimeUnix(b []byte, t time.Time, pow10 uint64) []byte {
   643  	sec, nsec := t.Unix(), int64(t.Nanosecond())
   644  	if sec < 0 {
   645  		b = append(b, '-')
   646  		sec, nsec = negateSecNano(sec, nsec)
   647  	}
   648  	switch {
   649  	case pow10 == 1e0: // fast case where units is in seconds
   650  		b = strconv.AppendUint(b, uint64(sec), 10)
   651  		return appendFracBase10(b, uint64(nsec), 1e9)
   652  	case uint64(sec) < 1e9: // intermediate case where units is not seconds, but no overflow
   653  		b = strconv.AppendUint(b, uint64(sec)*uint64(pow10)+uint64(uint64(nsec)/(1e9/pow10)), 10)
   654  		return appendFracBase10(b, (uint64(nsec)*pow10)%1e9, 1e9)
   655  	default: // slow case where units is not seconds and overflow would occur
   656  		b = strconv.AppendUint(b, uint64(sec), 10)
   657  		b = appendPaddedBase10(b, uint64(nsec)/(1e9/pow10), pow10)
   658  		return appendFracBase10(b, (uint64(nsec)*pow10)%1e9, 1e9)
   659  	}
   660  }
   661  
   662  // parseTimeUnix parses t formatted as a decimal fractional number,
   663  // where pow10 is a power-of-10 used to scale down the number.
   664  func parseTimeUnix(b []byte, pow10 uint64) (time.Time, error) {
   665  	suffix, neg := consumeSign(b, false)                     // consume sign
   666  	wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
   667  	whole, okWhole := jsonwire.ParseUint(wholeBytes)         // parse whole field; may overflow
   668  	frac, okFrac := parseFracBase10(fracBytes, 1e9/pow10)    // parse frac field
   669  	var sec, nsec int64
   670  	switch {
   671  	case pow10 == 1e0: // fast case where units is in seconds
   672  		sec = int64(whole) // check overflow later after negation
   673  		nsec = int64(frac) // cannot overflow
   674  	case okWhole: // intermediate case where units is not seconds, but no overflow
   675  		sec = int64(whole / pow10)                     // check overflow later after negation
   676  		nsec = int64((whole%pow10)*(1e9/pow10) + frac) // cannot overflow
   677  	case !okWhole && whole == math.MaxUint64: // slow case where units is not seconds and overflow occurred
   678  		width := int(math.Log10(float64(pow10)))                                // compute len(strconv.Itoa(pow10-1))
   679  		whole, okWhole = jsonwire.ParseUint(wholeBytes[:len(wholeBytes)-width]) // parse the upper whole field
   680  		mid, _ := parsePaddedBase10(wholeBytes[len(wholeBytes)-width:], pow10)  // parse the lower whole field
   681  		sec = int64(whole)                                                      // check overflow later after negation
   682  		nsec = int64(mid*(1e9/pow10) + frac)                                    // cannot overflow
   683  	}
   684  	if neg {
   685  		sec, nsec = negateSecNano(sec, nsec)
   686  	}
   687  	switch t := time.Unix(sec, nsec).UTC(); {
   688  	case (!okWhole && whole != math.MaxUint64) || !okFrac:
   689  		return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrSyntax)
   690  	case !okWhole || neg != (t.Unix() < 0):
   691  		return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrRange)
   692  	default:
   693  		return t, nil
   694  	}
   695  }
   696  
   697  // negateSecNano negates a Unix timestamp, where nsec must be within [0, 1e9).
   698  func negateSecNano(sec, nsec int64) (int64, int64) {
   699  	sec = ^sec               // twos-complement negation (i.e., -1*sec + 1)
   700  	nsec = -nsec + 1e9       // negate nsec and add 1e9 (which is the extra +1 from sec negation)
   701  	sec += int64(nsec / 1e9) // handle possible overflow of nsec if it started as zero
   702  	nsec %= 1e9              // ensure nsec stays within [0, 1e9)
   703  	return sec, nsec
   704  }
   705  
   706  // appendFracBase10 appends the fraction of n/max10,
   707  // where max10 is a power-of-10 that is larger than n.
   708  func appendFracBase10(b []byte, n, max10 uint64) []byte {
   709  	if n == 0 {
   710  		return b
   711  	}
   712  	return bytes.TrimRight(appendPaddedBase10(append(b, '.'), n, max10), "0")
   713  }
   714  
   715  // parseFracBase10 parses the fraction of n/max10,
   716  // where max10 is a power-of-10 that is larger than n.
   717  func parseFracBase10(b []byte, max10 uint64) (n uint64, ok bool) {
   718  	switch {
   719  	case len(b) == 0:
   720  		return 0, true
   721  	case len(b) < len(".0") || b[0] != '.':
   722  		return 0, false
   723  	}
   724  	return parsePaddedBase10(b[len("."):], max10)
   725  }
   726  
   727  // appendPaddedBase10 appends a zero-padded encoding of n,
   728  // where max10 is a power-of-10 that is larger than n.
   729  func appendPaddedBase10(b []byte, n, max10 uint64) []byte {
   730  	if n < max10/10 {
   731  		// Formatting of n is shorter than log10(max10),
   732  		// so add max10/10 to ensure the length is equal to log10(max10).
   733  		i := len(b)
   734  		b = strconv.AppendUint(b, n+max10/10, 10)
   735  		b[i]-- // subtract the addition of max10/10
   736  		return b
   737  	}
   738  	return strconv.AppendUint(b, n, 10)
   739  }
   740  
   741  // parsePaddedBase10 parses b as the zero-padded encoding of n,
   742  // where max10 is a power-of-10 that is larger than n.
   743  // Truncated suffix is treated as implicit zeros.
   744  // Extended suffix is ignored, but verified to contain only digits.
   745  func parsePaddedBase10(b []byte, max10 uint64) (n uint64, ok bool) {
   746  	pow10 := uint64(1)
   747  	for pow10 < max10 {
   748  		n *= 10
   749  		if len(b) > 0 {
   750  			if b[0] < '0' || '9' < b[0] {
   751  				return n, false
   752  			}
   753  			n += uint64(b[0] - '0')
   754  			b = b[1:]
   755  		}
   756  		pow10 *= 10
   757  	}
   758  	if len(b) > 0 && len(bytes.TrimRight(b, "0123456789")) > 0 {
   759  		return n, false // trailing characters are not digits
   760  	}
   761  	return n, true
   762  }
   763  
   764  // consumeSign consumes an optional leading negative or positive sign.
   765  func consumeSign(b []byte, allowPlus bool) ([]byte, bool) {
   766  	if len(b) > 0 {
   767  		if b[0] == '-' {
   768  			return b[len("-"):], true
   769  		} else if b[0] == '+' && allowPlus {
   770  			return b[len("+"):], false
   771  		}
   772  	}
   773  	return b, false
   774  }
   775  
   776  // bytesCutByte is similar to bytes.Cut(b, []byte{c}),
   777  // except c may optionally be included as part of the suffix.
   778  func bytesCutByte(b []byte, c byte, include bool) ([]byte, []byte) {
   779  	if i := bytes.IndexByte(b, c); i >= 0 {
   780  		if include {
   781  			return b[:i], b[i:]
   782  		}
   783  		return b[:i], b[i+1:]
   784  	}
   785  	return b, nil
   786  }
   787  
   788  // parseDec2 parses b as an unsigned, base-10, 2-digit number.
   789  // The result is undefined if digits are not base-10.
   790  func parseDec2(b []byte) byte {
   791  	if len(b) < 2 {
   792  		return 0
   793  	}
   794  	return 10*(b[0]-'0') + (b[1] - '0')
   795  }
   796  

View as plain text