VanguardAI commited on
Commit
1197e50
·
verified ·
1 Parent(s): f7d8d6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -82
app.py CHANGED
@@ -7,10 +7,12 @@ from transformers import AutoModel, AutoTokenizer
7
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
8
  from parler_tts import ParlerTTSForConditionalGeneration
9
  import soundfile as sf
10
- from langchain_community.embeddings import OpenAIEmbeddings
11
- from langchain_community.vectorstores import Chroma
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain.chains import RetrievalQA
 
 
14
  from PIL import Image
15
  from decord import VideoReader, cpu
16
  from tavily import TavilyClient
@@ -18,31 +20,31 @@ import requests
18
  from huggingface_hub import hf_hub_download
19
  from safetensors.torch import load_file
20
 
21
- # Initialize models
22
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
24
 
25
- text_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
26
  device_map="auto", torch_dtype=torch.bfloat16)
27
  tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)
28
 
29
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
30
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
31
 
32
- # Corrected image model and pipeline setup
33
  base = "stabilityai/stable-diffusion-xl-base-1.0"
34
  repo = "ByteDance/SDXL-Lightning"
35
  ckpt = "sdxl_lightning_4step_unet.safetensors"
36
 
37
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
38
  unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
39
- image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
  image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing")
41
 
42
- # Tavily Client
43
- tavily_client = TavilyClient(api_key="tvly-YOUR_API_KEY")
44
 
45
- # Voice output function
46
  def play_voice_output(response):
47
  description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
48
  input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
@@ -52,50 +54,55 @@ def play_voice_output(response):
52
  sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
53
  return "output.wav"
54
 
55
- # NumPy Calculation function
56
- def numpy_calculate(code: str) -> str:
 
57
  try:
58
- local_dict = {}
59
- exec(code, {"np": np}, local_dict)
 
 
 
 
 
 
 
 
 
 
 
 
60
  result = local_dict.get("result", "No result found")
61
  return str(result)
62
  except Exception as e:
63
- return f"An error occurred: {str(e)}"
64
 
65
- # Function to use Langchain for RAG
66
- def use_langchain_rag(file_name, file_content, query):
67
- # Split the document into chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
69
  docs = text_splitter.create_documents([file_content])
70
-
71
- # Create embeddings and store in the vector database
72
  embeddings = OpenAIEmbeddings()
73
- db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db") # Use a persistent directory
74
-
75
- # Create a question-answering chain
76
  qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=db.as_retriever())
77
-
78
- # Get the answer
79
  return qa.run(query)
80
 
81
- # Function to encode video
82
- def encode_video(video_path):
83
- MAX_NUM_FRAMES = 64
84
- vr = VideoReader(video_path, ctx=cpu(0))
85
- sample_fps = round(vr.get_avg_fps() / 1)
86
- frame_idx = [i for i in range(0, len(vr), sample_fps)]
87
- if len(frame_idx) > MAX_NUM_FRAMES:
88
- frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
89
- frames = vr.get_batch(frame_idx).asnumpy()
90
- frames = [Image.fromarray(v.astype('uint8')) for v in frames]
91
- return frames
92
-
93
- # Web search function
94
- def web_search(query):
95
- answer = tavily_client.qna_search(query=query)
96
- return answer
97
-
98
- # Function to handle different input types
99
  def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False):
100
  # Voice input handling
101
  if audio:
@@ -105,50 +112,58 @@ def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, webs
105
  )
106
  user_prompt = transcription.text
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # If user uploaded an image and text, use MiniCPM model
109
  if image:
110
  image = Image.open(image).convert('RGB')
111
  messages = [{"role": "user", "content": [image, user_prompt]}]
112
- response = text_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
113
  return response
114
 
115
- # Determine which tool to use
116
- if doc:
117
- file_content = doc.read().decode('utf-8')
118
- response = use_langchain_rag(doc.name, file_content, user_prompt)
119
- elif "calculate" in user_prompt.lower():
120
- response = numpy_calculate(user_prompt)
121
- elif "generate" in user_prompt.lower() and ("image" in user_prompt.lower() or "picture" in user_prompt.lower()):
122
- response = image_pipe(prompt=user_prompt, num_inference_steps=20, guidance_scale=7.5)
123
- elif websearch:
124
- response = web_search(user_prompt)
125
  else:
126
- chat_completion = client.chat.completions.create(
127
- messages=[
128
- {"role": "system", "content": "You are a helpful assistant."},
129
- {"role": "user", "content": user_prompt}
130
- ],
131
- model=MODEL,
132
- )
133
- response = chat_completion.choices[0].message.content
134
 
135
  return response
136
 
137
- @spaces.GPU()
138
- def main_interface(user_prompt, image=None, video=None, audio=None, doc=None, voice_only=False, websearch=False):
139
- text_model.to(device='cuda', dtype=torch.bfloat16)
140
- tts_model.to("cuda")
141
- unet.to("cuda", torch.float16)
142
- image_pipe.to("cuda")
143
-
144
- response = handle_input(user_prompt, image=image, video=video, audio=audio, doc=doc, websearch=websearch)
145
-
146
- if voice_only:
147
- audio_file = play_voice_output(response)
148
- return response, audio_file # Return both text and audio outputs
149
- else:
150
- return response, None # Return only the text output, no audio
151
-
152
  # Gradio UI Setup
153
  def create_ui():
154
  with gr.Blocks() as demo:
@@ -158,28 +173,27 @@ def create_ui():
158
  user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
159
  with gr.Column(scale=1):
160
  image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
161
- video_input = gr.Video(label="Upload a video", elem_id="video-icon")
162
  audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
163
  doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon")
164
  voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
165
  websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
166
  with gr.Column(scale=1):
167
  submit = gr.Button("Submit")
168
-
169
  output_label = gr.Label(label="Output")
170
  audio_output = gr.Audio(label="Audio Output", visible=False)
171
 
172
  submit.click(
173
  fn=main_interface,
174
- inputs=[user_prompt, image_input, video_input, audio_input, doc_input, voice_only_mode, websearch_mode],
175
- outputs=[output_label, audio_output] # Expecting a string and audio file
176
  )
177
 
178
  # Voice-only mode UI
179
  voice_only_mode.change(
180
  lambda x: gr.update(visible=not x),
181
  inputs=voice_only_mode,
182
- outputs=[user_prompt, image_input, video_input, doc_input, websearch_mode, submit]
183
  )
184
  voice_only_mode.change(
185
  lambda x: gr.update(visible=x),
@@ -189,6 +203,22 @@ def create_ui():
189
 
190
  return demo
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  # Launch the app
193
  demo = create_ui()
194
- demo.launch(inline=False)
 
7
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
8
  from parler_tts import ParlerTTSForConditionalGeneration
9
  import soundfile as sf
10
+ from langchain_community.embeddings import OpenAIEmbeddings
11
+ from langchain_community.vectorstores import Chroma
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain.chains import RetrievalQA
14
+ from langchain.agents import initialize_agent, Tool
15
+ from langchain.llms import OpenAI
16
  from PIL import Image
17
  from decord import VideoReader, cpu
18
  from tavily import TavilyClient
 
20
  from huggingface_hub import hf_hub_download
21
  from safetensors.torch import load_file
22
 
23
+ # Initialize models and clients
24
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
25
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
26
 
27
+ vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
28
  device_map="auto", torch_dtype=torch.bfloat16)
29
  tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)
30
 
31
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
32
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
33
 
34
+ # Image generation model
35
  base = "stabilityai/stable-diffusion-xl-base-1.0"
36
  repo = "ByteDance/SDXL-Lightning"
37
  ckpt = "sdxl_lightning_4step_unet.safetensors"
38
 
39
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
40
  unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
41
+ image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")
42
  image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing")
43
 
44
+ # Tavily Client for web search
45
+ tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY"))
46
 
47
+ # Function to play voice output
48
  def play_voice_output(response):
49
  description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
50
  input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
 
54
  sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
55
  return "output.wav"
56
 
57
+ # NumPy Code Calculator Tool
58
+ def numpy_code_calculator(query):
59
+ """Generates and executes NumPy code for mathematical operations."""
60
  try:
61
+ # You might need to use a more sophisticated approach to generate NumPy code
62
+ # based on the user's query. This is a simple example.
63
+ llm_response = client.chat.completions.create(
64
+ model=MODEL,
65
+ messages=[
66
+ {"role": "user", "content": f"Write NumPy code to: {query}"}
67
+ ]
68
+ )
69
+ code = llm_response.choices[0].message.content
70
+ print(f"Generated NumPy code:\n{code}") # Print the generated code
71
+
72
+ # Execute the code in a safe environment
73
+ local_dict = {"np": np}
74
+ exec(code, local_dict)
75
  result = local_dict.get("result", "No result found")
76
  return str(result)
77
  except Exception as e:
78
+ return f"Error: {e}"
79
 
80
+ # Web Search Tool
81
+ def web_search(query):
82
+ """Performs a web search using Tavily."""
83
+ answer = tavily_client.qna_search(query=query)
84
+ return answer
85
+
86
+ # Image Generation Tool
87
+ def image_generation(query):
88
+ """Generates an image based on the given prompt."""
89
+ image = image_pipe(prompt=query, num_inference_steps=20, guidance_scale=7.5).images[0]
90
+ image.save("output.jpg")
91
+ return "output.jpg"
92
+
93
+ # Document Question Answering Tool
94
+ def doc_question_answering(query, file_path):
95
+ """Answers questions based on the content of a document."""
96
+ with open(file_path, 'r') as f:
97
+ file_content = f.read()
98
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
99
  docs = text_splitter.create_documents([file_content])
 
 
100
  embeddings = OpenAIEmbeddings()
101
+ db = Chroma.from_documents(docs, embeddings, persist_directory=".chroma_db")
 
 
102
  qa = RetrievalQA.from_chain_type(llm=OpenAI(), chain_type="stuff", retriever=db.as_retriever())
 
 
103
  return qa.run(query)
104
 
105
+ # Function to handle different input types and choose the right tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def handle_input(user_prompt, image=None, video=None, audio=None, doc=None, websearch=False):
107
  # Voice input handling
108
  if audio:
 
112
  )
113
  user_prompt = transcription.text
114
 
115
+ # Initialize tools
116
+ tools = [
117
+ Tool(
118
+ name="Numpy Code Calculator",
119
+ func=numpy_code_calculator,
120
+ description="Useful for when you need to perform mathematical calculations using NumPy. Provide the calculation you want to perform.",
121
+ ),
122
+ Tool(
123
+ name="Web Search",
124
+ func=web_search,
125
+ description="Useful for when you need to find information from the real world.",
126
+ ),
127
+ Tool(
128
+ name="Image Generation",
129
+ func=image_generation,
130
+ description="Useful for when you need to generate an image based on a description.",
131
+ ),
132
+ ]
133
+
134
+ # Add document Q&A tool if a document is provided
135
+ if doc:
136
+ tools.append(
137
+ Tool(
138
+ name="Document Question Answering",
139
+ func=lambda query: doc_question_answering(query, doc.name),
140
+ description="Useful for when you need to answer questions about the uploaded document.",
141
+ )
142
+ )
143
+
144
+ # Initialize agent
145
+ agent = initialize_agent(
146
+ tools,
147
+ client,
148
+ agent="zero-shot-react-description",
149
+ verbose=True,
150
+ )
151
+
152
  # If user uploaded an image and text, use MiniCPM model
153
  if image:
154
  image = Image.open(image).convert('RGB')
155
  messages = [{"role": "user", "content": [image, user_prompt]}]
156
+ response = vqa_model.chat(image=None, msgs=messages, tokenizer=tokenizer)
157
  return response
158
 
159
+ # Use the agent to determine the best tool and get the response
160
+ if websearch:
161
+ response = agent.run(f"{user_prompt} Use the Web Search tool if necessary.")
 
 
 
 
 
 
 
162
  else:
163
+ response = agent.run(user_prompt)
 
 
 
 
 
 
 
164
 
165
  return response
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Gradio UI Setup
168
  def create_ui():
169
  with gr.Blocks() as demo:
 
173
  user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
174
  with gr.Column(scale=1):
175
  image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
 
176
  audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
177
  doc_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon")
178
  voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
179
  websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
180
  with gr.Column(scale=1):
181
  submit = gr.Button("Submit")
182
+
183
  output_label = gr.Label(label="Output")
184
  audio_output = gr.Audio(label="Audio Output", visible=False)
185
 
186
  submit.click(
187
  fn=main_interface,
188
+ inputs=[user_prompt, image_input, audio_input, doc_input, voice_only_mode, websearch_mode],
189
+ outputs=[output_label, audio_output]
190
  )
191
 
192
  # Voice-only mode UI
193
  voice_only_mode.change(
194
  lambda x: gr.update(visible=not x),
195
  inputs=voice_only_mode,
196
+ outputs=[user_prompt, image_input, doc_input, websearch_mode, submit]
197
  )
198
  voice_only_mode.change(
199
  lambda x: gr.update(visible=x),
 
203
 
204
  return demo
205
 
206
+ # Main interface function
207
+ @spaces.GPU()
208
+ def main_interface(user_prompt, image=None, audio=None, doc=None, voice_only=False, websearch=False):
209
+ vqa_model.to(device='cuda', dtype=torch.bfloat16)
210
+ tts_model.to("cuda")
211
+ unet.to("cuda", torch.float16)
212
+ image_pipe.to("cuda")
213
+
214
+ response = handle_input(user_prompt, image=image, audio=audio, doc=doc, websearch=websearch)
215
+
216
+ if voice_only:
217
+ audio_file = play_voice_output(response)
218
+ return response, audio_file
219
+ else:
220
+ return response, None
221
+
222
  # Launch the app
223
  demo = create_ui()
224
+ demo.launch(inline=False)