VizTracer实战,一小时性能提升900%

最近一直想在一个实际问题中用一下VizTracer试试,看看效果到底怎么样,今天逛stackexchange的时候发现有人发了一个问题,给了一段代码问哪里可以改进,于是我兴致勃勃地把他的代码拿下来开始研究,在VizTracer的帮助下,用了一个小时左右的时间,把速度提升了900%。

https://codereview.stackexchange.com/questions/248330/recursive-sudoku-solver-using-pythoncodereview.stackexchange.com

这里是英文的过程,在知乎在发一下中文的过程,和大家分享一下VizTracer的能力~

这个楼主自己写了一个sudoku solver,也就是解数独的一个程序。我的目的是用VizTracer,在尽量不读他的代码,不去理解他的思路的情况下,进行性能提升。这也是我觉得VizTracer强大的地方,他可以让你特别直观地看到整个程序的运行,不用去把程序全搞明白。

为了方便部分人阅读,我把它完整代码贴在下面,不感兴趣的请使劲往下翻,直接到实战部分。

from copy import deepcopy
import numpy as np


def create_grid(puzzle_str: str) -> np.ndarray:
    """Create a 9x9 Sudoku grid from a string of digits"""

    # Deleting whitespaces and newlines (\n)
    lines = puzzle_str.replace(' ','').replace('\n','')
    digits = list(map(int, lines))
    # Turning it to a 9x9 numpy array
    grid = np.array(digits).reshape(9,9)
    return grid


def get_subgrids(grid: np.ndarray) -> np.ndarray:
    """Divide the input grid into 9 3x3 sub-grids"""

    subgrids = []
    for box_i in range(3):
        for box_j in range(3):
            subgrid = []
            for i in range(3):
                for j in range(3):
                    subgrid.append(grid[3*box_i + i][3*box_j + j])
            subgrids.append(subgrid)
    return np.array(subgrids)


def get_candidates(grid : np.ndarray) -> list:
    """Get a list of candidates to fill empty cells of the input grid"""

    def subgrid_index(i, j):
        return (i//3) * 3 + j // 3

    subgrids = get_subgrids(grid)
    grid_candidates = []
    for i in range(9):
        row_candidates = []
        for j in range(9):
            # Row, column and subgrid digits
            row = set(grid[i])
            col = set(grid[:, j])
            sub = set(subgrids[subgrid_index(i, j)])
            common = row | col | sub
            candidates = set(range(10)) - common
            # If the case is filled take its value as the only candidate
            if not grid[i][j]:
                row_candidates.append(list(candidates))
            else:
                row_candidates.append([grid[i][j]])
        grid_candidates.append(row_candidates)
    return grid_candidates


def is_valid_grid(grid : np.ndarray) -> bool:
    """Verify the input grid has a possible solution"""

    candidates = get_candidates(grid)
    for i in range(9):
        for j in range(9):
            if len(candidates[i][j]) == 0:
                return False
    return True


def is_solution(grid : np.ndarray) -> bool:
    """Verify if the input grid is a solution"""

    if np.all(np.sum(grid, axis=1) == 45) and \
       np.all(np.sum(grid, axis=0) == 45) and \
       np.all(np.sum(get_subgrids(grid), axis=1) == 45):
        return True
    return False


def filter_candidates(grid : np.ndarray) -> list:
    """Filter input grid's list of candidates"""
    test_grid = grid.copy()
    candidates = get_candidates(grid)
    filtered_candidates = deepcopy(candidates)
    for i in range(9):
        for j in range(9):
            # Check for empty cells
            if grid[i][j] == 0:
                for candidate in candidates[i][j]:
                    # Use test candidate
                    test_grid[i][j] = candidate
                    # Remove candidate if it produces an invalid grid
                    if not is_valid_grid(fill_singles(test_grid)):
                        filtered_candidates[i][j].remove(candidate)
                    # Revert changes
                    test_grid[i][j] = 0
    return filtered_candidates


def merge(candidates_1 : list, candidates_2 : list) -> list:
    """Take shortest candidate list from inputs for each cell"""

    candidates_min = []
    for i in range(9):
        row = []
        for j in range(9):
            if len(candidates_1[i][j]) < len(candidates_2[i][j]):
                row.append(candidates_1[i][j][:])
            else:
                row.append(candidates_2[i][j][:])
        candidates_min.append(row)
    return candidates_min


def fill_singles(grid : np.ndarray, candidates=None) -> np.ndarray:
    """Fill input grid's cells with single candidates"""

    grid = grid.copy()
    if not candidates:
        candidates = get_candidates(grid)
    any_fill = True
    while any_fill:
        any_fill = False
        for i in range(9):
            for j in range(9):
                if len(candidates[i][j]) == 1 and grid[i][j] == 0:
                    grid[i][j] = candidates[i][j][0]
                    candidates = merge(get_candidates(grid), candidates)
                    any_fill = True
    return grid


def make_guess(grid : np.ndarray, candidates=None) -> np.ndarray:
    """Fill next empty cell with least candidates with first candidate"""

    grid = grid.copy()
    if not candidates:
        candidates = get_candidates(grid)
    # Getting the shortest number of candidates > 1:
    min_len = sorted(list(set(map(
       len, np.array(candidates).reshape(1,81)[0]))))[1]
    for i in range(9):
        for j in range(9):
            if len(candidates[i][j]) == min_len:
                for guess in candidates[i][j]:
                    grid[i][j] = guess
                    solution = solve(grid)
                    if solution is not None:
                        return solution
                    # Discarding a wrong guess
                    grid[i][j] = 0


def solve(grid : np.ndarray) -> np.ndarray:
    """Recursively find a solution filtering candidates and guessing values"""

    candidates = filter_candidates(grid)
    grid = fill_singles(grid, candidates)
    if is_solution(grid):
        return grid
    if not is_valid_grid(grid):
        return None
    return make_guess(grid, candidates)

# # Example usage

# puzzle = """100920000
#             524010000
#             000000070
#             050008102
#             000000000
#             402700090
#             060000000
#             000030945
#             000071006"""

# grid = create_grid(puzzle)
# solve(grid)
```

是不是翻了很久?这只是一个人随便写的一小段程序,不妨试想如果你要一头扎进比这个大得多的程序里去做优化,需要花多久去理解代码?

VizTracer说,不需要花时间,直接跑VizTracer!

当然了,我们需要把他程序跑一下先,看看用时。在我的电脑上,这个程序跑了4.5秒,这是我们的baseline

接着跑了一下VizTracer,大概花了7-8秒(WSL的overhead比正常大非常多)拿到下面这个图:

可以看到密密麻麻一大片(后面还有,这已经zoom了一下了)。但是大体结构一下就清晰了,蓝色的filter_candidates call了一大堆粉色的fill_singles,占据了绝大部分的时间线。让我们再zoom in一下。

粉色的fill_singles call了一大堆绿色的get_candidates。这个函数被疯狂地使用,占了绝大部分的run time。

于是我们第一个目标已经有了,什么都不管,直奔get_candidates

 def get_candidates(grid : np.ndarray) -> list:
    """Get a list of candidates to fill empty cells of the input grid"""

    def subgrid_index(i, j):
        return (i//3) * 3 + j // 3

    subgrids = get_subgrids(grid)
    grid_candidates = []
    for i in range(9):
        row_candidates = []
        for j in range(9):
            # Row, column and subgrid digits
            row = set(grid[i])
            col = set(grid[:, j])
            sub = set(subgrids[subgrid_index(i, j)])
            common = row | col | sub
            candidates = set(range(10)) - common
            # If the case is filled take its value as the only candidate
            if not grid[i][j]:
                row_candidates.append(list(candidates))
            else:
                row_candidates.append([grid[i][j]])
        grid_candidates.append(row_candidates)
    return grid_candidates

这个量级的代码是不是读起来就容易多了?

我第一个发现的是,在他for loop的最后有一个判断,如果grid[i][j]是空的,就把candidatesappendrow_candidates,否则直接把[grid[i][j]] 放进去。但是如果是else的情况,他for loop里面东西就白算了啊。所以尽管我压根不知道他算这些东西有啥用,我还是直接把else里的判断拿到前面,然后如果符合,跳过loop。

    for i in range(9):
        row_candidates = []
        for j in range(9):
            if grid[i][j]:
                row_candidates.append([grid[i][j]])
                continue
            # Row, column and subgrid digits
            row = set(grid[i])
            col = set(grid[:, j])
            sub = set(subgrids[subgrid_index(i, j)])
            common = row | col | sub
            candidates = set(range(10)) - common
            row_candidates.append(list(candidates)) 

这一个优化,直接把run time拉到了2.3s。

接下来,我发现在这个for loop里面,row col sub 每次都会被计算。然而他们实际上只需要算9次,并不需要算81次。有72次计算是完全重复的。

于是我就把这9次计算直接拿到了for loop外面,然后在里面直接用算出来的结果

    row_sets = [set(grid[i]) for i in range(9)]
    col_sets = [set(grid[:, j]) for j in range(9)]
    subgrid_sets = [set(subgrids[i]) for i in range(9)]
    total_sets = set(range(10))

    for i in range(9):
        row_candidates = []
        for j in range(9):
            if grid[i][j]:
                row_candidates.append([grid[i][j]])
                continue
            # Row, column and subgrid digits
            row = row_sets[i]
            col = col_sets[j]
            sub = subgrid_sets[subgrid_index(i, j)]
            common = row | col | sub
            candidates = total_sets - common
            # If the case is filled take its value as the only candidate
            row_candidates.append(list(candidates))
        grid_candidates.append(row_candidates)
    return grid_candidates

run time来到了1.5s。在15分钟之内,我把速度提升到了原来的300%。

这就是profiling的价值,我们老板说过,never optimize without profiling。你必须要知道你的程序到底是慢在哪儿,再去做优化。

我们通过VizTracer迅速定位了get_candidates这个函数,并且在完全不管它到底在做什么的情况下,通过相同逻辑的转换,争取了几倍的速度。

这时,VizTracer的overhead已经比较明显了,我就把c function的tracing给关掉了,因为那些也确实不本质,会影响我对函数占比的判断。

我又跑了一下,zoom之后结果大概是这样的:

在关c function之前,由于get_subgridsmerge用了大量的builtin function,所以显得占比很高(function entry/exit的instrumentation打来的影响),关掉之后发现其实这两个函数占比没那么大,优化它们价值不明显。

get_candidates明确的思路用的差不多之后,我开始关注函数之间的关系。这里就是VizTracer擅长而cProfile无能为力的地方了。我通过VizTracer观察到,每一个fill_singles都call了一大堆的get_candidates,所以我就去看看fill_singles里面,到底什么东西在做这件事,以及能不能让他少call几次。

 def fill_singles(grid : np.ndarray, candidates=None) -> np.ndarray:
    """Fill input grid's cells with single candidates"""

    grid = grid.copy()
    if not candidates:
        candidates = get_candidates(grid)
    any_fill = True
    while any_fill:
        any_fill = False
        for i in range(9):
            for j in range(9):
                if len(candidates[i][j]) == 1 and grid[i][j] == 0:
                    grid[i][j] = candidates[i][j][0]
                    candidates = merge(get_candidates(grid), candidates)
                    any_fill = True
    return grid

这段代码简单看一下,一个while loop里面套着一个81层for loop,每层如果这个点只有一个candidate并且没被填上,就填上,然后更新candidates。这逻辑本身肯定没有问题,但是老更新candidates这个代价太大了(我们从VizTracer上知道的)。这里我思考了一下,能不能不那么频繁地更新candidates?毕竟就填了一个格子。一个简单直接的优化就诞生了,因为一个格子填了一个数n之后,其他的只有一个candidate并且不是n的格子,是不会受到影响的。也就是说,在这一波81层循环里,我们可以尝试多填几个只有一个candidate的格子,只要数不重就不会有影响,然后出来了再更新一下candidate。我用了set来记录:

def fill_singles(grid : np.ndarray, candidates=None) -> np.ndarray:
    """Fill input grid's cells with single candidates"""

    grid = grid.copy()
    if not candidates:
        candidates = get_candidates(grid)
    any_fill = True
    while any_fill:
        any_fill = False
        filled_number = set()
        for i in range(9):
            for j in range(9):
                if len(candidates[i][j]) == 1 and grid[i][j] == 0 and candidates[i][j][0] not in filled_number:
                    grid[i][j] = candidates[i][j][0]
                    filled_number.add(candidates[i][j][0])
                    any_fill = True
        candidates = merge(get_candidates(grid), candidates)
    return grid

这个优化把run time带到了0.9s。这里需要稍稍理解一下这个函数是做什么的了,不再是无脑优化了。但是我们基本上还是盯着一个部分在理解,并不费劲。

再往上看一层,fill_singles是被filter_candidates使用的,而filter_candidates用它的时候,只在乎一件事,就是fill_singles结束之后,这个grid是不是valid的。

                    if not is_valid_grid(fill_singles(test_grid)):
                        filtered_candidates[i][j].remove(candidate)

我们回到fill_singles思考一下,这个函数其实很早就可以知道这个grid到底是不是valid的了。它在做循环的时候 ,一旦发现一个格子里没有任何candidate,就说明这个grid不valid了(我从is_valid_grid看的……)。也就是说,我们只需要判断这件事,就可以提前跳出这个万恶的fill_singles,从而节省我们在get_candidates里消耗的时间。

于是我稍稍调整了一下结构,让fill_singles在发现grid不valid的时候直接返回None,一下又节省了很多个get_candidates

最后的run time定格在0.5s,是原来程序的1/9。我从对这个程序一无所知,到完成优化,一共花了大约一个小时。

这个程序当然肯定还有优化的空间,但是我们回头想一下,我们完成的这些优化,都是极为“高效”的。我们修改的代码非常少,几乎没动原来的结构和算法,做的优化就是单纯地找到症结点,对症下药。直到优化结束,这个代码好多个函数我看都没看过一眼。

而这,是我们面对一个比较大的项目的时候希望的优化和debug方式。如果面对一个项目,你得把所有地方都搞明白才能开始优化和debug,那就天荒地老了。

我之所以写VizTracer,就是希望可以实现这种高效的优化和debug形式,完成cProfile完成不了的事情。

这是我第一篇实战记录,希望大家喜欢~最后还是放上VizTracer的github链接~走过路过星星点一个~

gaogaotiantian/viztracergithub.com图标

编辑于 10-25

文章被以下专栏收录