urroxyz commited on
Commit
b2ef551
·
verified ·
1 Parent(s): ea01d50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -133
app.py CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import HfApi, whoami
13
  from torch.jit import TracerWarning
14
  from transformers import AutoConfig, GenerationConfig
15
 
16
- # Suppress TorchScript tracer warnings in this process
17
  warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
@@ -22,8 +22,6 @@ logger = logging.getLogger(__name__)
22
 
23
  @dataclass
24
  class Config:
25
- """Application configuration."""
26
-
27
  hf_token: str
28
  hf_username: str
29
  transformers_version: str = "3.5.0"
@@ -35,7 +33,6 @@ class Config:
35
 
36
  @classmethod
37
  def from_env(cls) -> "Config":
38
- """Create config from environment variables and secrets."""
39
  system_token = st.secrets.get("HF_TOKEN")
40
  user_token = st.session_state.get("user_hf_token")
41
  if user_token:
@@ -45,22 +42,17 @@ class Config:
45
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
46
  )
47
  hf_token = user_token or system_token
48
-
49
  if not hf_token:
50
  raise ValueError("HF_TOKEN must be set")
51
-
52
  return cls(hf_token=hf_token, hf_username=hf_username)
53
 
54
 
55
  class ModelConverter:
56
- """Handles model conversion and upload operations."""
57
-
58
  def __init__(self, config: Config):
59
  self.config = config
60
  self.api = HfApi(token=config.hf_token)
61
 
62
  def _get_ref_type(self) -> str:
63
- """Determine the reference type for the transformers repository."""
64
  url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
65
  try:
66
  return "tags" if urlopen(url).getcode() == 200 else "heads"
@@ -69,14 +61,11 @@ class ModelConverter:
69
  return "heads"
70
 
71
  def setup_repository(self) -> None:
72
- """Download and setup transformers repository if needed."""
73
  if self.config.repo_path.exists():
74
  return
75
-
76
  ref_type = self._get_ref_type()
77
  archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
78
  archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
79
-
80
  try:
81
  urlretrieve(archive_url, archive_path)
82
  self._extract_archive(archive_path)
@@ -87,96 +76,66 @@ class ModelConverter:
87
  archive_path.unlink(missing_ok=True)
88
 
89
  def _extract_archive(self, archive_path: Path) -> None:
90
- """Extract the downloaded archive."""
91
- import tarfile
92
- import tempfile
93
-
94
  with tempfile.TemporaryDirectory() as tmp_dir:
95
  with tarfile.open(archive_path, "r:gz") as tar:
96
  tar.extractall(tmp_dir)
97
- extracted_folder = next(Path(tmp_dir).iterdir())
98
- extracted_folder.rename(self.config.repo_path)
99
 
100
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
101
- """
102
- Convert the model to ONNX format, always exporting attention maps.
103
- Relocate generation parameters, suppress tracer warnings, and
104
- strip out both relocation and tracer warnings from stderr.
105
- """
106
  try:
107
- # Prepare local directory for config edits
108
  model_dir = self.config.repo_path / "models" / input_model_id
109
  model_dir.mkdir(parents=True, exist_ok=True)
110
-
111
- # Load and relocate generation parameters
112
- base_config = AutoConfig.from_pretrained(input_model_id)
113
- gen_config = GenerationConfig.from_model_config(base_config)
114
- # Remove generation params from base config
115
- for key in gen_config.to_dict():
116
- if hasattr(base_config, key):
117
- setattr(base_config, key, None)
118
- base_config.save_pretrained(model_dir)
119
- gen_config.save_pretrained(model_dir)
120
-
121
- # Build conversion command with global warning ignore
122
  cmd = [
123
  sys.executable,
124
- "-W", "ignore",
125
  "-m", "scripts.convert",
126
  "--quantize",
127
  "--trust_remote_code",
128
  "--model_id", input_model_id,
129
  "--output_attentions",
 
130
  ]
131
-
132
  result = subprocess.run(
133
  cmd,
134
  cwd=self.config.repo_path,
135
  capture_output=True,
136
  text=True,
137
- env=os.environ.copy(),
138
  )
139
-
140
- # Filter out relocation and tracer warnings
141
- lines = []
142
- for ln in result.stderr.splitlines():
143
- if ln.startswith("Moving the following attributes"):
144
- continue
145
- if "TracerWarning" in ln:
146
- continue
147
- lines.append(ln)
148
- stderr = "\n".join(lines)
149
-
150
  if result.returncode != 0:
151
  return False, stderr
152
-
153
  return True, stderr
154
-
155
  except Exception as e:
156
  return False, str(e)
157
 
158
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
159
- """Upload the converted model to Hugging Face."""
160
- model_folder_path = self.config.repo_path / "models" / input_model_id
161
-
162
  try:
163
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
164
-
165
- readme_path = f"{model_folder_path}/README.md"
166
- if not os.path.exists(readme_path):
167
- with open(readme_path, "w") as file:
168
- file.write(self.generate_readme(input_model_id))
169
-
170
- self.api.upload_folder(
171
- folder_path=str(model_folder_path),
172
- repo_id=output_model_id
173
- )
174
  return None
175
  except Exception as e:
176
  return str(e)
177
  finally:
178
- import shutil
179
- shutil.rmtree(model_folder_path, ignore_errors=True)
180
 
181
  def generate_readme(self, imi: str) -> str:
182
  return (
@@ -187,76 +146,31 @@ class ModelConverter:
187
  "---\n\n"
188
  f"# {imi.split('/')[-1]} (ONNX)\n\n"
189
  f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
190
- "It was automatically converted and uploaded using "
191
- "[this space](https://huggingface.co/spaces/onnx-community/convert-to-onnx).\n"
192
  )
193
 
194
-
195
  def main():
196
- """Main application entry point."""
197
- st.write("## Convert a Hugging Face model to ONNX (with attentions)")
198
-
199
  try:
200
  config = Config.from_env()
201
- converter = ModelConverter(config)
202
- converter.setup_repository()
203
-
204
- input_model_id = st.text_input(
205
- "Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"
206
- )
207
- if not input_model_id:
208
- return
209
-
210
- st.text_input(
211
- "Optional: Your Hugging Face write token. Fill it if you want to upload under your account.",
212
- type="password",
213
- key="user_hf_token",
214
- )
215
-
216
- if config.hf_username == input_model_id.split("/")[0]:
217
- same_repo = st.checkbox("Upload ONNX weights to the same repository?")
218
- else:
219
- same_repo = False
220
-
221
- model_name = input_model_id.split("/")[-1]
222
- output_model_id = f"{config.hf_username}/{model_name}"
223
- if not same_repo:
224
- output_model_id += "-ONNX"
225
-
226
- output_model_url = f"{config.hf_base_url}/{output_model_id}"
227
-
228
- if not same_repo and converter.api.repo_exists(output_model_id):
229
- st.write("This model has already been converted! 🎉")
230
- st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
231
- return
232
-
233
- st.write("Destination repository:")
234
- st.code(output_model_url, language="plaintext")
235
-
236
- if not st.button(label="Proceed", type="primary"):
237
- return
238
-
239
- with st.spinner("Converting model (including attention maps)…"):
240
- success, stderr = converter.convert_model(input_model_id)
241
- if not success:
242
- st.error(f"Conversion failed: {stderr}")
243
- return
244
- st.success("Conversion successful!")
245
- st.code(stderr)
246
-
247
- with st.spinner("Uploading model…"):
248
- error = converter.upload_model(input_model_id, output_model_id)
249
- if error:
250
- st.error(f"Upload failed: {error}")
251
- return
252
- st.success("Upload successful!")
253
- st.write("You can now view the model on Hugging Face:")
254
- st.link_button(f"Go to {output_model_id}", output_model_url, type="primary")
255
-
256
  except Exception as e:
257
- logger.exception("Application error")
258
- st.error(f"An error occurred: {str(e)}")
259
-
260
 
261
- if __name__ == "__main__":
262
- main()
 
13
  from torch.jit import TracerWarning
14
  from transformers import AutoConfig, GenerationConfig
15
 
16
+ # Suppress local TorchScript TracerWarnings
17
  warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
 
22
 
23
  @dataclass
24
  class Config:
 
 
25
  hf_token: str
26
  hf_username: str
27
  transformers_version: str = "3.5.0"
 
33
 
34
  @classmethod
35
  def from_env(cls) -> "Config":
 
36
  system_token = st.secrets.get("HF_TOKEN")
37
  user_token = st.session_state.get("user_hf_token")
38
  if user_token:
 
42
  os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
43
  )
44
  hf_token = user_token or system_token
 
45
  if not hf_token:
46
  raise ValueError("HF_TOKEN must be set")
 
47
  return cls(hf_token=hf_token, hf_username=hf_username)
48
 
49
 
50
  class ModelConverter:
 
 
51
  def __init__(self, config: Config):
52
  self.config = config
53
  self.api = HfApi(token=config.hf_token)
54
 
55
  def _get_ref_type(self) -> str:
 
56
  url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
57
  try:
58
  return "tags" if urlopen(url).getcode() == 200 else "heads"
 
61
  return "heads"
62
 
63
  def setup_repository(self) -> None:
 
64
  if self.config.repo_path.exists():
65
  return
 
66
  ref_type = self._get_ref_type()
67
  archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
68
  archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
 
69
  try:
70
  urlretrieve(archive_url, archive_path)
71
  self._extract_archive(archive_path)
 
76
  archive_path.unlink(missing_ok=True)
77
 
78
  def _extract_archive(self, archive_path: Path) -> None:
79
+ import tarfile, tempfile
 
 
 
80
  with tempfile.TemporaryDirectory() as tmp_dir:
81
  with tarfile.open(archive_path, "r:gz") as tar:
82
  tar.extractall(tmp_dir)
83
+ next(Path(tmp_dir).iterdir()).rename(self.config.repo_path)
 
84
 
85
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
 
 
 
 
 
86
  try:
87
+ # Prepare model dir
88
  model_dir = self.config.repo_path / "models" / input_model_id
89
  model_dir.mkdir(parents=True, exist_ok=True)
90
+ # Relocate generation params
91
+ base_cfg = AutoConfig.from_pretrained(input_model_id)
92
+ gen_cfg = GenerationConfig.from_model_config(base_cfg)
93
+ for k in gen_cfg.to_dict():
94
+ if hasattr(base_cfg, k): setattr(base_cfg, k, None)
95
+ base_cfg.save_pretrained(model_dir)
96
+ gen_cfg.save_pretrained(model_dir)
97
+ # Set verbose logging
98
+ env = os.environ.copy()
99
+ env["TRANSFORMERS_VERBOSITY"] = "debug"
100
+ # Build command with debug
 
101
  cmd = [
102
  sys.executable,
 
103
  "-m", "scripts.convert",
104
  "--quantize",
105
  "--trust_remote_code",
106
  "--model_id", input_model_id,
107
  "--output_attentions",
108
+ "--debug"
109
  ]
 
110
  result = subprocess.run(
111
  cmd,
112
  cwd=self.config.repo_path,
113
  capture_output=True,
114
  text=True,
115
+ env=env,
116
  )
117
+ # Filter warnings
118
+ filtered = [ln for ln in result.stderr.splitlines() if not ln.startswith("Moving the following attributes") and "TracerWarning" not in ln]
119
+ stderr = "\n".join(filtered)
 
 
 
 
 
 
 
 
120
  if result.returncode != 0:
121
  return False, stderr
 
122
  return True, stderr
 
123
  except Exception as e:
124
  return False, str(e)
125
 
126
  def upload_model(self, input_model_id: str, output_model_id: str) -> Optional[str]:
127
+ model_folder = self.config.repo_path / "models" / input_model_id
 
 
128
  try:
129
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
130
+ readme = model_folder / "README.md"
131
+ if not readme.exists():
132
+ readme.write_text(self.generate_readme(input_model_id))
133
+ self.api.upload_folder(folder_path=str(model_folder), repo_id=output_model_id)
 
 
 
 
 
 
134
  return None
135
  except Exception as e:
136
  return str(e)
137
  finally:
138
+ import shutil; shutil.rmtree(model_folder, ignore_errors=True)
 
139
 
140
  def generate_readme(self, imi: str) -> str:
141
  return (
 
146
  "---\n\n"
147
  f"# {imi.split('/')[-1]} (ONNX)\n\n"
148
  f"This is an ONNX version of [{imi}](https://huggingface.co/{imi}). "
149
+ "Converted with debug logs and attention maps.\n"
 
150
  )
151
 
 
152
  def main():
153
+ st.write("## Convert a Hugging Face model to ONNX (with debug)")
 
 
154
  try:
155
  config = Config.from_env()
156
+ conv = ModelConverter(config)
157
+ conv.setup_repository()
158
+ input_id = st.text_input("Model ID e.g. EleutherAI/pythia-14m")
159
+ if not input_id: return
160
+ st.text_input("HF write token (optional)", type="password", key="user_hf_token")
161
+ same = st.checkbox("Upload to same repo?", value=False) if config.hf_username == input_id.split("/")[0] else False
162
+ name = input_id.split("/")[-1]; out = f"{config.hf_username}/{name}" + ("" if same else "-ONNX")
163
+ url = f"{config.hf_base_url}/{out}"; st.code(url)
164
+ if not st.button("Proceed"): return
165
+ with st.spinner("Converting (debug)..."):
166
+ ok, err = conv.convert_model(input_id)
167
+ if not ok: st.error(f"Conversion failed: {err}"); return
168
+ st.success("Conversion successful!"); st.code(err)
169
+ with st.spinner("Uploading..."):
170
+ err2 = conv.upload_model(input_id, out)
171
+ if err2: st.error(f"Upload failed: {err2}"); return
172
+ st.success("Upload successful!"); st.link_button(f"Go to {out}", url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  except Exception as e:
174
+ logger.exception(e); st.error(f"Error: {e}")
 
 
175
 
176
+ if __name__ == "__main__": main()