Spaces:
Sleeping
Sleeping
File size: 919 Bytes
aa5d6d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import io
import torch
def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
if isinstance(target_dtype, torch.dtype):
return target_dtype
if target_dtype == "float16":
return torch.float16
elif target_dtype == "float32":
return torch.float32
elif target_dtype == "bfloat16":
return torch.bfloat16
else:
raise ValueError(f"Invalid target_dtype: {target_dtype}")
def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
target_dtype = get_target_dtype_ref(target_dtype)
ckpt_data = ckpt_upload.getvalue()
# Load the checkpoint
checkpoint = torch.load(ckpt_data, map_location="cpu")
# Convert the checkpoint to a dictionary of tensors
tensor_dict = {}
for key, val in checkpoint.items():
tensor_dict[key] = val.to(dtype=target_dtype)
return tensor_dict
if __name__ == '__main__':
print('__main__ not allowed in modules')
|