convert.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import shutil
  3. from argparse import ArgumentParser
  4. from glob import glob
  5. from tqdm import tqdm, trange
  6. import torch
  7. from safetensors.torch import safe_open, save_file
  8. mapping = {
  9. "embed_tokens": ("embed", 0),
  10. "input_layernorm": ("attn_norm", None),
  11. "post_attention_layernorm": ("ffn_norm", None),
  12. "q_proj": ("wq", 0),
  13. "q_a_proj": ("wq_a", None),
  14. "q_a_layernorm": ("q_norm", None),
  15. "q_b_proj": ("wq_b", 0),
  16. "kv_a_proj_with_mqa": ("wkv_a", None),
  17. "kv_a_layernorm": ("kv_norm", None),
  18. "kv_b_proj": ("wkv_b", 0),
  19. "o_proj": ("wo", 1),
  20. "gate": ("gate", None),
  21. "gate_proj": ("w1", 0),
  22. "down_proj": ("w2", 1),
  23. "up_proj": ("w3", 0),
  24. "norm": ("norm", None),
  25. "lm_head": ("head", 0),
  26. "scale": ("scale", None),
  27. }
  28. def main(hf_ckpt_path, save_path, n_experts, mp):
  29. torch.set_num_threads(8)
  30. n_local_experts = n_experts // mp
  31. state_dicts = [{} for _ in range(mp)]
  32. for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
  33. with safe_open(file_path, framework="pt", device="cpu") as f:
  34. for name in f.keys():
  35. if "model.layers.61" in name:
  36. continue
  37. param: torch.Tensor = f.get_tensor(name)
  38. if name.startswith("model."):
  39. name = name[len("model."):]
  40. name = name.replace("self_attn", "attn")
  41. name = name.replace("mlp", "ffn")
  42. name = name.replace("weight_scale_inv", "scale")
  43. name = name.replace("e_score_correction_bias", "bias")
  44. key = name.split(".")[-2]
  45. assert key in mapping
  46. new_key, dim = mapping[key]
  47. name = name.replace(key, new_key)
  48. for i in range(mp):
  49. new_param = param
  50. if "experts" in name and "shared_experts" not in name:
  51. idx = int(name.split(".")[-3])
  52. if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
  53. continue
  54. elif dim is not None:
  55. assert param.size(dim) % mp == 0
  56. shard_size = param.size(dim) // mp
  57. new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
  58. state_dicts[i][name] = new_param
  59. os.makedirs(save_path, exist_ok=True)
  60. for i in trange(mp):
  61. save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
  62. for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
  63. new_file_path = os.path.join(save_path, os.path.basename(file_path))
  64. shutil.copyfile(file_path, new_file_path)
  65. if __name__ == "__main__":
  66. parser = ArgumentParser()
  67. parser.add_argument("--hf-ckpt-path", type=str, required=True)
  68. parser.add_argument("--save-path", type=str, required=True)
  69. parser.add_argument("--n-experts", type=int, required=True)
  70. parser.add_argument("--model-parallel", type=int, required=True)
  71. args = parser.parse_args()
  72. assert args.n_experts % args.model_parallel == 0
  73. main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)