TypeScript类型元编程:实现8位数的算术运算

失业中在 github 闲逛,看到有人用类型实现的一个4位虚拟机,为什么是4位呢,因为 TypeScript 的类型实例化深度有限制,没法实现太大的数字计算。说到用类型实现数字计算,一大堆邱奇数就冒出来了,但是因为这个限制,这种方法通常只能用于很小的数字,今天我们尝试一种不同的思路,用二进制实现8位的数字计算。

为了实现计算我们要先把数字类型转换成二进制,比如一个0和1的数组类型,但是进制转换又需要数学计算,显然我们是没办法实现的,那么怎么办呢? 很简单,硬编码一个映射就好了。但是对于8位数来说,255个长度为8的数组类型是一段很长的代码,为了减少代码长度,我们可以只编码一个二进制 trie 结构,然后通过搜索路径得到二进制编码。

我们用数组来表示二进制 trie,以3位数(0-7)为例,就是这个样子:

type binaryTrie = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]];

数字的访问路径就是对应的二进制,比如binaryTrie[0][1][1]就是3,也就是二进制的 011。然后我们写一个简单的脚本来生成8位的 trie:

JSON.stringify(
    (function it(n, acc) {
        return n > 0
            ? [it(n - 1, acc), it(n - 1, acc + 2 ** n)]
            : [acc + 0, acc + 1];
    })(7, 0)
);

看一下在编辑器里的样子:

还好,不算太长,但如果是16位数字的话那出来的就是1M多字节了,所以我们只实现8位数的就好了。然后我们来定义一个用来查找 trie 的类型:

type SearchInTrie<Num, Node, Digits> = {
    1: Node extends [infer A, infer B]
        ? Num extends A ? Push<Digits, 0>
        : Num extends B ? Push<Digits, 1>
        : never : never;

    0: Node extends [infer A, infer B]
        ? SearchInTrie<Num, A, Push<Digits, 0>>
        | SearchInTrie<Num, B, Push<Digits, 1>>
        : never;
}[Node extends [number, number] ? 1 : 0];

这是一个递归的类型,三个参数分别是 要查找的数字,当前查找节点,当前查找路径。我们先判断当前节点:如果是叶子节点(1)判断是左右节点的哪一个,把相应的二进制值加入路径并返回,都不是返回never。如果不是叶子节点(0),分别对左右节点进行查找,将结果 union 在一起,由于其中一边肯定查找不到,结果会是never,所以 union 的结果将会是最终查找到的路径。

这里用到了一个Push类型来将二进制位加入数组,它的完整定义如下:

type Copy<T, S extends any> = { [P in keyof T]: P extends keyof S ? S[P] : never };

type Unshift<T, A> = (
    (a: A, ...b: T extends any[] ? T : never) => void
) extends (...a: infer R) => void ? R : never;

type Push<T, A> =
    Copy<Unshift<T, any>, T & Record<string, A>>;

你们可能已经见过Unshift这种用法了,它就是用 conditional type 和函数类型的可变参数去推断出一个在开头插入了新元素的数组类型。目前只有这种办法可以向元组类型中加入元素,而且只能在开头加入。

那么这个Push是怎么实现在末尾加入元素的呢,这就涉及到另一个东西了,就是 mapped type,这个东西有个没有在文档中说明的特性,当用于元组类型时,具体来说,in关键字后面是一个元组类型的 key 时,mapped type的结果也是一个相同长度的元组类型。TS 的隐式规则越来越多了,很多特性都是糊出来的。

所以我们先定义一个Copy类型将T上的属性值覆盖为S的,然后在Push中,我们先往数组开头随意插入一个元素,然后从原来的数组T复制属性过来,但是因为多了一个元素,最后一个元素在T上是没有的,所以我们加一个& Record<string, A>,无论这最后一个元素的 key 是多少,最后肯定会从这个Record上取到A,也就是我们要加入的元素。

回到主题上来,我们现在可以把数字转成二进制了,同时这个 trie 也可以由二进制得到数字:

type Digit = 0 | 1;
type Bits = 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7;
type Uint8 = Record<Bits, Digit>; // 也可以定义成8个Digit的数组,这样写比较简短

// 数字转二进制表示
type ToUint8<A extends number> =
    SearchInTrie<A, BinaryTrie, []>;

// 二进制表示转数字
type ToNumber<A extends Uint8> =
    BinaryTrie[A[0]][A[1]][A[2]][A[3]][A[4]][A[5]][A[6]][A[7]];

那么我们就可以开始实现计算了,最简单的也是最基础的就是加法了,我们要用类型实现一个全加器。让我们先来实现1位的加法器:

// 两个1 bit数相加,C 表示进位
type BitAdd<A extends Digit, B extends Digit, C extends Digit> = [
    [[[0, 0], [1, 0]], [[1, 0], [0, 1]]],
    [[[1, 0], [0, 1]], [[0, 1], [1, 1]]]
][A][B][C];

非常简单,AB是相加的两个1位数字,C是进位标志,返回的类型是两个值的数组,第一个值是和,第二个值是进位标志。

然后把8个1位加法器级联起来就组成了一个8位全加器:

type AsDigit<T> = T extends Digit ? T : never;
type AsUint8<T> = T extends Uint8 ? T : never;

type Uint8Add<A extends Uint8, B extends Uint8> =
    BitAdd< A[7], B[7], 0> extends [infer S7, infer C]
    ? BitAdd<A[6], B[6], AsDigit<C>> extends [infer S6, infer C]
    ? BitAdd<A[5], B[5], AsDigit<C>> extends [infer S5, infer C]
    ? BitAdd<A[4], B[4], AsDigit<C>> extends [infer S4, infer C]
    ? BitAdd<A[3], B[3], AsDigit<C>> extends [infer S3, infer C]
    ? BitAdd<A[2], B[2], AsDigit<C>> extends [infer S2, infer C]
    ? BitAdd<A[1], B[1], AsDigit<C>> extends [infer S1, infer C]
    ? BitAdd<A[0], B[0], AsDigit<C>> extends [infer S0, infer C]
    // ? C extends 1 ? "overflow" :
    ? AsUint8<[S0, S1, S2, S3, S4, S5, S6, S7]>
    : never : never : never : never : never : never : never : never;

这里我们用 infer 语法当作变量保存每一步的结果,这个语法直接把“函数式”变成了“过程式”(笑),不过 infer 出来的类型变量是没有类型约束的,所以我们额外定义了两个类型AsDigitAsUint8来把结果 assert 成期望的类型,主要是为了满足泛型参数的约束检查。从中间的注释可以看到,我们没有对加法溢出进行处理,溢出会得到循环后的结果,这也是实现减法的原理。

我们来看一下,8 位二进制数能表示的最大数字,8 位全是 1,十进制是 255,如果我们对它加 1 会怎样?因为每位都是 1,所以会不断的进位最终得到全 0。如果加 2 呢,其实就相当于加 1 再加 1,第一次加 1 得到 0,第二次加 1 就得到 1。所以我们可以简单的认识到,有限位的二进制数字是循环的,因为最高位溢出后就回归到 0 了。再举一个例子,如果 255 加 256 呢?由于 256 是超出 8 位数范围的,我们拆开成加 1 + 255。加 1 得到 0,再加 255 又得到 255。因为 0-255 共 256 个数,所以任何数加 256 相当于循环一圈,得到的还是原来的数。

我们再来看对一个数字减 1 会如何。以0 - 1为例,根据前面的推论我们知道,任何 8 位数字加 256 结果不变,我们把0 - 1表示为0 + 256 - 1,就变成了计算0 + (256 - 1),再进一步地,由于 256 超出了 8 位数范围,我们改写为0 + (255 - 1 + 1),我们知道 255 的二进制 8 位全是 1,它减任何一个 8 位数,我们逐位相减可以发现,如果被减数字的某一位为 0,那么结果为1 - 0 = 1,如果为 1 结果为1 - 1 = 0,所以我们可以简单地把被减数每一位取反,得到的就是 255 减去它的结果。所以0 + (255 - 1 + 1)就变成了0 + (对1按位取反 + 1),所以0 - 1就等于“对 1 按位取反再加 1”(这个1是我们从256里拿出来的,不是减数1,不要搞混了),结果是 8 位全 1 也就是 255。和加法溢出类似,减法也产生了溢出循环的效果。这其实就是-1 的二进制表示,“按位取反再加 1”就是所谓的“补码”。在有符号的数字系统里,数字表示范围的后一半就是用来表示负数(这就是为什么有符号数上溢会产生负数,但实际上-1 + 1 = 0 才是二进制溢出的结果),对于计算机来说正数负数都是一样的,都是在一个环上,是我们通过指令的语义或语言的类型赋予了数字正负的含义。

再回到我们的主题,要实现A - B的计算,其实就是实现A + (-B),而负数我们上面已经说了就是补码,只不过我们的数字系统里没有负数,但是对于二进制来说都是一样的。

我们先来实现取反:

type Reverse = [1, 0];

type Uint8Reverse<A extends Uint8> = [
    Reverse[A[0]],
    Reverse[A[1]],
    Reverse[A[2]],
    Reverse[A[3]],
    Reverse[A[4]],
    Reverse[A[5]],
    Reverse[A[6]],
    Reverse[A[7]]
];

然后实现补码,先取反再加 1:

type ONE = [0, 0, 0, 0, 0, 0, 0, 1];

type Uint8Negate<A extends Uint8> =
    Uint8Add<Uint8Reverse<A>, ONE>;

减法就是加上被减数的补码:

type Uint8Sub<A extends Uint8, B extends Uint8> = 
    Uint8Add<A, Uint8Negate<B>>;

回顾一下,我们通过 1 位数的加法实现了 8 位数的加法,又通过加法实现减法。接下来我们又会用加法和减法实现除法,不过让我们先来看一下如何实现乘法。根据小学的数学知识,只要把被乘数和乘数的每一位相乘再相加就可以了,这里面其实还会有进位,还有与被乘数每一位相乘的操作其实还乘了位权,比如乘以 21,这个 2 其实是 20,要补上与位权相应的 0。

示例:

\begin{array}{rll} \phantom{\times00}11\\ \underline{\times\phantom{00}21}\\ \phantom{\times00}11\\ \underline{\phantom{\times0}220}\\ \phantom{\times0}231\\ \end{array}

二进制和十进制的算法是一样的,只是更简单,因为被乘数字只会有 0 或 1 两种,为 0 时结果也是 0,为 1 时数字不变,就只要补 0 就可以了,也就是左移操作。

示例:

\begin{array}{rll} \phantom{\times000}111\\ \underline{\times\phantom{000}101}\\ \phantom{\times000}111\\ \phantom{\times00}0000\\ \underline{\phantom{\times0}11100}\\ \phantom{\times}100011\\ \end{array}

我们先来实现左移,为了方便使用,我们用一个额外的参数指定左移时填补的数字:

type LShift<A extends Uint8, B extends number, P extends Digit> =
    B extends 1 ? [A[1], A[2], A[3], A[4], A[5], A[6], A[7], P]
    : B extends 2 ? [A[2], A[3], A[4], A[5], A[6], A[7], P, P]
    : B extends 3 ? [A[3], A[4], A[5], A[6], A[7], P, P, P]
    : B extends 4 ? [A[4], A[5], A[6], A[7], P, P, P, P]
    : B extends 5 ? [A[5], A[6], A[7], P, P, P, P, P]
    : B extends 6 ? [A[6], A[7], P, P, P, P, P, P]
    : B extends 7 ? [A[7], P, P, P, P, P, P, P]
    : B extends 0 ? A : [P, P, P, P, P, P, P, P];

然后实现逐位的乘法,额外提供一个参数指示位移长度,这里左移填充0,但是后面我们还会用到第三个参数:

type ZERO = [0, 0, 0, 0, 0, 0, 0, 0];

type BitMul<A extends Uint8, B extends Digit, C extends Bits> =
    B extends 1 ? LShift<A, C, 0> : ZERO;

最后实现完整的乘法,对每一位运算的结果进行累加就是乘积了:

type Uint8Mul<A extends Uint8, B extends Uint8> = 
    Uint8Add<ZERO, BitMul<A, B[7], 0>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[6], 1>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[5], 2>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[4], 3>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[3], 4>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[2], 5>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[1], 6>> extends infer S
    ? Uint8Add<AsUint8<S>, BitMul<A, B[0], 7>>
    : never : never : never : never : never : never : never;

除法的算法也不过是小学知识,就是从被除数的最高位开始与除数比较,大于就相除并记录商和余数,不断地将被除数的下一位补到余数的末位来,直到除完所有数位。先来看 10 进制的:

\begin{array}{rll}     \phantom{00}031\phantom{0} && \hbox{(步骤)} \\[-3pt]    4 \enclose{longdiv}{\phantom{00}125\phantom{0}}\kern-.2ex \\[-3pt]       \underline{\phantom{00}0\phantom{000}} && \hbox{($1 \div 4 = 0 \phantom0余1$)} \\[-3pt]       12\phantom{00} && \hbox{($追加下一位数$)} \\[-3pt]       \underline{\phantom{00}12\phantom{00}} && \hbox{($12 \div 4 = 3 \phantom0余0$)} \\[-3pt]       05\phantom{0} && \hbox{($追加下一位数$)} \\[-3pt]       \underline{\phantom{0000}4\phantom{0}} && \hbox{($5 \div 4 = 1\phantom0余1$)} \\[-3pt]       \phantom{000}1\phantom{0} && \hbox{($最终得到31余1$)} \\[-3pt] \end{array}

二进制的更简单了,因为每次求余的结果,商要么为 0 要么为 1,也就是说不会超过 2 倍,只要比较大小相减即可:

\begin{array}{rll}     \phantom{0}0100\phantom{0} && \hbox{(步骤)} \\[-3pt]     11 \enclose{longdiv}{\phantom{00}1101\phantom{0}}\kern-.2ex \\[-3pt]     \underline{\phantom{0}0\phantom{0000}} && \hbox{($1 \div 11 = 0 \phantom0余1$)} \\[-3pt]     \phantom{0}11\phantom{000} && \hbox{($追加下一位数$)} \\[-3pt]     \underline{\phantom{0}11\phantom{000}} && \hbox{($11 \div 11 = 1 \phantom0余0$)} \\[-3pt]     \phantom{0}00\phantom{00} && \hbox{($追加下一位数$)} \\[-3pt]     \underline{\phantom{0}0\phantom{00}} && \hbox{($0 \div 11 = 0\phantom0余0$)} \\[-3pt]     \phantom{0}01\phantom{0} && \hbox{($追加下一位数$)} \\[-3pt]     \underline{\phantom{0}0\phantom{0}} && \hbox{($1 \div 11 = 0 \phantom0余1$)} \\[-3pt]     \phantom{000}1\phantom{0} && \hbox{($最终得到100余1$)} \\[-3pt] \end{array}

先来实现一个比较器,从高位到低位比较,不相等就返回比较结果,相等就继续比较下一位:

type EQ = 0;
type GT = 1;
type LT = 2;

type BitCMP<A extends Digit, B extends Digit> =
    [[EQ, LT], [GT, EQ]][A][B];

type Uint8CMP<A extends Uint8, B extends Uint8> =
    BitCMP<A[0], B[0]> extends GT | LT ? BitCMP<A[0], B[0]>
    : BitCMP<A[1], B[1]> extends GT | LT ? BitCMP<A[1], B[1]>
    : BitCMP<A[2], B[2]> extends GT | LT ? BitCMP<A[2], B[2]>
    : BitCMP<A[3], B[3]> extends GT | LT ? BitCMP<A[3], B[3]>
    : BitCMP<A[4], B[4]> extends GT | LT ? BitCMP<A[4], B[4]>
    : BitCMP<A[5], B[5]> extends GT | LT ? BitCMP<A[5], B[5]>
    : BitCMP<A[6], B[6]> extends GT | LT ? BitCMP<A[6], B[6]>
    : BitCMP<A[7], B[7]>;

再实现用于迭代的简单求余器,返回商和余数两个值,被除数小于除数则商为 0 余数为被除数,否则商为 1 余数为两数之差,这里用到了我们前面实现的减法:

type Remainder<A extends Uint8, B extends Uint8> =
    Uint8CMP<A, B> extends LT ? [0, A] : [1, Uint8Sub<A, B>];

最后,让我们来实现完整的除法运算:

type Uint8Div<A extends Uint8, B extends Uint8> =
    Remainder<LShift<ZERO, 1, A[0]>, B> extends [infer Q0, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[1]>, B> extends [infer Q1, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[2]>, B> extends [infer Q2, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[3]>, B> extends [infer Q3, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[4]>, B> extends [infer Q4, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[5]>, B> extends [infer Q5, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[6]>, B> extends [infer Q6, infer R]
    ? Remainder<LShift<AsUint8<R>, 1, A[7]>, B> extends [infer Q7, infer R]
    ? [AsUint8<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>, AsUint8<R>]
    : never : never : never : never : never : never : never : never;

我们把被除数从高位到低位逐位后缀到每一步的余数后面,这里用到了我们左移操作的第三个参数。然后不断对新的数求余,并保存每一位得到的商,然后返回最终的商和最终的余数。

最后,我们来定义几个便于使用的类型:

// 加
type Add<A extends number, B extends number> =
        ToNumber<Uint8Add<ToUint8<A>, ToUint8<B>>>;
// 减
type Sub<A extends number, B extends number> =
    ToNumber<Uint8Sub<ToUint8<A>, ToUint8<B>>>;
// 乘
type Mul<A extends number, B extends number> =
    ToNumber<Uint8Mul<ToUint8<A>, ToUint8<B>>>;
// 除
type Div<A extends number, B extends number> =
    B extends 0 ? never :
    ToNumber<Uint8Div<ToUint8<A>, ToUint8<B>>[0]>;
// 取余
type Mod<A extends number, B extends number> =
    B extends 0 ? never :
    ToNumber<Uint8Div<ToUint8<A>, ToUint8<B>>[1]>;

然后我们简单的测试一下:

type case1_ShouldBe99 = Add<33, 66>;    // 33 + 66 = 99
type case2_ShouldBe0 = Add<255, 1>;     // 255 + 1 = 0 (overflow)

type case3_ShouldBe99 = Sub<123, 24>;   // 123 - 24 = 99
type case4_ShouldBe255 = Sub<0, 1>;     // 0 - 1 = 255 (overflow)

type case5_ShouldBe153 = Mul<17, 9>;    // 17 x 9 = 153
type case6_ShouldBe253 = Mul<255, 3>;   // 255 x 3 = 253 (overflow)

type case7_ShouldBe33 = Div<100, 3>;    // 100 / 3 = 33
type case8_ShouldBeNever = Div<1, 0>;   // 1 / 0 = error (divide by 0)

type case9_ShouldBe1 = Mod<100, 3>;     // 100 % 3 = 1
type case10_ShouldBeNever = Mod<1, 0>;  // 1 % 0 = error (divide by 0)

最后放上 Playground 可以在线试一下。下一篇我再说一下如何用类型实现一个 parser 来从一个 token 数组中解析表达式和语法,会用到更多奇技淫巧,如果有人想看的话。

编辑于 04-12

文章被以下专栏收录