File size: 5,147 Bytes
5b280f8 d088d6c 435210d d088d6c 435210d 5b280f8 d088d6c 435210d 5b280f8 d088d6c 435210d d088d6c 435210d 5b280f8 435210d d088d6c 435210d d088d6c 435210d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import sys
import os
import urllib.request
import subprocess
import tarfile
import tempfile
import streamlit as st
from huggingface_hub import HfApi
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
st.secrets.get("HF_USERNAME")
or os.environ.get("HF_USERNAME")
or os.environ.get("SPACE_AUTHOR_NAME")
)
TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs"
TRANSFORMERS_REPOSITORY_REVISION = "3.0.0"
TRANSFORMERS_REF_TYPE = "tags" if urllib.request.urlopen(f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz").getcode() == 200 else "heads"
TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
# Download the .tar.gz file
print(f"Downloading the repository from {TRANSFORMERS_REPOSITORY_URL}...")
urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
# Create a temporary directory for extraction
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract the .tar.gz file to temp directory
print(f"Extracting the archive {ARCHIVE_PATH}...")
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
tar.extractall(tmp_dir)
# Get the extracted folder name (there should be only one)
extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
# Move to final destination
os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
# Remove the downloaded .tar.gz archive
os.remove(ARCHIVE_PATH)
print("Repository downloaded and extracted successfully.")
st.write("## Convert a HuggingFace model to ONNX")
input_model_id = st.text_input(
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)
if input_model_id:
model_name = (
input_model_id.replace(f"{HF_BASE_URL}/", "")
.replace("/", "-")
.replace(f"{HF_USERNAME}-", "")
.strip()
)
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
output_model_url = f"{HF_BASE_URL}/{output_model_id}"
api = HfApi(token=HF_TOKEN)
repo_exists = api.repo_exists(output_model_id)
if repo_exists:
st.write("This model has already been converted! 🎉")
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
else:
st.write(f"This model will be converted and uploaded to the following URL:")
st.code(output_model_url, language="plaintext")
start_conversion = st.button(label="Proceed", type="primary")
if start_conversion:
with st.spinner("Converting model..."):
output = subprocess.run(
[
"python",
"-m",
"scripts.convert",
"--quantize",
"--model_id",
input_model_id,
],
cwd=TRANSFORMERS_REPOSITORY_PATH,
capture_output=True,
text=True,
)
# Log the script output
print("### Script Output ###")
print(output.stdout)
# Log any errors
if output.stderr:
print("### Script Errors ###")
print(output.stderr)
model_folder_path = (
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
)
os.rename(
f"{model_folder_path}/onnx/model.onnx",
f"{model_folder_path}/onnx/decoder_model_merged.onnx",
)
os.rename(
f"{model_folder_path}/onnx/model_quantized.onnx",
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
)
st.success("Conversion successful!")
st.code(output.stderr)
with st.spinner("Uploading model..."):
repository = api.create_repo(
f"{output_model_id}", exist_ok=True, private=False
)
upload_error_message = None
try:
api.upload_folder(
folder_path=model_folder_path, repo_id=repository.repo_id
)
except Exception as e:
upload_error_message = str(e)
os.system(f"rm -rf {model_folder_path}")
if upload_error_message:
st.error(f"Upload failed: {upload_error_message}")
else:
st.success(f"Upload successful!")
st.write("You can now go and view the model on HuggingFace!")
st.link_button(
f"Go to {output_model_id}", output_model_url, type="primary"
) |