urroxyz's picture
Update app.py
b2ef551 verified
raw
history blame
7.02 kB
import logging
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
from urllib.request import urlopen, urlretrieve
import streamlit as st
from huggingface_hub import HfApi, whoami
from torch.jit import TracerWarning
from transformers import AutoConfig, GenerationConfig
# Suppress local TorchScript TracerWarnings
warnings.filterwarnings("ignore", category=TracerWarning)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class Config:
hf_token: str
hf_username: str
transformers_version: str = "3.5.0"
hf_base_url: str = "https://huggingface.co"
transformers_base_url: str = (
"https://github.com/huggingface/transformers.js/archive/refs"
)
repo_path: Path = Path("./transformers.js")
@classmethod
def from_env(cls) -> "Config":
system_token = st.secrets.get("HF_TOKEN")
user_token = st.session_state.get("user_hf_token")
if user_token:
hf_username = whoami(token=user_token)["name"]
else:
hf_username = (
os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
)
hf_token = user_token or system_token
if not hf_token:
raise ValueError("HF_TOKEN must be set")
return cls(hf_token=hf_token, hf_username=hf_username)
class ModelConverter:
def __init__(self, config: Config):
self.config = config
self.api = HfApi(token=config.hf_token)
def _get_ref_type(self) -> str:
url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
try:
return "tags" if urlopen(url).getcode() == 200 else "heads"
except Exception as e:
logger.warning(f"Failed to check tags, defaulting to heads: {e}")
return "heads"
def setup_repository(self) -> None:
if self.config.repo_path.exists():
return
ref_type = self._get_ref_type()
archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
try:
urlretrieve(archive_url, archive_path)
self._extract_archive(archive_path)
logger.info("Repository downloaded and extracted successfully")
except Exception as e:
raise RuntimeError(f"Failed to setup repository: {e}")
finally:
archive_path.unlink(missing_ok=True)
def _extract_archive(self, archive_path: Path) -> None:
import tarfile, tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
with tarfile.open(archive_path, "r:gz") as tar:
tar.extractall(tmp_dir)
next(Path(tmp_dir).iterdir()).rename(self.config.repo_path)
def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
try:
# Prepare model dir
model_dir = self.config.repo_path / "models" / input_model_id
model_dir.mkdir(parents=True, exist_ok=True)
# Relocate generation params
base_cfg = AutoConfig.from_pretrained(input_model_id)
gen_cfg = GenerationConfig.from_model_config(base_cfg)
for k in gen_cfg.to_dict():
if hasattr(base_cfg, k): setattr(base_cfg, k, None)
base_cfg.save_pretrained(model_dir)
gen_cfg.save_pretrained(model_dir)
# Set verbose logging
env = os.environ.copy()
env["TRANSFORMERS_VERBOSITY"] = "debug"
# Build command with debug
cmd = [
sys.executable,
"-m", "scripts.convert",
"--quantize",
"--trust_remote_code",
"--model_id", input_model_id,
"--output_attentions",
"--debug"
]
result = subprocess.run(
cmd,
cwd=self.config.repo_path,
capture_output=True,
text=True,
env=env,
)
# Filter warnings
filtered = [ln for ln in result.stderr.splitlines() if not ln.startswith("Moving the following attributes") and "TracerWarning" not in ln]
stderr = "\n".join(filtered)
if result.returncode != 0:
return False, stderr
return True, stderr
except Exception as e:
return False, str(e)
def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
model_folder = self.config.repo_path / "models" / input_model_id
try:
self.api.create_repo(output_model_id, exist_ok=True, private=False)
readme = model_folder / "README.md"
if not readme.exists():
readme.write_text(self.generate_readme(input_model_id))
self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id)
return None
except Exception as e:
return str(e)
finally:
import shutil; shutil.rmtree(model_folder, ignore_errors=True)
def generate_readme(self, imi: str) -> str:
return (
"---\n"
"library_name: transformers.js\n"
"base_model:\n"
f"- {imi}\n"
"---\n\n"
f"# {imi.split('/')[-1]} (ONNX)\n\n"
f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
"Converted with debug logs and attention maps.\n"
)
def main():
st.write("## Convert a Hugging Face model to ONNX (with debug)")
try:
config = Config.from_env()
conv = ModelConverter(config)
conv.setup_repository()
input_id = st.text_input("Model ID e.g. EleutherAI/pythia-14m")
if not input_id: return
st.text_input("HF write token (optional)", type="password", key="user_hf_token")
same = st.checkbox("Upload to same repo?", value=False) if config.hf_username == input_id.split("/")[0] else False
name = input_id.split("/")[-1]; out = f"{config.hf_username}/{name}" + ("" if same else "-ONNX")
url = f"{config.hf_base_url}/{out}"; st.code(url)
if not st.button("Proceed"): return
with st.spinner("Converting (debug)..."):
ok, err = conv.convert_model(input_id)
if not ok: st.error(f"Conversion failed: {err}"); return
st.success("Conversion successful!"); st.code(err)
with st.spinner("Uploading..."):
err2 = conv.upload_model(input_id, out)
if err2: st.error(f"Upload failed: {err2}"); return
st.success("Upload successful!"); st.link_button(f"Go to {out}", url)
except Exception as e:
logger.exception(e); st.error(f"Error: {e}")
if __name__ == "__main__": main()