convert.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import shutil
  3. from argparse import ArgumentParser
  4. from glob import glob
  5. from tqdm import tqdm, trange
  6. import torch
  7. from safetensors.torch import safe_open, save_file
  8. mapping = {
  9. "embed_tokens": ("embed", 0),
  10. "input_layernorm": ("attn_norm", None),
  11. "post_attention_layernorm": ("ffn_norm", None),
  12. "q_proj": ("wq", 0),
  13. "q_a_proj": ("wq_a", None),
  14. "q_a_layernorm": ("q_norm", None),
  15. "q_b_proj": ("wq_b", 0),
  16. "kv_a_proj_with_mqa": ("wkv_a", None),
  17. "kv_a_layernorm": ("kv_norm", None),
  18. "kv_b_proj": ("wkv_b", 0),
  19. "o_proj": ("wo", 1),
  20. "gate": ("gate", None),
  21. "gate_proj": ("w1", 0),
  22. "down_proj": ("w2", 1),
  23. "up_proj": ("w3", 0),
  24. "norm": ("norm", None),
  25. "lm_head": ("head", 0),
  26. "scale": ("scale", None),
  27. }
  28. def main(hf_ckpt_path, save_path, n_experts, mp):
  29. """
  30. Converts and saves model checkpoint files into a specified format.
  31. Args:
  32. hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
  33. save_path (str): Path to the directory where the converted checkpoint files will be saved.
  34. n_experts (int): Total number of experts in the model.
  35. mp (int): Model parallelism factor.
  36. Returns:
  37. None
  38. """
  39. torch.set_num_threads(8)
  40. n_local_experts = n_experts // mp
  41. state_dicts = [{} for _ in range(mp)]
  42. for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
  43. with safe_open(file_path, framework="pt", device="cpu") as f:
  44. for name in f.keys():
  45. if "model.layers.61" in name:
  46. continue
  47. param: torch.Tensor = f.get_tensor(name)
  48. if name.startswith("model."):
  49. name = name[len("model."):]
  50. name = name.replace("self_attn", "attn")
  51. name = name.replace("mlp", "ffn")
  52. name = name.replace("weight_scale_inv", "scale")
  53. name = name.replace("e_score_correction_bias", "bias")
  54. key = name.split(".")[-2]
  55. assert key in mapping
  56. new_key, dim = mapping[key]
  57. name = name.replace(key, new_key)
  58. for i in range(mp):
  59. new_param = param
  60. if "experts" in name and "shared_experts" not in name:
  61. idx = int(name.split(".")[-3])
  62. if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
  63. continue
  64. elif dim is not None:
  65. assert param.size(dim) % mp == 0
  66. shard_size = param.size(dim) // mp
  67. new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
  68. state_dicts[i][name] = new_param
  69. os.makedirs(save_path, exist_ok=True)
  70. for i in trange(mp):
  71. save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
  72. for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
  73. new_file_path = os.path.join(save_path, os.path.basename(file_path))
  74. shutil.copyfile(file_path, new_file_path)
  75. if __name__ == "__main__":
  76. parser = ArgumentParser()
  77. parser.add_argument("--hf-ckpt-path", type=str, required=True)
  78. parser.add_argument("--save-path", type=str, required=True)
  79. parser.add_argument("--n-experts", type=int, required=True)
  80. parser.add_argument("--model-parallel", type=int, required=True)
  81. args = parser.parse_args()
  82. assert args.n_experts % args.model_parallel == 0
  83. main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)