Spaces:
Build error
Build error
import torch | |
from diffusers import AutoPipelineForImage2Image | |
from PIL import Image, ImageDraw, ImageFont | |
import requests | |
from io import BytesIO | |
import gradio as gr | |
import gc | |
import textwrap | |
# log gpu availability | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
def image_to_template(generated_image, logo, button_text, punchline, theme_color): | |
template_width = 540 | |
button_font_size = 10 | |
punchline_font_size = 30 | |
decoration_height = 10 | |
margin = 20 | |
# wrap punchline text | |
punchline = textwrap.wrap(punchline, width=30) | |
n_of_lines_punchline = len(punchline) | |
generated_image = generated_image.convert("RGBA") | |
logo = logo.convert("RGBA") | |
# image shape | |
image_width = template_width // 2 | |
image_height = image_width * generated_image.height // generated_image.width | |
image_shape = (image_width, image_height) | |
# logo shape | |
logo_width = image_width // 3 | |
logo_height = logo_width * logo.height // logo.width | |
logo_shape = (logo_width, logo_height) | |
# Define fonts | |
button_font = ImageFont.truetype("./assets/Montserrat-Bold.ttf", button_font_size) | |
punchline_font = ImageFont.truetype("./assets/Montserrat-Bold.ttf", punchline_font_size) | |
# button shape | |
button_width = template_width // 3 | |
button_height = button_font_size * 3 | |
# template height calculation | |
template_height = ( | |
image_height | |
+ logo_height | |
+ button_height | |
+ n_of_lines_punchline * punchline_font_size | |
+ (5 * margin) | |
+ (2 * decoration_height) | |
) | |
# Calculate positions for the centered layout | |
logo_pos = ((template_width - logo_width) // 2, margin + decoration_height) | |
image_pos = ( | |
(template_width - image_width) // 2, | |
logo_pos[1] + logo_height + margin, | |
) | |
# Decoration positions | |
top_decoration_pos = [ | |
margin, | |
-decoration_height // 2, | |
template_width - margin, | |
decoration_height // 2, | |
] | |
bottom_decoration_pos = [ | |
margin, | |
template_height - decoration_height // 2, | |
template_width - margin, | |
template_height + decoration_height // 2, | |
] | |
# Generate Components | |
generated_image.thumbnail(image_shape, Image.ANTIALIAS) | |
logo.thumbnail(logo_shape, Image.ANTIALIAS) | |
background = Image.new("RGBA", (template_width, template_height), "WHITE") | |
# round the corners of generated image | |
mask = Image.new("L", generated_image.size, 0) | |
draw = ImageDraw.Draw(mask) | |
draw.rounded_rectangle((0, 0) + generated_image.size, 20, fill=255) | |
generated_image.putalpha(mask) | |
# Paste the logo and the generated image onto the background | |
background.paste(logo, logo_pos, logo) | |
background.paste(generated_image, image_pos, generated_image) | |
# Draw the decorations, punchline, and button | |
draw = ImageDraw.Draw(background) | |
# Decorations on top and bottom | |
draw.rounded_rectangle(bottom_decoration_pos, radius=20, fill=theme_color) | |
draw.rounded_rectangle(top_decoration_pos, radius=20, fill=theme_color) | |
# Punchline text | |
text_heights = [] | |
for line in punchline: | |
text_width, text_height = draw.textsize(line, font=punchline_font) | |
punchline_pos = ( | |
(template_width - text_width) // 2, | |
image_pos[1] + generated_image.height + margin + sum(text_heights), | |
) | |
draw.text(punchline_pos, line, fill=theme_color, font=punchline_font) | |
text_heights.append(text_height) | |
# Button with rounded corners | |
button_text_width, button_text_height = draw.textsize(button_text, font=button_font) | |
button_shape = [ | |
((template_width - button_width) // 2, punchline_pos[1] + text_height + margin), | |
( | |
(template_width + button_width) // 2, | |
punchline_pos[1] + text_height + margin + button_height, | |
), | |
] | |
draw.rounded_rectangle(button_shape, radius=20, fill=theme_color) | |
# Button text | |
button_text_pos = ( | |
(template_width - button_text_width) // 2, | |
button_shape[0][1] + (button_height - button_text_height) // 2, | |
) | |
draw.text(button_text_pos, button_text, fill="white", font=button_font) | |
return background | |
def generate_template( | |
initial_image, logo, prompt, button_text, punchline, image_color, theme_color | |
): | |
pipeline = AutoPipelineForImage2Image.from_pretrained( | |
"./models/kandinsky-2-2-decoder", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
) | |
# pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) | |
pipeline.enable_model_cpu_offload() | |
prompt = f"{prompt}, include the color {image_color}" | |
negative_prompt = "low quality, bad quality, blurry, unprofessional" | |
generated_image = pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=initial_image, | |
height=256, | |
width=256, | |
).images[0] | |
template_image = image_to_template( | |
generated_image, logo, button_text, punchline, theme_color | |
) | |
# free cpu and gpu memory | |
del pipeline | |
gc.collect() | |
torch.cuda.empty_cache() | |
return template_image | |
# Set up Gradio interface | |
iface = gr.Interface( | |
fn=generate_template, | |
inputs=[ | |
gr.Image(type="pil", label="Initial Image"), | |
gr.Image(type="pil", label="Logo"), | |
gr.Textbox(label="Prompt"), | |
gr.Textbox(label="Button Text"), | |
gr.Textbox(label="Punchline"), | |
gr.ColorPicker(label="Image Color"), | |
gr.ColorPicker(label="Theme Color"), | |
], | |
outputs=[gr.Image(type="pil")], | |
title="Ad Template Generation Using Diffusion Models Demo", | |
description="Generate ad template based on your inputs using a trained model.", | |
concurrency_limit=2, | |
examples=[ | |
[ | |
"./assets/city_image.jpg", # Initial Image | |
"./assets/logo.png", # Logo | |
"Big bank building finance", # Prompt | |
"Discover More!", # Button Text | |
"We Maximize Risk-Adusted Returns for Our Customers", # Punchline | |
"#00FF00", # Image Color | |
"#0000FF", # Theme Color | |
] | |
], | |
) | |
# Run the interface | |
iface.launch(debug=True) | |