首发于算法之美

BFPRT算法原理

通常我们需要在一大堆数中求前 k 大的数。比如在搜索引擎中求当天用户点击次数排名前10000的热词,在文本特征选择中求 tf-idf 值按从大到小排名前 k 等问题,都涉及到一个核心问题,即TOP-K问题


求TOP-K问题最简单的方式为快速排序后取前K大的即可。但是这样做有两个问题

1. 快速排序的平均复杂度为 O(n\log(n)) ,但最坏时间复杂度为 O(n^2)

2. 我们只需要前 k 大的,而对其余不需要的数也进行了排序,浪费了大量排序时间。

而堆排序也是一个较好的方法,维护一个大小为 k 的堆,时间复杂度为 O(n\log(k))


这里介绍一个比较好的算法,叫做BFPTR算法,又称为中位数的中位数算法,它的最坏时间复杂度为 O(n) ,它是由Blum、Floyd、Pratt、Rivest、Tarjan提出。该算法的思想是修改快速选择算法的主元选取方法,提高算法在最坏情况下的时间复杂度。


一. 快速排序原理

先来看看快速排序是如何进行的,一趟快速排序的过程如下

  1. 先从序列中选取一个数作为基准数
  2. 将比这个数大的数全部放到它的右边,把小于或者等于它的数全部放到它的左边

一趟快速排序也叫做Partion,即将序列划分为两部分,一部分比基准数小,另一部分比基准数大,然后再进行分治过程,每一次Partion不一定都能保证划分得很均匀,所以最坏情况下的时间复杂度不能保证总是为 O(n\log(n)) 。对于Partion过程,通常有两种方法

1. 两个指针从首尾向中间扫描(双向扫描)

这种方法可以用挖坑填数来形容,比如

初始化:i = 0; j = 9; pivot = a[0];

现在a[0]保存到了变量pivot中了,相当于在数组a[0]处挖了个坑,那么可以将其它的数填到这里来。从j开始向前找一个小于或者等于pivot的数,即将a[8]填入a[0],但a[8]又形成了一个新坑,再从i开始向后找一个大于pivot的数,即a[3]填入a[8],那么a[3]又形成了一个新坑......

就这样,直到i==j才停止,最终得到结果如下

上述过程就是一趟快速排序

#include <iostream>
#include <string.h>
#include <stdio.h>
#include <algorithm>
#include <time.h>
 
using namespace std;
const int N = 10005;
 
int Partion(int a[], int l, int r)
{
    int i = l;
    int j = r;
    int pivot = a[l];
    while(i < j)
    {
        while(a[j] >= pivot && i < j)
            j--;
        a[i] = a[j];
        while(a[i] <= pivot && i < j)
            i++;
        a[j] = a[i];
    }
    a[i] = pivot;
    return i;
}
 
void QuickSort(int a[], int l, int r)
{
    if(l < r)
    {
        int k = Partion(a, l, r);
        QuickSort(a, l, k - 1);
        QuickSort(a, k + 1, r);
    }
}
 
int a[N];
 
int main()
{
    int n;
    while(cin >> n)
    {
        for(int i = 0; i < n; i++)
            cin >> a[i];
        QuickSort(a, 0, n - 1);
        for(int i = 0; i < n; i++)
            cout << a[i] << " ";
        cout << endl;
    }
    return 0;
}

2. 两个指针一前一后逐步向前扫描(单向扫描)

#include <iostream>
#include <string.h>
#include <stdio.h>
 
using namespace std;
const int N = 10005;
 
int Partion(int a[], int l, int r)
{
    int i = l - 1;
    int pivot = a[r];
    for(int j = l; j < r; j++)
    {
        if(a[j] <= pivot)
        {
            i++;
            swap(a[i], a[j]);
        }
    }
    swap(a[i + 1], a[r]);
    return i + 1;
}
 
void QuickSort(int a[], int l, int r)
{
    if(l < r)
    {
        int k = Partion(a, l, r);
        QuickSort(a, l, k - 1);
        QuickSort(a, k + 1, r);
    }
}
 
int a[N];
 
int main()
{
    int n;
    while(cin >> n)
    {
        for(int i = 0; i < n; i++)
            cin >> a[i];
        QuickSort(a, 0, n - 1);
        for(int i = 0; i < n; i++)
            cout << a[i] << " ";
        cout << endl;
    }
    return 0;
}

基于双向扫描的快速排序要比基于单向扫描的快速排序算法快很多。


二. BFPRT算法原理

在BFPTR算法中,仅仅是改变了快速排序Partion中的pivot值的选取,在快速排序中,我们始终选择第一个元素或者最后一个元素作为pivot,而在BFPTR算法中,每次选择五分中位数的中位数作为pivot,这样做的目的就是使得划分比较合理,从而避免了最坏情况的发生。算法步骤如下

1. 将 n 个元素划为 \lfloor n/5\rfloor 组,每组5个,至多只有一组由 n\bmod5 个元素组成。
2. 寻找这 \lceil n/5\rceil 个组中每一个组的中位数,这个过程可以用插入排序。
3. 对步骤2中的 \lceil n/5\rceil 个中位数,重复步骤1和步骤2,递归下去,直到剩下一个数字。
4. 最终剩下的数字即为pivot,把大于它的数全放左边,小于等于它的数全放右边。
5. 判断pivot的位置与k的大小,有选择的对左边或右边递归。

求第 k 大就是求第 n-k+1 小,这两者等价。

#include <iostream>
#include <string.h>
#include <stdio.h>
#include <time.h>
#include <algorithm>
 
using namespace std;
const int N = 10005;
 
int a[N];
 
//插入排序
void InsertSort(int a[], int l, int r)
{
    for(int i = l + 1; i <= r; i++)
    {
        if(a[i - 1] > a[i])
        {
            int t = a[i];
            int j = i;
            while(j > l && a[j - 1] > t)
            {
                a[j] = a[j - 1];
                j--;
            }
            a[j] = t;
        }
    }
}
 
//寻找中位数的中位数
int FindMid(int a[], int l, int r)
{
    if(l == r) return l;
    int i = 0;
    int n = 0;
    for(i = l; i < r - 5; i += 5)
    {
        InsertSort(a, i, i + 4);
        n = i - l;
        swap(a[l + n / 5], a[i + 2]);
    }
 
    //处理剩余元素
    int num = r - i + 1;
    if(num > 0)
    {
        InsertSort(a, i, i + num - 1);
        n = i - l;
        swap(a[l + n / 5], a[i + num / 2]);
    }
    n /= 5;
    if(n == l) return l;
    return FindMid(a, l, l + n);
}
 
//进行划分过程
int Partion(int a[], int l, int r, int p)
{
    swap(a[p], a[l]);
    int i = l;
    int j = r;
    int pivot = a[l];
    while(i < j)
    {
        while(a[j] >= pivot && i < j)
            j--;
        a[i] = a[j];
        while(a[i] <= pivot && i < j)
            i++;
        a[j] = a[i];
    }
    a[i] = pivot;
    return i;
}
 
int BFPRT(int a[], int l, int r, int k)
{
    int p = FindMid(a, l, r);    //寻找中位数的中位数
    int i = Partion(a, l, r, p);
 
    int m = i - l + 1;
    if(m == k) return a[i];
    if(m > k)  return BFPRT(a, l, i - 1, k);
    return BFPRT(a, i + 1, r, k - m);
}
 
int main()
{
    int n, k;
    scanf("%d", &n);
    for(int i = 0; i < n; i++)
        scanf("%d", &a[i]);
    scanf("%d", &k);
    printf("The %d th number is : %d\n", k, BFPRT(a, 0, n - 1, k));
    for(int i = 0; i < n; i++)
        printf("%d ", a[i]);
    puts("");
    return 0;
}
 
/**
10
72 6 57 88 60 42 83 73 48 85
5
*/

三. 时间复杂度分析

BFPRT算法的最坏时间复杂度为 O(n) 。设 T(n) 为时间复杂度,那么很容易有如下公式

\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ T(n)\leq T(\frac{n}{5})+T(\frac{7n}{10})+c\cdot n

  • T(\frac{n}{5}) 来自FindMid(), n 个元素,5个一组,共有 \lceil \frac{n}{5}\rceil 个中位数。
  • T(\frac{7n}{10}) 来自BFPRT(),在 \lceil \frac{n}{5}\rceil 个中位数中,主元pivot大于其中 \frac{1}{2}\cdot \frac{n}{5}=\frac{n}{10} 个中位数,而每个中位数在本来5个数的小组中又大于或等于其中的3个数,所以主元pivot至少大于所有数中的 \frac{n}{10}\cdot 3=\frac{3n}{10}个。即划分之后任意一边的长度至少为 \frac{3}{10} ,在最坏情况下,每次选择都选到了 \frac{7}{10} 的那一部分。
  • c\cdot n 来自其它地方,例如插入排序等其它的额外操作。

证明:T(n)=t\cdot n ,其中 t 可能是一个常数,也可能是关于 n 的函数。带入上式

\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ t\cdot n\leq t\cdot\frac{n}{5}+t\cdot\frac{7n}{10}+c\cdot n\ \ \Rightarrow \ \ t\leq 10c

其中 c 是常数,所以 t 也是常数,即 T(n)\leq 10c\cdot n ,所以 T(n)=O(n)


四. 为什么是5?

在BFPRT算法中,为什么是选5个作为分组?

首先,偶数排除,因为对于奇数来说,中位数更容易计算。

如果选用3,有 T(n)=T(n/3)+T(2n/3)+c\cdot n ,其操作元素个数还是 n

如果选取7,9或者更大,在插入排序时耗时增加,常数 c 会很大,有些得不偿失。

编辑于 2019-01-02