Add `output_attentions` flag the conversion of Whisper model
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import tarfile
|
|
7 |
import shutil
|
8 |
from dataclasses import dataclass
|
9 |
from pathlib import Path
|
10 |
-
from typing import
|
11 |
from urllib.request import urlopen, urlretrieve
|
12 |
|
13 |
import streamlit as st
|
@@ -128,21 +128,27 @@ class ModelConverter:
|
|
128 |
)
|
129 |
|
130 |
def convert_model(
|
131 |
-
self,
|
|
|
|
|
|
|
132 |
) -> Tuple[bool, Optional[str]]:
|
133 |
"""Convert the model to ONNX format."""
|
134 |
try:
|
|
|
135 |
if trust_remote_code:
|
136 |
if not self.config.is_using_user_token:
|
137 |
raise Exception(
|
138 |
"Trust Remote Code requires your own HuggingFace token."
|
139 |
)
|
|
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
146 |
|
147 |
if result.returncode != 0:
|
148 |
return False, result.stderr
|
@@ -215,6 +221,12 @@ def main():
|
|
215 |
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
|
216 |
)
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if config.hf_username == input_model_id.split("/")[0]:
|
219 |
same_repo = st.checkbox(
|
220 |
"Do you want to upload the ONNX weights to the same repository?"
|
@@ -244,7 +256,9 @@ def main():
|
|
244 |
|
245 |
with st.spinner("Converting model..."):
|
246 |
success, stderr = converter.convert_model(
|
247 |
-
input_model_id,
|
|
|
|
|
248 |
)
|
249 |
if not success:
|
250 |
st.error(f"Conversion failed: {stderr}")
|
|
|
7 |
import shutil
|
8 |
from dataclasses import dataclass
|
9 |
from pathlib import Path
|
10 |
+
from typing import List, Optional, Tuple
|
11 |
from urllib.request import urlopen, urlretrieve
|
12 |
|
13 |
import streamlit as st
|
|
|
128 |
)
|
129 |
|
130 |
def convert_model(
|
131 |
+
self,
|
132 |
+
input_model_id: str,
|
133 |
+
trust_remote_code=False,
|
134 |
+
output_attentions=False,
|
135 |
) -> Tuple[bool, Optional[str]]:
|
136 |
"""Convert the model to ONNX format."""
|
137 |
try:
|
138 |
+
extra_args: List[str] = []
|
139 |
if trust_remote_code:
|
140 |
if not self.config.is_using_user_token:
|
141 |
raise Exception(
|
142 |
"Trust Remote Code requires your own HuggingFace token."
|
143 |
)
|
144 |
+
extra_args.append("--trust_remote_code")
|
145 |
|
146 |
+
if output_attentions:
|
147 |
+
extra_args.append("--output_attentions")
|
148 |
+
|
149 |
+
result = self._run_conversion_subprocess(
|
150 |
+
input_model_id, extra_args=extra_args or None
|
151 |
+
)
|
152 |
|
153 |
if result.returncode != 0:
|
154 |
return False, result.stderr
|
|
|
221 |
"This option should only be enabled for repositories you trust and in which you have read the code, as it will execute arbitrary code present in the model repository. When this option is enabled, you must use your own Hugging Face write token."
|
222 |
)
|
223 |
|
224 |
+
output_attentions = False
|
225 |
+
if "whisper" in input_model_id.lower():
|
226 |
+
output_attentions = st.toggle(
|
227 |
+
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
|
228 |
+
)
|
229 |
+
|
230 |
if config.hf_username == input_model_id.split("/")[0]:
|
231 |
same_repo = st.checkbox(
|
232 |
"Do you want to upload the ONNX weights to the same repository?"
|
|
|
256 |
|
257 |
with st.spinner("Converting model..."):
|
258 |
success, stderr = converter.convert_model(
|
259 |
+
input_model_id,
|
260 |
+
trust_remote_code=trust_remote_code,
|
261 |
+
output_attentions=output_attentions,
|
262 |
)
|
263 |
if not success:
|
264 |
st.error(f"Conversion failed: {stderr}")
|