File size: 4,763 Bytes
d088d6c 435210d d088d6c 435210d d088d6c 435210d d088d6c 435210d d088d6c 435210d d088d6c 435210d d088d6c 435210d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
from huggingface_hub import HfApi
import os
import urllib.request
import subprocess
import tarfile
HF_TOKEN = st.secrets.get("HF_TOKEN") or os.environ.get("HF_TOKEN")
HF_USERNAME = (
st.secrets.get("HF_USERNAME")
or os.environ.get("HF_USERNAME")
or os.environ.get("SPACE_AUTHOR_NAME")
)
TRANSFORMERS_REPOSITORY_REVISION = "v3"
TRANSFORMERS_REPOSITORY_URL = f"https://github.com/xenova/transformers.js/archive/refs/heads/{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
TRANSFORMERS_REPOSITORY_PATH = "./transformers.js"
ARCHIVE_PATH = f"./transformers_{TRANSFORMERS_REPOSITORY_REVISION}.tar.gz"
HF_BASE_URL = "https://huggingface.co"
if not os.path.exists(TRANSFORMERS_REPOSITORY_PATH):
# Download the .tar.gz file
print(f"Downloading the repository from {TRANSFORMERS_REPOSITORY_URL}...")
urllib.request.urlretrieve(TRANSFORMERS_REPOSITORY_URL, ARCHIVE_PATH)
# Extract the .tar.gz file
print(f"Extracting the archive {ARCHIVE_PATH}...")
with tarfile.open(ARCHIVE_PATH, "r:gz") as tar:
tar.extractall()
# Rename the extracted folder to match the expected path
# GitHub strips the leading 'v', so we handle that here
extracted_folder = f"./transformers.js-{TRANSFORMERS_REPOSITORY_REVISION.lstrip('v')}"
os.rename(extracted_folder, TRANSFORMERS_REPOSITORY_PATH)
# Remove the downloaded .tar.gz archive
os.remove(ARCHIVE_PATH)
print("Repository downloaded and extracted successfully.")
st.write("## Convert a HuggingFace model to ONNX")
input_model_id = st.text_input(
"Enter the HuggingFace model ID to convert. Example: `EleutherAI/pythia-14m`"
)
if input_model_id:
model_name = (
input_model_id.replace(f"{HF_BASE_URL}/", "")
.replace("/", "-")
.replace(f"{HF_USERNAME}-", "")
.strip()
)
output_model_id = f"{HF_USERNAME}/{model_name}-ONNX"
output_model_url = f"{HF_BASE_URL}/{output_model_id}"
api = HfApi(token=HF_TOKEN)
repo_exists = api.repo_exists(output_model_id)
if repo_exists:
st.write("This model has already been converted! 🎉")
st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
else:
st.write(f"This model will be converted and uploaded to the following URL:")
st.code(output_model_url, language="plaintext")
start_conversion = st.button(label="Proceed", type="primary")
if start_conversion:
with st.spinner("Converting model..."):
output = subprocess.run(
[
"python",
"-m",
"scripts.convert",
"--quantize",
"--model_id",
input_model_id,
],
cwd=TRANSFORMERS_REPOSITORY_PATH,
capture_output=True,
text=True,
)
# Log the script output
print("### Script Output ###")
print(output.stdout)
# Log any errors
if output.stderr:
print("### Script Errors ###")
print(output.stderr)
model_folder_path = (
f"{TRANSFORMERS_REPOSITORY_PATH}/models/{input_model_id}"
)
os.rename(
f"{model_folder_path}/onnx/model.onnx",
f"{model_folder_path}/onnx/decoder_model_merged.onnx",
)
os.rename(
f"{model_folder_path}/onnx/model_quantized.onnx",
f"{model_folder_path}/onnx/decoder_model_merged_quantized.onnx",
)
st.success("Conversion successful!")
st.code(output.stderr)
with st.spinner("Uploading model..."):
repository = api.create_repo(
f"{output_model_id}", exist_ok=True, private=False
)
upload_error_message = None
try:
api.upload_folder(
folder_path=model_folder_path, repo_id=repository.repo_id
)
except Exception as e:
upload_error_message = str(e)
os.system(f"rm -rf {model_folder_path}")
if upload_error_message:
st.error(f"Upload failed: {upload_error_message}")
else:
st.success(f"Upload successful!")
st.write("You can now go and view the model on HuggingFace!")
st.link_button(
f"Go to {output_model_id}", output_model_url, type="primary"
) |