Triton Puzzles¶
Preliminary¶
CUDA 编写属于传统的 “单程序,多数据” GPU 执行模型,在线程的细粒度上进行编程,Triton 是在分块的细粒度上进行编程。
-
triton 启动的方式和 cuda 类似,
-
triton.load(ptr, mask, other)
- https://triton-lang.org/main/python-api/generated/triton.language.load.html
- 从
ptr
地址开始加载数据,mask
用于指定哪些字节是有效的(这些字节会被加载),其他字节会被填充为other
。即若mask[i]
为 false,则x[i] = other[i]
。
1D
-
triton.store(ptr, value, mask)
和 load 类似,将value
存储到ptr
地址开始的数据中,mask
用于指定哪些字节是有效的(这些字节会被存储),其他字节会被忽略。 -
可以使用 program id 来同时处理多个数据块。通过
tl.program_id(dim)
获得当前块在 dim 维度上的 id。Example
""" Print for each [0] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [1] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [2] [1. 1. 1. 1. 0. 0. 0. 0.] """ @triton.jit def demo4(x_ptr): pid = tl.program_id(0) range = tl.arange(0, 8) + pid * 8 # 1 block for 8 elements x = tl.load(x_ptr + range, range < 20) # flatten print("Print for each", pid, x) def run_demo4(): print("Demo4 Output: ") x = torch.ones(2, 4, 4) demo4[(3, 1, 1)](x) # 3 blocks print_end_line()
Puzzles¶
-
Puzzle 6: Fused Outer Multiplication - Backwards
\[ \begin{aligned} f(x, y) &= \text{relu}(x_{j, i} \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1\\ dx_{j, i} &= f_x'(x, y)_{j, i} \times dz_{j, i} \end{aligned} \]@triton.jit def mul_relu_block_back_kernel( x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr ): block_id_i = tl.program_id(0) block_id_j = tl.program_id(1) x_off = block_id_i * B0 + tl.arange(0, B0) y_off = block_id_j * B1 + tl.arange(0, B1) xji_off = x_off[None, :] + y_off[:, None] * N0 mask_x = x_off < N0 mask_y = y_off < N1 mask_xji = mask_x[None, :] & mask_y[:, None] x = tl.load(x_ptr + xji_off, mask_xji) y = tl.load(y_ptr + y_off, mask_y) z = tl.load(dz_ptr + xji_off, mask_xji) # 这里不能少 [:, None] # 注意这里 y: [B1], x: [B1, B0] 若 x * y,y 是行向量,会和 x 的每一列相乘。(因为 broadcast 会扩充为 [1, B1]) # 扩充为列向量 [B1, 1] 则会和 x 的每一行相乘。 y_ext = y[:, None] fxy = tl.where(x * y_ext > 0, y_ext, 0) ans = fxy * z tl.store(dx_ptr + xji_off, ans, mask_xji) return
-
Puzzle 9: Simple FlashAttention
\[ z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)_j v_{j} \text{ for } i = 1\ldots N_0 \]这里的思想来自 flash attention:@triton.jit def flashatt_kernel( q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr ): block_id_i = tl.program_id(0) log2_e = 1.44269504 myexp = lambda x: tl.exp2(log2_e * x) i_off = block_id_i * B0 + tl.arange(0, B0) mask_i = i_off < N0 q = tl.load(q_ptr + i_off, mask_i) max = tl.full([B0], -10, dtype=tl.float32) sum = tl.zeros([B0],dtype=tl.float32) z = tl.zeros([B0],dtype=tl.float32) for j in tl.range(0, T, B1): j_off = j + tl.arange(0, B1) mask_j = j_off < T k = tl.load(k_ptr + j_off, mask_j) mask_ij = mask_i[:, None] & mask_j[None, :] tmp = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -1.0e6) x_max = tl.max(tmp, axis=1) # local max new_max = tl.maximum(max, x_max) # gloabl max old_sum = sum sum = sum * myexp(max - x_max) + tl.sum(myexp(tmp - new_max[:, None]), axis=1) # 更新 expsum v = tl.load(v_ptr + j_off, mask_j) z = z * myexp(max - x_max) + tl.sum(myexp(tmp - new_max[:, None]) * v[None, :], axis=1) # z 这里是最后的输出(但是没有 ÷ softmax 里的分母) max = new_max # 不能在循环里 否则最开始的 sum=0, nan z /= sum tl.store(z_ptr + i_off, z, mask_i) return
- 计算 \(\sum e^{x-\max{x_i}}\),这里我们在线更新。假设前 i blocks 已经计算完毕,我们得到了 \(Q_i=\max{x_i}, P_i = \sum e^{x-Q_i}\),那么对于第 i+1 block,我们可以得到 \(P_{i+1} = P_i \times e^{Q_i - Q_{i+1}} + \sum e^{x-Q_{i+1}}\)。
- 假设 \(O_i = softmax(x_i)\),这里还需要计算 \(\sum_j O_i\cdot v_j\)。这里我们先抛开 softmax 最后的除法,只计算分子(这里用 \(\hat O_i\) 表示)。假设前 i blocks 我们得到了 \(S_i = \sum_j \hat O_i\cdot v_j\),那么对于第 i+1 block,我们可以得到 \(S_{i+1} = S_i \times e^{Q_i - Q_{i+1}} + \sum_j \hat O_{i+1}\cdot v_j\)。
-
Puzzle 11: Matrix Multiplication
这里拆成了很多小块来计算(类似于 tiling),随后将小块的计算结果累加。
\[ z_{i, j, k} = \sum_{l} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1 \]@triton.jit def dot_kernel( x_ptr, y_ptr, z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr, B_MID: tl.constexpr, ): block_id_j = tl.program_id(0) block_id_k = tl.program_id(1) block_id_i = tl.program_id(2) i_off = block_id_i * B2 + tl.arange(0, B2) j_off = block_id_j * B0 + tl.arange(0, B0) k_off = block_id_k * B1 + tl.arange(0, B1) mask_i = i_off < N2 mask_j = j_off < N0 mask_k = k_off < N1 z = tl.zeros([B2, B0, B1], dtype=tl.float32) # MID 是由 triton 调用的时候传入的 for l in tl.range(0, MID, B_MID): l_off = l + tl.arange(0, B_MID) mask_l = l_off < MID xjl_off = i_off[:, None, None] * N0 * N1 + j_off[None, :, None] * N1 + l_off[None, None, :] mask_xjl = mask_i[:, None, None] & mask_j[None, :, None] & mask_l[None, None, :] xjl = tl.load(x_ptr + xjl_off, mask_xjl) ylk_off = i_off[:, None, None] * N0 * N1 + l_off[None, :, None] * N1 + k_off[None, None, :] mask_ylk = mask_i[:, None, None] & mask_l[None, :, None] & mask_k[None, None, :] ylk = tl.load(y_ptr + ylk_off, mask_ylk) z += tl.dot(xjl, ylk) z_off = i_off[:, None, None] * N0 * N1 + j_off[None, :, None] * N1 + k_off[None, None, :] mask_z = mask_i[:, None, None] & mask_j[None, :, None] & mask_k[None, None, :] tl.store(z_ptr + z_off, z, mask_z) return