import spaces import gradio as gr import torch from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline import re import random import numpy as np import os from huggingface_hub import snapshot_download # Initialize models device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 huggingface_token = os.getenv("HUGGINGFACE_TOKEN") model_path = snapshot_download( repo_id="stabilityai/stable-diffusion-3-medium", revision="refs/pr/26", repo_type="model", ignore_patterns=["*.md", "*..gitattributes"], local_dir="SD3", token=huggingface_token, # type a new token-id. ) # VLM Captioner vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to(device).eval() vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner") # Prompt Enhancer enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device) enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device) def load_pipeline(pipeline_type): if pipeline_type == "text2img": return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) elif pipeline_type == "img2img": return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1344 # VLM Captioner function def create_captions_rich(image): prompt = "caption en" model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device) input_len = model_inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = vlm_model.generate(**model_inputs, max_new_tokens=256, do_sample=False) generation = generation[0][input_len:] decoded = vlm_processor.decode(generation, skip_special_tokens=True) return modify_caption(decoded) # Helper function for caption modification def modify_caption(caption: str) -> str: prefix_substrings = [ ('captured from ', ''), ('captured at ', '') ] pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) replacers = {opening: replacer for opening, replacer in prefix_substrings} def replace_fn(match): return replacers[match.group(0)] return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE) # Prompt Enhancer function def enhance_prompt(input_prompt, model_choice): if model_choice == "Medium": result = enhancer_medium("Enhance the description: " + input_prompt) enhanced_text = result[0]['summary_text'] pattern = r'^.*?of\s+(.*?(?:\.|$))' match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL) if match: remaining_text = enhanced_text[match.end():].strip() modified_sentence = match.group(1).capitalize() enhanced_text = modified_sentence + ' ' + remaining_text else: # Long result = enhancer_long("Enhance the description: " + input_prompt) enhanced_text = result[0]['summary_text'] return enhanced_text # SD3 Generation function def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) pipe = load_pipeline("text2img") image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator ).images[0] return image, seed # Gradio Interface @spaces.GPU def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps): if use_vlm and image is not None: prompt = create_captions_rich(image) else: prompt = text_prompt if use_enhancer: prompt = enhance_prompt(prompt, model_choice) generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps) return generated_image, prompt, used_seed @spaces.GPU def img2img_generate( prompt: str, init_image: gr.Image, use_vlm: bool, use_enhancer: bool, model_choice: str, negative_prompt: str = "", seed: int = 0, randomize_seed: bool = False, guidance_scale: float = 7, num_inference_steps: int = 30, strength: float = 0.8, ): if use_vlm and init_image is not None: prompt = create_captions_rich(init_image) if use_enhancer: prompt = enhance_prompt(prompt, model_choice) if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) img2img_pipe = load_pipeline("img2img") init_image = init_image.resize((768, 768)) image = img2img_pipe( prompt=prompt, image=init_image, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator, strength=strength, ).images[0] return image, prompt, seed custom_css = """ .input-group, .output-group { border: 1px solid #e0e0e0; border-radius: 10px; padding: 20px; margin-bottom: 20px; background-color: #f9f9f9; } .submit-btn { background-color: #2980b9 !important; color: white !important; } .submit-btn:hover { background-color: #3498db !important; } """ # Gradio Interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo: gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator") with gr.Tab(label="Text to Image"): with gr.Row(): with gr.Column(scale=1): with gr.Group(elem_classes="input-group"): input_image = gr.Image(label="Input Image for VLM") use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False) with gr.Group(elem_classes="input-group"): text_prompt = gr.Textbox(label="Text Prompt") use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False) model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Textbox(label="Negative Prompt") seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024) guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0) num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28) generate_btn = gr.Button("Generate Image", elem_classes="submit-btn") with gr.Column(scale=1): with gr.Group(elem_classes="output-group"): output_image = gr.Image(label="Generated Image") final_prompt = gr.Textbox(label="Final Prompt Used") used_seed = gr.Number(label="Seed Used") generate_btn.click( fn=process_workflow, inputs=[ input_image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps ], outputs=[output_image, final_prompt, used_seed] ) with gr.Tab(label="Image to Image"): with gr.Row(): with gr.Column(scale=1): with gr.Group(elem_classes="input-group"): init_image = gr.Image(label="Input Image", type="pil") use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False) with gr.Group(elem_classes="input-group"): img2img_prompt = gr.Textbox(label="Text Prompt") use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False) model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long") with gr.Accordion("Advanced Settings", open=False): negative_prompt = gr.Textbox(label="Negative Prompt") seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=5) num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28) strength = gr.Slider(label="Img2Img Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.5) img2img_generate_btn = gr.Button("Generate Image", elem_classes="submit-btn") with gr.Column(scale=1): with gr.Group(elem_classes="output-group"): img2img_output = gr.Image(label="Generated Image") img2img_final_prompt = gr.Textbox(label="Final Prompt Used") img2img_used_seed = gr.Number(label="Seed Used") img2img_generate_btn.click( fn=img2img_generate, inputs=[ img2img_prompt, init_image, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, strength ], outputs=[img2img_output, img2img_final_prompt, img2img_used_seed] ) demo.launch(debug=True)