Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,852 Bytes
27de533 9e88c26 27de533 0a0ceda 27de533 9e88c26 27de533 9e88c26 27de533 ad4797c dce7d1c 27de533 ad4797c 27de533 ad4797c 27de533 8a3fbac 2ed16e4 9e88c26 27de533 9e88c26 27de533 8a3fbac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import spaces
import gradio as gr
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
from diffusers import StableDiffusion3Pipeline
import re
import random
import numpy as np
import os
from huggingface_hub import snapshot_download
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
local_dir="SD3",
token=huggingface_token, # type a new token-id.
)
# VLM Captioner
vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to(device).eval()
vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
# Prompt Enhancer
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
# SD3
sd3_pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
# VLM Captioner function
def create_captions_rich(image):
prompt = "caption en"
model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = vlm_model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
generation = generation[0][input_len:]
decoded = vlm_processor.decode(generation, skip_special_tokens=True)
return modify_caption(decoded)
# Helper function for caption modification
def modify_caption(caption: str) -> str:
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
# Prompt Enhancer function
def enhance_prompt(input_prompt, model_choice):
if model_choice == "Medium":
result = enhancer_medium("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
pattern = r'^.*?of\s+(.*?(?:\.|$))'
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
if match:
remaining_text = enhanced_text[match.end():].strip()
modified_sentence = match.group(1).capitalize()
enhanced_text = modified_sentence + ' ' + remaining_text
else: # Long
result = enhancer_long("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
return enhanced_text
# SD3 Generation function
def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = sd3_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
return image, seed
# Gradio Interface
@spaces.GPU
def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if use_vlm and image is not None:
prompt = create_captions_rich(image)
else:
prompt = text_prompt
if use_enhancer:
prompt = enhance_prompt(prompt, model_choice)
generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)
return generated_image, prompt, used_seed
custom_css = """
.input-group, .output-group {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: #f9f9f9;
}
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="input-group"):
input_image = gr.Image(label="Input Image for VLM")
use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
with gr.Group(elem_classes="input-group"):
text_prompt = gr.Textbox(label="Text Prompt")
use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
with gr.Column(scale=1):
with gr.Group(elem_classes="output-group"):
output_image = gr.Image(label="Generated Image")
final_prompt = gr.Textbox(label="Final Prompt Used")
used_seed = gr.Number(label="Seed Used")
generate_btn.click(
fn=process_workflow,
inputs=[
input_image, text_prompt, use_vlm, use_enhancer, model_choice,
negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
],
outputs=[output_image, final_prompt, used_seed]
)
demo.launch(debug=True) |