File size: 4,740 Bytes
d088d6c 58a0ecb 1b48b29 58a0ecb 1b48b29 d088d6c 1b48b29 58a0ecb 1b48b29 d088d6c 1b48b29 d088d6c 1b48b29 58a0ecb 1b48b29 58a0ecb 1b48b29 58a0ecb 1b48b29 58a0ecb 1b48b29 58a0ecb d088d6c 6cbd7e2 d088d6c 6cbd7e2 d088d6c 58a0ecb 1b48b29 58a0ecb 1b48b29 d088d6c 58a0ecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import subprocess
import sys
import tarfile
import tempfile
import urllib.request
import streamlit as st
from huggingface_hub import HfApi
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
st.secrets.get("HF_USERNAME")
or os.environ.get("HF_USERNAME")
or os.environ.get("SPACE_AUTHOR_NAME")
)
TRANSFORMERS_BASE_URL = "https://github.com/xenova/transformers.js/archive/refs"
TRANSFORMERS_REPOSITORY_REVISION = "3.0.0"
TRANSFORMERS_REF_TYPE = (
"tags"
if urllib.request.urlopen(
f"{TRANSFORMERS_BASE_URL}/tags/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
).getcode()
== 200
else "heads"
)
TRANSFORMERS_REPOSITORY_URL = f"{TRANSFORMERS_BASE_URL}/{TRANSFORMERS_REF_TYPE}/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
with tempfile.TemporaryDirectory() as tmp_dir:
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
tar.extractall(tmp_dir)
extracted_folder = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
os.remove(ARCHIVE_PATH)
print("Repository downloaded and extracted successfully.")
st.write("## Convert a HuggingFace model to ONNX")
input_model_id = st.text_input(
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)
if input_model_id:
model_name = (
input_model_id.replace(f"{HF_BASE_URL}/", "")
.replace("/", "-")
.replace(f"{HF_USERNAME}-", "")
.strip()
)
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
output_model_url = f"{HF_BASE_URL}/{output_model_id}"
api = HfApi(token=HF_TOKEN)
repo_exists = api.repo_exists(output_model_id)
if repo_exists:
st.write("This model has already been converted! 🎉")
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
else:
st.write(f"This model will be converted and uploaded to the following URL:")
st.code(output_model_url, language="plaintext")
start_conversion = st.button(label="Proceed", type="primary")
if start_conversion:
with st.spinner("Converting model..."):
output = subprocess.run(
[
sys.executable,
"-m",
"scripts.convert",
"--quantize",
"--model_id",
input_model_id,
],
cwd=TRANSFORMERS_REPOSITORY_PATH,
capture_output=True,
text=True,
env={},
)
# Log the script output
print("### Script Output ###")
print(output.stdout)
# Log any errors
if output.stderr:
print("### Script Errors ###")
print(output.stderr)
model_folder_path = (
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
)
os.rename(
f"{model_folder_path}/onnx/model.onnx",
f"{model_folder_path}/onnx/decoder_model_merged.onnx",
)
os.rename(
f"{model_folder_path}/onnx/model_quantized.onnx",
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
)
st.success("Conversion successful!")
st.code(output.stderr)
with st.spinner("Uploading model..."):
repository = api.create_repo(
f"{output_model_id}", exist_ok=True, private=False
)
upload_error_message = None
try:
api.upload_folder(
folder_path=model_folder_path, repo_id=repository.repo_id
)
except Exception as e:
upload_error_message = str(e)
os.system(f"rm -rf {model_folder_path}")
if upload_error_message:
st.error(f"Upload failed: {upload_error_message}")
else:
st.success(f"Upload successful!")
st.write("You can now go and view the model on HuggingFace!")
st.link_button(
f"Go to {output_model_id}", output_model_url, type="primary"
)
|