|
import gradio as gr |
|
import collections |
|
import numpy as np |
|
import os |
|
import torch |
|
from safetensors.torch import serialize_file |
|
import requests |
|
import tempfile |
|
|
|
def download_file(url, local_path): |
|
"""Download a file from a URL to a local path.""" |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
with open(local_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
return local_path |
|
|
|
def rename_key(rename, name): |
|
for k, v in rename.items(): |
|
if k in name: |
|
name = name.replace(k, v) |
|
return name |
|
|
|
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]): |
|
loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu") |
|
if "state_dict" in loaded: |
|
loaded = loaded["state_dict"] |
|
|
|
kk = list(loaded.keys()) |
|
version = 4 |
|
for x in kk: |
|
if "ln_x" in x: |
|
version = max(5, version) |
|
if "gate.weight" in x: |
|
version = max(5.1, version) |
|
if int(version) == 5 and "att.time_decay" in x: |
|
if len(loaded[x].shape) > 1: |
|
if loaded[x].shape[1] > 1: |
|
version = max(5.2, version) |
|
if "time_maa" in x: |
|
version = max(6, version) |
|
|
|
print(f"Model detected: v{version:.1f}") |
|
|
|
if version == 5.1: |
|
_, n_emb = loaded["emb.weight"].shape |
|
for k in kk: |
|
if "time_decay" in k or "time_faaaa" in k: |
|
loaded[k] = loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0]) |
|
|
|
with torch.no_grad(): |
|
for k in kk: |
|
new_k = rename_key(rename, k).lower() |
|
v = loaded[k].half() |
|
del loaded[k] |
|
for transpose_name in transpose_names: |
|
if transpose_name in new_k: |
|
dims = len(v.shape) |
|
v = v.transpose(dims - 2, dims - 1) |
|
break |
|
print(f"{new_k}\t{v.shape}\t{v.dtype}") |
|
loaded[new_k] = { |
|
"dtype": str(v.dtype).split(".")[-1], |
|
"shape": v.shape, |
|
"data": v.numpy().tobytes(), |
|
} |
|
|
|
os.makedirs(os.path.dirname(sf_filename), exist_ok=True) |
|
serialize_file(loaded, sf_filename, metadata={"format": "pt"}) |
|
return sf_filename |
|
|
|
def process_model(url): |
|
"""Process the model URL and return a downloadable safetensors file.""" |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as temp_pth: |
|
pth_path = temp_pth.name |
|
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as temp_sf: |
|
sf_path = temp_sf.name |
|
|
|
|
|
download_file(url, pth_path) |
|
|
|
|
|
rename = {"time_faaaa": "time_first", "time_maa": "time_mix", "lora_A": "lora.0", "lora_B": "lora.1"} |
|
transpose_names = [ |
|
"time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2", |
|
"w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2", |
|
"time_state", "lora.0" |
|
] |
|
|
|
|
|
converted_file = convert_file(pth_path, sf_path, rename, transpose_names) |
|
|
|
|
|
os.remove(pth_path) |
|
|
|
return converted_file |
|
except Exception as e: |
|
|
|
if os.path.exists(pth_path): |
|
os.remove(pth_path) |
|
if os.path.exists(sf_path): |
|
os.remove(sf_path) |
|
raise gr.Error(f"Error processing the model: {str(e)}") |
|
|
|
|
|
with gr.Blocks(title="PTH to Safetensors Converter") as demo: |
|
gr.Markdown("# PTH to Safetensors Converter") |
|
gr.Markdown("Enter the URL to a `.pth` model file hosted on Hugging Face to convert it to `.safetensors` format.") |
|
|
|
url_input = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/.../model.pth") |
|
convert_btn = gr.Button("Convert") |
|
output_file = gr.File(label="Download Converted Safetensors File") |
|
|
|
convert_btn.click( |
|
fn=process_model, |
|
inputs=url_input, |
|
outputs=output_file |
|
) |
|
|
|
demo.launch() |