Spaces:
Runtime error
Runtime error
import requests | |
import torch | |
import numpy as np | |
import sounddevice as sd | |
from transformers import AutoTokenizer | |
from parler_tts import ParlerTTSForConditionalGeneration | |
# Hugging Face API URL for Roberta model | |
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2" | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
if torch.xpu.is_available(): | |
device = "xpu" | |
torch_dtype = torch.float16 if device != "cpu" else torch.float32 | |
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype) | |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1") | |
def query_roberta(api_token, prompt, context): | |
payload = { | |
"inputs": { | |
"question": prompt, | |
"context": context | |
} | |
} | |
headers = {"Authorization": f"Bearer {api_token}"} | |
response = requests.post(API_URL_ROBERTA, headers=headers, json=payload) | |
try: | |
response_json = response.json() | |
if 'error' in response_json: | |
raise ValueError(response_json['error']) | |
return response_json | |
except ValueError as e: | |
print(f"ValueError: {e}") | |
return {"error": str(e)} | |
except Exception as e: | |
print(f"Exception: {e}") | |
return {"error": "An unexpected error occurred"} | |
def generate_speech(answer): | |
input_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) | |
generation = model.generate(input_ids=input_ids).to(torch.float32) | |
audio_arr = generation.cpu().numpy().squeeze() | |
# Play the generated audio | |
try: | |
sd.play(audio_arr, samplerate=model.config.sampling_rate) | |
sd.wait() # Wait until the audio is done playing | |
except Exception as e: | |
print(f"Error playing audio: {e}") | |
# Handle the error, raise or log it, or provide an alternative approach | |
def gradio_interface(api_token, prompt, context): | |
answer = query_roberta(api_token, prompt, context) | |
if 'error' in answer: | |
return answer['error'], None | |
generate_speech(answer.get('answer', '')) | |
return answer.get('answer', 'No answer found'), None |