import logging import os import subprocess import sys import warnings from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple from urllib.request import urlopen, urlretrieve import streamlit as st from huggingface_hub import HfApi, whoami from torch.jit import TracerWarning from transformers import AutoConfig, GenerationConfig # Suppress local TorchScript TracerWarnings warnings.filterwarnings("ignore", category=TracerWarning) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class Config: hf_token: str hf_username: str transformers_version: str = "3.5.0" hf_base_url: str = "https://huggingface.co" transformers_base_url: str = ( "https://github.com/huggingface/transformers.js/archive/refs" ) repo_path: Path = Path("./transformers.js") @classmethod def from_env(cls) -> "Config": system_token = st.secrets.get("HF_TOKEN") user_token = st.session_state.get("user_hf_token") if user_token: hf_username = whoami(token=user_token)["name"] else: hf_username = ( os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"] ) hf_token = user_token or system_token if not hf_token: raise ValueError("HF_TOKEN must be set") return cls(hf_token=hf_token, hf_username=hf_username) class ModelConverter: def __init__(self, config: Config): self.config = config self.api = HfApi(token=config.hf_token) def _get_ref_type(self) -> str: url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz" try: return "tags" if urlopen(url).getcode() == 200 else "heads" except Exception as e: logger.warning(f"Failed to check tags, defaulting to heads: {e}") return "heads" def setup_repository(self) -> None: if self.config.repo_path.exists(): return ref_type = self._get_ref_type() archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz" archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz") try: urlretrieve(archive_url, archive_path) self._extract_archive(archive_path) logger.info("Repository downloaded and extracted successfully") except Exception as e: raise RuntimeError(f"Failed to setup repository: {e}") finally: archive_path.unlink(missing_ok=True) def _extract_archive(self, archive_path: Path) -> None: import tarfile, tempfile with tempfile.TemporaryDirectory() as tmp_dir: with tarfile.open(archive_path, "r:gz") as tar: tar.extractall(tmp_dir) next(Path(tmp_dir).iterdir()).rename(self.config.repo_path) def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]: try: # Prepare model dir model_dir = self.config.repo_path / "models" / input_model_id model_dir.mkdir(parents=True, exist_ok=True) # Relocate generation params base_cfg = AutoConfig.from_pretrained(input_model_id) gen_cfg = GenerationConfig.from_model_config(base_cfg) for k in gen_cfg.to_dict(): if hasattr(base_cfg, k): setattr(base_cfg, k, None) base_cfg.save_pretrained(model_dir) gen_cfg.save_pretrained(model_dir) # Set verbose logging env = os.environ.copy() env["TRANSFORMERS_VERBOSITY"] = "debug" # Build command with debug cmd = [ sys.executable, "-m", "scripts.convert", "--quantize", "--trust_remote_code", "--model_id", input_model_id, "--output_attentions", "--debug" ] result = subprocess.run( cmd, cwd=self.config.repo_path, capture_output=True, text=True, env=env, ) # Filter warnings filtered = [ln for ln in result.stderr.splitlines() if not ln.startswith("Moving the following attributes") and "TracerWarning" not in ln] stderr = "\n".join(filtered) if result.returncode != 0: return False, stderr return True, stderr except Exception as e: return False, str(e) def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]: model_folder = self.config.repo_path / "models" / input_model_id try: self.api.create_repo(output_model_id, exist_ok=True, private=False) readme = model_folder / "README.md" if not readme.exists(): readme.write_text(self.generate_readme(input_model_id)) self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id) return None except Exception as e: return str(e) finally: import shutil; shutil.rmtree(model_folder, ignore_errors=True) def generate_readme(self, imi: str) -> str: return ( "---\n" "library_name: transformers.js\n" "base_model:\n" f"- {imi}\n" "---\n\n" f"# {imi.split('/')[-1]} (ONNX)\n\n" f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). " "Converted with debug logs and attention maps.\n" ) def main(): st.write("## Convert a Hugging Face model to ONNX (with debug)") try: config = Config.from_env() conv = ModelConverter(config) conv.setup_repository() input_id = st.text_input("Model ID e.g. EleutherAI/pythia-14m") if not input_id: return st.text_input("HF write token (optional)", type="password", key="user_hf_token") same = st.checkbox("Upload to same repo?", value=False) if config.hf_username == input_id.split("/")[0] else False name = input_id.split("/")[-1]; out = f"{config.hf_username}/{name}" + ("" if same else "-ONNX") url = f"{config.hf_base_url}/{out}"; st.code(url) if not st.button("Proceed"): return with st.spinner("Converting (debug)..."): ok, err = conv.convert_model(input_id) if not ok: st.error(f"Conversion failed: {err}"); return st.success("Conversion successful!"); st.code(err) with st.spinner("Uploading..."): err2 = conv.upload_model(input_id, out) if err2: st.error(f"Upload failed: {err2}"); return st.success("Upload successful!"); st.link_button(f"Go to {out}", url) except Exception as e: logger.exception(e); st.error(f"Error: {e}") if __name__ == "__main__": main()