// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package unify import ( "errors" "fmt" "io" "io/fs" "os" "path/filepath" "regexp" "strings" "gopkg.in/yaml.v3" ) // ReadOpts provides options to [Read] and related functions. The zero value is // the default options. type ReadOpts struct { // FS, if non-nil, is the file system from which to resolve !import file // names. FS fs.FS } // Read reads a [Closure] in YAML format from r, using path for error messages. // // It maps YAML nodes into terminal Values as follows: // // - "_" or !top _ is the top value ([Top]). // // - "_|_" or !bottom _ is the bottom value. This is an error during // unmarshaling, but can appear in marshaled values. // // - "$" or !var is a variable ([Var]). Everywhere the same name // appears within a single unmarshal operation, it is mapped to the same // variable. Different unmarshal operations get different variables, even if // they have the same string name. // // - !regex "x" is a regular expression ([String]), as is any string that // doesn't match "_", "_|_", or "$...". Regular expressions are implicitly // anchored at the beginning and end. If the string doesn't contain any // meta-characters (that is, it's a "literal" regular expression), then it's // treated as an exact string. // // - !string "x", or any int, float, bool, or binary value is an exact string // ([String]). // // - !regex [x, y, ...] is an intersection of regular expressions ([String]). // // It maps YAML nodes into non-terminal Values as follows: // // - Sequence nodes like [x, y, z] are tuples ([Tuple]). // // - !repeat [x] is a repeated tuple ([Tuple]), which is 0 or more instances of // x. There must be exactly one element in the list. // // - Mapping nodes like {a: x, b: y} are defs ([Def]). Any fields not listed are // implicitly top. // // - !sum [x, y, z] is a sum of its children. This can be thought of as a union // of the values x, y, and z, or as a non-deterministic choice between x, y, and // z. If a variable appears both inside the sum and outside of it, only the // non-deterministic choice view really works. The unifier does not directly // implement sums; instead, this is decoded as a fresh variable that's // simultaneously bound to x, y, and z. // // - !import glob is like a !sum, but its children are read from all files // matching the given glob pattern, which is interpreted relative to the current // file path. Each file gets its own variable scope. func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) { dec := yamlDecoder{opts: opts, path: path, env: topEnv} v, err := dec.read(r) if err != nil { return Closure{}, err } return dec.close(v), nil } // ReadFile reads a [Closure] in YAML format from a file. // // The file must consist of a single YAML document. // // If opts.FS is not set, this sets it to a FS rooted at path's directory. // // See [Read] for details. func ReadFile(path string, opts ReadOpts) (Closure, error) { f, err := os.Open(path) if err != nil { return Closure{}, err } defer f.Close() if opts.FS == nil { opts.FS = os.DirFS(filepath.Dir(path)) } return Read(f, path, opts) } // UnmarshalYAML implements [yaml.Unmarshaler]. // // Since there is no way to pass [ReadOpts] to this function, it assumes default // options. func (c *Closure) UnmarshalYAML(node *yaml.Node) error { dec := yamlDecoder{path: "", env: topEnv} v, err := dec.root(node) if err != nil { return err } *c = dec.close(v) return nil } type yamlDecoder struct { opts ReadOpts path string vars map[string]*ident nSums int env envSet } func (dec *yamlDecoder) read(r io.Reader) (*Value, error) { n, err := readOneNode(r) if err != nil { return nil, fmt.Errorf("%s: %w", dec.path, err) } // Decode YAML node to a Value v, err := dec.root(n) if err != nil { return nil, fmt.Errorf("%s: %w", dec.path, err) } return v, nil } // readOneNode reads a single YAML document from r and returns an error if there // are more documents in r. func readOneNode(r io.Reader) (*yaml.Node, error) { yd := yaml.NewDecoder(r) // Decode as a YAML node var node yaml.Node if err := yd.Decode(&node); err != nil { return nil, err } np := &node if np.Kind == yaml.DocumentNode { np = node.Content[0] } // Ensure there are no more YAML docs in this file if err := yd.Decode(nil); err == nil { return nil, fmt.Errorf("must not contain multiple documents") } else if err != io.EOF { return nil, err } return np, nil } // root parses the root of a file. func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) { // Prepare for variable name resolution in this file. This may be a nested // root, so restore the current values when we're done. oldVars, oldNSums := dec.vars, dec.nSums defer func() { dec.vars, dec.nSums = oldVars, oldNSums }() dec.vars = make(map[string]*ident, 0) dec.nSums = 0 return dec.value(node) } // close wraps a decoded [Value] into a [Closure]. func (dec *yamlDecoder) close(v *Value) Closure { return Closure{v, dec.env} } func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) { pos := &Pos{Path: dec.path, Line: node.Line} // Resolve alias nodes. if node.Kind == yaml.AliasNode { node = node.Alias } mk := func(d Domain) (*Value, error) { v := &Value{Domain: d, pos: pos} return v, nil } mk2 := func(d Domain, err error) (*Value, error) { if err != nil { return nil, err } return mk(d) } // is tests the kind and long tag of node. is := func(kind yaml.Kind, tag string) bool { return node.Kind == kind && node.LongTag() == tag } isExact := func() bool { if node.Kind != yaml.ScalarNode { return false } // We treat any string-ish YAML node as a string. switch node.LongTag() { case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary": return true } return false } // !!str nodes provide a short-hand syntax for several leaf domains that are // also available under explicit tags. To simplify checking below, we set // strVal to non-"" only for !!str nodes. strVal := "" isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str") if isStr { strVal = node.Value } switch { case is(yaml.ScalarNode, "!var"): strVal = "$" + node.Value fallthrough case strings.HasPrefix(strVal, "$"): id, ok := dec.vars[strVal] if !ok { // We encode different idents with the same string name by adding a // #N suffix. Strip that off so it doesn't accumulate. This isn't // meant to be used in user-written input, though nothing stops that. name, _, _ := strings.Cut(strVal, "#") id = &ident{name: name} dec.vars[strVal] = id dec.env = dec.env.bind(id, topValue) } return mk(Var{id: id}) case strVal == "_" || is(yaml.ScalarNode, "!top"): return mk(Top{}) case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"): return nil, errors.New("found bottom") case isExact(): val := node.Value return mk(NewStringExact(val)) case isStr || is(yaml.ScalarNode, "!regex"): // Any other string we treat as a regex. This will produce an exact // string anyway if the regex is literal. val := node.Value return mk2(NewStringRegex(val)) case is(yaml.SequenceNode, "!regex"): var vals []string if err := node.Decode(&vals); err != nil { return nil, err } return mk2(NewStringRegex(vals...)) case is(yaml.MappingNode, "tag:yaml.org,2002:map"): var db DefBuilder for i := 0; i < len(node.Content); i += 2 { key := node.Content[i] if key.Kind != yaml.ScalarNode { return nil, fmt.Errorf("non-scalar key %q", key.Value) } val, err := dec.value(node.Content[i+1]) if err != nil { return nil, err } db.Add(key.Value, val) } return mk(db.Build()) case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"): elts := node.Content vs := make([]*Value, 0, len(elts)) for _, elt := range elts { v, err := dec.value(elt) if err != nil { return nil, err } vs = append(vs, v) } return mk(NewTuple(vs...)) case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"): // !repeat must have one child. !repeat-unify is used internally for // delayed unification, and is the same, it's just allowed to have more // than one child. if node.LongTag() == "!repeat" && len(node.Content) != 1 { return nil, fmt.Errorf("!repeat must have exactly one child") } // Decode the children to make sure they're well-formed, but otherwise // discard that decoding and do it again every time we need a new // element. var gen []func(e envSet) (*Value, envSet) origEnv := dec.env elts := node.Content for i, elt := range elts { _, err := dec.value(elt) if err != nil { return nil, err } // Undo any effects on the environment. We *do* keep any named // variables that were added to the vars map in case they were // introduced within the element. // // TODO: If we change how we implement repeat nodes, we might be // able to drop yamlEncoder.env and yamlDecoder.env. dec.env = origEnv // Add a generator function gen = append(gen, func(e envSet) (*Value, envSet) { dec.env = e // TODO: If this is in a sum, this tends to generate a ton of // fresh variables that are different on each branch of the // parent sum. Does it make sense to hold on to the i'th value // of the tuple after we've generated it? v, err := dec.value(elts[i]) if err != nil { // It worked the first time, so this really shouldn't hapen. panic("decoding repeat element failed") } return v, dec.env }) } return mk(NewRepeat(gen...)) case is(yaml.SequenceNode, "!sum"): vs := make([]*Value, 0, len(node.Content)) for _, elt := range node.Content { v, err := dec.value(elt) if err != nil { return nil, err } vs = append(vs, v) } if len(vs) == 1 { return vs[0], nil } // A sum is implemented as a fresh variable that's simultaneously bound // to each of the descendants. id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)} dec.nSums++ dec.env = dec.env.bind(id, vs...) return mk(Var{id: id}) case is(yaml.ScalarNode, "!import"): if dec.opts.FS == nil { return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)") } pat := node.Value if !fs.ValidPath(pat) { // This will result in Glob returning no results. Give a more useful // error message for this case. return nil, fmt.Errorf("!import path must not contain '.' or '..'") } ms, err := fs.Glob(dec.opts.FS, pat) if err != nil { return nil, fmt.Errorf("resolving !import: %w", err) } if len(ms) == 0 { return nil, fmt.Errorf("!import did not match any files") } // Parse each file vs := make([]*Value, 0, len(ms)) for _, m := range ms { v, err := dec.import1(m) if err != nil { return nil, err } vs = append(vs, v) } // Create a sum. if len(vs) == 1 { return vs[0], nil } id := &ident{name: "import"} dec.env = dec.env.bind(id, vs...) return mk(Var{id: id}) } return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag) } func (dec *yamlDecoder) import1(path string) (*Value, error) { // Make sure we can open the path first. f, err := dec.opts.FS.Open(path) if err != nil { return nil, fmt.Errorf("!import failed: %w", err) } defer f.Close() // Prepare the enter path. oldFS, oldPath := dec.opts.FS, dec.path defer func() { dec.opts.FS, dec.path = oldFS, oldPath }() // Enter path, which is relative to the current path's directory. newPath := filepath.Join(filepath.Dir(dec.path), path) subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path)) if err != nil { return nil, err } dec.opts.FS, dec.path = subFS, newPath // Parse the file. return dec.read(f) } type yamlEncoder struct { idp identPrinter e envSet // We track the environment for !repeat nodes. } // TODO: Switch some Value marshaling to Closure? func (c Closure) MarshalYAML() (any, error) { // TODO: If the environment is trivial, just marshal the value. enc := &yamlEncoder{} return enc.closure(c), nil } func (c Closure) String() string { b, err := yaml.Marshal(c) if err != nil { return fmt.Sprintf("marshal failed: %s", err) } return string(b) } func (v *Value) MarshalYAML() (any, error) { enc := &yamlEncoder{e: topEnv} return enc.value(v), nil } func (v *Value) String() string { b, err := yaml.Marshal(v) if err != nil { return fmt.Sprintf("marshal failed: %s", err) } return string(b) } func (enc *yamlEncoder) closure(c Closure) *yaml.Node { enc.e = c.env var n yaml.Node n.Kind = yaml.MappingNode n.Tag = "!closure" n.Content = make([]*yaml.Node, 4) n.Content[0] = new(yaml.Node) n.Content[0].SetString("env") n.Content[2] = new(yaml.Node) n.Content[2].SetString("in") n.Content[3] = enc.value(c.val) // Fill in the env after we've written the value in case value encoding // affects the env. n.Content[1] = enc.env(enc.e) enc.e = envSet{} // Allow GC'ing the env return &n } func (enc *yamlEncoder) env(e envSet) *yaml.Node { var encode func(e *envExpr) *yaml.Node encode = func(e *envExpr) *yaml.Node { var n yaml.Node switch e.kind { default: panic("bad kind") case envZero: n.SetString("0") case envUnit: n.SetString("1") case envBinding: var id yaml.Node id.SetString(enc.idp.unique(e.id)) n.Kind = yaml.MappingNode n.Content = []*yaml.Node{&id, enc.value(e.val)} case envProduct, envSum: n.Kind = yaml.SequenceNode if e.kind == envProduct { n.Tag = "!product" } else { n.Tag = "!sum" } for _, e2 := range e.operands { n.Content = append(n.Content, encode(e2)) } } return &n } return encode(e.root) } var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`) func (enc *yamlEncoder) value(v *Value) *yaml.Node { var n yaml.Node switch d := v.Domain.(type) { case nil: // Not allowed by unmarshaler, but useful for understanding when // something goes horribly wrong. // // TODO: We might be able to track useful provenance for this, which // would really help with debugging unexpected bottoms. n.SetString("_|_") return &n case Top: n.SetString("_") return &n case Def: n.Kind = yaml.MappingNode for k, elt := range d.All() { var kn yaml.Node kn.SetString(k) n.Content = append(n.Content, &kn, enc.value(elt)) } n.HeadComment = v.PosString() return &n case Tuple: n.Kind = yaml.SequenceNode if d.repeat == nil { for _, elt := range d.vs { n.Content = append(n.Content, enc.value(elt)) } } else { if len(d.repeat) == 1 { n.Tag = "!repeat" } else { n.Tag = "!repeat-unify" } // TODO: I'm not positive this will round-trip everything correctly. for _, gen := range d.repeat { v, e := gen(enc.e) enc.e = e n.Content = append(n.Content, enc.value(v)) } } return &n case String: switch d.kind { case stringExact: n.SetString(d.exact) switch { // Make this into a "nice" !!int node if I can. case yamlIntRe.MatchString(d.exact): n.Tag = "tag:yaml.org,2002:int" // Or a "nice" !!bool node. case d.exact == "false" || d.exact == "true": n.Tag = "tag:yaml.org,2002:bool" // If this doesn't require escaping, leave it as a str node to avoid // the annoying YAML tags. Otherwise, mark it as an exact string. // Alternatively, we could always emit a str node with regexp // quoting. case d.exact != regexp.QuoteMeta(d.exact): n.Tag = "!string" } return &n case stringRegex: o := make([]string, 0, 1) for _, re := range d.re { s := re.String() s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`) o = append(o, s) } if len(o) == 1 { n.SetString(o[0]) return &n } n.Encode(o) n.Tag = "!regex" return &n } panic("bad String kind") case Var: // TODO: If Var only appears once in the whole Value and is independent // in the environment (part of a term that is only over Var), then emit // this as a !sum instead. if false { var vs []*Value // TODO: Get values of this var. if len(vs) == 1 { return enc.value(vs[0]) } n.Kind = yaml.SequenceNode n.Tag = "!sum" for _, elt := range vs { n.Content = append(n.Content, enc.value(elt)) } return &n } n.SetString(enc.idp.unique(d.id)) if !strings.HasPrefix(d.id.name, "$") { n.Tag = "!var" } return &n } panic(fmt.Sprintf("unknown domain type %T", v.Domain)) }