RobertaSpeak / LLMwithvoice.py
ariankhalfani's picture
Update LLMwithvoice.py
9d4db1c verified
raw
history blame
3.17 kB
import requests
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from parler_tts import ParlerTTSForConditionalGeneration
from pydub import AudioSegment
import simpleaudio as sa
# Hugging Face API URL for Roberta model
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
# Determine the device to run the models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if device.type != "cpu" else torch.float32
# Load the ParlerTTS model and tokenizer
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device, dtype=torch_dtype)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")
# Function to query the RoBERTa model
def query_roberta(api_token, prompt, context):
payload = {
"inputs": {
"question": prompt,
"context": context
}
}
headers = {"Authorization": f"Bearer {api_token}"}
response = requests.post(API_URL_ROBERTA, headers=headers, json=payload)
try:
response_json = response.json()
if 'error' in response_json:
raise ValueError(response_json['error'])
return response_json
except ValueError as e:
print(f"ValueError: {e}")
return {"error": str(e)}
except Exception as e:
print(f"Exception: {e}")
return {"error": "An unexpected error occurred"}
# Function to generate speech from text
def generate_speech(answer):
input_ids = tokenizer(answer, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids)
audio_arr = generation.cpu().numpy().squeeze()
# Convert numpy array to audio segment
audio_segment = AudioSegment(
audio_arr.tobytes(),
frame_rate=model.config.sampling_rate,
sample_width=audio_arr.dtype.itemsize,
channels=1
)
# Play the audio using simpleaudio
try:
play_obj = sa.play_buffer(
audio_segment.raw_data,
num_channels=1,
bytes_per_sample=audio_segment.sample_width,
sample_rate=audio_segment.frame_rate
)
play_obj.wait_done() # Wait until the audio is done playing
except Exception as e:
print(f"Error playing audio: {e}")
# Function to interface with Gradio
def gradio_interface(api_token, prompt, context):
answer = query_roberta(api_token, prompt, context)
if 'error' in answer:
return answer['error'], None
generate_speech(answer.get('answer', ''))
return answer.get('answer', 'No answer found'), None
# Example usage
if __name__ == "__main__":
api_token = "your_huggingface_api_token"
prompt = "What is the capital of France?"
context = "France, in Western Europe, encompasses medieval cities, alpine villages, and Mediterranean beaches. Paris, its capital, is famed for its fashion houses, classical art museums including the Louvre, and monuments like the Eiffel Tower."
answer, _ = gradio_interface(api_token, prompt, context)
print("Answer:", answer)