VanguardAI commited on
Commit
e1310ff
·
verified ·
1 Parent(s): cffc6a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -53
app.py CHANGED
@@ -44,6 +44,7 @@ tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API"))
44
 
45
  # Function to play voice output
46
  def play_voice_output(response):
 
47
  description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
48
  input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
49
  prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
@@ -58,6 +59,7 @@ class NumpyCodeCalculator(Tool):
58
  description = "Useful only for performing numerical computations, not for general searches"
59
 
60
  def _run(self, query: str) -> str:
 
61
  try:
62
  local_dict = {"np": np}
63
  exec(query, local_dict)
@@ -72,6 +74,7 @@ class WebSearch(Tool):
72
  description = "Useful for advanced web searching beyond general information"
73
 
74
  def _run(self, query: str) -> str:
 
75
  answer = tavily_client.qna_search(query=query)
76
  return answer
77
 
@@ -81,6 +84,7 @@ class ImageGeneration(Tool):
81
  description = "Useful for generating images based on text descriptions"
82
 
83
  def _run(self, query: str) -> str:
 
84
  image = pipe(
85
  query,
86
  negative_prompt="",
@@ -101,6 +105,7 @@ class DocumentQuestionAnswering(Tool):
101
  self.qa_chain = self._setup_qa_chain()
102
 
103
  def _setup_qa_chain(self):
 
104
  loader = TextLoader(self.document)
105
  documents = loader.load()
106
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
@@ -116,77 +121,73 @@ class DocumentQuestionAnswering(Tool):
116
  return qa_chain
117
 
118
  def _run(self, query: str) -> str:
 
119
  response = self.qa_chain.run(query)
120
  return str(response)
121
 
122
 
123
  # Function to handle different input types and choose the right tool
124
  def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
 
125
 
126
- tools = [
127
- Tool(
128
- name="Image",
129
- func=ImageGeneration(), # Pass the class instance, not ImageGeneration()._run
130
- description="Useful for generating images based on text descriptions"
131
- ),
132
- ]
133
-
134
- # Add the numpy tool, but with a more specific description
135
- tools.append(Tool(
136
- name="Calculator",
137
- func=NumpyCodeCalculator(), # Pass the class instance, not NumpyCodeCalculator()._run
138
- description="Useful only for performing numerical computations, not for general searches"
139
- ))
140
-
141
- # Add the web search tool only if websearch mode is enabled
142
- if websearch:
143
- tools.append(Tool(
144
- name="Web",
145
- func=WebSearch(), # Pass the class instance, not WebSearch()._run
146
- description="Useful for advanced web searching beyond general information"
147
- ))
148
 
149
- # Add the document question answering tool only if a document is provided
150
- if document:
151
- tools.append(Tool(
152
- name="Document",
153
- func=DocumentQuestionAnswering(document), # This is already correct
154
- description="Useful for answering questions about a specific document"
155
- ))
156
 
157
- llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
 
 
 
 
158
 
159
- # Check if the input requires any tools
160
- requires_tool = False
161
- for tool in tools:
162
- if tool.name.lower() in user_prompt.lower():
163
- requires_tool = True
164
- break
165
 
166
- if image or audio or requires_tool:
167
- # Initialize the agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  agent = initialize_agent(
169
  tools,
170
  llm,
171
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
172
  verbose=True
173
  )
174
-
175
- if image:
176
- image = Image.open(image).convert('RGB')
177
- messages = [{"role": "user", "content": [image, user_prompt]}]
178
- response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
179
- elif audio:
180
- transcription = client.audio.transcriptions.create(
181
- file=(audio.name, audio.read()),
182
- model="whisper-large-v3"
183
- )
184
- user_prompt = transcription.text
185
- response = agent.run(user_prompt)
186
- else:
187
- response = agent.run(user_prompt)
188
  else:
189
- # If no tools are required, use the LLM directly
190
  response = llm.call(query=user_prompt)
191
 
192
  return response
 
44
 
45
  # Function to play voice output
46
  def play_voice_output(response):
47
+ print("Executing play_voice_output function")
48
  description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
49
  input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
50
  prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
 
59
  description = "Useful only for performing numerical computations, not for general searches"
60
 
61
  def _run(self, query: str) -> str:
62
+ print("Executing NumpyCodeCalculator tool")
63
  try:
64
  local_dict = {"np": np}
65
  exec(query, local_dict)
 
74
  description = "Useful for advanced web searching beyond general information"
75
 
76
  def _run(self, query: str) -> str:
77
+ print("Executing WebSearch tool")
78
  answer = tavily_client.qna_search(query=query)
79
  return answer
80
 
 
84
  description = "Useful for generating images based on text descriptions"
85
 
86
  def _run(self, query: str) -> str:
87
+ print("Executing ImageGeneration tool")
88
  image = pipe(
89
  query,
90
  negative_prompt="",
 
105
  self.qa_chain = self._setup_qa_chain()
106
 
107
  def _setup_qa_chain(self):
108
+ print("Setting up DocumentQuestionAnswering tool")
109
  loader = TextLoader(self.document)
110
  documents = loader.load()
111
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 
121
  return qa_chain
122
 
123
  def _run(self, query: str) -> str:
124
+ print("Executing DocumentQuestionAnswering tool")
125
  response = self.qa_chain.run(query)
126
  return str(response)
127
 
128
 
129
  # Function to handle different input types and choose the right tool
130
  def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
131
+ print(f"Handling input: {user_prompt}")
132
 
133
+ # Initialize the LLM
134
+ llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Define the tools
137
+ tools = []
 
 
 
 
 
138
 
139
+ # Add Image Generation Tool
140
+ tools.append(ImageGeneration())
141
+
142
+ # Add Calculator Tool
143
+ tools.append(NumpyCodeCalculator())
144
 
145
+ # Add Web Search Tool if enabled
146
+ if websearch:
147
+ tools.append(WebSearch())
 
 
 
148
 
149
+ # Add Document QA Tool if document is provided
150
+ if document:
151
+ tools.append(DocumentQuestionAnswering(document))
152
+
153
+ # Check if any tools are mentioned in the user prompt
154
+ requires_tool = any([tool.name.lower() in user_prompt.lower() for tool in tools])
155
+
156
+ # Handle different input scenarios
157
+ if image:
158
+ print("Processing image input")
159
+ image = Image.open(image).convert('RGB')
160
+ messages = [{"role": "user", "content": [image, user_prompt]}]
161
+ response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
162
+ elif audio:
163
+ print("Processing audio input")
164
+ transcription = client.audio.transcriptions.create(
165
+ file=(audio.name, audio.read()),
166
+ model="whisper-large-v3"
167
+ )
168
+ user_prompt = transcription.text
169
+ # If tools are required, use an agent
170
+ if requires_tool:
171
+ agent = initialize_agent(
172
+ tools,
173
+ llm,
174
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
175
+ verbose=True
176
+ )
177
+ response = agent.run(user_prompt)
178
+ else:
179
+ response = llm.call(query=user_prompt)
180
+ elif requires_tool:
181
+ print("Using agent with tools")
182
  agent = initialize_agent(
183
  tools,
184
  llm,
185
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
186
  verbose=True
187
  )
188
+ response = agent.run(user_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  else:
190
+ print("Using LLM directly")
191
  response = llm.call(query=user_prompt)
192
 
193
  return response