瀏覽代碼

Merge pull request #193 from enochkan/main

Add docstrings to functions in inference modules for better clarity
Huang Panpan 1 年之前
父節點
當前提交
fdbd5be754
共有 6 個文件被更改,包括 563 次插入1 次删除
  1. 5 1
      .gitignore
  2. 12 0
      inference/convert.py
  3. 31 0
      inference/fp8_cast_bf16.py
  4. 48 0
      inference/generate.py
  5. 83 0
      inference/kernel.py
  6. 384 0
      inference/model.py

+ 5 - 1
.gitignore

@@ -165,4 +165,8 @@ cython_debug/
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+#.idea/
+
+.vscode/*
+
+.DS_Store

+ 12 - 0
inference/convert.py

@@ -31,6 +31,18 @@ mapping = {
 
 
 def main(hf_ckpt_path, save_path, n_experts, mp):
+    """
+    Converts and saves model checkpoint files into a specified format.
+
+    Args:
+        hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
+        save_path (str): Path to the directory where the converted checkpoint files will be saved.
+        n_experts (int): Total number of experts in the model.
+        mp (int): Model parallelism factor.
+        
+    Returns:
+        None
+    """
     torch.set_num_threads(8)
     n_local_experts = n_experts // mp
     state_dicts = [{} for _ in range(mp)]

+ 31 - 0
inference/fp8_cast_bf16.py

@@ -10,6 +10,25 @@ from safetensors.torch import load_file, save_file
 from kernel import weight_dequant
 
 def main(fp8_path, bf16_path):
+    """
+    Converts FP8 weights to BF16 and saves the converted weights.
+
+    This function reads FP8 weights from the specified directory, converts them to BF16,
+    and saves the converted weights to another specified directory. It also updates the
+    model index file to reflect the changes.
+
+    Args:
+    fp8_path (str): The path to the directory containing the FP8 weights and model index file.
+    bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
+
+    Raises:
+    KeyError: If a required scale_inv tensor is missing for a weight.
+
+    Notes:
+    - The function assumes that the FP8 weights are stored in safetensor files.
+    - The function caches loaded safetensor files to optimize memory usage.
+    - The function updates the model index file to remove references to scale_inv tensors.
+    """
     torch.set_default_dtype(torch.bfloat16)
     os.makedirs(bf16_path, exist_ok=True)
     model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@@ -23,6 +42,18 @@ def main(fp8_path, bf16_path):
 
     # Helper function to get tensor from the correct file
     def get_tensor(tensor_name):
+        """
+        Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
+
+        Args:
+            tensor_name (str): The name of the tensor to retrieve.
+
+        Returns:
+            torch.Tensor: The retrieved tensor.
+
+        Raises:
+            KeyError: If the tensor does not exist in the safetensor file.
+        """
         file_name = weight_map[tensor_name]
         if file_name not in loaded_files:
             file_path = os.path.join(fp8_path, file_name)

+ 48 - 0
inference/generate.py

@@ -12,6 +12,16 @@ from model import Transformer, ModelArgs
 
 
 def sample(logits, temperature: float = 1.0):
+    """
+    Samples a token from the logits using temperature scaling.
+
+    Args:
+        logits (torch.Tensor): The logits tensor for token predictions.
+        temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
+
+    Returns:
+        torch.Tensor: The sampled token.
+    """
     logits = logits / max(temperature, 1e-5)
     probs = torch.softmax(logits, dim=-1)
     return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@@ -25,6 +35,19 @@ def generate(
     eos_id: int,
     temperature: float = 1.0
 ) -> List[List[int]]:
+    """
+    Generates new tokens based on the given prompt tokens using the specified model.
+
+    Args:
+        model (Transformer): The transformer model used for token generation.
+        prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
+        max_new_tokens (int): The maximum number of new tokens to generate.
+        eos_id (int): The end-of-sequence token ID.
+        temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
+
+    Returns:
+        List[List[int]]: A list of lists containing the generated tokens for each sequence.
+    """
     prompt_lens = [len(t) for t in prompt_tokens]
     assert max(prompt_lens) <= model.max_seq_len
     total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
@@ -63,6 +86,17 @@ def main(
     max_new_tokens: int = 100,
     temperature: float = 1.0,
 ) -> None:
+    """
+    Main function to load the model and perform interactive or batch text generation.
+
+    Args:
+        ckpt_path (str): Path to the model checkpoint directory.
+        config (str): Path to the model configuration file.
+        input_file (str, optional): Path to a file containing input prompts. Defaults to "".
+        interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
+        max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
+        temperature (float, optional): Temperature for sampling. Defaults to 1.0.
+    """
     world_size = int(os.getenv("WORLD_SIZE", "1"))
     rank = int(os.getenv("RANK", "0"))
     local_rank = int(os.getenv("LOCAL_RANK", "0"))
@@ -125,6 +159,20 @@ def main(
 
 
 if __name__ == "__main__":
+    """
+    Command-line interface for distributed text generation.
+
+    Arguments:
+        --ckpt-path (str): Path to the model checkpoint directory.
+        --config (str): Path to the model configuration file.
+        --input-file (str, optional): File containing prompts for batch processing.
+        --interactive (bool, optional): Enable interactive mode for generating text.
+        --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
+        --temperature (float, optional): Temperature for sampling. Defaults to 0.2.
+
+    Raises:
+        AssertionError: If neither input-file nor interactive mode is specified.
+    """
     parser = ArgumentParser()
     parser.add_argument("--ckpt-path", type=str, required=True)
     parser.add_argument("--config", type=str, required=True)

+ 83 - 0
inference/kernel.py

@@ -8,6 +8,18 @@ from triton import Config
 
 @triton.jit
 def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
+    """
+    Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
+
+    Args:
+        x_ptr (triton.Pointer): Pointer to the input tensor.
+        y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
+        s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
+        BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
+
+    Returns:
+        None
+    """
     pid = tl.program_id(axis=0)
     offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
     x = tl.load(x_ptr + offs).to(tl.float32)
@@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
 
 
 def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantizes the input tensor `x` using block-wise quantization.
+
+    Args:
+        x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
+        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - The quantized tensor with dtype `torch.float8_e4m3fn`.
+            - A tensor of scaling factors with dtype `torch.float32`.
+    """
     assert x.is_contiguous()
     assert x.size(-1) % block_size == 0
     y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
@@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
 
 @triton.jit
 def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
+    """
+    Dequantizes weights using the provided scaling factors and stores the result.
+
+    Args:
+        x_ptr (tl.pointer): Pointer to the quantized weights.
+        s_ptr (tl.pointer): Pointer to the scaling factors.
+        y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
+        M (int): Number of rows in the weight matrix.
+        N (int): Number of columns in the weight matrix.
+        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
+
+    Returns:
+        None
+    """
     pid_m = tl.program_id(axis=0)
     pid_n = tl.program_id(axis=1)
     n = tl.cdiv(N, BLOCK_SIZE)
@@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
 
 
 def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
+    """
+    Dequantizes the given weight tensor using the provided scale tensor.
+
+    Args:
+        x (torch.Tensor): The quantized weight tensor of shape (M, N).
+        s (torch.Tensor): The scale tensor of shape (M, N).
+        block_size (int, optional): The block size to use for dequantization. Defaults to 128.
+
+    Returns:
+        torch.Tensor: The dequantized weight tensor of the same shape as `x`.
+
+    Raises:
+        AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
+    """
     assert x.is_contiguous() and s.is_contiguous()
     assert x.dim() == 2 and s.dim() == 2
     M, N = x.size()
@@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
                     BLOCK_SIZE_M: tl.constexpr,
                     BLOCK_SIZE_N: tl.constexpr,
                     BLOCK_SIZE_K: tl.constexpr):
+    """
+    Performs a matrix multiplication operation on FP8 matrices with scaling factors.
+
+    Args:
+        a_ptr (tl.tensor): Pointer to the first input matrix A.
+        b_ptr (tl.tensor): Pointer to the second input matrix B.
+        c_ptr (tl.tensor): Pointer to the output matrix C.
+        a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
+        b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
+        M (int): Number of rows in matrix A and C.
+        N (tl.constexpr): Number of columns in matrix B and C.
+        K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
+        BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
+        BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
+        BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
+
+    Returns:
+        None
+    """
     pid_m = tl.program_id(axis=0)
     pid_n = tl.program_id(axis=1)
     k = tl.cdiv(K, BLOCK_SIZE_K)
@@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
 
 
 def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
+    """
+    Perform a matrix multiplication using FP8 precision.
+
+    Args:
+        a (torch.Tensor): The first input matrix, must be contiguous.
+        a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
+        b (torch.Tensor): The second input matrix, must be contiguous.
+        b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
+
+    Returns:
+        torch.Tensor: The result of the matrix multiplication.
+    """
     assert a.is_contiguous() and b.is_contiguous()
     assert a_s.is_contiguous() and b_s.is_contiguous()
     K = a.size(-1)

+ 384 - 0
inference/model.py

@@ -18,6 +18,39 @@ attn_impl: Literal["naive", "absorb"] = "absorb"
 
 @dataclass
 class ModelArgs:
+    """
+    Data class for defining model arguments and hyperparameters.
+
+    Attributes:
+        max_batch_size (int): Maximum batch size.
+        max_seq_len (int): Maximum sequence length.
+        dtype (Literal["bf16", "fp8"]): Data type for computations.
+        vocab_size (int): Vocabulary size.
+        dim (int): Model dimension.
+        inter_dim (int): Intermediate dimension for MLP layers.
+        moe_inter_dim (int): Intermediate dimension for MoE layers.
+        n_layers (int): Number of transformer layers.
+        n_dense_layers (int): Number of dense layers in the model.
+        n_heads (int): Number of attention heads.
+        n_routed_experts (int): Number of routed experts for MoE layers.
+        n_shared_experts (int): Number of shared experts for MoE layers.
+        n_activated_experts (int): Number of activated experts in MoE layers.
+        n_expert_groups (int): Number of expert groups.
+        n_limited_groups (int): Number of limited groups for MoE routing.
+        score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
+        route_scale (float): Scaling factor for routing scores.
+        q_lora_rank (int): LoRA rank for query projections.
+        kv_lora_rank (int): LoRA rank for key-value projections.
+        qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
+        qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
+        v_head_dim (int): Dimension for value projections.
+        original_seq_len (int): Original sequence length.
+        rope_theta (float): Base for rotary positional encoding.
+        rope_factor (float): Scaling factor for extended sequence lengths.
+        beta_fast (int): Fast beta correction factor.
+        beta_slow (int): Slow beta correction factor.
+        mscale (float): Scaling factor for extended attention.
+    """
     max_batch_size: int = 8
     max_seq_len: int = 4096 * 4
     dtype: Literal["bf16", "fp8"] = "bf16"
@@ -52,6 +85,13 @@ class ModelArgs:
 
 
 class ParallelEmbedding(nn.Module):
+    """
+    Embedding layer with parallelism support across distributed processes.
+
+    Args:
+        vocab_size (int): Vocabulary size.
+        dim (int): Embedding dimension.
+    """
     def __init__(self, vocab_size: int, dim: int):
         super().__init__()
         self.vocab_size = vocab_size
@@ -63,6 +103,18 @@ class ParallelEmbedding(nn.Module):
         self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for parallel embedding layer.
+
+        Args:
+            x (torch.Tensor): Input tensor containing token indices.
+
+        Returns:
+            torch.Tensor: Embedded representations.
+
+        Raises:
+            ValueError: If `world_size` is not defined.
+        """
         if world_size > 1:
             mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
             x = x - self.vocab_start_idx
@@ -75,6 +127,27 @@ class ParallelEmbedding(nn.Module):
 
 
 def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    """
+    Applies a linear transformation to the incoming data: y = xA^T + b.
+    This function supports specialized implementations based on quantization
+    and tensor formats.
+
+    Args:
+        x (torch.Tensor): The input tensor.
+        weight (torch.Tensor): The weight tensor. It may be quantized and 
+            requires dequantization for certain cases.
+        bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
+
+    Returns:
+        torch.Tensor: The result of the linear transformation, which may involve 
+        quantization-aware computations depending on the input parameters.
+
+    Notes:
+        - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version 
+          is used for computation.
+        - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
+        - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
+    """
     if weight.element_size() > 1:
         return F.linear(x, weight, bias)
     elif gemm_impl == "bf16":
@@ -89,6 +162,15 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
 
 
 class Linear(nn.Module):
+    """
+    Custom linear layer with support for quantized weights and optional bias.
+
+    Args:
+        in_features (int): Number of input features.
+        out_features (int): Number of output features.
+        bias (bool): Whether to include a bias term. Defaults to False.
+        dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
+    """
     dtype = torch.bfloat16
 
     def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
@@ -108,27 +190,72 @@ class Linear(nn.Module):
             self.register_parameter("bias", None)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for the custom linear layer.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Transformed tensor after linear computation.
+        """
         return linear(x, self.weight, self.bias)
 
 
 class ColumnParallelLinear(Linear):
+    """
+    Linear layer with column parallelism, splitting output features across distributed processes.
+
+    Args:
+        in_features (int): Number of input features.
+        out_features (int): Total number of output features.
+        bias (bool): Whether to include a bias term. Defaults to False.
+        dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
+    """
     def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
         assert out_features % world_size == 0
         self.part_out_features = out_features // world_size
         super().__init__(in_features, self.part_out_features, bias, dtype)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for column parallel linear layer.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Transformed tensor with column-parallel computation.
+        """
         y = linear(x, self.weight, self.bias)
         return y
 
 
 class RowParallelLinear(Linear):
+    """
+    Linear layer with row parallelism, splitting input features across distributed processes.
+
+    Args:
+        in_features (int): Total number of input features.
+        out_features (int): Number of output features.
+        bias (bool): Whether to include a bias term. Defaults to False.
+        dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
+    """
     def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
         assert in_features % world_size == 0
         self.part_in_features = in_features // world_size
         super().__init__(self.part_in_features, out_features, bias, dtype)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for row parallel linear layer.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Transformed tensor with row-parallel computation.
+        """
         y = linear(x, self.weight)
         if world_size > 1:
             dist.all_reduce(y)
@@ -138,6 +265,13 @@ class RowParallelLinear(Linear):
 
 
 class RMSNorm(nn.Module):
+    """
+    Root Mean Square Layer Normalization (RMSNorm).
+
+    Args:
+        dim (int): Dimension of the input tensor.
+        eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
+    """
     def __init__(self, dim: int, eps: float = 1e-6):
         super().__init__()
         self.dim = dim
@@ -145,10 +279,28 @@ class RMSNorm(nn.Module):
         self.weight = nn.Parameter(torch.ones(dim))
 
     def forward(self, x: torch.Tensor):
+        """
+        Forward pass for RMSNorm.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Normalized tensor with the same shape as input.
+        """
         return F.rms_norm(x, (self.dim,), self.weight, self.eps)
 
 
 def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
+    """
+    Precomputes frequency-based complex exponential values for rotary positional embeddings.
+
+    Args:
+        args (ModelArgs): Model arguments containing positional embedding parameters.
+
+    Returns:
+        torch.Tensor: Precomputed complex exponential values for positional embeddings.
+    """
     dim = args.qk_rope_head_dim
     seqlen = args.max_seq_len
     beta_fast = args.beta_fast
@@ -157,14 +309,51 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
     factor = args.rope_factor
 
     def find_correction_dim(num_rotations, dim, base, max_seq_len):
+        """
+        Computes the correction dimension for a given number of rotations in the rotary positional embedding.
+
+        Args:
+            num_rotations (float): Number of rotations to compute the correction for.
+            dim (int): Dimensionality of the embedding space.
+            base (float): Base value for the exponential computation.
+            max_seq_len (int): Maximum sequence length.
+
+        Returns:
+            float: The correction dimension based on the input parameters.
+        """
         return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
 
     def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
+        """
+        Computes the range of correction dimensions for rotary positional embeddings.
+
+        Args:
+            low_rot (float): Lower bound for the number of rotations.
+            high_rot (float): Upper bound for the number of rotations.
+            dim (int): Dimensionality of the embedding space.
+            base (float): Base value for the exponential computation.
+            max_seq_len (int): Maximum sequence length.
+
+        Returns:
+            Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
+        """
         low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
         high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
         return max(low, 0), min(high, dim-1)
 
     def linear_ramp_factor(min, max, dim):
+        """
+        Computes a linear ramp function used to smooth values between a minimum and maximum range.
+
+        Args:
+            min (float): Minimum value for the ramp function.
+            max (float): Maximum value for the ramp function.
+            dim (int): Dimensionality of the ramp tensor.
+
+        Returns:
+            torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
+                clamped to the range [0, 1].
+        """
         if min == max:
             max += 0.001
         linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
@@ -184,6 +373,16 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
 
 
 def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+    """
+    Applies rotary positional embeddings to the input tensor.
+
+    Args:
+        x (torch.Tensor): Input tensor with positional embeddings to be applied.
+        freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
+
+    Returns:
+        torch.Tensor: Tensor with rotary embeddings applied.
+    """
     dtype = x.dtype
     x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
     freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
@@ -192,6 +391,21 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
 
 
 class MLA(nn.Module):
+    """
+    Multi-Headed Attention Layer (MLA).
+
+    Attributes:
+        dim (int): Dimensionality of the input features.
+        n_heads (int): Number of attention heads.
+        n_local_heads (int): Number of local attention heads for distributed systems.
+        q_lora_rank (int): Rank for low-rank query projection.
+        kv_lora_rank (int): Rank for low-rank key/value projection.
+        qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
+        qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
+        qk_head_dim (int): Total dimensionality of query/key projections.
+        v_head_dim (int): Dimensionality of value projections.
+        softmax_scale (float): Scaling factor for softmax in attention computation.
+    """
     def __init__(self, args: ModelArgs):
         super().__init__()
         self.dim = args.dim
@@ -227,6 +441,18 @@ class MLA(nn.Module):
             self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
 
     def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
+        """
+        Forward pass for the Multi-Headed Attention Layer (MLA).
+
+        Args:
+            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
+            start_pos (int): Starting position in the sequence for caching.
+            freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
+            mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
+
+        Returns:
+            torch.Tensor: Output tensor with the same shape as the input.
+        """
         bsz, seqlen, _ = x.size()
         end_pos = start_pos + seqlen
         if self.q_lora_rank == 0:
@@ -269,18 +495,61 @@ class MLA(nn.Module):
 
 
 class MLP(nn.Module):
+    """
+    Multi-Layer Perceptron (MLP) used as a feed-forward layer.
+
+    Attributes:
+        w1 (nn.Module): Linear layer for input-to-hidden transformation.
+        w2 (nn.Module): Linear layer for hidden-to-output transformation.
+        w3 (nn.Module): Additional linear layer for feature transformation.
+    """
     def __init__(self, dim: int, inter_dim: int):
+        """
+        Initializes the MLP layer.
+
+        Args:
+            dim (int): Input and output dimensionality.
+            inter_dim (int): Hidden layer dimensionality.
+        """
         super().__init__()
         self.w1 = ColumnParallelLinear(dim, inter_dim)
         self.w2 = RowParallelLinear(inter_dim, dim)
         self.w3 = ColumnParallelLinear(dim, inter_dim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for the MLP layer.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Output tensor after MLP computation.
+        """
         return self.w2(F.silu(self.w1(x)) * self.w3(x))
 
 
 class Gate(nn.Module):
+    """
+    Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
+
+    Attributes:
+        dim (int): Dimensionality of input features.
+        topk (int): Number of top experts activated for each input.
+        n_groups (int): Number of groups for routing.
+        topk_groups (int): Number of groups to route inputs to.
+        score_func (str): Scoring function ('softmax' or 'sigmoid').
+        route_scale (float): Scaling factor for routing weights.
+        weight (torch.nn.Parameter): Learnable weights for the gate.
+        bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
+    """
     def __init__(self, args: ModelArgs):
+        """
+        Initializes the Gate module.
+
+        Args:
+            args (ModelArgs): Model arguments containing gating parameters.
+        """
         super().__init__()
         self.dim = args.dim
         self.topk = args.n_activated_experts
@@ -292,6 +561,15 @@ class Gate(nn.Module):
         self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
 
     def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Forward pass for the gating mechanism.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
+        """
         scores = linear(x, self.weight)
         if self.score_func == "softmax":
             scores = scores.softmax(dim=-1, dtype=torch.float32)
@@ -318,18 +596,60 @@ class Gate(nn.Module):
 
 
 class Expert(nn.Module):
+    """
+    Expert layer for Mixture-of-Experts (MoE) models.
+
+    Attributes:
+        w1 (nn.Module): Linear layer for input-to-hidden transformation.
+        w2 (nn.Module): Linear layer for hidden-to-output transformation.
+        w3 (nn.Module): Additional linear layer for feature transformation.
+    """
     def __init__(self, dim: int, inter_dim: int):
+        """
+        Initializes the Expert layer.
+
+        Args:
+            dim (int): Input and output dimensionality.
+            inter_dim (int): Hidden layer dimensionality.
+        """
         super().__init__()
         self.w1 = Linear(dim, inter_dim)
         self.w2 = Linear(inter_dim, dim)
         self.w3 = Linear(dim, inter_dim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for the Expert layer.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Output tensor after expert computation.
+        """
         return self.w2(F.silu(self.w1(x)) * self.w3(x))
 
 
 class MoE(nn.Module):
+    """
+    Mixture-of-Experts (MoE) module.
+
+    Attributes:
+        dim (int): Dimensionality of input features.
+        n_routed_experts (int): Total number of experts in the model.
+        n_local_experts (int): Number of experts handled locally in distributed systems.
+        n_activated_experts (int): Number of experts activated for each input.
+        gate (nn.Module): Gating mechanism to route inputs to experts.
+        experts (nn.ModuleList): List of expert modules.
+        shared_experts (nn.Module): Shared experts applied to all inputs.
+    """
     def __init__(self, args: ModelArgs):
+        """
+        Initializes the MoE module.
+
+        Args:
+            args (ModelArgs): Model arguments containing MoE parameters.
+        """
         super().__init__()
         self.dim = args.dim
         assert args.n_routed_experts % world_size == 0
@@ -344,6 +664,15 @@ class MoE(nn.Module):
         self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass for the MoE module.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Output tensor after expert routing and computation.
+        """
         shape = x.size()
         x = x.view(-1, self.dim)
         weights, indices = self.gate(x)
@@ -362,7 +691,23 @@ class MoE(nn.Module):
 
 
 class Block(nn.Module):
+    """
+    Transformer block combining attention and feed-forward layers.
+
+    Attributes:
+        attn (nn.Module): Attention layer (MLA).
+        ffn (nn.Module): Feed-forward network (MLP or MoE).
+        attn_norm (nn.Module): Layer normalization for attention.
+        ffn_norm (nn.Module): Layer normalization for feed-forward network.
+    """
     def __init__(self, layer_id: int, args: ModelArgs):
+        """
+        Initializes the Transformer block.
+
+        Args:
+            layer_id (int): Layer index in the transformer.
+            args (ModelArgs): Model arguments containing block parameters.
+        """
         super().__init__()
         self.attn = MLA(args)
         self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
@@ -370,13 +715,42 @@ class Block(nn.Module):
         self.ffn_norm = RMSNorm(args.dim)
 
     def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
+        """
+        Forward pass for the Transformer block.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+            start_pos (int): Starting position in the sequence.
+            freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
+            mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
+
+        Returns:
+            torch.Tensor: Output tensor after block computation.
+        """
         x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
         x = x + self.ffn(self.ffn_norm(x))
         return x
 
 
 class Transformer(nn.Module):
+    """
+    Transformer model with positional embeddings, multiple layers, and output projection.
+
+    Attributes:
+        max_seq_len (int): Maximum sequence length for the transformer.
+        embed (nn.Module): Embedding layer for input tokens.
+        layers (torch.nn.ModuleList): List of transformer blocks.
+        norm (nn.Module): Layer normalization applied after all blocks.
+        head (nn.Module): Output projection layer mapping to vocabulary size.
+        freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
+    """
     def __init__(self, args: ModelArgs):
+        """
+        Initializes the Transformer model.
+
+        Args:
+            args (ModelArgs): Model arguments containing transformer parameters.
+        """
         global world_size, rank
         world_size = dist.get_world_size() if dist.is_initialized() else 1
         rank = dist.get_rank() if dist.is_initialized() else 0
@@ -393,6 +767,16 @@ class Transformer(nn.Module):
 
     @torch.inference_mode()
     def forward(self, tokens: torch.Tensor, start_pos: int = 0):
+        """
+        Forward pass for the Transformer model.
+
+        Args:
+            tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
+            start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
+
+        Returns:
+            torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
+        """
         seqlen = tokens.size(1)
         h = self.embed(tokens)
         freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]