Pijush2023 commited on
Commit
9965e97
·
verified ·
1 Parent(s): 21fc9ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -46
app.py CHANGED
@@ -1,51 +1,322 @@
1
- # app.py
2
 
 
 
 
3
  import os
4
- import sys
5
- import pathlib
6
- import subprocess
7
  import gradio as gr
8
- from fam.llm.fast_inference import TTS
9
-
10
- # Clone the repository
11
- if not os.path.exists("metavoice-src"):
12
- subprocess.run(["git", "clone", "https://github.com/metavoiceio/metavoice-src.git"])
13
- os.chdir("metavoice-src")
14
-
15
- # Install dependencies
16
- subprocess.run(["sudo", "apt", "install", "pipx", "-y"])
17
- subprocess.run(["pipx", "install", "poetry"])
18
- subprocess.run(["pipx", "run", "poetry", "install"])
19
- subprocess.run(["pipx", "run", "poetry", "run", "pip", "install", "torch==2.2.1", "torchaudio==2.2.1"])
20
-
21
- # Get the poetry environment path
22
- result = subprocess.run(["pipx", "run", "poetry", "env", "list"], capture_output=True, text=True)
23
- venv = result.stdout.split()[0]
24
- with open("poetry_env.txt", "w") as f:
25
- f.write(venv)
26
-
27
- # Add the virtual environment to the system path
28
- venv_path = pathlib.Path("poetry_env.txt").read_text().strip("\n")
29
- sys.path.append(f"{venv_path}/lib/python3.10/site-packages")
30
-
31
- # Initialize TTS
32
- tts = TTS()
33
-
34
- def text_to_speech(text):
35
- wav_file = tts.synthesise(
36
- text=text,
37
- spk_ref_path="assets/bria.mp3" # Specify your speaker reference file path
38
- )
39
- return wav_file
40
-
41
- # Create Gradio interface
42
- interface = gr.Interface(
43
- fn=text_to_speech,
44
- inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
45
- outputs=gr.Audio(type="numpy", label="Generated Audio"),
46
- title="MetaVoice-1B Text to Speech",
47
- description="Enter text to convert it into speech using the MetaVoice-1B model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
 
50
- # Launch the Gradio interface
51
- interface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
 
3
+ # Set up logging
4
+ logging.basicConfig(level=logging.DEBUG)
5
+ from langchain_openai import OpenAIEmbeddings
6
  import os
7
+ import re
8
+ import folium
 
9
  import gradio as gr
10
+ import time
11
+ import requests
12
+ from googlemaps import Client as GoogleMapsClient
13
+ from gtts import gTTS
14
+ import tempfile
15
+ import string
16
+
17
+ embeddings = OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY'])
18
+
19
+ from pinecone import Pinecone, ServerlessSpec
20
+ pc = Pinecone(api_key=os.environ['PINECONE_API_KEY'])
21
+
22
+ index_name = "omaha-details"
23
+
24
+ from langchain_pinecone import PineconeVectorStore
25
+
26
+ vectorstore = PineconeVectorStore(index_name=index_name, embedding=embeddings)
27
+ retriever = vectorstore.as_retriever(search_kwargs={'k': 5})
28
+
29
+ from langchain_openai import ChatOpenAI
30
+ from langchain.prompts import PromptTemplate
31
+ from langchain.chains import RetrievalQA
32
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
33
+ from langchain.agents import Tool, initialize_agent
34
+
35
+ # Build prompt
36
+ template1 = """You are an expert concierge who is helpful and a renowned guide for Omaha, Nebraska. Use the following pieces of context,
37
+ memory, and message history, along with your knowledge of perennial events in Omaha, Nebraska, to answer the question at the end.
38
+ If you don't know the answer, just say "Homie, I need to get more data for this," and don't try to make up an answer.
39
+ Use fifteen sentences maximum. Keep the answer as detailed as possible. Always include the address, time, date, and
40
+ event type and description. Always say "It was my pleasure!" at the end of the answer.
41
+ {context}
42
+ Question: {question}
43
+ Helpful Answer:"""
44
+
45
+ template2 = """You are an expert guide of Omaha, Nebraska's perennial events.
46
+ With the context, memory, and message history provided, answer the question in as crisp as possible. Always include the time, date, and
47
+ event type and description only apart from that don't give any other details. Always say "It was my pleasure!" at the end of the answer.
48
+ If you don't know the answer, simply say, "Homie, I need to get more data for this," without making up an answer.
49
+
50
+ {context}
51
+ Question: {question}
52
+ Helpful Answer:"""
53
+
54
+ QA_CHAIN_PROMPT_1 = PromptTemplate(input_variables=["context", "question"], template=template1)
55
+ QA_CHAIN_PROMPT_2 = PromptTemplate(input_variables=["context", "question"], template=template2)
56
+
57
+ chat_model = ChatOpenAI(api_key=os.environ['OPENAI_API_KEY'],
58
+ temperature=0, model='gpt-4o')
59
+
60
+ conversational_memory = ConversationBufferWindowMemory(
61
+ memory_key='chat_history',
62
+ k=10,
63
+ return_messages=True
64
  )
65
 
66
+ # Define the retrieval QA chain
67
+ def build_qa_chain(prompt_template):
68
+ qa_chain = RetrievalQA.from_chain_type(
69
+ llm=chat_model,
70
+ chain_type="stuff",
71
+ retriever=retriever,
72
+ chain_type_kwargs={"prompt": prompt_template}
73
+ )
74
+ tools = [
75
+ Tool(
76
+ name='Knowledge Base',
77
+ func=qa_chain,
78
+ description='use this tool when answering general knowledge queries to get more information about the topic'
79
+ )
80
+ ]
81
+ return qa_chain, tools
82
+
83
+ # Define the agent initializer
84
+ def initialize_agent_with_prompt(prompt_template):
85
+ qa_chain, tools = build_qa_chain(prompt_template)
86
+ agent = initialize_agent(
87
+ agent='chat-conversational-react-description',
88
+ tools=tools,
89
+ llm=chat_model,
90
+ verbose=False,
91
+ max_iteration=5,
92
+ early_stopping_method='generate',
93
+ memory=conversational_memory
94
+ )
95
+ return agent
96
+
97
+ # Define the function to generate answers
98
+ def generate_answer(message, choice):
99
+ logging.debug(f"generate_answer called with prompt_choice: {choice}")
100
+ if choice == "Details":
101
+ agent = initialize_agent_with_prompt(QA_CHAIN_PROMPT_1)
102
+ elif choice == "Conversational":
103
+ agent = initialize_agent_with_prompt(QA_CHAIN_PROMPT_2)
104
+ else:
105
+ logging.error(f"Invalid prompt_choice: {choice}. Defaulting to 'Details'")
106
+ agent = initialize_agent_with_prompt(QA_CHAIN_PROMPT_1)
107
+
108
+ response = agent(message)
109
+ return response['output']
110
+
111
+ def bot(history, choice):
112
+ if not history:
113
+ return history
114
+ response = generate_answer(history[-1][0], choice)
115
+ history[-1][1] = ""
116
+ for character in response:
117
+ history[-1][1] += character
118
+ time.sleep(0.05)
119
+ yield history
120
+
121
+ def add_message(history, message):
122
+ history.append((message, None))
123
+ return history, gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False)
124
+
125
+ def print_like_dislike(x: gr.LikeData):
126
+ print(x.index, x.value, x.liked)
127
+
128
+ # Function to extract addresses from the chatbot's response
129
+ def extract_addresses(response):
130
+ address_pattern_1 = r'([A-Z].*,\sOmaha,\sNE\s\d{5})'
131
+ address_pattern_2 = r'(\d{4}\s.*,\sOmaha,\sNE\s\d{5})'
132
+ address_pattern_3 = r'([A-Z].*,\sNE\s\d{5})'
133
+ address_pattern_4 = r'([A-Z].*,.*\sSt,\sOmaha,\sNE\s\d{5})'
134
+ address_pattern_5 = r'([A-Z].*,.*\sStreets,\sOmaha,\sNE\s\d{5})'
135
+ address_pattern_6 = r'(\d{2}.*\sStreets)'
136
+ address_pattern_7 = r'([A-Z].*\s\d{2},\sOmaha,\sNE\s\d{5})'
137
+ addresses = re.findall(address_pattern_1, response) + re.findall(address_pattern_2, response) + \
138
+ re.findall(address_pattern_3, response) + re.findall(address_pattern_4, response) + \
139
+ re.findall(address_pattern_5, response) + re.findall(address_pattern_6, response) + \
140
+ re.findall(address_pattern_7, response)
141
+ return addresses
142
+
143
+ # Store all found addresses
144
+ all_addresses = []
145
+
146
+ # Map generation function using Google Maps Geocoding API
147
+ def generate_map(location_names):
148
+ global all_addresses
149
+ all_addresses.extend(location_names)
150
+
151
+ api_key = os.environ['GOOGLEMAPS_API_KEY']
152
+ gmaps = GoogleMapsClient(key=api_key)
153
+
154
+ m = folium.Map(location=[41.2565, -95.9345], zoom_start=12)
155
+
156
+ for location_name in all_addresses:
157
+ geocode_result = gmaps.geocode(location_name)
158
+ if geocode_result:
159
+ location = geocode_result[0]['geometry']['location']
160
+ folium.Marker(
161
+ [location['lat'], location['lng']],
162
+ tooltip=f"{geocode_result[0]['formatted_address']}"
163
+ ).add_to(m)
164
+
165
+ map_html = m._repr_html_()
166
+ return map_html
167
+
168
+ # Function to fetch local news
169
+ def fetch_local_news():
170
+ api_key = os.environ['SERP_API']
171
+ url = f'https://serpapi.com/search.json?engine=google_news&q=ohama headline&api_key={api_key}'
172
+
173
+ response = requests.get(url)
174
+ if response.status_code == 200:
175
+ results = response.json().get("news_results", [])
176
+ news_html = "<h2>Omaha Today Headline </h2>"
177
+ for index, result in enumerate(results[:10]):
178
+ title = result.get("title", "No title")
179
+ link = result.get("link", "#")
180
+ snippet = result.get("snippet", "")
181
+ news_html += f"<p>{index + 1}. <a href='{link}' target='_blank'>{title}</a><br>{snippet}</p>"
182
+ return news_html
183
+ else:
184
+ return "<p>Failed to fetch local news</p>"
185
+
186
+ # Function to fetch local events
187
+ def fetch_local_events():
188
+ api_key = os.environ['SERP_API']
189
+ url = f'https://serpapi.com/search.json?engine=google_events&q=Events+in+Omaha&hl=en&gl=us&api_key={api_key}'
190
+
191
+ response = requests.get(url)
192
+ if response.status_code == 200:
193
+ events_results = response.json().get("events_results", [])
194
+ events_text = "<h2>Local Events </h2>"
195
+ for index, event in enumerate(events_results):
196
+ title = event.get("title", "No title")
197
+ date = event.get("date", "No date")
198
+ location = event.get("address", "No location")
199
+ link = event.get("link", "#")
200
+ events_text += f"<p>{index + 1}. {title}<br> Date: {date}<br> Location: {location}<br> <a href='{link}' target='_blank'>Link :</a> <br>"
201
+ return events_text
202
+ else:
203
+ return "Failed to fetch local events"
204
+
205
+ # Function to fetch local weather
206
+ def fetch_local_weather():
207
+ try:
208
+ api_key = os.environ['WEATHER_API']
209
+ url = f'https://weather.visualcrossing.com/VisualCrossingWebServices/rest/services/timeline/omaha?unitGroup=metric&include=events%2Calerts%2Chours%2Cdays%2Ccurrent&key={api_key}'
210
+ response = requests.get(url)
211
+ response.raise_for_status()
212
+ jsonData = response.json()
213
+
214
+ current_conditions = jsonData.get("currentConditions", {})
215
+ temp = current_conditions.get("temp", "N/A")
216
+ condition = current_conditions.get("conditions", "N/A")
217
+ humidity = current_conditions.get("humidity", "N/A")
218
+
219
+ weather_html = f"<h2>Local Weather</h2>"
220
+ weather_html += f"<p>Temperature: {temp}°C</p>"
221
+ weather_html += f"<p>Condition: {condition}</p>"
222
+ weather_html += f"<p>Humidity: {humidity}%</p>"
223
+
224
+ return weather_html
225
+ except requests.exceptions.RequestException as e:
226
+ return f"<p>Failed to fetch local weather: {e}</p>"
227
+
228
+ # Voice Control
229
+ import numpy as np
230
+ import torch
231
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
232
+
233
+ model_id = 'openai/whisper-large-v3'
234
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
235
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
236
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype,
237
+ #low_cpu_mem_usage=True,
238
+ use_safetensors=True).to(device)
239
+ processor = AutoProcessor.from_pretrained(model_id)
240
+
241
+ # Optimized ASR pipeline
242
+ pipe_asr = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, max_new_tokens=128, chunk_length_s=15, batch_size=16, torch_dtype=torch_dtype, device=device, return_timestamps=True)
243
+
244
+ base_audio_drive = "/data/audio"
245
+
246
+ import numpy as np
247
+
248
+ def transcribe_function(stream, new_chunk):
249
+ try:
250
+ sr, y = new_chunk[0], new_chunk[1]
251
+ except TypeError:
252
+ print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
253
+ return stream, "", None
254
+
255
+ y = y.astype(np.float32) / np.max(np.abs(y))
256
+
257
+ if stream is not None:
258
+ stream = np.concatenate([stream, y])
259
+ else:
260
+ stream = y
261
+
262
+ result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False)
263
+
264
+ full_text = result.get("text", "")
265
+
266
+ return stream, full_text, result
267
+
268
+ # Map Retrieval Function for location finder
269
+ def update_map_with_response(history):
270
+ if not history:
271
+ return ""
272
+ response = history[-1][1]
273
+ addresses = extract_addresses(response)
274
+ return generate_map(addresses)
275
+
276
+ def clear_textbox():
277
+ return ""
278
+
279
+ # Gradio Blocks interface
280
+ with gr.Blocks(theme='rawrsor1/Everforest') as demo:
281
+ with gr.Row():
282
+ with gr.Column():
283
+ chatbot = gr.Chatbot([], elem_id="chatbot", bubble_full_width=False)
284
+
285
+ with gr.Column():
286
+ weather_output = gr.HTML(value=fetch_local_weather())
287
+
288
+ with gr.Column():
289
+ news_output = gr.HTML(value=fetch_local_news())
290
+
291
+ def setup_ui():
292
+ state = gr.State()
293
+ with gr.Row():
294
+ with gr.Column():
295
+ gr.Markdown("Choose the prompt")
296
+ choice = gr.Radio(label="Choose a prompt", choices=["Details", "Conversational"], value="Details")
297
+
298
+ with gr.Column(): # Larger scale for the right column
299
+ gr.Markdown("Enter the query / Voice Output")
300
+ chat_input = gr.Textbox(show_copy_button=True, interactive=True, show_label=False, label="Transcription")
301
+ chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
302
+ bot_msg = chat_msg.then(bot, [chatbot, choice], chatbot, api_name="bot_response")
303
+ bot_msg.then(lambda: gr.Textbox(value="", interactive=True, placeholder="Enter message or upload file...", show_label=False), None, [chat_input])
304
+ chatbot.like(print_like_dislike, None, None)
305
+ clear_button = gr.Button("Clear")
306
+ clear_button.click(fn=clear_textbox, inputs=None, outputs=chat_input)
307
+
308
+ with gr.Column(): # Smaller scale for the left column
309
+ gr.Markdown("Stream your Voice")
310
+ audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy')
311
+ audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, chat_input], api_name="SAMLOne_real_time")
312
+
313
+ with gr.Row():
314
+ with gr.Column():
315
+ gr.Markdown("Locate the Events")
316
+ location_output = gr.HTML()
317
+ bot_msg.then(update_map_with_response, chatbot, location_output)
318
+
319
+ setup_ui()
320
+
321
+ demo.queue()
322
+ demo.launch(share=True)