VanguardAI commited on
Commit
df220f6
·
verified ·
1 Parent(s): d70fa2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -11
app.py CHANGED
@@ -17,6 +17,7 @@ import requests
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file
19
  from llama_index.core.chat_engine.types import AgentChatResponse
 
20
 
21
  # Initialize models and clients
22
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
@@ -73,8 +74,15 @@ def image_generation(query):
73
  image.save("output.jpg")
74
  return "output.jpg"
75
 
 
 
 
 
 
 
 
76
  # Function to handle different input types and choose the right tool
77
- def handle_input(user_prompt, image=None, audio=None, websearch=False):
78
  if audio:
79
  if isinstance(audio, str):
80
  audio = open(audio, "rb")
@@ -88,11 +96,16 @@ def handle_input(user_prompt, image=None, audio=None, websearch=False):
88
  FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"),
89
  FunctionTool.from_defaults(fn=image_generation, name="Image"),
90
  ]
91
-
92
  # Add the web search tool only if websearch mode is enabled
93
  if websearch:
94
  tools.append(FunctionTool.from_defaults(fn=web_search, name="Web"))
95
 
 
 
 
 
 
96
  llm = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
97
  agent = ReActAgent.from_tools(tools, llm=llm, verbose=True)
98
 
@@ -102,11 +115,11 @@ def handle_input(user_prompt, image=None, audio=None, websearch=False):
102
  response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
103
  else:
104
  response = agent.chat(user_prompt)
105
-
106
  # Extract the content from AgentChatResponse to return as a string
107
  if isinstance(response, AgentChatResponse):
108
  response = response.response
109
-
110
  return response
111
 
112
 
@@ -120,6 +133,7 @@ def create_ui():
120
  with gr.Column(scale=1):
121
  image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
122
  audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
 
123
  voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
124
  websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
125
  with gr.Column(scale=1):
@@ -130,14 +144,14 @@ def create_ui():
130
 
131
  submit.click(
132
  fn=main_interface,
133
- inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode],
134
  outputs=[output_label, audio_output]
135
  )
136
 
137
  voice_only_mode.change(
138
  lambda x: gr.update(visible=not x),
139
  inputs=voice_only_mode,
140
- outputs=[user_prompt, image_input, websearch_mode, submit]
141
  )
142
  voice_only_mode.change(
143
  lambda x: gr.update(visible=x),
@@ -149,16 +163,16 @@ def create_ui():
149
 
150
  # Main interface function
151
  @spaces.GPU()
152
- def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False):
153
  print("Starting main_interface function")
154
  vqa_model.to(device='cuda', dtype=torch.bfloat16)
155
  tts_model.to("cuda")
156
  pipe.to("cuda")
157
 
158
- print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}")
159
-
160
  try:
161
- response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch)
162
  print("handle_input function executed successfully")
163
  except Exception as e:
164
  print(f"Error in handle_input: {e}")
@@ -178,4 +192,4 @@ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websea
178
 
179
  # Launch the UI
180
  demo = create_ui()
181
- demo.launch()
 
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file
19
  from llama_index.core.chat_engine.types import AgentChatResponse
20
+ from llama_index.core import VectorStoreIndex
21
 
22
  # Initialize models and clients
23
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
 
74
  image.save("output.jpg")
75
  return "output.jpg"
76
 
77
+ # Document Question Answering Tool
78
+ def document_question_answering(query, docs):
79
+ index = VectorStoreIndex.from_documents(docs)
80
+ query_engine = index.as_query_engine(similarity_top_k=3)
81
+ response = query_engine.query(query)
82
+ return str(response)
83
+
84
  # Function to handle different input types and choose the right tool
85
+ def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
86
  if audio:
87
  if isinstance(audio, str):
88
  audio = open(audio, "rb")
 
96
  FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"),
97
  FunctionTool.from_defaults(fn=image_generation, name="Image"),
98
  ]
99
+
100
  # Add the web search tool only if websearch mode is enabled
101
  if websearch:
102
  tools.append(FunctionTool.from_defaults(fn=web_search, name="Web"))
103
 
104
+ # Add the document question answering tool only if a document is provided
105
+ if document:
106
+ docs = LlamaParse(result_type="text").load_data(document)
107
+ tools.append(FunctionTool.from_defaults(fn=document_question_answering, name="Document", docs=docs))
108
+
109
  llm = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
110
  agent = ReActAgent.from_tools(tools, llm=llm, verbose=True)
111
 
 
115
  response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
116
  else:
117
  response = agent.chat(user_prompt)
118
+
119
  # Extract the content from AgentChatResponse to return as a string
120
  if isinstance(response, AgentChatResponse):
121
  response = response.response
122
+
123
  return response
124
 
125
 
 
133
  with gr.Column(scale=1):
134
  image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
135
  audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
136
+ document_input = gr.File(type="file", label="Upload a document", elem_id="document-icon")
137
  voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
138
  websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
139
  with gr.Column(scale=1):
 
144
 
145
  submit.click(
146
  fn=main_interface,
147
+ inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode, document_input],
148
  outputs=[output_label, audio_output]
149
  )
150
 
151
  voice_only_mode.change(
152
  lambda x: gr.update(visible=not x),
153
  inputs=voice_only_mode,
154
+ outputs=[user_prompt, image_input, websearch_mode, document_input, submit]
155
  )
156
  voice_only_mode.change(
157
  lambda x: gr.update(visible=x),
 
163
 
164
  # Main interface function
165
  @spaces.GPU()
166
+ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
167
  print("Starting main_interface function")
168
  vqa_model.to(device='cuda', dtype=torch.bfloat16)
169
  tts_model.to("cuda")
170
  pipe.to("cuda")
171
 
172
+ print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")
173
+
174
  try:
175
+ response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
176
  print("handle_input function executed successfully")
177
  except Exception as e:
178
  print(f"Error in handle_input: {e}")
 
192
 
193
  # Launch the UI
194
  demo = create_ui()
195
+ demo.launch()