File size: 4,740 Bytes
d088d6c
 
58a0ecb
1b48b29
 
58a0ecb
1b48b29
 
 
d088d6c
 
 
 
 
 
 
1b48b29
 
 
58a0ecb
 
 
 
 
 
 
 
1b48b29
d088d6c
1b48b29
d088d6c
 
 
1b48b29
58a0ecb
1b48b29
 
 
58a0ecb
1b48b29
58a0ecb
1b48b29
58a0ecb
1b48b29
 
58a0ecb
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbd7e2
d088d6c
 
 
 
 
 
 
 
 
6cbd7e2
d088d6c
58a0ecb
1b48b29
 
 
58a0ecb
1b48b29
 
 
 
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58a0ecb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import subprocess
import sys
import tarfile
import tempfile
import urllib.request

import streamlit as st
from huggingface_hub import HfApi

HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
    st.secrets.get("HF_USERNAME")
    or os.environ.get("HF_USERNAME")
    or os.environ.get("SPACE_AUTHOR_NAME")
)

TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs"
TRANSFORMERS_REPOSITORY_REVISION = "3.0.0"
TRANSFORMERS_REF_TYPE = (
    "tags"
    if urllib.request.urlopen(
        f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
    ).getcode()
    == 200
    else "heads"
)
TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"

if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
    urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)

    with tempfile.TemporaryDirectory() as tmp_dir:
        with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
            tar.extractall(tmp_dir)

        extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])

        os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)

    os.remove(ARCHIVE_PATH)
    print("Repository downloaded and extracted successfully.")

st.write("## Convert a HuggingFace model to ONNX")

input_model_id = st.text_input(
    "Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)

if input_model_id:
    model_name = (
        input_model_id.replace(f"{HF_BASE_URL}/", "")
        .replace("/", "-")
        .replace(f"{HF_USERNAME}-", "")
        .strip()
    )
    output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
    output_model_url = f"{HF_BASE_URL}/{output_model_id}"
    api = HfApi(token=HF_TOKEN)
    repo_exists = api.repo_exists(output_model_id)

    if repo_exists:
        st.write("This model has already been converted! 🎉")
        st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
    else:
        st.write(f"This model will be converted and uploaded to the following URL:")
        st.code(output_model_url, language="plaintext")
        start_conversion = st.button(label="Proceed", type="primary")

        if start_conversion:
            with st.spinner("Converting model..."):
                output = subprocess.run(
                    [
                        sys.executable,
                        "-m",
                        "scripts.convert",
                        "--quantize",
                        "--model_id",
                        input_model_id,
                    ],
                    cwd=TRANSFORMERS_REPOSITORY_PATH,
                    capture_output=True,
                    text=True,
                    env={},
                )

                # Log the script output
                print("### Script Output ###")
                print(output.stdout)

                # Log any errors
                if output.stderr:
                    print("### Script Errors ###")
                    print(output.stderr)

            model_folder_path = (
                f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
            )

            os.rename(
                f"{model_folder_path}/onnx/model.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged.onnx",
            )
            os.rename(
                f"{model_folder_path}/onnx/model_quantized.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
            )

            st.success("Conversion successful!")

            st.code(output.stderr)

            with st.spinner("Uploading model..."):
                repository = api.create_repo(
                    f"{output_model_id}", exist_ok=True, private=False
                )

                upload_error_message = None

                try:
                    api.upload_folder(
                        folder_path=model_folder_path, repo_id=repository.repo_id
                    )
                except Exception as e:
                    upload_error_message = str(e)

            os.system(f"rm -rf {model_folder_path}")

            if upload_error_message:
                st.error(f"Upload failed: {upload_error_message}")
            else:
                st.success(f"Upload successful!")
                st.write("You can now go and view the model on HuggingFace!")
                st.link_button(
                    f"Go to {output_model_id}", output_model_url, type="primary"
                )