Sort函数
Sort函数只会调用一次Len,并且调用 O(n*log(n)) 次Less和Swap。使用快速排序稳定性是一件尽力而为的事情,如果需要绝对的稳定应该使用Stable函数。
type Interface interface {
Len() int
Less(i, j int) bool
Swap(i, j int)
}
func Sort(data Interface) {
n := data.Len()
quickSort(data, 0, n, maxDepth(n))
}
在输入元素在12个以内时,使用ShellSort。超出12个元素则使用快速排序。
func quickSort(data Interface, a, b, maxDepth int) {
for b-a > 12 { // Use ShellSort for slices <= 12 elements
if maxDepth == 0 {
heapSort(data, a, b)
return
}
maxDepth--
mlo, mhi := doPivot(data, a, b)
// Avoiding recursion on the larger subproblem guarantees
// a stack depth of at most lg(b-a).
if mlo-a < b-mhi {
quickSort(data, a, mlo, maxDepth)
a = mhi // i.e., quickSort(data, mhi, b)
} else {
quickSort(data, mhi, b, maxDepth)
b = mlo // i.e., quickSort(data, a, mlo)
}
}
if b-a > 1 {
// Do ShellSort pass with gap 6
// It could be written in this simplified form cause b-a <= 12
for i := a + 6; i < b; i++ {
if data.Less(i, i-6) {
data.Swap(i, i-6)
}
}
insertionSort(data, a, b)
}
}
快速排序最好情况下快速排序的调用栈深度是 lg(n),此时算法时间复杂度为nlog(n)。
最坏情况下是 n,此时算法时间复杂度为 n2 。因为快速排序借助递归,所以甚至比复杂度为 n2 的算法还要慢。例如简单地选取首元素或尾元素作为枢纽元,同时遇到本身有序的输入。
于是Golang把最大深度限制为 2ceil(lg(n+1)) ,当超过这个深度则转而使用堆排序。同时选择适当的枢纽元以降低栈深度:选择首元素、中间元素、尾元素这三个元素(当输入元素数量超过40则采样9个元素)的中位数作为枢纽元。
此外,还可以优化元素交换次数:
- 将枢纽元放到首位(位置0)
- 然后从位置1开始向尾元素寻找大于枢纽元值的元素a,从尾元素开始向首元素寻找小于等于枢纽元的元素b,然后交换元素a和b
- 重复步骤2直到位置a位于位置b的后面
- 将枢纽元与元素b交换,并返回枢纽元现在的位置。
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
if hi-lo > 40 {
// Tukey's ``Ninther,'' median of three medians of three.
s := (hi - lo) / 8
medianOfThree(data, lo, lo+s, lo+2*s)
medianOfThree(data, m, m-s, m+s)
medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
}
medianOfThree(data, lo, m, hi-1)
// Invariants are:
// data[lo] = pivot (set up by ChoosePivot)
// data[lo < i < a] < pivot
// data[a <= i < b] <= pivot
// data[b <= i < c] unexamined
// data[c <= i < hi-1] > pivot
// data[hi-1] >= pivot
pivot := lo
a, c := lo+1, hi-1
for ; a < c && data.Less(a, pivot); a++ {
}
b := a
for {
for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
}
for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
}
if b >= c {
break
}
// data[b] > pivot; data[c-1] <= pivot
data.Swap(b, c-1)
b++
c--
}
// If hi-c<3 then there are duplicates (by property of median of nine).
// Let's be a bit more conservative, and set border to 5.
protect := hi-c < 5
if !protect && hi-c < (hi-lo)/4 {
// Lets test some points for equality to pivot
dups := 0
if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
data.Swap(c, hi-1)
c++
dups++
}
if !data.Less(b-1, pivot) { // data[b-1] = pivot
b--
dups++
}
// m-lo = (hi-lo)/2 > 6
// b-lo > (hi-lo)*3/4-1 > 8
// ==> m < b ==> data[m] <= pivot
if !data.Less(m, pivot) { // data[m] = pivot
data.Swap(m, b-1)
b--
dups++
}
// if at least 2 points are equal to pivot, assume skewed distribution
protect = dups > 1
}
if protect {
// Protect against a lot of duplicates
// Add invariant:
// data[a <= i < b] unexamined
// data[b <= i < c] = pivot
for {
for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
}
for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
}
if a >= b {
break
}
// data[a] == pivot; data[b-1] < pivot
data.Swap(a, b-1)
a++
b--
}
}
// Swap pivot into middle
data.Swap(pivot, b-1)
return b - 1, c
}
Stable函数
Stable函数会调用一次Len,O(nlog(n)) 次Less,以及 O(nlog(n)*log(n)) 次Swap。
- 首先,将输入元素分为多个block,每个block大小20,末尾一个block可能小于20。对每个block执行插入排序
- 然后每两个block一组进行归并排序,block大小变为之前2倍
重复步骤2直到只剩一个block
func stable(data Interface, n int) { blockSize := 20 // must be > 0 a, b := 0, blockSize for b <= n { insertionSort(data, a, b) a = b b += blockSize } insertionSort(data, a, n) for blockSize < n { a, b = 0, 2*blockSize for b <= n { symMerge(data, a, a+blockSize, b) a = b b += 2 * blockSize } if m := a + blockSize; m < n { symMerge(data, a, m, n) } blockSize *= 2 } }