import io import json import safetensors import torch from safetensors.torch import serialize from .torch_tools import get_target_dtype_ref def read_safetensors_metadata(lora_upload: io.BytesIO) -> dict: # This is a simple file structure, the first 8 bytes are the metadata length. # Read (length) bytes starting from [8] to get the metadata (a json string). lora_upload.seek(0) metadata_length = int.from_bytes(lora_upload.read(8), byteorder='little') lora_upload.seek(8) metadata_raw = lora_upload.read(metadata_length) metadata_raw = metadata_raw.decode("utf-8") metadata_raw = metadata_raw.strip() metadata_dict = json.loads(metadata_raw) # Rewind the buffer to the start, we were just peeking at the metadata. lora_upload.seek(0) return metadata_dict.get('__metadata__', {}) def rescale_lora_alpha(lora_upload: io.BytesIO, output_dtype, target_weight: float = 1.0) -> dict: output_dtype = get_target_dtype_ref(output_dtype) loaded_tensors = safetensors.torch.load(lora_upload.getvalue()) initial_tensors = {} for tensor_pair in loaded_tensors.items(): key, tensor = tensor_pair initial_tensors[key] = tensor.to(dtype=torch.float32) new_tensors = {} for key, val in initial_tensors.items(): if key.endswith(".alpha"): val *= target_weight new_tensors[key] = val.to(dtype=output_dtype) return new_tensors def remove_clip_weights(lora_upload: io.BytesIO, output_dtype) -> dict: output_dtype = get_target_dtype_ref(output_dtype) loaded_tensors = safetensors.torch.load(lora_upload.getvalue()) initial_tensors = {} for tensor_pair in loaded_tensors.items(): key, tensor = tensor_pair initial_tensors[key] = tensor.to(dtype=torch.float32) filtered_tensors = {} for key, val in initial_tensors.items(): if key.startswith("lora_te1") or key.startswith("lora_te2"): continue filtered_tensors[key] = val.to(dtype=output_dtype) return filtered_tensors if __name__ == '__main__': print('__main__ not allowed in modules')