File size: 3,701 Bytes
d088d6c
be78a51
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from huggingface_hub import HfApiee
import os
import subprocess

HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
    st.secrets.get("HF_USERNAME")
    or os.environ.get("HF_USERNAME")
    or os.environ.get("SPACE_AUTHOR_NAME")
)
TRANSFORMERS_REPOSITORY_URL = "https://github.com/xenova/transformers.js"
TRANSFORMERS_REPOSITORY_REVISION = "2.16.0"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
HF_BASE_URL = "https://huggingface.co"

if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
    os.system(f"git clone {TRANSFORMERS_REPOSITORY_URL} {TRANSFORMERS_REPOSITORY_PATH}")

os.system(
    f"cd {TRANSFORMERS_REPOSITORY_PATH} && git checkout {TRANSFORMERS_REPOSITORY_REVISION}"
)

st.write("## Convert a HuggingFace model to ONNX")

input_model_id = st.text_input(
    "Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)

if input_model_id:
    model_name = (
        input_model_id.replace(f"{HF_BASE_URL}/", "")
        .replace("/", "-")
        .replace(f"{HF_USERNAME}-", "")
        .strip()
    )
    output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
    output_model_url = f"{HF_BASE_URL}/{output_model_id}"
    api = HfApi(token=HF_TOKEN)
    repo_exists = api.repo_exists(output_model_id)

    if repo_exists:
        st.write("This model has already been converted! 🎉")
        st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
    else:
        st.write(f"This model will be converted and uploaded to the following URL:")
        st.code(output_model_url, language="plaintext")
        start_conversion = st.button(label="Proceed", type="primary")

        if start_conversion:
            with st.spinner("Converting model..."):
                output = subprocess.run(
                    [
                        "python",
                        "-m",
                        "scripts.convert",
                        "--quantize",
                        "--model_id",
                        input_model_id,
                    ],
                    cwd=TRANSFORMERS_REPOSITORY_PATH,
                    capture_output=True,
                    text=True,
                )

            model_folder_path = (
                f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
            )

            os.rename(
                f"{model_folder_path}/onnx/model.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged.onnx",
            )
            os.rename(
                f"{model_folder_path}/onnx/model_quantized.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
            )

            st.success("Conversion successful!")

            st.code(output.stderr)

            with st.spinner("Uploading model..."):
                repository = api.create_repo(
                    f"{output_model_id}", exist_ok=True, private=False
                )

                upload_error_message = None

                try:
                    api.upload_folder(
                        folder_path=model_folder_path, repo_id=repository.repo_id
                    )
                except Exception as e:
                    upload_error_message = str(e)

            os.system(f"rm -rf {model_folder_path}")

            if upload_error_message:
                st.error(f"Upload failed: {upload_error_message}")
            else:
                st.success(f"Upload successful!")
                st.write("You can now go and view the model on HuggingFace!")
                st.link_button(
                    f"Go to {output_model_id}", output_model_url, type="primary"
                )