prithivMLmods commited on
Commit
0b6db44
·
verified ·
1 Parent(s): 3fb8098

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -73
app.py CHANGED
@@ -23,9 +23,9 @@ from transformers import (
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
26
-
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
 
29
  """
30
 
31
  css = '''
@@ -40,6 +40,34 @@ h1 {
40
  background: #1565c0;
41
  border-radius: 100vh;
42
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  '''
44
 
45
  MAX_MAX_NEW_TOKENS = 2048
@@ -63,7 +91,7 @@ TTS_VOICES = [
63
  "en-US-GuyNeural", # @tts2
64
  ]
65
 
66
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
69
  MODEL_ID,
@@ -78,24 +106,20 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
78
  return output_file
79
 
80
  def clean_chat_history(chat_history):
81
- """
82
- Filter out any chat entries whose "content" is not a string.
83
- This helps prevent errors when concatenating previous messages.
84
- """
85
  cleaned = []
86
  for msg in chat_history:
87
  if isinstance(msg, dict) and isinstance(msg.get("content"), str):
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
- # Environment variables and parameters for Stable Diffusion XL
92
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
 
98
- # Load the SDXL pipeline
99
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
  MODEL_ID_SD,
101
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -104,22 +128,19 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
104
  ).to(device)
105
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
 
107
- # Ensure that the text encoder is in half-precision if using CUDA.
108
  if torch.cuda.is_available():
109
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
 
111
- # Optional: compile the model for speedup if enabled
112
  if USE_TORCH_COMPILE:
113
  sd_pipe.compile()
114
 
115
- # Optional: offload parts of the model to CPU if needed
116
  if ENABLE_CPU_OFFLOAD:
117
  sd_pipe.enable_model_cpu_offload()
118
 
119
  MAX_SEED = np.iinfo(np.int32).max
120
 
121
  def save_image(img: Image.Image) -> str:
122
- """Save a PIL image with a unique filename and return the path."""
123
  unique_name = str(uuid.uuid4()) + ".png"
124
  img.save(unique_name)
125
  return unique_name
@@ -144,7 +165,7 @@ def generate_image_fn(
144
  num_images: int = 1,
145
  progress=gr.Progress(track_tqdm=True),
146
  ):
147
- """Generate images using the SDXL pipeline."""
148
  seed = int(randomize_seed_fn(seed, randomize_seed))
149
  generator = torch.Generator(device=device).manual_seed(seed)
150
 
@@ -162,13 +183,11 @@ def generate_image_fn(
162
  options["use_resolution_binning"] = True
163
 
164
  images = []
165
- # Process in batches
166
  for i in range(0, num_images, BATCH_SIZE):
167
  batch_options = options.copy()
168
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
- # Wrap the pipeline call in autocast if using CUDA
172
  if device.type == "cuda":
173
  with torch.autocast("cuda", dtype=torch.float16):
174
  outputs = sd_pipe(**batch_options)
@@ -197,35 +216,14 @@ def generate(
197
  text = input_dict["text"]
198
  files = input_dict.get("files", [])
199
 
200
- # Define an HTML template for the animated progress bar.
201
- # The bar is a thin 5px line in light green with a simple opacity animation.
202
- progress_bar_html = """
203
- <div style="display: flex; align-items: center;">
204
- <span>{message}</span>
205
- <div style="flex-grow: 1; margin-left: 10px;">
206
- <div class="progress-bar"></div>
207
- </div>
208
- </div>
209
- <style>
210
- .progress-bar {{
211
- width: 100%;
212
- height: 5px;
213
- background: lightgreen;
214
- animation: progressAnim 2s infinite;
215
- }}
216
- @keyframes progressAnim {{
217
- 0% {{ opacity: 0.5; }}
218
- 50% {{ opacity: 1; }}
219
- 100% {{ opacity: 0.5; }}
220
- }}
221
- </style>
222
- """
223
-
224
  if text.strip().lower().startswith("@image"):
225
- # Remove the "@image" tag and use the rest as prompt.
226
  prompt = text[len("@image"):].strip()
227
- # Yield progress bar for image generation.
228
- yield gr.HTML(progress_bar_html.format(message="Generating Image..."))
 
 
 
 
229
  image_paths, used_seed = generate_image_fn(
230
  prompt=prompt,
231
  negative_prompt="",
@@ -239,9 +237,9 @@ def generate(
239
  use_resolution_binning=True,
240
  num_images=1,
241
  )
242
- # Once the image is generated, yield the image (thus replacing the progress bar).
243
  yield gr.Image(image_paths[0])
244
- return # Exit early
245
 
246
  tts_prefix = "@tts"
247
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -250,11 +248,9 @@ def generate(
250
  if is_tts and voice_index:
251
  voice = TTS_VOICES[voice_index - 1]
252
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
253
- # Clear previous chat history for a fresh TTS request.
254
  conversation = [{"role": "user", "content": text}]
255
  else:
256
  voice = None
257
- # Remove any stray @tts tags and build the conversation history.
258
  text = text.replace(tts_prefix, "").strip()
259
  conversation = clean_chat_history(chat_history)
260
  conversation.append({"role": "user", "content": text})
@@ -280,21 +276,18 @@ def generate(
280
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
281
  thread.start()
282
 
283
- # Yield progress bar for multimodal input processing.
284
- yield gr.HTML(progress_bar_html.format(message="Thinking..."))
 
 
 
285
  buffer = ""
286
  for new_text in streamer:
287
  buffer += new_text
288
  buffer = buffer.replace("<|im_end|>", "")
289
  time.sleep(0.01)
290
- # During streaming, update the progress UI (progress bar remains visible).
291
- combined_html = f"""
292
- <div style="display: flex; flex-direction: column;">
293
- {progress_bar_html.format(message="Thinking...")}
294
- <div style="margin-top: 10px;">{buffer}</div>
295
- </div>
296
- """
297
- yield gr.HTML(combined_html)
298
  else:
299
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
300
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
@@ -316,23 +309,18 @@ def generate(
316
  t = Thread(target=model.generate, kwargs=generation_kwargs)
317
  t.start()
318
 
319
- # Yield initial progress bar for text generation.
320
- yield gr.HTML(progress_bar_html.format(message="Thinking..."))
321
- outputs = []
 
 
 
322
  for new_text in streamer:
323
- outputs.append(new_text)
324
- combined_html = f"""
325
- <div style="display: flex; flex-direction: column;">
326
- {progress_bar_html.format(message="Thinking...")}
327
- <div style="margin-top: 10px;">{''.join(outputs)}</div>
328
- </div>
329
- """
330
- yield gr.HTML(combined_html)
331
- final_response = "".join(outputs)
332
- # Final response: progress bar is removed and only the generated text is shown.
333
- yield final_response
334
 
335
- # If TTS was requested, convert the final response to speech.
336
  if is_tts and voice:
337
  output_file = asyncio.run(text_to_speech(final_response, voice))
338
  yield gr.Audio(output_file, autoplay=True)
 
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
 
26
  DESCRIPTION = """
27
  # QwQ Edge 💬
28
+ **Note:** During image generation, a progress bar will appear both at the top of the interface and within the chat. For text generation, a loading animation will display until the response begins.
29
  """
30
 
31
  css = '''
 
40
  background: #1565c0;
41
  border-radius: 100vh;
42
  }
43
+
44
+ /* Custom styling for progress bars within chat */
45
+ .progress-bar-container {
46
+ width: 100%;
47
+ margin-top: 5px;
48
+ }
49
+
50
+ .progress-bar {
51
+ width: 100%;
52
+ height: 4px;
53
+ background-color: #e0e0e0;
54
+ border-radius: 2px;
55
+ }
56
+
57
+ .progress-bar::-webkit-progress-bar {
58
+ background-color: #e0e0e0;
59
+ border-radius: 2px;
60
+ }
61
+
62
+ .progress-bar::-webkit-progress-value {
63
+ background-color: #90ee90; /* Light green */
64
+ border-radius: 2px;
65
+ }
66
+
67
+ .progress-bar::-moz-progress-bar {
68
+ background-color: #90ee90; /* Light green */
69
+ border-radius: 2px;
70
+ }
71
  '''
72
 
73
  MAX_MAX_NEW_TOKENS = 2048
 
91
  "en-US-GuyNeural", # @tts2
92
  ]
93
 
94
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
95
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
96
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
97
  MODEL_ID,
 
106
  return output_file
107
 
108
  def clean_chat_history(chat_history):
109
+ """Filter out non-string content to prevent concatenation errors"""
 
 
 
110
  cleaned = []
111
  for msg in chat_history:
112
  if isinstance(msg, dict) and isinstance(msg.get("content"), str):
113
  cleaned.append(msg)
114
  return cleaned
115
 
116
+ # Stable Diffusion XL setup
117
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH")
118
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
119
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
120
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
121
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
122
 
 
123
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
124
  MODEL_ID_SD,
125
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
128
  ).to(device)
129
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
130
 
 
131
  if torch.cuda.is_available():
132
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
133
 
 
134
  if USE_TORCH_COMPILE:
135
  sd_pipe.compile()
136
 
 
137
  if ENABLE_CPU_OFFLOAD:
138
  sd_pipe.enable_model_cpu_offload()
139
 
140
  MAX_SEED = np.iinfo(np.int32).max
141
 
142
  def save_image(img: Image.Image) -> str:
143
+ """Save a PIL image with a unique filename and return the path"""
144
  unique_name = str(uuid.uuid4()) + ".png"
145
  img.save(unique_name)
146
  return unique_name
 
165
  num_images: int = 1,
166
  progress=gr.Progress(track_tqdm=True),
167
  ):
168
+ """Generate images using the SDXL pipeline"""
169
  seed = int(randomize_seed_fn(seed, randomize_seed))
170
  generator = torch.Generator(device=device).manual_seed(seed)
171
 
 
183
  options["use_resolution_binning"] = True
184
 
185
  images = []
 
186
  for i in range(0, num_images, BATCH_SIZE):
187
  batch_options = options.copy()
188
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
189
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
190
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
191
  if device.type == "cuda":
192
  with torch.autocast("cuda", dtype=torch.float16):
193
  outputs = sd_pipe(**batch_options)
 
216
  text = input_dict["text"]
217
  files = input_dict.get("files", [])
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  if text.strip().lower().startswith("@image"):
 
220
  prompt = text[len("@image"):].strip()
221
+ # Initial message with progress bar at 0%
222
+ yield gr.HTML(
223
+ '<div>Generating Image...</div>'
224
+ '<progress class="progress-bar" value="0" max="100" '
225
+ 'style="width:100%; height:4px; background-color:#e0e0e0;"></progress>'
226
+ )
227
  image_paths, used_seed = generate_image_fn(
228
  prompt=prompt,
229
  negative_prompt="",
 
237
  use_resolution_binning=True,
238
  num_images=1,
239
  )
240
+ # Final message with the image, progress bar at 100%
241
  yield gr.Image(image_paths[0])
242
+ return
243
 
244
  tts_prefix = "@tts"
245
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
248
  if is_tts and voice_index:
249
  voice = TTS_VOICES[voice_index - 1]
250
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
251
  conversation = [{"role": "user", "content": text}]
252
  else:
253
  voice = None
 
254
  text = text.replace(tts_prefix, "").strip()
255
  conversation = clean_chat_history(chat_history)
256
  conversation.append({"role": "user", "content": text})
 
276
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
277
  thread.start()
278
 
279
+ # Initial loading bar (indeterminate animation via CSS)
280
+ yield gr.HTML(
281
+ '<div>Generating response...</div>'
282
+ '<progress class="progress-bar" style="width:100%; height:4px; background-color:#e0e0e0;"></progress>'
283
+ )
284
  buffer = ""
285
  for new_text in streamer:
286
  buffer += new_text
287
  buffer = buffer.replace("<|im_end|>", "")
288
  time.sleep(0.01)
289
+ # Yield only the text, replacing the loading bar
290
+ yield buffer
 
 
 
 
 
 
291
  else:
292
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
293
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
 
309
  t = Thread(target=model.generate, kwargs=generation_kwargs)
310
  t.start()
311
 
312
+ # Initial loading bar
313
+ yield gr.HTML(
314
+ '<div>Generating response...</div>'
315
+ '<progress class="progress-bar" style="width:100%; height:4px; background-color:#e0e0e0;"></progress>'
316
+ )
317
+ buffer = ""
318
  for new_text in streamer:
319
+ buffer += new_text
320
+ # Yield only the text, replacing the loading bar
321
+ yield buffer
 
 
 
 
 
 
 
 
322
 
323
+ final_response = buffer
324
  if is_tts and voice:
325
  output_file = asyncio.run(text_to_speech(final_response, voice))
326
  yield gr.Audio(output_file, autoplay=True)