|
|
@@ -18,6 +18,39 @@ attn_impl: Literal["naive", "absorb"] = "absorb"
|
|
|
|
|
|
@dataclass
|
|
|
class ModelArgs:
|
|
|
+ """
|
|
|
+ Data class for defining model arguments and hyperparameters.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ max_batch_size (int): Maximum batch size.
|
|
|
+ max_seq_len (int): Maximum sequence length.
|
|
|
+ dtype (Literal["bf16", "fp8"]): Data type for computations.
|
|
|
+ vocab_size (int): Vocabulary size.
|
|
|
+ dim (int): Model dimension.
|
|
|
+ inter_dim (int): Intermediate dimension for MLP layers.
|
|
|
+ moe_inter_dim (int): Intermediate dimension for MoE layers.
|
|
|
+ n_layers (int): Number of transformer layers.
|
|
|
+ n_dense_layers (int): Number of dense layers in the model.
|
|
|
+ n_heads (int): Number of attention heads.
|
|
|
+ n_routed_experts (int): Number of routed experts for MoE layers.
|
|
|
+ n_shared_experts (int): Number of shared experts for MoE layers.
|
|
|
+ n_activated_experts (int): Number of activated experts in MoE layers.
|
|
|
+ n_expert_groups (int): Number of expert groups.
|
|
|
+ n_limited_groups (int): Number of limited groups for MoE routing.
|
|
|
+ score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
|
|
|
+ route_scale (float): Scaling factor for routing scores.
|
|
|
+ q_lora_rank (int): LoRA rank for query projections.
|
|
|
+ kv_lora_rank (int): LoRA rank for key-value projections.
|
|
|
+ qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
|
|
|
+ qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
|
|
|
+ v_head_dim (int): Dimension for value projections.
|
|
|
+ original_seq_len (int): Original sequence length.
|
|
|
+ rope_theta (float): Base for rotary positional encoding.
|
|
|
+ rope_factor (float): Scaling factor for extended sequence lengths.
|
|
|
+ beta_fast (int): Fast beta correction factor.
|
|
|
+ beta_slow (int): Slow beta correction factor.
|
|
|
+ mscale (float): Scaling factor for extended attention.
|
|
|
+ """
|
|
|
max_batch_size: int = 8
|
|
|
max_seq_len: int = 4096 * 4
|
|
|
dtype: Literal["bf16", "fp8"] = "bf16"
|
|
|
@@ -52,6 +85,13 @@ class ModelArgs:
|
|
|
|
|
|
|
|
|
class ParallelEmbedding(nn.Module):
|
|
|
+ """
|
|
|
+ Embedding layer with parallelism support across distributed processes.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ vocab_size (int): Vocabulary size.
|
|
|
+ dim (int): Embedding dimension.
|
|
|
+ """
|
|
|
def __init__(self, vocab_size: int, dim: int):
|
|
|
super().__init__()
|
|
|
self.vocab_size = vocab_size
|
|
|
@@ -63,6 +103,18 @@ class ParallelEmbedding(nn.Module):
|
|
|
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for parallel embedding layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor containing token indices.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Embedded representations.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: If `world_size` is not defined.
|
|
|
+ """
|
|
|
if world_size > 1:
|
|
|
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
|
|
x = x - self.vocab_start_idx
|
|
|
@@ -75,6 +127,27 @@ class ParallelEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Applies a linear transformation to the incoming data: y = xA^T + b.
|
|
|
+ This function supports specialized implementations based on quantization
|
|
|
+ and tensor formats.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): The input tensor.
|
|
|
+ weight (torch.Tensor): The weight tensor. It may be quantized and
|
|
|
+ requires dequantization for certain cases.
|
|
|
+ bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: The result of the linear transformation, which may involve
|
|
|
+ quantization-aware computations depending on the input parameters.
|
|
|
+
|
|
|
+ Notes:
|
|
|
+ - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
|
|
|
+ is used for computation.
|
|
|
+ - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
|
|
+ - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
|
|
+ """
|
|
|
if weight.element_size() > 1:
|
|
|
return F.linear(x, weight, bias)
|
|
|
elif gemm_impl == "bf16":
|
|
|
@@ -89,6 +162,15 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
|
|
|
|
|
|
|
|
class Linear(nn.Module):
|
|
|
+ """
|
|
|
+ Custom linear layer with support for quantized weights and optional bias.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ in_features (int): Number of input features.
|
|
|
+ out_features (int): Number of output features.
|
|
|
+ bias (bool): Whether to include a bias term. Defaults to False.
|
|
|
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
|
|
+ """
|
|
|
dtype = torch.bfloat16
|
|
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
|
|
@@ -108,27 +190,72 @@ class Linear(nn.Module):
|
|
|
self.register_parameter("bias", None)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for the custom linear layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Transformed tensor after linear computation.
|
|
|
+ """
|
|
|
return linear(x, self.weight, self.bias)
|
|
|
|
|
|
|
|
|
class ColumnParallelLinear(Linear):
|
|
|
+ """
|
|
|
+ Linear layer with column parallelism, splitting output features across distributed processes.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ in_features (int): Number of input features.
|
|
|
+ out_features (int): Total number of output features.
|
|
|
+ bias (bool): Whether to include a bias term. Defaults to False.
|
|
|
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
|
|
+ """
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
|
|
assert out_features % world_size == 0
|
|
|
self.part_out_features = out_features // world_size
|
|
|
super().__init__(in_features, self.part_out_features, bias, dtype)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for column parallel linear layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Transformed tensor with column-parallel computation.
|
|
|
+ """
|
|
|
y = linear(x, self.weight, self.bias)
|
|
|
return y
|
|
|
|
|
|
|
|
|
class RowParallelLinear(Linear):
|
|
|
+ """
|
|
|
+ Linear layer with row parallelism, splitting input features across distributed processes.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ in_features (int): Total number of input features.
|
|
|
+ out_features (int): Number of output features.
|
|
|
+ bias (bool): Whether to include a bias term. Defaults to False.
|
|
|
+ dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
|
|
+ """
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
|
|
assert in_features % world_size == 0
|
|
|
self.part_in_features = in_features // world_size
|
|
|
super().__init__(self.part_in_features, out_features, bias, dtype)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for row parallel linear layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Transformed tensor with row-parallel computation.
|
|
|
+ """
|
|
|
y = linear(x, self.weight)
|
|
|
if world_size > 1:
|
|
|
dist.all_reduce(y)
|
|
|
@@ -138,6 +265,13 @@ class RowParallelLinear(Linear):
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
+ """
|
|
|
+ Root Mean Square Layer Normalization (RMSNorm).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dim (int): Dimension of the input tensor.
|
|
|
+ eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
|
|
+ """
|
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
@@ -145,10 +279,28 @@ class RMSNorm(nn.Module):
|
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
+ """
|
|
|
+ Forward pass for RMSNorm.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Normalized tensor with the same shape as input.
|
|
|
+ """
|
|
|
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
|
|
|
|
|
|
|
|
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args (ModelArgs): Model arguments containing positional embedding parameters.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
|
|
+ """
|
|
|
dim = args.qk_rope_head_dim
|
|
|
seqlen = args.max_seq_len
|
|
|
beta_fast = args.beta_fast
|
|
|
@@ -157,14 +309,51 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|
|
factor = args.rope_factor
|
|
|
|
|
|
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
|
|
+ """
|
|
|
+ Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ num_rotations (float): Number of rotations to compute the correction for.
|
|
|
+ dim (int): Dimensionality of the embedding space.
|
|
|
+ base (float): Base value for the exponential computation.
|
|
|
+ max_seq_len (int): Maximum sequence length.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ float: The correction dimension based on the input parameters.
|
|
|
+ """
|
|
|
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
|
|
|
|
|
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
|
|
+ """
|
|
|
+ Computes the range of correction dimensions for rotary positional embeddings.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ low_rot (float): Lower bound for the number of rotations.
|
|
|
+ high_rot (float): Upper bound for the number of rotations.
|
|
|
+ dim (int): Dimensionality of the embedding space.
|
|
|
+ base (float): Base value for the exponential computation.
|
|
|
+ max_seq_len (int): Maximum sequence length.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
|
|
+ """
|
|
|
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
|
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
|
|
return max(low, 0), min(high, dim-1)
|
|
|
|
|
|
def linear_ramp_factor(min, max, dim):
|
|
|
+ """
|
|
|
+ Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ min (float): Minimum value for the ramp function.
|
|
|
+ max (float): Maximum value for the ramp function.
|
|
|
+ dim (int): Dimensionality of the ramp tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
|
|
+ clamped to the range [0, 1].
|
|
|
+ """
|
|
|
if min == max:
|
|
|
max += 0.001
|
|
|
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
|
|
@@ -184,6 +373,16 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Applies rotary positional embeddings to the input tensor.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
|
|
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Tensor with rotary embeddings applied.
|
|
|
+ """
|
|
|
dtype = x.dtype
|
|
|
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
|
|
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
|
|
@@ -192,6 +391,21 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
class MLA(nn.Module):
|
|
|
+ """
|
|
|
+ Multi-Headed Attention Layer (MLA).
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ dim (int): Dimensionality of the input features.
|
|
|
+ n_heads (int): Number of attention heads.
|
|
|
+ n_local_heads (int): Number of local attention heads for distributed systems.
|
|
|
+ q_lora_rank (int): Rank for low-rank query projection.
|
|
|
+ kv_lora_rank (int): Rank for low-rank key/value projection.
|
|
|
+ qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
|
|
|
+ qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
|
|
|
+ qk_head_dim (int): Total dimensionality of query/key projections.
|
|
|
+ v_head_dim (int): Dimensionality of value projections.
|
|
|
+ softmax_scale (float): Scaling factor for softmax in attention computation.
|
|
|
+ """
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
super().__init__()
|
|
|
self.dim = args.dim
|
|
|
@@ -227,6 +441,18 @@ class MLA(nn.Module):
|
|
|
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
|
|
+ """
|
|
|
+ Forward pass for the Multi-Headed Attention Layer (MLA).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
|
|
+ start_pos (int): Starting position in the sequence for caching.
|
|
|
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
|
|
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Output tensor with the same shape as the input.
|
|
|
+ """
|
|
|
bsz, seqlen, _ = x.size()
|
|
|
end_pos = start_pos + seqlen
|
|
|
if self.q_lora_rank == 0:
|
|
|
@@ -269,18 +495,61 @@ class MLA(nn.Module):
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
+ """
|
|
|
+ Multi-Layer Perceptron (MLP) used as a feed-forward layer.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
|
|
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
|
|
+ w3 (nn.Module): Additional linear layer for feature transformation.
|
|
|
+ """
|
|
|
def __init__(self, dim: int, inter_dim: int):
|
|
|
+ """
|
|
|
+ Initializes the MLP layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dim (int): Input and output dimensionality.
|
|
|
+ inter_dim (int): Hidden layer dimensionality.
|
|
|
+ """
|
|
|
super().__init__()
|
|
|
self.w1 = ColumnParallelLinear(dim, inter_dim)
|
|
|
self.w2 = RowParallelLinear(inter_dim, dim)
|
|
|
self.w3 = ColumnParallelLinear(dim, inter_dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for the MLP layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Output tensor after MLP computation.
|
|
|
+ """
|
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
|
|
|
class Gate(nn.Module):
|
|
|
+ """
|
|
|
+ Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ dim (int): Dimensionality of input features.
|
|
|
+ topk (int): Number of top experts activated for each input.
|
|
|
+ n_groups (int): Number of groups for routing.
|
|
|
+ topk_groups (int): Number of groups to route inputs to.
|
|
|
+ score_func (str): Scoring function ('softmax' or 'sigmoid').
|
|
|
+ route_scale (float): Scaling factor for routing weights.
|
|
|
+ weight (torch.nn.Parameter): Learnable weights for the gate.
|
|
|
+ bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
|
|
+ """
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
+ """
|
|
|
+ Initializes the Gate module.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args (ModelArgs): Model arguments containing gating parameters.
|
|
|
+ """
|
|
|
super().__init__()
|
|
|
self.dim = args.dim
|
|
|
self.topk = args.n_activated_experts
|
|
|
@@ -292,6 +561,15 @@ class Gate(nn.Module):
|
|
|
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ """
|
|
|
+ Forward pass for the gating mechanism.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
|
|
+ """
|
|
|
scores = linear(x, self.weight)
|
|
|
if self.score_func == "softmax":
|
|
|
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
|
|
@@ -318,18 +596,60 @@ class Gate(nn.Module):
|
|
|
|
|
|
|
|
|
class Expert(nn.Module):
|
|
|
+ """
|
|
|
+ Expert layer for Mixture-of-Experts (MoE) models.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
|
|
+ w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
|
|
+ w3 (nn.Module): Additional linear layer for feature transformation.
|
|
|
+ """
|
|
|
def __init__(self, dim: int, inter_dim: int):
|
|
|
+ """
|
|
|
+ Initializes the Expert layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dim (int): Input and output dimensionality.
|
|
|
+ inter_dim (int): Hidden layer dimensionality.
|
|
|
+ """
|
|
|
super().__init__()
|
|
|
self.w1 = Linear(dim, inter_dim)
|
|
|
self.w2 = Linear(inter_dim, dim)
|
|
|
self.w3 = Linear(dim, inter_dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for the Expert layer.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Output tensor after expert computation.
|
|
|
+ """
|
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
|
|
|
class MoE(nn.Module):
|
|
|
+ """
|
|
|
+ Mixture-of-Experts (MoE) module.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ dim (int): Dimensionality of input features.
|
|
|
+ n_routed_experts (int): Total number of experts in the model.
|
|
|
+ n_local_experts (int): Number of experts handled locally in distributed systems.
|
|
|
+ n_activated_experts (int): Number of experts activated for each input.
|
|
|
+ gate (nn.Module): Gating mechanism to route inputs to experts.
|
|
|
+ experts (nn.ModuleList): List of expert modules.
|
|
|
+ shared_experts (nn.Module): Shared experts applied to all inputs.
|
|
|
+ """
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
+ """
|
|
|
+ Initializes the MoE module.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args (ModelArgs): Model arguments containing MoE parameters.
|
|
|
+ """
|
|
|
super().__init__()
|
|
|
self.dim = args.dim
|
|
|
assert args.n_routed_experts % world_size == 0
|
|
|
@@ -344,6 +664,15 @@ class MoE(nn.Module):
|
|
|
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for the MoE module.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Output tensor after expert routing and computation.
|
|
|
+ """
|
|
|
shape = x.size()
|
|
|
x = x.view(-1, self.dim)
|
|
|
weights, indices = self.gate(x)
|
|
|
@@ -362,7 +691,23 @@ class MoE(nn.Module):
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
+ """
|
|
|
+ Transformer block combining attention and feed-forward layers.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ attn (nn.Module): Attention layer (MLA).
|
|
|
+ ffn (nn.Module): Feed-forward network (MLP or MoE).
|
|
|
+ attn_norm (nn.Module): Layer normalization for attention.
|
|
|
+ ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
|
|
+ """
|
|
|
def __init__(self, layer_id: int, args: ModelArgs):
|
|
|
+ """
|
|
|
+ Initializes the Transformer block.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ layer_id (int): Layer index in the transformer.
|
|
|
+ args (ModelArgs): Model arguments containing block parameters.
|
|
|
+ """
|
|
|
super().__init__()
|
|
|
self.attn = MLA(args)
|
|
|
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
|
|
@@ -370,13 +715,42 @@ class Block(nn.Module):
|
|
|
self.ffn_norm = RMSNorm(args.dim)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Forward pass for the Transformer block.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Input tensor.
|
|
|
+ start_pos (int): Starting position in the sequence.
|
|
|
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
|
|
+ mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Output tensor after block computation.
|
|
|
+ """
|
|
|
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
|
|
|
x = x + self.ffn(self.ffn_norm(x))
|
|
|
return x
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
+ """
|
|
|
+ Transformer model with positional embeddings, multiple layers, and output projection.
|
|
|
+
|
|
|
+ Attributes:
|
|
|
+ max_seq_len (int): Maximum sequence length for the transformer.
|
|
|
+ embed (nn.Module): Embedding layer for input tokens.
|
|
|
+ layers (torch.nn.ModuleList): List of transformer blocks.
|
|
|
+ norm (nn.Module): Layer normalization applied after all blocks.
|
|
|
+ head (nn.Module): Output projection layer mapping to vocabulary size.
|
|
|
+ freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
|
|
+ """
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
+ """
|
|
|
+ Initializes the Transformer model.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ args (ModelArgs): Model arguments containing transformer parameters.
|
|
|
+ """
|
|
|
global world_size, rank
|
|
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
|
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
|
@@ -393,6 +767,16 @@ class Transformer(nn.Module):
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
|
|
+ """
|
|
|
+ Forward pass for the Transformer model.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
|
|
|
+ start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
|
|
+ """
|
|
|
seqlen = tokens.size(1)
|
|
|
h = self.embed(tokens)
|
|
|
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|