|
import sys |
|
import os |
|
import urllib.request |
|
import subprocess |
|
import tarfile |
|
import tempfile |
|
|
|
import streamlit as st |
|
from huggingface_hub import HfApi |
|
|
|
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN") |
|
HF_USERNAME = ( |
|
st.secrets.get("HF_USERNAME") |
|
or os.environ.get("HF_USERNAME") |
|
or os.environ.get("SPACE_AUTHOR_NAME") |
|
) |
|
|
|
TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs" |
|
TRANSFORMERS_REPOSITORY_REVISION = "3.0.0" |
|
TRANSFORMERS_REF_TYPE = "tags" if urllib.request.urlopen(f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz").getcode() == 200 else "heads" |
|
TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz" |
|
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js" |
|
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz" |
|
HF_BASE_URL = "https://huggingface.co" |
|
|
|
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH): |
|
|
|
print(f"Downloading the repository from {TRANSFORMERS_REPOSITORY_URL}...") |
|
urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
|
print(f"Extracting the archive {ARCHIVE_PATH}...") |
|
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar: |
|
tar.extractall(tmp_dir) |
|
|
|
|
|
extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0]) |
|
|
|
|
|
os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH) |
|
|
|
|
|
os.remove(ARCHIVE_PATH) |
|
print("Repository downloaded and extracted successfully.") |
|
|
|
st.write("## Convert a HuggingFace model to ONNX") |
|
|
|
input_model_id = st.text_input( |
|
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`" |
|
) |
|
|
|
if input_model_id: |
|
model_name = ( |
|
input_model_id.replace(f"{HF_BASE_URL}/", "") |
|
.replace("/", "-") |
|
.replace(f"{HF_USERNAME}-", "") |
|
.strip() |
|
) |
|
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX" |
|
output_model_url = f"{HF_BASE_URL}/{output_model_id}" |
|
api = HfApi(token=HF_TOKEN) |
|
repo_exists = api.repo_exists(output_model_id) |
|
|
|
if repo_exists: |
|
st.write("This model has already been converted! 🎉") |
|
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary") |
|
else: |
|
st.write(f"This model will be converted and uploaded to the following URL:") |
|
st.code(output_model_url, language="plaintext") |
|
start_conversion = st.button(label="Proceed", type="primary") |
|
|
|
if start_conversion: |
|
with st.spinner("Converting model..."): |
|
output = subprocess.run( |
|
[ |
|
"python", |
|
"-m", |
|
"scripts.convert", |
|
"--quantize", |
|
"--model_id", |
|
input_model_id, |
|
], |
|
cwd=TRANSFORMERS_REPOSITORY_PATH, |
|
capture_output=True, |
|
text=True, |
|
) |
|
|
|
|
|
print("### Script Output ###") |
|
print(output.stdout) |
|
|
|
|
|
if output.stderr: |
|
print("### Script Errors ###") |
|
print(output.stderr) |
|
|
|
model_folder_path = ( |
|
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}" |
|
) |
|
|
|
os.rename( |
|
f"{model_folder_path}/onnx/model.onnx", |
|
f"{model_folder_path}/onnx/decoder_model_merged.onnx", |
|
) |
|
os.rename( |
|
f"{model_folder_path}/onnx/model_quantized.onnx", |
|
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx", |
|
) |
|
|
|
st.success("Conversion successful!") |
|
|
|
st.code(output.stderr) |
|
|
|
with st.spinner("Uploading model..."): |
|
repository = api.create_repo( |
|
f"{output_model_id}", exist_ok=True, private=False |
|
) |
|
|
|
upload_error_message = None |
|
|
|
try: |
|
api.upload_folder( |
|
folder_path=model_folder_path, repo_id=repository.repo_id |
|
) |
|
except Exception as e: |
|
upload_error_message = str(e) |
|
|
|
os.system(f"rm -rf {model_folder_path}") |
|
|
|
if upload_error_message: |
|
st.error(f"Upload failed: {upload_error_message}") |
|
else: |
|
st.success(f"Upload successful!") |
|
st.write("You can now go and view the model on HuggingFace!") |
|
st.link_button( |
|
f"Go to {output_model_id}", output_model_url, type="primary" |
|
) |