File size: 5,147 Bytes
5b280f8
d088d6c
435210d
d088d6c
435210d
5b280f8
 
 
 
d088d6c
 
 
 
 
 
 
435210d
5b280f8
 
 
 
d088d6c
435210d
d088d6c
 
 
435210d
 
 
 
5b280f8
 
 
 
 
 
 
 
 
 
 
 
435210d
 
 
 
 
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435210d
 
 
 
 
 
 
 
 
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435210d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import sys
import os
import urllib.request
import subprocess
import tarfile
import tempfile

import streamlit as st
from huggingface_hub import HfApi

HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
    st.secrets.get("HF_USERNAME")
    or os.environ.get("HF_USERNAME")
    or os.environ.get("SPACE_AUTHOR_NAME")
)

TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs"
TRANSFORMERS_REPOSITORY_REVISION = "3.0.0"
TRANSFORMERS_REF_TYPE = "tags" if urllib.request.urlopen(f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz").getcode() == 200 else "heads"
TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"

if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
    # Download the .tar.gz file
    print(f"Downloading the repository from {TRANSFORMERS_REPOSITORY_URL}...")
    urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
    
    # Create a temporary directory for extraction
    with tempfile.TemporaryDirectory() as tmp_dir:
        # Extract the .tar.gz file to temp directory
        print(f"Extracting the archive {ARCHIVE_PATH}...")
        with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
            tar.extractall(tmp_dir)
        
        # Get the extracted folder name (there should be only one)
        extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
        
        # Move to final destination
        os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
    
    # Remove the downloaded .tar.gz archive
    os.remove(ARCHIVE_PATH)
    print("Repository downloaded and extracted successfully.")
    
st.write("## Convert a HuggingFace model to ONNX")

input_model_id = st.text_input(
    "Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)

if input_model_id:
    model_name = (
        input_model_id.replace(f"{HF_BASE_URL}/", "")
        .replace("/", "-")
        .replace(f"{HF_USERNAME}-", "")
        .strip()
    )
    output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
    output_model_url = f"{HF_BASE_URL}/{output_model_id}"
    api = HfApi(token=HF_TOKEN)
    repo_exists = api.repo_exists(output_model_id)

    if repo_exists:
        st.write("This model has already been converted! 🎉")
        st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
    else:
        st.write(f"This model will be converted and uploaded to the following URL:")
        st.code(output_model_url, language="plaintext")
        start_conversion = st.button(label="Proceed", type="primary")

        if start_conversion:
            with st.spinner("Converting model..."):
                output = subprocess.run(
                    [
                        "python",
                        "-m",
                        "scripts.convert",
                        "--quantize",
                        "--model_id",
                        input_model_id,
                    ],
                    cwd=TRANSFORMERS_REPOSITORY_PATH,
                    capture_output=True,
                    text=True,
                )
                
                # Log the script output
                print("### Script Output ###")
                print(output.stdout)
            
                # Log any errors
                if output.stderr:
                    print("### Script Errors ###")
                    print(output.stderr)

            model_folder_path = (
                f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
            )

            os.rename(
                f"{model_folder_path}/onnx/model.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged.onnx",
            )
            os.rename(
                f"{model_folder_path}/onnx/model_quantized.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
            )

            st.success("Conversion successful!")

            st.code(output.stderr)

            with st.spinner("Uploading model..."):
                repository = api.create_repo(
                    f"{output_model_id}", exist_ok=True, private=False
                )

                upload_error_message = None

                try:
                    api.upload_folder(
                        folder_path=model_folder_path, repo_id=repository.repo_id
                    )
                except Exception as e:
                    upload_error_message = str(e)

            os.system(f"rm -rf {model_folder_path}")

            if upload_error_message:
                st.error(f"Upload failed: {upload_error_message}")
            else:
                st.success(f"Upload successful!")
                st.write("You can now go and view the model on HuggingFace!")
                st.link_button(
                    f"Go to {output_model_id}", output_model_url, type="primary"
                )