Source file src/encoding/json/v2/arshal_time.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"bytes"
    11  	"cmp"
    12  	"errors"
    13  	"fmt"
    14  	"math"
    15  	"math/bits"
    16  	"reflect"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"encoding/json/internal"
    22  	"encoding/json/internal/jsonflags"
    23  	"encoding/json/internal/jsonopts"
    24  	"encoding/json/internal/jsonwire"
    25  	"encoding/json/jsontext"
    26  )
    27  
    28  var (
    29  	timeDurationType = reflect.TypeFor[time.Duration]()
    30  	timeTimeType     = reflect.TypeFor[time.Time]()
    31  )
    32  
    33  func makeTimeArshaler(fncs *arshaler, t reflect.Type) *arshaler {
    34  	// Ideally, time types would implement MarshalerTo and UnmarshalerFrom,
    35  	// but that would incur a dependency on package json from package time.
    36  	// Given how widely used time is, it is more acceptable that we incur a
    37  	// dependency on time from json.
    38  	//
    39  	// Injecting the arshaling functionality like this will not be identical
    40  	// to actually declaring methods on the time types since embedding of the
    41  	// time types will not be able to forward this functionality.
    42  	switch t {
    43  	case timeDurationType:
    44  		fncs.nonDefault = true
    45  		marshalNano := fncs.marshal
    46  		fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
    47  			xe := export.Encoder(enc)
    48  			var m durationArshaler
    49  			if mo.Format != "" && mo.FormatDepth == xe.Tokens.Depth() {
    50  				if !m.initFormat(mo.Format) {
    51  					return newInvalidFormatError(enc, t, mo)
    52  				}
    53  			} else if mo.Flags.Get(jsonflags.FormatTimeWithLegacySemantics) {
    54  				return marshalNano(enc, va, mo)
    55  			} else {
    56  				// TODO(https://go.dev/issue/71631): Decide on default duration representation.
    57  				return newMarshalErrorBefore(enc, t, errors.New("no default representation (see https://go.dev/issue/71631); specify an explicit format"))
    58  			}
    59  
    60  			// TODO(https://go.dev/issue/62121): Use reflect.Value.AssertTo.
    61  			m.td = *va.Addr().Interface().(*time.Duration)
    62  			k := stringOrNumberKind(!m.isNumeric() || xe.Tokens.Last.NeedObjectName() || mo.Flags.Get(jsonflags.StringifyNumbers))
    63  			if err := xe.AppendRaw(k, true, m.appendMarshal); err != nil {
    64  				if !isSyntacticError(err) && !export.IsIOError(err) {
    65  					err = newMarshalErrorBefore(enc, t, err)
    66  				}
    67  				return err
    68  			}
    69  			return nil
    70  		}
    71  		unmarshalNano := fncs.unmarshal
    72  		fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
    73  			xd := export.Decoder(dec)
    74  			var u durationArshaler
    75  			if uo.Format != "" && uo.FormatDepth == xd.Tokens.Depth() {
    76  				if !u.initFormat(uo.Format) {
    77  					return newInvalidFormatError(dec, t, uo)
    78  				}
    79  			} else if uo.Flags.Get(jsonflags.FormatTimeWithLegacySemantics) {
    80  				return unmarshalNano(dec, va, uo)
    81  			} else {
    82  				// TODO(https://go.dev/issue/71631): Decide on default duration representation.
    83  				return newUnmarshalErrorBeforeWithSkipping(dec, uo, t, errors.New("no default representation (see https://go.dev/issue/71631); specify an explicit format"))
    84  			}
    85  
    86  			stringify := !u.isNumeric() || xd.Tokens.Last.NeedObjectName() || uo.Flags.Get(jsonflags.StringifyNumbers)
    87  			var flags jsonwire.ValueFlags
    88  			td := va.Addr().Interface().(*time.Duration)
    89  			val, err := xd.ReadValue(&flags)
    90  			if err != nil {
    91  				return err
    92  			}
    93  			switch k := val.Kind(); k {
    94  			case 'n':
    95  				if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
    96  					*td = time.Duration(0)
    97  				}
    98  				return nil
    99  			case '"':
   100  				if !stringify {
   101  					break
   102  				}
   103  				val = jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
   104  				if err := u.unmarshal(val); err != nil {
   105  					return newUnmarshalErrorAfter(dec, t, err)
   106  				}
   107  				*td = u.td
   108  				return nil
   109  			case '0':
   110  				if stringify {
   111  					break
   112  				}
   113  				if err := u.unmarshal(val); err != nil {
   114  					return newUnmarshalErrorAfter(dec, t, err)
   115  				}
   116  				*td = u.td
   117  				return nil
   118  			}
   119  			return newUnmarshalErrorAfter(dec, t, nil)
   120  		}
   121  	case timeTimeType:
   122  		fncs.nonDefault = true
   123  		fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) (err error) {
   124  			xe := export.Encoder(enc)
   125  			var m timeArshaler
   126  			if mo.Format != "" && mo.FormatDepth == xe.Tokens.Depth() {
   127  				if !m.initFormat(mo.Format) {
   128  					return newInvalidFormatError(enc, t, mo)
   129  				}
   130  			}
   131  
   132  			// TODO(https://go.dev/issue/62121): Use reflect.Value.AssertTo.
   133  			m.tt = *va.Addr().Interface().(*time.Time)
   134  			k := stringOrNumberKind(!m.isNumeric() || xe.Tokens.Last.NeedObjectName() || mo.Flags.Get(jsonflags.StringifyNumbers))
   135  			if err := xe.AppendRaw(k, !m.hasCustomFormat(), m.appendMarshal); err != nil {
   136  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   137  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSON") // unlike unmarshal, always wrapped
   138  				}
   139  				if !isSyntacticError(err) && !export.IsIOError(err) {
   140  					err = newMarshalErrorBefore(enc, t, err)
   141  				}
   142  				return err
   143  			}
   144  			return nil
   145  		}
   146  		fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) (err error) {
   147  			xd := export.Decoder(dec)
   148  			var u timeArshaler
   149  			if uo.Format != "" && uo.FormatDepth == xd.Tokens.Depth() {
   150  				if !u.initFormat(uo.Format) {
   151  					return newInvalidFormatError(dec, t, uo)
   152  				}
   153  			} else if uo.Flags.Get(jsonflags.FormatTimeWithLegacySemantics) {
   154  				u.looseRFC3339 = true
   155  			}
   156  
   157  			stringify := !u.isNumeric() || xd.Tokens.Last.NeedObjectName() || uo.Flags.Get(jsonflags.StringifyNumbers)
   158  			var flags jsonwire.ValueFlags
   159  			tt := va.Addr().Interface().(*time.Time)
   160  			val, err := xd.ReadValue(&flags)
   161  			if err != nil {
   162  				return err
   163  			}
   164  			switch k := val.Kind(); k {
   165  			case 'n':
   166  				if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
   167  					*tt = time.Time{}
   168  				}
   169  				return nil
   170  			case '"':
   171  				if !stringify {
   172  					break
   173  				}
   174  				val = jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
   175  				if err := u.unmarshal(val); err != nil {
   176  					if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   177  						return err // unlike marshal, never wrapped
   178  					}
   179  					return newUnmarshalErrorAfter(dec, t, err)
   180  				}
   181  				*tt = u.tt
   182  				return nil
   183  			case '0':
   184  				if stringify {
   185  					break
   186  				}
   187  				if err := u.unmarshal(val); err != nil {
   188  					if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   189  						return err // unlike marshal, never wrapped
   190  					}
   191  					return newUnmarshalErrorAfter(dec, t, err)
   192  				}
   193  				*tt = u.tt
   194  				return nil
   195  			}
   196  			return newUnmarshalErrorAfter(dec, t, nil)
   197  		}
   198  	}
   199  	return fncs
   200  }
   201  
   202  type durationArshaler struct {
   203  	td time.Duration
   204  
   205  	// base records the representation where:
   206  	//   - 0 uses time.Duration.String
   207  	//   - 1e0, 1e3, 1e6, or 1e9 use a decimal encoding of the duration as
   208  	//     nanoseconds, microseconds, milliseconds, or seconds.
   209  	//   - 8601 uses ISO 8601
   210  	base uint64
   211  }
   212  
   213  func (a *durationArshaler) initFormat(format string) (ok bool) {
   214  	switch format {
   215  	case "units":
   216  		a.base = 0
   217  	case "sec":
   218  		a.base = 1e9
   219  	case "milli":
   220  		a.base = 1e6
   221  	case "micro":
   222  		a.base = 1e3
   223  	case "nano":
   224  		a.base = 1e0
   225  	case "iso8601":
   226  		a.base = 8601
   227  	default:
   228  		return false
   229  	}
   230  	return true
   231  }
   232  
   233  func (a *durationArshaler) isNumeric() bool {
   234  	return a.base != 0 && a.base != 8601
   235  }
   236  
   237  func (a *durationArshaler) appendMarshal(b []byte) ([]byte, error) {
   238  	switch a.base {
   239  	case 0:
   240  		return append(b, a.td.String()...), nil
   241  	case 8601:
   242  		return appendDurationISO8601(b, a.td), nil
   243  	default:
   244  		return appendDurationBase10(b, a.td, a.base), nil
   245  	}
   246  }
   247  
   248  func (a *durationArshaler) unmarshal(b []byte) (err error) {
   249  	switch a.base {
   250  	case 0:
   251  		a.td, err = time.ParseDuration(string(b))
   252  	case 8601:
   253  		a.td, err = parseDurationISO8601(b)
   254  	default:
   255  		a.td, err = parseDurationBase10(b, a.base)
   256  	}
   257  	return err
   258  }
   259  
   260  type timeArshaler struct {
   261  	tt time.Time
   262  
   263  	// base records the representation where:
   264  	//   - 0 uses RFC 3339 encoding of the timestamp
   265  	//   - 1e0, 1e3, 1e6, or 1e9 use a decimal encoding of the timestamp as
   266  	//     seconds, milliseconds, microseconds, or nanoseconds since Unix epoch.
   267  	//   - math.MaxUint uses time.Time.Format to encode the timestamp
   268  	base   uint64
   269  	format string // time format passed to time.Parse
   270  
   271  	looseRFC3339 bool
   272  }
   273  
   274  func (a *timeArshaler) initFormat(format string) bool {
   275  	// We assume that an exported constant in the time package will
   276  	// always start with an uppercase ASCII letter.
   277  	if len(format) == 0 {
   278  		return false
   279  	}
   280  	a.base = math.MaxUint // implies custom format
   281  	if c := format[0]; !('a' <= c && c <= 'z') && !('A' <= c && c <= 'Z') {
   282  		a.format = format
   283  		return true
   284  	}
   285  	switch format {
   286  	case "ANSIC":
   287  		a.format = time.ANSIC
   288  	case "UnixDate":
   289  		a.format = time.UnixDate
   290  	case "RubyDate":
   291  		a.format = time.RubyDate
   292  	case "RFC822":
   293  		a.format = time.RFC822
   294  	case "RFC822Z":
   295  		a.format = time.RFC822Z
   296  	case "RFC850":
   297  		a.format = time.RFC850
   298  	case "RFC1123":
   299  		a.format = time.RFC1123
   300  	case "RFC1123Z":
   301  		a.format = time.RFC1123Z
   302  	case "RFC3339":
   303  		a.base = 0
   304  		a.format = time.RFC3339
   305  	case "RFC3339Nano":
   306  		a.base = 0
   307  		a.format = time.RFC3339Nano
   308  	case "Kitchen":
   309  		a.format = time.Kitchen
   310  	case "Stamp":
   311  		a.format = time.Stamp
   312  	case "StampMilli":
   313  		a.format = time.StampMilli
   314  	case "StampMicro":
   315  		a.format = time.StampMicro
   316  	case "StampNano":
   317  		a.format = time.StampNano
   318  	case "DateTime":
   319  		a.format = time.DateTime
   320  	case "DateOnly":
   321  		a.format = time.DateOnly
   322  	case "TimeOnly":
   323  		a.format = time.TimeOnly
   324  	case "unix":
   325  		a.base = 1e0
   326  	case "unixmilli":
   327  		a.base = 1e3
   328  	case "unixmicro":
   329  		a.base = 1e6
   330  	case "unixnano":
   331  		a.base = 1e9
   332  	default:
   333  		// Reject any Go identifier in case new constants are supported.
   334  		if strings.TrimFunc(format, isLetterOrDigit) == "" {
   335  			return false
   336  		}
   337  		a.format = format
   338  	}
   339  	return true
   340  }
   341  
   342  func (a *timeArshaler) isNumeric() bool {
   343  	return int(a.base) > 0
   344  }
   345  
   346  func (a *timeArshaler) hasCustomFormat() bool {
   347  	return a.base == math.MaxUint
   348  }
   349  
   350  func (a *timeArshaler) appendMarshal(b []byte) ([]byte, error) {
   351  	switch a.base {
   352  	case 0:
   353  		format := cmp.Or(a.format, time.RFC3339Nano)
   354  		n0 := len(b)
   355  		b = a.tt.AppendFormat(b, format)
   356  		// Not all Go timestamps can be represented as valid RFC 3339.
   357  		// Explicitly check for these edge cases.
   358  		// See https://go.dev/issue/4556 and https://go.dev/issue/54580.
   359  		switch b := b[n0:]; {
   360  		case b[len("9999")] != '-': // year must be exactly 4 digits wide
   361  			return b, errors.New("year outside of range [0,9999]")
   362  		case b[len(b)-1] != 'Z':
   363  			c := b[len(b)-len("Z07:00")]
   364  			if ('0' <= c && c <= '9') || parseDec2(b[len(b)-len("07:00"):]) >= 24 {
   365  				return b, errors.New("timezone hour outside of range [0,23]")
   366  			}
   367  		}
   368  		return b, nil
   369  	case math.MaxUint:
   370  		return a.tt.AppendFormat(b, a.format), nil
   371  	default:
   372  		return appendTimeUnix(b, a.tt, a.base), nil
   373  	}
   374  }
   375  
   376  func (a *timeArshaler) unmarshal(b []byte) (err error) {
   377  	switch a.base {
   378  	case 0:
   379  		// Use time.Time.UnmarshalText to avoid possible string allocation.
   380  		if err := a.tt.UnmarshalText(b); err != nil {
   381  			return err
   382  		}
   383  		// TODO(https://go.dev/issue/57912):
   384  		// RFC 3339 specifies the grammar for a valid timestamp.
   385  		// However, the parsing functionality in "time" is too loose and
   386  		// incorrectly accepts invalid timestamps as valid.
   387  		// Remove these manual checks when "time" checks it for us.
   388  		newParseError := func(layout, value, layoutElem, valueElem, message string) error {
   389  			return &time.ParseError{Layout: layout, Value: value, LayoutElem: layoutElem, ValueElem: valueElem, Message: message}
   390  		}
   391  		switch {
   392  		case a.looseRFC3339:
   393  			return nil
   394  		case b[len("2006-01-02T")+1] == ':': // hour must be two digits
   395  			return newParseError(time.RFC3339, string(b), "15", string(b[len("2006-01-02T"):][:1]), "")
   396  		case b[len("2006-01-02T15:04:05")] == ',': // sub-second separator must be a period
   397  			return newParseError(time.RFC3339, string(b), ".", ",", "")
   398  		case b[len(b)-1] != 'Z':
   399  			switch {
   400  			case parseDec2(b[len(b)-len("07:00"):]) >= 24: // timezone hour must be in range
   401  				return newParseError(time.RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone hour out of range")
   402  			case parseDec2(b[len(b)-len("00"):]) >= 60: // timezone minute must be in range
   403  				return newParseError(time.RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone minute out of range")
   404  			}
   405  		}
   406  		return nil
   407  	case math.MaxUint:
   408  		a.tt, err = time.Parse(a.format, string(b))
   409  		return err
   410  	default:
   411  		a.tt, err = parseTimeUnix(b, a.base)
   412  		return err
   413  	}
   414  }
   415  
   416  // appendDurationBase10 appends d formatted as a decimal fractional number,
   417  // where pow10 is a power-of-10 used to scale down the number.
   418  func appendDurationBase10(b []byte, d time.Duration, pow10 uint64) []byte {
   419  	b, n := mayAppendDurationSign(b, d)            // append sign
   420  	whole, frac := bits.Div64(0, n, uint64(pow10)) // compute whole and frac fields
   421  	b = strconv.AppendUint(b, whole, 10)           // append whole field
   422  	return appendFracBase10(b, frac, pow10)        // append frac field
   423  }
   424  
   425  // parseDurationBase10 parses d from a decimal fractional number,
   426  // where pow10 is a power-of-10 used to scale up the number.
   427  func parseDurationBase10(b []byte, pow10 uint64) (time.Duration, error) {
   428  	suffix, neg := consumeSign(b, false)                     // consume sign
   429  	wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
   430  	whole, okWhole := jsonwire.ParseUint(wholeBytes)         // parse whole field; may overflow
   431  	frac, okFrac := parseFracBase10(fracBytes, pow10)        // parse frac field
   432  	hi, lo := bits.Mul64(whole, uint64(pow10))               // overflow if hi > 0
   433  	sum, co := bits.Add64(lo, uint64(frac), 0)               // overflow if co > 0
   434  	switch d := mayApplyDurationSign(sum, neg); {            // overflow if neg != (d < 0)
   435  	case (!okWhole && whole != math.MaxUint64) || !okFrac:
   436  		return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrSyntax)
   437  	case !okWhole || hi > 0 || co > 0 || neg != (d < 0):
   438  		return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrRange)
   439  	default:
   440  		return d, nil
   441  	}
   442  }
   443  
   444  // appendDurationISO8601 appends an ISO 8601 duration with a restricted grammar,
   445  // where leading and trailing zeroes and zero-value designators are omitted.
   446  // It only uses hour, minute, and second designators since ISO 8601 defines
   447  // those as being "accurate", while year, month, week, and day are "nominal".
   448  func appendDurationISO8601(b []byte, d time.Duration) []byte {
   449  	if d == 0 {
   450  		return append(b, "PT0S"...)
   451  	}
   452  	b, n := mayAppendDurationSign(b, d)
   453  	b = append(b, "PT"...)
   454  	n, nsec := bits.Div64(0, n, 1e9)  // compute nsec field
   455  	n, sec := bits.Div64(0, n, 60)    // compute sec field
   456  	hour, min := bits.Div64(0, n, 60) // compute hour and min fields
   457  	if hour > 0 {
   458  		b = append(strconv.AppendUint(b, hour, 10), 'H')
   459  	}
   460  	if min > 0 {
   461  		b = append(strconv.AppendUint(b, min, 10), 'M')
   462  	}
   463  	if sec > 0 || nsec > 0 {
   464  		b = append(appendFracBase10(strconv.AppendUint(b, sec, 10), nsec, 1e9), 'S')
   465  	}
   466  	return b
   467  }
   468  
   469  // daysPerYear is the exact average number of days in a year according to
   470  // the Gregorian calender, which has an extra day each year that is
   471  // a multiple of 4, unless it is evenly divisible by 100 but not by 400.
   472  // This does not take into account leap seconds, which are not deterministic.
   473  const daysPerYear = 365.2425
   474  
   475  var errInaccurateDateUnits = errors.New("inaccurate year, month, week, or day units")
   476  
   477  // parseDurationISO8601 parses a duration according to ISO 8601-1:2019,
   478  // section 5.5.2.2 and 5.5.2.3 with the following restrictions or extensions:
   479  //
   480  //   - A leading minus sign is permitted for negative duration according
   481  //     to ISO 8601-2:2019, section 4.4.1.9. We do not permit negative values
   482  //     for each "time scale component", which is permitted by section 4.4.1.1,
   483  //     but rarely supported by parsers.
   484  //
   485  //   - A leading plus sign is permitted (and ignored).
   486  //     This is not required by ISO 8601, but not forbidden either.
   487  //     There is some precedent for this as it is supported by the principle of
   488  //     duration arithmetic as specified in ISO 8601-2-2019, section 14.1.
   489  //     Of note, the JavaScript grammar for ISO 8601 permits a leading plus sign.
   490  //
   491  //   - A fractional value is only permitted for accurate units
   492  //     (i.e., hour, minute, and seconds) in the last time component,
   493  //     which is permissible by ISO 8601-1:2019, section 5.5.2.3.
   494  //
   495  //   - Both periods ('.') and commas (',') are supported as the separator
   496  //     between the integer part and fraction part of a number,
   497  //     as specified in ISO 8601-1:2019, section 3.2.6.
   498  //     While ISO 8601 recommends comma as the default separator,
   499  //     most formatters uses a period.
   500  //
   501  //   - Leading zeros are ignored. This is not required by ISO 8601,
   502  //     but also not forbidden by the standard. Many parsers support this.
   503  //
   504  //   - Lowercase designators are supported. This is not required by ISO 8601,
   505  //     but also not forbidden by the standard. Many parsers support this.
   506  //
   507  // If the nominal units of year, month, week, or day are present,
   508  // this produces a best-effort value and also reports [errInaccurateDateUnits].
   509  //
   510  // The accepted grammar is identical to JavaScript's Duration:
   511  //
   512  //	https://tc39.es/proposal-temporal/#prod-Duration
   513  //
   514  // We follow JavaScript's grammar as JSON itself is derived from JavaScript.
   515  // The Temporal.Duration.toJSON method is guaranteed to produce an output
   516  // that can be parsed by this function so long as arithmetic in JavaScript
   517  // do not use a largestUnit value higher than "hours" (which is the default).
   518  // Even if it does, this will do a best-effort parsing with inaccurate units,
   519  // but report [errInaccurateDateUnits].
   520  func parseDurationISO8601(b []byte) (time.Duration, error) {
   521  	var invalid, overflow, inaccurate, sawFrac bool
   522  	var sumNanos, n, co uint64
   523  
   524  	// cutBytes is like [bytes.Cut], but uses either c0 or c1 as the separator.
   525  	cutBytes := func(b []byte, c0, c1 byte) (prefix, suffix []byte, ok bool) {
   526  		for i, c := range b {
   527  			if c == c0 || c == c1 {
   528  				return b[:i], b[i+1:], true
   529  			}
   530  		}
   531  		return b, nil, false
   532  	}
   533  
   534  	// mayParseUnit attempts to parse another date or time number
   535  	// identified by the desHi and desLo unit characters.
   536  	// If the part is absent for current unit, it returns b as is.
   537  	mayParseUnit := func(b []byte, desHi, desLo byte, unit time.Duration) []byte {
   538  		number, suffix, ok := cutBytes(b, desHi, desLo)
   539  		if !ok || sawFrac {
   540  			return b // designator is not present or already saw fraction, which can only be in the last component
   541  		}
   542  
   543  		// Parse the number.
   544  		// A fraction allowed for the accurate units in the last part.
   545  		whole, frac, ok := cutBytes(number, '.', ',')
   546  		if ok {
   547  			sawFrac = true
   548  			invalid = invalid || len(frac) == len("") || unit > time.Hour
   549  			if unit == time.Second {
   550  				n, ok = parsePaddedBase10(frac, uint64(time.Second))
   551  				invalid = invalid || !ok
   552  			} else {
   553  				f, err := strconv.ParseFloat("0."+string(frac), 64)
   554  				invalid = invalid || err != nil || len(bytes.Trim(frac[len("."):], "0123456789")) > 0
   555  				n = uint64(math.Round(f * float64(unit))) // never overflows since f is within [0..1]
   556  			}
   557  			sumNanos, co = bits.Add64(sumNanos, n, 0) // overflow if co > 0
   558  			overflow = overflow || co > 0
   559  		}
   560  		for len(whole) > 1 && whole[0] == '0' {
   561  			whole = whole[len("0"):] // trim leading zeros
   562  		}
   563  		n, ok := jsonwire.ParseUint(whole)         // overflow if !ok && MaxUint64
   564  		hi, lo := bits.Mul64(n, uint64(unit))      // overflow if hi > 0
   565  		sumNanos, co = bits.Add64(sumNanos, lo, 0) // overflow if co > 0
   566  		invalid = invalid || (!ok && n != math.MaxUint64)
   567  		overflow = overflow || (!ok && n == math.MaxUint64) || hi > 0 || co > 0
   568  		inaccurate = inaccurate || unit > time.Hour
   569  		return suffix
   570  	}
   571  
   572  	suffix, neg := consumeSign(b, true)
   573  	prefix, suffix, okP := cutBytes(suffix, 'P', 'p')
   574  	durDate, durTime, okT := cutBytes(suffix, 'T', 't')
   575  	invalid = invalid || len(prefix) > 0 || !okP || (okT && len(durTime) == 0) || len(durDate)+len(durTime) == 0
   576  	if len(durDate) > 0 { // nominal portion of the duration
   577  		durDate = mayParseUnit(durDate, 'Y', 'y', time.Duration(daysPerYear*24*60*60*1e9))
   578  		durDate = mayParseUnit(durDate, 'M', 'm', time.Duration(daysPerYear/12*24*60*60*1e9))
   579  		durDate = mayParseUnit(durDate, 'W', 'w', time.Duration(7*24*60*60*1e9))
   580  		durDate = mayParseUnit(durDate, 'D', 'd', time.Duration(24*60*60*1e9))
   581  		invalid = invalid || len(durDate) > 0 // unknown elements
   582  	}
   583  	if len(durTime) > 0 { // accurate portion of the duration
   584  		durTime = mayParseUnit(durTime, 'H', 'h', time.Duration(60*60*1e9))
   585  		durTime = mayParseUnit(durTime, 'M', 'm', time.Duration(60*1e9))
   586  		durTime = mayParseUnit(durTime, 'S', 's', time.Duration(1e9))
   587  		invalid = invalid || len(durTime) > 0 // unknown elements
   588  	}
   589  	d := mayApplyDurationSign(sumNanos, neg)
   590  	overflow = overflow || (neg != (d < 0) && d != 0) // overflows signed duration
   591  
   592  	switch {
   593  	case invalid:
   594  		return 0, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, strconv.ErrSyntax)
   595  	case overflow:
   596  		return 0, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, strconv.ErrRange)
   597  	case inaccurate:
   598  		return d, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, errInaccurateDateUnits)
   599  	default:
   600  		return d, nil
   601  	}
   602  }
   603  
   604  // mayAppendDurationSign appends a negative sign if n is negative.
   605  func mayAppendDurationSign(b []byte, d time.Duration) ([]byte, uint64) {
   606  	if d < 0 {
   607  		b = append(b, '-')
   608  		d *= -1
   609  	}
   610  	return b, uint64(d)
   611  }
   612  
   613  // mayApplyDurationSign inverts n if neg is specified.
   614  func mayApplyDurationSign(n uint64, neg bool) time.Duration {
   615  	if neg {
   616  		return -1 * time.Duration(n)
   617  	} else {
   618  		return +1 * time.Duration(n)
   619  	}
   620  }
   621  
   622  // appendTimeUnix appends t formatted as a decimal fractional number,
   623  // where pow10 is a power-of-10 used to scale up the number.
   624  func appendTimeUnix(b []byte, t time.Time, pow10 uint64) []byte {
   625  	sec, nsec := t.Unix(), int64(t.Nanosecond())
   626  	if sec < 0 {
   627  		b = append(b, '-')
   628  		sec, nsec = negateSecNano(sec, nsec)
   629  	}
   630  	switch {
   631  	case pow10 == 1e0: // fast case where units is in seconds
   632  		b = strconv.AppendUint(b, uint64(sec), 10)
   633  		return appendFracBase10(b, uint64(nsec), 1e9)
   634  	case uint64(sec) < 1e9: // intermediate case where units is not seconds, but no overflow
   635  		b = strconv.AppendUint(b, uint64(sec)*uint64(pow10)+uint64(uint64(nsec)/(1e9/pow10)), 10)
   636  		return appendFracBase10(b, (uint64(nsec)*pow10)%1e9, 1e9)
   637  	default: // slow case where units is not seconds and overflow would occur
   638  		b = strconv.AppendUint(b, uint64(sec), 10)
   639  		b = appendPaddedBase10(b, uint64(nsec)/(1e9/pow10), pow10)
   640  		return appendFracBase10(b, (uint64(nsec)*pow10)%1e9, 1e9)
   641  	}
   642  }
   643  
   644  // parseTimeUnix parses t formatted as a decimal fractional number,
   645  // where pow10 is a power-of-10 used to scale down the number.
   646  func parseTimeUnix(b []byte, pow10 uint64) (time.Time, error) {
   647  	suffix, neg := consumeSign(b, false)                     // consume sign
   648  	wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
   649  	whole, okWhole := jsonwire.ParseUint(wholeBytes)         // parse whole field; may overflow
   650  	frac, okFrac := parseFracBase10(fracBytes, 1e9/pow10)    // parse frac field
   651  	var sec, nsec int64
   652  	switch {
   653  	case pow10 == 1e0: // fast case where units is in seconds
   654  		sec = int64(whole) // check overflow later after negation
   655  		nsec = int64(frac) // cannot overflow
   656  	case okWhole: // intermediate case where units is not seconds, but no overflow
   657  		sec = int64(whole / pow10)                     // check overflow later after negation
   658  		nsec = int64((whole%pow10)*(1e9/pow10) + frac) // cannot overflow
   659  	case !okWhole && whole == math.MaxUint64: // slow case where units is not seconds and overflow occurred
   660  		width := int(math.Log10(float64(pow10)))                                // compute len(strconv.Itoa(pow10-1))
   661  		whole, okWhole = jsonwire.ParseUint(wholeBytes[:len(wholeBytes)-width]) // parse the upper whole field
   662  		mid, _ := parsePaddedBase10(wholeBytes[len(wholeBytes)-width:], pow10)  // parse the lower whole field
   663  		sec = int64(whole)                                                      // check overflow later after negation
   664  		nsec = int64(mid*(1e9/pow10) + frac)                                    // cannot overflow
   665  	}
   666  	if neg {
   667  		sec, nsec = negateSecNano(sec, nsec)
   668  	}
   669  	switch t := time.Unix(sec, nsec).UTC(); {
   670  	case (!okWhole && whole != math.MaxUint64) || !okFrac:
   671  		return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrSyntax)
   672  	case !okWhole || neg != (t.Unix() < 0):
   673  		return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrRange)
   674  	default:
   675  		return t, nil
   676  	}
   677  }
   678  
   679  // negateSecNano negates a Unix timestamp, where nsec must be within [0, 1e9).
   680  func negateSecNano(sec, nsec int64) (int64, int64) {
   681  	sec = ^sec               // twos-complement negation (i.e., -1*sec + 1)
   682  	nsec = -nsec + 1e9       // negate nsec and add 1e9 (which is the extra +1 from sec negation)
   683  	sec += int64(nsec / 1e9) // handle possible overflow of nsec if it started as zero
   684  	nsec %= 1e9              // ensure nsec stays within [0, 1e9)
   685  	return sec, nsec
   686  }
   687  
   688  // appendFracBase10 appends the fraction of n/max10,
   689  // where max10 is a power-of-10 that is larger than n.
   690  func appendFracBase10(b []byte, n, max10 uint64) []byte {
   691  	if n == 0 {
   692  		return b
   693  	}
   694  	return bytes.TrimRight(appendPaddedBase10(append(b, '.'), n, max10), "0")
   695  }
   696  
   697  // parseFracBase10 parses the fraction of n/max10,
   698  // where max10 is a power-of-10 that is larger than n.
   699  func parseFracBase10(b []byte, max10 uint64) (n uint64, ok bool) {
   700  	switch {
   701  	case len(b) == 0:
   702  		return 0, true
   703  	case len(b) < len(".0") || b[0] != '.':
   704  		return 0, false
   705  	}
   706  	return parsePaddedBase10(b[len("."):], max10)
   707  }
   708  
   709  // appendPaddedBase10 appends a zero-padded encoding of n,
   710  // where max10 is a power-of-10 that is larger than n.
   711  func appendPaddedBase10(b []byte, n, max10 uint64) []byte {
   712  	if n < max10/10 {
   713  		// Formatting of n is shorter than log10(max10),
   714  		// so add max10/10 to ensure the length is equal to log10(max10).
   715  		i := len(b)
   716  		b = strconv.AppendUint(b, n+max10/10, 10)
   717  		b[i]-- // subtract the addition of max10/10
   718  		return b
   719  	}
   720  	return strconv.AppendUint(b, n, 10)
   721  }
   722  
   723  // parsePaddedBase10 parses b as the zero-padded encoding of n,
   724  // where max10 is a power-of-10 that is larger than n.
   725  // Truncated suffix is treated as implicit zeros.
   726  // Extended suffix is ignored, but verified to contain only digits.
   727  func parsePaddedBase10(b []byte, max10 uint64) (n uint64, ok bool) {
   728  	pow10 := uint64(1)
   729  	for pow10 < max10 {
   730  		n *= 10
   731  		if len(b) > 0 {
   732  			if b[0] < '0' || '9' < b[0] {
   733  				return n, false
   734  			}
   735  			n += uint64(b[0] - '0')
   736  			b = b[1:]
   737  		}
   738  		pow10 *= 10
   739  	}
   740  	if len(b) > 0 && len(bytes.TrimRight(b, "0123456789")) > 0 {
   741  		return n, false // trailing characters are not digits
   742  	}
   743  	return n, true
   744  }
   745  
   746  // consumeSign consumes an optional leading negative or positive sign.
   747  func consumeSign(b []byte, allowPlus bool) ([]byte, bool) {
   748  	if len(b) > 0 {
   749  		if b[0] == '-' {
   750  			return b[len("-"):], true
   751  		} else if b[0] == '+' && allowPlus {
   752  			return b[len("+"):], false
   753  		}
   754  	}
   755  	return b, false
   756  }
   757  
   758  // bytesCutByte is similar to bytes.Cut(b, []byte{c}),
   759  // except c may optionally be included as part of the suffix.
   760  func bytesCutByte(b []byte, c byte, include bool) ([]byte, []byte) {
   761  	if i := bytes.IndexByte(b, c); i >= 0 {
   762  		if include {
   763  			return b[:i], b[i:]
   764  		}
   765  		return b[:i], b[i+1:]
   766  	}
   767  	return b, nil
   768  }
   769  
   770  // parseDec2 parses b as an unsigned, base-10, 2-digit number.
   771  // The result is undefined if digits are not base-10.
   772  func parseDec2(b []byte) byte {
   773  	if len(b) < 2 {
   774  		return 0
   775  	}
   776  	return 10*(b[0]-'0') + (b[1] - '0')
   777  }
   778  

View as plain text