Source file src/simd/_gen/unify/closure.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package unify
     6  
     7  import (
     8  	"fmt"
     9  	"iter"
    10  	"maps"
    11  	"slices"
    12  )
    13  
    14  type Closure struct {
    15  	val *Value
    16  	env envSet
    17  }
    18  
    19  func NewSum(vs ...*Value) Closure {
    20  	id := &ident{name: "sum"}
    21  	return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)}
    22  }
    23  
    24  // IsBottom returns whether c consists of no values.
    25  func (c Closure) IsBottom() bool {
    26  	return c.val.Domain == nil
    27  }
    28  
    29  // Summands returns the top-level Values of c. This assumes the top-level of c
    30  // was constructed as a sum, and is mostly useful for debugging.
    31  func (c Closure) Summands() iter.Seq[*Value] {
    32  	return func(yield func(*Value) bool) {
    33  		var rec func(v *Value, env envSet) bool
    34  		rec = func(v *Value, env envSet) bool {
    35  			switch d := v.Domain.(type) {
    36  			case Var:
    37  				parts := env.partitionBy(d.id)
    38  				for _, part := range parts {
    39  					// It may be a sum of sums. Walk into this value.
    40  					if !rec(part.value, part.env) {
    41  						return false
    42  					}
    43  				}
    44  				return true
    45  			default:
    46  				return yield(v)
    47  			}
    48  		}
    49  		rec(c.val, c.env)
    50  	}
    51  }
    52  
    53  // All enumerates all possible concrete values of c by substituting variables
    54  // from the environment.
    55  //
    56  // E.g., enumerating this Value
    57  //
    58  //	a: !sum [1, 2]
    59  //	b: !sum [3, 4]
    60  //
    61  // results in
    62  //
    63  //   - {a: 1, b: 3}
    64  //   - {a: 1, b: 4}
    65  //   - {a: 2, b: 3}
    66  //   - {a: 2, b: 4}
    67  func (c Closure) All() iter.Seq[*Value] {
    68  	// In order to enumerate all concrete values under all possible variable
    69  	// bindings, we use a "non-deterministic continuation passing style" to
    70  	// implement this. We use CPS to traverse the Value tree, threading the
    71  	// (possibly narrowing) environment through that CPS following an Euler
    72  	// tour. Where the environment permits multiple choices, we invoke the same
    73  	// continuation for each choice. Similar to a yield function, the
    74  	// continuation can return false to stop the non-deterministic walk.
    75  	return func(yield func(*Value) bool) {
    76  		c.val.all1(c.env, func(v *Value, e envSet) bool {
    77  			return yield(v)
    78  		})
    79  	}
    80  }
    81  
    82  func (v *Value) all1(e envSet, cont func(*Value, envSet) bool) bool {
    83  	switch d := v.Domain.(type) {
    84  	default:
    85  		panic(fmt.Sprintf("unknown domain type %T", d))
    86  
    87  	case nil:
    88  		return true
    89  
    90  	case Top, String:
    91  		return cont(v, e)
    92  
    93  	case Def:
    94  		fields := d.keys()
    95  		// We can reuse this parts slice because we're doing a DFS through the
    96  		// state space. (Otherwise, we'd have to do some messy threading of an
    97  		// immutable slice-like value through allElt.)
    98  		parts := make(map[string]*Value, len(fields))
    99  
   100  		// TODO: If there are no Vars or Sums under this Def, then nothing can
   101  		// change the Value or env, so we could just cont(v, e).
   102  		var allElt func(elt int, e envSet) bool
   103  		allElt = func(elt int, e envSet) bool {
   104  			if elt == len(fields) {
   105  				// Build a new Def from the concrete parts. Clone parts because
   106  				// we may reuse it on other non-deterministic branches.
   107  				nVal := newValueFrom(Def{maps.Clone(parts)}, v)
   108  				return cont(nVal, e)
   109  			}
   110  
   111  			return d.fields[fields[elt]].all1(e, func(v *Value, e envSet) bool {
   112  				parts[fields[elt]] = v
   113  				return allElt(elt+1, e)
   114  			})
   115  		}
   116  		return allElt(0, e)
   117  
   118  	case Tuple:
   119  		// Essentially the same as Def.
   120  		if d.repeat != nil {
   121  			// There's nothing we can do with this.
   122  			return cont(v, e)
   123  		}
   124  		parts := make([]*Value, len(d.vs))
   125  		var allElt func(elt int, e envSet) bool
   126  		allElt = func(elt int, e envSet) bool {
   127  			if elt == len(d.vs) {
   128  				// Build a new tuple from the concrete parts. Clone parts because
   129  				// we may reuse it on other non-deterministic branches.
   130  				nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v)
   131  				return cont(nVal, e)
   132  			}
   133  
   134  			return d.vs[elt].all1(e, func(v *Value, e envSet) bool {
   135  				parts[elt] = v
   136  				return allElt(elt+1, e)
   137  			})
   138  		}
   139  		return allElt(0, e)
   140  
   141  	case Var:
   142  		// Go each way this variable can be bound.
   143  		for _, ePart := range e.partitionBy(d.id) {
   144  			// d.id is no longer bound in this environment partition. We'll may
   145  			// need it later in the Euler tour, so bind it back to this single
   146  			// value.
   147  			env := ePart.env.bind(d.id, ePart.value)
   148  			if !ePart.value.all1(env, cont) {
   149  				return false
   150  			}
   151  		}
   152  		return true
   153  	}
   154  }
   155  

View as plain text