1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package zip_sum_test
17
18 import (
19 "context"
20 "crypto/sha256"
21 "encoding/csv"
22 "encoding/hex"
23 "flag"
24 "fmt"
25 "internal/testenv"
26 "io"
27 "os"
28 "path/filepath"
29 "strings"
30 "testing"
31
32 "cmd/go/internal/cfg"
33 "cmd/go/internal/modfetch"
34
35 "golang.org/x/mod/module"
36 )
37
38 var (
39 updateTestData = flag.Bool("u", false, "when set, tests may update files in testdata instead of failing")
40 enableZipSum = flag.Bool("zipsum", false, "enable TestZipSums")
41 debugZipSum = flag.Bool("testwork", false, "when set, TestZipSums will preserve its test directory")
42 modCacheDir = flag.String("zipsumcache", "", "module cache to use instead of temp directory")
43 shardCount = flag.Int("zipsumshardcount", 1, "number of shards to divide TestZipSums into")
44 shardIndex = flag.Int("zipsumshard", 0, "index of TestZipSums shard to test (0 <= zipsumshard < zipsumshardcount)")
45 )
46
47 const zipSumsPath = "testdata/zip_sums.csv"
48
49 type zipSumTest struct {
50 m module.Version
51 wantSum, wantFileHash string
52 }
53
54 func TestZipSums(t *testing.T) {
55 if !*enableZipSum {
56
57
58 t.Skip("TestZipSum not enabled with -zipsum")
59 }
60 if *shardCount < 1 {
61 t.Fatal("-zipsumshardcount must be a positive integer")
62 }
63 if *shardIndex < 0 || *shardCount <= *shardIndex {
64 t.Fatal("-zipsumshard must be between 0 and -zipsumshardcount")
65 }
66
67 testenv.MustHaveGoBuild(t)
68 testenv.MustHaveExternalNetwork(t)
69 testenv.MustHaveExecPath(t, "bzr")
70 testenv.MustHaveExecPath(t, "git")
71
72
73
74 tests, err := readZipSumTests()
75 if err != nil {
76 t.Fatal(err)
77 }
78
79 if *modCacheDir != "" {
80 cfg.BuildContext.GOPATH = *modCacheDir
81 } else {
82 tmpDir, err := os.MkdirTemp("", "TestZipSums")
83 if err != nil {
84 t.Fatal(err)
85 }
86 if *debugZipSum {
87 fmt.Fprintf(os.Stderr, "TestZipSums: modCacheDir: %s\n", tmpDir)
88 } else {
89 defer os.RemoveAll(tmpDir)
90 }
91 cfg.BuildContext.GOPATH = tmpDir
92 }
93
94 cfg.GOPROXY = "direct"
95 cfg.GOSUMDB = "off"
96
97
98
99
100
101 if *shardCount > 1 {
102 r := *shardIndex
103 w := 0
104 for r < len(tests) {
105 tests[w] = tests[r]
106 w++
107 r += *shardCount
108 }
109 tests = tests[:w]
110 }
111
112
113
114 needUpdate := false
115 fetcher := modfetch.NewFetcher()
116 for i := range tests {
117 test := &tests[i]
118 name := fmt.Sprintf("%s@%s", strings.ReplaceAll(test.m.Path, "/", "_"), test.m.Version)
119 t.Run(name, func(t *testing.T) {
120 t.Parallel()
121 ctx := context.Background()
122
123 zipPath, err := fetcher.DownloadZip(ctx, test.m)
124 if err != nil {
125 if *updateTestData {
126 t.Logf("%s: could not download module: %s (will remove from testdata)", test.m, err)
127 test.m.Path = ""
128 needUpdate = true
129 } else {
130 t.Errorf("%s: could not download module: %s", test.m, err)
131 }
132 return
133 }
134
135 sum := modfetch.Sum(ctx, test.m)
136 if sum != test.wantSum {
137 if *updateTestData {
138 t.Logf("%s: updating content sum to %s", test.m, sum)
139 test.wantSum = sum
140 needUpdate = true
141 } else {
142 t.Errorf("%s: got content sum %s; want sum %s", test.m, sum, test.wantSum)
143 return
144 }
145 }
146
147 h := sha256.New()
148 f, err := os.Open(zipPath)
149 if err != nil {
150 t.Errorf("%s: %v", test.m, err)
151 }
152 defer f.Close()
153 if _, err := io.Copy(h, f); err != nil {
154 t.Errorf("%s: %v", test.m, err)
155 }
156 zipHash := hex.EncodeToString(h.Sum(nil))
157 if zipHash != test.wantFileHash {
158 if *updateTestData {
159 t.Logf("%s: updating zip file hash to %s", test.m, zipHash)
160 test.wantFileHash = zipHash
161 needUpdate = true
162 } else {
163 t.Errorf("%s: got zip file hash %s; want hash %s (but content sum matches)", test.m, zipHash, test.wantFileHash)
164 }
165 }
166 })
167 }
168
169 if needUpdate {
170
171 r, w := 0, 0
172 for r < len(tests) {
173 if tests[r].m.Path != "" {
174 tests[w] = tests[r]
175 w++
176 }
177 r++
178 }
179 tests = tests[:w]
180
181 if err := writeZipSumTests(tests); err != nil {
182 t.Error(err)
183 }
184 }
185 }
186
187 func readZipSumTests() ([]zipSumTest, error) {
188 f, err := os.Open(filepath.FromSlash(zipSumsPath))
189 if err != nil {
190 return nil, err
191 }
192 defer f.Close()
193 r := csv.NewReader(f)
194
195 var tests []zipSumTest
196 for {
197 line, err := r.Read()
198 if err == io.EOF {
199 break
200 } else if err != nil {
201 return nil, err
202 } else if len(line) != 4 {
203 return nil, fmt.Errorf("%s:%d: malformed line", f.Name(), len(tests)+1)
204 }
205 test := zipSumTest{m: module.Version{Path: line[0], Version: line[1]}, wantSum: line[2], wantFileHash: line[3]}
206 tests = append(tests, test)
207 }
208 return tests, nil
209 }
210
211 func writeZipSumTests(tests []zipSumTest) (err error) {
212 f, err := os.Create(filepath.FromSlash(zipSumsPath))
213 if err != nil {
214 return err
215 }
216 defer func() {
217 if cerr := f.Close(); err == nil && cerr != nil {
218 err = cerr
219 }
220 }()
221 w := csv.NewWriter(f)
222 line := make([]string, 0, 4)
223 for _, test := range tests {
224 line = append(line[:0], test.m.Path, test.m.Version, test.wantSum, test.wantFileHash)
225 if err := w.Write(line); err != nil {
226 return err
227 }
228 }
229 w.Flush()
230 return nil
231 }
232
View as plain text