跳转至

Triton Puzzles

Abstract

关于 Triton Puzzles 的介绍:[MLSys 入门向] 做12道题,快速上手Triton!

Preliminary

CUDA 编写属于传统的 “单程序,多数据” GPU 执行模型,在线程的细粒度上进行编程,Triton 是在分块的细粒度上进行编程。

  • triton 启动的方式和 cuda 类似,

    @triton.jit
    def kernel(x_ptr, ...):
        ...
    
    def run_kernel(x, ...):
        kernel[(1, 1, 1)](...)          # grid size
    

  • triton.load(ptr, mask, other)

    1D

    range = tl.arange(0, 8)                   # [0, 1, 2, 3, 4, 5, 6, 7]
    x = tl.load(x_ptr + range, range < 5, 0)  # mask: [1, 1, 1, 1, 1, 0, 0, 0]
    print(x)
    

    2D

    i_range = tl.arange(0, 8)[:, None] # 8 * 1
    j_range = tl.arange(0, 4)[None, :] # 1 * 4  for broadcasting 
    range = i_range * 4 + j_range
    x = tl.load(x_ptr + range, (i_range < 4) & (j_range < 3), 0)
    
  • triton.store(ptr, value, mask) 和 load 类似,将 value 存储到 ptr 地址开始的数据中,mask 用于指定哪些字节是有效的(这些字节会被存储),其他字节会被忽略。

  • 可以使用 program id 来同时处理多个数据块。通过 tl.program_id(dim) 获得当前块在 dim 维度上的 id。

    Example

    """
    Print for each [0] [1. 1. 1. 1. 1. 1. 1. 1.]
    Print for each [1] [1. 1. 1. 1. 1. 1. 1. 1.]
    Print for each [2] [1. 1. 1. 1. 0. 0. 0. 0.]
    """
    @triton.jit
    def demo4(x_ptr):
        pid = tl.program_id(0)
        range = tl.arange(0, 8) + pid * 8       # 1 block for 8 elements
        x = tl.load(x_ptr + range, range < 20)  # flatten
        print("Print for each", pid, x)
    
    def run_demo4():
        print("Demo4 Output: ")
        x = torch.ones(2, 4, 4)
        demo4[(3, 1, 1)](x)     # 3 blocks
        print_end_line()
    

Puzzles

  • Puzzle 6: Fused Outer Multiplication - Backwards

    \[ \begin{aligned} f(x, y) &= \text{relu}(x_{j, i} \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1\\ dx_{j, i} &= f_x'(x, y)_{j, i} \times dz_{j, i} \end{aligned} \]
    @triton.jit
    def mul_relu_block_back_kernel(
        x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
    ):
        block_id_i = tl.program_id(0)
        block_id_j = tl.program_id(1)
        x_off = block_id_i * B0 + tl.arange(0, B0)
        y_off = block_id_j * B1 + tl.arange(0, B1)
        xji_off = x_off[None, :] + y_off[:, None] * N0
        mask_x = x_off < N0
        mask_y = y_off < N1
        mask_xji = mask_x[None, :] & mask_y[:, None]
        x = tl.load(x_ptr + xji_off, mask_xji)
        y = tl.load(y_ptr + y_off, mask_y)
        z = tl.load(dz_ptr + xji_off, mask_xji)
        # 这里不能少 [:, None]
        # 注意这里 y: [B1], x: [B1, B0] 若 x * y,y 是行向量,会和 x 的每一列相乘。(因为 broadcast 会扩充为 [1, B1])
        # 扩充为列向量 [B1, 1] 则会和 x 的每一行相乘。
        y_ext = y[:, None]
        fxy = tl.where(x * y_ext > 0, y_ext, 0)
        ans = fxy * z
        tl.store(dx_ptr + xji_off, ans, mask_xji)
        return
    
  • Puzzle 9: Simple FlashAttention

    \[ z_{i} = \sum_{j=1}^{T} \text{softmax}(q_i k_1, \ldots, q_i k_T)_j v_{j} \text{ for } i = 1\ldots N_0 \]

    @triton.jit
    def flashatt_kernel(
        q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr, B1: tl.constexpr
    ):
        block_id_i = tl.program_id(0)
        log2_e = 1.44269504
        myexp = lambda x: tl.exp2(log2_e * x)
        i_off = block_id_i * B0 + tl.arange(0, B0)
        mask_i = i_off < N0
        q = tl.load(q_ptr + i_off, mask_i)
        max = tl.full([B0], -10, dtype=tl.float32)
        sum = tl.zeros([B0],dtype=tl.float32)
        z = tl.zeros([B0],dtype=tl.float32)
        for j in tl.range(0, T, B1):
            j_off = j + tl.arange(0, B1)
            mask_j = j_off < T
            k = tl.load(k_ptr + j_off, mask_j)
            mask_ij = mask_i[:, None] & mask_j[None, :]
            tmp = q[:, None] * k[None, :] + tl.where(mask_ij, 0, -1.0e6)
            x_max = tl.max(tmp, axis=1)                                 # local max
            new_max = tl.maximum(max, x_max)                            # gloabl max
            old_sum = sum
            sum = sum * myexp(max - x_max) + tl.sum(myexp(tmp - new_max[:, None]), axis=1)      # 更新 expsum
            v = tl.load(v_ptr + j_off, mask_j)
            z = z * myexp(max - x_max) + tl.sum(myexp(tmp - new_max[:, None]) * v[None, :], axis=1) # z 这里是最后的输出(但是没有 ÷ softmax 里的分母)
            max = new_max
    
        # 不能在循环里 否则最开始的 sum=0, nan
        z /= sum
        tl.store(z_ptr + i_off, z, mask_i)
        return
    
    这里的思想来自 flash attention:

    • 计算 \(\sum e^{x-\max{x_i}}\),这里我们在线更新。假设前 i blocks 已经计算完毕,我们得到了 \(Q_i=\max{x_i}, P_i = \sum e^{x-Q_i}\),那么对于第 i+1 block,我们可以得到 \(P_{i+1} = P_i \times e^{Q_i - Q_{i+1}} + \sum e^{x-Q_{i+1}}\)
    • 假设 \(O_i = softmax(x_i)\),这里还需要计算 \(\sum_j O_i\cdot v_j\)。这里我们先抛开 softmax 最后的除法,只计算分子(这里用 \(\hat O_i\) 表示)。假设前 i blocks 我们得到了 \(S_i = \sum_j \hat O_i\cdot v_j\),那么对于第 i+1 block,我们可以得到 \(S_{i+1} = S_i \times e^{Q_i - Q_{i+1}} + \sum_j \hat O_{i+1}\cdot v_j\)
  • Puzzle 11: Matrix Multiplication

    这里拆成了很多小块来计算(类似于 tiling),随后将小块的计算结果累加。

    \[ z_{i, j, k} = \sum_{l} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1 \]
    @triton.jit
    def dot_kernel(
        x_ptr,
        y_ptr,
        z_ptr,
        N0,
        N1,
        N2,
        MID,
        B0: tl.constexpr,
        B1: tl.constexpr,
        B2: tl.constexpr,
        B_MID: tl.constexpr,
    ):
        block_id_j = tl.program_id(0)
        block_id_k = tl.program_id(1)
        block_id_i = tl.program_id(2)
        i_off = block_id_i * B2 + tl.arange(0, B2)
        j_off = block_id_j * B0 + tl.arange(0, B0)
        k_off = block_id_k * B1 + tl.arange(0, B1)
        mask_i = i_off < N2
        mask_j = j_off < N0
        mask_k = k_off < N1
    
        z = tl.zeros([B2, B0, B1], dtype=tl.float32)
    
        # MID 是由 triton 调用的时候传入的
        for l in tl.range(0, MID, B_MID):
            l_off = l + tl.arange(0, B_MID)
            mask_l = l_off < MID
            xjl_off = i_off[:, None, None] * N0 * N1 + j_off[None, :, None] * N1 + l_off[None, None, :]
            mask_xjl = mask_i[:, None, None] & mask_j[None, :, None] & mask_l[None, None, :]
            xjl = tl.load(x_ptr + xjl_off, mask_xjl)
            ylk_off = i_off[:, None, None] * N0 * N1 + l_off[None, :, None] * N1 + k_off[None, None, :]
            mask_ylk = mask_i[:, None, None] & mask_l[None, :, None] & mask_k[None, None, :]
            ylk = tl.load(y_ptr + ylk_off, mask_ylk)
    
            z += tl.dot(xjl, ylk)
    
        z_off = i_off[:, None, None] * N0 * N1 + j_off[None, :, None] * N1 + k_off[None, None, :]
        mask_z = mask_i[:, None, None] & mask_j[None, :, None] & mask_k[None, None, :]
        tl.store(z_ptr + z_off, z, mask_z)
    
        return
    

评论