kernel.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from typing import Tuple
  2. import torch
  3. import triton
  4. import triton.language as tl
  5. from triton import Config
  6. @triton.jit
  7. def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
  8. """
  9. Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
  10. Args:
  11. x_ptr (triton.Pointer): Pointer to the input tensor.
  12. y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
  13. s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
  14. BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
  15. Returns:
  16. None
  17. """
  18. pid = tl.program_id(axis=0)
  19. offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  20. x = tl.load(x_ptr + offs).to(tl.float32)
  21. s = tl.max(tl.abs(x)) / 448.
  22. y = x / s
  23. y = y.to(y_ptr.dtype.element_ty)
  24. tl.store(y_ptr + offs, y)
  25. tl.store(s_ptr + pid, s)
  26. def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
  27. """
  28. Quantizes the input tensor `x` using block-wise quantization.
  29. Args:
  30. x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
  31. block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
  32. Returns:
  33. Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
  34. - The quantized tensor with dtype `torch.float8_e4m3fn`.
  35. - A tensor of scaling factors with dtype `torch.float32`.
  36. """
  37. assert x.is_contiguous()
  38. assert x.size(-1) % block_size == 0
  39. y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
  40. s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
  41. grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
  42. act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
  43. return y, s
  44. @triton.jit
  45. def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
  46. """
  47. Dequantizes weights using the provided scaling factors and stores the result.
  48. Args:
  49. x_ptr (tl.pointer): Pointer to the quantized weights.
  50. s_ptr (tl.pointer): Pointer to the scaling factors.
  51. y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
  52. M (int): Number of rows in the weight matrix.
  53. N (int): Number of columns in the weight matrix.
  54. BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
  55. Returns:
  56. None
  57. """
  58. pid_m = tl.program_id(axis=0)
  59. pid_n = tl.program_id(axis=1)
  60. n = tl.cdiv(N, BLOCK_SIZE)
  61. offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  62. offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  63. offs = offs_m[:, None] * N + offs_n[None, :]
  64. mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
  65. x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
  66. s = tl.load(s_ptr + pid_m * n + pid_n)
  67. y = x * s
  68. tl.store(y_ptr + offs, y, mask=mask)
  69. def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
  70. """
  71. Dequantizes the given weight tensor using the provided scale tensor.
  72. Args:
  73. x (torch.Tensor): The quantized weight tensor of shape (M, N).
  74. s (torch.Tensor): The scale tensor of shape (M, N).
  75. block_size (int, optional): The block size to use for dequantization. Defaults to 128.
  76. Returns:
  77. torch.Tensor: The dequantized weight tensor of the same shape as `x`.
  78. Raises:
  79. AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
  80. """
  81. assert x.is_contiguous() and s.is_contiguous()
  82. assert x.dim() == 2 and s.dim() == 2
  83. M, N = x.size()
  84. y = torch.empty_like(x, dtype=torch.get_default_dtype())
  85. grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
  86. weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
  87. return y
  88. fp8_gemm_configs = [
  89. Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
  90. for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
  91. ]
  92. @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
  93. @triton.jit
  94. def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
  95. a_s_ptr, b_s_ptr,
  96. M, N: tl.constexpr, K: tl.constexpr,
  97. BLOCK_SIZE_M: tl.constexpr,
  98. BLOCK_SIZE_N: tl.constexpr,
  99. BLOCK_SIZE_K: tl.constexpr):
  100. """
  101. Performs a matrix multiplication operation on FP8 matrices with scaling factors.
  102. Args:
  103. a_ptr (tl.tensor): Pointer to the first input matrix A.
  104. b_ptr (tl.tensor): Pointer to the second input matrix B.
  105. c_ptr (tl.tensor): Pointer to the output matrix C.
  106. a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
  107. b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
  108. M (int): Number of rows in matrix A and C.
  109. N (tl.constexpr): Number of columns in matrix B and C.
  110. K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
  111. BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
  112. BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
  113. BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
  114. Returns:
  115. None
  116. """
  117. pid_m = tl.program_id(axis=0)
  118. pid_n = tl.program_id(axis=1)
  119. k = tl.cdiv(K, BLOCK_SIZE_K)
  120. offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
  121. offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  122. offs_k = tl.arange(0, BLOCK_SIZE_K)
  123. a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
  124. b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
  125. a_s_ptrs = a_s_ptr + offs_m * k
  126. b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
  127. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  128. for i in range(k):
  129. a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
  130. b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
  131. a_s = tl.load(a_s_ptrs)
  132. b_s = tl.load(b_s_ptrs)
  133. accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
  134. a_ptrs += BLOCK_SIZE_K
  135. b_ptrs += BLOCK_SIZE_K
  136. a_s_ptrs += 1
  137. b_s_ptrs += 1
  138. c = accumulator.to(c_ptr.dtype.element_ty)
  139. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  140. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  141. c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
  142. mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
  143. tl.store(c_ptrs, c, mask=mask)
  144. def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
  145. """
  146. Perform a matrix multiplication using FP8 precision.
  147. Args:
  148. a (torch.Tensor): The first input matrix, must be contiguous.
  149. a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
  150. b (torch.Tensor): The second input matrix, must be contiguous.
  151. b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
  152. Returns:
  153. torch.Tensor: The result of the matrix multiplication.
  154. """
  155. assert a.is_contiguous() and b.is_contiguous()
  156. assert a_s.is_contiguous() and b_s.is_contiguous()
  157. K = a.size(-1)
  158. M = a.numel() // K
  159. N = b.size(0)
  160. c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
  161. grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
  162. fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
  163. return c