gokaygokay's picture
Update app.py
92211ba verified
raw
history blame
10.5 kB
import spaces
import gradio as gr
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
from diffusers import StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
import re
import random
import numpy as np
import os
from huggingface_hub import snapshot_download
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
local_dir="SD3",
token=huggingface_token, # type a new token-id.
)
# VLM Captioner
vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to(device).eval()
vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
# Prompt Enhancer
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
def load_pipeline(pipeline_type):
if pipeline_type == "text2img":
return StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
elif pipeline_type == "img2img":
return StableDiffusion3Img2ImgPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
# VLM Captioner function
def create_captions_rich(image):
prompt = "caption en"
model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = vlm_model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
generation = generation[0][input_len:]
decoded = vlm_processor.decode(generation, skip_special_tokens=True)
return modify_caption(decoded)
# Helper function for caption modification
def modify_caption(caption: str) -> str:
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
# Prompt Enhancer function
def enhance_prompt(input_prompt, model_choice):
if model_choice == "Medium":
result = enhancer_medium("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
pattern = r'^.*?of\s+(.*?(?:\.|$))'
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
if match:
remaining_text = enhanced_text[match.end():].strip()
modified_sentence = match.group(1).capitalize()
enhanced_text = modified_sentence + ' ' + remaining_text
else: # Long
result = enhancer_long("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
return enhanced_text
# SD3 Generation function
def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
pipe = load_pipeline("text2img")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
return image, seed
# Gradio Interface
@spaces.GPU
def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if use_vlm and image is not None:
prompt = create_captions_rich(image)
else:
prompt = text_prompt
if use_enhancer:
prompt = enhance_prompt(prompt, model_choice)
generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)
return generated_image, prompt, used_seed
@spaces.GPU
def img2img_generate(
prompt: str,
init_image: gr.Image,
use_vlm: bool,
use_enhancer: bool,
model_choice: str,
negative_prompt: str = "",
seed: int = 0,
randomize_seed: bool = False,
guidance_scale: float = 7,
num_inference_steps: int = 30,
strength: float = 0.8,
):
if use_vlm and init_image is not None:
prompt = create_captions_rich(init_image)
if use_enhancer:
prompt = enhance_prompt(prompt, model_choice)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
img2img_pipe = load_pipeline("img2img")
init_image = init_image.resize((768, 768))
image = img2img_pipe(
prompt=prompt,
image=init_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
strength=strength,
).images[0]
return image, prompt, seed
custom_css = """
.input-group, .output-group {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: #f9f9f9;
}
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
# Gradio Interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
with gr.Tab(label="Text to Image"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="input-group"):
input_image = gr.Image(label="Input Image for VLM")
use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
with gr.Group(elem_classes="input-group"):
text_prompt = gr.Textbox(label="Text Prompt")
use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
with gr.Column(scale=1):
with gr.Group(elem_classes="output-group"):
output_image = gr.Image(label="Generated Image")
final_prompt = gr.Textbox(label="Final Prompt Used")
used_seed = gr.Number(label="Seed Used")
generate_btn.click(
fn=process_workflow,
inputs=[
input_image, text_prompt, use_vlm, use_enhancer, model_choice,
negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
],
outputs=[output_image, final_prompt, used_seed]
)
with gr.Tab(label="Image to Image"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="input-group"):
init_image = gr.Image(label="Input Image", type="pil")
use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
with gr.Group(elem_classes="input-group"):
img2img_prompt = gr.Textbox(label="Text Prompt")
use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=10.0, step=0.1, value=5)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
strength = gr.Slider(label="Img2Img Strength", minimum=0.0, maximum=1.0, step=0.01, value=0.5)
img2img_generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
with gr.Column(scale=1):
with gr.Group(elem_classes="output-group"):
img2img_output = gr.Image(label="Generated Image")
img2img_final_prompt = gr.Textbox(label="Final Prompt Used")
img2img_used_seed = gr.Number(label="Seed Used")
img2img_generate_btn.click(
fn=img2img_generate,
inputs=[
img2img_prompt, init_image, use_vlm, use_enhancer, model_choice,
negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, strength
],
outputs=[img2img_output, img2img_final_prompt, img2img_used_seed]
)
demo.launch(debug=True)