Selaa lähdekoodia

Enhance documentation and update .gitignore for model conversion scripts

enoch kan 1 vuosi sitten
vanhempi
commit
a1296f099e
5 muutettua tiedostoa jossa 179 lisäystä ja 1 poistoa
  1. 5 1
      .gitignore
  2. 12 0
      inference/convert.py
  3. 31 0
      inference/fp8_cast_bf16.py
  4. 48 0
      inference/generate.py
  5. 83 0
      inference/kernel.py

+ 5 - 1
.gitignore

@@ -165,4 +165,8 @@ cython_debug/
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+#.idea/
+
+.vscode/*
+
+.DS_Store

+ 12 - 0
inference/convert.py

@@ -31,6 +31,18 @@ mapping = {
 
 
 def main(hf_ckpt_path, save_path, n_experts, mp):
+    """
+    Converts and saves model checkpoint files into a specified format.
+
+    Args:
+        hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
+        save_path (str): Path to the directory where the converted checkpoint files will be saved.
+        n_experts (int): Total number of experts in the model.
+        mp (int): Model parallelism factor.
+        
+    Returns:
+        None
+    """
     torch.set_num_threads(8)
     n_local_experts = n_experts // mp
     state_dicts = [{} for _ in range(mp)]

+ 31 - 0
inference/fp8_cast_bf16.py

@@ -10,6 +10,25 @@ from safetensors.torch import load_file, save_file
 from kernel import weight_dequant
 
 def main(fp8_path, bf16_path):
+    """
+    Converts FP8 weights to BF16 and saves the converted weights.
+
+    This function reads FP8 weights from the specified directory, converts them to BF16,
+    and saves the converted weights to another specified directory. It also updates the
+    model index file to reflect the changes.
+
+    Args:
+    fp8_path (str): The path to the directory containing the FP8 weights and model index file.
+    bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
+
+    Raises:
+    KeyError: If a required scale_inv tensor is missing for a weight.
+
+    Notes:
+    - The function assumes that the FP8 weights are stored in safetensor files.
+    - The function caches loaded safetensor files to optimize memory usage.
+    - The function updates the model index file to remove references to scale_inv tensors.
+    """
     torch.set_default_dtype(torch.bfloat16)
     os.makedirs(bf16_path, exist_ok=True)
     model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@@ -23,6 +42,18 @@ def main(fp8_path, bf16_path):
 
     # Helper function to get tensor from the correct file
     def get_tensor(tensor_name):
+        """
+        Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
+
+        Args:
+            tensor_name (str): The name of the tensor to retrieve.
+
+        Returns:
+            torch.Tensor: The retrieved tensor.
+
+        Raises:
+            KeyError: If the tensor does not exist in the safetensor file.
+        """
         file_name = weight_map[tensor_name]
         if file_name not in loaded_files:
             file_path = os.path.join(fp8_path, file_name)

+ 48 - 0
inference/generate.py

@@ -12,6 +12,16 @@ from model import Transformer, ModelArgs
 
 
 def sample(logits, temperature: float = 1.0):
+    """
+    Samples a token from the logits using temperature scaling.
+
+    Args:
+        logits (torch.Tensor): The logits tensor for token predictions.
+        temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
+
+    Returns:
+        torch.Tensor: The sampled token.
+    """
     logits = logits / max(temperature, 1e-5)
     probs = torch.softmax(logits, dim=-1)
     return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@@ -25,6 +35,19 @@ def generate(
     eos_id: int,
     temperature: float = 1.0
 ) -> List[List[int]]:
+    """
+    Generates new tokens based on the given prompt tokens using the specified model.
+
+    Args:
+        model (Transformer): The transformer model used for token generation.
+        prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
+        max_new_tokens (int): The maximum number of new tokens to generate.
+        eos_id (int): The end-of-sequence token ID.
+        temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
+
+    Returns:
+        List[List[int]]: A list of lists containing the generated tokens for each sequence.
+    """
     prompt_lens = [len(t) for t in prompt_tokens]
     assert max(prompt_lens) <= model.max_seq_len
     total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
@@ -63,6 +86,17 @@ def main(
     max_new_tokens: int = 100,
     temperature: float = 1.0,
 ) -> None:
+    """
+    Main function to load the model and perform interactive or batch text generation.
+
+    Args:
+        ckpt_path (str): Path to the model checkpoint directory.
+        config (str): Path to the model configuration file.
+        input_file (str, optional): Path to a file containing input prompts. Defaults to "".
+        interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
+        max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
+        temperature (float, optional): Temperature for sampling. Defaults to 1.0.
+    """
     world_size = int(os.getenv("WORLD_SIZE", "1"))
     rank = int(os.getenv("RANK", "0"))
     local_rank = int(os.getenv("LOCAL_RANK", "0"))
@@ -125,6 +159,20 @@ def main(
 
 
 if __name__ == "__main__":
+    """
+    Command-line interface for distributed text generation.
+
+    Arguments:
+        --ckpt-path (str): Path to the model checkpoint directory.
+        --config (str): Path to the model configuration file.
+        --input-file (str, optional): File containing prompts for batch processing.
+        --interactive (bool, optional): Enable interactive mode for generating text.
+        --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
+        --temperature (float, optional): Temperature for sampling. Defaults to 0.2.
+
+    Raises:
+        AssertionError: If neither input-file nor interactive mode is specified.
+    """
     parser = ArgumentParser()
     parser.add_argument("--ckpt-path", type=str, required=True)
     parser.add_argument("--config", type=str, required=True)

+ 83 - 0
inference/kernel.py

@@ -8,6 +8,18 @@ from triton import Config
 
 @triton.jit
 def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
+    """
+    Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
+
+    Args:
+        x_ptr (triton.Pointer): Pointer to the input tensor.
+        y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
+        s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
+        BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
+
+    Returns:
+        None
+    """
     pid = tl.program_id(axis=0)
     offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     x = tl.load(x_ptr + offs).to(tl.float32)
@@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
 
 
 def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantizes the input tensor `x` using block-wise quantization.
+
+    Args:
+        x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
+        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - The quantized tensor with dtype `torch.float8_e4m3fn`.
+            - A tensor of scaling factors with dtype `torch.float32`.
+    """
     assert x.is_contiguous()
     assert x.size(-1) % block_size == 0
     y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
@@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
 
 @triton.jit
 def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
+    """
+    Dequantizes weights using the provided scaling factors and stores the result.
+
+    Args:
+        x_ptr (tl.pointer): Pointer to the quantized weights.
+        s_ptr (tl.pointer): Pointer to the scaling factors.
+        y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
+        M (int): Number of rows in the weight matrix.
+        N (int): Number of columns in the weight matrix.
+        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
+
+    Returns:
+        None
+    """
     pid_m = tl.program_id(axis=0)
     pid_n = tl.program_id(axis=1)
     n = tl.cdiv(N, BLOCK_SIZE)
@@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
 
 
 def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
+    """
+    Dequantizes the given weight tensor using the provided scale tensor.
+
+    Args:
+        x (torch.Tensor): The quantized weight tensor of shape (M, N).
+        s (torch.Tensor): The scale tensor of shape (M, N).
+        block_size (int, optional): The block size to use for dequantization. Defaults to 128.
+
+    Returns:
+        torch.Tensor: The dequantized weight tensor of the same shape as `x`.
+
+    Raises:
+        AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
+    """
     assert x.is_contiguous() and s.is_contiguous()
     assert x.dim() == 2 and s.dim() == 2
     M, N = x.size()
@@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
                     BLOCK_SIZE_M: tl.constexpr,
                     BLOCK_SIZE_N: tl.constexpr,
                     BLOCK_SIZE_K: tl.constexpr):
+    """
+    Performs a matrix multiplication operation on FP8 matrices with scaling factors.
+
+    Args:
+        a_ptr (tl.tensor): Pointer to the first input matrix A.
+        b_ptr (tl.tensor): Pointer to the second input matrix B.
+        c_ptr (tl.tensor): Pointer to the output matrix C.
+        a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
+        b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
+        M (int): Number of rows in matrix A and C.
+        N (tl.constexpr): Number of columns in matrix B and C.
+        K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
+        BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
+        BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
+        BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
+
+    Returns:
+        None
+    """
     pid_m = tl.program_id(axis=0)
     pid_n = tl.program_id(axis=1)
     k = tl.cdiv(K, BLOCK_SIZE_K)
@@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
 
 
 def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
+    """
+    Perform a matrix multiplication using FP8 precision.
+
+    Args:
+        a (torch.Tensor): The first input matrix, must be contiguous.
+        a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
+        b (torch.Tensor): The second input matrix, must be contiguous.
+        b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
+
+    Returns:
+        torch.Tensor: The result of the matrix multiplication.
+    """
     assert a.is_contiguous() and b.is_contiguous()
     assert a_s.is_contiguous() and b_s.is_contiguous()
     K = a.size(-1)