有效率的 Python 數值運算

有效率的 Python 數值運算 #

機器學習需要大量的數值運算,然而純 Python 在數值運算上很慢,需要其它套件輔助。本文介紹如何善用各種套件寫出快速有效率的 Python 程式。

本文以機器學習常見的線性迴歸作為範例,介紹如何使用 NumpyNumba 以及 Cython 做加速。

\[f(X, W, Y) = ||XW - Y||_2\] \[X \in \R^{N \times D}, W \in \R^D, Y \in \R^N\]

import numpy as np

X = np.random.random((10000, 100))
W = np.random.random(100)
Y = np.random.random(10000)

def f_pure_python(X, W, Y):
    loss = 0
    for i in range(X.shape[0]):
        val = 0
        for j in range(X.shape[1]):
            val += X[i, j] * W[j]
        loss += (val - Y[i]) ** 2
    loss = math.sqrt(loss)
    return loss
%timeit f_pure_python(X, W, Y)
# 1 loop, best of 3: 521 ms per loop

Numpy #

Numpy vecotrzie 是很強大的功能,可以整個 array 一起計算,會有顯著加速。可以用 np.testing.assert_almost_equal 檢查,確實回傳幾乎一樣的結果。因為是浮點數運算,所以用 almost equal 檢查允許一點誤差。

def f_numpy(X, W, Y):
    loss = 0
    for i in range(X.shape[0]):
        val = (X[i] * W).sum()
        loss += np.power(val - Y[i], 2)
    return np.sqrt(loss)
%timeit f_numpy(X, W, Y)
# 10 loops, best of 3: 68.3 ms per loop

其實 Numpy 有提供矩陣乘法:

%timeit np.sqrt(np.sum(np.power(np.matmul(X,W[:, np.newaxis]).reshape([-1]) - Y, 2)))
# 1000 loops, best of 3: 594 µs per loop

Numba #

另外有 Numba 套件,可以用簡單的 decorater @numba.jit 加到要加速的 function 上,jit 是 just-in-time compiler 的意思,會在執行時即時將 Python code 編譯成機器碼,得到接近 C 或 Fortran 的速度。因為是即時編譯,所以第一次執行時需要編譯程式,會比較慢。

@numba.jit(nopython=True)
def f_numba(X, W, Y):
    loss = 0
    for i in range(X.shape[0]):
        val = 0
        for j in range(X.shape[1]):
            val += X[i, j] * W[j]
        loss += (val - Y[i]) ** 2
    loss = math.sqrt(loss)
    return loss
# 1000 loops, best of 3: 1.18 ms per loop

速度比完全使用 Numpy 慢一點點,Numba 有一個 fastmath 模式,犧牲一點點精確度來加速。這邊有簡單的 performance tips 可以參考。

@numba.jit(nopython=True, fastmath=True)
def f_numba_fastmath(X, W, Y):
    loss = 0
    for i in range(X.shape[0]):
        val = 0
        for j in range(X.shape[1]):
            val += X[i, j] * W[j]
        loss += (val - Y[i]) ** 2
    loss = math.sqrt(loss)
    return loss
# The slowest run took 566.47 times longer than the fastest. This could mean that an intermediate result is being cached.
# 1 loop, best of 3: 421 µs per loop

Cython #

另外也有 Cython,顧名思義,就是 C + Pyhon,可以 Python 跟 C/C++ 混合編寫,事先編譯後可以得到 C/C++ 的效能。在 IPyhon 環境終使用 Cython 必須先執行 %load_ext cython 載入 Cython 模組。載入模組後在要寫 Cython code 的 cell 開頭先寫 %%cython,表示這個 cell 會被當成獨立的 Cython 模組,沒辦法使用其他 cell 的變數。

%%cython -a
cimport cython
cimport numpy as np
from libc.math cimport sqrt, pow
from cython.parallel import prange

@cython.boundscheck(False)
@cython.wraparound(False)
def f_cython(double[:, :] X, double[:] W, double[:] Y):
    cdef double loss = 0
    cdef double val
    cdef Py_ssize_t n = X.shape[0]
    cdef Py_ssize_t d = X.shape[1]
    cdef Py_ssize_t i, j

    for i in prange(n, nogil=True):
        val = 0
        for j in range(d):
            val += X[i, j] * W[j]
        loss += pow(val - Y[i], 2)
    loss = sqrt(loss)
    return loss

在其他 cell 就可以使用 f_cython 這個函式,得到接近 fastmath=False 的 Numba 效率。

%timeit f_cython(X, W, Y)
# 1000 loops, best of 3: 1.29 ms per loop

如果不使用 IPython 的情況下要使用 Cython,必須將 Cython 程式碼寫在 .pyx 檔裡面,如 my_cython.pyx

cimport cython
cimport numpy as np
from libc.math cimport sqrt, pow
from cython.parallel import prange

@cython.boundscheck(False)
@cython.wraparound(False)
def f_cython(double[:, :] X, double[:] W, double[:] Y):
    cdef double loss = 0
    cdef double val
    cdef Py_ssize_t n = X.shape[0]
    cdef Py_ssize_t d = X.shape[1]
    cdef Py_ssize_t i, j

    for i in prange(n, nogil=True):
        val = 0
        for j in range(d):
            val += X[i, j] * W[j]
        loss += pow(val - Y[i], 2)
    loss = sqrt(loss)
    return loss

以及 setup.py

from setuptools import setup
from Cython.Build import cythonize

setup(
    ext_modules = cythonize("my_cython.pyx")
)

使用前先編譯:

python setup.py build_ext --inplace

最後,在程式碼加入 from my_cython import f_cython 即可使用。

版本執行時間
Python521ms
Numpy68.3 ms
Numpy 內建矩陣乘法594 µs
Numba1.18 ms
Numba fastmath421 µs
Cython1.29 ms

注意,本文程式使用 Google colab 做效率測試,不同機器、不同 Python 版本、不同套件版本、不同 C/C++ compiler 、甚至 Intel 或 AMD CPU 都會影響效率。在另一台機器上,Numba (無 fastmath 版本)可能不會比 Cython 快。不過 Numpy 使用內建矩陣乘法版本會比 Numpy 快,Numbafastmath 比沒有快是肯定的。

參考資料:

Numba performance tips

Cython tutorial

Cython for Numpy users

IPython Codebook Chapter 4

comments powered by Disqus