urroxyz commited on
Commit
de5f9d5
·
verified ·
1 Parent(s): b0efbfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -10
app.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import os
3
  import subprocess
4
  import sys
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
  from typing import Optional, Tuple
@@ -9,6 +10,11 @@ from urllib.request import urlopen, urlretrieve
9
 
10
  import streamlit as st
11
  from huggingface_hub import HfApi, whoami
 
 
 
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
@@ -88,13 +94,31 @@ class ModelConverter:
88
  with tempfile.TemporaryDirectory() as tmp_dir:
89
  with tarfile.open(archive_path, "r:gz") as tar:
90
  tar.extractall(tmp_dir)
91
-
92
  extracted_folder = next(Path(tmp_dir).iterdir())
93
  extracted_folder.rename(self.config.repo_path)
94
 
95
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
96
- """Convert the model to ONNX format, always exporting attention maps."""
 
 
 
 
97
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  cmd = [
99
  sys.executable,
100
  "-m", "scripts.convert",
@@ -103,12 +127,14 @@ class ModelConverter:
103
  "--model_id", input_model_id,
104
  "--output_attentions",
105
  ]
 
 
106
  result = subprocess.run(
107
  cmd,
108
  cwd=self.config.repo_path,
109
  capture_output=True,
110
  text=True,
111
- env={},
112
  )
113
 
114
  if result.returncode != 0:
@@ -127,13 +153,13 @@ class ModelConverter:
127
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
128
 
129
  readme_path = f"{model_folder_path}/README.md"
130
-
131
  if not os.path.exists(readme_path):
132
  with open(readme_path, "w") as file:
133
  file.write(self.generate_readme(input_model_id))
134
 
135
  self.api.upload_folder(
136
- folder_path=str(model_folder_path), repo_id=output_model_id
 
137
  )
138
  return None
139
  except Exception as e:
@@ -142,7 +168,7 @@ class ModelConverter:
142
  import shutil
143
  shutil.rmtree(model_folder_path, ignore_errors=True)
144
 
145
- def generate_readme(self, imi: str):
146
  return (
147
  "---\n"
148
  "library_name: transformers.js\n"
@@ -178,9 +204,7 @@ def main():
178
  )
179
 
180
  if config.hf_username == input_model_id.split("/")[0]:
181
- same_repo = st.checkbox(
182
- "Upload ONNX weights to the same repository?"
183
- )
184
  else:
185
  same_repo = False
186
 
@@ -226,4 +250,3 @@ def main():
226
 
227
  if __name__ == "__main__":
228
  main()
229
-
 
2
  import os
3
  import subprocess
4
  import sys
5
+ import warnings
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  from typing import Optional, Tuple
 
10
 
11
  import streamlit as st
12
  from huggingface_hub import HfApi, whoami
13
+ from torch.jit import TracerWarning
14
+ from transformers import AutoConfig, GenerationConfig
15
+
16
+ # Suppress TorchScript tracer warnings globally
17
+ warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
 
94
  with tempfile.TemporaryDirectory() as tmp_dir:
95
  with tarfile.open(archive_path, "r:gz") as tar:
96
  tar.extractall(tmp_dir)
 
97
  extracted_folder = next(Path(tmp_dir).iterdir())
98
  extracted_folder.rename(self.config.repo_path)
99
 
100
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
101
+ """
102
+ Convert the model to ONNX format, always exporting attention maps.
103
+ Also relocate any generation parameters into generation_config.json
104
+ and suppress JIT tracer warnings.
105
+ """
106
  try:
107
+ # 1. Clone or prepare a local copy of the model to adjust configs
108
+ model_dir = self.config.repo_path / "models" / input_model_id
109
+ model_dir.mkdir(parents=True, exist_ok=True)
110
+
111
+ # 2. Load and relocate generation parameters
112
+ config = AutoConfig.from_pretrained(input_model_id)
113
+ gen_cfg = GenerationConfig.from_model_config(config)
114
+ # Remove generation-specific keys from model config
115
+ for key in gen_cfg.to_dict().keys():
116
+ if hasattr(config, key):
117
+ setattr(config, key, None)
118
+ config.save_pretrained(model_dir)
119
+ gen_cfg.save_pretrained(model_dir)
120
+
121
+ # 3. Build the conversion command
122
  cmd = [
123
  sys.executable,
124
  "-m", "scripts.convert",
 
127
  "--model_id", input_model_id,
128
  "--output_attentions",
129
  ]
130
+
131
+ # 4. Run the conversion
132
  result = subprocess.run(
133
  cmd,
134
  cwd=self.config.repo_path,
135
  capture_output=True,
136
  text=True,
137
+ env=os.environ.copy(),
138
  )
139
 
140
  if result.returncode != 0:
 
153
  self.api.create_repo(output_model_id, exist_ok=True, private=False)
154
 
155
  readme_path = f"{model_folder_path}/README.md"
 
156
  if not os.path.exists(readme_path):
157
  with open(readme_path, "w") as file:
158
  file.write(self.generate_readme(input_model_id))
159
 
160
  self.api.upload_folder(
161
+ folder_path=str(model_folder_path),
162
+ repo_id=output_model_id
163
  )
164
  return None
165
  except Exception as e:
 
168
  import shutil
169
  shutil.rmtree(model_folder_path, ignore_errors=True)
170
 
171
+ def generate_readme(self, imi: str) -> str:
172
  return (
173
  "---\n"
174
  "library_name: transformers.js\n"
 
204
  )
205
 
206
  if config.hf_username == input_model_id.split("/")[0]:
207
+ same_repo = st.checkbox("Upload ONNX weights to the same repository?")
 
 
208
  else:
209
  same_repo = False
210
 
 
250
 
251
  if __name__ == "__main__":
252
  main()