Source file src/slices/sort_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package slices_test
     6  
     7  import (
     8  	"cmp"
     9  	"fmt"
    10  	"math"
    11  	"math/rand"
    12  	. "slices"
    13  	"strconv"
    14  	"strings"
    15  	"testing"
    16  )
    17  
    18  var ints = [...]int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586}
    19  var float64s = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8, 74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3}
    20  var strs = [...]string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
    21  
    22  func TestSortIntSlice(t *testing.T) {
    23  	data := Clone(ints[:])
    24  	Sort(data)
    25  	if !IsSorted(data) {
    26  		t.Errorf("sorted %v", ints)
    27  		t.Errorf("   got %v", data)
    28  	}
    29  }
    30  
    31  func TestSortFuncIntSlice(t *testing.T) {
    32  	data := Clone(ints[:])
    33  	SortFunc(data, func(a, b int) int { return a - b })
    34  	if !IsSorted(data) {
    35  		t.Errorf("sorted %v", ints)
    36  		t.Errorf("   got %v", data)
    37  	}
    38  }
    39  
    40  func TestSortFloat64Slice(t *testing.T) {
    41  	data := Clone(float64s[:])
    42  	Sort(data)
    43  	if !IsSorted(data) {
    44  		t.Errorf("sorted %v", float64s)
    45  		t.Errorf("   got %v", data)
    46  	}
    47  }
    48  
    49  func TestSortStringSlice(t *testing.T) {
    50  	data := Clone(strs[:])
    51  	Sort(data)
    52  	if !IsSorted(data) {
    53  		t.Errorf("sorted %v", strs)
    54  		t.Errorf("   got %v", data)
    55  	}
    56  }
    57  
    58  func TestSortLarge_Random(t *testing.T) {
    59  	n := 1000000
    60  	if testing.Short() {
    61  		n /= 100
    62  	}
    63  	data := make([]int, n)
    64  	for i := 0; i < len(data); i++ {
    65  		data[i] = rand.Intn(100)
    66  	}
    67  	if IsSorted(data) {
    68  		t.Fatalf("terrible rand.rand")
    69  	}
    70  	Sort(data)
    71  	if !IsSorted(data) {
    72  		t.Errorf("sort didn't sort - 1M ints")
    73  	}
    74  }
    75  
    76  type intPair struct {
    77  	a, b int
    78  }
    79  
    80  type intPairs []intPair
    81  
    82  // Pairs compare on a only.
    83  func intPairCmp(x, y intPair) int {
    84  	return x.a - y.a
    85  }
    86  
    87  // Record initial order in B.
    88  func (d intPairs) initB() {
    89  	for i := range d {
    90  		d[i].b = i
    91  	}
    92  }
    93  
    94  // InOrder checks if a-equal elements were not reordered.
    95  func (d intPairs) inOrder() bool {
    96  	lastA, lastB := -1, 0
    97  	for i := 0; i < len(d); i++ {
    98  		if lastA != d[i].a {
    99  			lastA = d[i].a
   100  			lastB = d[i].b
   101  			continue
   102  		}
   103  		if d[i].b <= lastB {
   104  			return false
   105  		}
   106  		lastB = d[i].b
   107  	}
   108  	return true
   109  }
   110  
   111  func TestStability(t *testing.T) {
   112  	n, m := 100000, 1000
   113  	if testing.Short() {
   114  		n, m = 1000, 100
   115  	}
   116  	data := make(intPairs, n)
   117  
   118  	// random distribution
   119  	for i := 0; i < len(data); i++ {
   120  		data[i].a = rand.Intn(m)
   121  	}
   122  	if IsSortedFunc(data, intPairCmp) {
   123  		t.Fatalf("terrible rand.rand")
   124  	}
   125  	data.initB()
   126  	SortStableFunc(data, intPairCmp)
   127  	if !IsSortedFunc(data, intPairCmp) {
   128  		t.Errorf("Stable didn't sort %d ints", n)
   129  	}
   130  	if !data.inOrder() {
   131  		t.Errorf("Stable wasn't stable on %d ints", n)
   132  	}
   133  
   134  	// already sorted
   135  	data.initB()
   136  	SortStableFunc(data, intPairCmp)
   137  	if !IsSortedFunc(data, intPairCmp) {
   138  		t.Errorf("Stable shuffled sorted %d ints (order)", n)
   139  	}
   140  	if !data.inOrder() {
   141  		t.Errorf("Stable shuffled sorted %d ints (stability)", n)
   142  	}
   143  
   144  	// sorted reversed
   145  	for i := 0; i < len(data); i++ {
   146  		data[i].a = len(data) - i
   147  	}
   148  	data.initB()
   149  	SortStableFunc(data, intPairCmp)
   150  	if !IsSortedFunc(data, intPairCmp) {
   151  		t.Errorf("Stable didn't sort %d ints", n)
   152  	}
   153  	if !data.inOrder() {
   154  		t.Errorf("Stable wasn't stable on %d ints", n)
   155  	}
   156  }
   157  
   158  type S struct {
   159  	a int
   160  	b string
   161  }
   162  
   163  func cmpS(s1, s2 S) int {
   164  	return cmp.Compare(s1.a, s2.a)
   165  }
   166  
   167  func TestMinMax(t *testing.T) {
   168  	intCmp := func(a, b int) int { return a - b }
   169  
   170  	tests := []struct {
   171  		data    []int
   172  		wantMin int
   173  		wantMax int
   174  	}{
   175  		{[]int{7}, 7, 7},
   176  		{[]int{1, 2}, 1, 2},
   177  		{[]int{2, 1}, 1, 2},
   178  		{[]int{1, 2, 3}, 1, 3},
   179  		{[]int{3, 2, 1}, 1, 3},
   180  		{[]int{2, 1, 3}, 1, 3},
   181  		{[]int{2, 2, 3}, 2, 3},
   182  		{[]int{3, 2, 3}, 2, 3},
   183  		{[]int{0, 2, -9}, -9, 2},
   184  	}
   185  	for _, tt := range tests {
   186  		t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) {
   187  			gotMin := Min(tt.data)
   188  			if gotMin != tt.wantMin {
   189  				t.Errorf("Min got %v, want %v", gotMin, tt.wantMin)
   190  			}
   191  
   192  			gotMinFunc := MinFunc(tt.data, intCmp)
   193  			if gotMinFunc != tt.wantMin {
   194  				t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin)
   195  			}
   196  
   197  			gotMax := Max(tt.data)
   198  			if gotMax != tt.wantMax {
   199  				t.Errorf("Max got %v, want %v", gotMax, tt.wantMax)
   200  			}
   201  
   202  			gotMaxFunc := MaxFunc(tt.data, intCmp)
   203  			if gotMaxFunc != tt.wantMax {
   204  				t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax)
   205  			}
   206  		})
   207  	}
   208  
   209  	svals := []S{
   210  		{1, "a"},
   211  		{2, "a"},
   212  		{1, "b"},
   213  		{2, "b"},
   214  	}
   215  
   216  	gotMin := MinFunc(svals, cmpS)
   217  	wantMin := S{1, "a"}
   218  	if gotMin != wantMin {
   219  		t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin)
   220  	}
   221  
   222  	gotMax := MaxFunc(svals, cmpS)
   223  	wantMax := S{2, "a"}
   224  	if gotMax != wantMax {
   225  		t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax)
   226  	}
   227  }
   228  
   229  func TestMinMaxNaNs(t *testing.T) {
   230  	fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14}
   231  	if Min(fs) != -400.4 {
   232  		t.Errorf("got min %v, want -400.4", Min(fs))
   233  	}
   234  	if Max(fs) != 999.9 {
   235  		t.Errorf("got max %v, want 999.9", Max(fs))
   236  	}
   237  
   238  	// No matter which element of fs is replaced with a NaN, both Min and Max
   239  	// should propagate the NaN to their output.
   240  	for i := 0; i < len(fs); i++ {
   241  		testfs := Clone(fs)
   242  		testfs[i] = math.NaN()
   243  
   244  		fmin := Min(testfs)
   245  		if !math.IsNaN(fmin) {
   246  			t.Errorf("got min %v, want NaN", fmin)
   247  		}
   248  
   249  		fmax := Max(testfs)
   250  		if !math.IsNaN(fmax) {
   251  			t.Errorf("got max %v, want NaN", fmax)
   252  		}
   253  	}
   254  }
   255  
   256  func TestMinMaxPanics(t *testing.T) {
   257  	intCmp := func(a, b int) int { return a - b }
   258  	emptySlice := []int{}
   259  
   260  	if !panics(func() { Min(emptySlice) }) {
   261  		t.Errorf("Min([]): got no panic, want panic")
   262  	}
   263  
   264  	if !panics(func() { Max(emptySlice) }) {
   265  		t.Errorf("Max([]): got no panic, want panic")
   266  	}
   267  
   268  	if !panics(func() { MinFunc(emptySlice, intCmp) }) {
   269  		t.Errorf("MinFunc([]): got no panic, want panic")
   270  	}
   271  
   272  	if !panics(func() { MaxFunc(emptySlice, intCmp) }) {
   273  		t.Errorf("MaxFunc([]): got no panic, want panic")
   274  	}
   275  }
   276  
   277  func TestBinarySearch(t *testing.T) {
   278  	str1 := []string{"foo"}
   279  	str2 := []string{"ab", "ca"}
   280  	str3 := []string{"mo", "qo", "vo"}
   281  	str4 := []string{"ab", "ad", "ca", "xy"}
   282  
   283  	// slice with repeating elements
   284  	strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}
   285  
   286  	// slice with all element equal
   287  	strSame := []string{"xx", "xx", "xx"}
   288  
   289  	tests := []struct {
   290  		data      []string
   291  		target    string
   292  		wantPos   int
   293  		wantFound bool
   294  	}{
   295  		{[]string{}, "foo", 0, false},
   296  		{[]string{}, "", 0, false},
   297  
   298  		{str1, "foo", 0, true},
   299  		{str1, "bar", 0, false},
   300  		{str1, "zx", 1, false},
   301  
   302  		{str2, "aa", 0, false},
   303  		{str2, "ab", 0, true},
   304  		{str2, "ad", 1, false},
   305  		{str2, "ca", 1, true},
   306  		{str2, "ra", 2, false},
   307  
   308  		{str3, "bb", 0, false},
   309  		{str3, "mo", 0, true},
   310  		{str3, "nb", 1, false},
   311  		{str3, "qo", 1, true},
   312  		{str3, "tr", 2, false},
   313  		{str3, "vo", 2, true},
   314  		{str3, "xr", 3, false},
   315  
   316  		{str4, "aa", 0, false},
   317  		{str4, "ab", 0, true},
   318  		{str4, "ac", 1, false},
   319  		{str4, "ad", 1, true},
   320  		{str4, "ax", 2, false},
   321  		{str4, "ca", 2, true},
   322  		{str4, "cc", 3, false},
   323  		{str4, "dd", 3, false},
   324  		{str4, "xy", 3, true},
   325  		{str4, "zz", 4, false},
   326  
   327  		{strRepeats, "da", 2, true},
   328  		{strRepeats, "db", 5, false},
   329  		{strRepeats, "ma", 6, true},
   330  		{strRepeats, "mb", 8, false},
   331  
   332  		{strSame, "xx", 0, true},
   333  		{strSame, "ab", 0, false},
   334  		{strSame, "zz", 3, false},
   335  	}
   336  	for _, tt := range tests {
   337  		t.Run(tt.target, func(t *testing.T) {
   338  			{
   339  				pos, found := BinarySearch(tt.data, tt.target)
   340  				if pos != tt.wantPos || found != tt.wantFound {
   341  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   342  				}
   343  			}
   344  
   345  			{
   346  				pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
   347  				if pos != tt.wantPos || found != tt.wantFound {
   348  					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   349  				}
   350  			}
   351  		})
   352  	}
   353  }
   354  
   355  func TestBinarySearchInts(t *testing.T) {
   356  	data := []int{20, 30, 40, 50, 60, 70, 80, 90}
   357  	tests := []struct {
   358  		target    int
   359  		wantPos   int
   360  		wantFound bool
   361  	}{
   362  		{20, 0, true},
   363  		{23, 1, false},
   364  		{43, 3, false},
   365  		{80, 6, true},
   366  	}
   367  	for _, tt := range tests {
   368  		t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
   369  			{
   370  				pos, found := BinarySearch(data, tt.target)
   371  				if pos != tt.wantPos || found != tt.wantFound {
   372  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   373  				}
   374  			}
   375  
   376  			{
   377  				cmp := func(a, b int) int {
   378  					return a - b
   379  				}
   380  				pos, found := BinarySearchFunc(data, tt.target, cmp)
   381  				if pos != tt.wantPos || found != tt.wantFound {
   382  					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   383  				}
   384  			}
   385  		})
   386  	}
   387  }
   388  
   389  func TestBinarySearchFloats(t *testing.T) {
   390  	data := []float64{math.NaN(), -0.25, 0.0, 1.4}
   391  	tests := []struct {
   392  		target    float64
   393  		wantPos   int
   394  		wantFound bool
   395  	}{
   396  		{math.NaN(), 0, true},
   397  		{math.Inf(-1), 1, false},
   398  		{-0.25, 1, true},
   399  		{0.0, 2, true},
   400  		{1.4, 3, true},
   401  		{1.5, 4, false},
   402  	}
   403  	for _, tt := range tests {
   404  		t.Run(fmt.Sprintf("%v", tt.target), func(t *testing.T) {
   405  			{
   406  				pos, found := BinarySearch(data, tt.target)
   407  				if pos != tt.wantPos || found != tt.wantFound {
   408  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   409  				}
   410  			}
   411  		})
   412  	}
   413  }
   414  
   415  func TestBinarySearchFunc(t *testing.T) {
   416  	data := []int{1, 10, 11, 2} // sorted lexicographically
   417  	cmp := func(a int, b string) int {
   418  		return strings.Compare(strconv.Itoa(a), b)
   419  	}
   420  	pos, found := BinarySearchFunc(data, "2", cmp)
   421  	if pos != 3 || !found {
   422  		t.Errorf("BinarySearchFunc(%v, %q, cmp) = %v, %v, want %v, %v", data, "2", pos, found, 3, true)
   423  	}
   424  }
   425  

View as plain text