File size: 2,061 Bytes
aa5d6d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import io
import json

import safetensors
import torch
from safetensors.torch import serialize

from .torch_tools import get_target_dtype_ref

def read_safetensors_metadata(lora_upload: io.BytesIO) -> dict:
	# This is a simple file structure, the first 8 bytes are the metadata length.
	# Read (length) bytes starting from [8] to get the metadata (a json string).
	lora_upload.seek(0)
	
	metadata_length = int.from_bytes(lora_upload.read(8), byteorder='little')
	
	lora_upload.seek(8)
	metadata_raw = lora_upload.read(metadata_length)
	
	metadata_raw = metadata_raw.decode("utf-8")
	metadata_raw = metadata_raw.strip()
	metadata_dict = json.loads(metadata_raw)
	
	# Rewind the buffer to the start, we were just peeking at the metadata.
	lora_upload.seek(0)
	
	return metadata_dict.get('__metadata__', {})

def rescale_lora_alpha(lora_upload: io.BytesIO, output_dtype, target_weight: float = 1.0) -> dict:
	output_dtype = get_target_dtype_ref(output_dtype)
	
	loaded_tensors = safetensors.torch.load(lora_upload.getvalue())
	
	initial_tensors = {}
	for tensor_pair in loaded_tensors.items():
		key, tensor = tensor_pair
		initial_tensors[key] = tensor.to(dtype=torch.float32)
	
	new_tensors = {}
	for key, val in initial_tensors.items():
		if key.endswith(".alpha"):
			val *= target_weight
		new_tensors[key] = val.to(dtype=output_dtype)
	
	return new_tensors

def remove_clip_weights(lora_upload: io.BytesIO, output_dtype) -> dict:
	output_dtype = get_target_dtype_ref(output_dtype)
	
	loaded_tensors = safetensors.torch.load(lora_upload.getvalue())
	
	initial_tensors = {}
	for tensor_pair in loaded_tensors.items():
		key, tensor = tensor_pair
		initial_tensors[key] = tensor.to(dtype=torch.float32)
	
	filtered_tensors = {}
	for key, val in initial_tensors.items():
		if key.startswith("lora_te1") or key.startswith("lora_te2"):
			continue
		filtered_tensors[key] = val.to(dtype=output_dtype)
	
	return filtered_tensors

if __name__ == '__main__':
	print('__main__ not allowed in modules')