Sort函数

Sort函数只会调用一次Len,并且调用 O(n*log(n)) 次Less和Swap。使用快速排序稳定性是一件尽力而为的事情,如果需要绝对的稳定应该使用Stable函数。

  1. type Interface interface {
  2. Len() int
  3. Less(i, j int) bool
  4. Swap(i, j int)
  5. }
  6. func Sort(data Interface) {
  7. n := data.Len()
  8. quickSort(data, 0, n, maxDepth(n))
  9. }

在输入元素在12个以内时,使用ShellSort。超出12个元素则使用快速排序。

func quickSort(data Interface, a, b, maxDepth int) {
    for b-a > 12 { // Use ShellSort for slices <= 12 elements
        if maxDepth == 0 {
            heapSort(data, a, b)
            return
        }
        maxDepth--
        mlo, mhi := doPivot(data, a, b)
        // Avoiding recursion on the larger subproblem guarantees
        // a stack depth of at most lg(b-a).
        if mlo-a < b-mhi {
            quickSort(data, a, mlo, maxDepth)
            a = mhi // i.e., quickSort(data, mhi, b)
        } else {
            quickSort(data, mhi, b, maxDepth)
            b = mlo // i.e., quickSort(data, a, mlo)
        }
    }
    if b-a > 1 {
        // Do ShellSort pass with gap 6
        // It could be written in this simplified form cause b-a <= 12
        for i := a + 6; i < b; i++ {
            if data.Less(i, i-6) {
                data.Swap(i, i-6)
            }
        }
        insertionSort(data, a, b)
    }
}

快速排序最好情况下快速排序的调用栈深度是 lg(n),此时算法时间复杂度为nlog(n)。
最坏情况下是 n,此时算法时间复杂度为 n2 。因为快速排序借助递归,所以甚至比复杂度为 n2 的算法还要慢。例如简单地选取首元素或尾元素作为枢纽元,同时遇到本身有序的输入。
于是Golang把最大深度限制为 2
ceil(lg(n+1)) ,当超过这个深度则转而使用堆排序。同时选择适当的枢纽元以降低栈深度:选择首元素、中间元素、尾元素这三个元素(当输入元素数量超过40则采样9个元素)的中位数作为枢纽元。

此外,还可以优化元素交换次数:

  1. 将枢纽元放到首位(位置0)
  2. 然后从位置1开始向尾元素寻找大于枢纽元值的元素a,从尾元素开始向首元素寻找小于等于枢纽元的元素b,然后交换元素a和b
  3. 重复步骤2直到位置a位于位置b的后面
  4. 将枢纽元与元素b交换,并返回枢纽元现在的位置。
func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
    m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
    if hi-lo > 40 {
        // Tukey's ``Ninther,'' median of three medians of three.
        s := (hi - lo) / 8
        medianOfThree(data, lo, lo+s, lo+2*s)
        medianOfThree(data, m, m-s, m+s)
        medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
    }
    medianOfThree(data, lo, m, hi-1)

    // Invariants are:
    //    data[lo] = pivot (set up by ChoosePivot)
    //    data[lo < i < a] < pivot
    //    data[a <= i < b] <= pivot
    //    data[b <= i < c] unexamined
    //    data[c <= i < hi-1] > pivot
    //    data[hi-1] >= pivot
    pivot := lo
    a, c := lo+1, hi-1

    for ; a < c && data.Less(a, pivot); a++ {
    }
    b := a
    for {
        for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
        }
        for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
        }
        if b >= c {
            break
        }
        // data[b] > pivot; data[c-1] <= pivot
        data.Swap(b, c-1)
        b++
        c--
    }
    // If hi-c<3 then there are duplicates (by property of median of nine).
    // Let's be a bit more conservative, and set border to 5.
    protect := hi-c < 5
    if !protect && hi-c < (hi-lo)/4 {
        // Lets test some points for equality to pivot
        dups := 0
        if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
            data.Swap(c, hi-1)
            c++
            dups++
        }
        if !data.Less(b-1, pivot) { // data[b-1] = pivot
            b--
            dups++
        }
        // m-lo = (hi-lo)/2 > 6
        // b-lo > (hi-lo)*3/4-1 > 8
        // ==> m < b ==> data[m] <= pivot
        if !data.Less(m, pivot) { // data[m] = pivot
            data.Swap(m, b-1)
            b--
            dups++
        }
        // if at least 2 points are equal to pivot, assume skewed distribution
        protect = dups > 1
    }
    if protect {
        // Protect against a lot of duplicates
        // Add invariant:
        //    data[a <= i < b] unexamined
        //    data[b <= i < c] = pivot
        for {
            for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
            }
            for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
            }
            if a >= b {
                break
            }
            // data[a] == pivot; data[b-1] < pivot
            data.Swap(a, b-1)
            a++
            b--
        }
    }
    // Swap pivot into middle
    data.Swap(pivot, b-1)
    return b - 1, c
}

Stable函数

Stable函数会调用一次Len,O(nlog(n)) 次Less,以及 O(nlog(n)*log(n)) 次Swap。

  1. 首先,将输入元素分为多个block,每个block大小20,末尾一个block可能小于20。对每个block执行插入排序
  2. 然后每两个block一组进行归并排序,block大小变为之前2倍
  3. 重复步骤2直到只剩一个block

    func stable(data Interface, n int) {
     blockSize := 20 // must be > 0
     a, b := 0, blockSize
     for b <= n {
         insertionSort(data, a, b)
         a = b
         b += blockSize
     }
     insertionSort(data, a, n)
    
     for blockSize < n {
         a, b = 0, 2*blockSize
         for b <= n {
             symMerge(data, a, a+blockSize, b)
             a = b
             b += 2 * blockSize
         }
         if m := a + blockSize; m < n {
             symMerge(data, a, m, n)
         }
         blockSize *= 2
     }
    }