ariankhalfani commited on
Commit
0f1374b
·
verified ·
1 Parent(s): 5530ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -54
app.py CHANGED
@@ -1,79 +1,133 @@
1
  import requests
2
- from pydub import AudioSegment
3
- from io import BytesIO
4
  import gradio as gr
5
  import os
 
 
 
6
 
7
  # Hugging Face API URLs
8
  API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
9
  API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Function to query the RoBERTa model
12
- def query_roberta(api_token, prompt, context):
13
- payload = {
14
- "inputs": {
15
- "question": prompt,
16
- "context": context
17
- }
18
- }
19
- headers = {"Authorization": f"Bearer {api_token}"}
20
- response = requests.post(API_URL_ROBERTA, headers=headers, json=payload)
21
- try:
22
- response.raise_for_status() # Raise an error for bad responses
23
- return response.json()
24
- except requests.exceptions.HTTPError as e:
25
- return {"error": f"HTTP error occurred: {e}"}
26
- except ValueError as e:
27
- return {"error": f"Value error occurred: {e}"}
28
- except Exception as e:
29
- return {"error": f"An unexpected error occurred: {e}"}
30
 
31
  # Function to generate speech from text using ESPnet TTS
32
- def generate_speech(api_token, answer):
33
- payload = {
34
- "inputs": answer,
35
- }
36
- headers = {"Authorization": f"Bearer {api_token}"}
37
- response = requests.post(API_URL_TTS, headers=headers, json=payload)
38
- try:
39
- response.raise_for_status() # Raise an error for bad responses
40
- audio = response.content
41
- audio_segment = AudioSegment.from_file(BytesIO(audio), format="flac")
42
- audio_file_path = "/tmp/answer.wav"
43
- audio_segment.export(audio_file_path, format="wav")
44
- return audio_file_path
45
- except requests.exceptions.HTTPError as e:
46
- print(f"HTTP error occurred: {e}")
47
- return None
48
- except Exception as e:
49
- print(f"An unexpected error occurred: {e}")
50
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Function to interface with Gradio
53
- def gradio_interface(api_token, context, prompt):
54
- answer = query_roberta(api_token, prompt, context)
55
- if 'error' in answer:
56
- return answer['error'], None
57
- answer_text = answer.get('answer', 'No answer found')
58
- audio_file_path = generate_speech(api_token, answer_text)
59
- return answer_text, audio_file_path
60
 
61
  # Define the Gradio interface
62
  iface = gr.Interface(
63
- fn=gradio_interface,
64
  inputs=[
65
- gr.Textbox(type="password", lines=1, label="Hugging Face API Token", placeholder="Enter your Hugging Face API token here..."),
66
  gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
67
- gr.Textbox(lines=1, label="Question", placeholder="Enter your question here...")
68
  ],
69
  outputs=[
70
  gr.Textbox(label="Answer"),
71
- gr.Audio(label="Answer as Speech", type="filepath") # Changed to filepath type
72
  ],
73
  title="Chat with Roberta with Voice",
74
- description="Ask questions based on a provided context using the Roberta model and hear the response via text-to-speech."
75
  )
76
 
77
  # Launch the Gradio app
78
- if __name__ == "__main__":
79
- iface.launch(share=True)
 
1
  import requests
 
 
2
  import gradio as gr
3
  import os
4
+ from pydub import AudioSegment
5
+ from io import BytesIO
6
+ import time
7
 
8
  # Hugging Face API URLs
9
  API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
10
  API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
11
+ API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
12
+
13
+ # Hugging Face API Token from environment variable
14
+ API_TOKEN = os.getenv("API_KEY")
15
+ HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
16
+
17
+ # Retry settings
18
+ MAX_RETRIES = 5
19
+ RETRY_DELAY = 1 # seconds
20
+
21
+ # Function to query the Whisper model for audio transcription
22
+ def query_whisper(audio_path):
23
+ for attempt in range(MAX_RETRIES):
24
+ try:
25
+ if not audio_path:
26
+ raise ValueError("Audio file path is None")
27
+ if not os.path.exists(audio_path):
28
+ raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
29
+
30
+ with open(audio_path, "rb") as f:
31
+ data = f.read()
32
+
33
+ response = requests.post(API_URL_WHISPER, headers=HEADERS, data=data)
34
+ response.raise_for_status()
35
+ return response.json()
36
+
37
+ except Exception as e:
38
+ print(f"Whisper model query failed: {e}")
39
+ if attempt < MAX_RETRIES - 1:
40
+ print(f"Retrying Whisper model query ({attempt + 1}/{MAX_RETRIES})...")
41
+ time.sleep(RETRY_DELAY)
42
+ else:
43
+ return {"error": str(e)}
44
 
45
  # Function to query the RoBERTa model
46
+ def query_roberta(prompt, context):
47
+ payload = {"inputs": {"question": prompt, "context": context}}
48
+
49
+ for attempt in range(MAX_RETRIES):
50
+ try:
51
+ response = requests.post(API_URL_ROBERTA, headers=HEADERS, json=payload)
52
+ response.raise_for_status()
53
+ return response.json()
54
+ except Exception as e:
55
+ print(f"RoBERTa model query failed: {e}")
56
+ if attempt < MAX_RETRIES - 1:
57
+ print(f"Retrying RoBERTa model query ({attempt + 1}/{MAX_RETRIES})...")
58
+ time.sleep(RETRY_DELAY)
59
+ else:
60
+ return {"error": str(e)}
 
 
 
61
 
62
  # Function to generate speech from text using ESPnet TTS
63
+ def generate_speech(answer):
64
+ payload = {"inputs": answer}
65
+
66
+ for attempt in range(MAX_RETRIES):
67
+ try:
68
+ response = requests.post(API_URL_TTS, headers=HEADERS, json=payload)
69
+ response.raise_for_status()
70
+ audio = response.content
71
+
72
+ audio_segment = AudioSegment.from_file(BytesIO(audio), format="flac")
73
+ audio_file_path = "/tmp/answer.wav"
74
+ audio_segment.export(audio_file_path, format="wav")
75
+ return audio_file_path
76
+ except Exception as e:
77
+ print(f"ESPnet TTS query failed: {e}")
78
+ if attempt < MAX_RETRIES - 1:
79
+ print(f"Retrying ESPnet TTS query ({attempt + 1}/{MAX_RETRIES})...")
80
+ time.sleep(RETRY_DELAY)
81
+ else:
82
+ return {"error": str(e)}
83
+
84
+ # Function to handle the entire process
85
+ def handle_all(context, audio):
86
+ for attempt in range(MAX_RETRIES):
87
+ try:
88
+ # Step 1: Transcribe audio
89
+ transcription = query_whisper(audio)
90
+ if 'error' in transcription:
91
+ raise Exception(transcription['error'])
92
+
93
+ question = transcription.get("text", "No transcription found")
94
+
95
+ # Step 2: Get answer from RoBERTa
96
+ answer = query_roberta(question, context)
97
+ if 'error' in answer:
98
+ raise Exception(answer['error'])
99
+
100
+ answer_text = answer.get('answer', 'No answer found')
101
+
102
+ # Step 3: Generate speech from answer
103
+ audio_file_path = generate_speech(answer_text)
104
+ if 'error' in audio_file_path:
105
+ raise Exception(audio_file_path['error'])
106
+
107
+ return answer_text, audio_file_path
108
 
109
+ except Exception as e:
110
+ print(f"Process failed: {e}")
111
+ if attempt < MAX_RETRIES - 1:
112
+ print(f"Retrying entire process ({attempt + 1}/{MAX_RETRIES})...")
113
+ time.sleep(RETRY_DELAY)
114
+ else:
115
+ return str(e), None
 
116
 
117
  # Define the Gradio interface
118
  iface = gr.Interface(
119
+ fn=handle_all,
120
  inputs=[
 
121
  gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
122
+ gr.Audio(type="filepath", label="Record your voice")
123
  ],
124
  outputs=[
125
  gr.Textbox(label="Answer"),
126
+ gr.Audio(label="Answer as Speech", type="filepath")
127
  ],
128
  title="Chat with Roberta with Voice",
129
+ description="Record your voice, get the transcription, use it as a question for the Roberta model, and hear the response via text-to-speech."
130
  )
131
 
132
  # Launch the Gradio app
133
+ iface.launch()