Pārlūkot izejas kodu

handle missing scale_inv_name (#2)

* handle missing scale_inv_name

Fixed an issue where `weight` and `weight_scale_inv` (e.g. `model.layers.39.mlp.experts.92.gate_proj.weight` and `model.layers.39.mlp.experts.92.gate_proj.weight_scale_inv`) were not in the same SafeTensor, causing an assertion error due to scale_inv_name not being in the state_dict.

* sort filename to reduce memory costs

* Add CUDA cache clearing in memory management

Added torch.cuda.empty_cache() to free up unused memory on the GPU,
Yang Wang 1 gadu atpakaļ
vecāks
revīzija
8f1c9488b5
1 mainītis faili ar 37 papildinājumiem un 11 dzēšanām
  1. 37 11
      inference/fp8_cast_bf16.py

+ 37 - 11
inference/fp8_cast_bf16.py

@@ -16,32 +16,58 @@ def main(fp8_path, bf16_path):
     with open(model_index_file, "r") as f:
         model_index = json.load(f)
     weight_map = model_index["weight_map"]
-    fp8_weight_names = []
     
+    # Cache for loaded safetensor files
+    loaded_files = {}
+    fp8_weight_names = []
+
+    # Helper function to get tensor from the correct file
+    def get_tensor(tensor_name):
+        file_name = weight_map[tensor_name]
+        if file_name not in loaded_files:
+            file_path = os.path.join(fp8_path, file_name)
+            loaded_files[file_name] = load_file(file_path, device="cuda")
+        return loaded_files[file_name][tensor_name]
+
     safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
+    safetensor_files.sort()
     for safetensor_file in tqdm(safetensor_files):
         file_name = os.path.basename(safetensor_file)
-        state_dict = load_file(safetensor_file, device="cuda")
+        current_state_dict = load_file(safetensor_file, device="cuda")
+        loaded_files[file_name] = current_state_dict
+        
         new_state_dict = {}
-        for weight_name, weight in state_dict.items():
+        for weight_name, weight in current_state_dict.items():
             if weight_name.endswith("_scale_inv"):
                 continue
-            elif weight.element_size() == 1:
+            elif weight.element_size() == 1:  # FP8 weight
                 scale_inv_name = f"{weight_name}_scale_inv"
-                assert scale_inv_name in state_dict
-                fp8_weight_names.append(weight_name)
-                scale_inv = state_dict[scale_inv_name]
-                new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
+                try:
+                    # Get scale_inv from the correct file
+                    scale_inv = get_tensor(scale_inv_name)
+                    fp8_weight_names.append(weight_name)
+                    new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
+                except KeyError:
+                    print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
+                    new_state_dict[weight_name] = weight
             else:
                 new_state_dict[weight_name] = weight
+                
         new_safetensor_file = os.path.join(bf16_path, file_name)
         save_file(new_state_dict, new_safetensor_file)
+        
+        # Memory management: keep only the 2 most recently used files
+        if len(loaded_files) > 2:
+            oldest_file = next(iter(loaded_files))
+            del loaded_files[oldest_file]
+            torch.cuda.empty_cache()
     
+    # Update model index
     new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
     for weight_name in fp8_weight_names:
         scale_inv_name = f"{weight_name}_scale_inv"
-        assert scale_inv_name in weight_map
-        weight_map.pop(scale_inv_name)
+        if scale_inv_name in weight_map:
+            weight_map.pop(scale_inv_name)
     with open(new_model_index_file, "w") as f:
         json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
         
@@ -52,4 +78,4 @@ if __name__ == "__main__":
     parser.add_argument("--output-bf16-hf-path", type=str, required=True)
     args = parser.parse_args()
     main(args.input_fp8_hf_path, args.output_bf16_hf_path)
-    
+