|
@@ -8,6 +8,18 @@ from triton import Config
|
|
|
|
|
|
|
|
@triton.jit
|
|
@triton.jit
|
|
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
|
|
|
+ """
|
|
|
|
|
+ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x_ptr (triton.Pointer): Pointer to the input tensor.
|
|
|
|
|
+ y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
|
|
|
|
|
+ s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
|
|
|
|
|
+ BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ None
|
|
|
|
|
+ """
|
|
|
pid = tl.program_id(axis=0)
|
|
pid = tl.program_id(axis=0)
|
|
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
|
@@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
|
|
|
|
|
|
|
|
|
|
|
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Quantizes the input tensor `x` using block-wise quantization.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
|
|
|
|
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
|
|
|
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
|
|
|
|
|
+ - A tensor of scaling factors with dtype `torch.float32`.
|
|
|
|
|
+ """
|
|
|
assert x.is_contiguous()
|
|
assert x.is_contiguous()
|
|
|
assert x.size(-1) % block_size == 0
|
|
assert x.size(-1) % block_size == 0
|
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
@@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
|
|
|
|
|
|
|
@triton.jit
|
|
@triton.jit
|
|
|
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|
|
|
|
+ """
|
|
|
|
|
+ Dequantizes weights using the provided scaling factors and stores the result.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x_ptr (tl.pointer): Pointer to the quantized weights.
|
|
|
|
|
+ s_ptr (tl.pointer): Pointer to the scaling factors.
|
|
|
|
|
+ y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
|
|
|
|
|
+ M (int): Number of rows in the weight matrix.
|
|
|
|
|
+ N (int): Number of columns in the weight matrix.
|
|
|
|
|
+ BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ None
|
|
|
|
|
+ """
|
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_m = tl.program_id(axis=0)
|
|
|
pid_n = tl.program_id(axis=1)
|
|
pid_n = tl.program_id(axis=1)
|
|
|
n = tl.cdiv(N, BLOCK_SIZE)
|
|
n = tl.cdiv(N, BLOCK_SIZE)
|
|
@@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|
|
|
|
|
|
|
|
|
|
|
|
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|
|
|
|
+ """
|
|
|
|
|
+ Dequantizes the given weight tensor using the provided scale tensor.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
|
|
|
|
+ s (torch.Tensor): The scale tensor of shape (M, N).
|
|
|
|
|
+ block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
|
|
|
|
+
|
|
|
|
|
+ Raises:
|
|
|
|
|
+ AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
|
|
|
|
+ """
|
|
|
assert x.is_contiguous() and s.is_contiguous()
|
|
assert x.is_contiguous() and s.is_contiguous()
|
|
|
assert x.dim() == 2 and s.dim() == 2
|
|
assert x.dim() == 2 and s.dim() == 2
|
|
|
M, N = x.size()
|
|
M, N = x.size()
|
|
@@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
|
BLOCK_SIZE_K: tl.constexpr):
|
|
BLOCK_SIZE_K: tl.constexpr):
|
|
|
|
|
+ """
|
|
|
|
|
+ Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ a_ptr (tl.tensor): Pointer to the first input matrix A.
|
|
|
|
|
+ b_ptr (tl.tensor): Pointer to the second input matrix B.
|
|
|
|
|
+ c_ptr (tl.tensor): Pointer to the output matrix C.
|
|
|
|
|
+ a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
|
|
|
|
+ b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
|
|
|
|
+ M (int): Number of rows in matrix A and C.
|
|
|
|
|
+ N (tl.constexpr): Number of columns in matrix B and C.
|
|
|
|
|
+ K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
|
|
|
|
+ BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
|
|
|
|
+ BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
|
|
|
|
+ BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ None
|
|
|
|
|
+ """
|
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_m = tl.program_id(axis=0)
|
|
|
pid_n = tl.program_id(axis=1)
|
|
pid_n = tl.program_id(axis=1)
|
|
|
k = tl.cdiv(K, BLOCK_SIZE_K)
|
|
k = tl.cdiv(K, BLOCK_SIZE_K)
|
|
@@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|
|
|
|
|
|
|
|
|
|
|
|
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
|
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
|
|
|
|
+ """
|
|
|
|
|
+ Perform a matrix multiplication using FP8 precision.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ a (torch.Tensor): The first input matrix, must be contiguous.
|
|
|
|
|
+ a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
|
|
|
|
+ b (torch.Tensor): The second input matrix, must be contiguous.
|
|
|
|
|
+ b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ torch.Tensor: The result of the matrix multiplication.
|
|
|
|
|
+ """
|
|
|
assert a.is_contiguous() and b.is_contiguous()
|
|
assert a.is_contiguous() and b.is_contiguous()
|
|
|
assert a_s.is_contiguous() and b_s.is_contiguous()
|
|
assert a_s.is_contiguous() and b_s.is_contiguous()
|
|
|
K = a.size(-1)
|
|
K = a.size(-1)
|