Source file src/cmd/compile/internal/types2/struct.go

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package types2
     6  
     7  import (
     8  	"cmd/compile/internal/syntax"
     9  	. "internal/types/errors"
    10  	"strconv"
    11  )
    12  
    13  // ----------------------------------------------------------------------------
    14  // API
    15  
    16  // A Struct represents a struct type.
    17  type Struct struct {
    18  	fields []*Var   // fields != nil indicates the struct is set up (possibly with len(fields) == 0)
    19  	tags   []string // field tags; nil if there are no tags
    20  }
    21  
    22  // NewStruct returns a new struct with the given fields and corresponding field tags.
    23  // If a field with index i has a tag, tags[i] must be that tag, but len(tags) may be
    24  // only as long as required to hold the tag with the largest index i. Consequently,
    25  // if no field has a tag, tags may be nil.
    26  func NewStruct(fields []*Var, tags []string) *Struct {
    27  	var fset objset
    28  	for _, f := range fields {
    29  		if f.name != "_" && fset.insert(f) != nil {
    30  			panic("multiple fields with the same name")
    31  		}
    32  	}
    33  	if len(tags) > len(fields) {
    34  		panic("more tags than fields")
    35  	}
    36  	s := &Struct{fields: fields, tags: tags}
    37  	s.markComplete()
    38  	return s
    39  }
    40  
    41  // NumFields returns the number of fields in the struct (including blank and embedded fields).
    42  func (s *Struct) NumFields() int { return len(s.fields) }
    43  
    44  // Field returns the i'th field for 0 <= i < NumFields().
    45  func (s *Struct) Field(i int) *Var { return s.fields[i] }
    46  
    47  // Tag returns the i'th field tag for 0 <= i < NumFields().
    48  func (s *Struct) Tag(i int) string {
    49  	if i < len(s.tags) {
    50  		return s.tags[i]
    51  	}
    52  	return ""
    53  }
    54  
    55  func (s *Struct) Underlying() Type { return s }
    56  func (s *Struct) String() string   { return TypeString(s, nil) }
    57  
    58  // ----------------------------------------------------------------------------
    59  // Implementation
    60  
    61  func (s *Struct) markComplete() {
    62  	if s.fields == nil {
    63  		s.fields = make([]*Var, 0)
    64  	}
    65  }
    66  
    67  func (check *Checker) structType(styp *Struct, e *syntax.StructType) {
    68  	if e.FieldList == nil {
    69  		styp.markComplete()
    70  		return
    71  	}
    72  
    73  	// struct fields and tags
    74  	var fields []*Var
    75  	var tags []string
    76  
    77  	// for double-declaration checks
    78  	var fset objset
    79  
    80  	// current field typ and tag
    81  	var typ Type
    82  	var tag string
    83  	add := func(ident *syntax.Name, embedded bool) {
    84  		if tag != "" && tags == nil {
    85  			tags = make([]string, len(fields))
    86  		}
    87  		if tags != nil {
    88  			tags = append(tags, tag)
    89  		}
    90  
    91  		pos := ident.Pos()
    92  		name := ident.Value
    93  		fld := NewField(pos, check.pkg, name, typ, embedded)
    94  		// spec: "Within a struct, non-blank field names must be unique."
    95  		if name == "_" || check.declareInSet(&fset, pos, fld) {
    96  			fields = append(fields, fld)
    97  			check.recordDef(ident, fld)
    98  		}
    99  	}
   100  
   101  	// addInvalid adds an embedded field of invalid type to the struct for
   102  	// fields with errors; this keeps the number of struct fields in sync
   103  	// with the source as long as the fields are _ or have different names
   104  	// (go.dev/issue/25627).
   105  	addInvalid := func(ident *syntax.Name) {
   106  		typ = Typ[Invalid]
   107  		tag = ""
   108  		add(ident, true)
   109  	}
   110  
   111  	var prev syntax.Expr
   112  	for i, f := range e.FieldList {
   113  		// Fields declared syntactically with the same type (e.g.: a, b, c T)
   114  		// share the same type expression. Only check type if it's a new type.
   115  		if i == 0 || f.Type != prev {
   116  			typ = check.varType(f.Type)
   117  			prev = f.Type
   118  		}
   119  		tag = ""
   120  		if i < len(e.TagList) {
   121  			tag = check.tag(e.TagList[i])
   122  		}
   123  		if f.Name != nil {
   124  			// named field
   125  			add(f.Name, false)
   126  		} else {
   127  			// embedded field
   128  			// spec: "An embedded type must be specified as a type name T or as a
   129  			// pointer to a non-interface type name *T, and T itself may not be a
   130  			// pointer type."
   131  			pos := syntax.StartPos(f.Type) // position of type, for errors
   132  			name := embeddedFieldIdent(f.Type)
   133  			if name == nil {
   134  				check.errorf(pos, InvalidSyntaxTree, "invalid embedded field type %s", f.Type)
   135  				name = syntax.NewName(pos, "_")
   136  				addInvalid(name)
   137  				continue
   138  			}
   139  			add(name, true) // struct{p.T} field has position of T
   140  
   141  			// Because we have a name, typ must be of the form T or *T, where T is the name
   142  			// of a (named or alias) type, and t (= deref(typ)) must be the type of T.
   143  			// We must delay this check to the end because we don't want to instantiate
   144  			// (via under(t)) a possibly incomplete type.
   145  			embeddedTyp := typ // for closure below
   146  			embeddedPos := pos
   147  			check.later(func() {
   148  				t, isPtr := deref(embeddedTyp)
   149  				switch u := under(t).(type) {
   150  				case *Basic:
   151  					if !isValid(t) {
   152  						// error was reported before
   153  						return
   154  					}
   155  					// unsafe.Pointer is treated like a regular pointer
   156  					if u.kind == UnsafePointer {
   157  						check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be unsafe.Pointer")
   158  					}
   159  				case *Pointer:
   160  					check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be a pointer")
   161  				case *Interface:
   162  					if isTypeParam(t) {
   163  						// The error code here is inconsistent with other error codes for
   164  						// invalid embedding, because this restriction may be relaxed in the
   165  						// future, and so it did not warrant a new error code.
   166  						check.error(embeddedPos, MisplacedTypeParam, "embedded field type cannot be a (pointer to a) type parameter")
   167  						break
   168  					}
   169  					if isPtr {
   170  						check.error(embeddedPos, InvalidPtrEmbed, "embedded field type cannot be a pointer to an interface")
   171  					}
   172  				}
   173  			}).describef(embeddedPos, "check embedded type %s", embeddedTyp)
   174  		}
   175  	}
   176  
   177  	styp.fields = fields
   178  	styp.tags = tags
   179  	styp.markComplete()
   180  }
   181  
   182  func embeddedFieldIdent(e syntax.Expr) *syntax.Name {
   183  	switch e := e.(type) {
   184  	case *syntax.Name:
   185  		return e
   186  	case *syntax.Operation:
   187  		if base := ptrBase(e); base != nil {
   188  			// *T is valid, but **T is not
   189  			if op, _ := base.(*syntax.Operation); op == nil || ptrBase(op) == nil {
   190  				return embeddedFieldIdent(e.X)
   191  			}
   192  		}
   193  	case *syntax.SelectorExpr:
   194  		return e.Sel
   195  	case *syntax.IndexExpr:
   196  		return embeddedFieldIdent(e.X)
   197  	}
   198  	return nil // invalid embedded field
   199  }
   200  
   201  func (check *Checker) declareInSet(oset *objset, pos syntax.Pos, obj Object) bool {
   202  	if alt := oset.insert(obj); alt != nil {
   203  		err := check.newError(DuplicateDecl)
   204  		err.addf(pos, "%s redeclared", obj.Name())
   205  		err.addAltDecl(alt)
   206  		err.report()
   207  		return false
   208  	}
   209  	return true
   210  }
   211  
   212  func (check *Checker) tag(t *syntax.BasicLit) string {
   213  	// If t.Bad, an error was reported during parsing.
   214  	if t != nil && !t.Bad {
   215  		if t.Kind == syntax.StringLit {
   216  			if val, err := strconv.Unquote(t.Value); err == nil {
   217  				return val
   218  			}
   219  		}
   220  		check.errorf(t, InvalidSyntaxTree, "incorrect tag syntax: %q", t.Value)
   221  	}
   222  	return ""
   223  }
   224  
   225  func ptrBase(x *syntax.Operation) syntax.Expr {
   226  	if x.Op == syntax.Mul && x.Y == nil {
   227  		return x.X
   228  	}
   229  	return nil
   230  }
   231  

View as plain text