Source file src/simd/_gen/unify/yaml.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package unify
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/fs"
    12  	"os"
    13  	"path/filepath"
    14  	"regexp"
    15  	"strings"
    16  
    17  	"gopkg.in/yaml.v3"
    18  )
    19  
    20  // ReadOpts provides options to [Read] and related functions. The zero value is
    21  // the default options.
    22  type ReadOpts struct {
    23  	// FS, if non-nil, is the file system from which to resolve !import file
    24  	// names.
    25  	FS fs.FS
    26  }
    27  
    28  // Read reads a [Closure] in YAML format from r, using path for error messages.
    29  //
    30  // It maps YAML nodes into terminal Values as follows:
    31  //
    32  // - "_" or !top _ is the top value ([Top]).
    33  //
    34  // - "_|_" or !bottom _ is the bottom value. This is an error during
    35  // unmarshaling, but can appear in marshaled values.
    36  //
    37  // - "$<name>" or !var <name> is a variable ([Var]). Everywhere the same name
    38  // appears within a single unmarshal operation, it is mapped to the same
    39  // variable. Different unmarshal operations get different variables, even if
    40  // they have the same string name.
    41  //
    42  // - !regex "x" is a regular expression ([String]), as is any string that
    43  // doesn't match "_", "_|_", or "$...". Regular expressions are implicitly
    44  // anchored at the beginning and end. If the string doesn't contain any
    45  // meta-characters (that is, it's a "literal" regular expression), then it's
    46  // treated as an exact string.
    47  //
    48  // - !string "x", or any int, float, bool, or binary value is an exact string
    49  // ([String]).
    50  //
    51  // - !regex [x, y, ...] is an intersection of regular expressions ([String]).
    52  //
    53  // It maps YAML nodes into non-terminal Values as follows:
    54  //
    55  // - Sequence nodes like [x, y, z] are tuples ([Tuple]).
    56  //
    57  // - !repeat [x] is a repeated tuple ([Tuple]), which is 0 or more instances of
    58  // x. There must be exactly one element in the list.
    59  //
    60  // - Mapping nodes like {a: x, b: y} are defs ([Def]). Any fields not listed are
    61  // implicitly top.
    62  //
    63  // - !sum [x, y, z] is a sum of its children. This can be thought of as a union
    64  // of the values x, y, and z, or as a non-deterministic choice between x, y, and
    65  // z. If a variable appears both inside the sum and outside of it, only the
    66  // non-deterministic choice view really works. The unifier does not directly
    67  // implement sums; instead, this is decoded as a fresh variable that's
    68  // simultaneously bound to x, y, and z.
    69  //
    70  // - !import glob is like a !sum, but its children are read from all files
    71  // matching the given glob pattern, which is interpreted relative to the current
    72  // file path. Each file gets its own variable scope.
    73  func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) {
    74  	dec := yamlDecoder{opts: opts, path: path, env: topEnv}
    75  	v, err := dec.read(r)
    76  	if err != nil {
    77  		return Closure{}, err
    78  	}
    79  	return dec.close(v), nil
    80  }
    81  
    82  // ReadFile reads a [Closure] in YAML format from a file.
    83  //
    84  // The file must consist of a single YAML document.
    85  //
    86  // If opts.FS is not set, this sets it to a FS rooted at path's directory.
    87  //
    88  // See [Read] for details.
    89  func ReadFile(path string, opts ReadOpts) (Closure, error) {
    90  	f, err := os.Open(path)
    91  	if err != nil {
    92  		return Closure{}, err
    93  	}
    94  	defer f.Close()
    95  
    96  	if opts.FS == nil {
    97  		opts.FS = os.DirFS(filepath.Dir(path))
    98  	}
    99  
   100  	return Read(f, path, opts)
   101  }
   102  
   103  // UnmarshalYAML implements [yaml.Unmarshaler].
   104  //
   105  // Since there is no way to pass [ReadOpts] to this function, it assumes default
   106  // options.
   107  func (c *Closure) UnmarshalYAML(node *yaml.Node) error {
   108  	dec := yamlDecoder{path: "<yaml.Node>", env: topEnv}
   109  	v, err := dec.root(node)
   110  	if err != nil {
   111  		return err
   112  	}
   113  	*c = dec.close(v)
   114  	return nil
   115  }
   116  
   117  type yamlDecoder struct {
   118  	opts ReadOpts
   119  	path string
   120  
   121  	vars  map[string]*ident
   122  	nSums int
   123  
   124  	env envSet
   125  }
   126  
   127  func (dec *yamlDecoder) read(r io.Reader) (*Value, error) {
   128  	n, err := readOneNode(r)
   129  	if err != nil {
   130  		return nil, fmt.Errorf("%s: %w", dec.path, err)
   131  	}
   132  
   133  	// Decode YAML node to a Value
   134  	v, err := dec.root(n)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("%s: %w", dec.path, err)
   137  	}
   138  
   139  	return v, nil
   140  }
   141  
   142  // readOneNode reads a single YAML document from r and returns an error if there
   143  // are more documents in r.
   144  func readOneNode(r io.Reader) (*yaml.Node, error) {
   145  	yd := yaml.NewDecoder(r)
   146  
   147  	// Decode as a YAML node
   148  	var node yaml.Node
   149  	if err := yd.Decode(&node); err != nil {
   150  		return nil, err
   151  	}
   152  	np := &node
   153  	if np.Kind == yaml.DocumentNode {
   154  		np = node.Content[0]
   155  	}
   156  
   157  	// Ensure there are no more YAML docs in this file
   158  	if err := yd.Decode(nil); err == nil {
   159  		return nil, fmt.Errorf("must not contain multiple documents")
   160  	} else if err != io.EOF {
   161  		return nil, err
   162  	}
   163  
   164  	return np, nil
   165  }
   166  
   167  // root parses the root of a file.
   168  func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) {
   169  	// Prepare for variable name resolution in this file. This may be a nested
   170  	// root, so restore the current values when we're done.
   171  	oldVars, oldNSums := dec.vars, dec.nSums
   172  	defer func() {
   173  		dec.vars, dec.nSums = oldVars, oldNSums
   174  	}()
   175  	dec.vars = make(map[string]*ident, 0)
   176  	dec.nSums = 0
   177  
   178  	return dec.value(node)
   179  }
   180  
   181  // close wraps a decoded [Value] into a [Closure].
   182  func (dec *yamlDecoder) close(v *Value) Closure {
   183  	return Closure{v, dec.env}
   184  }
   185  
   186  func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) {
   187  	pos := &Pos{Path: dec.path, Line: node.Line}
   188  
   189  	// Resolve alias nodes.
   190  	if node.Kind == yaml.AliasNode {
   191  		node = node.Alias
   192  	}
   193  
   194  	mk := func(d Domain) (*Value, error) {
   195  		v := &Value{Domain: d, pos: pos}
   196  		return v, nil
   197  	}
   198  	mk2 := func(d Domain, err error) (*Value, error) {
   199  		if err != nil {
   200  			return nil, err
   201  		}
   202  		return mk(d)
   203  	}
   204  
   205  	// is tests the kind and long tag of node.
   206  	is := func(kind yaml.Kind, tag string) bool {
   207  		return node.Kind == kind && node.LongTag() == tag
   208  	}
   209  	isExact := func() bool {
   210  		if node.Kind != yaml.ScalarNode {
   211  			return false
   212  		}
   213  		// We treat any string-ish YAML node as a string.
   214  		switch node.LongTag() {
   215  		case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary":
   216  			return true
   217  		}
   218  		return false
   219  	}
   220  
   221  	// !!str nodes provide a short-hand syntax for several leaf domains that are
   222  	// also available under explicit tags. To simplify checking below, we set
   223  	// strVal to non-"" only for !!str nodes.
   224  	strVal := ""
   225  	isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str")
   226  	if isStr {
   227  		strVal = node.Value
   228  	}
   229  
   230  	switch {
   231  	case is(yaml.ScalarNode, "!var"):
   232  		strVal = "$" + node.Value
   233  		fallthrough
   234  	case strings.HasPrefix(strVal, "$"):
   235  		id, ok := dec.vars[strVal]
   236  		if !ok {
   237  			// We encode different idents with the same string name by adding a
   238  			// #N suffix. Strip that off so it doesn't accumulate. This isn't
   239  			// meant to be used in user-written input, though nothing stops that.
   240  			name, _, _ := strings.Cut(strVal, "#")
   241  			id = &ident{name: name}
   242  			dec.vars[strVal] = id
   243  			dec.env = dec.env.bind(id, topValue)
   244  		}
   245  		return mk(Var{id: id})
   246  
   247  	case strVal == "_" || is(yaml.ScalarNode, "!top"):
   248  		return mk(Top{})
   249  
   250  	case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"):
   251  		return nil, errors.New("found bottom")
   252  
   253  	case isExact():
   254  		val := node.Value
   255  		return mk(NewStringExact(val))
   256  
   257  	case isStr || is(yaml.ScalarNode, "!regex"):
   258  		// Any other string we treat as a regex. This will produce an exact
   259  		// string anyway if the regex is literal.
   260  		val := node.Value
   261  		return mk2(NewStringRegex(val))
   262  
   263  	case is(yaml.SequenceNode, "!regex"):
   264  		var vals []string
   265  		if err := node.Decode(&vals); err != nil {
   266  			return nil, err
   267  		}
   268  		return mk2(NewStringRegex(vals...))
   269  
   270  	case is(yaml.MappingNode, "tag:yaml.org,2002:map"):
   271  		var db DefBuilder
   272  		for i := 0; i < len(node.Content); i += 2 {
   273  			key := node.Content[i]
   274  			if key.Kind != yaml.ScalarNode {
   275  				return nil, fmt.Errorf("non-scalar key %q", key.Value)
   276  			}
   277  			val, err := dec.value(node.Content[i+1])
   278  			if err != nil {
   279  				return nil, err
   280  			}
   281  			db.Add(key.Value, val)
   282  		}
   283  		return mk(db.Build())
   284  
   285  	case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"):
   286  		elts := node.Content
   287  		vs := make([]*Value, 0, len(elts))
   288  		for _, elt := range elts {
   289  			v, err := dec.value(elt)
   290  			if err != nil {
   291  				return nil, err
   292  			}
   293  			vs = append(vs, v)
   294  		}
   295  		return mk(NewTuple(vs...))
   296  
   297  	case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"):
   298  		// !repeat must have one child. !repeat-unify is used internally for
   299  		// delayed unification, and is the same, it's just allowed to have more
   300  		// than one child.
   301  		if node.LongTag() == "!repeat" && len(node.Content) != 1 {
   302  			return nil, fmt.Errorf("!repeat must have exactly one child")
   303  		}
   304  
   305  		// Decode the children to make sure they're well-formed, but otherwise
   306  		// discard that decoding and do it again every time we need a new
   307  		// element.
   308  		var gen []func(e envSet) (*Value, envSet)
   309  		origEnv := dec.env
   310  		elts := node.Content
   311  		for i, elt := range elts {
   312  			_, err := dec.value(elt)
   313  			if err != nil {
   314  				return nil, err
   315  			}
   316  			// Undo any effects on the environment. We *do* keep any named
   317  			// variables that were added to the vars map in case they were
   318  			// introduced within the element.
   319  			//
   320  			// TODO: If we change how we implement repeat nodes, we might be
   321  			// able to drop yamlEncoder.env and yamlDecoder.env.
   322  			dec.env = origEnv
   323  			// Add a generator function
   324  			gen = append(gen, func(e envSet) (*Value, envSet) {
   325  				dec.env = e
   326  				// TODO: If this is in a sum, this tends to generate a ton of
   327  				// fresh variables that are different on each branch of the
   328  				// parent sum. Does it make sense to hold on to the i'th value
   329  				// of the tuple after we've generated it?
   330  				v, err := dec.value(elts[i])
   331  				if err != nil {
   332  					// It worked the first time, so this really shouldn't hapen.
   333  					panic("decoding repeat element failed")
   334  				}
   335  				return v, dec.env
   336  			})
   337  		}
   338  		return mk(NewRepeat(gen...))
   339  
   340  	case is(yaml.SequenceNode, "!sum"):
   341  		vs := make([]*Value, 0, len(node.Content))
   342  		for _, elt := range node.Content {
   343  			v, err := dec.value(elt)
   344  			if err != nil {
   345  				return nil, err
   346  			}
   347  			vs = append(vs, v)
   348  		}
   349  		if len(vs) == 1 {
   350  			return vs[0], nil
   351  		}
   352  
   353  		// A sum is implemented as a fresh variable that's simultaneously bound
   354  		// to each of the descendants.
   355  		id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)}
   356  		dec.nSums++
   357  		dec.env = dec.env.bind(id, vs...)
   358  		return mk(Var{id: id})
   359  
   360  	case is(yaml.ScalarNode, "!import"):
   361  		if dec.opts.FS == nil {
   362  			return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)")
   363  		}
   364  		pat := node.Value
   365  
   366  		if !fs.ValidPath(pat) {
   367  			// This will result in Glob returning no results. Give a more useful
   368  			// error message for this case.
   369  			return nil, fmt.Errorf("!import path must not contain '.' or '..'")
   370  		}
   371  
   372  		ms, err := fs.Glob(dec.opts.FS, pat)
   373  		if err != nil {
   374  			return nil, fmt.Errorf("resolving !import: %w", err)
   375  		}
   376  		if len(ms) == 0 {
   377  			return nil, fmt.Errorf("!import did not match any files")
   378  		}
   379  
   380  		// Parse each file
   381  		vs := make([]*Value, 0, len(ms))
   382  		for _, m := range ms {
   383  			v, err := dec.import1(m)
   384  			if err != nil {
   385  				return nil, err
   386  			}
   387  			vs = append(vs, v)
   388  		}
   389  
   390  		// Create a sum.
   391  		if len(vs) == 1 {
   392  			return vs[0], nil
   393  		}
   394  		id := &ident{name: "import"}
   395  		dec.env = dec.env.bind(id, vs...)
   396  		return mk(Var{id: id})
   397  	}
   398  
   399  	return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag)
   400  }
   401  
   402  func (dec *yamlDecoder) import1(path string) (*Value, error) {
   403  	// Make sure we can open the path first.
   404  	f, err := dec.opts.FS.Open(path)
   405  	if err != nil {
   406  		return nil, fmt.Errorf("!import failed: %w", err)
   407  	}
   408  	defer f.Close()
   409  
   410  	// Prepare the enter path.
   411  	oldFS, oldPath := dec.opts.FS, dec.path
   412  	defer func() {
   413  		dec.opts.FS, dec.path = oldFS, oldPath
   414  	}()
   415  
   416  	// Enter path, which is relative to the current path's directory.
   417  	newPath := filepath.Join(filepath.Dir(dec.path), path)
   418  	subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path))
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  	dec.opts.FS, dec.path = subFS, newPath
   423  
   424  	// Parse the file.
   425  	return dec.read(f)
   426  }
   427  
   428  type yamlEncoder struct {
   429  	idp identPrinter
   430  	e   envSet // We track the environment for !repeat nodes.
   431  }
   432  
   433  // TODO: Switch some Value marshaling to Closure?
   434  
   435  func (c Closure) MarshalYAML() (any, error) {
   436  	// TODO: If the environment is trivial, just marshal the value.
   437  	enc := &yamlEncoder{}
   438  	return enc.closure(c), nil
   439  }
   440  
   441  func (c Closure) String() string {
   442  	b, err := yaml.Marshal(c)
   443  	if err != nil {
   444  		return fmt.Sprintf("marshal failed: %s", err)
   445  	}
   446  	return string(b)
   447  }
   448  
   449  func (v *Value) MarshalYAML() (any, error) {
   450  	enc := &yamlEncoder{e: topEnv}
   451  	return enc.value(v), nil
   452  }
   453  
   454  func (v *Value) String() string {
   455  	b, err := yaml.Marshal(v)
   456  	if err != nil {
   457  		return fmt.Sprintf("marshal failed: %s", err)
   458  	}
   459  	return string(b)
   460  }
   461  
   462  func (enc *yamlEncoder) closure(c Closure) *yaml.Node {
   463  	enc.e = c.env
   464  	var n yaml.Node
   465  	n.Kind = yaml.MappingNode
   466  	n.Tag = "!closure"
   467  	n.Content = make([]*yaml.Node, 4)
   468  	n.Content[0] = new(yaml.Node)
   469  	n.Content[0].SetString("env")
   470  	n.Content[2] = new(yaml.Node)
   471  	n.Content[2].SetString("in")
   472  	n.Content[3] = enc.value(c.val)
   473  	// Fill in the env after we've written the value in case value encoding
   474  	// affects the env.
   475  	n.Content[1] = enc.env(enc.e)
   476  	enc.e = envSet{} // Allow GC'ing the env
   477  	return &n
   478  }
   479  
   480  func (enc *yamlEncoder) env(e envSet) *yaml.Node {
   481  	var encode func(e *envExpr) *yaml.Node
   482  	encode = func(e *envExpr) *yaml.Node {
   483  		var n yaml.Node
   484  		switch e.kind {
   485  		default:
   486  			panic("bad kind")
   487  		case envZero:
   488  			n.SetString("0")
   489  		case envUnit:
   490  			n.SetString("1")
   491  		case envBinding:
   492  			var id yaml.Node
   493  			id.SetString(enc.idp.unique(e.id))
   494  			n.Kind = yaml.MappingNode
   495  			n.Content = []*yaml.Node{&id, enc.value(e.val)}
   496  		case envProduct, envSum:
   497  			n.Kind = yaml.SequenceNode
   498  			if e.kind == envProduct {
   499  				n.Tag = "!product"
   500  			} else {
   501  				n.Tag = "!sum"
   502  			}
   503  			for _, e2 := range e.operands {
   504  				n.Content = append(n.Content, encode(e2))
   505  			}
   506  		}
   507  		return &n
   508  	}
   509  	return encode(e.root)
   510  }
   511  
   512  var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`)
   513  
   514  func (enc *yamlEncoder) value(v *Value) *yaml.Node {
   515  	var n yaml.Node
   516  	switch d := v.Domain.(type) {
   517  	case nil:
   518  		// Not allowed by unmarshaler, but useful for understanding when
   519  		// something goes horribly wrong.
   520  		//
   521  		// TODO: We might be able to track useful provenance for this, which
   522  		// would really help with debugging unexpected bottoms.
   523  		n.SetString("_|_")
   524  		return &n
   525  
   526  	case Top:
   527  		n.SetString("_")
   528  		return &n
   529  
   530  	case Def:
   531  		n.Kind = yaml.MappingNode
   532  		for k, elt := range d.All() {
   533  			var kn yaml.Node
   534  			kn.SetString(k)
   535  			n.Content = append(n.Content, &kn, enc.value(elt))
   536  		}
   537  		n.HeadComment = v.PosString()
   538  		return &n
   539  
   540  	case Tuple:
   541  		n.Kind = yaml.SequenceNode
   542  		if d.repeat == nil {
   543  			for _, elt := range d.vs {
   544  				n.Content = append(n.Content, enc.value(elt))
   545  			}
   546  		} else {
   547  			if len(d.repeat) == 1 {
   548  				n.Tag = "!repeat"
   549  			} else {
   550  				n.Tag = "!repeat-unify"
   551  			}
   552  			// TODO: I'm not positive this will round-trip everything correctly.
   553  			for _, gen := range d.repeat {
   554  				v, e := gen(enc.e)
   555  				enc.e = e
   556  				n.Content = append(n.Content, enc.value(v))
   557  			}
   558  		}
   559  		return &n
   560  
   561  	case String:
   562  		switch d.kind {
   563  		case stringExact:
   564  			n.SetString(d.exact)
   565  			switch {
   566  			// Make this into a "nice" !!int node if I can.
   567  			case yamlIntRe.MatchString(d.exact):
   568  				n.Tag = "tag:yaml.org,2002:int"
   569  
   570  			// Or a "nice" !!bool node.
   571  			case d.exact == "false" || d.exact == "true":
   572  				n.Tag = "tag:yaml.org,2002:bool"
   573  
   574  			// If this doesn't require escaping, leave it as a str node to avoid
   575  			// the annoying YAML tags. Otherwise, mark it as an exact string.
   576  			// Alternatively, we could always emit a str node with regexp
   577  			// quoting.
   578  			case d.exact != regexp.QuoteMeta(d.exact):
   579  				n.Tag = "!string"
   580  			}
   581  			return &n
   582  		case stringRegex:
   583  			o := make([]string, 0, 1)
   584  			for _, re := range d.re {
   585  				s := re.String()
   586  				s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`)
   587  				o = append(o, s)
   588  			}
   589  			if len(o) == 1 {
   590  				n.SetString(o[0])
   591  				return &n
   592  			}
   593  			n.Encode(o)
   594  			n.Tag = "!regex"
   595  			return &n
   596  		}
   597  		panic("bad String kind")
   598  
   599  	case Var:
   600  		// TODO: If Var only appears once in the whole Value and is independent
   601  		// in the environment (part of a term that is only over Var), then emit
   602  		// this as a !sum instead.
   603  		if false {
   604  			var vs []*Value // TODO: Get values of this var.
   605  			if len(vs) == 1 {
   606  				return enc.value(vs[0])
   607  			}
   608  			n.Kind = yaml.SequenceNode
   609  			n.Tag = "!sum"
   610  			for _, elt := range vs {
   611  				n.Content = append(n.Content, enc.value(elt))
   612  			}
   613  			return &n
   614  		}
   615  		n.SetString(enc.idp.unique(d.id))
   616  		if !strings.HasPrefix(d.id.name, "$") {
   617  			n.Tag = "!var"
   618  		}
   619  		return &n
   620  	}
   621  	panic(fmt.Sprintf("unknown domain type %T", v.Domain))
   622  }
   623  

View as plain text