电子说
OpenAI/Triton MLIR 第四章: ROCm-triton配置
最近在整理python-based的benchmark代码,反过来在NV的GPU上又把Triton装了一遍,发现Triton的github repo已经给出了对应的llvm的commit id以及对应的编译细节,然后跟着走了一遍,也顺利的安装成功,只需要按照如下方式即可完成NV GPU上的安装,
1. git clone https://github.com/openai/triton.git; 2. cd triton; 3. cd $HOME/llvm-project # your clone of LLVM. 4. git checkout 49af6502 5. mkdir build 6. cd build 7. cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" 8. ninja -j8 export LLVM_BUILD_DIR=$HOME/llvm-project/build cdLLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib LLVM_SYSPATH=$LLVM_BUILD_DIR pip install -e python
出现3.0.0说明triton已经安装成功了,装完triton后一定要安装Torch,为个人使用的是CUDA 12.1版本,按照下面的命令无脑安装即可。
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
NV GPU上triton的安装和使用其实已经轻车熟路了,接下来,让我们来探索一下AMD GPU上如何安装和配置triton。
0x00 软件安装
关于triton amd的backend,虽然triton的官方将其作为third-party来进行支持,但是我还是推荐大家使用AMD专门维护的一套triton版本,因为在最开始的官方triton的main分支下,开启 TRITON_CODEGEN_AMD_HIP_BACKEND=1 没有正确完成编译。所以找到了
按照对应的安装流程进行安装即可,我推荐使用如下命令进行安装,亲测有效
1. git clone https://github.com/ROCmSoftwarePlatform/triton.git 2. cd triton 3. git checkout triton-mlir
这里已经准备好了需要编译的triton,但是triton后端是基于LLVM的,所以要想借助triton去生成可以跑在对应设备上的代码,我们还需要对LLVM进行编译,本教程中将会手动编译LLVM,当然如果你选择直接编译好的LLVM也是没有问题的。关于LLVM,由于triton是基于b1115f8c这个commit id进行开发的,那么我们只需要将LLVM clone下来后,checkout到对应的commit id,然后按照如下完整命令进行编译即可。
1. git clone https://github.com/llvm/llvm-project 2. git checkout b1115f8c 3. cd llvm-project 4. mkdir build 5. cd build 6. cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" 7. ninja -j8
等LLVM全部装好后,就可以去将当前这个LLVM的路径写入到你的bashrc下
export PATH=/home/llvm-project/build/bin:$PATH
然后进入到一开始clone下来的triton目录下进行如下命令
1. cd triton 2. vim CMakeLists.txt (option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON)) 3. mkdir build 4. cd build 5. cmake .. 6. make -j8
在编译完全正确后,就会在当前的 build 目录下产生一个 libtriton.so 文件。那么接下来只要将
libtriton.so 文件移动到 triton/python/triton/_C 目录下,将 triton 的 python 路径下入 bashrc
export TRITON_HOME=/home/Documents/compiler/triton export PYTHONPATH=$TRITON_HOME/python:${PYTHONPATH}
如果在编译的过程中出现 goolge test 找不到的情况,按照如下命令进行安装:
1. git clone https://github.com/google/googletest 2. cd googletest 3. cmake CMakeLists.txt 4. make -j8 5. cp ./lib/libgtest*.a /usr/lib 6. cd googletest 7. cp –a include/gtest /usr/include
如果在编译的过程中出现 pybind11 找不到的情况,按照如下命令进行按照:
1. pip install pytest 2. git clone https://github.com/pybind/pybind11.git 3. cd pybind11 4. mkdir build 5. cd build 6. cmake .. 7. make check -j 8 8. sudo make instal
关于 在AMD GPU上的pytorch 一定要去安装适配 ROCM 版本的 pytorch,由于我的机器使用的是5.6版本的ROCm,所以我的安装的命令如下,仅供参考:
pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/rocm5.6
关于 ROCM 版本可以通过如下命令进行查询:
dpkg -l | grep rocm
这里要记住,pytorch在AMD GPU上的使用和在NV GPU上的使用非常相似,也是用.cuda()来指定变量所在位置。
0x01 GEMM代码示例
全部编译好后,就可以通过执行下面的代码得到对应的 GEMM 在 AMD 显卡上针对 Triton和 rocBLAS 的 benchmark 了。
import torch import triton import triton.language as tl import sys import argparse import pytest # `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # - A list of `triton.Config` objects that define different configurations of # meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try # - An auto-tuning *key* whose change in values will trigger evaluation of all the # provided configs @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ] if torch.version.hip is None else [ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=0), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0), ], key=['M', 'N', 'K'], ) @triton.heuristics({ 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, }) @triton.jit def matmul_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. # See above `L2 Cache Optimizations` section for details. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) if GROUP_SIZE_M == 1: pid_m = pid // num_pid_n pid_n = pid % num_pid_n else: num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # See above `Pointer Arithmetics` section for details offs_k = tl.arange(0, BLOCK_SIZE_K) offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. if EVEN_K: a = tl.load(a_ptrs) b = tl.load(b_ptrs) else: a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. @triton.jit def leaky_relu(x): x = x + 1 return tl.where(x >= 0, x, 0.01 * x) # %% # We can now create a convenience wrapper function that only takes two input tensors, # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. def matmul(a, b, activation=""): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" assert b.is_contiguous(), "Matrix B must be contiguous" M, K = a.shape K, N = b.shape # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel[grid]( a, b, c, # M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # ACTIVATION=activation # ) return c # %% # Unit Test # --------- # # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). @pytest.mark.parametrize("M, N, K, in_dtype, out_dtype", [ (*shape, in_dtype, out_dtype) for shape in [(128, 256, 32), (128, 16, 32), (32, 128, 64), (128, 128, 64), (64, 128, 128), (32, 128, 64), (64, 64, 32), (32, 32, 128), (128, 128, 64), (64, 128, 128), (512, 512, 512), (1024, 1024, 1024)] for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('bfloat16', 'bfloat16'), ('float16', 'float32'), ('float32', 'float32')]] ) def test_correctness(M, N, K, in_dtype, out_dtype): torch.manual_seed(0) a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") rtol = 0 if torch.version.hip is None else 1e-2 if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print(" Triton and Torch match") else: print(" Triton and Torch differ") assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol) # %% # Benchmark # --------- # # Square Matrix Performance # ~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices, # but feel free to arrange this script as you wish to benchmark any other matrix shape. global verbose verbose = False @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot x_vals=[ (1024, 1024, 1024), (2048, 2048, 2048), (4096, 4096, 4096), (8192, 8192, 8192), (9728, 8192, 65536) ], # Different possible values for `x_name` line_arg='provider', # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` line_vals=['rocblas', 'triton'], # Label name for the lines line_names=["rocBLAS", "Triton"], # Line styles styles=[('green', '-'), ('blue', '-')], ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, )) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == 'rocblas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) global verbose if verbose: print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})') perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) def parse_args(): parser = argparse.ArgumentParser( prog="GEMM tutorial example", allow_abbrev=False, ) parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") args = parser.parse_args() return args def main(): # assign to a global verbose var to indicate whether print # best tuning config global verbose args = parse_args() verbose=args.v benchmark.run(show_plots=True, print_data=True) if __name__ == '__main__': sys.exit(main())
0x10 GEMM代码详细解读
首先是对于搜索空间的定义,这里
@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), ] if torch.version.hip is None else [ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=0), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=8, num_stages=0), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, num_warps=4, num_stages=0), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, num_warps=4, num_stages=0), ], key=['M', 'N', 'K'], )
其中的torch.version.hip走的就是AMD GPU所对应的搜索空间,我们看到其对应的可以tuning的knob,有最常规的BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M外,还有了一个新的wave_per_eu,我一开始看到这个概念的时候也很陌生,随后和AMD的技术人员请教了下,总结下来就是:
AMD GPU由计算单元(CU)组成,这相当于NVIDIA GPU上的流处理器(SM)。在每个CU中,有4个SIMD单元(也称执行引擎或EU)。你可以把SIMD单元看成是一个矢量执行单元,它具有执行计算所需的一定数量的寄存器和ALUs。当你发起一个计算网格时,工作组(相当于NVIDIA GPU上的线程块)会安排在CU上运行。
在CU中,波前(相当于NVIDIA GPU上的波纹)会安排在SIMD单元上运行。这里提出了occupancy的概念,它表示每个SIMD单元上可同时运行的波前数。这取决于每个波前需要的资源量和每个SIMD单元的资源量。waves_per_eu参数重点关注寄存器使用情况。例如,每个SIMD(EU)有512个寄存器。
如果每个波前需要256个寄存器,那么occupancy为2。但如果我们设置waves_per_eu=3,编译器会试图将每个波前的寄存器使用量减少到170,这样occupancy就可以是3了。但是提高waves_per_eu存在寄存器溢出的风险和性能下降。所以增加waves_per_eu可能会增加occupancy,但不一定能提高性能。
然后是具体的kernel定义,这部分的定义其实和NV GPU上的写法没有本质区别
@triton.jit def matmul_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, EVEN_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. # See above `L2 Cache Optimizations` section for details. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) if GROUP_SIZE_M == 1: pid_m = pid // num_pid_n pid_n = pid % num_pid_n else: num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # See above `Pointer Arithmetics` section for details offs_k = tl.arange(0, BLOCK_SIZE_K) offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. if EVEN_K: a = tl.load(a_ptrs) b = tl.load(b_ptrs) else: a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask)
接下来是单元测试,用来说明triton的输出结果和torch的输出结果必须是相同的
def test_correctness(M, N, K, in_dtype, out_dtype): torch.manual_seed(0) a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) triton_output = matmul(a, b) torch_output = torch.matmul(a, b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") rtol = 0 if torch.version.hip is None else 1e-2 if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print(" Triton and Torch match") else: print(" Triton and Torch differ") assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol)
接下来你只需要指定好对应的GEMM的尺寸,我们的默认输入顺序还是以M,N,K为主,剩下都是中规中局的操作了。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot x_vals=[ (1024, 1024, 1024), (2048, 2048, 2048), (4096, 4096, 4096), (8192, 8192, 8192), (9728, 8192, 65536) ], # Different possible values for `x_name` line_arg='provider', # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` line_vals=['rocblas', 'triton'], # Label name for the lines line_names=["rocBLAS", "Triton"], # Line styles styles=[('green', '-'), ('blue', '-')], ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, )) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == 'rocblas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) global verbose if verbose: print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})') perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) def parse_args(): parser = argparse.ArgumentParser( prog="GEMM tutorial example", allow_abbrev=False, ) parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") args = parser.parse_args() return args def main(): # assign to a global verbose var to indicate whether print # best tuning config global verbose args = parse_args() verbose=args.v benchmark.run(show_plots=True, print_data=True) if __name__ == '__main__': sys.exit(main())
关于在AMD GPU上更加自动化的GEMM benchmark调优脚本,我们将在后面的章节中来为大家进行解读。
审核编辑:刘清
全部0条评论
快来发表一下你的评论吧 !