model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Tuple, Optional, Literal
  4. import torch
  5. from torch import nn
  6. import torch.nn.functional as F
  7. import torch.distributed as dist
  8. from kernel import act_quant, weight_dequant, fp8_gemm
  9. world_size = 1
  10. rank = 0
  11. block_size = 128
  12. gemm_impl: Literal["bf16", "fp8"] = "bf16"
  13. attn_impl: Literal["naive", "absorb"] = "absorb"
  14. @dataclass
  15. class ModelArgs:
  16. max_batch_size: int = 8
  17. max_seq_len: int = 4096 * 4
  18. dtype: Literal["bf16", "fp8"] = "bf16"
  19. vocab_size: int = 102400
  20. dim: int = 2048
  21. inter_dim: int = 10944
  22. moe_inter_dim: int = 1408
  23. n_layers: int = 27
  24. n_dense_layers: int = 1
  25. n_heads: int = 16
  26. # moe
  27. n_routed_experts: int = 64
  28. n_shared_experts: int = 2
  29. n_activated_experts: int = 6
  30. n_expert_groups: int = 1
  31. n_limited_groups: int = 1
  32. score_func: Literal["softmax", "sigmoid"] = "softmax"
  33. route_scale: float = 1.
  34. # mla
  35. q_lora_rank: int = 0
  36. kv_lora_rank: int = 512
  37. qk_nope_head_dim: int = 128
  38. qk_rope_head_dim: int = 64
  39. v_head_dim: int = 128
  40. # yarn
  41. original_seq_len: int = 4096
  42. rope_theta: float = 10000.0
  43. rope_factor: float = 40
  44. beta_fast: int = 32
  45. beta_slow: int = 1
  46. mscale: float = 1.
  47. class ParallelEmbedding(nn.Module):
  48. def __init__(self, vocab_size: int, dim: int):
  49. super().__init__()
  50. self.vocab_size = vocab_size
  51. self.dim = dim
  52. assert vocab_size % world_size == 0
  53. self.part_vocab_size = (vocab_size // world_size)
  54. self.vocab_start_idx = rank * self.part_vocab_size
  55. self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
  56. self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
  57. def forward(self, x: torch.Tensor) -> torch.Tensor:
  58. if world_size > 1:
  59. mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
  60. x = x - self.vocab_start_idx
  61. x[mask] = 0
  62. y = F.embedding(x, self.weight)
  63. if world_size > 1:
  64. y[mask] = 0
  65. dist.all_reduce(y)
  66. return y
  67. def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  68. if weight.element_size() > 1:
  69. return F.linear(x, weight, bias)
  70. elif gemm_impl == "bf16":
  71. weight = weight_dequant(weight, weight.scale)
  72. return F.linear(x, weight, bias)
  73. else:
  74. x, scale = act_quant(x, block_size)
  75. y = fp8_gemm(x, scale, weight, weight.scale)
  76. if bias is not None:
  77. y += bias
  78. return y
  79. class Linear(nn.Module):
  80. dtype = torch.bfloat16
  81. def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
  82. super().__init__()
  83. self.in_features = in_features
  84. self.out_features = out_features
  85. self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
  86. if self.weight.element_size() == 1:
  87. scale_out_features = (out_features + block_size - 1) // block_size
  88. scale_in_features = (in_features + block_size - 1) // block_size
  89. self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
  90. else:
  91. self.register_parameter("scale", None)
  92. if bias:
  93. self.bias = nn.Parameter(torch.empty(self.part_out_features))
  94. else:
  95. self.register_parameter("bias", None)
  96. def forward(self, x: torch.Tensor) -> torch.Tensor:
  97. return linear(x, self.weight, self.bias)
  98. class ColumnParallelLinear(Linear):
  99. def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
  100. assert out_features % world_size == 0
  101. self.part_out_features = out_features // world_size
  102. super().__init__(in_features, self.part_out_features, bias, dtype)
  103. def forward(self, x: torch.Tensor) -> torch.Tensor:
  104. y = linear(x, self.weight, self.bias)
  105. return y
  106. class RowParallelLinear(Linear):
  107. def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
  108. assert in_features % world_size == 0
  109. self.part_in_features = in_features // world_size
  110. super().__init__(self.part_in_features, out_features, bias, dtype)
  111. def forward(self, x: torch.Tensor) -> torch.Tensor:
  112. y = linear(x, self.weight)
  113. if world_size > 1:
  114. dist.all_reduce(y)
  115. if self.bias is not None:
  116. y += self.bias
  117. return y
  118. class RMSNorm(nn.Module):
  119. def __init__(self, dim: int, eps: float = 1e-6):
  120. super().__init__()
  121. self.eps = eps
  122. self.weight = nn.Parameter(torch.ones(dim))
  123. def forward(self, x: torch.Tensor):
  124. x = x.float()
  125. y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  126. return y.type_as(self.weight) * self.weight
  127. def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
  128. dim = args.qk_rope_head_dim
  129. seqlen = args.max_seq_len
  130. beta_fast = args.beta_fast
  131. beta_slow = args.beta_slow
  132. base = args.rope_theta
  133. factor = args.rope_factor
  134. def find_correction_dim(num_rotations, dim, base, max_seq_len):
  135. return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
  136. def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
  137. low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
  138. high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
  139. return max(low, 0), min(high, dim-1)
  140. def linear_ramp_factor(min, max, dim):
  141. if min == max:
  142. max += 0.001
  143. linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
  144. ramp_func = torch.clamp(linear_func, 0, 1)
  145. return ramp_func
  146. freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
  147. if seqlen > args.original_seq_len:
  148. low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
  149. smooth = 1 - linear_ramp_factor(low, high, dim // 2)
  150. freqs = freqs / factor * (1 - smooth) + freqs * smooth
  151. t = torch.arange(seqlen)
  152. freqs = torch.outer(t, freqs)
  153. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  154. return freqs_cis
  155. def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
  156. dtype = x.dtype
  157. x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
  158. freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
  159. y = torch.view_as_real(x * freqs_cis).flatten(3)
  160. return y.to(dtype)
  161. class MLA(nn.Module):
  162. def __init__(self, args: ModelArgs):
  163. super().__init__()
  164. self.dim = args.dim
  165. self.n_heads = args.n_heads
  166. self.n_local_heads = args.n_heads // world_size
  167. self.q_lora_rank = args.q_lora_rank
  168. self.kv_lora_rank = args.kv_lora_rank
  169. self.qk_nope_head_dim = args.qk_nope_head_dim
  170. self.qk_rope_head_dim = args.qk_rope_head_dim
  171. self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
  172. self.v_head_dim = args.v_head_dim
  173. if self.q_lora_rank == 0:
  174. self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
  175. else:
  176. self.wq_a = Linear(self.dim, self.q_lora_rank)
  177. self.q_norm = RMSNorm(self.q_lora_rank)
  178. self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
  179. self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
  180. self.kv_norm = RMSNorm(self.kv_lora_rank)
  181. self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
  182. self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
  183. self.softmax_scale = self.qk_head_dim ** -0.5
  184. if args.max_seq_len > args.original_seq_len:
  185. mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
  186. self.softmax_scale = self.softmax_scale * mscale * mscale
  187. if attn_impl == "naive":
  188. self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
  189. self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
  190. else:
  191. self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
  192. self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
  193. def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
  194. bsz, seqlen, _ = x.size()
  195. end_pos = start_pos + seqlen
  196. if self.q_lora_rank == 0:
  197. q = self.wq(x)
  198. else:
  199. q = self.wq_b(self.q_norm(self.wq_a(x)))
  200. q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
  201. q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
  202. q_pe = apply_rotary_emb(q_pe, freqs_cis)
  203. kv = self.wkv_a(x)
  204. kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  205. k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
  206. if attn_impl == "naive":
  207. q = torch.cat([q_nope, q_pe], dim=-1)
  208. kv = self.wkv_b(self.kv_norm(kv))
  209. kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
  210. k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  211. k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
  212. self.k_cache[:bsz, start_pos:end_pos] = k
  213. self.v_cache[:bsz, start_pos:end_pos] = v
  214. scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
  215. else:
  216. wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
  217. wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
  218. q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
  219. self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
  220. self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
  221. scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
  222. torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
  223. if mask is not None:
  224. scores += mask.unsqueeze(1)
  225. scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
  226. if attn_impl == "naive":
  227. x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
  228. else:
  229. x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
  230. x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
  231. x = self.wo(x.flatten(2))
  232. return x
  233. class MLP(nn.Module):
  234. def __init__(self, dim: int, inter_dim: int):
  235. super().__init__()
  236. self.w1 = ColumnParallelLinear(dim, inter_dim)
  237. self.w2 = RowParallelLinear(inter_dim, dim)
  238. self.w3 = ColumnParallelLinear(dim, inter_dim)
  239. def forward(self, x: torch.Tensor) -> torch.Tensor:
  240. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  241. class Gate(nn.Module):
  242. def __init__(self, args: ModelArgs):
  243. super().__init__()
  244. self.dim = args.dim
  245. self.topk = args.n_activated_experts
  246. self.n_groups = args.n_expert_groups
  247. self.topk_groups = args.n_limited_groups
  248. self.score_func = args.score_func
  249. self.route_scale = args.route_scale
  250. self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
  251. self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
  252. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  253. scores = linear(x, self.weight)
  254. if self.score_func == "softmax":
  255. scores = scores.softmax(dim=-1, dtype=torch.float32)
  256. else:
  257. scores = scores.sigmoid()
  258. original_scores = scores
  259. if self.bias is not None:
  260. scores = scores + self.bias
  261. if self.n_groups > 1:
  262. scores = scores.view(x.size(0), self.n_groups, -1)
  263. if self.bias is None:
  264. group_scores = scores.amax(dim=-1)
  265. else:
  266. group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
  267. indices = group_scores.topk(self.topk_groups, dim=-1)[1]
  268. mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
  269. scores = (scores * mask.unsqueeze(-1)).flatten(1)
  270. indices = torch.topk(scores, self.topk, dim=-1)[1]
  271. weights = original_scores.gather(1, indices)
  272. if self.score_func == "sigmoid":
  273. weights /= weights.sum(dim=-1, keepdim=True)
  274. weights *= self.route_scale
  275. return weights.type_as(x), indices
  276. class Expert(nn.Module):
  277. def __init__(self, dim: int, inter_dim: int):
  278. super().__init__()
  279. self.w1 = Linear(dim, inter_dim)
  280. self.w2 = Linear(inter_dim, dim)
  281. self.w3 = Linear(dim, inter_dim)
  282. def forward(self, x: torch.Tensor) -> torch.Tensor:
  283. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  284. class MoE(nn.Module):
  285. def __init__(self, args: ModelArgs):
  286. super().__init__()
  287. self.dim = args.dim
  288. assert args.n_routed_experts % world_size == 0
  289. self.n_routed_experts = args.n_routed_experts
  290. self.n_local_experts = args.n_routed_experts // world_size
  291. self.n_activated_experts = args.n_activated_experts
  292. self.experts_start_idx = rank * self.n_local_experts
  293. self.experts_end_idx = self.experts_start_idx + self.n_local_experts
  294. self.gate = Gate(args)
  295. self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
  296. for i in range(self.n_routed_experts)])
  297. self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
  298. def forward(self, x: torch.Tensor) -> torch.Tensor:
  299. shape = x.size()
  300. x = x.view(-1, self.dim)
  301. weights, indices = self.gate(x)
  302. y = torch.zeros_like(x)
  303. counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
  304. for i in range(self.experts_start_idx, self.experts_end_idx):
  305. if counts[i] == 0:
  306. continue
  307. expert = self.experts[i]
  308. idx, top = torch.where(indices == i)
  309. y[idx] += expert(x[idx]) * weights[idx, top, None]
  310. z = self.shared_experts(x)
  311. if world_size > 1:
  312. dist.all_reduce(y)
  313. return (y + z).view(shape)
  314. class Block(nn.Module):
  315. def __init__(self, layer_id: int, args: ModelArgs):
  316. super().__init__()
  317. self.attn = MLA(args)
  318. self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
  319. self.attn_norm = RMSNorm(args.dim)
  320. self.ffn_norm = RMSNorm(args.dim)
  321. def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
  322. x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
  323. x = x + self.ffn(self.ffn_norm(x))
  324. return x
  325. class Transformer(nn.Module):
  326. def __init__(self, args: ModelArgs):
  327. global world_size, rank
  328. world_size = dist.get_world_size() if dist.is_initialized() else 1
  329. rank = dist.get_rank() if dist.is_initialized() else 0
  330. Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
  331. super().__init__()
  332. self.max_seq_len = args.max_seq_len
  333. self.embed = ParallelEmbedding(args.vocab_size, args.dim)
  334. self.layers = torch.nn.ModuleList()
  335. for layer_id in range(args.n_layers):
  336. self.layers.append(Block(layer_id, args))
  337. self.norm = RMSNorm(args.dim)
  338. self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
  339. self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
  340. @torch.inference_mode()
  341. def forward(self, tokens: torch.Tensor, start_pos: int = 0):
  342. seqlen = tokens.size(1)
  343. h = self.embed(tokens)
  344. freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
  345. mask = None
  346. if seqlen > 1:
  347. mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
  348. for layer in self.layers:
  349. h = layer(h, start_pos, freqs_cis, mask)
  350. h = self.norm(h)[:, -1]
  351. logits = self.head(h)
  352. if world_size > 1:
  353. all_logits = [torch.empty_like(logits) for _ in range(world_size)]
  354. dist.all_gather(all_logits, logits)
  355. logits = torch.cat(all_logits, dim=-1)
  356. return logits
  357. if __name__ == "__main__":
  358. torch.set_default_dtype(torch.bfloat16)
  359. torch.set_default_device("cuda")
  360. torch.manual_seed(0)
  361. args = ModelArgs()
  362. x = torch.randint(0, args.vocab_size, (2, 128))
  363. model = Transformer(args)
  364. print(model(x).size())