gokaygokay's picture
Update app.py
dce7d1c verified
raw
history blame
7.64 kB
import spaces
import gradio as gr
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
from diffusers import StableDiffusion3Pipeline
import re
import random
import numpy as np
import os
from huggingface_hub import snapshot_download
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
model_path = snapshot_download(
repo_id="stabilityai/stable-diffusion-3-medium",
revision="refs/pr/26",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
local_dir="SD3",
token=huggingface_token, # type a new token-id.
)
# VLM Captioner
vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner").to(device).eval()
vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner")
# Prompt Enhancer
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
# SD3
sd3_pipe = StableDiffusion3Pipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344
# VLM Captioner function
def create_captions_rich(image):
prompt = "caption en"
model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = vlm_model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
generation = generation[0][input_len:]
decoded = vlm_processor.decode(generation, skip_special_tokens=True)
return modify_caption(decoded)
# Helper function for caption modification
def modify_caption(caption: str) -> str:
prefix_substrings = [
('captured from ', ''),
('captured at ', '')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
# Prompt Enhancer function
def enhance_prompt(input_prompt, model_choice):
if model_choice == "Medium":
result = enhancer_medium("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
pattern = r'^.*?of\s+(.*?(?:\.|$))'
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
if match:
remaining_text = enhanced_text[match.end():].strip()
modified_sentence = match.group(1).capitalize()
enhanced_text = modified_sentence + ' ' + remaining_text
else: # Long
result = enhancer_long("Enhance the description: " + input_prompt)
enhanced_text = result[0]['summary_text']
return enhanced_text
# SD3 Generation function
def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = sd3_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
return image, seed
# Gradio Interface
@spaces.GPU
def process_workflow(image, text_prompt, use_vlm, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if use_vlm and image is not None:
prompt = create_captions_rich(image)
else:
prompt = text_prompt
if use_enhancer:
prompt = enhance_prompt(prompt, model_choice)
generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)
return generated_image, prompt, used_seed
css = """
body {
font-family: 'Arial', sans-serif;
background-color: #f0f4f8;
}
.container {
max-width: 1200px;
margin: 0 auto;
padding: 20px;
background-color: white;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
h1 {
color: #2c3e50;
text-align: center;
margin-bottom: 20px;
}
.input-group, .output-group {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: #f9f9f9;
}
.input-box, .output-box {
border: 1px solid #bdc3c7;
border-radius: 5px;
padding: 10px;
margin-bottom: 10px;
}
.input-box:focus, .output-box:focus {
border-color: #3498db;
box-shadow: 0 0 5px rgba(52, 152, 219, 0.5);
}
.submit-btn {
background-color: #2980b9;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
transition: background-color 0.3s;
}
.submit-btn:hover {
background-color: #3498db;
}
"""
# ... (keep the helper functions as before)
# Gradio Interface
with gr.Blocks(css=css) as demo:
gr.Markdown("# VLM Captioner + Prompt Enhancer + SD3 Image Generator")
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="input-group"):
input_image = gr.Image(label="Input Image for VLM", elem_classes="input-box")
use_vlm = gr.Checkbox(label="Use VLM Captioner", value=False)
with gr.Group(elem_classes="input-group"):
text_prompt = gr.Textbox(label="Text Prompt", elem_classes="input-box")
use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", elem_classes="input-box")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=64, value=1024)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
with gr.Column(scale=1):
with gr.Group(elem_classes="output-group"):
output_image = gr.Image(label="Generated Image", elem_classes="output-box")
final_prompt = gr.Textbox(label="Final Prompt Used", elem_classes="output-box")
used_seed = gr.Number(label="Seed Used", elem_classes="output-box")
generate_btn.click(
fn=process_workflow,
inputs=[
input_image, text_prompt, use_vlm, use_enhancer, model_choice,
negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
],
outputs=[output_image, final_prompt, used_seed]
)
demo.launch()