File size: 4,763 Bytes
d088d6c
 
 
435210d
d088d6c
435210d
d088d6c
 
 
 
 
 
 
435210d
 
 
d088d6c
435210d
d088d6c
 
 
435210d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435210d
 
 
 
 
 
 
 
 
d088d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435210d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
from huggingface_hub import HfApi
import os
import urllib.request
import subprocess
import tarfile

HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
    st.secrets.get("HF_USERNAME")
    or os.environ.get("HF_USERNAME")
    or os.environ.get("SPACE_AUTHOR_NAME")
)

TRANSFORMERS_REPOSITORY_REVISION = "v3"
TRANSFORMERS_REPOSITORY_URL = f"https://github.com/xenova/transformers.js/archive/refs/heads/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"

if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
    # Download the .tar.gz file
    print(f"Downloading the repository from {TRANSFORMERS_REPOSITORY_URL}...")
    urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
    
    # Extract the .tar.gz file
    print(f"Extracting the archive {ARCHIVE_PATH}...")
    with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
        tar.extractall()
    
    # Rename the extracted folder to match the expected path
    # GitHub strips the leading 'v', so we handle that here
    extracted_folder = f"./transformers.js-{TRANSFORMERS_REPOSITORY_REVISION.lstrip('v')}"
    os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
    
    # Remove the downloaded .tar.gz archive
    os.remove(ARCHIVE_PATH)
    print("Repository downloaded and extracted successfully.")
    
st.write("## Convert a HuggingFace model to ONNX")

input_model_id = st.text_input(
    "Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)

if input_model_id:
    model_name = (
        input_model_id.replace(f"{HF_BASE_URL}/", "")
        .replace("/", "-")
        .replace(f"{HF_USERNAME}-", "")
        .strip()
    )
    output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
    output_model_url = f"{HF_BASE_URL}/{output_model_id}"
    api = HfApi(token=HF_TOKEN)
    repo_exists = api.repo_exists(output_model_id)

    if repo_exists:
        st.write("This model has already been converted! 🎉")
        st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
    else:
        st.write(f"This model will be converted and uploaded to the following URL:")
        st.code(output_model_url, language="plaintext")
        start_conversion = st.button(label="Proceed", type="primary")

        if start_conversion:
            with st.spinner("Converting model..."):
                output = subprocess.run(
                    [
                        "python",
                        "-m",
                        "scripts.convert",
                        "--quantize",
                        "--model_id",
                        input_model_id,
                    ],
                    cwd=TRANSFORMERS_REPOSITORY_PATH,
                    capture_output=True,
                    text=True,
                )
                
                # Log the script output
                print("### Script Output ###")
                print(output.stdout)
            
                # Log any errors
                if output.stderr:
                    print("### Script Errors ###")
                    print(output.stderr)

            model_folder_path = (
                f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
            )

            os.rename(
                f"{model_folder_path}/onnx/model.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged.onnx",
            )
            os.rename(
                f"{model_folder_path}/onnx/model_quantized.onnx",
                f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
            )

            st.success("Conversion successful!")

            st.code(output.stderr)

            with st.spinner("Uploading model..."):
                repository = api.create_repo(
                    f"{output_model_id}", exist_ok=True, private=False
                )

                upload_error_message = None

                try:
                    api.upload_folder(
                        folder_path=model_folder_path, repo_id=repository.repo_id
                    )
                except Exception as e:
                    upload_error_message = str(e)

            os.system(f"rm -rf {model_folder_path}")

            if upload_error_message:
                st.error(f"Upload failed: {upload_error_message}")
            else:
                st.success(f"Upload successful!")
                st.write("You can now go and view the model on HuggingFace!")
                st.link_button(
                    f"Go to {output_model_id}", output_model_url, type="primary"
                )