# ------------------------------------------ | |
# TextDiffuser: Diffusion Models as Text Painters | |
# Paper Link: | |
# Code Link: | |
# Copyright (c) Microsoft Corporation. | |
# This file aims to predict the layout of keywords in user prompts. | |
# ------------------------------------------ | |
import warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
import re | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from transformers import CLIPTokenizer | |
from PIL import Image, ImageDraw, ImageFont | |
from util import get_width, get_key_words, adjust_overlap_box, shrink_box, adjust_font_size, alphabet_dic | |
from model.layout_transformer import LayoutTransformer, TextConditioner | |
from termcolor import colored | |
# import layout transformer | |
model = LayoutTransformer().cuda().eval() | |
model.load_state_dict(torch.load('textdiffuser-ckpt/layout_transformer.pth')) | |
# import text encoder and tokenizer | |
text_encoder = TextConditioner().cuda().eval() | |
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14') | |
def process_caption(font_path, caption, keywords): | |
# remove punctuations. please remove this statement if you want to paint punctuations | |
caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption) | |
# tokenize it into ids and get length | |
caption_words = tokenizer([caption], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
caption_words_ids = caption_words['input_ids'] # (1, 77) | |
length = caption_words['length'] # (1, ) | |
# convert id to words | |
words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist()) | |
words = [i.replace('</w>', '') for i in words] | |
words_valid = words[:int(length)] | |
# store the box coordinates and state of each token | |
info_array = np.zeros((77,5)) # (77, 5) | |
# split the caption into words and convert them into lower case | |
caption_split = caption.split() | |
caption_split = [i.lower() for i in caption_split] | |
start_dic = {} # get the start index of each word | |
state_list = [] # 0: start, 1: middle, 2: special token | |
word_match_list = [] # the index of the word in the caption | |
current_caption_index = 0 | |
current_match = '' | |
for i in range(length): | |
# the first and last token are special tokens | |
if i == 0 or i == length-1: | |
state_list.append(2) | |
word_match_list.append(127) | |
continue | |
if current_match == '': | |
state_list.append(0) | |
start_dic[current_caption_index] = i | |
else: | |
state_list.append(1) | |
current_match += words_valid[i] | |
word_match_list.append(current_caption_index) | |
if current_match == caption_split[current_caption_index]: | |
current_match = '' | |
current_caption_index += 1 | |
while len(state_list) < 77: | |
state_list.append(127) | |
while len(word_match_list) < 77: | |
word_match_list.append(127) | |
length_list = [] | |
width_list =[] | |
for i in range(len(word_match_list)): | |
if word_match_list[i] == 127: | |
length_list.append(0) | |
width_list.append(0) | |
else: | |
length_list.append(len(caption.split()[word_match_list[i]])) | |
width_list.append(get_width(font_path, caption.split()[word_match_list[i]])) | |
while len(length_list) < 77: | |
length_list.append(127) | |
width_list.append(0) | |
length_list = torch.Tensor(length_list).long() # (77, ) | |
width_list = torch.Tensor(width_list).long() # (77, ) | |
boxes = [] | |
duplicate_dict = {} # some words may appear more than once | |
for keyword in keywords: | |
keyword = keyword.lower() | |
if keyword in caption_split: | |
if keyword not in duplicate_dict: | |
duplicate_dict[keyword] = caption_split.index(keyword) | |
index = caption_split.index(keyword) | |
else: | |
if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]: | |
index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword) | |
duplicate_dict[keyword] = index | |
else: | |
continue | |
index = caption_split.index(keyword) | |
index = start_dic[index] | |
info_array[index][0] = 1 | |
box = [0,0,0,0] | |
boxes.append(list(box)) | |
info_array[index][1:] = box | |
boxes_length = len(boxes) | |
if boxes_length > 8: | |
boxes = boxes[:8] | |
while len(boxes) < 8: | |
boxes.append([0,0,0,0]) | |
return caption, length_list, width_list, torch.from_numpy(info_array), words, torch.Tensor(state_list).long(), torch.Tensor(word_match_list).long(), torch.Tensor(boxes), boxes_length | |
def get_layout_from_prompt(args): | |
# prompt = args.prompt | |
font_path = args.font_path | |
keywords = get_key_words(args.prompt) | |
print(f'{colored("[!]", "red")} Detected keywords: {keywords} from prompt {args.prompt}') | |
text_embedding, mask = text_encoder(args.prompt) # (1, 77 768) / (1, 77) | |
# process all relevant info | |
caption, length_list, width_list, target, words, state_list, word_match_list, boxes, boxes_length = process_caption(font_path, args.prompt, keywords) | |
target = target.cuda().unsqueeze(0) # (77, 5) | |
width_list = width_list.cuda().unsqueeze(0) # (77, ) | |
length_list = length_list.cuda().unsqueeze(0) # (77, ) | |
state_list = state_list.cuda().unsqueeze(0) # (77, ) | |
word_match_list = word_match_list.cuda().unsqueeze(0) # (77, ) | |
padding = torch.zeros(1, 1, 4).cuda() | |
boxes = boxes.unsqueeze(0).cuda() | |
right_shifted_boxes =[padding, boxes[:,0:-1,:]],1) # (1, 8, 4) | |
# inference | |
return_boxes= [] | |
with torch.no_grad(): | |
for box_index in range(boxes_length): | |
if box_index == 0: | |
encoder_embedding = None | |
output, encoder_embedding = model(text_embedding, length_list, width_list, mask, state_list, word_match_list, target, right_shifted_boxes, train=False, encoder_embedding=encoder_embedding) | |
output = torch.clamp(output, min=0, max=1) # (1, 8, 4) | |
# add overlap detection | |
output = adjust_overlap_box(output, box_index) # (1, 8, 4) | |
right_shifted_boxes[:,box_index+1,:] = output[:,box_index,:] | |
xmin, ymin, xmax, ymax = output[0, box_index, :].tolist() | |
return_boxes.append([xmin, ymin, xmax, ymax]) | |
# print the location of keywords | |
print(f'index\tkeyword\tx_min\ty_min\tx_max\ty_max') | |
for index, keyword in enumerate(keywords): | |
x_min = int(return_boxes[index][0] * 512) | |
y_min = int(return_boxes[index][1] * 512) | |
x_max = int(return_boxes[index][2] * 512) | |
y_max = int(return_boxes[index][3] * 512) | |
print(f'{index}\t{keyword}\t{x_min}\t{y_min}\t{x_max}\t{y_max}') | |
# paint the layout | |
render_image ='RGB', (512, 512), (255, 255, 255)) | |
draw = ImageDraw.Draw(render_image) | |
segmentation_mask ="L", (512,512), 0) | |
segmentation_mask_draw = ImageDraw.Draw(segmentation_mask) | |
for index, box in enumerate(return_boxes): | |
box = [int(i*512) for i in box] | |
xmin, ymin, xmax, ymax = box | |
width = xmax - xmin | |
height = ymax - ymin | |
text = keywords[index] | |
font_size = adjust_font_size(args, width, height, draw, text) | |
font = ImageFont.truetype(args.font_path, font_size) | |
# draw.rectangle([xmin, ymin, xmax,ymax], outline=(255,0,0)) | |
draw.text((xmin, ymin), text, font=font, fill=(0, 0, 0)) | |
boxes = [] | |
for i, char in enumerate(text): | |
# paint character-level segmentation masks | |
# | |
bottom_1 = font.getsize(text[i])[1] | |
right, bottom_2 = font.getsize(text[:i+1]) | |
bottom = bottom_1 if bottom_1 < bottom_2 else bottom_2 | |
width, height = font.getmask(char).size | |
right += xmin | |
bottom += ymin | |
top = bottom - height | |
left = right - width | |
char_box = (left, top, right, bottom) | |
boxes.append(char_box) | |
char_index = alphabet_dic[char] | |
segmentation_mask_draw.rectangle(shrink_box(char_box, scale_factor = 0.9), fill=char_index) | |
print(f'{colored("[√]", "green")} Layout is successfully generated') | |
return render_image, segmentation_mask | |