File size: 919 Bytes
aa5d6d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import io

import torch

def get_target_dtype_ref(target_dtype: str) -> torch.dtype:
	if isinstance(target_dtype, torch.dtype):
		return target_dtype
	
	if target_dtype == "float16":
		return torch.float16
	elif target_dtype == "float32":
		return torch.float32
	elif target_dtype == "bfloat16":
		return torch.bfloat16
	else:
		raise ValueError(f"Invalid target_dtype: {target_dtype}")

def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict:
	target_dtype = get_target_dtype_ref(target_dtype)
	ckpt_data = ckpt_upload.getvalue()
	
	# Load the checkpoint
	checkpoint = torch.load(ckpt_data, map_location="cpu")
	
	# Convert the checkpoint to a dictionary of tensors
	tensor_dict = {}
	for key, val in checkpoint.items():
		tensor_dict[key] = val.to(dtype=target_dtype)
	
	return tensor_dict

if __name__ == '__main__':
	print('__main__ not allowed in modules')