1
2
3
4
5 package unify
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "io/fs"
12 "os"
13 "path/filepath"
14 "regexp"
15 "strings"
16
17 "gopkg.in/yaml.v3"
18 )
19
20
21
22 type ReadOpts struct {
23
24
25 FS fs.FS
26 }
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 func Read(r io.Reader, path string, opts ReadOpts) (Closure, error) {
74 dec := yamlDecoder{opts: opts, path: path, env: topEnv}
75 v, err := dec.read(r)
76 if err != nil {
77 return Closure{}, err
78 }
79 return dec.close(v), nil
80 }
81
82
83
84
85
86
87
88
89 func ReadFile(path string, opts ReadOpts) (Closure, error) {
90 f, err := os.Open(path)
91 if err != nil {
92 return Closure{}, err
93 }
94 defer f.Close()
95
96 if opts.FS == nil {
97 opts.FS = os.DirFS(filepath.Dir(path))
98 }
99
100 return Read(f, path, opts)
101 }
102
103
104
105
106
107 func (c *Closure) UnmarshalYAML(node *yaml.Node) error {
108 dec := yamlDecoder{path: "<yaml.Node>", env: topEnv}
109 v, err := dec.root(node)
110 if err != nil {
111 return err
112 }
113 *c = dec.close(v)
114 return nil
115 }
116
117 type yamlDecoder struct {
118 opts ReadOpts
119 path string
120
121 vars map[string]*ident
122 nSums int
123
124 env envSet
125 }
126
127 func (dec *yamlDecoder) read(r io.Reader) (*Value, error) {
128 n, err := readOneNode(r)
129 if err != nil {
130 return nil, fmt.Errorf("%s: %w", dec.path, err)
131 }
132
133
134 v, err := dec.root(n)
135 if err != nil {
136 return nil, fmt.Errorf("%s: %w", dec.path, err)
137 }
138
139 return v, nil
140 }
141
142
143
144 func readOneNode(r io.Reader) (*yaml.Node, error) {
145 yd := yaml.NewDecoder(r)
146
147
148 var node yaml.Node
149 if err := yd.Decode(&node); err != nil {
150 return nil, err
151 }
152 np := &node
153 if np.Kind == yaml.DocumentNode {
154 np = node.Content[0]
155 }
156
157
158 if err := yd.Decode(nil); err == nil {
159 return nil, fmt.Errorf("must not contain multiple documents")
160 } else if err != io.EOF {
161 return nil, err
162 }
163
164 return np, nil
165 }
166
167
168 func (dec *yamlDecoder) root(node *yaml.Node) (*Value, error) {
169
170
171 oldVars, oldNSums := dec.vars, dec.nSums
172 defer func() {
173 dec.vars, dec.nSums = oldVars, oldNSums
174 }()
175 dec.vars = make(map[string]*ident, 0)
176 dec.nSums = 0
177
178 return dec.value(node)
179 }
180
181
182 func (dec *yamlDecoder) close(v *Value) Closure {
183 return Closure{v, dec.env}
184 }
185
186 func (dec *yamlDecoder) value(node *yaml.Node) (vOut *Value, errOut error) {
187 pos := &Pos{Path: dec.path, Line: node.Line}
188
189
190 if node.Kind == yaml.AliasNode {
191 node = node.Alias
192 }
193
194 mk := func(d Domain) (*Value, error) {
195 v := &Value{Domain: d, pos: pos}
196 return v, nil
197 }
198 mk2 := func(d Domain, err error) (*Value, error) {
199 if err != nil {
200 return nil, err
201 }
202 return mk(d)
203 }
204
205
206 is := func(kind yaml.Kind, tag string) bool {
207 return node.Kind == kind && node.LongTag() == tag
208 }
209 isExact := func() bool {
210 if node.Kind != yaml.ScalarNode {
211 return false
212 }
213
214 switch node.LongTag() {
215 case "!string", "tag:yaml.org,2002:int", "tag:yaml.org,2002:float", "tag:yaml.org,2002:bool", "tag:yaml.org,2002:binary":
216 return true
217 }
218 return false
219 }
220
221
222
223
224 strVal := ""
225 isStr := is(yaml.ScalarNode, "tag:yaml.org,2002:str")
226 if isStr {
227 strVal = node.Value
228 }
229
230 switch {
231 case is(yaml.ScalarNode, "!var"):
232 strVal = "$" + node.Value
233 fallthrough
234 case strings.HasPrefix(strVal, "$"):
235 id, ok := dec.vars[strVal]
236 if !ok {
237
238
239
240 name, _, _ := strings.Cut(strVal, "#")
241 id = &ident{name: name}
242 dec.vars[strVal] = id
243 dec.env = dec.env.bind(id, topValue)
244 }
245 return mk(Var{id: id})
246
247 case strVal == "_" || is(yaml.ScalarNode, "!top"):
248 return mk(Top{})
249
250 case strVal == "_|_" || is(yaml.ScalarNode, "!bottom"):
251 return nil, errors.New("found bottom")
252
253 case isExact():
254 val := node.Value
255 return mk(NewStringExact(val))
256
257 case isStr || is(yaml.ScalarNode, "!regex"):
258
259
260 val := node.Value
261 return mk2(NewStringRegex(val))
262
263 case is(yaml.SequenceNode, "!regex"):
264 var vals []string
265 if err := node.Decode(&vals); err != nil {
266 return nil, err
267 }
268 return mk2(NewStringRegex(vals...))
269
270 case is(yaml.MappingNode, "tag:yaml.org,2002:map"):
271 var db DefBuilder
272 for i := 0; i < len(node.Content); i += 2 {
273 key := node.Content[i]
274 if key.Kind != yaml.ScalarNode {
275 return nil, fmt.Errorf("non-scalar key %q", key.Value)
276 }
277 val, err := dec.value(node.Content[i+1])
278 if err != nil {
279 return nil, err
280 }
281 db.Add(key.Value, val)
282 }
283 return mk(db.Build())
284
285 case is(yaml.SequenceNode, "tag:yaml.org,2002:seq"):
286 elts := node.Content
287 vs := make([]*Value, 0, len(elts))
288 for _, elt := range elts {
289 v, err := dec.value(elt)
290 if err != nil {
291 return nil, err
292 }
293 vs = append(vs, v)
294 }
295 return mk(NewTuple(vs...))
296
297 case is(yaml.SequenceNode, "!repeat") || is(yaml.SequenceNode, "!repeat-unify"):
298
299
300
301 if node.LongTag() == "!repeat" && len(node.Content) != 1 {
302 return nil, fmt.Errorf("!repeat must have exactly one child")
303 }
304
305
306
307
308 var gen []func(e envSet) (*Value, envSet)
309 origEnv := dec.env
310 elts := node.Content
311 for i, elt := range elts {
312 _, err := dec.value(elt)
313 if err != nil {
314 return nil, err
315 }
316
317
318
319
320
321
322 dec.env = origEnv
323
324 gen = append(gen, func(e envSet) (*Value, envSet) {
325 dec.env = e
326
327
328
329
330 v, err := dec.value(elts[i])
331 if err != nil {
332
333 panic("decoding repeat element failed")
334 }
335 return v, dec.env
336 })
337 }
338 return mk(NewRepeat(gen...))
339
340 case is(yaml.SequenceNode, "!sum"):
341 vs := make([]*Value, 0, len(node.Content))
342 for _, elt := range node.Content {
343 v, err := dec.value(elt)
344 if err != nil {
345 return nil, err
346 }
347 vs = append(vs, v)
348 }
349 if len(vs) == 1 {
350 return vs[0], nil
351 }
352
353
354
355 id := &ident{name: fmt.Sprintf("sum%d", dec.nSums)}
356 dec.nSums++
357 dec.env = dec.env.bind(id, vs...)
358 return mk(Var{id: id})
359
360 case is(yaml.ScalarNode, "!import"):
361 if dec.opts.FS == nil {
362 return nil, fmt.Errorf("!import not allowed (ReadOpts.FS not set)")
363 }
364 pat := node.Value
365
366 if !fs.ValidPath(pat) {
367
368
369 return nil, fmt.Errorf("!import path must not contain '.' or '..'")
370 }
371
372 ms, err := fs.Glob(dec.opts.FS, pat)
373 if err != nil {
374 return nil, fmt.Errorf("resolving !import: %w", err)
375 }
376 if len(ms) == 0 {
377 return nil, fmt.Errorf("!import did not match any files")
378 }
379
380
381 vs := make([]*Value, 0, len(ms))
382 for _, m := range ms {
383 v, err := dec.import1(m)
384 if err != nil {
385 return nil, err
386 }
387 vs = append(vs, v)
388 }
389
390
391 if len(vs) == 1 {
392 return vs[0], nil
393 }
394 id := &ident{name: "import"}
395 dec.env = dec.env.bind(id, vs...)
396 return mk(Var{id: id})
397 }
398
399 return nil, fmt.Errorf("unknown node kind %d %v", node.Kind, node.Tag)
400 }
401
402 func (dec *yamlDecoder) import1(path string) (*Value, error) {
403
404 f, err := dec.opts.FS.Open(path)
405 if err != nil {
406 return nil, fmt.Errorf("!import failed: %w", err)
407 }
408 defer f.Close()
409
410
411 oldFS, oldPath := dec.opts.FS, dec.path
412 defer func() {
413 dec.opts.FS, dec.path = oldFS, oldPath
414 }()
415
416
417 newPath := filepath.Join(filepath.Dir(dec.path), path)
418 subFS, err := fs.Sub(dec.opts.FS, filepath.Dir(path))
419 if err != nil {
420 return nil, err
421 }
422 dec.opts.FS, dec.path = subFS, newPath
423
424
425 return dec.read(f)
426 }
427
428 type yamlEncoder struct {
429 idp identPrinter
430 e envSet
431 }
432
433
434
435 func (c Closure) MarshalYAML() (any, error) {
436
437 enc := &yamlEncoder{}
438 return enc.closure(c), nil
439 }
440
441 func (c Closure) String() string {
442 b, err := yaml.Marshal(c)
443 if err != nil {
444 return fmt.Sprintf("marshal failed: %s", err)
445 }
446 return string(b)
447 }
448
449 func (v *Value) MarshalYAML() (any, error) {
450 enc := &yamlEncoder{e: topEnv}
451 return enc.value(v), nil
452 }
453
454 func (v *Value) String() string {
455 b, err := yaml.Marshal(v)
456 if err != nil {
457 return fmt.Sprintf("marshal failed: %s", err)
458 }
459 return string(b)
460 }
461
462 func (enc *yamlEncoder) closure(c Closure) *yaml.Node {
463 enc.e = c.env
464 var n yaml.Node
465 n.Kind = yaml.MappingNode
466 n.Tag = "!closure"
467 n.Content = make([]*yaml.Node, 4)
468 n.Content[0] = new(yaml.Node)
469 n.Content[0].SetString("env")
470 n.Content[2] = new(yaml.Node)
471 n.Content[2].SetString("in")
472 n.Content[3] = enc.value(c.val)
473
474
475 n.Content[1] = enc.env(enc.e)
476 enc.e = envSet{}
477 return &n
478 }
479
480 func (enc *yamlEncoder) env(e envSet) *yaml.Node {
481 var encode func(e *envExpr) *yaml.Node
482 encode = func(e *envExpr) *yaml.Node {
483 var n yaml.Node
484 switch e.kind {
485 default:
486 panic("bad kind")
487 case envZero:
488 n.SetString("0")
489 case envUnit:
490 n.SetString("1")
491 case envBinding:
492 var id yaml.Node
493 id.SetString(enc.idp.unique(e.id))
494 n.Kind = yaml.MappingNode
495 n.Content = []*yaml.Node{&id, enc.value(e.val)}
496 case envProduct, envSum:
497 n.Kind = yaml.SequenceNode
498 if e.kind == envProduct {
499 n.Tag = "!product"
500 } else {
501 n.Tag = "!sum"
502 }
503 for _, e2 := range e.operands {
504 n.Content = append(n.Content, encode(e2))
505 }
506 }
507 return &n
508 }
509 return encode(e.root)
510 }
511
512 var yamlIntRe = regexp.MustCompile(`^-?[0-9]+$`)
513
514 func (enc *yamlEncoder) value(v *Value) *yaml.Node {
515 var n yaml.Node
516 switch d := v.Domain.(type) {
517 case nil:
518
519
520
521
522
523 n.SetString("_|_")
524 return &n
525
526 case Top:
527 n.SetString("_")
528 return &n
529
530 case Def:
531 n.Kind = yaml.MappingNode
532 for k, elt := range d.All() {
533 var kn yaml.Node
534 kn.SetString(k)
535 n.Content = append(n.Content, &kn, enc.value(elt))
536 }
537 n.HeadComment = v.PosString()
538 return &n
539
540 case Tuple:
541 n.Kind = yaml.SequenceNode
542 if d.repeat == nil {
543 for _, elt := range d.vs {
544 n.Content = append(n.Content, enc.value(elt))
545 }
546 } else {
547 if len(d.repeat) == 1 {
548 n.Tag = "!repeat"
549 } else {
550 n.Tag = "!repeat-unify"
551 }
552
553 for _, gen := range d.repeat {
554 v, e := gen(enc.e)
555 enc.e = e
556 n.Content = append(n.Content, enc.value(v))
557 }
558 }
559 return &n
560
561 case String:
562 switch d.kind {
563 case stringExact:
564 n.SetString(d.exact)
565 switch {
566
567 case yamlIntRe.MatchString(d.exact):
568 n.Tag = "tag:yaml.org,2002:int"
569
570
571 case d.exact == "false" || d.exact == "true":
572 n.Tag = "tag:yaml.org,2002:bool"
573
574
575
576
577
578 case d.exact != regexp.QuoteMeta(d.exact):
579 n.Tag = "!string"
580 }
581 return &n
582 case stringRegex:
583 o := make([]string, 0, 1)
584 for _, re := range d.re {
585 s := re.String()
586 s = strings.TrimSuffix(strings.TrimPrefix(s, `\A(?:`), `)\z`)
587 o = append(o, s)
588 }
589 if len(o) == 1 {
590 n.SetString(o[0])
591 return &n
592 }
593 n.Encode(o)
594 n.Tag = "!regex"
595 return &n
596 }
597 panic("bad String kind")
598
599 case Var:
600
601
602
603 if false {
604 var vs []*Value
605 if len(vs) == 1 {
606 return enc.value(vs[0])
607 }
608 n.Kind = yaml.SequenceNode
609 n.Tag = "!sum"
610 for _, elt := range vs {
611 n.Content = append(n.Content, enc.value(elt))
612 }
613 return &n
614 }
615 n.SetString(enc.idp.unique(d.id))
616 if !strings.HasPrefix(d.id.name, "$") {
617 n.Tag = "!var"
618 }
619 return &n
620 }
621 panic(fmt.Sprintf("unknown domain type %T", v.Domain))
622 }
623
View as plain text