File size: 8,853 Bytes
e8dca02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link:
# Code Link:
# Copyright (c) Microsoft Corporation.
# This file aims to predict the layout of keywords in user prompts.
# ------------------------------------------
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import re
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from PIL import Image, ImageDraw, ImageFont
from util import get_width, get_key_words, adjust_overlap_box, shrink_box, adjust_font_size, alphabet_dic
from model.layout_transformer import LayoutTransformer, TextConditioner
from termcolor import colored
# import layout transformer
model = LayoutTransformer().cuda().eval()
# import text encoder and tokenizer
text_encoder = TextConditioner().cuda().eval()
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
def process_caption(font_path, caption, keywords):
# remove punctuations. please remove this statement if you want to paint punctuations
caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption)
# tokenize it into ids and get length
caption_words = tokenizer([caption], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
caption_words_ids = caption_words['input_ids'] # (1, 77)
length = caption_words['length'] # (1, )
# convert id to words
words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist())
words = [i.replace('</w>', '') for i in words]
words_valid = words[:int(length)]
# store the box coordinates and state of each token
info_array = np.zeros((77,5)) # (77, 5)
# split the caption into words and convert them into lower case
caption_split = caption.split()
caption_split = [i.lower() for i in caption_split]
start_dic = {} # get the start index of each word
state_list = [] # 0: start, 1: middle, 2: special token
word_match_list = [] # the index of the word in the caption
current_caption_index = 0
current_match = ''
for i in range(length):
# the first and last token are special tokens
if i == 0 or i == length-1:
if current_match == '':
start_dic[current_caption_index] = i
current_match += words_valid[i]
if current_match == caption_split[current_caption_index]:
current_match = ''
current_caption_index += 1
while len(state_list) < 77:
while len(word_match_list) < 77:
length_list = []
width_list =[]
for i in range(len(word_match_list)):
if word_match_list[i] == 127:
width_list.append(get_width(font_path, caption.split()[word_match_list[i]]))
while len(length_list) < 77:
length_list = torch.Tensor(length_list).long() # (77, )
width_list = torch.Tensor(width_list).long() # (77, )
boxes = []
duplicate_dict = {} # some words may appear more than once
for keyword in keywords:
keyword = keyword.lower()
if keyword in caption_split:
if keyword not in duplicate_dict:
duplicate_dict[keyword] = caption_split.index(keyword)
index = caption_split.index(keyword)
if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]:
index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword)
duplicate_dict[keyword] = index
index = caption_split.index(keyword)
index = start_dic[index]
info_array[index][0] = 1
box = [0,0,0,0]
info_array[index][1:] = box
boxes_length = len(boxes)
if boxes_length > 8:
boxes = boxes[:8]
while len(boxes) < 8:
return caption, length_list, width_list, torch.from_numpy(info_array), words, torch.Tensor(state_list).long(), torch.Tensor(word_match_list).long(), torch.Tensor(boxes), boxes_length
def get_layout_from_prompt(args):
# prompt = args.prompt
font_path = args.font_path
keywords = get_key_words(args.prompt)
print(f'{colored("[!]", "red")} Detected keywords: {keywords} from prompt {args.prompt}')
text_embedding, mask = text_encoder(args.prompt) # (1, 77 768) / (1, 77)
# process all relevant info
caption, length_list, width_list, target, words, state_list, word_match_list, boxes, boxes_length = process_caption(font_path, args.prompt, keywords)
target = target.cuda().unsqueeze(0) # (77, 5)
width_list = width_list.cuda().unsqueeze(0) # (77, )
length_list = length_list.cuda().unsqueeze(0) # (77, )
state_list = state_list.cuda().unsqueeze(0) # (77, )
word_match_list = word_match_list.cuda().unsqueeze(0) # (77, )
padding = torch.zeros(1, 1, 4).cuda()
boxes = boxes.unsqueeze(0).cuda()
right_shifted_boxes =[padding, boxes[:,0:-1,:]],1) # (1, 8, 4)
# inference
return_boxes= []
with torch.no_grad():
for box_index in range(boxes_length):
if box_index == 0:
encoder_embedding = None
output, encoder_embedding = model(text_embedding, length_list, width_list, mask, state_list, word_match_list, target, right_shifted_boxes, train=False, encoder_embedding=encoder_embedding)
output = torch.clamp(output, min=0, max=1) # (1, 8, 4)
# add overlap detection
output = adjust_overlap_box(output, box_index) # (1, 8, 4)
right_shifted_boxes[:,box_index+1,:] = output[:,box_index,:]
xmin, ymin, xmax, ymax = output[0, box_index, :].tolist()
return_boxes.append([xmin, ymin, xmax, ymax])
# print the location of keywords
for index, keyword in enumerate(keywords):
x_min = int(return_boxes[index][0] * 512)
y_min = int(return_boxes[index][1] * 512)
x_max = int(return_boxes[index][2] * 512)
y_max = int(return_boxes[index][3] * 512)
# paint the layout
render_image ='RGB', (512, 512), (255, 255, 255))
draw = ImageDraw.Draw(render_image)
segmentation_mask ="L", (512,512), 0)
segmentation_mask_draw = ImageDraw.Draw(segmentation_mask)
for index, box in enumerate(return_boxes):
box = [int(i*512) for i in box]
xmin, ymin, xmax, ymax = box
width = xmax - xmin
height = ymax - ymin
text = keywords[index]
font_size = adjust_font_size(args, width, height, draw, text)
font = ImageFont.truetype(args.font_path, font_size)
# draw.rectangle([xmin, ymin, xmax,ymax], outline=(255,0,0))
draw.text((xmin, ymin), text, font=font, fill=(0, 0, 0))
boxes = []
for i, char in enumerate(text):
# paint character-level segmentation masks
bottom_1 = font.getsize(text[i])[1]
right, bottom_2 = font.getsize(text[:i+1])
bottom = bottom_1 if bottom_1 < bottom_2 else bottom_2
width, height = font.getmask(char).size
right += xmin
bottom += ymin
top = bottom - height
left = right - width
char_box = (left, top, right, bottom)
char_index = alphabet_dic[char]
segmentation_mask_draw.rectangle(shrink_box(char_box, scale_factor = 0.9), fill=char_index)
print(f'{colored("[√]", "green")} Layout is successfully generated')
return render_image, segmentation_mask