Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import uuid | |
| import json | |
| import time | |
| import asyncio | |
| import re | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import edge_tts | |
| import subprocess | |
| # Install flash-attn with our environment flag (if needed) | |
| subprocess.run( | |
| 'pip install flash-attn --no-build-isolation', | |
| env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
| shell=True | |
| ) | |
| # Set torch backend configurations for Flux RealismLora | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # ------------------------------- | |
| # CONFIGURATION & UTILITY FUNCTIONS | |
| # ------------------------------- | |
| MAX_SEED = 2**32 - 1 | |
| def save_image(img: Image.Image) -> str: | |
| """Save a PIL image with a unique filename and return its path.""" | |
| unique_name = str(uuid.uuid4()) + ".png" | |
| img.save(unique_name) | |
| return unique_name | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def progress_bar_html(label: str) -> str: | |
| """ | |
| Returns an HTML snippet for an animated progress bar with a given label. | |
| """ | |
| return f''' | |
| <div style="display: flex; align-items: center;"> | |
| <span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
| <div style="width: 110px; height: 5px; background-color: #FFC0CB; border-radius: 2px; overflow: hidden;"> | |
| <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div> | |
| </div> | |
| </div> | |
| <style> | |
| @keyframes loading {{ | |
| 0% {{ transform: translateX(-100%); }} | |
| 100% {{ transform: translateX(100%); }} | |
| }} | |
| </style> | |
| ''' | |
| # ------------------------------- | |
| # FLUX REALISMLORA IMAGE GENERATION SETUP (New Implementation) | |
| # ------------------------------- | |
| from diffusers import DiffusionPipeline | |
| base_model = "black-forest-labs/FLUX.1-dev" | |
| pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16) | |
| lora_repo = "XLabs-AI/flux-RealismLora" | |
| trigger_word = "" # No trigger word used. | |
| pipe.load_lora_weights(lora_repo) | |
| pipe.to("cuda") | |
| def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)): | |
| # Set random seed for reproducibility | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| # Update progress bar (0% at start) | |
| progress(0, "Starting image generation...") | |
| # Simulate progress updates during the steps | |
| for i in range(1, steps + 1): | |
| if steps >= 10 and i % (steps // 10) == 0: | |
| progress(i / steps * 100, f"Processing step {i} of {steps}...") | |
| # Generate image using the pipeline | |
| image = pipe( | |
| prompt=f"{prompt} {trigger_word}", | |
| num_inference_steps=steps, | |
| guidance_scale=cfg_scale, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| joint_attention_kwargs={"scale": lora_scale}, | |
| ).images[0] | |
| # Final progress update (100%) | |
| progress(100, "Completed!") | |
| yield image, seed | |
| # ------------------------------- | |
| # SMOLVLM2 SETUP (Default Text/Multimodal Model) | |
| # ------------------------------- | |
| from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer | |
| smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") | |
| smol_model = AutoModelForImageTextToText.from_pretrained( | |
| "HuggingFaceTB/SmolVLM2-2.2B-Instruct", | |
| _attn_implementation="flash_attention_2", | |
| torch_dtype=torch.float16 | |
| ).to("cuda:0") | |
| # ------------------------------- | |
| # TTS UTILITY FUNCTIONS | |
| # ------------------------------- | |
| TTS_VOICES = [ | |
| "en-US-JennyNeural", # @tts1 | |
| "en-US-GuyNeural", # @tts2 | |
| ] | |
| async def text_to_speech(text: str, voice: str, output_file="output.mp3"): | |
| """Convert text to speech using Edge TTS and save the output as MP3.""" | |
| communicate = edge_tts.Communicate(text, voice) | |
| await communicate.save(output_file) | |
| return output_file | |
| # ------------------------------- | |
| # CHAT / MULTIMODAL GENERATION FUNCTION | |
| # ------------------------------- | |
| def generate(input_dict: dict, chat_history: list[dict], max_tokens: int = 200): | |
| """ | |
| Generates chatbot responses using SmolVLM2 with support for multimodal inputs and TTS. | |
| Special commands: | |
| - "@image": triggers image generation using the RealismLora flux implementation. | |
| - "@tts1" or "@tts2": triggers text-to-speech after generation. | |
| """ | |
| torch.cuda.empty_cache() | |
| text = input_dict["text"] | |
| files = input_dict.get("files", []) | |
| # If the query starts with "@image", use RealismLora to generate an image. | |
| if text.strip().lower().startswith("@image"): | |
| prompt = text[len("@image"):].strip() | |
| yield progress_bar_html("Hold Tight Generating Flux RealismLora Image") | |
| # Default parameters for RealismLora generation | |
| default_cfg_scale = 3.2 | |
| default_steps = 32 | |
| default_width = 1152 | |
| default_height = 896 | |
| default_seed = 3981632454 | |
| default_lora_scale = 0.85 | |
| # Call the new run_lora function and yield its final result | |
| for result in run_lora(prompt, default_cfg_scale, default_steps, True, default_seed, default_width, default_height, default_lora_scale, progress=gr.Progress(track_tqdm=True)): | |
| final_result = result | |
| yield gr.Image(final_result[0]) | |
| return | |
| # Handle TTS commands if present. | |
| tts_prefix = "@tts" | |
| is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3)) | |
| voice = None | |
| if is_tts: | |
| voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None) | |
| if voice_index: | |
| voice = TTS_VOICES[voice_index - 1] | |
| text = text.replace(f"{tts_prefix}{voice_index}", "").strip() | |
| yield "Processing with SmolVLM2" | |
| # Build conversation messages based on input and history. | |
| user_content = [] | |
| media_queue = [] | |
| if chat_history == []: | |
| text = text.strip() | |
| for file in files: | |
| if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): | |
| media_queue.append({"type": "image", "path": file}) | |
| elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")): | |
| media_queue.append({"type": "video", "path": file}) | |
| if "<image>" in text or "<video>" in text: | |
| parts = re.split(r'(<image>|<video>)', text) | |
| for part in parts: | |
| if part == "<image>" and media_queue: | |
| user_content.append(media_queue.pop(0)) | |
| elif part == "<video>" and media_queue: | |
| user_content.append(media_queue.pop(0)) | |
| elif part.strip(): | |
| user_content.append({"type": "text", "text": part.strip()}) | |
| else: | |
| user_content.append({"type": "text", "text": text}) | |
| for media in media_queue: | |
| user_content.append(media) | |
| resulting_messages = [{"role": "user", "content": user_content}] | |
| else: | |
| resulting_messages = [] | |
| user_content = [] | |
| media_queue = [] | |
| for hist in chat_history: | |
| if hist["role"] == "user" and isinstance(hist["content"], tuple): | |
| file_name = hist["content"][0] | |
| if file_name.endswith((".png", ".jpg", ".jpeg")): | |
| media_queue.append({"type": "image", "path": file_name}) | |
| elif file_name.endswith(".mp4"): | |
| media_queue.append({"type": "video", "path": file_name}) | |
| for hist in chat_history: | |
| if hist["role"] == "user" and isinstance(hist["content"], str): | |
| txt = hist["content"] | |
| parts = re.split(r'(<image>|<video>)', txt) | |
| for part in parts: | |
| if part == "<image>" and media_queue: | |
| user_content.append(media_queue.pop(0)) | |
| elif part == "<video>" and media_queue: | |
| user_content.append(media_queue.pop(0)) | |
| elif part.strip(): | |
| user_content.append({"type": "text", "text": part.strip()}) | |
| elif hist["role"] == "assistant": | |
| resulting_messages.append({ | |
| "role": "user", | |
| "content": user_content | |
| }) | |
| resulting_messages.append({ | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": hist["content"]}] | |
| }) | |
| user_content = [] | |
| if not resulting_messages: | |
| resulting_messages = [{"role": "user", "content": user_content}] | |
| if text == "" and not files: | |
| yield "Please input a query and optionally image(s)." | |
| return | |
| if text == "" and files: | |
| yield "Please input a text query along with the image(s)." | |
| return | |
| inputs = smol_processor.apply_chat_template( | |
| resulting_messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| if "pixel_values" in inputs: | |
| inputs["pixel_values"] = inputs["pixel_values"].to(torch.float16) | |
| inputs = inputs.to(smol_model.device) | |
| streamer = TextIteratorStreamer(smol_processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens) | |
| thread = Thread(target=smol_model.generate, kwargs=generation_args) | |
| thread.start() | |
| yield "..." | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| time.sleep(0.01) | |
| yield buffer | |
| if is_tts and voice: | |
| final_response = buffer | |
| output_file = asyncio.run(text_to_speech(final_response, voice)) | |
| yield gr.Audio(output_file, autoplay=True) | |
| # ------------------------------- | |
| # GRADIO CHAT INTERFACE | |
| # ------------------------------- | |
| DESCRIPTION = "# Flux RealismLora + SmolVLM2 Chat" | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>⚠️Running on CPU, this may not work as expected.</p>" | |
| css = ''' | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| #duplicate-button { | |
| margin: auto; | |
| color: #fff; | |
| background: #1565c0; | |
| border-radius: 100vh; | |
| } | |
| ''' | |
| demo = gr.ChatInterface( | |
| fn=generate, | |
| additional_inputs=[ | |
| gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens"), | |
| ], | |
| examples=[ | |
| [{"text": "@image A futuristic cityscape at dusk in hyper-realistic style"}], | |
| [{"text": "Describe this image.", "files": ["example_images/mosque.jpg"]}], | |
| [{"text": "What does this document say?", "files": ["example_images/document.jpg"]}], | |
| [{"text": "@tts1 Explain the weather patterns shown in this diagram.", "files": ["example_images/examples_weather_events.png"]}], | |
| ], | |
| cache_examples=False, | |
| type="messages", | |
| description=DESCRIPTION, | |
| css=css, | |
| fill_height=True, | |
| textbox=gr.MultimodalTextbox( | |
| label="Query Input", | |
| file_types=["image", ".mp4"], | |
| file_count="multiple", | |
| placeholder="Type text and/or upload media. Use '@image' for image gen, '@tts1' or '@tts2' for TTS." | |
| ), | |
| stop_btn="Stop Generation", | |
| multimodal=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(share=True) |