lora_tools / tools /torch_tools.py
kjerk
Add initial tools, layout, and config.
aa5d6d0
raw
history blame
919 Bytes
import io
import torch
def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
if isinstance(target_dtype, torch.dtype):
return target_dtype
if target_dtype == "float16":
return torch.float16
elif target_dtype == "float32":
return torch.float32
elif target_dtype == "bfloat16":
return torch.bfloat16
else:
raise ValueError(f"Invalid target_dtype: {target_dtype}")
def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
target_dtype = get_target_dtype_ref(target_dtype)
ckpt_data = ckpt_upload.getvalue()
# Load the checkpoint
checkpoint = torch.load(ckpt_data, map_location="cpu")
# Convert the checkpoint to a dictionary of tensors
tensor_dict = {}
for key, val in checkpoint.items():
tensor_dict[key] = val.to(dtype=target_dtype)
return tensor_dict
if __name__ == '__main__':
print('__main__ not allowed in modules')