兰湾
首发于兰湾

半精度浮点数实验

IEEE-754规定的单精度浮点数是4个字节32位,包括1位符号、8位指数和23位尾数。它能表示的动态范围是 2^{-126}\sim2^{127} ,也就是 10^{-38} \sim 10^{38} ;精度是 \rm{lg} 2^{24} ,大约7个十进制有效数字。(另有其他不同规格的浮点数,不过估计现在极少见了。)

有时需要存储或传输大量对精度要求不高、但动态范围相对较大的数据,2字节的整型最大只能到32767(有符号)或65535(无符号),范围不够,4字节的单精度浮点数存储和传输效率又偏低,怎么办呢?3个字节在大部分场合都得考虑对齐访问,更麻烦,所以要减就减到2字节。

NVidia在2002年提出了半精度浮点数,只使用2个字节16位,包括1位符号、5位指数和10位尾数,动态范围是 2^{-30}\sim 2^{31} 也就是 10^{-9}\sim 10^9 ,精度是 \rm lg2^{11} ,大约3个十进制有效数字。NVidia的方案已经被IEEE-754采纳。Google的TensorFlow则比较简单粗暴,把单精度的后16位砍掉,也就是1位符号、8位指数和7位尾数。动态范围和单精度相同,精度只有 \rm lg 2^8 ,2个有效数字。

这两种半精度浮点数的实际使用效果如何呢?可以写个程序验证一下。

typedef unsigned short half;       // 先定义unsigned short为half

按NVidia方案,把单精度浮点数转成半精度,可以这么做:

half Float2Half(float m)
{
    unsigned long m2 = *(unsigned long*)(&m);    // 强制把float转为unsigned long
    // 截取后23位尾数,右移13位,剩余10位;符号位直接右移16位;
    // 指数位麻烦一些,截取指数的8位先右移13位(左边多出3位不管了)
    // 之前是0~255表示-127~128, 调整之后变成0~31表示-15~16
    // 因此要减去127-15=112(在左移10位的位置).
    unsigned short t = ((m2 & 0x007fffff) >> 13) | ((m2 & 0x80000000) >> 16) 
        | (((m2 & 0x7f800000) >> 13) - (112 << 10));           
    if(m2 & 0x1000) 
        t++;                       // 四舍五入(尾数被截掉部分的最高位为1, 则尾数剩余部分+1)
    half h = *(half*)(&t);     // 强制转为half
    return h ;
}

从半精度转回单精度比较好办, 按格式取出符号位、指数和尾数,再按定义计算,结果保存为float即可。

float Half2Float(half n)
{
    unsigned short frac = (n & 0x3ff) | 0x400;
    int exp = ((n & 0x7c00) >> 10) - 25;
    float m;

    if(frac == 0 && exp == 0x1f)
        m = INFINITY;
    else if (frac || exp)
        m = frac * pow(2, exp);
    else
        m = 0;

    return (n & 0x8000) ? -m : m;
}

最后把实际数据从单精度转成半精度,再转回单精度,计算误差:

    for(float n = 4e-5; n < 6e4; n *= 1.001) { 
        printf("%f, %f,   %.4f\n", n, Half2Float(Float2Half(n)), 
                ((double)n - Half2Float(Float2Half(n))) / n * 100.0); 
    }

实测最大误差0.048%(也就是1/2048),平均绝对误差0.018%,似乎还不错。

Google TensorFlow的方案验证起来就非常简单了,砍掉后16位即可,四舍五入还是要的。

#include <stdio.h>

int main(void)
{
    for(float n = 1e-8; n < 1e8; n *= 1.001) {
        unsigned long k, l;
        k = *(unsigned long*)(&n);
        l = k & 0xffff0000;
        if(k & 0x8000)
            l += 0x10000;                  // 四舍五入
        float m = *(float*)(&l);

        printf("%f, %f, %f\n", n, m, (n - m) / n);
    }
    return 0;
}

最大误差0.39%(也就是1/256),平均绝对误差0.14%。许多场合其实主要关心的只是数量级,用这个也不错。

ps. 补充一下,AMD在显卡上确实用过3个字节/24位的浮点数,哦,那会儿还是ATI。微软在Direct3D API里也提供了支持。现在不知道还有没有了。

编辑于 2018-06-01

文章被以下专栏收录