VanguardAI commited on
Commit
724aed2
·
verified ·
1 Parent(s): 64725d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -5,7 +5,6 @@ import numpy as np
5
  from groq import Groq
6
  import spaces
7
  from transformers import AutoModel, AutoTokenizer
8
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
9
  from parler_tts import ParlerTTSForConditionalGeneration
10
  import soundfile as sf
11
  from llama_index.core.agent import ReActAgent
@@ -16,12 +15,12 @@ from tavily import TavilyClient
16
  import requests
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file
 
19
 
20
  # Initialize models and clients
21
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
22
  client = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
23
 
24
-
25
  vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
26
  device_map="auto", torch_dtype=torch.bfloat16)
27
  tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)
@@ -29,15 +28,9 @@ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_co
29
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
30
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
31
 
32
- # Image generation model
33
- base = "stabilityai/stable-diffusion-xl-base-1.0"
34
- repo = "ByteDance/SDXL-Lightning"
35
- ckpt = "sdxl_lightning_4step_unet.safetensors"
36
-
37
- unet = UNet2DConditionModel.from_config(base, subfolder="unet")
38
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
39
- image_pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16")
40
- image_pipe.scheduler = EulerDiscreteScheduler.from_config(image_pipe.scheduler.config, timestep_spacing="trailing")
41
 
42
  # Tavily Client for web search
43
  tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API"))
@@ -79,7 +72,12 @@ def web_search(query):
79
 
80
  # Image Generation Tool
81
  def image_generation(query):
82
- image = image_pipe(prompt=query, num_inference_steps=20, guidance_scale=7.5).images[0]
 
 
 
 
 
83
  image.save("output.jpg")
84
  return "output.jpg"
85
 
@@ -97,7 +95,7 @@ def handle_input(user_prompt, image=None, audio=None, websearch=False):
97
  user_prompt = transcription.text
98
 
99
  tools = [
100
- FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy Code Calculator"),
101
  FunctionTool.from_defaults(fn=image_generation, name="Image"),
102
  ]
103
 
@@ -166,8 +164,7 @@ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websea
166
  print("Starting main_interface function")
167
  vqa_model.to(device='cuda', dtype=torch.bfloat16)
168
  tts_model.to("cuda")
169
- unet.to("cuda")
170
- image_pipe.to("cuda")
171
 
172
  print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}")
173
 
@@ -192,4 +189,4 @@ def main_interface(user_prompt, image=None, audio=None, voice_only=False, websea
192
 
193
  # Launch the UI
194
  demo = create_ui()
195
- demo.launch()
 
5
  from groq import Groq
6
  import spaces
7
  from transformers import AutoModel, AutoTokenizer
 
8
  from parler_tts import ParlerTTSForConditionalGeneration
9
  import soundfile as sf
10
  from llama_index.core.agent import ReActAgent
 
15
  import requests
16
  from huggingface_hub import hf_hub_download
17
  from safetensors.torch import load_file
18
+ from diffusers import StableDiffusion3Pipeline
19
 
20
  # Initialize models and clients
21
  MODEL = 'llama3-groq-70b-8192-tool-use-preview'
22
  client = Groq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))
23
 
 
24
  vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
25
  device_map="auto", torch_dtype=torch.bfloat16)
26
  tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)
 
28
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
29
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")
30
 
31
+ # Updated Image Generation Model
32
+ pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
33
+ pipe = pipe.to("cuda")
 
 
 
 
 
 
34
 
35
  # Tavily Client for web search
36
  tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API"))
 
72
 
73
  # Image Generation Tool
74
  def image_generation(query):
75
+ image = pipe(
76
+ query,
77
+ negative_prompt="",
78
+ num_inference_steps=28,
79
+ guidance_scale=7.0,
80
+ ).images[0]
81
  image.save("output.jpg")
82
  return "output.jpg"
83
 
 
95
  user_prompt = transcription.text
96
 
97
  tools = [
98
+ FunctionTool.from_defaults(fn=numpy_code_calculator, name="Numpy"),
99
  FunctionTool.from_defaults(fn=image_generation, name="Image"),
100
  ]
101
 
 
164
  print("Starting main_interface function")
165
  vqa_model.to(device='cuda', dtype=torch.bfloat16)
166
  tts_model.to("cuda")
167
+ pipe.to("cuda")
 
168
 
169
  print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}")
170
 
 
189
 
190
  # Launch the UI
191
  demo = create_ui()
192
+ demo.launch()