Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import json | |
import time | |
import asyncio | |
import re | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import edge_tts | |
import subprocess | |
# Install flash-attn with our environment flag (if needed) | |
subprocess.run( | |
'pip install flash-attn --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True | |
) | |
# Set torch backend configurations for Flux RealismLora | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# ------------------------------- | |
# CONFIGURATION & UTILITY FUNCTIONS | |
# ------------------------------- | |
MAX_SEED = 2**32 - 1 | |
def save_image(img: Image.Image) -> str: | |
"""Save a PIL image with a unique filename and return its path.""" | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def progress_bar_html(label: str) -> str: | |
""" | |
Returns an HTML snippet for an animated progress bar with a given label. | |
""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
# ------------------------------- | |
# FLUX REALISMLORA IMAGE GENERATION SETUP (New Implementation) | |
# ------------------------------- | |
from diffusers import DiffusionPipeline | |
base_model = "black-forest-labs/FLUX.1-dev" | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16) | |
lora_repo = "XLabs-AI/flux-RealismLora" | |
trigger_word = "" # No trigger word used. | |
pipe.load_lora_weights(lora_repo) | |
pipe.to("cuda") | |
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): | |
# Set random seed for reproducibility | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Update progress bar (0% at start) | |
progress(0, "Starting image generation...") | |
# Simulate progress updates during the steps | |
for i in range(1, steps + 1): | |
if steps >= 10 and i % (steps // 10) == 0: | |
progress(i / steps * 100, f"Processing step {i} of {steps}...") | |
# Generate image using the pipeline | |
image = pipe( | |
prompt=f"{prompt} {trigger_word}", | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
joint_attention_kwargs={"scale": lora_scale}, | |
).images[0] | |
# Final progress update (100%) | |
progress(100, "Completed!") | |
yield image, seed | |
# ------------------------------- | |
# SMOLVLM2 SETUP (Default Text/Multimodal Model) | |
# ------------------------------- | |
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer | |
smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") | |
smol_model = AutoModelForImageTextToText.from_pretrained( | |
"HuggingFaceTB/SmolVLM2-2.2B-Instruct", | |
_attn_implementation="flash_attention_2", | |
torch_dtype=torch.float16 | |
).to("cuda:0") | |
# ------------------------------- | |
# TTS UTILITY FUNCTIONS | |
# ------------------------------- | |
TTS_VOICES = [ | |
"en-US-JennyNeural", # @tts1 | |
"en-US-GuyNeural", # @tts2 | |
] | |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"): | |
"""Convert text to speech using Edge TTS and save the output as MP3.""" | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(output_file) | |
return output_file | |
# ------------------------------- | |
# CHAT / MULTIMODAL GENERATION FUNCTION | |
# ------------------------------- | |
def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200): | |
""" | |
Generates chatbot responses using SmolVLM2 with support for multimodal inputs and TTS. | |
Special commands: | |
- "@image": triggers image generation using the RealismLora flux implementation. | |
- "@tts1" or "@tts2": triggers text-to-speech after generation. | |
""" | |
torch.cuda.empty_cache() | |
text = input_dict["text"] | |
files = input_dict.get("files", []) | |
# If the query starts with "@image", use RealismLora to generate an image. | |
if text.strip().lower().startswith("@image"): | |
prompt = text[len("@image"):].strip() | |
yield progress_bar_html("Hold Tight Generating Flux RealismLora Image") | |
# Default parameters for RealismLora generation | |
default_cfg_scale = 3.2 | |
default_steps = 32 | |
default_width = 1152 | |
default_height = 896 | |
default_seed = 3981632454 | |
default_lora_scale = 0.85 | |
# Call the new run_lora function and yield its final result | |
for result in run_lora(prompt, default_cfg_scale, default_steps, True, default_seed, default_width, default_height, default_lora_scale, progress=gr.Progress(track_tqdm=True)): | |
final_result = result | |
yield gr.Image(final_result[0]) | |
return | |
# Handle TTS commands if present. | |
tts_prefix = "@tts" | |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3)) | |
voice = None | |
if is_tts: | |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None) | |
if voice_index: | |
voice = TTS_VOICES[voice_index - 1] | |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip() | |
yield "Processing with SmolVLM2" | |
# Build conversation messages based on input and history. | |
user_content = [] | |
media_queue = [] | |
if chat_history == []: | |
text = text.strip() | |
for file in files: | |
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): | |
media_queue.append({"type": "image", "path": file}) | |
elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")): | |
media_queue.append({"type": "video", "path": file}) | |
if "<image>" in text or "<video>" in text: | |
parts = re.split(r'(<image>|<video>)', text) | |
for part in parts: | |
if part == "<image>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part == "<video>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part.strip(): | |
user_content.append({"type": "text", "text": part.strip()}) | |
else: | |
user_content.append({"type": "text", "text": text}) | |
for media in media_queue: | |
user_content.append(media) | |
resulting_messages = [{"role": "user", "content": user_content}] | |
else: | |
resulting_messages = [] | |
user_content = [] | |
media_queue = [] | |
for hist in chat_history: | |
if hist["role"] == "user" and isinstance(hist["content"], tuple): | |
file_name = hist["content"][0] | |
if file_name.endswith((".png", ".jpg", ".jpeg")): | |
media_queue.append({"type": "image", "path": file_name}) | |
elif file_name.endswith(".mp4"): | |
media_queue.append({"type": "video", "path": file_name}) | |
for hist in chat_history: | |
if hist["role"] == "user" and isinstance(hist["content"], str): | |
txt = hist["content"] | |
parts = re.split(r'(<image>|<video>)', txt) | |
for part in parts: | |
if part == "<image>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part == "<video>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part.strip(): | |
user_content.append({"type": "text", "text": part.strip()}) | |
elif hist["role"] == "assistant": | |
resulting_messages.append({ | |
"role": "user", | |
"content": user_content | |
}) | |
resulting_messages.append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": hist["content"]}] | |
}) | |
user_content = [] | |
if not resulting_messages: | |
resulting_messages = [{"role": "user", "content": user_content}] | |
if text == "" and not files: | |
yield "Please input a query and optionally image(s)." | |
return | |
if text == "" and files: | |
yield "Please input a text query along with the image(s)." | |
return | |
inputs = smol_processor.apply_chat_template( | |
resulting_messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
) | |
if "pixel_values" in inputs: | |
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16) | |
inputs = inputs.to(smol_model.device) | |
streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True) | |
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens) | |
thread = Thread(target=smol_model.generate, kwargs=generation_args) | |
thread.start() | |
yield "..." | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
if is_tts and voice: | |
final_response = buffer | |
output_file = asyncio.run(text_to_speech(final_response, voice)) | |
yield gr.Audio(output_file, autoplay=True) | |
# ------------------------------- | |
# GRADIO CHAT INTERFACE | |
# ------------------------------- | |
DESCRIPTION = "# Flux RealismLora + SmolVLM2 Chat" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>" | |
css = ''' | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: #fff; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
''' | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"), | |
], | |
examples=[ | |
[{"text": "@image A futuristic cityscape at dusk in hyper-realistic style"}], | |
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}], | |
[{"text": "What does this document say?", "files": ["example_images/document.jpg"]}], | |
[{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}], | |
], | |
cache_examples=False, | |
type="messages", | |
description=DESCRIPTION, | |
css=css, | |
fill_height=True, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", ".mp4"], | |
file_count="multiple", | |
placeholder="Type text and/or upload media. Use '@image' for image gen, '@tts1' or '@tts2' for TTS." | |
), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) |