prithivMLmods commited on
Commit
8a2ba41
·
verified ·
1 Parent(s): 5663d15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -190
app.py CHANGED
@@ -6,6 +6,8 @@ import time
6
  import asyncio
7
  import re
8
  from threading import Thread
 
 
9
 
10
  import gradio as gr
11
  import spaces
@@ -13,27 +15,40 @@ import torch
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
- import subprocess
17
 
18
- # Install flash-attn with our environment flag (if needed)
19
  subprocess.run(
20
  'pip install flash-attn --no-build-isolation',
21
  env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
22
  shell=True
23
  )
24
 
25
- # Set torch backend configurations for Flux RealismLora
26
- torch.backends.cudnn.deterministic = True
27
- torch.backends.cudnn.benchmark = False
28
- torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # -------------------------------
31
- # CONFIGURATION & UTILITY FUNCTIONS
32
- # -------------------------------
33
- MAX_SEED = 2**32 - 1
34
 
35
  def save_image(img: Image.Image) -> str:
36
- """Save a PIL image with a unique filename and return its path."""
37
  unique_name = str(uuid.uuid4()) + ".png"
38
  img.save(unique_name)
39
  return unique_name
@@ -43,141 +58,116 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
43
  seed = random.randint(0, MAX_SEED)
44
  return seed
45
 
46
- def progress_bar_html(label: str) -> str:
47
- """
48
- Returns an HTML snippet for an animated progress bar with a given label.
49
- """
50
- return f'''
51
- <div style="display: flex; align-items: center;">
52
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
53
- <div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;">
54
- <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
55
- </div>
56
- </div>
57
- <style>
58
- @keyframes loading {{
59
- 0% {{ transform: translateX(-100%); }}
60
- 100% {{ transform: translateX(100%); }}
61
- }}
62
- </style>
63
- '''
64
-
65
- # -------------------------------
66
- # FLUX REALISMLORA IMAGE GENERATION SETUP (New Implementation)
67
- # -------------------------------
68
- from diffusers import DiffusionPipeline
69
-
70
  base_model = "black-forest-labs/FLUX.1-dev"
71
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
72
- lora_repo = "XLabs-AI/flux-RealismLora"
73
- trigger_word = "" # No trigger word used.
74
  pipe.load_lora_weights(lora_repo)
75
  pipe.to("cuda")
76
 
77
- @spaces.GPU()
78
- def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
79
- # Set random seed for reproducibility
80
- if randomize_seed:
81
- seed = random.randint(0, MAX_SEED)
82
- generator = torch.Generator(device="cuda").manual_seed(seed)
83
-
84
- # Update progress bar (0% at start)
85
- progress(0, "Starting image generation...")
86
-
87
- # Simulate progress updates during the steps
88
- for i in range(1, steps + 1):
89
- if steps >= 10 and i % (steps // 10) == 0:
90
- progress(i / steps * 100, f"Processing step {i} of {steps}...")
91
-
92
- # Generate image using the pipeline
93
- image = pipe(
94
- prompt=f"{prompt} {trigger_word}",
95
- num_inference_steps=steps,
96
- guidance_scale=cfg_scale,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  width=width,
98
  height=height,
99
- generator=generator,
100
- joint_attention_kwargs={"scale": lora_scale},
101
- ).images[0]
102
-
103
- # Final progress update (100%)
104
- progress(100, "Completed!")
105
- yield image, seed
106
-
107
- # -------------------------------
108
- # SMOLVLM2 SETUP (Default Text/Multimodal Model)
109
- # -------------------------------
110
- from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
111
-
112
- smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
113
- smol_model = AutoModelForImageTextToText.from_pretrained(
114
- "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
115
  _attn_implementation="flash_attention_2",
116
- torch_dtype=torch.float16
117
  ).to("cuda:0")
118
 
119
- # -------------------------------
120
- # TTS UTILITY FUNCTIONS
121
- # -------------------------------
122
- TTS_VOICES = [
123
- "en-US-JennyNeural", # @tts1
124
- "en-US-GuyNeural", # @tts2
125
- ]
126
-
127
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
128
- """Convert text to speech using Edge TTS and save the output as MP3."""
129
- communicate = edge_tts.Communicate(text, voice)
130
- await communicate.save(output_file)
131
- return output_file
132
-
133
- # -------------------------------
134
- # CHAT / MULTIMODAL GENERATION FUNCTION
135
- # -------------------------------
136
  @spaces.GPU
137
- def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
138
  """
139
- Generates chatbot responses using SmolVLM2 with support for multimodal inputs and TTS.
140
- Special commands:
141
- - "@image": triggers image generation using the RealismLora flux implementation.
142
- - "@tts1" or "@tts2": triggers text-to-speech after generation.
 
 
143
  """
144
- torch.cuda.empty_cache()
145
  text = input_dict["text"]
146
  files = input_dict.get("files", [])
147
-
148
- # If the query starts with "@image", use RealismLora to generate an image.
149
  if text.strip().lower().startswith("@image"):
150
  prompt = text[len("@image"):].strip()
151
- yield progress_bar_html("Hold Tight Generating Flux RealismLora Image")
152
- # Default parameters for RealismLora generation
153
- default_cfg_scale = 3.2
154
- default_steps = 32
155
- default_width = 1152
156
- default_height = 896
157
- default_seed = 3981632454
158
- default_lora_scale = 0.85
159
- # Call the new run_lora function and yield its final result
160
- for result in run_lora(prompt, default_cfg_scale, default_steps, True, default_seed, default_width, default_height, default_lora_scale, progress=gr.Progress(track_tqdm=True)):
161
- final_result = result
162
- yield gr.Image(final_result[0])
163
  return
164
-
165
- # Handle TTS commands if present.
166
- tts_prefix = "@tts"
167
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
168
- voice = None
169
- if is_tts:
170
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
171
- if voice_index:
172
- voice = TTS_VOICES[voice_index - 1]
173
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
174
-
175
- yield "Processing with SmolVLM2"
176
-
177
- # Build conversation messages based on input and history.
178
  user_content = []
179
  media_queue = []
180
- if chat_history == []:
 
 
181
  text = text.strip()
182
  for file in files:
183
  if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
@@ -202,17 +192,17 @@ def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
202
  resulting_messages = []
203
  user_content = []
204
  media_queue = []
205
- for hist in chat_history:
206
  if hist["role"] == "user" and isinstance(hist["content"], tuple):
207
  file_name = hist["content"][0]
208
  if file_name.endswith((".png", ".jpg", ".jpeg")):
209
  media_queue.append({"type": "image", "path": file_name})
210
  elif file_name.endswith(".mp4"):
211
  media_queue.append({"type": "video", "path": file_name})
212
- for hist in chat_history:
213
  if hist["role"] == "user" and isinstance(hist["content"], str):
214
- txt = hist["content"]
215
- parts = re.split(r'(<image>|<video>)', txt)
216
  for part in parts:
217
  if part == "<image>" and media_queue:
218
  user_content.append(media_queue.pop(0))
@@ -230,89 +220,63 @@ def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200):
230
  "content": [{"type": "text", "text": hist["content"]}]
231
  })
232
  user_content = []
233
- if not resulting_messages:
234
- resulting_messages = [{"role": "user", "content": user_content}]
235
-
236
  if text == "" and not files:
237
- yield "Please input a query and optionally image(s)."
238
  return
239
  if text == "" and files:
240
- yield "Please input a text query along with the image(s)."
241
  return
242
-
243
- inputs = smol_processor.apply_chat_template(
 
244
  resulting_messages,
245
  add_generation_prompt=True,
246
  tokenize=True,
247
  return_dict=True,
248
  return_tensors="pt",
249
  )
250
- if "pixel_values" in inputs:
251
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16)
252
- inputs = inputs.to(smol_model.device)
253
-
254
- streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True)
255
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
256
- thread = Thread(target=smol_model.generate, kwargs=generation_args)
 
257
  thread.start()
258
-
259
- yield "..."
260
  buffer = ""
261
  for new_text in streamer:
262
  buffer += new_text
263
  time.sleep(0.01)
264
  yield buffer
265
 
266
- if is_tts and voice:
267
- final_response = buffer
268
- output_file = asyncio.run(text_to_speech(final_response, voice))
269
- yield gr.Audio(output_file, autoplay=True)
270
-
271
- # -------------------------------
272
  # GRADIO CHAT INTERFACE
273
- # -------------------------------
274
- DESCRIPTION = "# Flux RealismLora + SmolVLM2 Chat"
275
- if not torch.cuda.is_available():
276
- DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>"
277
-
278
- css = '''
279
- h1 {
280
- text-align: center;
281
- display: block;
282
- }
283
- #duplicate-button {
284
- margin: auto;
285
- color: #fff;
286
- background: #1565c0;
287
- border-radius: 100vh;
288
- }
289
- '''
290
 
291
  demo = gr.ChatInterface(
292
- fn=generate,
293
- additional_inputs=[
294
- gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"),
295
- ],
296
- examples=[
297
- [{"text": "@image A futuristic cityscape at dusk in hyper-realistic style"}],
298
- [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
299
- [{"text": "What does this document say?", "files": ["example_images/document.jpg"]}],
300
- [{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}],
301
- ],
302
- cache_examples=False,
303
- type="messages",
304
- description=DESCRIPTION,
305
- css=css,
306
- fill_height=True,
307
- textbox=gr.MultimodalTextbox(
308
- label="Query Input",
309
- file_types=["image", ".mp4"],
310
- file_count="multiple",
311
- placeholder="Type text and/or upload media. Use '@image' for image gen, '@tts1' or '@tts2' for TTS."
312
- ),
313
  stop_btn="Stop Generation",
314
  multimodal=True,
 
 
 
315
  )
316
 
317
  if __name__ == "__main__":
318
- demo.queue(max_size=20).launch(share=True)
 
6
  import asyncio
7
  import re
8
  from threading import Thread
9
+ from io import BytesIO
10
+ import subprocess
11
 
12
  import gradio as gr
13
  import spaces
 
15
  import numpy as np
16
  from PIL import Image
17
  import edge_tts
 
18
 
19
+ # Install flash-attn without building CUDA kernels (if needed)
20
  subprocess.run(
21
  'pip install flash-attn --no-build-isolation',
22
  env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
23
  shell=True
24
  )
25
 
26
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
27
+ from diffusers import DiffusionPipeline
28
+
29
+ # ------------------------------------------------------------------------------
30
+ # Global Configurations
31
+ # ------------------------------------------------------------------------------
32
+ DESCRIPTION = "# SmolVLM2 with Flux.1 Integration 📺"
33
+ if not torch.cuda.is_available():
34
+ DESCRIPTION += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
35
+
36
+ css = '''
37
+ h1 {
38
+ text-align: center;
39
+ display: block;
40
+ }
41
+ '''
42
+
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
 
45
+ # ------------------------------------------------------------------------------
46
+ # FLUX.1 IMAGE GENERATION SETUP
47
+ # ------------------------------------------------------------------------------
48
+ MAX_SEED = np.iinfo(np.int32).max
49
 
50
  def save_image(img: Image.Image) -> str:
51
+ """Save a PIL image with a unique filename and return the path."""
52
  unique_name = str(uuid.uuid4()) + ".png"
53
  img.save(unique_name)
54
  return unique_name
 
58
  seed = random.randint(0, MAX_SEED)
59
  return seed
60
 
61
+ # Initialize Flux.1 pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  base_model = "black-forest-labs/FLUX.1-dev"
63
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
64
+ lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
65
+ trigger_word = "Super Realism" # Leave blank if no trigger word is needed.
66
  pipe.load_lora_weights(lora_repo)
67
  pipe.to("cuda")
68
 
69
+ # Define style prompts for Flux.1
70
+ style_list = [
71
+ {
72
+ "name": "3840 x 2160",
73
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
74
+ },
75
+ {
76
+ "name": "2560 x 1440",
77
+ "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
78
+ },
79
+ {
80
+ "name": "HD+",
81
+ "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
82
+ },
83
+ {
84
+ "name": "Style Zero",
85
+ "prompt": "{prompt}",
86
+ },
87
+ ]
88
+ styles = {s["name"]: s["prompt"] for s in style_list}
89
+ DEFAULT_STYLE_NAME = "3840 x 2160"
90
+ STYLE_NAMES = list(styles.keys())
91
+
92
+ def apply_style(style_name: str, positive: str) -> str:
93
+ return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
94
+
95
+ def generate_image_flux(
96
+ prompt: str,
97
+ seed: int = 0,
98
+ width: int = 1024,
99
+ height: int = 1024,
100
+ guidance_scale: float = 3,
101
+ randomize_seed: bool = False,
102
+ style_name: str = DEFAULT_STYLE_NAME,
103
+ ):
104
+ """Generate an image using the Flux.1 pipeline with style prompts."""
105
+ seed = int(randomize_seed_fn(seed, randomize_seed))
106
+ positive_prompt = apply_style(style_name, prompt)
107
+ if trigger_word:
108
+ positive_prompt = f"{trigger_word} {positive_prompt}"
109
+ images = pipe(
110
+ prompt=positive_prompt,
111
  width=width,
112
  height=height,
113
+ guidance_scale=guidance_scale,
114
+ num_inference_steps=28,
115
+ num_images_per_prompt=1,
116
+ output_type="pil",
117
+ ).images
118
+ image_paths = [save_image(img) for img in images]
119
+ return image_paths, seed
120
+
121
+ # ------------------------------------------------------------------------------
122
+ # SMOLVLM2 MODEL SETUP
123
+ # ------------------------------------------------------------------------------
124
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
125
+ model = AutoModelForImageTextToText.from_pretrained(
126
+ "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
 
 
127
  _attn_implementation="flash_attention_2",
128
+ torch_dtype=torch.bfloat16
129
  ).to("cuda:0")
130
 
131
+ # ------------------------------------------------------------------------------
132
+ # CHAT / INFERENCE FUNCTION
133
+ # ------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @spaces.GPU
135
+ def model_inference(input_dict, history, max_tokens):
136
  """
137
+ Implements a chat interface using SmolVLM2.
138
+
139
+ Special behavior:
140
+ - If the query text starts with "@image", the Flux.1 pipeline is used to generate an image.
141
+ - Otherwise, the query is processed with SmolVLM2.
142
+ - In the SmolVLM2 branch, a progress message "Processing with SmolVLM2..." is yielded.
143
  """
 
144
  text = input_dict["text"]
145
  files = input_dict.get("files", [])
146
+
147
+ # If the text begins with "@image", use Flux.1 image generation.
148
  if text.strip().lower().startswith("@image"):
149
  prompt = text[len("@image"):].strip()
150
+ yield "Hold Tight Generating Flux.1 Image..."
151
+ image_paths, used_seed = generate_image_flux(
152
+ prompt=prompt,
153
+ seed=1,
154
+ width=1024,
155
+ height=1024,
156
+ guidance_scale=3,
157
+ randomize_seed=True,
158
+ style_name=DEFAULT_STYLE_NAME,
159
+ )
160
+ yield gr.Image(image_paths[0])
 
161
  return
162
+
163
+ # Default: Use SmolVLM2 inference.
164
+ yield "Processing with SmolVLM2..."
165
+
 
 
 
 
 
 
 
 
 
 
166
  user_content = []
167
  media_queue = []
168
+
169
+ # If no conversation history, process current input.
170
+ if not history:
171
  text = text.strip()
172
  for file in files:
173
  if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
 
192
  resulting_messages = []
193
  user_content = []
194
  media_queue = []
195
+ for hist in history:
196
  if hist["role"] == "user" and isinstance(hist["content"], tuple):
197
  file_name = hist["content"][0]
198
  if file_name.endswith((".png", ".jpg", ".jpeg")):
199
  media_queue.append({"type": "image", "path": file_name})
200
  elif file_name.endswith(".mp4"):
201
  media_queue.append({"type": "video", "path": file_name})
202
+ for hist in history:
203
  if hist["role"] == "user" and isinstance(hist["content"], str):
204
+ text = hist["content"]
205
+ parts = re.split(r'(<image>|<video>)', text)
206
  for part in parts:
207
  if part == "<image>" and media_queue:
208
  user_content.append(media_queue.pop(0))
 
220
  "content": [{"type": "text", "text": hist["content"]}]
221
  })
222
  user_content = []
223
+ if user_content:
224
+ resulting_messages.append({"role": "user", "content": user_content})
225
+
226
  if text == "" and not files:
227
+ yield gr.Error("Please input a query and optionally image(s).")
228
  return
229
  if text == "" and files:
230
+ yield gr.Error("Please input a text query along with the image(s).")
231
  return
232
+
233
+ print("resulting_messages", resulting_messages)
234
+ inputs = processor.apply_chat_template(
235
  resulting_messages,
236
  add_generation_prompt=True,
237
  tokenize=True,
238
  return_dict=True,
239
  return_tensors="pt",
240
  )
241
+ inputs = inputs.to(model.device)
242
+
243
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
 
244
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
245
+
246
+ thread = Thread(target=model.generate, kwargs=generation_args)
247
  thread.start()
248
+
 
249
  buffer = ""
250
  for new_text in streamer:
251
  buffer += new_text
252
  time.sleep(0.01)
253
  yield buffer
254
 
255
+ # ------------------------------------------------------------------------------
 
 
 
 
 
256
  # GRADIO CHAT INTERFACE
257
+ # ------------------------------------------------------------------------------
258
+ examples = [
259
+ [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
260
+ [{"text": "What art era does this artpiece <image> and this artpiece <image> belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}],
261
+ [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}],
262
+ [{"text": "When was this purchase made and how much did it cost?", "files": ["example_images/fiche.jpg"]}],
263
+ [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
264
+ [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
265
+ [{"text": "@image A futuristic cityscape with vibrant neon lights"}],
266
+ ]
 
 
 
 
 
 
 
267
 
268
  demo = gr.ChatInterface(
269
+ fn=model_inference,
270
+ title="SmolVLM2 with Flux.1 Integration 📺",
271
+ description="Play with SmolVLM2 (HuggingFaceTB/SmolVLM2-2.2B-Instruct) with integrated Flux.1 image generation. Use the '@image' prefix to generate images with Flux.1.",
272
+ examples=examples,
273
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  stop_btn="Stop Generation",
275
  multimodal=True,
276
+ cache_examples=False,
277
+ additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
278
+ type="messages"
279
  )
280
 
281
  if __name__ == "__main__":
282
+ demo.launch(debug=True)