import io import torch def get_target_dtype_ref(target_dtype: str) -> torch.dtype: if isinstance(target_dtype, torch.dtype): return target_dtype if target_dtype == "float16": return torch.float16 elif target_dtype == "float32": return torch.float32 elif target_dtype == "bfloat16": return torch.bfloat16 else: raise ValueError(f"Invalid target_dtype: {target_dtype}") def convert_ckpt_to_safetensors(ckpt_upload: io.BytesIO, target_dtype) -> dict: if isinstance(ckpt_upload, bytes): ckpt_upload = io.BytesIO(ckpt_upload) target_dtype = get_target_dtype_ref(target_dtype) # Load the checkpoint loaded_dict = torch.load(ckpt_upload, map_location="cpu") tensor_dict = {} is_embedding = 'string_to_param' in loaded_dict if is_embedding: emb_tensor = loaded_dict.get('string_to_param', {}).get('*', None) if emb_tensor is not None: emb_tensor = emb_tensor.to(dtype=target_dtype) tensor_dict = { 'emb_params': emb_tensor } else: # Convert weights in a checkpoint to a dictionary of tensors for key, val in loaded_dict.items(): if isinstance(val, torch.Tensor): tensor_dict[key] = val.to(dtype=target_dtype) return tensor_dict if __name__ == '__main__': print('__main__ not allowed in modules')