fp8_cast_bf16.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import json
  3. from argparse import ArgumentParser
  4. from glob import glob
  5. from tqdm import tqdm
  6. import torch
  7. from safetensors.torch import load_file, save_file
  8. from kernel import weight_dequant
  9. def main(fp8_path, bf16_path):
  10. """
  11. Converts FP8 weights to BF16 and saves the converted weights.
  12. This function reads FP8 weights from the specified directory, converts them to BF16,
  13. and saves the converted weights to another specified directory. It also updates the
  14. model index file to reflect the changes.
  15. Args:
  16. fp8_path (str): The path to the directory containing the FP8 weights and model index file.
  17. bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
  18. Raises:
  19. KeyError: If a required scale_inv tensor is missing for a weight.
  20. Notes:
  21. - The function assumes that the FP8 weights are stored in safetensor files.
  22. - The function caches loaded safetensor files to optimize memory usage.
  23. - The function updates the model index file to remove references to scale_inv tensors.
  24. """
  25. torch.set_default_dtype(torch.bfloat16)
  26. os.makedirs(bf16_path, exist_ok=True)
  27. model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
  28. with open(model_index_file, "r") as f:
  29. model_index = json.load(f)
  30. weight_map = model_index["weight_map"]
  31. # Cache for loaded safetensor files
  32. loaded_files = {}
  33. fp8_weight_names = []
  34. # Helper function to get tensor from the correct file
  35. def get_tensor(tensor_name):
  36. """
  37. Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
  38. Args:
  39. tensor_name (str): The name of the tensor to retrieve.
  40. Returns:
  41. torch.Tensor: The retrieved tensor.
  42. Raises:
  43. KeyError: If the tensor does not exist in the safetensor file.
  44. """
  45. file_name = weight_map[tensor_name]
  46. if file_name not in loaded_files:
  47. file_path = os.path.join(fp8_path, file_name)
  48. loaded_files[file_name] = load_file(file_path, device="cuda")
  49. return loaded_files[file_name][tensor_name]
  50. safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
  51. safetensor_files.sort()
  52. for safetensor_file in tqdm(safetensor_files):
  53. file_name = os.path.basename(safetensor_file)
  54. current_state_dict = load_file(safetensor_file, device="cuda")
  55. loaded_files[file_name] = current_state_dict
  56. new_state_dict = {}
  57. for weight_name, weight in current_state_dict.items():
  58. if weight_name.endswith("_scale_inv"):
  59. continue
  60. elif weight.element_size() == 1: # FP8 weight
  61. scale_inv_name = f"{weight_name}_scale_inv"
  62. try:
  63. # Get scale_inv from the correct file
  64. scale_inv = get_tensor(scale_inv_name)
  65. fp8_weight_names.append(weight_name)
  66. new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
  67. except KeyError:
  68. print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
  69. new_state_dict[weight_name] = weight
  70. else:
  71. new_state_dict[weight_name] = weight
  72. new_safetensor_file = os.path.join(bf16_path, file_name)
  73. save_file(new_state_dict, new_safetensor_file)
  74. # Memory management: keep only the 2 most recently used files
  75. if len(loaded_files) > 2:
  76. oldest_file = next(iter(loaded_files))
  77. del loaded_files[oldest_file]
  78. torch.cuda.empty_cache()
  79. # Update model index
  80. new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
  81. for weight_name in fp8_weight_names:
  82. scale_inv_name = f"{weight_name}_scale_inv"
  83. if scale_inv_name in weight_map:
  84. weight_map.pop(scale_inv_name)
  85. with open(new_model_index_file, "w") as f:
  86. json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
  87. if __name__ == "__main__":
  88. parser = ArgumentParser()
  89. parser.add_argument("--input-fp8-hf-path", type=str, required=True)
  90. parser.add_argument("--output-bf16-hf-path", type=str, required=True)
  91. args = parser.parse_args()
  92. main(args.input_fp8_hf_path, args.output_bf16_hf_path)