import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm

import torch
from safetensors.torch import load_file, save_file, safe_open  # Import safe_open

from kernel import weight_dequant

def main(fp8_path, bf16_path):
    torch.set_default_dtype(torch.bfloat16)
    os.makedirs(bf16_path, exist_ok=True)
    model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
    with open(model_index_file, "r") as f:
        model_index = json.load(f)
    weight_map = model_index["weight_map"]
    
    fp8_weight_names = []

    safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
    safetensor_files.sort()
    
    for safetensor_file in tqdm(safetensor_files):
        file_name = os.path.basename(safetensor_file)
        
        with safe_open(safetensor_file, framework="pt", device="cuda") as f: # Use safe_open
            for weight_name, weight_in_file in weight_map.items():
                if weight_in_file != file_name:
                    continue
                
                if weight_name.endswith("_scale_inv"):
                    continue
                elif "weight" in weight_name and f.get_tensor(weight_name).element_size() == 1:  # FP8 weight
                    scale_inv_name = f"{weight_name}_scale_inv"
                    try:
                        # Load and process weight and scale_inv individually
                        weight = f.get_tensor(weight_name)
                        scale_inv = f.get_tensor(scale_inv_name)
                        fp8_weight_names.append(weight_name)
                        new_weight = weight_dequant(weight, scale_inv)
                        
                        # Save the processed weight immediately
                        new_safetensor_file = os.path.join(bf16_path, file_name.replace(".safetensors", f".{weight_name}.safetensors"))
                        save_file({weight_name: new_weight}, new_safetensor_file)
                        
                        # Explicitly release memory
                        del weight, scale_inv, new_weight
                        torch.cuda.empty_cache()
                    except KeyError:
                        print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
                else:
                    # For non-FP8 weights, just copy
                    new_weight = f.get_tensor(weight_name)
                    new_safetensor_file = os.path.join(bf16_path, file_name.replace(".safetensors", f".{weight_name}.safetensors"))
                    save_file({weight_name: new_weight}, new_safetensor_file)
                    del new_weight
                    torch.cuda.empty_cache()
                    
    # Update model index (similar to before, but adjust paths)
    new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
    new_weight_map = {}
    for original_weight_name, original_file_name in weight_map.items():
        if not original_weight_name.endswith("_scale_inv"):
            new_weight_name = original_weight_name
            new_file_name = original_file_name.replace(".safetensors", f".{original_weight_name}.safetensors")
            new_weight_map[new_weight_name] = new_file_name
    with open(new_model_index_file, "w") as f:
        json.dump({"metadata": {}, "weight_map": new_weight_map}, f, indent=2)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--input-fp8-hf-path", type=str, required=True)
    parser.add_argument("--output-bf16-hf-path", type=str, required=True)
    args = parser.parse_args()
    main(args.input_fp8_hf_path, args.output_bf16_hf_path)