import io import torch def get_target_dtype_ref(target_dtype: str) -> torch.dtype: if isinstance(target_dtype, torch.dtype): return target_dtype if target_dtype == "float16": return torch.float16 elif target_dtype == "float32": return torch.float32 elif target_dtype == "bfloat16": return torch.bfloat16 else: raise ValueError(f"Invalid target_dtype: {target_dtype}") def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict: target_dtype = get_target_dtype_ref(target_dtype) ckpt_data = ckpt_upload.getvalue() # Load the checkpoint checkpoint = torch.load(ckpt_data, map_location="cpu") # Convert the checkpoint to a dictionary of tensors tensor_dict = {} for key, val in checkpoint.items(): tensor_dict[key] = val.to(dtype=target_dtype) return tensor_dict if __name__ == '__main__': print('__main__ not allowed in modules')