Update app.py
Browse files
app.py
CHANGED
|
@@ -39,8 +39,7 @@ from transformers import pipeline, AutoModelForAudioClassification, AutoProcesso
|
|
| 39 |
|
| 40 |
|
| 41 |
# Set device and dtype
|
| 42 |
-
|
| 43 |
-
device="cpu"
|
| 44 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 45 |
lid_model_id = "facebook/mms-lid-126"
|
| 46 |
lid_pipeline = pipeline("audio-classification", model=lid_model_id,device=device)
|
|
@@ -186,7 +185,7 @@ def generate_response(transcription):
|
|
| 186 |
{"role": "user", "content": transcription},
|
| 187 |
]
|
| 188 |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
| 189 |
-
input_ids = tokenized_chat[0].to(
|
| 190 |
if len(input_ids.shape) == 1:
|
| 191 |
input_ids = input_ids.unsqueeze(0)
|
| 192 |
with torch.no_grad():
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
# Set device and dtype
|
| 42 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 43 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 44 |
lid_model_id = "facebook/mms-lid-126"
|
| 45 |
lid_pipeline = pipeline("audio-classification", model=lid_model_id,device=device)
|
|
|
|
| 185 |
{"role": "user", "content": transcription},
|
| 186 |
]
|
| 187 |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
|
| 188 |
+
input_ids = tokenized_chat[0].to(device)
|
| 189 |
if len(input_ids.shape) == 1:
|
| 190 |
input_ids = input_ids.unsqueeze(0)
|
| 191 |
with torch.no_grad():
|