Source file src/encoding/json/v2/arshal_time.go

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.jsonv2
     6  
     7  package json
     8  
     9  import (
    10  	"bytes"
    11  	"cmp"
    12  	"errors"
    13  	"fmt"
    14  	"math"
    15  	"math/bits"
    16  	"reflect"
    17  	"strconv"
    18  	"strings"
    19  	"time"
    20  
    21  	"encoding/json/internal"
    22  	"encoding/json/internal/jsonflags"
    23  	"encoding/json/internal/jsonopts"
    24  	"encoding/json/internal/jsonwire"
    25  	"encoding/json/jsontext"
    26  )
    27  
    28  var (
    29  	timeDurationType = reflect.TypeFor[time.Duration]()
    30  	timeTimeType     = reflect.TypeFor[time.Time]()
    31  )
    32  
    33  func makeTimeArshaler(fncs *arshaler, t reflect.Type) *arshaler {
    34  	// Ideally, time types would implement MarshalerTo and UnmarshalerFrom,
    35  	// but that would incur a dependency on package json from package time.
    36  	// Given how widely used time is, it is more acceptable that we incur a
    37  	// dependency on time from json.
    38  	//
    39  	// Injecting the arshaling functionality like this will not be identical
    40  	// to actually declaring methods on the time types since embedding of the
    41  	// time types will not be able to forward this functionality.
    42  	switch t {
    43  	case timeDurationType:
    44  		fncs.nonDefault = true
    45  		marshalNano := fncs.marshal
    46  		fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
    47  			xe := export.Encoder(enc)
    48  			var m durationArshaler
    49  			if mo.Format != "" && mo.FormatDepth == xe.Tokens.Depth() {
    50  				if !m.initFormat(mo.Format) {
    51  					return newInvalidFormatError(enc, t)
    52  				}
    53  			} else if mo.Flags.Get(jsonflags.FormatDurationAsNano) {
    54  				return marshalNano(enc, va, mo)
    55  			} else {
    56  				// TODO(https://go.dev/issue/71631): Decide on default duration representation.
    57  				return newMarshalErrorBefore(enc, t, errors.New("no default representation (see https://go.dev/issue/71631); specify an explicit format"))
    58  			}
    59  
    60  			m.td, _ = reflect.TypeAssert[time.Duration](va.Value)
    61  			k := stringOrNumberKind(!m.isNumeric() || xe.Tokens.Last.NeedObjectName() || mo.Flags.Get(jsonflags.StringifyNumbers))
    62  			if err := xe.AppendRaw(k, true, m.appendMarshal); err != nil {
    63  				if !isSyntacticError(err) && !export.IsIOError(err) {
    64  					err = newMarshalErrorBefore(enc, t, err)
    65  				}
    66  				return err
    67  			}
    68  			return nil
    69  		}
    70  		unmarshalNano := fncs.unmarshal
    71  		fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
    72  			xd := export.Decoder(dec)
    73  			var u durationArshaler
    74  			if uo.Format != "" && uo.FormatDepth == xd.Tokens.Depth() {
    75  				if !u.initFormat(uo.Format) {
    76  					return newInvalidFormatError(dec, t)
    77  				}
    78  			} else if uo.Flags.Get(jsonflags.FormatDurationAsNano) {
    79  				return unmarshalNano(dec, va, uo)
    80  			} else {
    81  				// TODO(https://go.dev/issue/71631): Decide on default duration representation.
    82  				return newUnmarshalErrorBeforeWithSkipping(dec, t, errors.New("no default representation (see https://go.dev/issue/71631); specify an explicit format"))
    83  			}
    84  
    85  			stringify := !u.isNumeric() || xd.Tokens.Last.NeedObjectName() || uo.Flags.Get(jsonflags.StringifyNumbers)
    86  			var flags jsonwire.ValueFlags
    87  			td, _ := reflect.TypeAssert[*time.Duration](va.Addr())
    88  			val, err := xd.ReadValue(&flags)
    89  			if err != nil {
    90  				return err
    91  			}
    92  			switch k := val.Kind(); k {
    93  			case 'n':
    94  				if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
    95  					*td = time.Duration(0)
    96  				}
    97  				return nil
    98  			case '"':
    99  				if !stringify {
   100  					break
   101  				}
   102  				val = jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
   103  				if err := u.unmarshal(val); err != nil {
   104  					return newUnmarshalErrorAfter(dec, t, err)
   105  				}
   106  				*td = u.td
   107  				return nil
   108  			case '0':
   109  				if stringify {
   110  					break
   111  				}
   112  				if err := u.unmarshal(val); err != nil {
   113  					return newUnmarshalErrorAfter(dec, t, err)
   114  				}
   115  				*td = u.td
   116  				return nil
   117  			}
   118  			return newUnmarshalErrorAfter(dec, t, nil)
   119  		}
   120  	case timeTimeType:
   121  		fncs.nonDefault = true
   122  		fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) (err error) {
   123  			xe := export.Encoder(enc)
   124  			var m timeArshaler
   125  			if mo.Format != "" && mo.FormatDepth == xe.Tokens.Depth() {
   126  				if !m.initFormat(mo.Format) {
   127  					return newInvalidFormatError(enc, t)
   128  				}
   129  			}
   130  
   131  			m.tt, _ = reflect.TypeAssert[time.Time](va.Value)
   132  			k := stringOrNumberKind(!m.isNumeric() || xe.Tokens.Last.NeedObjectName() || mo.Flags.Get(jsonflags.StringifyNumbers))
   133  			if err := xe.AppendRaw(k, !m.hasCustomFormat(), m.appendMarshal); err != nil {
   134  				if mo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   135  					return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSON") // unlike unmarshal, always wrapped
   136  				}
   137  				if !isSyntacticError(err) && !export.IsIOError(err) {
   138  					err = newMarshalErrorBefore(enc, t, err)
   139  				}
   140  				return err
   141  			}
   142  			return nil
   143  		}
   144  		fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) (err error) {
   145  			xd := export.Decoder(dec)
   146  			var u timeArshaler
   147  			if uo.Format != "" && uo.FormatDepth == xd.Tokens.Depth() {
   148  				if !u.initFormat(uo.Format) {
   149  					return newInvalidFormatError(dec, t)
   150  				}
   151  			} else if uo.Flags.Get(jsonflags.ParseTimeWithLooseRFC3339) {
   152  				u.looseRFC3339 = true
   153  			}
   154  
   155  			stringify := !u.isNumeric() || xd.Tokens.Last.NeedObjectName() || uo.Flags.Get(jsonflags.StringifyNumbers)
   156  			var flags jsonwire.ValueFlags
   157  			tt, _ := reflect.TypeAssert[*time.Time](va.Addr())
   158  			val, err := xd.ReadValue(&flags)
   159  			if err != nil {
   160  				return err
   161  			}
   162  			switch k := val.Kind(); k {
   163  			case 'n':
   164  				if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
   165  					*tt = time.Time{}
   166  				}
   167  				return nil
   168  			case '"':
   169  				if !stringify {
   170  					break
   171  				}
   172  				val = jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
   173  				if err := u.unmarshal(val); err != nil {
   174  					if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   175  						return err // unlike marshal, never wrapped
   176  					}
   177  					return newUnmarshalErrorAfter(dec, t, err)
   178  				}
   179  				*tt = u.tt
   180  				return nil
   181  			case '0':
   182  				if stringify {
   183  					break
   184  				}
   185  				if err := u.unmarshal(val); err != nil {
   186  					if uo.Flags.Get(jsonflags.ReportErrorsWithLegacySemantics) {
   187  						return err // unlike marshal, never wrapped
   188  					}
   189  					return newUnmarshalErrorAfter(dec, t, err)
   190  				}
   191  				*tt = u.tt
   192  				return nil
   193  			}
   194  			return newUnmarshalErrorAfter(dec, t, nil)
   195  		}
   196  	}
   197  	return fncs
   198  }
   199  
   200  type durationArshaler struct {
   201  	td time.Duration
   202  
   203  	// base records the representation where:
   204  	//   - 0 uses time.Duration.String
   205  	//   - 1e0, 1e3, 1e6, or 1e9 use a decimal encoding of the duration as
   206  	//     nanoseconds, microseconds, milliseconds, or seconds.
   207  	//   - 8601 uses ISO 8601
   208  	base uint64
   209  }
   210  
   211  func (a *durationArshaler) initFormat(format string) (ok bool) {
   212  	switch format {
   213  	case "units":
   214  		a.base = 0
   215  	case "sec":
   216  		a.base = 1e9
   217  	case "milli":
   218  		a.base = 1e6
   219  	case "micro":
   220  		a.base = 1e3
   221  	case "nano":
   222  		a.base = 1e0
   223  	case "iso8601":
   224  		a.base = 8601
   225  	default:
   226  		return false
   227  	}
   228  	return true
   229  }
   230  
   231  func (a *durationArshaler) isNumeric() bool {
   232  	return a.base != 0 && a.base != 8601
   233  }
   234  
   235  func (a *durationArshaler) appendMarshal(b []byte) ([]byte, error) {
   236  	switch a.base {
   237  	case 0:
   238  		return append(b, a.td.String()...), nil
   239  	case 8601:
   240  		return appendDurationISO8601(b, a.td), nil
   241  	default:
   242  		return appendDurationBase10(b, a.td, a.base), nil
   243  	}
   244  }
   245  
   246  func (a *durationArshaler) unmarshal(b []byte) (err error) {
   247  	switch a.base {
   248  	case 0:
   249  		a.td, err = time.ParseDuration(string(b))
   250  	case 8601:
   251  		a.td, err = parseDurationISO8601(b)
   252  	default:
   253  		a.td, err = parseDurationBase10(b, a.base)
   254  	}
   255  	return err
   256  }
   257  
   258  type timeArshaler struct {
   259  	tt time.Time
   260  
   261  	// base records the representation where:
   262  	//   - 0 uses RFC 3339 encoding of the timestamp
   263  	//   - 1e0, 1e3, 1e6, or 1e9 use a decimal encoding of the timestamp as
   264  	//     seconds, milliseconds, microseconds, or nanoseconds since Unix epoch.
   265  	//   - math.MaxUint uses time.Time.Format to encode the timestamp
   266  	base   uint64
   267  	format string // time format passed to time.Parse
   268  
   269  	looseRFC3339 bool
   270  }
   271  
   272  func (a *timeArshaler) initFormat(format string) bool {
   273  	// We assume that an exported constant in the time package will
   274  	// always start with an uppercase ASCII letter.
   275  	if len(format) == 0 {
   276  		return false
   277  	}
   278  	a.base = math.MaxUint // implies custom format
   279  	if c := format[0]; !('a' <= c && c <= 'z') && !('A' <= c && c <= 'Z') {
   280  		a.format = format
   281  		return true
   282  	}
   283  	switch format {
   284  	case "ANSIC":
   285  		a.format = time.ANSIC
   286  	case "UnixDate":
   287  		a.format = time.UnixDate
   288  	case "RubyDate":
   289  		a.format = time.RubyDate
   290  	case "RFC822":
   291  		a.format = time.RFC822
   292  	case "RFC822Z":
   293  		a.format = time.RFC822Z
   294  	case "RFC850":
   295  		a.format = time.RFC850
   296  	case "RFC1123":
   297  		a.format = time.RFC1123
   298  	case "RFC1123Z":
   299  		a.format = time.RFC1123Z
   300  	case "RFC3339":
   301  		a.base = 0
   302  		a.format = time.RFC3339
   303  	case "RFC3339Nano":
   304  		a.base = 0
   305  		a.format = time.RFC3339Nano
   306  	case "Kitchen":
   307  		a.format = time.Kitchen
   308  	case "Stamp":
   309  		a.format = time.Stamp
   310  	case "StampMilli":
   311  		a.format = time.StampMilli
   312  	case "StampMicro":
   313  		a.format = time.StampMicro
   314  	case "StampNano":
   315  		a.format = time.StampNano
   316  	case "DateTime":
   317  		a.format = time.DateTime
   318  	case "DateOnly":
   319  		a.format = time.DateOnly
   320  	case "TimeOnly":
   321  		a.format = time.TimeOnly
   322  	case "unix":
   323  		a.base = 1e0
   324  	case "unixmilli":
   325  		a.base = 1e3
   326  	case "unixmicro":
   327  		a.base = 1e6
   328  	case "unixnano":
   329  		a.base = 1e9
   330  	default:
   331  		// Reject any Go identifier in case new constants are supported.
   332  		if strings.TrimFunc(format, isLetterOrDigit) == "" {
   333  			return false
   334  		}
   335  		a.format = format
   336  	}
   337  	return true
   338  }
   339  
   340  func (a *timeArshaler) isNumeric() bool {
   341  	return int(a.base) > 0
   342  }
   343  
   344  func (a *timeArshaler) hasCustomFormat() bool {
   345  	return a.base == math.MaxUint
   346  }
   347  
   348  func (a *timeArshaler) appendMarshal(b []byte) ([]byte, error) {
   349  	switch a.base {
   350  	case 0:
   351  		format := cmp.Or(a.format, time.RFC3339Nano)
   352  		n0 := len(b)
   353  		b = a.tt.AppendFormat(b, format)
   354  		// Not all Go timestamps can be represented as valid RFC 3339.
   355  		// Explicitly check for these edge cases.
   356  		// See https://go.dev/issue/4556 and https://go.dev/issue/54580.
   357  		switch b := b[n0:]; {
   358  		case b[len("9999")] != '-': // year must be exactly 4 digits wide
   359  			return b, errors.New("year outside of range [0,9999]")
   360  		case b[len(b)-1] != 'Z':
   361  			c := b[len(b)-len("Z07:00")]
   362  			if ('0' <= c && c <= '9') || parseDec2(b[len(b)-len("07:00"):]) >= 24 {
   363  				return b, errors.New("timezone hour outside of range [0,23]")
   364  			}
   365  		}
   366  		return b, nil
   367  	case math.MaxUint:
   368  		return a.tt.AppendFormat(b, a.format), nil
   369  	default:
   370  		return appendTimeUnix(b, a.tt, a.base), nil
   371  	}
   372  }
   373  
   374  func (a *timeArshaler) unmarshal(b []byte) (err error) {
   375  	switch a.base {
   376  	case 0:
   377  		// Use time.Time.UnmarshalText to avoid possible string allocation.
   378  		if err := a.tt.UnmarshalText(b); err != nil {
   379  			return err
   380  		}
   381  		// TODO(https://go.dev/issue/57912):
   382  		// RFC 3339 specifies the grammar for a valid timestamp.
   383  		// However, the parsing functionality in "time" is too loose and
   384  		// incorrectly accepts invalid timestamps as valid.
   385  		// Remove these manual checks when "time" checks it for us.
   386  		newParseError := func(layout, value, layoutElem, valueElem, message string) error {
   387  			return &time.ParseError{Layout: layout, Value: value, LayoutElem: layoutElem, ValueElem: valueElem, Message: message}
   388  		}
   389  		switch {
   390  		case a.looseRFC3339:
   391  			return nil
   392  		case b[len("2006-01-02T")+1] == ':': // hour must be two digits
   393  			return newParseError(time.RFC3339, string(b), "15", string(b[len("2006-01-02T"):][:1]), "")
   394  		case b[len("2006-01-02T15:04:05")] == ',': // sub-second separator must be a period
   395  			return newParseError(time.RFC3339, string(b), ".", ",", "")
   396  		case b[len(b)-1] != 'Z':
   397  			switch {
   398  			case parseDec2(b[len(b)-len("07:00"):]) >= 24: // timezone hour must be in range
   399  				return newParseError(time.RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone hour out of range")
   400  			case parseDec2(b[len(b)-len("00"):]) >= 60: // timezone minute must be in range
   401  				return newParseError(time.RFC3339, string(b), "Z07:00", string(b[len(b)-len("Z07:00"):]), ": timezone minute out of range")
   402  			}
   403  		}
   404  		return nil
   405  	case math.MaxUint:
   406  		a.tt, err = time.Parse(a.format, string(b))
   407  		return err
   408  	default:
   409  		a.tt, err = parseTimeUnix(b, a.base)
   410  		return err
   411  	}
   412  }
   413  
   414  // appendDurationBase10 appends d formatted as a decimal fractional number,
   415  // where pow10 is a power-of-10 used to scale down the number.
   416  func appendDurationBase10(b []byte, d time.Duration, pow10 uint64) []byte {
   417  	b, n := mayAppendDurationSign(b, d)            // append sign
   418  	whole, frac := bits.Div64(0, n, uint64(pow10)) // compute whole and frac fields
   419  	b = strconv.AppendUint(b, whole, 10)           // append whole field
   420  	return appendFracBase10(b, frac, pow10)        // append frac field
   421  }
   422  
   423  // parseDurationBase10 parses d from a decimal fractional number,
   424  // where pow10 is a power-of-10 used to scale up the number.
   425  func parseDurationBase10(b []byte, pow10 uint64) (time.Duration, error) {
   426  	suffix, neg := consumeSign(b, false)                     // consume sign
   427  	wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
   428  	whole, okWhole := jsonwire.ParseUint(wholeBytes)         // parse whole field; may overflow
   429  	frac, okFrac := parseFracBase10(fracBytes, pow10)        // parse frac field
   430  	hi, lo := bits.Mul64(whole, uint64(pow10))               // overflow if hi > 0
   431  	sum, co := bits.Add64(lo, uint64(frac), 0)               // overflow if co > 0
   432  	switch d := mayApplyDurationSign(sum, neg); {            // overflow if neg != (d < 0)
   433  	case (!okWhole && whole != math.MaxUint64) || !okFrac:
   434  		return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrSyntax)
   435  	case !okWhole || hi > 0 || co > 0 || neg != (d < 0):
   436  		return 0, fmt.Errorf("invalid duration %q: %w", b, strconv.ErrRange)
   437  	default:
   438  		return d, nil
   439  	}
   440  }
   441  
   442  // appendDurationISO8601 appends an ISO 8601 duration with a restricted grammar,
   443  // where leading and trailing zeroes and zero-value designators are omitted.
   444  // It only uses hour, minute, and second designators since ISO 8601 defines
   445  // those as being "accurate", while year, month, week, and day are "nominal".
   446  func appendDurationISO8601(b []byte, d time.Duration) []byte {
   447  	if d == 0 {
   448  		return append(b, "PT0S"...)
   449  	}
   450  	b, n := mayAppendDurationSign(b, d)
   451  	b = append(b, "PT"...)
   452  	n, nsec := bits.Div64(0, n, 1e9)  // compute nsec field
   453  	n, sec := bits.Div64(0, n, 60)    // compute sec field
   454  	hour, min := bits.Div64(0, n, 60) // compute hour and min fields
   455  	if hour > 0 {
   456  		b = append(strconv.AppendUint(b, hour, 10), 'H')
   457  	}
   458  	if min > 0 {
   459  		b = append(strconv.AppendUint(b, min, 10), 'M')
   460  	}
   461  	if sec > 0 || nsec > 0 {
   462  		b = append(appendFracBase10(strconv.AppendUint(b, sec, 10), nsec, 1e9), 'S')
   463  	}
   464  	return b
   465  }
   466  
   467  // daysPerYear is the exact average number of days in a year according to
   468  // the Gregorian calender, which has an extra day each year that is
   469  // a multiple of 4, unless it is evenly divisible by 100 but not by 400.
   470  // This does not take into account leap seconds, which are not deterministic.
   471  const daysPerYear = 365.2425
   472  
   473  var errInaccurateDateUnits = errors.New("inaccurate year, month, week, or day units")
   474  
   475  // parseDurationISO8601 parses a duration according to ISO 8601-1:2019,
   476  // section 5.5.2.2 and 5.5.2.3 with the following restrictions or extensions:
   477  //
   478  //   - A leading minus sign is permitted for negative duration according
   479  //     to ISO 8601-2:2019, section 4.4.1.9. We do not permit negative values
   480  //     for each "time scale component", which is permitted by section 4.4.1.1,
   481  //     but rarely supported by parsers.
   482  //
   483  //   - A leading plus sign is permitted (and ignored).
   484  //     This is not required by ISO 8601, but not forbidden either.
   485  //     There is some precedent for this as it is supported by the principle of
   486  //     duration arithmetic as specified in ISO 8601-2-2019, section 14.1.
   487  //     Of note, the JavaScript grammar for ISO 8601 permits a leading plus sign.
   488  //
   489  //   - A fractional value is only permitted for accurate units
   490  //     (i.e., hour, minute, and seconds) in the last time component,
   491  //     which is permissible by ISO 8601-1:2019, section 5.5.2.3.
   492  //
   493  //   - Both periods ('.') and commas (',') are supported as the separator
   494  //     between the integer part and fraction part of a number,
   495  //     as specified in ISO 8601-1:2019, section 3.2.6.
   496  //     While ISO 8601 recommends comma as the default separator,
   497  //     most formatters uses a period.
   498  //
   499  //   - Leading zeros are ignored. This is not required by ISO 8601,
   500  //     but also not forbidden by the standard. Many parsers support this.
   501  //
   502  //   - Lowercase designators are supported. This is not required by ISO 8601,
   503  //     but also not forbidden by the standard. Many parsers support this.
   504  //
   505  // If the nominal units of year, month, week, or day are present,
   506  // this produces a best-effort value and also reports [errInaccurateDateUnits].
   507  //
   508  // The accepted grammar is identical to JavaScript's Duration:
   509  //
   510  //	https://tc39.es/proposal-temporal/#prod-Duration
   511  //
   512  // We follow JavaScript's grammar as JSON itself is derived from JavaScript.
   513  // The Temporal.Duration.toJSON method is guaranteed to produce an output
   514  // that can be parsed by this function so long as arithmetic in JavaScript
   515  // do not use a largestUnit value higher than "hours" (which is the default).
   516  // Even if it does, this will do a best-effort parsing with inaccurate units,
   517  // but report [errInaccurateDateUnits].
   518  func parseDurationISO8601(b []byte) (time.Duration, error) {
   519  	var invalid, overflow, inaccurate, sawFrac bool
   520  	var sumNanos, n, co uint64
   521  
   522  	// cutBytes is like [bytes.Cut], but uses either c0 or c1 as the separator.
   523  	cutBytes := func(b []byte, c0, c1 byte) (prefix, suffix []byte, ok bool) {
   524  		for i, c := range b {
   525  			if c == c0 || c == c1 {
   526  				return b[:i], b[i+1:], true
   527  			}
   528  		}
   529  		return b, nil, false
   530  	}
   531  
   532  	// mayParseUnit attempts to parse another date or time number
   533  	// identified by the desHi and desLo unit characters.
   534  	// If the part is absent for current unit, it returns b as is.
   535  	mayParseUnit := func(b []byte, desHi, desLo byte, unit time.Duration) []byte {
   536  		number, suffix, ok := cutBytes(b, desHi, desLo)
   537  		if !ok || sawFrac {
   538  			return b // designator is not present or already saw fraction, which can only be in the last component
   539  		}
   540  
   541  		// Parse the number.
   542  		// A fraction allowed for the accurate units in the last part.
   543  		whole, frac, ok := cutBytes(number, '.', ',')
   544  		if ok {
   545  			sawFrac = true
   546  			invalid = invalid || len(frac) == len("") || unit > time.Hour
   547  			if unit == time.Second {
   548  				n, ok = parsePaddedBase10(frac, uint64(time.Second))
   549  				invalid = invalid || !ok
   550  			} else {
   551  				f, err := strconv.ParseFloat("0."+string(frac), 64)
   552  				invalid = invalid || err != nil || len(bytes.Trim(frac[len("."):], "0123456789")) > 0
   553  				n = uint64(math.Round(f * float64(unit))) // never overflows since f is within [0..1]
   554  			}
   555  			sumNanos, co = bits.Add64(sumNanos, n, 0) // overflow if co > 0
   556  			overflow = overflow || co > 0
   557  		}
   558  		for len(whole) > 1 && whole[0] == '0' {
   559  			whole = whole[len("0"):] // trim leading zeros
   560  		}
   561  		n, ok := jsonwire.ParseUint(whole)         // overflow if !ok && MaxUint64
   562  		hi, lo := bits.Mul64(n, uint64(unit))      // overflow if hi > 0
   563  		sumNanos, co = bits.Add64(sumNanos, lo, 0) // overflow if co > 0
   564  		invalid = invalid || (!ok && n != math.MaxUint64)
   565  		overflow = overflow || (!ok && n == math.MaxUint64) || hi > 0 || co > 0
   566  		inaccurate = inaccurate || unit > time.Hour
   567  		return suffix
   568  	}
   569  
   570  	suffix, neg := consumeSign(b, true)
   571  	prefix, suffix, okP := cutBytes(suffix, 'P', 'p')
   572  	durDate, durTime, okT := cutBytes(suffix, 'T', 't')
   573  	invalid = invalid || len(prefix) > 0 || !okP || (okT && len(durTime) == 0) || len(durDate)+len(durTime) == 0
   574  	if len(durDate) > 0 { // nominal portion of the duration
   575  		durDate = mayParseUnit(durDate, 'Y', 'y', time.Duration(daysPerYear*24*60*60*1e9))
   576  		durDate = mayParseUnit(durDate, 'M', 'm', time.Duration(daysPerYear/12*24*60*60*1e9))
   577  		durDate = mayParseUnit(durDate, 'W', 'w', time.Duration(7*24*60*60*1e9))
   578  		durDate = mayParseUnit(durDate, 'D', 'd', time.Duration(24*60*60*1e9))
   579  		invalid = invalid || len(durDate) > 0 // unknown elements
   580  	}
   581  	if len(durTime) > 0 { // accurate portion of the duration
   582  		durTime = mayParseUnit(durTime, 'H', 'h', time.Duration(60*60*1e9))
   583  		durTime = mayParseUnit(durTime, 'M', 'm', time.Duration(60*1e9))
   584  		durTime = mayParseUnit(durTime, 'S', 's', time.Duration(1e9))
   585  		invalid = invalid || len(durTime) > 0 // unknown elements
   586  	}
   587  	d := mayApplyDurationSign(sumNanos, neg)
   588  	overflow = overflow || (neg != (d < 0) && d != 0) // overflows signed duration
   589  
   590  	switch {
   591  	case invalid:
   592  		return 0, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, strconv.ErrSyntax)
   593  	case overflow:
   594  		return 0, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, strconv.ErrRange)
   595  	case inaccurate:
   596  		return d, fmt.Errorf("invalid ISO 8601 duration %q: %w", b, errInaccurateDateUnits)
   597  	default:
   598  		return d, nil
   599  	}
   600  }
   601  
   602  // mayAppendDurationSign appends a negative sign if n is negative.
   603  func mayAppendDurationSign(b []byte, d time.Duration) ([]byte, uint64) {
   604  	if d < 0 {
   605  		b = append(b, '-')
   606  		d *= -1
   607  	}
   608  	return b, uint64(d)
   609  }
   610  
   611  // mayApplyDurationSign inverts n if neg is specified.
   612  func mayApplyDurationSign(n uint64, neg bool) time.Duration {
   613  	if neg {
   614  		return -1 * time.Duration(n)
   615  	} else {
   616  		return +1 * time.Duration(n)
   617  	}
   618  }
   619  
   620  // appendTimeUnix appends t formatted as a decimal fractional number,
   621  // where pow10 is a power-of-10 used to scale up the number.
   622  func appendTimeUnix(b []byte, t time.Time, pow10 uint64) []byte {
   623  	sec, nsec := t.Unix(), int64(t.Nanosecond())
   624  	if sec < 0 {
   625  		b = append(b, '-')
   626  		sec, nsec = negateSecNano(sec, nsec)
   627  	}
   628  	switch {
   629  	case pow10 == 1e0: // fast case where units is in seconds
   630  		b = strconv.AppendUint(b, uint64(sec), 10)
   631  		return appendFracBase10(b, uint64(nsec), 1e9)
   632  	case uint64(sec) < 1e9: // intermediate case where units is not seconds, but no overflow
   633  		b = strconv.AppendUint(b, uint64(sec)*uint64(pow10)+uint64(uint64(nsec)/(1e9/pow10)), 10)
   634  		return appendFracBase10(b, (uint64(nsec)*pow10)%1e9, 1e9)
   635  	default: // slow case where units is not seconds and overflow would occur
   636  		b = strconv.AppendUint(b, uint64(sec), 10)
   637  		b = appendPaddedBase10(b, uint64(nsec)/(1e9/pow10), pow10)
   638  		return appendFracBase10(b, (uint64(nsec)*pow10)%1e9, 1e9)
   639  	}
   640  }
   641  
   642  // parseTimeUnix parses t formatted as a decimal fractional number,
   643  // where pow10 is a power-of-10 used to scale down the number.
   644  func parseTimeUnix(b []byte, pow10 uint64) (time.Time, error) {
   645  	suffix, neg := consumeSign(b, false)                     // consume sign
   646  	wholeBytes, fracBytes := bytesCutByte(suffix, '.', true) // consume whole and frac fields
   647  	whole, okWhole := jsonwire.ParseUint(wholeBytes)         // parse whole field; may overflow
   648  	frac, okFrac := parseFracBase10(fracBytes, 1e9/pow10)    // parse frac field
   649  	var sec, nsec int64
   650  	switch {
   651  	case pow10 == 1e0: // fast case where units is in seconds
   652  		sec = int64(whole) // check overflow later after negation
   653  		nsec = int64(frac) // cannot overflow
   654  	case okWhole: // intermediate case where units is not seconds, but no overflow
   655  		sec = int64(whole / pow10)                     // check overflow later after negation
   656  		nsec = int64((whole%pow10)*(1e9/pow10) + frac) // cannot overflow
   657  	case !okWhole && whole == math.MaxUint64: // slow case where units is not seconds and overflow occurred
   658  		width := int(math.Log10(float64(pow10)))                                // compute len(strconv.Itoa(pow10-1))
   659  		whole, okWhole = jsonwire.ParseUint(wholeBytes[:len(wholeBytes)-width]) // parse the upper whole field
   660  		mid, _ := parsePaddedBase10(wholeBytes[len(wholeBytes)-width:], pow10)  // parse the lower whole field
   661  		sec = int64(whole)                                                      // check overflow later after negation
   662  		nsec = int64(mid*(1e9/pow10) + frac)                                    // cannot overflow
   663  	}
   664  	if neg {
   665  		sec, nsec = negateSecNano(sec, nsec)
   666  	}
   667  	switch t := time.Unix(sec, nsec).UTC(); {
   668  	case (!okWhole && whole != math.MaxUint64) || !okFrac:
   669  		return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrSyntax)
   670  	case !okWhole || neg != (t.Unix() < 0):
   671  		return time.Time{}, fmt.Errorf("invalid time %q: %w", b, strconv.ErrRange)
   672  	default:
   673  		return t, nil
   674  	}
   675  }
   676  
   677  // negateSecNano negates a Unix timestamp, where nsec must be within [0, 1e9).
   678  func negateSecNano(sec, nsec int64) (int64, int64) {
   679  	sec = ^sec               // twos-complement negation (i.e., -1*sec + 1)
   680  	nsec = -nsec + 1e9       // negate nsec and add 1e9 (which is the extra +1 from sec negation)
   681  	sec += int64(nsec / 1e9) // handle possible overflow of nsec if it started as zero
   682  	nsec %= 1e9              // ensure nsec stays within [0, 1e9)
   683  	return sec, nsec
   684  }
   685  
   686  // appendFracBase10 appends the fraction of n/max10,
   687  // where max10 is a power-of-10 that is larger than n.
   688  func appendFracBase10(b []byte, n, max10 uint64) []byte {
   689  	if n == 0 {
   690  		return b
   691  	}
   692  	return bytes.TrimRight(appendPaddedBase10(append(b, '.'), n, max10), "0")
   693  }
   694  
   695  // parseFracBase10 parses the fraction of n/max10,
   696  // where max10 is a power-of-10 that is larger than n.
   697  func parseFracBase10(b []byte, max10 uint64) (n uint64, ok bool) {
   698  	switch {
   699  	case len(b) == 0:
   700  		return 0, true
   701  	case len(b) < len(".0") || b[0] != '.':
   702  		return 0, false
   703  	}
   704  	return parsePaddedBase10(b[len("."):], max10)
   705  }
   706  
   707  // appendPaddedBase10 appends a zero-padded encoding of n,
   708  // where max10 is a power-of-10 that is larger than n.
   709  func appendPaddedBase10(b []byte, n, max10 uint64) []byte {
   710  	if n < max10/10 {
   711  		// Formatting of n is shorter than log10(max10),
   712  		// so add max10/10 to ensure the length is equal to log10(max10).
   713  		i := len(b)
   714  		b = strconv.AppendUint(b, n+max10/10, 10)
   715  		b[i]-- // subtract the addition of max10/10
   716  		return b
   717  	}
   718  	return strconv.AppendUint(b, n, 10)
   719  }
   720  
   721  // parsePaddedBase10 parses b as the zero-padded encoding of n,
   722  // where max10 is a power-of-10 that is larger than n.
   723  // Truncated suffix is treated as implicit zeros.
   724  // Extended suffix is ignored, but verified to contain only digits.
   725  func parsePaddedBase10(b []byte, max10 uint64) (n uint64, ok bool) {
   726  	pow10 := uint64(1)
   727  	for pow10 < max10 {
   728  		n *= 10
   729  		if len(b) > 0 {
   730  			if b[0] < '0' || '9' < b[0] {
   731  				return n, false
   732  			}
   733  			n += uint64(b[0] - '0')
   734  			b = b[1:]
   735  		}
   736  		pow10 *= 10
   737  	}
   738  	if len(b) > 0 && len(bytes.TrimRight(b, "0123456789")) > 0 {
   739  		return n, false // trailing characters are not digits
   740  	}
   741  	return n, true
   742  }
   743  
   744  // consumeSign consumes an optional leading negative or positive sign.
   745  func consumeSign(b []byte, allowPlus bool) ([]byte, bool) {
   746  	if len(b) > 0 {
   747  		if b[0] == '-' {
   748  			return b[len("-"):], true
   749  		} else if b[0] == '+' && allowPlus {
   750  			return b[len("+"):], false
   751  		}
   752  	}
   753  	return b, false
   754  }
   755  
   756  // bytesCutByte is similar to bytes.Cut(b, []byte{c}),
   757  // except c may optionally be included as part of the suffix.
   758  func bytesCutByte(b []byte, c byte, include bool) ([]byte, []byte) {
   759  	if i := bytes.IndexByte(b, c); i >= 0 {
   760  		if include {
   761  			return b[:i], b[i:]
   762  		}
   763  		return b[:i], b[i+1:]
   764  	}
   765  	return b, nil
   766  }
   767  
   768  // parseDec2 parses b as an unsigned, base-10, 2-digit number.
   769  // The result is undefined if digits are not base-10.
   770  func parseDec2(b []byte) byte {
   771  	if len(b) < 2 {
   772  		return 0
   773  	}
   774  	return 10*(b[0]-'0') + (b[1] - '0')
   775  }
   776  

View as plain text