| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- from typing import Tuple
- import torch
- import triton
- import triton.language as tl
- from triton import Config
- @triton.jit
- def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
- """
- Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
- Args:
- x_ptr (triton.Pointer): Pointer to the input tensor.
- y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
- s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
- BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
- Returns:
- None
- """
- pid = tl.program_id(axis=0)
- offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- x = tl.load(x_ptr + offs).to(tl.float32)
- s = tl.max(tl.abs(x)) / 448.
- y = x / s
- y = y.to(y_ptr.dtype.element_ty)
- tl.store(y_ptr + offs, y)
- tl.store(s_ptr + pid, s)
- def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Quantizes the input tensor `x` using block-wise quantization.
- Args:
- x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
- block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- - The quantized tensor with dtype `torch.float8_e4m3fn`.
- - A tensor of scaling factors with dtype `torch.float32`.
- """
- assert x.is_contiguous()
- assert x.size(-1) % block_size == 0
- y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
- s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
- grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
- act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
- return y, s
- @triton.jit
- def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
- """
- Dequantizes weights using the provided scaling factors and stores the result.
- Args:
- x_ptr (tl.pointer): Pointer to the quantized weights.
- s_ptr (tl.pointer): Pointer to the scaling factors.
- y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
- M (int): Number of rows in the weight matrix.
- N (int): Number of columns in the weight matrix.
- BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
- Returns:
- None
- """
- pid_m = tl.program_id(axis=0)
- pid_n = tl.program_id(axis=1)
- n = tl.cdiv(N, BLOCK_SIZE)
- offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
- offs = offs_m[:, None] * N + offs_n[None, :]
- mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
- x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
- s = tl.load(s_ptr + pid_m * n + pid_n)
- y = x * s
- tl.store(y_ptr + offs, y, mask=mask)
- def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
- """
- Dequantizes the given weight tensor using the provided scale tensor.
- Args:
- x (torch.Tensor): The quantized weight tensor of shape (M, N).
- s (torch.Tensor): The scale tensor of shape (M, N).
- block_size (int, optional): The block size to use for dequantization. Defaults to 128.
- Returns:
- torch.Tensor: The dequantized weight tensor of the same shape as `x`.
- Raises:
- AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
- """
- assert x.is_contiguous() and s.is_contiguous()
- assert x.dim() == 2 and s.dim() == 2
- M, N = x.size()
- y = torch.empty_like(x, dtype=torch.get_default_dtype())
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
- weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
- return y
- fp8_gemm_configs = [
- Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
- for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
- ]
- @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
- @triton.jit
- def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
- a_s_ptr, b_s_ptr,
- M, N: tl.constexpr, K: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr,
- BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_K: tl.constexpr):
- """
- Performs a matrix multiplication operation on FP8 matrices with scaling factors.
- Args:
- a_ptr (tl.tensor): Pointer to the first input matrix A.
- b_ptr (tl.tensor): Pointer to the second input matrix B.
- c_ptr (tl.tensor): Pointer to the output matrix C.
- a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
- b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
- M (int): Number of rows in matrix A and C.
- N (tl.constexpr): Number of columns in matrix B and C.
- K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
- BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
- BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
- BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
- Returns:
- None
- """
- pid_m = tl.program_id(axis=0)
- pid_n = tl.program_id(axis=1)
- k = tl.cdiv(K, BLOCK_SIZE_K)
- offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
- offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
- b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
- a_s_ptrs = a_s_ptr + offs_m * k
- b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for i in range(k):
- a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
- a_s = tl.load(a_s_ptrs)
- b_s = tl.load(b_s_ptrs)
- accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
- a_ptrs += BLOCK_SIZE_K
- b_ptrs += BLOCK_SIZE_K
- a_s_ptrs += 1
- b_s_ptrs += 1
- c = accumulator.to(c_ptr.dtype.element_ty)
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
- mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
- tl.store(c_ptrs, c, mask=mask)
- def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
- """
- Perform a matrix multiplication using FP8 precision.
- Args:
- a (torch.Tensor): The first input matrix, must be contiguous.
- a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
- b (torch.Tensor): The second input matrix, must be contiguous.
- b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
- Returns:
- torch.Tensor: The result of the matrix multiplication.
- """
- assert a.is_contiguous() and b.is_contiguous()
- assert a_s.is_contiguous() and b_s.is_contiguous()
- K = a.size(-1)
- M = a.numel() // K
- N = b.size(0)
- c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
- grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
- fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
- return c
|