urroxyz commited on
Commit
00c972d
·
verified ·
1 Parent(s): cc245ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -22
app.py CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import HfApi, whoami
13
  from torch.jit import TracerWarning
14
  from transformers import AutoConfig, GenerationConfig
15
 
16
- # Suppress TorchScript tracer warnings in parent process
17
  warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
@@ -100,50 +100,56 @@ class ModelConverter:
100
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
101
  """
102
  Convert the model to ONNX format, always exporting attention maps.
103
- Relocate generation parameters and suppress warnings in subprocess.
104
- Strip out static relocation and tracer warnings from stderr.
105
  """
106
  try:
107
- # Prepare local model folder for config edits
108
  model_dir = self.config.repo_path / "models" / input_model_id
109
  model_dir.mkdir(parents=True, exist_ok=True)
110
 
111
- # Reload and relocate generation parameters
112
- config = AutoConfig.from_pretrained(input_model_id)
113
- gen_cfg = GenerationConfig.from_model_config(config)
114
- for key in gen_cfg.to_dict().keys():
115
- if hasattr(config, key):
116
- setattr(config, key, None)
117
- config.save_pretrained(model_dir)
118
- gen_cfg.save_pretrained(model_dir)
119
-
120
- # Build command with warning suppression flag
 
 
 
 
 
121
  cmd = [
122
  sys.executable,
123
- "-W", "ignore::torch.jit.TracerWarning",
124
  "-m", "scripts.convert",
125
  "--quantize",
126
  "--trust_remote_code",
127
  "--model_id", input_model_id,
128
  "--output_attentions",
129
  ]
 
 
130
  result = subprocess.run(
131
  cmd,
132
  cwd=self.config.repo_path,
133
  capture_output=True,
134
  text=True,
135
- env=os.environ.copy(),
136
  )
137
 
138
- # Filter stderr lines
139
- filtered = []
140
  for ln in result.stderr.splitlines():
141
- if ln.startswith("Moving the following attributes"): # relocation warning
142
  continue
143
- if "TracerWarning" in ln: # any tracer warnings
144
  continue
145
- filtered.append(ln)
146
- stderr = "\n".join(filtered)
147
 
148
  if result.returncode != 0:
149
  return False, stderr
 
13
  from torch.jit import TracerWarning
14
  from transformers import AutoConfig, GenerationConfig
15
 
16
+ # Suppress TorchScript tracer warnings in this process
17
  warnings.filterwarnings("ignore", category=TracerWarning)
18
 
19
  logging.basicConfig(level=logging.INFO)
 
100
  def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
101
  """
102
  Convert the model to ONNX format, always exporting attention maps.
103
+ Relocate generation parameters, suppress tracer warnings, and
104
+ strip out both relocation and tracer warnings from stderr.
105
  """
106
  try:
107
+ # Prepare local directory for config edits
108
  model_dir = self.config.repo_path / "models" / input_model_id
109
  model_dir.mkdir(parents=True, exist_ok=True)
110
 
111
+ # Load and relocate generation parameters
112
+ base_config = AutoConfig.from_pretrained(input_model_id)
113
+ gen_config = GenerationConfig.from_model_config(base_config)
114
+ # Remove generation params from base config
115
+ for key in gen_config.to_dict():
116
+ if hasattr(base_config, key):
117
+ setattr(base_config, key, None)
118
+ base_config.save_pretrained(model_dir)
119
+ gen_config.save_pretrained(model_dir)
120
+
121
+ # Set up env to suppress tracer warnings in subprocess
122
+ env = os.environ.copy()
123
+ env["PYTHONWARNINGS"] = "ignore::torch.jit.TracerWarning"
124
+
125
+ # Build conversion command
126
  cmd = [
127
  sys.executable,
 
128
  "-m", "scripts.convert",
129
  "--quantize",
130
  "--trust_remote_code",
131
  "--model_id", input_model_id,
132
  "--output_attentions",
133
  ]
134
+
135
+ # Execute conversion
136
  result = subprocess.run(
137
  cmd,
138
  cwd=self.config.repo_path,
139
  capture_output=True,
140
  text=True,
141
+ env=env,
142
  )
143
 
144
+ # Filter out relocation and tracer warnings
145
+ lines = []
146
  for ln in result.stderr.splitlines():
147
+ if ln.startswith("Moving the following attributes"):
148
  continue
149
+ if "TracerWarning" in ln:
150
  continue
151
+ lines.append(ln)
152
+ stderr = "\n".join(lines)
153
 
154
  if result.returncode != 0:
155
  return False, stderr