prithivMLmods commited on
Commit
5b985be
·
verified ·
1 Parent(s): ecce109

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -43
app.py CHANGED
@@ -42,6 +42,24 @@ h1 {
42
  }
43
  '''
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  MAX_MAX_NEW_TOKENS = 2048
46
  DEFAULT_MAX_NEW_TOKENS = 1024
47
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -88,23 +106,6 @@ def clean_chat_history(chat_history):
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
- # Helper: returns HTML code for a thin light-green animated progress bar with a label.
92
- def progress_bar_html(label: str) -> str:
93
- return f'''
94
- <div style="display: flex; align-items: center;">
95
- <span>{label}</span>
96
- <div style="flex-grow: 1; margin-left: 8px; height: 5px; background-color: lightgreen; overflow: hidden; position: relative;">
97
- <div style="width: 100%; height: 100%; background: linear-gradient(90deg, rgba(255,255,255,0) 0%, rgba(255,255,255,0.5) 50%, rgba(255,255,255,0) 100%); animation: progressAnim 1s linear infinite;"></div>
98
- </div>
99
- </div>
100
- <style>
101
- @keyframes progressAnim {{
102
- 0% {{ transform: translateX(-100%); }}
103
- 100% {{ transform: translateX(100%); }}
104
- }}
105
- </style>
106
- '''
107
-
108
  # Environment variables and parameters for Stable Diffusion XL
109
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
110
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -185,7 +186,6 @@ def generate_image_fn(
185
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
186
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
187
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
188
- # Wrap the pipeline call in autocast if using CUDA
189
  if device.type == "cuda":
190
  with torch.autocast("cuda", dtype=torch.float16):
191
  outputs = sd_pipe(**batch_options)
@@ -214,13 +214,12 @@ def generate(
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
217
- # For image generation triggered by "@image"
218
  if text.strip().lower().startswith("@image"):
219
  # Remove the "@image" tag and use the rest as prompt
220
  prompt = text[len("@image"):].strip()
221
- # Yield a progress bar with label "Generating Image"
222
- progress_component = gr.HTML(progress_bar_html("Generating Image"))
223
- yield progress_component
224
  image_paths, used_seed = generate_image_fn(
225
  prompt=prompt,
226
  negative_prompt="",
@@ -234,7 +233,7 @@ def generate(
234
  use_resolution_binning=True,
235
  num_images=1,
236
  )
237
- # Clear the progress bar (replace with empty HTML) and then yield the image
238
  yield gr.HTML.update(value="")
239
  yield gr.Image(image_paths[0])
240
  return # Exit early
@@ -255,7 +254,6 @@ def generate(
255
  conversation = clean_chat_history(chat_history)
256
  conversation.append({"role": "user", "content": text})
257
 
258
- # If there are attached image files, use multimodal processing
259
  if files:
260
  if len(files) > 1:
261
  images = [load_image(image) for image in files]
@@ -277,19 +275,17 @@ def generate(
277
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
278
  thread.start()
279
 
 
 
 
280
  buffer = ""
281
- # Yield a progress bar with label "Thinking..."
282
- progress_component = gr.HTML(progress_bar_html("Thinking..."))
283
- yield progress_component
284
  for new_text in streamer:
285
  buffer += new_text
286
  buffer = buffer.replace("<|im_end|>", "")
287
  time.sleep(0.01)
288
- # Clear the progress bar and yield the final result text.
289
- yield gr.HTML.update(value="")
290
- yield buffer
291
  else:
292
- # For pure text responses:
293
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
294
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
295
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -307,23 +303,23 @@ def generate(
307
  "num_beams": 1,
308
  "repetition_penalty": repetition_penalty,
309
  }
310
- t = Thread(target=model.generate, kwargs=generation_kwargs)
311
- t.start()
312
 
313
- outputs = []
314
- # Yield a progress bar with label "Thinking..."
315
- progress_component = gr.HTML(progress_bar_html("Thinking..."))
316
- yield progress_component
317
  for new_text in streamer:
318
- outputs.append(new_text)
319
- final_response = "".join(outputs)
320
- # Clear the progress bar and yield the final plain text result.
321
- yield gr.HTML.update(value="")
322
- yield final_response
323
 
324
  # If TTS was requested, convert the final response to speech.
325
  if is_tts and voice:
326
- output_file = asyncio.run(text_to_speech(final_response, voice))
327
  yield gr.Audio(output_file, autoplay=True)
328
 
329
  demo = gr.ChatInterface(
 
42
  }
43
  '''
44
 
45
+ def progress_bar_html(label):
46
+ """Returns an HTML snippet with a label and an animated thin progress bar."""
47
+ return f"""
48
+ <div style="display: flex; align-items: center;">
49
+ <span style="margin-right: 10px;">{label}</span>
50
+ <div style="position: relative; width: 110px; height: 5px; background-color: #e0e0e0; border-radius: 2.5px; overflow: hidden;">
51
+ <div style="width: 100%; height: 100%; background-color: #90ee90; animation: progressAnimation 2s infinite;"></div>
52
+ </div>
53
+ <style>
54
+ @keyframes progressAnimation {{
55
+ 0% {{ opacity: 1; }}
56
+ 50% {{ opacity: 0.5; }}
57
+ 100% {{ opacity: 1; }}
58
+ }}
59
+ </style>
60
+ </div>
61
+ """
62
+
63
  MAX_MAX_NEW_TOKENS = 2048
64
  DEFAULT_MAX_NEW_TOKENS = 1024
65
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
106
  cleaned.append(msg)
107
  return cleaned
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # Environment variables and parameters for Stable Diffusion XL
110
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
111
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
186
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
187
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
188
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
189
  if device.type == "cuda":
190
  with torch.autocast("cuda", dtype=torch.float16):
191
  outputs = sd_pipe(**batch_options)
 
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
 
217
  if text.strip().lower().startswith("@image"):
218
  # Remove the "@image" tag and use the rest as prompt
219
  prompt = text[len("@image"):].strip()
220
+ # Show a progress bar for image generation
221
+ progress_html = progress_bar_html("Generating Image")
222
+ yield gr.HTML(progress_html)
223
  image_paths, used_seed = generate_image_fn(
224
  prompt=prompt,
225
  negative_prompt="",
 
233
  use_resolution_binning=True,
234
  num_images=1,
235
  )
236
+ # Remove the progress bar and then yield the generated image
237
  yield gr.HTML.update(value="")
238
  yield gr.Image(image_paths[0])
239
  return # Exit early
 
254
  conversation = clean_chat_history(chat_history)
255
  conversation.append({"role": "user", "content": text})
256
 
 
257
  if files:
258
  if len(files) > 1:
259
  images = [load_image(image) for image in files]
 
275
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
276
  thread.start()
277
 
278
+ # Show a progress bar while processing the multimodal input
279
+ progress_html = progress_bar_html("Thinking...")
280
+ yield gr.HTML(progress_html)
281
  buffer = ""
 
 
 
282
  for new_text in streamer:
283
  buffer += new_text
284
  buffer = buffer.replace("<|im_end|>", "")
285
  time.sleep(0.01)
286
+ # Update the same message to display the final result (removing the progress bar)
287
+ yield gr.HTML.update(value=buffer)
 
288
  else:
 
289
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
290
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
291
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
303
  "num_beams": 1,
304
  "repetition_penalty": repetition_penalty,
305
  }
306
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
307
+ thread.start()
308
 
309
+ # Show a progress bar for text generation
310
+ progress_html = progress_bar_html("Thinking...")
311
+ yield gr.HTML(progress_html)
312
+ buffer = ""
313
  for new_text in streamer:
314
+ buffer += new_text
315
+ buffer = buffer.replace("<|im_end|>", "")
316
+ time.sleep(0.01)
317
+ # Replace the progress bar with the final text response
318
+ yield gr.HTML.update(value=buffer)
319
 
320
  # If TTS was requested, convert the final response to speech.
321
  if is_tts and voice:
322
+ output_file = asyncio.run(text_to_speech(buffer, voice))
323
  yield gr.Audio(output_file, autoplay=True)
324
 
325
  demo = gr.ChatInterface(