File size: 7,022 Bytes
1364cbe
5a87d26
ccc9d98
58a0ecb
de5f9d5
1364cbe
 
 
 
1b48b29
 
8fbbd7f
de5f9d5
 
 
b2ef551
de5f9d5
d088d6c
1364cbe
 
 
 
 
 
 
 
c951ed8
1364cbe
 
1e3c553
d088d6c
1364cbe
 
 
 
8fbbd7f
5a87d26
 
 
 
 
 
 
2af96bc
 
 
1364cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2ef551
1364cbe
 
 
b2ef551
1364cbe
 
 
b2ef551
de5f9d5
 
b2ef551
 
 
 
 
 
 
 
 
 
 
b0efbfd
 
 
 
 
 
 
b2ef551
b0efbfd
1364cbe
b0efbfd
1364cbe
 
 
b2ef551
d088d6c
b2ef551
 
 
1364cbe
d10f691
 
1364cbe
 
 
 
b2ef551
1364cbe
 
b2ef551
 
 
 
1364cbe
 
 
 
b2ef551
d088d6c
de5f9d5
afb14c6
 
 
 
 
 
 
 
b2ef551
afb14c6
 
1364cbe
b2ef551
1364cbe
 
b2ef551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1364cbe
b2ef551
d088d6c
b2ef551
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import logging
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
from urllib.request import urlopen, urlretrieve

import streamlit as st
from huggingface_hub import HfApi, whoami
from torch.jit import TracerWarning
from transformers import AutoConfig, GenerationConfig

# Suppress local TorchScript TracerWarnings
warnings.filterwarnings("ignore", category=TracerWarning)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class Config:
    hf_token: str
    hf_username: str
    transformers_version: str = "3.5.0"
    hf_base_url: str = "https://huggingface.co"
    transformers_base_url: str = (
        "https://github.com/huggingface/transformers.js/archive/refs"
    )
    repo_path: Path = Path("./transformers.js")

    @classmethod
    def from_env(cls) -> "Config":
        system_token = st.secrets.get("HF_TOKEN")
        user_token = st.session_state.get("user_hf_token")
        if user_token:
            hf_username = whoami(token=user_token)["name"]
        else:
            hf_username = (
                os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
            )
        hf_token = user_token or system_token
        if not hf_token:
            raise ValueError("HF_TOKEN must be set")
        return cls(hf_token=hf_token, hf_username=hf_username)


class ModelConverter:
    def __init__(self, config: Config):
        self.config = config
        self.api = HfApi(token=config.hf_token)

    def _get_ref_type(self) -> str:
        url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
        try:
            return "tags" if urlopen(url).getcode() == 200 else "heads"
        except Exception as e:
            logger.warning(f"Failed to check tags, defaulting to heads: {e}")
            return "heads"

    def setup_repository(self) -> None:
        if self.config.repo_path.exists():
            return
        ref_type = self._get_ref_type()
        archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
        archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
        try:
            urlretrieve(archive_url, archive_path)
            self._extract_archive(archive_path)
            logger.info("Repository downloaded and extracted successfully")
        except Exception as e:
            raise RuntimeError(f"Failed to setup repository: {e}")
        finally:
            archive_path.unlink(missing_ok=True)

    def _extract_archive(self, archive_path: Path) -> None:
        import tarfile, tempfile
        with tempfile.TemporaryDirectory() as tmp_dir:
            with tarfile.open(archive_path, "r:gz") as tar:
                tar.extractall(tmp_dir)
            next(Path(tmp_dir).iterdir()).rename(self.config.repo_path)

    def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
        try:
            # Prepare model dir
            model_dir = self.config.repo_path / "models" / input_model_id
            model_dir.mkdir(parents=True, exist_ok=True)
            # Relocate generation params
            base_cfg = AutoConfig.from_pretrained(input_model_id)
            gen_cfg = GenerationConfig.from_model_config(base_cfg)
            for k in gen_cfg.to_dict():
                if hasattr(base_cfg, k): setattr(base_cfg, k, None)
            base_cfg.save_pretrained(model_dir)
            gen_cfg.save_pretrained(model_dir)
            # Set verbose logging
            env = os.environ.copy()
            env["TRANSFORMERS_VERBOSITY"] = "debug"
            # Build command with debug
            cmd = [
                sys.executable,
                "-m", "scripts.convert",
                "--quantize",
                "--trust_remote_code",
                "--model_id", input_model_id,
                "--output_attentions",
                "--debug"
            ]
            result = subprocess.run(
                cmd,
                cwd=self.config.repo_path,
                capture_output=True,
                text=True,
                env=env,
            )
            # Filter warnings
            filtered = [ln for ln in result.stderr.splitlines() if not ln.startswith("Moving the following attributes") and "TracerWarning" not in ln]
            stderr = "\n".join(filtered)
            if result.returncode != 0:
                return False, stderr
            return True, stderr
        except Exception as e:
            return False, str(e)

    def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
        model_folder = self.config.repo_path / "models" / input_model_id
        try:
            self.api.create_repo(output_model_id, exist_ok=True, private=False)
            readme = model_folder / "README.md"
            if not readme.exists():
                readme.write_text(self.generate_readme(input_model_id))
            self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id)
            return None
        except Exception as e:
            return str(e)
        finally:
            import shutil; shutil.rmtree(model_folder, ignore_errors=True)

    def generate_readme(self, imi: str) -> str:
        return (
            "---\n"
            "library_name: transformers.js\n"
            "base_model:\n"
            f"- {imi}\n"
            "---\n\n"
            f"# {imi.split('/')[-1]} (ONNX)\n\n"
            f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
            "Converted with debug logs and attention maps.\n"
        )

def main():
    st.write("## Convert a Hugging Face model to ONNX (with debug)")
    try:
        config = Config.from_env()
        conv = ModelConverter(config)
        conv.setup_repository()
        input_id = st.text_input("Model ID e.g. EleutherAI/pythia-14m")
        if not input_id: return
        st.text_input("HF write token (optional)", type="password", key="user_hf_token")
        same = st.checkbox("Upload to same repo?", value=False) if config.hf_username == input_id.split("/")[0] else False
        name = input_id.split("/")[-1]; out = f"{config.hf_username}/{name}" + ("" if same else "-ONNX")
        url = f"{config.hf_base_url}/{out}"; st.code(url)
        if not st.button("Proceed"): return
        with st.spinner("Converting (debug)..."):
            ok, err = conv.convert_model(input_id)
            if not ok: st.error(f"Conversion failed: {err}"); return
            st.success("Conversion successful!"); st.code(err)
        with st.spinner("Uploading..."):
            err2 = conv.upload_model(input_id, out)
            if err2: st.error(f"Upload failed: {err2}"); return
            st.success("Upload successful!"); st.link_button(f"Go to {out}", url)
    except Exception as e:
        logger.exception(e); st.error(f"Error: {e}")

if __name__ == "__main__": main()