import logging import os import subprocess import sys import warnings from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple from urllib.request import urlopen, urlretrieve import streamlit as st from huggingface_hub import HfApi, whoami from torch.jit import TracerWarning from transformers import AutoConfig, GenerationConfig # Suppress local TorchScript tracer warnings warnings.filterwarnings("ignore", category=TracerWarning) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class Config: """Application configuration.""" hf_token: str hf_username: str transformers_version: str = "3.5.0" hf_base_url: str = "https://huggingface.co" transformers_base_url: str = ( "https://github.com/huggingface/transformers.js/archive/refs" ) repo_path: Path = Path("./transformers.js") @classmethod def from_env(cls) -> "Config": """Create config from environment variables and secrets.""" system_token = st.secrets.get("HF_TOKEN") user_token = st.session_state.get("user_hf_token") if user_token: hf_username = whoami(token=user_token)["name"] else: hf_username = ( os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"] ) hf_token = user_token or system_token if not hf_token: raise ValueError("HF_TOKEN must be set") return cls(hf_token=hf_token, hf_username=hf_username) class ModelConverter: """Handles model conversion and upload operations.""" def __init__(self, config: Config): self.config = config self.api = HfApi(token=config.hf_token) def _get_ref_type(self) -> str: """Determine the reference type for the transformers repository.""" url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz" try: return "tags" if urlopen(url).getcode() == 200 else "heads" except Exception as e: logger.warning(f"Failed to check tags, defaulting to heads: {e}") return "heads" def setup_repository(self) -> None: """Download and setup transformers.js repo if needed.""" if self.config.repo_path.exists(): return ref_type = self._get_ref_type() archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz" archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz") try: urlretrieve(archive_url, archive_path) self._extract_archive(archive_path) logger.info("Repository downloaded and extracted successfully") except Exception as e: raise RuntimeError(f"Failed to setup repository: {e}") finally: archive_path.unlink(missing_ok=True) def _extract_archive(self, archive_path: Path) -> None: """Extract the downloaded archive.""" import tarfile, tempfile with tempfile.TemporaryDirectory() as tmp_dir: with tarfile.open(archive_path, "r:gz") as tar: tar.extractall(tmp_dir) extracted_folder = next(Path(tmp_dir).iterdir()) extracted_folder.rename(self.config.repo_path) def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]: """ Convert the model to ONNX, always exporting attention maps. Relocate generation params, suppress tracer warnings, and filter out relocation/tracer warnings from stderr. """ try: # 1. Prepare a local folder for config tweaks model_dir = self.config.repo_path / "models" / input_model_id model_dir.mkdir(parents=True, exist_ok=True) # 2. Move any generation parameters into generation_config.json base_cfg = AutoConfig.from_pretrained(input_model_id) gen_cfg = GenerationConfig.from_model_config(base_cfg) for k in gen_cfg.to_dict(): if hasattr(base_cfg, k): setattr(base_cfg, k, None) base_cfg.save_pretrained(model_dir) gen_cfg.save_pretrained(model_dir) # 3. Set verbose logging via env var (no --debug flag) env = os.environ.copy() env["TRANSFORMERS_VERBOSITY"] = "debug" # 4. Build and run the conversion command cmd = [ sys.executable, "-m", "scripts.convert", "--quantize", "--trust_remote_code", "--model_id", input_model_id, "--output_attentions", ] result = subprocess.run( cmd, cwd=self.config.repo_path, capture_output=True, text=True, env=env, ) # 5. Filter out spurious warnings from stderr filtered = [] for ln in result.stderr.splitlines(): if ln.startswith("Moving the following attributes"): continue if "TracerWarning" in ln: continue filtered.append(ln) stderr = "\n".join(filtered) if result.returncode != 0: return False, stderr return True, stderr except Exception as e: return False, str(e) def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]: """Upload the converted model to Hugging Face Hub.""" model_folder = self.config.repo_path / "models" / input_model_id try: self.api.create_repo(output_model_id, exist_ok=True, private=False) readme_path = model_folder / "README.md" if not readme_path.exists(): readme_path.write_text(self.generate_readme(input_model_id)) self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id) return None except Exception as e: return str(e) finally: import shutil shutil.rmtree(model_folder, ignore_errors=True) def generate_readme(self, imi: str) -> str: return ( "---\n" "library_name: transformers.js\n" "base_model:\n" f"- {imi}\n" "---\n\n" f"# {imi.split('/')[-1]} (ONNX)\n\n" f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). " "Converted with attention maps and verbose export logs.\n" ) def main(): """Streamlit application entry point.""" st.write("## Convert a Hugging Face model to ONNX (with attentions & debug logs)") try: config = Config.from_env() converter = ModelConverter(config) converter.setup_repository() input_model_id = st.text_input( "Enter the Hugging Face model ID to convert, e.g. `EleutherAI/pythia-14m`" ) if not input_model_id: return st.text_input( "Optional: Your Hugging Face write token (for uploading to your namespace).", type="password", key="user_hf_token", ) if config.hf_username == input_model_id.split("/")[0]: same_repo = st.checkbox("Upload ONNX weights to the same repository?") else: same_repo = False model_name = input_model_id.split("/")[-1] output_model_id = f"{config.hf_username}/{model_name}" if not same_repo: output_model_id += "-ONNX" output_url = f"{config.hf_base_url}/{output_model_id}" st.write("Destination repository:") st.code(output_url, language="plaintext") if not st.button("Proceed", type="primary"): return with st.spinner("Converting model…"): success, stderr = converter.convert_model(input_model_id) if not success: st.error(f"Conversion failed: {stderr}") return st.success("Conversion successful!") st.code(stderr) with st.spinner("Uploading model…"): error = converter.upload_model(input_model_id, output_model_id) if error: st.error(f"Upload failed: {error}") return st.success("Upload successful!") st.link_button(f"Go to {output_model_id}", output_url, type="primary") except Exception as e: logger.exception("Application error") st.error(f"An error occurred: {e}") if __name__ == "__main__": main()