VanguardAI commited on
Commit
e7e0762
·
verified ·
1 Parent(s): dff714c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -137
app.py CHANGED
@@ -8,21 +8,16 @@ from transformers import AutoModel, AutoTokenizer
8
  from diffusers import StableDiffusion3Pipeline
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
- from langchain.agents import AgentExecutor, create_react_agent, initialize_agent, Tool
12
- from langchain.agents import AgentType
13
  from langchain_groq import ChatGroq
14
- from langchain.prompts import PromptTemplate
15
  from PIL import Image
16
  from tavily import TavilyClient
17
- import requests
18
- from huggingface_hub import hf_hub_download
19
- from safetensors.torch import load_file
20
  from langchain.schema import AIMessage
21
  from langchain_community.embeddings import HuggingFaceEmbeddings
22
  from langchain_community.vectorstores import FAISS
23
  from langchain_community.document_loaders import TextLoader
24
  from langchain.text_splitter import CharacterTextSplitter
25
  from langchain.chains import RetrievalQA
 
26
 
27
  # Initialize models and clients
28
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
@@ -53,54 +48,46 @@ def play_voice_output(response):
53
  sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
54
  return "output.wav"
55
 
56
- # NumPy Code Calculator Tool
57
- class NumpyCodeCalculator(Tool):
58
- name = "Calculator"
59
- description = "Useful only for performing numerical computations, not for general searches"
60
-
61
- def _run(self, query: str) -> str:
62
- print("Executing NumpyCodeCalculator tool")
63
- try:
64
- local_dict = {"np": np}
65
- exec(query, local_dict)
66
- result = local_dict.get("result", "No result found")
67
- return str(result)
68
- except Exception as e:
69
- return f"Error: {e}"
70
-
71
- # Web Search Tool
72
- class WebSearch(Tool):
73
- name = "Web"
74
- description = "Useful for advanced web searching beyond general information"
75
-
76
- def _run(self, query: str) -> str:
77
- print("Executing WebSearch tool")
78
- answer = tavily_client.qna_search(query=query)
79
- return answer
80
-
81
- # Image Generation Tool
82
- class ImageGeneration(Tool):
83
- name = "Image"
84
- description = "Useful for generating images based on text descriptions"
85
-
86
- def _run(self, query: str) -> str:
87
- print("Executing ImageGeneration tool")
88
- image = pipe(
89
- query,
90
- negative_prompt="",
91
- num_inference_steps=15,
92
- guidance_scale=7.0,
93
- ).images[0]
94
- image.save("output.jpg")
95
- return "output.jpg"
96
 
97
- # Document Question Answering Tool
98
- class DocumentQuestionAnswering(Tool):
99
- name = "Document"
100
- description = "Useful for answering questions about a specific document"
 
 
 
 
 
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def __init__(self, document):
103
- super().__init__()
104
  self.document = document
105
  self.qa_chain = self._setup_qa_chain()
106
 
@@ -120,79 +107,94 @@ class DocumentQuestionAnswering(Tool):
120
  )
121
  return qa_chain
122
 
123
- def _run(self, query: str) -> str:
124
  print("Executing DocumentQuestionAnswering tool")
125
  response = self.qa_chain.run(query)
126
  return str(response)
127
 
128
-
129
- # Function to handle different input types and choose the right tool
130
  def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
131
  print(f"Handling input: {user_prompt}")
132
 
133
  # Initialize the LLM
134
  llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
135
 
136
- # Define the tools
137
- tools = []
138
-
139
- # Add Image Generation Tool
140
- tools.append(ImageGeneration())
141
-
142
- # Add Calculator Tool
143
- tools.append(NumpyCodeCalculator())
144
-
145
- # Add Web Search Tool if enabled
146
- if websearch:
147
- tools.append(WebSearch())
148
-
149
- # Add Document QA Tool if document is provided
150
- if document:
151
- tools.append(DocumentQuestionAnswering(document))
152
-
153
- # Check if any tools are mentioned in the user prompt
154
- requires_tool = any([tool.name.lower() in user_prompt.lower() for tool in tools])
155
-
156
- # Handle different input scenarios
157
- if image:
158
- print("Processing image input")
159
- image = Image.open(image).convert('RGB')
160
- messages = [{"role": "user", "content": [image, user_prompt]}]
161
- response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
162
- elif audio:
163
  print("Processing audio input")
164
  transcription = client.audio.transcriptions.create(
165
  file=(audio.name, audio.read()),
166
  model="whisper-large-v3"
167
  )
168
  user_prompt = transcription.text
169
- # If tools are required, use an agent
170
- if requires_tool:
171
- agent = initialize_agent(
172
- tools,
173
- llm,
174
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
175
- verbose=True
176
- )
177
- response = agent.run(user_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  else:
179
- response = llm.call(query=user_prompt)
180
- elif requires_tool:
181
- print("Using agent with tools")
182
- agent = initialize_agent(
183
- tools,
184
- llm,
185
- agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
186
- verbose=True
187
- )
188
- response = agent.run(user_prompt)
189
- else:
190
- print("Using LLM directly")
 
191
  response = llm.call(query=user_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  return response
194
 
195
-
196
  def create_ui():
197
  with gr.Blocks(css="""
198
  /* Overall Styling */
@@ -403,40 +405,6 @@ def create_ui():
403
 
404
  return demo
405
 
406
- # Main interface function
407
- @spaces.GPU(duration=720)
408
- def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
409
- print("Starting main_interface function")
410
- vqa_model.to(device='cuda', dtype=torch.bfloat16)
411
- tts_model.to("cuda")
412
- pipe.to("cuda")
413
-
414
- print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
415
-
416
- try:
417
- response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
418
- print("handle_input function executed successfully")
419
- except Exception as e:
420
- print(f"Error in handle_input: {e}")
421
- response = "Error occurred during processing."
422
-
423
- if voice_only:
424
- try:
425
- transcription = client.audio.transcriptions.create(
426
- file=("input.wav", open("input.wav", "rb").read()),
427
- model="whisper-large-v3"
428
- )
429
- user_prompt = transcription.text
430
- response = handle_input(user_prompt)
431
- audio_output = play_voice_output(response)
432
- print("play_voice_output function executed successfully")
433
- return "Response generated.", audio_output
434
- except Exception as e:
435
- print(f"Error in play_voice_output: {e}")
436
- return "Error occurred during voice output.", None
437
- else:
438
- return response, None
439
-
440
  # Launch the UI
441
  demo = create_ui()
442
  demo.launch()
 
8
  from diffusers import StableDiffusion3Pipeline
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
 
 
11
  from langchain_groq import ChatGroq
 
12
  from PIL import Image
13
  from tavily import TavilyClient
 
 
 
14
  from langchain.schema import AIMessage
15
  from langchain_community.embeddings import HuggingFaceEmbeddings
16
  from langchain_community.vectorstores import FAISS
17
  from langchain_community.document_loaders import TextLoader
18
  from langchain.text_splitter import CharacterTextSplitter
19
  from langchain.chains import RetrievalQA
20
+ import json
21
 
22
  # Initialize models and clients
23
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
 
48
  sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
49
  return "output.wav"
50
 
51
+ # Function to classify user input using LLM
52
+ def classify_function(user_prompt):
53
+ prompt = f"""
54
+ You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ - `image_generation`: If the user wants to generate an image.
57
+ - `image_description`: If the user wants to describe an image.
58
+ - `document_summarization`: If the user wants to summarize a document.
59
+ - `text_to_text`: If the user wants a text-based response.
60
+
61
+ Respond with a JSON object containing only the chosen function. For example:
62
+
63
+ ```json
64
+ {{"function": "image_generation"}}
65
+ ```
66
 
67
+ User input: {user_prompt}
68
+ """
69
+
70
+ chat_completion = client.chat.completions.create(
71
+ messages=[
72
+ {
73
+ "role": "user",
74
+ "content": prompt,
75
+ }
76
+ ],
77
+ model="llama3-8b-8192",
78
+ )
79
+
80
+ try:
81
+ response = json.loads(chat_completion.choices[0].message.content)
82
+ function = response.get("function")
83
+ return function
84
+ except json.JSONDecodeError:
85
+ print(f"Error decoding JSON: {chat_completion.choices[0].message.content}")
86
+ return "text_to_text" # Default to text-to-text if JSON parsing fails
87
+
88
+ # Document Question Answering Tool
89
+ class DocumentQuestionAnswering:
90
  def __init__(self, document):
 
91
  self.document = document
92
  self.qa_chain = self._setup_qa_chain()
93
 
 
107
  )
108
  return qa_chain
109
 
110
+ def run(self, query: str) -> str:
111
  print("Executing DocumentQuestionAnswering tool")
112
  response = self.qa_chain.run(query)
113
  return str(response)
114
 
115
+ # Function to handle different input types and choose the right pipeline
 
116
  def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
117
  print(f"Handling input: {user_prompt}")
118
 
119
  # Initialize the LLM
120
  llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
121
 
122
+ # Handle voice-only mode
123
+ if audio:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  print("Processing audio input")
125
  transcription = client.audio.transcriptions.create(
126
  file=(audio.name, audio.read()),
127
  model="whisper-large-v3"
128
  )
129
  user_prompt = transcription.text
130
+ response = llm.call(query=user_prompt)
131
+ audio_output = play_voice_output(response)
132
+ return "Response generated.", audio_output
133
+
134
+ # Handle websearch mode
135
+ if websearch:
136
+ print("Executing Web Search")
137
+ answer = tavily_client.qna_search(query=user_prompt)
138
+ return answer, None
139
+
140
+ # Classify user input using LLM
141
+ function = classify_function(user_prompt)
142
+
143
+ # Handle different functions
144
+ if function == "image_generation":
145
+ print("Executing Image Generation")
146
+ image = pipe(
147
+ user_prompt,
148
+ negative_prompt="",
149
+ num_inference_steps=15,
150
+ guidance_scale=7.0,
151
+ ).images[0]
152
+ image.save("output.jpg")
153
+ return "output.jpg", None
154
+
155
+ elif function == "image_description":
156
+ print("Executing Image Description")
157
+ if image:
158
+ image = Image.open(image).convert('RGB')
159
+ messages = [{"role": "user", "content": [image, user_prompt]}]
160
+ response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
161
+ return response, None
162
  else:
163
+ return "Please upload an image.", None
164
+
165
+ elif function == "document_summarization":
166
+ print("Executing Document Summarization")
167
+ if document:
168
+ document_qa = DocumentQuestionAnswering(document)
169
+ response = document_qa.run(user_prompt)
170
+ return response, None
171
+ else:
172
+ return "Please upload a document.", None
173
+
174
+ else: # function == "text_to_text"
175
+ print("Executing Text-to-Text")
176
  response = llm.call(query=user_prompt)
177
+ return response, None
178
+
179
+ # Main interface function
180
+ @spaces.GPU(duration=720)
181
+ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
182
+ print("Starting main_interface function")
183
+ vqa_model.to(device='cuda', dtype=torch.bfloat16)
184
+ tts_model.to("cuda")
185
+ pipe.to("cuda")
186
+
187
+ print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
188
+
189
+ try:
190
+ response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
191
+ print("handle_input function executed successfully")
192
+ except Exception as e:
193
+ print(f"Error in handle_input: {e}")
194
+ response = "Error occurred during processing."
195
 
196
  return response
197
 
 
198
  def create_ui():
199
  with gr.Blocks(css="""
200
  /* Overall Styling */
 
405
 
406
  return demo
407
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  # Launch the UI
409
  demo = create_ui()
410
  demo.launch()