Spaces:
Sleeping
Sleeping
import io | |
import json | |
import safetensors | |
import torch | |
from safetensors.torch import serialize | |
from .torch_tools import get_target_dtype_ref | |
def read_safetensors_metadata(lora_upload: io.BytesIO) -> dict: | |
# This is a simple file structure, the first 8 bytes are the metadata length. | |
# Read (length) bytes starting from [8] to get the metadata (a json string). | |
lora_upload.seek(0) | |
metadata_length = int.from_bytes(lora_upload.read(8), byteorder='little') | |
lora_upload.seek(8) | |
metadata_raw = lora_upload.read(metadata_length) | |
metadata_raw = metadata_raw.decode("utf-8") | |
metadata_raw = metadata_raw.strip() | |
metadata_dict = json.loads(metadata_raw) | |
# Rewind the buffer to the start, we were just peeking at the metadata. | |
lora_upload.seek(0) | |
return metadata_dict.get('__metadata__', {}) | |
def rescale_lora_alpha(lora_upload: io.BytesIO, output_dtype, target_weight: float = 1.0) -> dict: | |
output_dtype = get_target_dtype_ref(output_dtype) | |
loaded_tensors = safetensors.torch.load(lora_upload.getvalue()) | |
initial_tensors = {} | |
for tensor_pair in loaded_tensors.items(): | |
key, tensor = tensor_pair | |
initial_tensors[key] = tensor.to(dtype=torch.float32) | |
new_tensors = {} | |
for key, val in initial_tensors.items(): | |
if key.endswith(".alpha"): | |
val *= target_weight | |
new_tensors[key] = val.to(dtype=output_dtype) | |
return new_tensors | |
def remove_clip_weights(lora_upload: io.BytesIO, output_dtype) -> dict: | |
output_dtype = get_target_dtype_ref(output_dtype) | |
loaded_tensors = safetensors.torch.load(lora_upload.getvalue()) | |
initial_tensors = {} | |
for tensor_pair in loaded_tensors.items(): | |
key, tensor = tensor_pair | |
initial_tensors[key] = tensor.to(dtype=torch.float32) | |
filtered_tensors = {} | |
for key, val in initial_tensors.items(): | |
if key.startswith("lora_te1") or key.startswith("lora_te2"): | |
continue | |
filtered_tensors[key] = val.to(dtype=output_dtype) | |
return filtered_tensors | |
if __name__ == '__main__': | |
print('__main__ not allowed in modules') | |