pth-to-st / app.py
Bradarr's picture
Create app.py
d4365d2 verified
import gradio as gr
import collections
import numpy as np
import os
import torch
from safetensors.torch import serialize_file
import requests
import tempfile
def download_file(url, local_path):
"""Download a file from a URL to a local path."""
response = requests.get(url, stream=True)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
def rename_key(rename, name):
for k, v in rename.items():
if k in name:
name = name.replace(k, v)
return name
def convert_file(pt_filename: str, sf_filename: str, rename={}, transpose_names=[]):
loaded: collections.OrderedDict = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
kk = list(loaded.keys())
version = 4
for x in kk:
if "ln_x" in x:
version = max(5, version)
if "gate.weight" in x:
version = max(5.1, version)
if int(version) == 5 and "att.time_decay" in x:
if len(loaded[x].shape) > 1:
if loaded[x].shape[1] > 1:
version = max(5.2, version)
if "time_maa" in x:
version = max(6, version)
print(f"Model detected: v{version:.1f}")
if version == 5.1:
_, n_emb = loaded["emb.weight"].shape
for k in kk:
if "time_decay" in k or "time_faaaa" in k:
loaded[k] = loaded[k].unsqueeze(1).repeat(1, n_emb // loaded[k].shape[0])
with torch.no_grad():
for k in kk:
new_k = rename_key(rename, k).lower()
v = loaded[k].half()
del loaded[k]
for transpose_name in transpose_names:
if transpose_name in new_k:
dims = len(v.shape)
v = v.transpose(dims - 2, dims - 1)
break
print(f"{new_k}\t{v.shape}\t{v.dtype}")
loaded[new_k] = {
"dtype": str(v.dtype).split(".")[-1],
"shape": v.shape,
"data": v.numpy().tobytes(),
}
os.makedirs(os.path.dirname(sf_filename), exist_ok=True)
serialize_file(loaded, sf_filename, metadata={"format": "pt"})
return sf_filename
def process_model(url):
"""Process the model URL and return a downloadable safetensors file."""
try:
# Create temporary files
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as temp_pth:
pth_path = temp_pth.name
with tempfile.NamedTemporaryFile(suffix=".safetensors", delete=False) as temp_sf:
sf_path = temp_sf.name
# Download the .pth file from the URL
download_file(url, pth_path)
# Conversion parameters
rename = {"time_faaaa": "time_first", "time_maa": "time_mix", "lora_A": "lora.0", "lora_B": "lora.1"}
transpose_names = [
"time_mix_w1", "time_mix_w2", "time_decay_w1", "time_decay_w2",
"w1", "w2", "a1", "a2", "g1", "g2", "v1", "v2",
"time_state", "lora.0"
]
# Convert the file
converted_file = convert_file(pth_path, sf_path, rename, transpose_names)
# Clean up the temporary .pth file
os.remove(pth_path)
return converted_file
except Exception as e:
# Clean up temporary files in case of error
if os.path.exists(pth_path):
os.remove(pth_path)
if os.path.exists(sf_path):
os.remove(sf_path)
raise gr.Error(f"Error processing the model: {str(e)}")
# Gradio interface
with gr.Blocks(title="PTH to Safetensors Converter") as demo:
gr.Markdown("# PTH to Safetensors Converter")
gr.Markdown("Enter the URL to a `.pth` model file hosted on Hugging Face to convert it to `.safetensors` format.")
url_input = gr.Textbox(label="Model URL", placeholder="https://huggingface.co/.../model.pth")
convert_btn = gr.Button("Convert")
output_file = gr.File(label="Download Converted Safetensors File")
convert_btn.click(
fn=process_model,
inputs=url_input,
outputs=output_file
)
demo.launch()