Spaces:
Runtime error
Runtime error
Update LLMwithvoice.py
Browse files- LLMwithvoice.py +19 -11
LLMwithvoice.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import requests
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
-
from transformers import AutoTokenizer
|
5 |
from parler_tts import ParlerTTSForConditionalGeneration
|
6 |
from pydub import AudioSegment
|
7 |
import simpleaudio as sa
|
@@ -9,18 +9,15 @@ import simpleaudio as sa
|
|
9 |
# Hugging Face API URL for Roberta model
|
10 |
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
|
11 |
|
12 |
-
device
|
13 |
-
if torch.cuda.is_available()
|
14 |
-
|
15 |
-
if torch.backends.mps.is_available():
|
16 |
-
device = "mps"
|
17 |
-
if torch.xpu.is_available():
|
18 |
-
device = "xpu"
|
19 |
-
torch_dtype = torch.float16 if device != "cpu" else torch.float32
|
20 |
|
|
|
21 |
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
|
22 |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
|
23 |
|
|
|
24 |
def query_roberta(api_token, prompt, context):
|
25 |
payload = {
|
26 |
"inputs": {
|
@@ -42,10 +39,11 @@ def query_roberta(api_token, prompt, context):
|
|
42 |
print(f"Exception: {e}")
|
43 |
return {"error": "An unexpected error occurred"}
|
44 |
|
|
|
45 |
def generate_speech(answer):
|
46 |
input_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
|
47 |
|
48 |
-
generation = model.generate(input_ids=input_ids)
|
49 |
audio_arr = generation.cpu().numpy().squeeze()
|
50 |
|
51 |
# Convert numpy array to audio segment
|
@@ -68,9 +66,19 @@ def generate_speech(answer):
|
|
68 |
except Exception as e:
|
69 |
print(f"Error playing audio: {e}")
|
70 |
|
|
|
71 |
def gradio_interface(api_token, prompt, context):
|
72 |
answer = query_roberta(api_token, prompt, context)
|
73 |
if 'error' in answer:
|
74 |
return answer['error'], None
|
75 |
generate_speech(answer.get('answer', ''))
|
76 |
-
return answer.get('answer', 'No answer found'), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
import torch
|
3 |
import numpy as np
|
4 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
5 |
from parler_tts import ParlerTTSForConditionalGeneration
|
6 |
from pydub import AudioSegment
|
7 |
import simpleaudio as sa
|
|
|
9 |
# Hugging Face API URL for Roberta model
|
10 |
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
|
11 |
|
12 |
+
# Determine the device to run the models
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
torch_dtype = torch.float16 if device.type != "cpu" else torch.float32
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
# Load the ParlerTTS model and tokenizer
|
17 |
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
|
18 |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
|
19 |
|
20 |
+
# Function to query the RoBERTa model
|
21 |
def query_roberta(api_token, prompt, context):
|
22 |
payload = {
|
23 |
"inputs": {
|
|
|
39 |
print(f"Exception: {e}")
|
40 |
return {"error": "An unexpected error occurred"}
|
41 |
|
42 |
+
# Function to generate speech from text
|
43 |
def generate_speech(answer):
|
44 |
input_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
|
45 |
|
46 |
+
generation = model.generate(input_ids=input_ids)
|
47 |
audio_arr = generation.cpu().numpy().squeeze()
|
48 |
|
49 |
# Convert numpy array to audio segment
|
|
|
66 |
except Exception as e:
|
67 |
print(f"Error playing audio: {e}")
|
68 |
|
69 |
+
# Function to interface with Gradio
|
70 |
def gradio_interface(api_token, prompt, context):
|
71 |
answer = query_roberta(api_token, prompt, context)
|
72 |
if 'error' in answer:
|
73 |
return answer['error'], None
|
74 |
generate_speech(answer.get('answer', ''))
|
75 |
+
return answer.get('answer', 'No answer found'), None
|
76 |
+
|
77 |
+
# Example usage
|
78 |
+
if __name__ == "__main__":
|
79 |
+
api_token = "your_huggingface_api_token"
|
80 |
+
prompt = "What is the capital of France?"
|
81 |
+
context = "France, in Western Europe, encompasses medieval cities, alpine villages, and Mediterranean beaches. Paris, its capital, is famed for its fashion houses, classical art museums including the Louvre, and monuments like the Eiffel Tower."
|
82 |
+
|
83 |
+
answer, _ = gradio_interface(api_token, prompt, context)
|
84 |
+
print("Answer:", answer)
|