FridaKahlosArtCenter
final commit
ce9973a
import torch
from diffusers import AutoPipelineForImage2Image
from PIL import Image, ImageDraw, ImageFont
import requests
from io import BytesIO
import gradio as gr
import gc
import textwrap
# log gpu availability
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
def image_to_template(generated_image, logo, button_text, punchline, theme_color):
template_width = 540
button_font_size = 10
punchline_font_size = 30
decoration_height = 10
margin = 20
# wrap punchline text
punchline = textwrap.wrap(punchline, width=30)
n_of_lines_punchline = len(punchline)
generated_image = generated_image.convert("RGBA")
logo = logo.convert("RGBA")
# image shape
image_width = template_width // 2
image_height = image_width * generated_image.height // generated_image.width
image_shape = (image_width, image_height)
# logo shape
logo_width = image_width // 3
logo_height = logo_width * logo.height // logo.width
logo_shape = (logo_width, logo_height)
# Define fonts
button_font = ImageFont.truetype("./assets/Montserrat-Bold.ttf", button_font_size)
punchline_font = ImageFont.truetype("./assets/Montserrat-Bold.ttf", punchline_font_size)
# button shape
button_width = template_width // 3
button_height = button_font_size * 3
# template height calculation
template_height = (
image_height
+ logo_height
+ button_height
+ n_of_lines_punchline * punchline_font_size
+ (5 * margin)
+ (2 * decoration_height)
)
# Calculate positions for the centered layout
logo_pos = ((template_width - logo_width) // 2, margin + decoration_height)
image_pos = (
(template_width - image_width) // 2,
logo_pos[1] + logo_height + margin,
)
# Decoration positions
top_decoration_pos = [
margin,
-decoration_height // 2,
template_width - margin,
decoration_height // 2,
]
bottom_decoration_pos = [
margin,
template_height - decoration_height // 2,
template_width - margin,
template_height + decoration_height // 2,
]
# Generate Components
generated_image.thumbnail(image_shape, Image.ANTIALIAS)
logo.thumbnail(logo_shape, Image.ANTIALIAS)
background = Image.new("RGBA", (template_width, template_height), "WHITE")
# round the corners of generated image
mask = Image.new("L", generated_image.size, 0)
draw = ImageDraw.Draw(mask)
draw.rounded_rectangle((0, 0) + generated_image.size, 20, fill=255)
generated_image.putalpha(mask)
# Paste the logo and the generated image onto the background
background.paste(logo, logo_pos, logo)
background.paste(generated_image, image_pos, generated_image)
# Draw the decorations, punchline, and button
draw = ImageDraw.Draw(background)
# Decorations on top and bottom
draw.rounded_rectangle(bottom_decoration_pos, radius=20, fill=theme_color)
draw.rounded_rectangle(top_decoration_pos, radius=20, fill=theme_color)
# Punchline text
text_heights = []
for line in punchline:
text_width, text_height = draw.textsize(line, font=punchline_font)
punchline_pos = (
(template_width - text_width) // 2,
image_pos[1] + generated_image.height + margin + sum(text_heights),
)
draw.text(punchline_pos, line, fill=theme_color, font=punchline_font)
text_heights.append(text_height)
# Button with rounded corners
button_text_width, button_text_height = draw.textsize(button_text, font=button_font)
button_shape = [
((template_width - button_width) // 2, punchline_pos[1] + text_height + margin),
(
(template_width + button_width) // 2,
punchline_pos[1] + text_height + margin + button_height,
),
]
draw.rounded_rectangle(button_shape, radius=20, fill=theme_color)
# Button text
button_text_pos = (
(template_width - button_text_width) // 2,
button_shape[0][1] + (button_height - button_text_height) // 2,
)
draw.text(button_text_pos, button_text, fill="white", font=button_font)
return background
def generate_template(
initial_image, logo, prompt, button_text, punchline, image_color, theme_color
):
pipeline = AutoPipelineForImage2Image.from_pretrained(
"./models/kandinsky-2-2-decoder",
torch_dtype=torch.float16,
use_safetensors=True,
)
# pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
pipeline.enable_model_cpu_offload()
prompt = f"{prompt}, include the color {image_color}"
negative_prompt = "low quality, bad quality, blurry, unprofessional"
generated_image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=initial_image,
height=256,
width=256,
).images[0]
template_image = image_to_template(
generated_image, logo, button_text, punchline, theme_color
)
# free cpu and gpu memory
del pipeline
gc.collect()
torch.cuda.empty_cache()
return template_image
# Set up Gradio interface
iface = gr.Interface(
fn=generate_template,
inputs=[
gr.Image(type="pil", label="Initial Image"),
gr.Image(type="pil", label="Logo"),
gr.Textbox(label="Prompt"),
gr.Textbox(label="Button Text"),
gr.Textbox(label="Punchline"),
gr.ColorPicker(label="Image Color"),
gr.ColorPicker(label="Theme Color"),
],
outputs=[gr.Image(type="pil")],
title="Ad Template Generation Using Diffusion Models Demo",
description="Generate ad template based on your inputs using a trained model.",
concurrency_limit=2,
examples=[
[
"./assets/city_image.jpg", # Initial Image
"./assets/logo.png", # Logo
"Big bank building finance", # Prompt
"Discover More!", # Button Text
"We Maximize Risk-Adusted Returns for Our Customers", # Punchline
"#00FF00", # Image Color
"#0000FF", # Theme Color
]
],
)
# Run the interface
iface.launch(debug=True)