Spaces:
Runtime error
Runtime error
import requests | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from pydub import AudioSegment | |
import simpleaudio as sa | |
# Hugging Face API URL for Roberta model | |
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2" | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
if torch.xpu.is_available(): | |
device = "xpu" | |
torch_dtype = torch.float16 if device != "cpu" else torch.float32 | |
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype) | |
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1") | |
def query_roberta(api_token, prompt, context): | |
payload = { | |
"inputs": { | |
"question": prompt, | |
"context": context | |
} | |
} | |
headers = {"Authorization": f"Bearer {api_token}"} | |
response = requests.post(API_URL_ROBERTA, headers=headers, json=payload) | |
try: | |
response_json = response.json() | |
if 'error' in response_json: | |
raise ValueError(response_json['error']) | |
return response_json | |
except ValueError as e: | |
print(f"ValueError: {e}") | |
return {"error": str(e)} | |
except Exception as e: | |
print(f"Exception: {e}") | |
return {"error": "An unexpected error occurred"} | |
def generate_speech(answer): | |
input_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device) | |
generation = model.generate(input_ids=input_ids).to(torch.float32) | |
audio_arr = generation.cpu().numpy().squeeze() | |
# Convert numpy array to audio segment | |
audio_segment = AudioSegment( | |
audio_arr.tobytes(), | |
frame_rate=model.config.sampling_rate, | |
sample_width=audio_arr.dtype.itemsize, | |
channels=1 | |
) | |
# Play the audio using simpleaudio | |
try: | |
play_obj = sa.play_buffer( | |
audio_segment.raw_data, | |
num_channels=1, | |
bytes_per_sample=audio_segment.sample_width, | |
sample_rate=audio_segment.frame_rate | |
) | |
play_obj.wait_done() # Wait until the audio is done playing | |
except Exception as e: | |
print(f"Error playing audio: {e}") | |
def gradio_interface(api_token, prompt, context): | |
answer = query_roberta(api_token, prompt, context) | |
if 'error' in answer: | |
return answer['error'], None | |
generate_speech(answer.get('answer', '')) | |
return answer.get('answer', 'No answer found'), None |