import os import random import uuid import json import time import asyncio import re from threading import Thread import gradio as gr import spaces import torch import numpy as np from PIL import Image import edge_tts import subprocess # Install flash-attn with our environment flag (if needed) subprocess.run( 'pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True ) # ------------------------------- # CONFIGURATION & UTILITY FUNCTIONS # ------------------------------- MAX_SEED = np.iinfo(np.int32).max def save_image(img: Image.Image) -> str: """Save a PIL image with a unique filename and return its path.""" unique_name = str(uuid.uuid4()) + ".png" img.save(unique_name) return unique_name def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed # Determine preferred torch dtype based on GPU support. bf16_supported = torch.cuda.is_bf16_supported() preferred_dtype = torch.bfloat16 if bf16_supported else torch.float16 # ------------------------------- # FLUX.1 IMAGE GENERATION SETUP # ------------------------------- from diffusers import DiffusionPipeline base_model = "black-forest-labs/FLUX.1-dev" pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=preferred_dtype) lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA" trigger_word = "Super Realism" # Leave blank if no trigger word is needed. pipe.load_lora_weights(lora_repo) pipe.to("cuda") # Define style prompts for Flux.1 style_list = [ { "name": "3840 x 2160", "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", }, { "name": "2560 x 1440", "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", }, { "name": "HD+", "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", }, { "name": "Style Zero", "prompt": "{prompt}", }, ] styles = {s["name"]: s["prompt"] for s in style_list} DEFAULT_STYLE_NAME = "3840 x 2160" STYLE_NAMES = list(styles.keys()) def apply_style(style_name: str, positive: str) -> str: return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive) @spaces.GPU(duration=60, enable_queue=True) def generate_image_flux( prompt: str, seed: int = 0, width: int = 1024, height: int = 1024, guidance_scale: float = 3, randomize_seed: bool = False, style_name: str = DEFAULT_STYLE_NAME, progress=gr.Progress(track_tqdm=True), ): """Generate an image using the Flux.1 pipeline with a chosen style.""" torch.cuda.empty_cache() # Clear unused GPU memory to prevent allocation errors seed = int(randomize_seed_fn(seed, randomize_seed)) positive_prompt = apply_style(style_name, prompt) if trigger_word: positive_prompt = f"{trigger_word} {positive_prompt}" # Wrap the diffusion call in no_grad to avoid unnecessary gradient state. with torch.no_grad(): images = pipe( prompt=positive_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=28, num_images_per_prompt=1, output_type="pil", ).images torch.cuda.synchronize() # Ensure all CUDA operations have completed image_paths = [save_image(img) for img in images] return image_paths, seed # ------------------------------- # SMOLVLM2 SETUP (Default Text/Multimodal Model) # ------------------------------- from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer smol_processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct") smol_model = AutoModelForImageTextToText.from_pretrained( "HuggingFaceTB/SmolVLM2-2.2B-Instruct", _attn_implementation="flash_attention_2", torch_dtype=preferred_dtype ).to("cuda:0") # ------------------------------- # UTILITY FUNCTIONS # ------------------------------- def progress_bar_html(label: str) -> str: """ Returns an HTML snippet for an animated progress bar with a given label. """ return f'''
{label}
''' TTS_VOICES = [ "en-US-JennyNeural", # @tts1 "en-US-GuyNeural", # @tts2 ] async def text_to_speech(text: str, voice: str, output_file="output.mp3"): """Convert text to speech using Edge TTS and save the output as MP3.""" communicate = edge_tts.Communicate(text, voice) await communicate.save(output_file) return output_file # ------------------------------- # CHAT / MULTIMODAL GENERATION FUNCTION # ------------------------------- @spaces.GPU def generate( input_dict: dict, chat_history: list[dict], max_tokens: int = 200, ): """ Generates chatbot responses using SmolVLM2 by default—with support for multimodal inputs and TTS. Special commands: - "@image": triggers image generation using the Flux.1 pipeline. - "@tts1" or "@tts2": triggers text-to-speech after generation. """ torch.cuda.empty_cache() # Clear unused GPU memory for consistency text = input_dict["text"] files = input_dict.get("files", []) # If the query starts with "@image", use Flux.1 to generate an image. if text.strip().lower().startswith("@image"): prompt = text[len("@image"):].strip() yield progress_bar_html("Hold Tight Generating Flux.1 Image") image_paths, used_seed = generate_image_flux( prompt=prompt, seed=1, width=1024, height=1024, guidance_scale=3, randomize_seed=True, style_name=DEFAULT_STYLE_NAME, progress=gr.Progress(track_tqdm=True), ) yield gr.Image(image_paths[0]) return # Handle TTS commands if present. tts_prefix = "@tts" is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3)) voice = None if is_tts: voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None) if voice_index: voice = TTS_VOICES[voice_index - 1] text = text.replace(f"{tts_prefix}{voice_index}", "").strip() # Use SmolVLM2 for chat/multimodal text generation. yield "Processing with SmolVLM2" # Build conversation messages based on input and history. user_content = [] media_queue = [] if chat_history == []: text = text.strip() for file in files: if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): media_queue.append({"type": "image", "path": file}) elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")): media_queue.append({"type": "video", "path": file}) if "" in text or "