Source file src/slices/sort_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package slices_test
     6  
     7  import (
     8  	"cmp"
     9  	"fmt"
    10  	"math"
    11  	"math/rand"
    12  	. "slices"
    13  	"strconv"
    14  	"strings"
    15  	"testing"
    16  )
    17  
    18  var ints = [...]int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586}
    19  var float64s = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8, 74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3}
    20  var strs = [...]string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"}
    21  
    22  func TestSortIntSlice(t *testing.T) {
    23  	data := Clone(ints[:])
    24  	Sort(data)
    25  	if !IsSorted(data) {
    26  		t.Errorf("sorted %v", ints)
    27  		t.Errorf("   got %v", data)
    28  	}
    29  }
    30  
    31  func TestSortFuncIntSlice(t *testing.T) {
    32  	data := Clone(ints[:])
    33  	SortFunc(data, func(a, b int) int { return a - b })
    34  	if !IsSorted(data) {
    35  		t.Errorf("sorted %v", ints)
    36  		t.Errorf("   got %v", data)
    37  	}
    38  }
    39  
    40  func TestSortFloat64Slice(t *testing.T) {
    41  	data := Clone(float64s[:])
    42  	Sort(data)
    43  	if !IsSorted(data) {
    44  		t.Errorf("sorted %v", float64s)
    45  		t.Errorf("   got %v", data)
    46  	}
    47  }
    48  
    49  func TestSortStringSlice(t *testing.T) {
    50  	data := Clone(strs[:])
    51  	Sort(data)
    52  	if !IsSorted(data) {
    53  		t.Errorf("sorted %v", strs)
    54  		t.Errorf("   got %v", data)
    55  	}
    56  }
    57  
    58  func TestSortLarge_Random(t *testing.T) {
    59  	n := 1000000
    60  	if testing.Short() {
    61  		n /= 100
    62  	}
    63  	data := make([]int, n)
    64  	for i := 0; i < len(data); i++ {
    65  		data[i] = rand.Intn(100)
    66  	}
    67  	if IsSorted(data) {
    68  		t.Fatalf("terrible rand.rand")
    69  	}
    70  	Sort(data)
    71  	if !IsSorted(data) {
    72  		t.Errorf("sort didn't sort - 1M ints")
    73  	}
    74  }
    75  
    76  type intPair struct {
    77  	a, b int
    78  }
    79  
    80  type intPairs []intPair
    81  
    82  // Pairs compare on a only.
    83  func intPairCmp(x, y intPair) int {
    84  	return x.a - y.a
    85  }
    86  
    87  // Record initial order in B.
    88  func (d intPairs) initB() {
    89  	for i := range d {
    90  		d[i].b = i
    91  	}
    92  }
    93  
    94  // InOrder checks if a-equal elements were not reordered.
    95  // If reversed is true, expect reverse ordering.
    96  func (d intPairs) inOrder(reversed bool) bool {
    97  	lastA, lastB := -1, 0
    98  	for i := 0; i < len(d); i++ {
    99  		if lastA != d[i].a {
   100  			lastA = d[i].a
   101  			lastB = d[i].b
   102  			continue
   103  		}
   104  		if !reversed {
   105  			if d[i].b <= lastB {
   106  				return false
   107  			}
   108  		} else {
   109  			if d[i].b >= lastB {
   110  				return false
   111  			}
   112  		}
   113  		lastB = d[i].b
   114  	}
   115  	return true
   116  }
   117  
   118  func TestStability(t *testing.T) {
   119  	n, m := 100000, 1000
   120  	if testing.Short() {
   121  		n, m = 1000, 100
   122  	}
   123  	data := make(intPairs, n)
   124  
   125  	// random distribution
   126  	for i := 0; i < len(data); i++ {
   127  		data[i].a = rand.Intn(m)
   128  	}
   129  	if IsSortedFunc(data, intPairCmp) {
   130  		t.Fatalf("terrible rand.rand")
   131  	}
   132  	data.initB()
   133  	SortStableFunc(data, intPairCmp)
   134  	if !IsSortedFunc(data, intPairCmp) {
   135  		t.Errorf("Stable didn't sort %d ints", n)
   136  	}
   137  	if !data.inOrder(false) {
   138  		t.Errorf("Stable wasn't stable on %d ints", n)
   139  	}
   140  
   141  	// already sorted
   142  	data.initB()
   143  	SortStableFunc(data, intPairCmp)
   144  	if !IsSortedFunc(data, intPairCmp) {
   145  		t.Errorf("Stable shuffled sorted %d ints (order)", n)
   146  	}
   147  	if !data.inOrder(false) {
   148  		t.Errorf("Stable shuffled sorted %d ints (stability)", n)
   149  	}
   150  
   151  	// sorted reversed
   152  	for i := 0; i < len(data); i++ {
   153  		data[i].a = len(data) - i
   154  	}
   155  	data.initB()
   156  	SortStableFunc(data, intPairCmp)
   157  	if !IsSortedFunc(data, intPairCmp) {
   158  		t.Errorf("Stable didn't sort %d ints", n)
   159  	}
   160  	if !data.inOrder(false) {
   161  		t.Errorf("Stable wasn't stable on %d ints", n)
   162  	}
   163  }
   164  
   165  type S struct {
   166  	a int
   167  	b string
   168  }
   169  
   170  func cmpS(s1, s2 S) int {
   171  	return cmp.Compare(s1.a, s2.a)
   172  }
   173  
   174  func TestMinMax(t *testing.T) {
   175  	intCmp := func(a, b int) int { return a - b }
   176  
   177  	tests := []struct {
   178  		data    []int
   179  		wantMin int
   180  		wantMax int
   181  	}{
   182  		{[]int{7}, 7, 7},
   183  		{[]int{1, 2}, 1, 2},
   184  		{[]int{2, 1}, 1, 2},
   185  		{[]int{1, 2, 3}, 1, 3},
   186  		{[]int{3, 2, 1}, 1, 3},
   187  		{[]int{2, 1, 3}, 1, 3},
   188  		{[]int{2, 2, 3}, 2, 3},
   189  		{[]int{3, 2, 3}, 2, 3},
   190  		{[]int{0, 2, -9}, -9, 2},
   191  	}
   192  	for _, tt := range tests {
   193  		t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) {
   194  			gotMin := Min(tt.data)
   195  			if gotMin != tt.wantMin {
   196  				t.Errorf("Min got %v, want %v", gotMin, tt.wantMin)
   197  			}
   198  
   199  			gotMinFunc := MinFunc(tt.data, intCmp)
   200  			if gotMinFunc != tt.wantMin {
   201  				t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin)
   202  			}
   203  
   204  			gotMax := Max(tt.data)
   205  			if gotMax != tt.wantMax {
   206  				t.Errorf("Max got %v, want %v", gotMax, tt.wantMax)
   207  			}
   208  
   209  			gotMaxFunc := MaxFunc(tt.data, intCmp)
   210  			if gotMaxFunc != tt.wantMax {
   211  				t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax)
   212  			}
   213  		})
   214  	}
   215  
   216  	svals := []S{
   217  		{1, "a"},
   218  		{2, "a"},
   219  		{1, "b"},
   220  		{2, "b"},
   221  	}
   222  
   223  	gotMin := MinFunc(svals, cmpS)
   224  	wantMin := S{1, "a"}
   225  	if gotMin != wantMin {
   226  		t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin)
   227  	}
   228  
   229  	gotMax := MaxFunc(svals, cmpS)
   230  	wantMax := S{2, "a"}
   231  	if gotMax != wantMax {
   232  		t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax)
   233  	}
   234  }
   235  
   236  func TestMinMaxNaNs(t *testing.T) {
   237  	fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14}
   238  	if Min(fs) != -400.4 {
   239  		t.Errorf("got min %v, want -400.4", Min(fs))
   240  	}
   241  	if Max(fs) != 999.9 {
   242  		t.Errorf("got max %v, want 999.9", Max(fs))
   243  	}
   244  
   245  	// No matter which element of fs is replaced with a NaN, both Min and Max
   246  	// should propagate the NaN to their output.
   247  	for i := 0; i < len(fs); i++ {
   248  		testfs := Clone(fs)
   249  		testfs[i] = math.NaN()
   250  
   251  		fmin := Min(testfs)
   252  		if !math.IsNaN(fmin) {
   253  			t.Errorf("got min %v, want NaN", fmin)
   254  		}
   255  
   256  		fmax := Max(testfs)
   257  		if !math.IsNaN(fmax) {
   258  			t.Errorf("got max %v, want NaN", fmax)
   259  		}
   260  	}
   261  }
   262  
   263  func TestMinMaxPanics(t *testing.T) {
   264  	intCmp := func(a, b int) int { return a - b }
   265  	emptySlice := []int{}
   266  
   267  	if !panics(func() { Min(emptySlice) }) {
   268  		t.Errorf("Min([]): got no panic, want panic")
   269  	}
   270  
   271  	if !panics(func() { Max(emptySlice) }) {
   272  		t.Errorf("Max([]): got no panic, want panic")
   273  	}
   274  
   275  	if !panics(func() { MinFunc(emptySlice, intCmp) }) {
   276  		t.Errorf("MinFunc([]): got no panic, want panic")
   277  	}
   278  
   279  	if !panics(func() { MaxFunc(emptySlice, intCmp) }) {
   280  		t.Errorf("MaxFunc([]): got no panic, want panic")
   281  	}
   282  }
   283  
   284  func TestBinarySearch(t *testing.T) {
   285  	str1 := []string{"foo"}
   286  	str2 := []string{"ab", "ca"}
   287  	str3 := []string{"mo", "qo", "vo"}
   288  	str4 := []string{"ab", "ad", "ca", "xy"}
   289  
   290  	// slice with repeating elements
   291  	strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"}
   292  
   293  	// slice with all element equal
   294  	strSame := []string{"xx", "xx", "xx"}
   295  
   296  	tests := []struct {
   297  		data      []string
   298  		target    string
   299  		wantPos   int
   300  		wantFound bool
   301  	}{
   302  		{[]string{}, "foo", 0, false},
   303  		{[]string{}, "", 0, false},
   304  
   305  		{str1, "foo", 0, true},
   306  		{str1, "bar", 0, false},
   307  		{str1, "zx", 1, false},
   308  
   309  		{str2, "aa", 0, false},
   310  		{str2, "ab", 0, true},
   311  		{str2, "ad", 1, false},
   312  		{str2, "ca", 1, true},
   313  		{str2, "ra", 2, false},
   314  
   315  		{str3, "bb", 0, false},
   316  		{str3, "mo", 0, true},
   317  		{str3, "nb", 1, false},
   318  		{str3, "qo", 1, true},
   319  		{str3, "tr", 2, false},
   320  		{str3, "vo", 2, true},
   321  		{str3, "xr", 3, false},
   322  
   323  		{str4, "aa", 0, false},
   324  		{str4, "ab", 0, true},
   325  		{str4, "ac", 1, false},
   326  		{str4, "ad", 1, true},
   327  		{str4, "ax", 2, false},
   328  		{str4, "ca", 2, true},
   329  		{str4, "cc", 3, false},
   330  		{str4, "dd", 3, false},
   331  		{str4, "xy", 3, true},
   332  		{str4, "zz", 4, false},
   333  
   334  		{strRepeats, "da", 2, true},
   335  		{strRepeats, "db", 5, false},
   336  		{strRepeats, "ma", 6, true},
   337  		{strRepeats, "mb", 8, false},
   338  
   339  		{strSame, "xx", 0, true},
   340  		{strSame, "ab", 0, false},
   341  		{strSame, "zz", 3, false},
   342  	}
   343  	for _, tt := range tests {
   344  		t.Run(tt.target, func(t *testing.T) {
   345  			{
   346  				pos, found := BinarySearch(tt.data, tt.target)
   347  				if pos != tt.wantPos || found != tt.wantFound {
   348  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   349  				}
   350  			}
   351  
   352  			{
   353  				pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare)
   354  				if pos != tt.wantPos || found != tt.wantFound {
   355  					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   356  				}
   357  			}
   358  		})
   359  	}
   360  }
   361  
   362  func TestBinarySearchInts(t *testing.T) {
   363  	data := []int{20, 30, 40, 50, 60, 70, 80, 90}
   364  	tests := []struct {
   365  		target    int
   366  		wantPos   int
   367  		wantFound bool
   368  	}{
   369  		{20, 0, true},
   370  		{23, 1, false},
   371  		{43, 3, false},
   372  		{80, 6, true},
   373  	}
   374  	for _, tt := range tests {
   375  		t.Run(strconv.Itoa(tt.target), func(t *testing.T) {
   376  			{
   377  				pos, found := BinarySearch(data, tt.target)
   378  				if pos != tt.wantPos || found != tt.wantFound {
   379  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   380  				}
   381  			}
   382  
   383  			{
   384  				cmp := func(a, b int) int {
   385  					return a - b
   386  				}
   387  				pos, found := BinarySearchFunc(data, tt.target, cmp)
   388  				if pos != tt.wantPos || found != tt.wantFound {
   389  					t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   390  				}
   391  			}
   392  		})
   393  	}
   394  }
   395  
   396  func TestBinarySearchFloats(t *testing.T) {
   397  	data := []float64{math.NaN(), -0.25, 0.0, 1.4}
   398  	tests := []struct {
   399  		target    float64
   400  		wantPos   int
   401  		wantFound bool
   402  	}{
   403  		{math.NaN(), 0, true},
   404  		{math.Inf(-1), 1, false},
   405  		{-0.25, 1, true},
   406  		{0.0, 2, true},
   407  		{1.4, 3, true},
   408  		{1.5, 4, false},
   409  	}
   410  	for _, tt := range tests {
   411  		t.Run(fmt.Sprintf("%v", tt.target), func(t *testing.T) {
   412  			{
   413  				pos, found := BinarySearch(data, tt.target)
   414  				if pos != tt.wantPos || found != tt.wantFound {
   415  					t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound)
   416  				}
   417  			}
   418  		})
   419  	}
   420  }
   421  
   422  func TestBinarySearchFunc(t *testing.T) {
   423  	data := []int{1, 10, 11, 2} // sorted lexicographically
   424  	cmp := func(a int, b string) int {
   425  		return strings.Compare(strconv.Itoa(a), b)
   426  	}
   427  	pos, found := BinarySearchFunc(data, "2", cmp)
   428  	if pos != 3 || !found {
   429  		t.Errorf("BinarySearchFunc(%v, %q, cmp) = %v, %v, want %v, %v", data, "2", pos, found, 3, true)
   430  	}
   431  }
   432  

View as plain text