...

Source file src/github.com/pelletier/go-toml/v2/unmarshaler.go

Documentation: github.com/pelletier/go-toml/v2

     1  package toml
     2  
     3  import (
     4  	"encoding"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"math"
    10  	"reflect"
    11  	"strings"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/pelletier/go-toml/v2/internal/danger"
    16  	"github.com/pelletier/go-toml/v2/internal/tracker"
    17  	"github.com/pelletier/go-toml/v2/unstable"
    18  )
    19  
    20  // Unmarshal deserializes a TOML document into a Go value.
    21  //
    22  // It is a shortcut for Decoder.Decode() with the default options.
    23  func Unmarshal(data []byte, v interface{}) error {
    24  	p := unstable.Parser{}
    25  	p.Reset(data)
    26  	d := decoder{p: &p}
    27  
    28  	return d.FromParser(v)
    29  }
    30  
    31  // Decoder reads and decode a TOML document from an input stream.
    32  type Decoder struct {
    33  	// input
    34  	r io.Reader
    35  
    36  	// global settings
    37  	strict bool
    38  }
    39  
    40  // NewDecoder creates a new Decoder that will read from r.
    41  func NewDecoder(r io.Reader) *Decoder {
    42  	return &Decoder{r: r}
    43  }
    44  
    45  // DisallowUnknownFields causes the Decoder to return an error when the
    46  // destination is a struct and the input contains a key that does not match a
    47  // non-ignored field.
    48  //
    49  // In that case, the Decoder returns a StrictMissingError that can be used to
    50  // retrieve the individual errors as well as generate a human readable
    51  // description of the missing fields.
    52  func (d *Decoder) DisallowUnknownFields() *Decoder {
    53  	d.strict = true
    54  	return d
    55  }
    56  
    57  // Decode the whole content of r into v.
    58  //
    59  // By default, values in the document that don't exist in the target Go value
    60  // are ignored. See Decoder.DisallowUnknownFields() to change this behavior.
    61  //
    62  // When a TOML local date, time, or date-time is decoded into a time.Time, its
    63  // value is represented in time.Local timezone. Otherwise the appropriate Local*
    64  // structure is used. For time values, precision up to the nanosecond is
    65  // supported by truncating extra digits.
    66  //
    67  // Empty tables decoded in an interface{} create an empty initialized
    68  // map[string]interface{}.
    69  //
    70  // Types implementing the encoding.TextUnmarshaler interface are decoded from a
    71  // TOML string.
    72  //
    73  // When decoding a number, go-toml will return an error if the number is out of
    74  // bounds for the target type (which includes negative numbers when decoding
    75  // into an unsigned int).
    76  //
    77  // If an error occurs while decoding the content of the document, this function
    78  // returns a toml.DecodeError, providing context about the issue. When using
    79  // strict mode and a field is missing, a `toml.StrictMissingError` is
    80  // returned. In any other case, this function returns a standard Go error.
    81  //
    82  // # Type mapping
    83  //
    84  // List of supported TOML types and their associated accepted Go types:
    85  //
    86  //	String           -> string
    87  //	Integer          -> uint*, int*, depending on size
    88  //	Float            -> float*, depending on size
    89  //	Boolean          -> bool
    90  //	Offset Date-Time -> time.Time
    91  //	Local Date-time  -> LocalDateTime, time.Time
    92  //	Local Date       -> LocalDate, time.Time
    93  //	Local Time       -> LocalTime, time.Time
    94  //	Array            -> slice and array, depending on elements types
    95  //	Table            -> map and struct
    96  //	Inline Table     -> same as Table
    97  //	Array of Tables  -> same as Array and Table
    98  func (d *Decoder) Decode(v interface{}) error {
    99  	b, err := ioutil.ReadAll(d.r)
   100  	if err != nil {
   101  		return fmt.Errorf("toml: %w", err)
   102  	}
   103  
   104  	p := unstable.Parser{}
   105  	p.Reset(b)
   106  	dec := decoder{
   107  		p: &p,
   108  		strict: strict{
   109  			Enabled: d.strict,
   110  		},
   111  	}
   112  
   113  	return dec.FromParser(v)
   114  }
   115  
   116  type decoder struct {
   117  	// Which parser instance in use for this decoding session.
   118  	p *unstable.Parser
   119  
   120  	// Flag indicating that the current expression is stashed.
   121  	// If set to true, calling nextExpr will not actually pull a new expression
   122  	// but turn off the flag instead.
   123  	stashedExpr bool
   124  
   125  	// Skip expressions until a table is found. This is set to true when a
   126  	// table could not be created (missing field in map), so all KV expressions
   127  	// need to be skipped.
   128  	skipUntilTable bool
   129  
   130  	// Tracks position in Go arrays.
   131  	// This is used when decoding [[array tables]] into Go arrays. Given array
   132  	// tables are separate TOML expression, we need to keep track of where we
   133  	// are at in the Go array, as we can't just introspect its size.
   134  	arrayIndexes map[reflect.Value]int
   135  
   136  	// Tracks keys that have been seen, with which type.
   137  	seen tracker.SeenTracker
   138  
   139  	// Strict mode
   140  	strict strict
   141  
   142  	// Current context for the error.
   143  	errorContext *errorContext
   144  }
   145  
   146  type errorContext struct {
   147  	Struct reflect.Type
   148  	Field  []int
   149  }
   150  
   151  func (d *decoder) typeMismatchError(toml string, target reflect.Type) error {
   152  	return fmt.Errorf("toml: %s", d.typeMismatchString(toml, target))
   153  }
   154  
   155  func (d *decoder) typeMismatchString(toml string, target reflect.Type) string {
   156  	if d.errorContext != nil && d.errorContext.Struct != nil {
   157  		ctx := d.errorContext
   158  		f := ctx.Struct.FieldByIndex(ctx.Field)
   159  		return fmt.Sprintf("cannot decode TOML %s into struct field %s.%s of type %s", toml, ctx.Struct, f.Name, f.Type)
   160  	}
   161  	return fmt.Sprintf("cannot decode TOML %s into a Go value of type %s", toml, target)
   162  }
   163  
   164  func (d *decoder) expr() *unstable.Node {
   165  	return d.p.Expression()
   166  }
   167  
   168  func (d *decoder) nextExpr() bool {
   169  	if d.stashedExpr {
   170  		d.stashedExpr = false
   171  		return true
   172  	}
   173  	return d.p.NextExpression()
   174  }
   175  
   176  func (d *decoder) stashExpr() {
   177  	d.stashedExpr = true
   178  }
   179  
   180  func (d *decoder) arrayIndex(shouldAppend bool, v reflect.Value) int {
   181  	if d.arrayIndexes == nil {
   182  		d.arrayIndexes = make(map[reflect.Value]int, 1)
   183  	}
   184  
   185  	idx, ok := d.arrayIndexes[v]
   186  
   187  	if !ok {
   188  		d.arrayIndexes[v] = 0
   189  	} else if shouldAppend {
   190  		idx++
   191  		d.arrayIndexes[v] = idx
   192  	}
   193  
   194  	return idx
   195  }
   196  
   197  func (d *decoder) FromParser(v interface{}) error {
   198  	r := reflect.ValueOf(v)
   199  	if r.Kind() != reflect.Ptr {
   200  		return fmt.Errorf("toml: decoding can only be performed into a pointer, not %s", r.Kind())
   201  	}
   202  
   203  	if r.IsNil() {
   204  		return fmt.Errorf("toml: decoding pointer target cannot be nil")
   205  	}
   206  
   207  	r = r.Elem()
   208  	if r.Kind() == reflect.Interface && r.IsNil() {
   209  		newMap := map[string]interface{}{}
   210  		r.Set(reflect.ValueOf(newMap))
   211  	}
   212  
   213  	err := d.fromParser(r)
   214  	if err == nil {
   215  		return d.strict.Error(d.p.Data())
   216  	}
   217  
   218  	var e *unstable.ParserError
   219  	if errors.As(err, &e) {
   220  		return wrapDecodeError(d.p.Data(), e)
   221  	}
   222  
   223  	return err
   224  }
   225  
   226  func (d *decoder) fromParser(root reflect.Value) error {
   227  	for d.nextExpr() {
   228  		err := d.handleRootExpression(d.expr(), root)
   229  		if err != nil {
   230  			return err
   231  		}
   232  	}
   233  
   234  	return d.p.Error()
   235  }
   236  
   237  /*
   238  Rules for the unmarshal code:
   239  
   240  - The stack is used to keep track of which values need to be set where.
   241  - handle* functions <=> switch on a given unstable.Kind.
   242  - unmarshalX* functions need to unmarshal a node of kind X.
   243  - An "object" is either a struct or a map.
   244  */
   245  
   246  func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
   247  	var x reflect.Value
   248  	var err error
   249  
   250  	if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
   251  		err = d.seen.CheckExpression(expr)
   252  		if err != nil {
   253  			return err
   254  		}
   255  	}
   256  
   257  	switch expr.Kind {
   258  	case unstable.KeyValue:
   259  		if d.skipUntilTable {
   260  			return nil
   261  		}
   262  		x, err = d.handleKeyValue(expr, v)
   263  	case unstable.Table:
   264  		d.skipUntilTable = false
   265  		d.strict.EnterTable(expr)
   266  		x, err = d.handleTable(expr.Key(), v)
   267  	case unstable.ArrayTable:
   268  		d.skipUntilTable = false
   269  		d.strict.EnterArrayTable(expr)
   270  		x, err = d.handleArrayTable(expr.Key(), v)
   271  	default:
   272  		panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
   273  	}
   274  
   275  	if d.skipUntilTable {
   276  		if expr.Kind == unstable.Table || expr.Kind == unstable.ArrayTable {
   277  			d.strict.MissingTable(expr)
   278  		}
   279  	} else if err == nil && x.IsValid() {
   280  		v.Set(x)
   281  	}
   282  
   283  	return err
   284  }
   285  
   286  func (d *decoder) handleArrayTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   287  	if key.Next() {
   288  		return d.handleArrayTablePart(key, v)
   289  	}
   290  	return d.handleKeyValues(v)
   291  }
   292  
   293  func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   294  	switch v.Kind() {
   295  	case reflect.Interface:
   296  		elem := v.Elem()
   297  		if !elem.IsValid() {
   298  			elem = reflect.New(sliceInterfaceType).Elem()
   299  			elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   300  		} else if elem.Kind() == reflect.Slice {
   301  			if elem.Type() != sliceInterfaceType {
   302  				elem = reflect.New(sliceInterfaceType).Elem()
   303  				elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   304  			} else if !elem.CanSet() {
   305  				nelem := reflect.New(sliceInterfaceType).Elem()
   306  				nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap()))
   307  				reflect.Copy(nelem, elem)
   308  				elem = nelem
   309  			}
   310  		}
   311  		return d.handleArrayTableCollectionLast(key, elem)
   312  	case reflect.Ptr:
   313  		elem := v.Elem()
   314  		if !elem.IsValid() {
   315  			ptr := reflect.New(v.Type().Elem())
   316  			v.Set(ptr)
   317  			elem = ptr.Elem()
   318  		}
   319  
   320  		elem, err := d.handleArrayTableCollectionLast(key, elem)
   321  		if err != nil {
   322  			return reflect.Value{}, err
   323  		}
   324  		v.Elem().Set(elem)
   325  
   326  		return v, nil
   327  	case reflect.Slice:
   328  		elemType := v.Type().Elem()
   329  		var elem reflect.Value
   330  		if elemType.Kind() == reflect.Interface {
   331  			elem = makeMapStringInterface()
   332  		} else {
   333  			elem = reflect.New(elemType).Elem()
   334  		}
   335  		elem2, err := d.handleArrayTable(key, elem)
   336  		if err != nil {
   337  			return reflect.Value{}, err
   338  		}
   339  		if elem2.IsValid() {
   340  			elem = elem2
   341  		}
   342  		return reflect.Append(v, elem), nil
   343  	case reflect.Array:
   344  		idx := d.arrayIndex(true, v)
   345  		if idx >= v.Len() {
   346  			return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
   347  		}
   348  		elem := v.Index(idx)
   349  		_, err := d.handleArrayTable(key, elem)
   350  		return v, err
   351  	default:
   352  		return reflect.Value{}, d.typeMismatchError("array table", v.Type())
   353  	}
   354  }
   355  
   356  // When parsing an array table expression, each part of the key needs to be
   357  // evaluated like a normal key, but if it returns a collection, it also needs to
   358  // point to the last element of the collection. Unless it is the last part of
   359  // the key, then it needs to create a new element at the end.
   360  func (d *decoder) handleArrayTableCollection(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   361  	if key.IsLast() {
   362  		return d.handleArrayTableCollectionLast(key, v)
   363  	}
   364  
   365  	switch v.Kind() {
   366  	case reflect.Ptr:
   367  		elem := v.Elem()
   368  		if !elem.IsValid() {
   369  			ptr := reflect.New(v.Type().Elem())
   370  			v.Set(ptr)
   371  			elem = ptr.Elem()
   372  		}
   373  
   374  		elem, err := d.handleArrayTableCollection(key, elem)
   375  		if err != nil {
   376  			return reflect.Value{}, err
   377  		}
   378  		if elem.IsValid() {
   379  			v.Elem().Set(elem)
   380  		}
   381  
   382  		return v, nil
   383  	case reflect.Slice:
   384  		elem := v.Index(v.Len() - 1)
   385  		x, err := d.handleArrayTable(key, elem)
   386  		if err != nil || d.skipUntilTable {
   387  			return reflect.Value{}, err
   388  		}
   389  		if x.IsValid() {
   390  			elem.Set(x)
   391  		}
   392  
   393  		return v, err
   394  	case reflect.Array:
   395  		idx := d.arrayIndex(false, v)
   396  		if idx >= v.Len() {
   397  			return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
   398  		}
   399  		elem := v.Index(idx)
   400  		_, err := d.handleArrayTable(key, elem)
   401  		return v, err
   402  	}
   403  
   404  	return d.handleArrayTable(key, v)
   405  }
   406  
   407  func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) {
   408  	var rv reflect.Value
   409  
   410  	// First, dispatch over v to make sure it is a valid object.
   411  	// There is no guarantee over what it could be.
   412  	switch v.Kind() {
   413  	case reflect.Ptr:
   414  		elem := v.Elem()
   415  		if !elem.IsValid() {
   416  			v.Set(reflect.New(v.Type().Elem()))
   417  		}
   418  		elem = v.Elem()
   419  		return d.handleKeyPart(key, elem, nextFn, makeFn)
   420  	case reflect.Map:
   421  		vt := v.Type()
   422  
   423  		// Create the key for the map element. Convert to key type.
   424  		mk, err := d.keyFromData(vt.Key(), key.Node().Data)
   425  		if err != nil {
   426  			return reflect.Value{}, err
   427  		}
   428  
   429  		// If the map does not exist, create it.
   430  		if v.IsNil() {
   431  			vt := v.Type()
   432  			v = reflect.MakeMap(vt)
   433  			rv = v
   434  		}
   435  
   436  		mv := v.MapIndex(mk)
   437  		set := false
   438  		if !mv.IsValid() {
   439  			// If there is no value in the map, create a new one according to
   440  			// the map type. If the element type is interface, create either a
   441  			// map[string]interface{} or a []interface{} depending on whether
   442  			// this is the last part of the array table key.
   443  
   444  			t := vt.Elem()
   445  			if t.Kind() == reflect.Interface {
   446  				mv = makeFn()
   447  			} else {
   448  				mv = reflect.New(t).Elem()
   449  			}
   450  			set = true
   451  		} else if mv.Kind() == reflect.Interface {
   452  			mv = mv.Elem()
   453  			if !mv.IsValid() {
   454  				mv = makeFn()
   455  			}
   456  			set = true
   457  		} else if !mv.CanAddr() {
   458  			vt := v.Type()
   459  			t := vt.Elem()
   460  			oldmv := mv
   461  			mv = reflect.New(t).Elem()
   462  			mv.Set(oldmv)
   463  			set = true
   464  		}
   465  
   466  		x, err := nextFn(key, mv)
   467  		if err != nil {
   468  			return reflect.Value{}, err
   469  		}
   470  
   471  		if x.IsValid() {
   472  			mv = x
   473  			set = true
   474  		}
   475  
   476  		if set {
   477  			v.SetMapIndex(mk, mv)
   478  		}
   479  	case reflect.Struct:
   480  		path, found := structFieldPath(v, string(key.Node().Data))
   481  		if !found {
   482  			d.skipUntilTable = true
   483  			return reflect.Value{}, nil
   484  		}
   485  
   486  		if d.errorContext == nil {
   487  			d.errorContext = new(errorContext)
   488  		}
   489  		t := v.Type()
   490  		d.errorContext.Struct = t
   491  		d.errorContext.Field = path
   492  
   493  		f := fieldByIndex(v, path)
   494  		x, err := nextFn(key, f)
   495  		if err != nil || d.skipUntilTable {
   496  			return reflect.Value{}, err
   497  		}
   498  		if x.IsValid() {
   499  			f.Set(x)
   500  		}
   501  		d.errorContext.Field = nil
   502  		d.errorContext.Struct = nil
   503  	case reflect.Interface:
   504  		if v.Elem().IsValid() {
   505  			v = v.Elem()
   506  		} else {
   507  			v = makeMapStringInterface()
   508  		}
   509  
   510  		x, err := d.handleKeyPart(key, v, nextFn, makeFn)
   511  		if err != nil {
   512  			return reflect.Value{}, err
   513  		}
   514  		if x.IsValid() {
   515  			v = x
   516  		}
   517  		rv = v
   518  	default:
   519  		panic(fmt.Errorf("unhandled part: %s", v.Kind()))
   520  	}
   521  
   522  	return rv, nil
   523  }
   524  
   525  // HandleArrayTablePart navigates the Go structure v using the key v. It is
   526  // only used for the prefix (non-last) parts of an array-table. When
   527  // encountering a collection, it should go to the last element.
   528  func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   529  	var makeFn valueMakerFn
   530  	if key.IsLast() {
   531  		makeFn = makeSliceInterface
   532  	} else {
   533  		makeFn = makeMapStringInterface
   534  	}
   535  	return d.handleKeyPart(key, v, d.handleArrayTableCollection, makeFn)
   536  }
   537  
   538  // HandleTable returns a reference when it has checked the next expression but
   539  // cannot handle it.
   540  func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   541  	if v.Kind() == reflect.Slice {
   542  		if v.Len() == 0 {
   543  			return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
   544  		}
   545  		elem := v.Index(v.Len() - 1)
   546  		x, err := d.handleTable(key, elem)
   547  		if err != nil {
   548  			return reflect.Value{}, err
   549  		}
   550  		if x.IsValid() {
   551  			elem.Set(x)
   552  		}
   553  		return reflect.Value{}, nil
   554  	}
   555  	if key.Next() {
   556  		// Still scoping the key
   557  		return d.handleTablePart(key, v)
   558  	}
   559  	// Done scoping the key.
   560  	// Now handle all the key-value expressions in this table.
   561  	return d.handleKeyValues(v)
   562  }
   563  
   564  // Handle root expressions until the end of the document or the next
   565  // non-key-value.
   566  func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
   567  	var rv reflect.Value
   568  	for d.nextExpr() {
   569  		expr := d.expr()
   570  		if expr.Kind != unstable.KeyValue {
   571  			// Stash the expression so that fromParser can just loop and use
   572  			// the right handler.
   573  			// We could just recurse ourselves here, but at least this gives a
   574  			// chance to pop the stack a bit.
   575  			d.stashExpr()
   576  			break
   577  		}
   578  
   579  		err := d.seen.CheckExpression(expr)
   580  		if err != nil {
   581  			return reflect.Value{}, err
   582  		}
   583  
   584  		x, err := d.handleKeyValue(expr, v)
   585  		if err != nil {
   586  			return reflect.Value{}, err
   587  		}
   588  		if x.IsValid() {
   589  			v = x
   590  			rv = x
   591  		}
   592  	}
   593  	return rv, nil
   594  }
   595  
   596  type (
   597  	handlerFn    func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
   598  	valueMakerFn func() reflect.Value
   599  )
   600  
   601  func makeMapStringInterface() reflect.Value {
   602  	return reflect.MakeMap(mapStringInterfaceType)
   603  }
   604  
   605  func makeSliceInterface() reflect.Value {
   606  	return reflect.MakeSlice(sliceInterfaceType, 0, 16)
   607  }
   608  
   609  func (d *decoder) handleTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
   610  	return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface)
   611  }
   612  
   613  func (d *decoder) tryTextUnmarshaler(node *unstable.Node, v reflect.Value) (bool, error) {
   614  	// Special case for time, because we allow to unmarshal to it from
   615  	// different kind of AST nodes.
   616  	if v.Type() == timeType {
   617  		return false, nil
   618  	}
   619  
   620  	if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
   621  		err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
   622  		if err != nil {
   623  			return false, unstable.NewParserError(d.p.Raw(node.Raw), "%w", err)
   624  		}
   625  
   626  		return true, nil
   627  	}
   628  
   629  	return false, nil
   630  }
   631  
   632  func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
   633  	for v.Kind() == reflect.Ptr {
   634  		v = initAndDereferencePointer(v)
   635  	}
   636  
   637  	ok, err := d.tryTextUnmarshaler(value, v)
   638  	if ok || err != nil {
   639  		return err
   640  	}
   641  
   642  	switch value.Kind {
   643  	case unstable.String:
   644  		return d.unmarshalString(value, v)
   645  	case unstable.Integer:
   646  		return d.unmarshalInteger(value, v)
   647  	case unstable.Float:
   648  		return d.unmarshalFloat(value, v)
   649  	case unstable.Bool:
   650  		return d.unmarshalBool(value, v)
   651  	case unstable.DateTime:
   652  		return d.unmarshalDateTime(value, v)
   653  	case unstable.LocalDate:
   654  		return d.unmarshalLocalDate(value, v)
   655  	case unstable.LocalTime:
   656  		return d.unmarshalLocalTime(value, v)
   657  	case unstable.LocalDateTime:
   658  		return d.unmarshalLocalDateTime(value, v)
   659  	case unstable.InlineTable:
   660  		return d.unmarshalInlineTable(value, v)
   661  	case unstable.Array:
   662  		return d.unmarshalArray(value, v)
   663  	default:
   664  		panic(fmt.Errorf("handleValue not implemented for %s", value.Kind))
   665  	}
   666  }
   667  
   668  func (d *decoder) unmarshalArray(array *unstable.Node, v reflect.Value) error {
   669  	switch v.Kind() {
   670  	case reflect.Slice:
   671  		if v.IsNil() {
   672  			v.Set(reflect.MakeSlice(v.Type(), 0, 16))
   673  		} else {
   674  			v.SetLen(0)
   675  		}
   676  	case reflect.Array:
   677  		// arrays are always initialized
   678  	case reflect.Interface:
   679  		elem := v.Elem()
   680  		if !elem.IsValid() {
   681  			elem = reflect.New(sliceInterfaceType).Elem()
   682  			elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   683  		} else if elem.Kind() == reflect.Slice {
   684  			if elem.Type() != sliceInterfaceType {
   685  				elem = reflect.New(sliceInterfaceType).Elem()
   686  				elem.Set(reflect.MakeSlice(sliceInterfaceType, 0, 16))
   687  			} else if !elem.CanSet() {
   688  				nelem := reflect.New(sliceInterfaceType).Elem()
   689  				nelem.Set(reflect.MakeSlice(sliceInterfaceType, elem.Len(), elem.Cap()))
   690  				reflect.Copy(nelem, elem)
   691  				elem = nelem
   692  			}
   693  		}
   694  		err := d.unmarshalArray(array, elem)
   695  		if err != nil {
   696  			return err
   697  		}
   698  		v.Set(elem)
   699  		return nil
   700  	default:
   701  		// TODO: use newDecodeError, but first the parser needs to fill
   702  		//   array.Data.
   703  		return d.typeMismatchError("array", v.Type())
   704  	}
   705  
   706  	elemType := v.Type().Elem()
   707  
   708  	it := array.Children()
   709  	idx := 0
   710  	for it.Next() {
   711  		n := it.Node()
   712  
   713  		// TODO: optimize
   714  		if v.Kind() == reflect.Slice {
   715  			elem := reflect.New(elemType).Elem()
   716  
   717  			err := d.handleValue(n, elem)
   718  			if err != nil {
   719  				return err
   720  			}
   721  
   722  			v.Set(reflect.Append(v, elem))
   723  		} else { // array
   724  			if idx >= v.Len() {
   725  				return nil
   726  			}
   727  			elem := v.Index(idx)
   728  			err := d.handleValue(n, elem)
   729  			if err != nil {
   730  				return err
   731  			}
   732  			idx++
   733  		}
   734  	}
   735  
   736  	return nil
   737  }
   738  
   739  func (d *decoder) unmarshalInlineTable(itable *unstable.Node, v reflect.Value) error {
   740  	// Make sure v is an initialized object.
   741  	switch v.Kind() {
   742  	case reflect.Map:
   743  		if v.IsNil() {
   744  			v.Set(reflect.MakeMap(v.Type()))
   745  		}
   746  	case reflect.Struct:
   747  	// structs are always initialized.
   748  	case reflect.Interface:
   749  		elem := v.Elem()
   750  		if !elem.IsValid() {
   751  			elem = makeMapStringInterface()
   752  			v.Set(elem)
   753  		}
   754  		return d.unmarshalInlineTable(itable, elem)
   755  	default:
   756  		return unstable.NewParserError(d.p.Raw(itable.Raw), "cannot store inline table in Go type %s", v.Kind())
   757  	}
   758  
   759  	it := itable.Children()
   760  	for it.Next() {
   761  		n := it.Node()
   762  
   763  		x, err := d.handleKeyValue(n, v)
   764  		if err != nil {
   765  			return err
   766  		}
   767  		if x.IsValid() {
   768  			v = x
   769  		}
   770  	}
   771  
   772  	return nil
   773  }
   774  
   775  func (d *decoder) unmarshalDateTime(value *unstable.Node, v reflect.Value) error {
   776  	dt, err := parseDateTime(value.Data)
   777  	if err != nil {
   778  		return err
   779  	}
   780  
   781  	v.Set(reflect.ValueOf(dt))
   782  	return nil
   783  }
   784  
   785  func (d *decoder) unmarshalLocalDate(value *unstable.Node, v reflect.Value) error {
   786  	ld, err := parseLocalDate(value.Data)
   787  	if err != nil {
   788  		return err
   789  	}
   790  
   791  	if v.Type() == timeType {
   792  		cast := ld.AsTime(time.Local)
   793  		v.Set(reflect.ValueOf(cast))
   794  		return nil
   795  	}
   796  
   797  	v.Set(reflect.ValueOf(ld))
   798  
   799  	return nil
   800  }
   801  
   802  func (d *decoder) unmarshalLocalTime(value *unstable.Node, v reflect.Value) error {
   803  	lt, rest, err := parseLocalTime(value.Data)
   804  	if err != nil {
   805  		return err
   806  	}
   807  
   808  	if len(rest) > 0 {
   809  		return unstable.NewParserError(rest, "extra characters at the end of a local time")
   810  	}
   811  
   812  	v.Set(reflect.ValueOf(lt))
   813  	return nil
   814  }
   815  
   816  func (d *decoder) unmarshalLocalDateTime(value *unstable.Node, v reflect.Value) error {
   817  	ldt, rest, err := parseLocalDateTime(value.Data)
   818  	if err != nil {
   819  		return err
   820  	}
   821  
   822  	if len(rest) > 0 {
   823  		return unstable.NewParserError(rest, "extra characters at the end of a local date time")
   824  	}
   825  
   826  	if v.Type() == timeType {
   827  		cast := ldt.AsTime(time.Local)
   828  
   829  		v.Set(reflect.ValueOf(cast))
   830  		return nil
   831  	}
   832  
   833  	v.Set(reflect.ValueOf(ldt))
   834  
   835  	return nil
   836  }
   837  
   838  func (d *decoder) unmarshalBool(value *unstable.Node, v reflect.Value) error {
   839  	b := value.Data[0] == 't'
   840  
   841  	switch v.Kind() {
   842  	case reflect.Bool:
   843  		v.SetBool(b)
   844  	case reflect.Interface:
   845  		v.Set(reflect.ValueOf(b))
   846  	default:
   847  		return unstable.NewParserError(value.Data, "cannot assign boolean to a %t", b)
   848  	}
   849  
   850  	return nil
   851  }
   852  
   853  func (d *decoder) unmarshalFloat(value *unstable.Node, v reflect.Value) error {
   854  	f, err := parseFloat(value.Data)
   855  	if err != nil {
   856  		return err
   857  	}
   858  
   859  	switch v.Kind() {
   860  	case reflect.Float64:
   861  		v.SetFloat(f)
   862  	case reflect.Float32:
   863  		if f > math.MaxFloat32 {
   864  			return unstable.NewParserError(value.Data, "number %f does not fit in a float32", f)
   865  		}
   866  		v.SetFloat(f)
   867  	case reflect.Interface:
   868  		v.Set(reflect.ValueOf(f))
   869  	default:
   870  		return unstable.NewParserError(value.Data, "float cannot be assigned to %s", v.Kind())
   871  	}
   872  
   873  	return nil
   874  }
   875  
   876  const (
   877  	maxInt = int64(^uint(0) >> 1)
   878  	minInt = -maxInt - 1
   879  )
   880  
   881  // Maximum value of uint for decoding. Currently the decoder parses the integer
   882  // into an int64. As a result, on architectures where uint is 64 bits, the
   883  // effective maximum uint we can decode is the maximum of int64. On
   884  // architectures where uint is 32 bits, the maximum value we can decode is
   885  // lower: the maximum of uint32. I didn't find a way to figure out this value at
   886  // compile time, so it is computed during initialization.
   887  var maxUint int64 = math.MaxInt64
   888  
   889  func init() {
   890  	m := uint64(^uint(0))
   891  	if m < uint64(maxUint) {
   892  		maxUint = int64(m)
   893  	}
   894  }
   895  
   896  func (d *decoder) unmarshalInteger(value *unstable.Node, v reflect.Value) error {
   897  	kind := v.Kind()
   898  	if kind == reflect.Float32 || kind == reflect.Float64 {
   899  		return d.unmarshalFloat(value, v)
   900  	}
   901  
   902  	i, err := parseInteger(value.Data)
   903  	if err != nil {
   904  		return err
   905  	}
   906  
   907  	var r reflect.Value
   908  
   909  	switch kind {
   910  	case reflect.Int64:
   911  		v.SetInt(i)
   912  		return nil
   913  	case reflect.Int32:
   914  		if i < math.MinInt32 || i > math.MaxInt32 {
   915  			return fmt.Errorf("toml: number %d does not fit in an int32", i)
   916  		}
   917  
   918  		r = reflect.ValueOf(int32(i))
   919  	case reflect.Int16:
   920  		if i < math.MinInt16 || i > math.MaxInt16 {
   921  			return fmt.Errorf("toml: number %d does not fit in an int16", i)
   922  		}
   923  
   924  		r = reflect.ValueOf(int16(i))
   925  	case reflect.Int8:
   926  		if i < math.MinInt8 || i > math.MaxInt8 {
   927  			return fmt.Errorf("toml: number %d does not fit in an int8", i)
   928  		}
   929  
   930  		r = reflect.ValueOf(int8(i))
   931  	case reflect.Int:
   932  		if i < minInt || i > maxInt {
   933  			return fmt.Errorf("toml: number %d does not fit in an int", i)
   934  		}
   935  
   936  		r = reflect.ValueOf(int(i))
   937  	case reflect.Uint64:
   938  		if i < 0 {
   939  			return fmt.Errorf("toml: negative number %d does not fit in an uint64", i)
   940  		}
   941  
   942  		r = reflect.ValueOf(uint64(i))
   943  	case reflect.Uint32:
   944  		if i < 0 || i > math.MaxUint32 {
   945  			return fmt.Errorf("toml: negative number %d does not fit in an uint32", i)
   946  		}
   947  
   948  		r = reflect.ValueOf(uint32(i))
   949  	case reflect.Uint16:
   950  		if i < 0 || i > math.MaxUint16 {
   951  			return fmt.Errorf("toml: negative number %d does not fit in an uint16", i)
   952  		}
   953  
   954  		r = reflect.ValueOf(uint16(i))
   955  	case reflect.Uint8:
   956  		if i < 0 || i > math.MaxUint8 {
   957  			return fmt.Errorf("toml: negative number %d does not fit in an uint8", i)
   958  		}
   959  
   960  		r = reflect.ValueOf(uint8(i))
   961  	case reflect.Uint:
   962  		if i < 0 || i > maxUint {
   963  			return fmt.Errorf("toml: negative number %d does not fit in an uint", i)
   964  		}
   965  
   966  		r = reflect.ValueOf(uint(i))
   967  	case reflect.Interface:
   968  		r = reflect.ValueOf(i)
   969  	default:
   970  		return unstable.NewParserError(d.p.Raw(value.Raw), d.typeMismatchString("integer", v.Type()))
   971  	}
   972  
   973  	if !r.Type().AssignableTo(v.Type()) {
   974  		r = r.Convert(v.Type())
   975  	}
   976  
   977  	v.Set(r)
   978  
   979  	return nil
   980  }
   981  
   982  func (d *decoder) unmarshalString(value *unstable.Node, v reflect.Value) error {
   983  	switch v.Kind() {
   984  	case reflect.String:
   985  		v.SetString(string(value.Data))
   986  	case reflect.Interface:
   987  		v.Set(reflect.ValueOf(string(value.Data)))
   988  	default:
   989  		return unstable.NewParserError(d.p.Raw(value.Raw), d.typeMismatchString("string", v.Type()))
   990  	}
   991  
   992  	return nil
   993  }
   994  
   995  func (d *decoder) handleKeyValue(expr *unstable.Node, v reflect.Value) (reflect.Value, error) {
   996  	d.strict.EnterKeyValue(expr)
   997  
   998  	v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v)
   999  	if d.skipUntilTable {
  1000  		d.strict.MissingField(expr)
  1001  		d.skipUntilTable = false
  1002  	}
  1003  
  1004  	d.strict.ExitKeyValue(expr)
  1005  
  1006  	return v, err
  1007  }
  1008  
  1009  func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
  1010  	if key.Next() {
  1011  		// Still scoping the key
  1012  		return d.handleKeyValuePart(key, value, v)
  1013  	}
  1014  	// Done scoping the key.
  1015  	// v is whatever Go value we need to fill.
  1016  	return reflect.Value{}, d.handleValue(value, v)
  1017  }
  1018  
  1019  func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value, error) {
  1020  	switch {
  1021  	case stringType.AssignableTo(keyType):
  1022  		return reflect.ValueOf(string(data)), nil
  1023  
  1024  	case stringType.ConvertibleTo(keyType):
  1025  		return reflect.ValueOf(string(data)).Convert(keyType), nil
  1026  
  1027  	case keyType.Implements(textUnmarshalerType):
  1028  		mk := reflect.New(keyType.Elem())
  1029  		if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
  1030  			return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
  1031  		}
  1032  		return mk, nil
  1033  
  1034  	case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
  1035  		mk := reflect.New(keyType)
  1036  		if err := mk.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
  1037  			return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err)
  1038  		}
  1039  		return mk.Elem(), nil
  1040  	}
  1041  	return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType)
  1042  }
  1043  
  1044  func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
  1045  	// contains the replacement for v
  1046  	var rv reflect.Value
  1047  
  1048  	// First, dispatch over v to make sure it is a valid object.
  1049  	// There is no guarantee over what it could be.
  1050  	switch v.Kind() {
  1051  	case reflect.Map:
  1052  		vt := v.Type()
  1053  
  1054  		mk, err := d.keyFromData(vt.Key(), key.Node().Data)
  1055  		if err != nil {
  1056  			return reflect.Value{}, err
  1057  		}
  1058  
  1059  		// If the map does not exist, create it.
  1060  		if v.IsNil() {
  1061  			v = reflect.MakeMap(vt)
  1062  			rv = v
  1063  		}
  1064  
  1065  		mv := v.MapIndex(mk)
  1066  		set := false
  1067  		if !mv.IsValid() || key.IsLast() {
  1068  			set = true
  1069  			mv = reflect.New(v.Type().Elem()).Elem()
  1070  		}
  1071  
  1072  		nv, err := d.handleKeyValueInner(key, value, mv)
  1073  		if err != nil {
  1074  			return reflect.Value{}, err
  1075  		}
  1076  		if nv.IsValid() {
  1077  			mv = nv
  1078  			set = true
  1079  		}
  1080  
  1081  		if set {
  1082  			v.SetMapIndex(mk, mv)
  1083  		}
  1084  	case reflect.Struct:
  1085  		path, found := structFieldPath(v, string(key.Node().Data))
  1086  		if !found {
  1087  			d.skipUntilTable = true
  1088  			break
  1089  		}
  1090  
  1091  		if d.errorContext == nil {
  1092  			d.errorContext = new(errorContext)
  1093  		}
  1094  		t := v.Type()
  1095  		d.errorContext.Struct = t
  1096  		d.errorContext.Field = path
  1097  
  1098  		f := fieldByIndex(v, path)
  1099  
  1100  		if !f.CanAddr() {
  1101  			// If the field is not addressable, need to take a slower path and
  1102  			// make a copy of the struct itself to a new location.
  1103  			nvp := reflect.New(v.Type())
  1104  			nvp.Elem().Set(v)
  1105  			v = nvp.Elem()
  1106  			_, err := d.handleKeyValuePart(key, value, v)
  1107  			if err != nil {
  1108  				return reflect.Value{}, err
  1109  			}
  1110  			return nvp.Elem(), nil
  1111  		}
  1112  		x, err := d.handleKeyValueInner(key, value, f)
  1113  		if err != nil {
  1114  			return reflect.Value{}, err
  1115  		}
  1116  
  1117  		if x.IsValid() {
  1118  			f.Set(x)
  1119  		}
  1120  		d.errorContext.Struct = nil
  1121  		d.errorContext.Field = nil
  1122  	case reflect.Interface:
  1123  		v = v.Elem()
  1124  
  1125  		// Following encoding/json: decoding an object into an
  1126  		// interface{}, it needs to always hold a
  1127  		// map[string]interface{}. This is for the types to be
  1128  		// consistent whether a previous value was set or not.
  1129  		if !v.IsValid() || v.Type() != mapStringInterfaceType {
  1130  			v = makeMapStringInterface()
  1131  		}
  1132  
  1133  		x, err := d.handleKeyValuePart(key, value, v)
  1134  		if err != nil {
  1135  			return reflect.Value{}, err
  1136  		}
  1137  		if x.IsValid() {
  1138  			v = x
  1139  		}
  1140  		rv = v
  1141  	case reflect.Ptr:
  1142  		elem := v.Elem()
  1143  		if !elem.IsValid() {
  1144  			ptr := reflect.New(v.Type().Elem())
  1145  			v.Set(ptr)
  1146  			rv = v
  1147  			elem = ptr.Elem()
  1148  		}
  1149  
  1150  		elem2, err := d.handleKeyValuePart(key, value, elem)
  1151  		if err != nil {
  1152  			return reflect.Value{}, err
  1153  		}
  1154  		if elem2.IsValid() {
  1155  			elem = elem2
  1156  		}
  1157  		v.Elem().Set(elem)
  1158  	default:
  1159  		return reflect.Value{}, fmt.Errorf("unhandled kv part: %s", v.Kind())
  1160  	}
  1161  
  1162  	return rv, nil
  1163  }
  1164  
  1165  func initAndDereferencePointer(v reflect.Value) reflect.Value {
  1166  	var elem reflect.Value
  1167  	if v.IsNil() {
  1168  		ptr := reflect.New(v.Type().Elem())
  1169  		v.Set(ptr)
  1170  	}
  1171  	elem = v.Elem()
  1172  	return elem
  1173  }
  1174  
  1175  // Same as reflect.Value.FieldByIndex, but creates pointers if needed.
  1176  func fieldByIndex(v reflect.Value, path []int) reflect.Value {
  1177  	for _, x := range path {
  1178  		v = v.Field(x)
  1179  
  1180  		if v.Kind() == reflect.Ptr {
  1181  			if v.IsNil() {
  1182  				v.Set(reflect.New(v.Type().Elem()))
  1183  			}
  1184  			v = v.Elem()
  1185  		}
  1186  	}
  1187  	return v
  1188  }
  1189  
  1190  type fieldPathsMap = map[string][]int
  1191  
  1192  var globalFieldPathsCache atomic.Value // map[danger.TypeID]fieldPathsMap
  1193  
  1194  func structFieldPath(v reflect.Value, name string) ([]int, bool) {
  1195  	t := v.Type()
  1196  
  1197  	cache, _ := globalFieldPathsCache.Load().(map[danger.TypeID]fieldPathsMap)
  1198  	fieldPaths, ok := cache[danger.MakeTypeID(t)]
  1199  
  1200  	if !ok {
  1201  		fieldPaths = map[string][]int{}
  1202  
  1203  		forEachField(t, nil, func(name string, path []int) {
  1204  			fieldPaths[name] = path
  1205  			// extra copy for the case-insensitive match
  1206  			fieldPaths[strings.ToLower(name)] = path
  1207  		})
  1208  
  1209  		newCache := make(map[danger.TypeID]fieldPathsMap, len(cache)+1)
  1210  		newCache[danger.MakeTypeID(t)] = fieldPaths
  1211  		for k, v := range cache {
  1212  			newCache[k] = v
  1213  		}
  1214  		globalFieldPathsCache.Store(newCache)
  1215  	}
  1216  
  1217  	path, ok := fieldPaths[name]
  1218  	if !ok {
  1219  		path, ok = fieldPaths[strings.ToLower(name)]
  1220  	}
  1221  	return path, ok
  1222  }
  1223  
  1224  func forEachField(t reflect.Type, path []int, do func(name string, path []int)) {
  1225  	n := t.NumField()
  1226  	for i := 0; i < n; i++ {
  1227  		f := t.Field(i)
  1228  
  1229  		if !f.Anonymous && f.PkgPath != "" {
  1230  			// only consider exported fields.
  1231  			continue
  1232  		}
  1233  
  1234  		fieldPath := append(path, i)
  1235  		fieldPath = fieldPath[:len(fieldPath):len(fieldPath)]
  1236  
  1237  		name := f.Tag.Get("toml")
  1238  		if name == "-" {
  1239  			continue
  1240  		}
  1241  
  1242  		if i := strings.IndexByte(name, ','); i >= 0 {
  1243  			name = name[:i]
  1244  		}
  1245  
  1246  		if f.Anonymous && name == "" {
  1247  			t2 := f.Type
  1248  			if t2.Kind() == reflect.Ptr {
  1249  				t2 = t2.Elem()
  1250  			}
  1251  
  1252  			if t2.Kind() == reflect.Struct {
  1253  				forEachField(t2, fieldPath, do)
  1254  			}
  1255  			continue
  1256  		}
  1257  
  1258  		if name == "" {
  1259  			name = f.Name
  1260  		}
  1261  
  1262  		do(name, fieldPath)
  1263  	}
  1264  }
  1265  

View as plain text