model.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Tuple, Optional, Literal
  4. import torch
  5. from torch import nn
  6. import torch.nn.functional as F
  7. import torch.distributed as dist
  8. from kernel import act_quant, weight_dequant, fp8_gemm
  9. world_size = 1
  10. rank = 0
  11. block_size = 128
  12. gemm_impl: Literal["bf16", "fp8"] = "bf16"
  13. attn_impl: Literal["naive", "absorb"] = "absorb"
  14. @dataclass
  15. class ModelArgs:
  16. """
  17. Data class for defining model arguments and hyperparameters.
  18. Attributes:
  19. max_batch_size (int): Maximum batch size.
  20. max_seq_len (int): Maximum sequence length.
  21. dtype (Literal["bf16", "fp8"]): Data type for computations.
  22. vocab_size (int): Vocabulary size.
  23. dim (int): Model dimension.
  24. inter_dim (int): Intermediate dimension for MLP layers.
  25. moe_inter_dim (int): Intermediate dimension for MoE layers.
  26. n_layers (int): Number of transformer layers.
  27. n_dense_layers (int): Number of dense layers in the model.
  28. n_heads (int): Number of attention heads.
  29. n_routed_experts (int): Number of routed experts for MoE layers.
  30. n_shared_experts (int): Number of shared experts for MoE layers.
  31. n_activated_experts (int): Number of activated experts in MoE layers.
  32. n_expert_groups (int): Number of expert groups.
  33. n_limited_groups (int): Number of limited groups for MoE routing.
  34. score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
  35. route_scale (float): Scaling factor for routing scores.
  36. q_lora_rank (int): LoRA rank for query projections.
  37. kv_lora_rank (int): LoRA rank for key-value projections.
  38. qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
  39. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
  40. v_head_dim (int): Dimension for value projections.
  41. original_seq_len (int): Original sequence length.
  42. rope_theta (float): Base for rotary positional encoding.
  43. rope_factor (float): Scaling factor for extended sequence lengths.
  44. beta_fast (int): Fast beta correction factor.
  45. beta_slow (int): Slow beta correction factor.
  46. mscale (float): Scaling factor for extended attention.
  47. """
  48. max_batch_size: int = 8
  49. max_seq_len: int = 4096 * 4
  50. dtype: Literal["bf16", "fp8"] = "bf16"
  51. vocab_size: int = 102400
  52. dim: int = 2048
  53. inter_dim: int = 10944
  54. moe_inter_dim: int = 1408
  55. n_layers: int = 27
  56. n_dense_layers: int = 1
  57. n_heads: int = 16
  58. # moe
  59. n_routed_experts: int = 64
  60. n_shared_experts: int = 2
  61. n_activated_experts: int = 6
  62. n_expert_groups: int = 1
  63. n_limited_groups: int = 1
  64. score_func: Literal["softmax", "sigmoid"] = "softmax"
  65. route_scale: float = 1.
  66. # mla
  67. q_lora_rank: int = 0
  68. kv_lora_rank: int = 512
  69. qk_nope_head_dim: int = 128
  70. qk_rope_head_dim: int = 64
  71. v_head_dim: int = 128
  72. # yarn
  73. original_seq_len: int = 4096
  74. rope_theta: float = 10000.0
  75. rope_factor: float = 40
  76. beta_fast: int = 32
  77. beta_slow: int = 1
  78. mscale: float = 1.
  79. class ParallelEmbedding(nn.Module):
  80. """
  81. Embedding layer with parallelism support across distributed processes.
  82. Args:
  83. vocab_size (int): Vocabulary size.
  84. dim (int): Embedding dimension.
  85. """
  86. def __init__(self, vocab_size: int, dim: int):
  87. super().__init__()
  88. self.vocab_size = vocab_size
  89. self.dim = dim
  90. assert vocab_size % world_size == 0
  91. self.part_vocab_size = (vocab_size // world_size)
  92. self.vocab_start_idx = rank * self.part_vocab_size
  93. self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
  94. self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
  95. def forward(self, x: torch.Tensor) -> torch.Tensor:
  96. """
  97. Forward pass for parallel embedding layer.
  98. Args:
  99. x (torch.Tensor): Input tensor containing token indices.
  100. Returns:
  101. torch.Tensor: Embedded representations.
  102. Raises:
  103. ValueError: If `world_size` is not defined.
  104. """
  105. if world_size > 1:
  106. mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
  107. x = x - self.vocab_start_idx
  108. x[mask] = 0
  109. y = F.embedding(x, self.weight)
  110. if world_size > 1:
  111. y[mask] = 0
  112. dist.all_reduce(y)
  113. return y
  114. def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  115. """
  116. Applies a linear transformation to the incoming data: y = xA^T + b.
  117. This function supports specialized implementations based on quantization
  118. and tensor formats.
  119. Args:
  120. x (torch.Tensor): The input tensor.
  121. weight (torch.Tensor): The weight tensor. It may be quantized and
  122. requires dequantization for certain cases.
  123. bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
  124. Returns:
  125. torch.Tensor: The result of the linear transformation, which may involve
  126. quantization-aware computations depending on the input parameters.
  127. Notes:
  128. - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
  129. is used for computation.
  130. - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
  131. - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
  132. """
  133. if weight.element_size() > 1:
  134. return F.linear(x, weight, bias)
  135. elif gemm_impl == "bf16":
  136. weight = weight_dequant(weight, weight.scale)
  137. return F.linear(x, weight, bias)
  138. else:
  139. x, scale = act_quant(x, block_size)
  140. y = fp8_gemm(x, scale, weight, weight.scale)
  141. if bias is not None:
  142. y += bias
  143. return y
  144. class Linear(nn.Module):
  145. """
  146. Custom linear layer with support for quantized weights and optional bias.
  147. Args:
  148. in_features (int): Number of input features.
  149. out_features (int): Number of output features.
  150. bias (bool): Whether to include a bias term. Defaults to False.
  151. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
  152. """
  153. dtype = torch.bfloat16
  154. def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
  155. super().__init__()
  156. self.in_features = in_features
  157. self.out_features = out_features
  158. self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
  159. if self.weight.element_size() == 1:
  160. scale_out_features = (out_features + block_size - 1) // block_size
  161. scale_in_features = (in_features + block_size - 1) // block_size
  162. self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
  163. else:
  164. self.register_parameter("scale", None)
  165. if bias:
  166. self.bias = nn.Parameter(torch.empty(self.part_out_features))
  167. else:
  168. self.register_parameter("bias", None)
  169. def forward(self, x: torch.Tensor) -> torch.Tensor:
  170. """
  171. Forward pass for the custom linear layer.
  172. Args:
  173. x (torch.Tensor): Input tensor.
  174. Returns:
  175. torch.Tensor: Transformed tensor after linear computation.
  176. """
  177. return linear(x, self.weight, self.bias)
  178. class ColumnParallelLinear(Linear):
  179. """
  180. Linear layer with column parallelism, splitting output features across distributed processes.
  181. Args:
  182. in_features (int): Number of input features.
  183. out_features (int): Total number of output features.
  184. bias (bool): Whether to include a bias term. Defaults to False.
  185. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
  186. """
  187. def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
  188. assert out_features % world_size == 0
  189. self.part_out_features = out_features // world_size
  190. super().__init__(in_features, self.part_out_features, bias, dtype)
  191. def forward(self, x: torch.Tensor) -> torch.Tensor:
  192. """
  193. Forward pass for column parallel linear layer.
  194. Args:
  195. x (torch.Tensor): Input tensor.
  196. Returns:
  197. torch.Tensor: Transformed tensor with column-parallel computation.
  198. """
  199. y = linear(x, self.weight, self.bias)
  200. return y
  201. class RowParallelLinear(Linear):
  202. """
  203. Linear layer with row parallelism, splitting input features across distributed processes.
  204. Args:
  205. in_features (int): Total number of input features.
  206. out_features (int): Number of output features.
  207. bias (bool): Whether to include a bias term. Defaults to False.
  208. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
  209. """
  210. def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
  211. assert in_features % world_size == 0
  212. self.part_in_features = in_features // world_size
  213. super().__init__(self.part_in_features, out_features, bias, dtype)
  214. def forward(self, x: torch.Tensor) -> torch.Tensor:
  215. """
  216. Forward pass for row parallel linear layer.
  217. Args:
  218. x (torch.Tensor): Input tensor.
  219. Returns:
  220. torch.Tensor: Transformed tensor with row-parallel computation.
  221. """
  222. y = linear(x, self.weight)
  223. if world_size > 1:
  224. dist.all_reduce(y)
  225. if self.bias is not None:
  226. y += self.bias
  227. return y
  228. class RMSNorm(nn.Module):
  229. """
  230. Root Mean Square Layer Normalization (RMSNorm).
  231. Args:
  232. dim (int): Dimension of the input tensor.
  233. eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
  234. """
  235. def __init__(self, dim: int, eps: float = 1e-6):
  236. super().__init__()
  237. self.dim = dim
  238. self.eps = eps
  239. self.weight = nn.Parameter(torch.ones(dim))
  240. def forward(self, x: torch.Tensor):
  241. """
  242. Forward pass for RMSNorm.
  243. Args:
  244. x (torch.Tensor): Input tensor.
  245. Returns:
  246. torch.Tensor: Normalized tensor with the same shape as input.
  247. """
  248. return F.rms_norm(x, (self.dim,), self.weight, self.eps)
  249. def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
  250. """
  251. Precomputes frequency-based complex exponential values for rotary positional embeddings.
  252. Args:
  253. args (ModelArgs): Model arguments containing positional embedding parameters.
  254. Returns:
  255. torch.Tensor: Precomputed complex exponential values for positional embeddings.
  256. """
  257. dim = args.qk_rope_head_dim
  258. seqlen = args.max_seq_len
  259. beta_fast = args.beta_fast
  260. beta_slow = args.beta_slow
  261. base = args.rope_theta
  262. factor = args.rope_factor
  263. def find_correction_dim(num_rotations, dim, base, max_seq_len):
  264. """
  265. Computes the correction dimension for a given number of rotations in the rotary positional embedding.
  266. Args:
  267. num_rotations (float): Number of rotations to compute the correction for.
  268. dim (int): Dimensionality of the embedding space.
  269. base (float): Base value for the exponential computation.
  270. max_seq_len (int): Maximum sequence length.
  271. Returns:
  272. float: The correction dimension based on the input parameters.
  273. """
  274. return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
  275. def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
  276. """
  277. Computes the range of correction dimensions for rotary positional embeddings.
  278. Args:
  279. low_rot (float): Lower bound for the number of rotations.
  280. high_rot (float): Upper bound for the number of rotations.
  281. dim (int): Dimensionality of the embedding space.
  282. base (float): Base value for the exponential computation.
  283. max_seq_len (int): Maximum sequence length.
  284. Returns:
  285. Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
  286. """
  287. low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
  288. high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
  289. return max(low, 0), min(high, dim-1)
  290. def linear_ramp_factor(min, max, dim):
  291. """
  292. Computes a linear ramp function used to smooth values between a minimum and maximum range.
  293. Args:
  294. min (float): Minimum value for the ramp function.
  295. max (float): Maximum value for the ramp function.
  296. dim (int): Dimensionality of the ramp tensor.
  297. Returns:
  298. torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
  299. clamped to the range [0, 1].
  300. """
  301. if min == max:
  302. max += 0.001
  303. linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
  304. ramp_func = torch.clamp(linear_func, 0, 1)
  305. return ramp_func
  306. freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
  307. if seqlen > args.original_seq_len:
  308. low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
  309. smooth = 1 - linear_ramp_factor(low, high, dim // 2)
  310. freqs = freqs / factor * (1 - smooth) + freqs * smooth
  311. t = torch.arange(seqlen)
  312. freqs = torch.outer(t, freqs)
  313. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  314. return freqs_cis
  315. def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
  316. """
  317. Applies rotary positional embeddings to the input tensor.
  318. Args:
  319. x (torch.Tensor): Input tensor with positional embeddings to be applied.
  320. freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
  321. Returns:
  322. torch.Tensor: Tensor with rotary embeddings applied.
  323. """
  324. dtype = x.dtype
  325. x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
  326. freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
  327. y = torch.view_as_real(x * freqs_cis).flatten(3)
  328. return y.to(dtype)
  329. class MLA(nn.Module):
  330. """
  331. Multi-Headed Attention Layer (MLA).
  332. Attributes:
  333. dim (int): Dimensionality of the input features.
  334. n_heads (int): Number of attention heads.
  335. n_local_heads (int): Number of local attention heads for distributed systems.
  336. q_lora_rank (int): Rank for low-rank query projection.
  337. kv_lora_rank (int): Rank for low-rank key/value projection.
  338. qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
  339. qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
  340. qk_head_dim (int): Total dimensionality of query/key projections.
  341. v_head_dim (int): Dimensionality of value projections.
  342. softmax_scale (float): Scaling factor for softmax in attention computation.
  343. """
  344. def __init__(self, args: ModelArgs):
  345. super().__init__()
  346. self.dim = args.dim
  347. self.n_heads = args.n_heads
  348. self.n_local_heads = args.n_heads // world_size
  349. self.q_lora_rank = args.q_lora_rank
  350. self.kv_lora_rank = args.kv_lora_rank
  351. self.qk_nope_head_dim = args.qk_nope_head_dim
  352. self.qk_rope_head_dim = args.qk_rope_head_dim
  353. self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
  354. self.v_head_dim = args.v_head_dim
  355. if self.q_lora_rank == 0:
  356. self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
  357. else:
  358. self.wq_a = Linear(self.dim, self.q_lora_rank)
  359. self.q_norm = RMSNorm(self.q_lora_rank)
  360. self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
  361. self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
  362. self.kv_norm = RMSNorm(self.kv_lora_rank)
  363. self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
  364. self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
  365. self.softmax_scale = self.qk_head_dim ** -0.5
  366. if args.max_seq_len > args.original_seq_len:
  367. mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
  368. self.softmax_scale = self.softmax_scale * mscale * mscale
  369. if attn_impl == "naive":
  370. self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
  371. self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
  372. else:
  373. self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
  374. self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
  375. def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
  376. """
  377. Forward pass for the Multi-Headed Attention Layer (MLA).
  378. Args:
  379. x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
  380. start_pos (int): Starting position in the sequence for caching.
  381. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
  382. mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
  383. Returns:
  384. torch.Tensor: Output tensor with the same shape as the input.
  385. """
  386. bsz, seqlen, _ = x.size()
  387. end_pos = start_pos + seqlen
  388. if self.q_lora_rank == 0:
  389. q = self.wq(x)
  390. else:
  391. q = self.wq_b(self.q_norm(self.wq_a(x)))
  392. q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
  393. q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
  394. q_pe = apply_rotary_emb(q_pe, freqs_cis)
  395. kv = self.wkv_a(x)
  396. kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  397. k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
  398. if attn_impl == "naive":
  399. q = torch.cat([q_nope, q_pe], dim=-1)
  400. kv = self.wkv_b(self.kv_norm(kv))
  401. kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
  402. k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  403. k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
  404. self.k_cache[:bsz, start_pos:end_pos] = k
  405. self.v_cache[:bsz, start_pos:end_pos] = v
  406. scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
  407. else:
  408. wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
  409. wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
  410. q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
  411. self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
  412. self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
  413. scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
  414. torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
  415. if mask is not None:
  416. scores += mask.unsqueeze(1)
  417. scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
  418. if attn_impl == "naive":
  419. x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
  420. else:
  421. x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
  422. x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
  423. x = self.wo(x.flatten(2))
  424. return x
  425. class MLP(nn.Module):
  426. """
  427. Multi-Layer Perceptron (MLP) used as a feed-forward layer.
  428. Attributes:
  429. w1 (nn.Module): Linear layer for input-to-hidden transformation.
  430. w2 (nn.Module): Linear layer for hidden-to-output transformation.
  431. w3 (nn.Module): Additional linear layer for feature transformation.
  432. """
  433. def __init__(self, dim: int, inter_dim: int):
  434. """
  435. Initializes the MLP layer.
  436. Args:
  437. dim (int): Input and output dimensionality.
  438. inter_dim (int): Hidden layer dimensionality.
  439. """
  440. super().__init__()
  441. self.w1 = ColumnParallelLinear(dim, inter_dim)
  442. self.w2 = RowParallelLinear(inter_dim, dim)
  443. self.w3 = ColumnParallelLinear(dim, inter_dim)
  444. def forward(self, x: torch.Tensor) -> torch.Tensor:
  445. """
  446. Forward pass for the MLP layer.
  447. Args:
  448. x (torch.Tensor): Input tensor.
  449. Returns:
  450. torch.Tensor: Output tensor after MLP computation.
  451. """
  452. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  453. class Gate(nn.Module):
  454. """
  455. Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
  456. Attributes:
  457. dim (int): Dimensionality of input features.
  458. topk (int): Number of top experts activated for each input.
  459. n_groups (int): Number of groups for routing.
  460. topk_groups (int): Number of groups to route inputs to.
  461. score_func (str): Scoring function ('softmax' or 'sigmoid').
  462. route_scale (float): Scaling factor for routing weights.
  463. weight (torch.nn.Parameter): Learnable weights for the gate.
  464. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
  465. """
  466. def __init__(self, args: ModelArgs):
  467. """
  468. Initializes the Gate module.
  469. Args:
  470. args (ModelArgs): Model arguments containing gating parameters.
  471. """
  472. super().__init__()
  473. self.dim = args.dim
  474. self.topk = args.n_activated_experts
  475. self.n_groups = args.n_expert_groups
  476. self.topk_groups = args.n_limited_groups
  477. self.score_func = args.score_func
  478. self.route_scale = args.route_scale
  479. self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
  480. self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
  481. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  482. """
  483. Forward pass for the gating mechanism.
  484. Args:
  485. x (torch.Tensor): Input tensor.
  486. Returns:
  487. Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
  488. """
  489. scores = linear(x, self.weight)
  490. if self.score_func == "softmax":
  491. scores = scores.softmax(dim=-1, dtype=torch.float32)
  492. else:
  493. scores = scores.sigmoid()
  494. original_scores = scores
  495. if self.bias is not None:
  496. scores = scores + self.bias
  497. if self.n_groups > 1:
  498. scores = scores.view(x.size(0), self.n_groups, -1)
  499. if self.bias is None:
  500. group_scores = scores.amax(dim=-1)
  501. else:
  502. group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
  503. indices = group_scores.topk(self.topk_groups, dim=-1)[1]
  504. mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
  505. scores = (scores * mask.unsqueeze(-1)).flatten(1)
  506. indices = torch.topk(scores, self.topk, dim=-1)[1]
  507. weights = original_scores.gather(1, indices)
  508. if self.score_func == "sigmoid":
  509. weights /= weights.sum(dim=-1, keepdim=True)
  510. weights *= self.route_scale
  511. return weights.type_as(x), indices
  512. class Expert(nn.Module):
  513. """
  514. Expert layer for Mixture-of-Experts (MoE) models.
  515. Attributes:
  516. w1 (nn.Module): Linear layer for input-to-hidden transformation.
  517. w2 (nn.Module): Linear layer for hidden-to-output transformation.
  518. w3 (nn.Module): Additional linear layer for feature transformation.
  519. """
  520. def __init__(self, dim: int, inter_dim: int):
  521. """
  522. Initializes the Expert layer.
  523. Args:
  524. dim (int): Input and output dimensionality.
  525. inter_dim (int): Hidden layer dimensionality.
  526. """
  527. super().__init__()
  528. self.w1 = Linear(dim, inter_dim)
  529. self.w2 = Linear(inter_dim, dim)
  530. self.w3 = Linear(dim, inter_dim)
  531. def forward(self, x: torch.Tensor) -> torch.Tensor:
  532. """
  533. Forward pass for the Expert layer.
  534. Args:
  535. x (torch.Tensor): Input tensor.
  536. Returns:
  537. torch.Tensor: Output tensor after expert computation.
  538. """
  539. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  540. class MoE(nn.Module):
  541. """
  542. Mixture-of-Experts (MoE) module.
  543. Attributes:
  544. dim (int): Dimensionality of input features.
  545. n_routed_experts (int): Total number of experts in the model.
  546. n_local_experts (int): Number of experts handled locally in distributed systems.
  547. n_activated_experts (int): Number of experts activated for each input.
  548. gate (nn.Module): Gating mechanism to route inputs to experts.
  549. experts (nn.ModuleList): List of expert modules.
  550. shared_experts (nn.Module): Shared experts applied to all inputs.
  551. """
  552. def __init__(self, args: ModelArgs):
  553. """
  554. Initializes the MoE module.
  555. Args:
  556. args (ModelArgs): Model arguments containing MoE parameters.
  557. """
  558. super().__init__()
  559. self.dim = args.dim
  560. assert args.n_routed_experts % world_size == 0
  561. self.n_routed_experts = args.n_routed_experts
  562. self.n_local_experts = args.n_routed_experts // world_size
  563. self.n_activated_experts = args.n_activated_experts
  564. self.experts_start_idx = rank * self.n_local_experts
  565. self.experts_end_idx = self.experts_start_idx + self.n_local_experts
  566. self.gate = Gate(args)
  567. self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
  568. for i in range(self.n_routed_experts)])
  569. self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
  570. def forward(self, x: torch.Tensor) -> torch.Tensor:
  571. """
  572. Forward pass for the MoE module.
  573. Args:
  574. x (torch.Tensor): Input tensor.
  575. Returns:
  576. torch.Tensor: Output tensor after expert routing and computation.
  577. """
  578. shape = x.size()
  579. x = x.view(-1, self.dim)
  580. weights, indices = self.gate(x)
  581. y = torch.zeros_like(x)
  582. counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
  583. for i in range(self.experts_start_idx, self.experts_end_idx):
  584. if counts[i] == 0:
  585. continue
  586. expert = self.experts[i]
  587. idx, top = torch.where(indices == i)
  588. y[idx] += expert(x[idx]) * weights[idx, top, None]
  589. z = self.shared_experts(x)
  590. if world_size > 1:
  591. dist.all_reduce(y)
  592. return (y + z).view(shape)
  593. class Block(nn.Module):
  594. """
  595. Transformer block combining attention and feed-forward layers.
  596. Attributes:
  597. attn (nn.Module): Attention layer (MLA).
  598. ffn (nn.Module): Feed-forward network (MLP or MoE).
  599. attn_norm (nn.Module): Layer normalization for attention.
  600. ffn_norm (nn.Module): Layer normalization for feed-forward network.
  601. """
  602. def __init__(self, layer_id: int, args: ModelArgs):
  603. """
  604. Initializes the Transformer block.
  605. Args:
  606. layer_id (int): Layer index in the transformer.
  607. args (ModelArgs): Model arguments containing block parameters.
  608. """
  609. super().__init__()
  610. self.attn = MLA(args)
  611. self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
  612. self.attn_norm = RMSNorm(args.dim)
  613. self.ffn_norm = RMSNorm(args.dim)
  614. def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
  615. """
  616. Forward pass for the Transformer block.
  617. Args:
  618. x (torch.Tensor): Input tensor.
  619. start_pos (int): Starting position in the sequence.
  620. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
  621. mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
  622. Returns:
  623. torch.Tensor: Output tensor after block computation.
  624. """
  625. x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
  626. x = x + self.ffn(self.ffn_norm(x))
  627. return x
  628. class Transformer(nn.Module):
  629. """
  630. Transformer model with positional embeddings, multiple layers, and output projection.
  631. Attributes:
  632. max_seq_len (int): Maximum sequence length for the transformer.
  633. embed (nn.Module): Embedding layer for input tokens.
  634. layers (torch.nn.ModuleList): List of transformer blocks.
  635. norm (nn.Module): Layer normalization applied after all blocks.
  636. head (nn.Module): Output projection layer mapping to vocabulary size.
  637. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
  638. """
  639. def __init__(self, args: ModelArgs):
  640. """
  641. Initializes the Transformer model.
  642. Args:
  643. args (ModelArgs): Model arguments containing transformer parameters.
  644. """
  645. global world_size, rank
  646. world_size = dist.get_world_size() if dist.is_initialized() else 1
  647. rank = dist.get_rank() if dist.is_initialized() else 0
  648. Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
  649. super().__init__()
  650. self.max_seq_len = args.max_seq_len
  651. self.embed = ParallelEmbedding(args.vocab_size, args.dim)
  652. self.layers = torch.nn.ModuleList()
  653. for layer_id in range(args.n_layers):
  654. self.layers.append(Block(layer_id, args))
  655. self.norm = RMSNorm(args.dim)
  656. self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
  657. self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
  658. @torch.inference_mode()
  659. def forward(self, tokens: torch.Tensor, start_pos: int = 0):
  660. """
  661. Forward pass for the Transformer model.
  662. Args:
  663. tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
  664. start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
  665. Returns:
  666. torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
  667. """
  668. seqlen = tokens.size(1)
  669. h = self.embed(tokens)
  670. freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
  671. mask = None
  672. if seqlen > 1:
  673. mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
  674. for layer in self.layers:
  675. h = layer(h, start_pos, freqs_cis, mask)
  676. h = self.norm(h)[:, -1]
  677. logits = self.head(h)
  678. if world_size > 1:
  679. all_logits = [torch.empty_like(logits) for _ in range(world_size)]
  680. dist.all_gather(all_logits, logits)
  681. logits = torch.cat(all_logits, dim=-1)
  682. return logits
  683. if __name__ == "__main__":
  684. torch.set_default_dtype(torch.bfloat16)
  685. torch.set_default_device("cuda")
  686. torch.manual_seed(0)
  687. args = ModelArgs()
  688. x = torch.randint(0, args.vocab_size, (2, 128))
  689. model = Transformer(args)
  690. print(model(x).size())