Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import json | |
import time | |
import asyncio | |
import re | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import edge_tts | |
import subprocess | |
# Install flash-attn with our environment flag (if needed) | |
subprocess.run( | |
'pip install flash-attn --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True | |
) | |
# ------------------------------- | |
# CONFIGURATION & UTILITY FUNCTIONS | |
# ------------------------------- | |
MAX_SEED = np.iinfo(np.int32).max | |
def save_image(img: Image.Image) -> str: | |
"""Save a PIL image with a unique filename and return its path.""" | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
# Determine preferred torch dtype based on GPU support. | |
bf16_supported = torch.cuda.is_bf16_supported() | |
preferred_dtype = torch.bfloat16 if bf16_supported else torch.float16 | |
# ------------------------------- | |
# FLUX.1 IMAGE GENERATION SETUP | |
# ------------------------------- | |
from diffusers import DiffusionPipeline | |
base_model = "black-forest-labs/FLUX.1-dev" | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=preferred_dtype) | |
lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA" | |
trigger_word = "Super Realism" # Leave blank if no trigger word is needed. | |
pipe.load_lora_weights(lora_repo) | |
pipe.to("cuda") | |
# Define style prompts for Flux.1 | |
style_list = [ | |
{ | |
"name": "3840 x 2160", | |
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
}, | |
{ | |
"name": "2560 x 1440", | |
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
}, | |
{ | |
"name": "HD+", | |
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
}, | |
{ | |
"name": "Style Zero", | |
"prompt": "{prompt}", | |
}, | |
] | |
styles = {s["name"]: s["prompt"] for s in style_list} | |
DEFAULT_STYLE_NAME = "3840 x 2160" | |
STYLE_NAMES = list(styles.keys()) | |
def apply_style(style_name: str, positive: str) -> str: | |
return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive) | |
def generate_image_flux( | |
prompt: str, | |
seed: int = 0, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 3, | |
randomize_seed: bool = False, | |
style_name: str = DEFAULT_STYLE_NAME, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
"""Generate an image using the Flux.1 pipeline with a chosen style.""" | |
torch.cuda.empty_cache() # Clear unused GPU memory to prevent allocation errors | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
positive_prompt = apply_style(style_name, prompt) | |
if trigger_word: | |
positive_prompt = f"{trigger_word} {positive_prompt}" | |
# Wrap the diffusion call in no_grad to avoid unnecessary gradient state. | |
with torch.no_grad(): | |
images = pipe( | |
prompt=positive_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=28, | |
num_images_per_prompt=1, | |
output_type="pil", | |
).images | |
torch.cuda.synchronize() # Ensure all CUDA operations have completed | |
image_paths = [save_image(img) for img in images] | |
return image_paths, seed | |
# ------------------------------- | |
# SMOLVLM2 SETUP (Default Text/Multimodal Model) | |
# ------------------------------- | |
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer | |
smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") | |
smol_model = AutoModelForImageTextToText.from_pretrained( | |
"HuggingFaceTB/SmolVLM2-2.2B-Instruct", | |
_attn_implementation="flash_attention_2", | |
torch_dtype=preferred_dtype | |
).to("cuda:0") | |
# ------------------------------- | |
# UTILITY FUNCTIONS | |
# ------------------------------- | |
def progress_bar_html(label: str) -> str: | |
""" | |
Returns an HTML snippet for an animated progress bar with a given label. | |
""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
TTS_VOICES = [ | |
"en-US-JennyNeural", # @tts1 | |
"en-US-GuyNeural", # @tts2 | |
] | |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"): | |
"""Convert text to speech using Edge TTS and save the output as MP3.""" | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(output_file) | |
return output_file | |
# ------------------------------- | |
# CHAT / MULTIMODAL GENERATION FUNCTION | |
# ------------------------------- | |
def generate( | |
input_dict: dict, | |
chat_history: list[dict], | |
max_tokens: int = 200, | |
): | |
""" | |
Generates chatbot responses using SmolVLM2 by default—with support for multimodal inputs and TTS. | |
Special commands: | |
- "@image": triggers image generation using the Flux.1 pipeline. | |
- "@tts1" or "@tts2": triggers text-to-speech after generation. | |
""" | |
torch.cuda.empty_cache() # Clear unused GPU memory for consistency | |
text = input_dict["text"] | |
files = input_dict.get("files", []) | |
# If the query starts with "@image", use Flux.1 to generate an image. | |
if text.strip().lower().startswith("@image"): | |
prompt = text[len("@image"):].strip() | |
yield progress_bar_html("Hold Tight Generating Flux.1 Image") | |
image_paths, used_seed = generate_image_flux( | |
prompt=prompt, | |
seed=1, | |
width=1024, | |
height=1024, | |
guidance_scale=3, | |
randomize_seed=True, | |
style_name=DEFAULT_STYLE_NAME, | |
progress=gr.Progress(track_tqdm=True), | |
) | |
yield gr.Image(image_paths[0]) | |
return | |
# Handle TTS commands if present. | |
tts_prefix = "@tts" | |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3)) | |
voice = None | |
if is_tts: | |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None) | |
if voice_index: | |
voice = TTS_VOICES[voice_index - 1] | |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip() | |
# Use SmolVLM2 for chat/multimodal text generation. | |
yield "Processing with SmolVLM2" | |
# Build conversation messages based on input and history. | |
user_content = [] | |
media_queue = [] | |
if chat_history == []: | |
text = text.strip() | |
for file in files: | |
if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): | |
media_queue.append({"type": "image", "path": file}) | |
elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")): | |
media_queue.append({"type": "video", "path": file}) | |
if "<image>" in text or "<video>" in text: | |
parts = re.split(r'(<image>|<video>)', text) | |
for part in parts: | |
if part == "<image>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part == "<video>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part.strip(): | |
user_content.append({"type": "text", "text": part.strip()}) | |
else: | |
user_content.append({"type": "text", "text": text}) | |
for media in media_queue: | |
user_content.append(media) | |
resulting_messages = [{"role": "user", "content": user_content}] | |
else: | |
resulting_messages = [] | |
user_content = [] | |
media_queue = [] | |
for hist in chat_history: | |
if hist["role"] == "user" and isinstance(hist["content"], tuple): | |
file_name = hist["content"][0] | |
if file_name.endswith((".png", ".jpg", ".jpeg")): | |
media_queue.append({"type": "image", "path": file_name}) | |
elif file_name.endswith(".mp4"): | |
media_queue.append({"type": "video", "path": file_name}) | |
for hist in chat_history: | |
if hist["role"] == "user" and isinstance(hist["content"], str): | |
txt = hist["content"] | |
parts = re.split(r'(<image>|<video>)', txt) | |
for part in parts: | |
if part == "<image>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part == "<video>" and media_queue: | |
user_content.append(media_queue.pop(0)) | |
elif part.strip(): | |
user_content.append({"type": "text", "text": part.strip()}) | |
elif hist["role"] == "assistant": | |
resulting_messages.append({ | |
"role": "user", | |
"content": user_content | |
}) | |
resulting_messages.append({ | |
"role": "assistant", | |
"content": [{"type": "text", "text": hist["content"]}] | |
}) | |
user_content = [] | |
if not resulting_messages: | |
resulting_messages = [{"role": "user", "content": user_content}] | |
if text == "" and not files: | |
yield "Please input a query and optionally image(s)." | |
return | |
if text == "" and files: | |
yield "Please input a text query along with the image(s)." | |
return | |
print("resulting_messages", resulting_messages) | |
inputs = smol_processor.apply_chat_template( | |
resulting_messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt", | |
) | |
# Explicitly cast pixel values to the preferred dtype to match model weights. | |
if "pixel_values" in inputs: | |
inputs["pixel_values"] = inputs["pixel_values"].to(preferred_dtype) | |
inputs = inputs.to(smol_model.device) | |
streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True) | |
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens) | |
thread = Thread(target=smol_model.generate, kwargs=generation_args) | |
thread.start() | |
yield "..." | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
if is_tts and voice: | |
final_response = buffer | |
output_file = asyncio.run(text_to_speech(final_response, voice)) | |
yield gr.Audio(output_file, autoplay=True) | |
# ------------------------------- | |
# GRADIO CHAT INTERFACE | |
# ------------------------------- | |
DESCRIPTION = "# Flux.1 Realism 🥖 + SmolVLM2 Chat" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>" | |
css = ''' | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
#duplicate-button { | |
margin: auto; | |
color: #fff; | |
background: #1565c0; | |
border-radius: 100vh; | |
} | |
''' | |
demo = gr.ChatInterface( | |
fn=generate, | |
additional_inputs=[ | |
gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"), | |
], | |
examples=[ | |
[{"text": "@image A futuristic cityscape at dusk in hyper-realistic 8K"}], | |
[{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}], | |
[{"text": "What does this document say?", "files": ["example_images/document.jpg"]}], | |
[{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}], | |
], | |
cache_examples=False, | |
type="messages", | |
description=DESCRIPTION, | |
css=css, | |
fill_height=True, | |
textbox=gr.MultimodalTextbox( | |
label="Query Input", | |
file_types=["image", ".mp4"], | |
file_count="multiple", | |
placeholder="Type text and/or upload media. Use '@image' for Flux.1 image gen, '@tts1' or '@tts2' for TTS." | |
), | |
stop_btn="Stop Generation", | |
multimodal=True, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) |