Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import PIL | |
from PIL import Image | |
from torch.utils.data import Dataset | |
import random | |
from PIL import Image | |
from PIL import ImageDraw | |
from PIL import ImageFont | |
from .font_list import font_list | |
from .font_list_single import font_list_single | |
import warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
import torchvision.transforms as T | |
color_multi_font =[ | |
'#a3001b', | |
'#175c7a', | |
'#5cd6ce', | |
'#d1440c', | |
'#1f1775', | |
'#9e1e46', | |
'#9b100d', | |
'#d3760c', | |
'#e8bb06', | |
'#5135ad', | |
'#366993', | |
'#470e7c', | |
'#070707', | |
'#053d22', | |
'#7a354c', | |
'#7c0b03' | |
] | |
alphabets = [ | |
'A','B','C','D','E','F','G','H','I','J','K','L', | |
'M','N','O','P','Q','R','S','T','U','V','W','X', 'Y','Z' | |
] | |
class Rasterizer(Dataset): | |
def __init__(self, | |
text = "R", | |
style_word = "DRAGON", | |
data_path = "data/DRAGON", | |
alternate_glyph = None, | |
img_size = 256, | |
num_samples = 1001, | |
make_black = False, | |
one_font = True, | |
full_word = False, | |
font_name = None, | |
just_use_style=False, | |
use_alt=False): | |
self.text = text | |
self.img_size = img_size | |
self.interpolation = PIL.Image.BILINEAR | |
self.num_samples = num_samples | |
self.style_word = style_word | |
self.data_path = data_path | |
self.alternate_glyph = alternate_glyph | |
self.dict = {} | |
self.data_img = [] | |
self.alt_gly = [] | |
self.make_black = make_black | |
self.one_font = one_font | |
self.classes = [] | |
if font_name is not None: | |
self.fontname = font_name | |
else: | |
fontname = random.choice(font_list_single) | |
self.fontname = fontname | |
self.full_word = full_word | |
self.just_use_style = just_use_style | |
self.use_alt = use_alt | |
self.load_back() | |
def load_back(self): | |
style_only = self.style_word.split(" ")[0] | |
self.data_path = f"data_style/{style_only}" | |
self.data_img = [] | |
for file in os.listdir(self.data_path): | |
self.data_img.append(os.path.join(self.data_path, file)) | |
def __len__(self): | |
return self.num_samples | |
def getSize(self, txt, font): | |
testImg = Image.new('RGB', (1, 1)) | |
testDraw = ImageDraw.Draw(testImg) | |
return testDraw.textsize(txt, font) | |
def __getitem__(self, i): | |
output = {} | |
fontname = random.choice(font_list) | |
if self.one_font: | |
fontname = self.fontname | |
if self.full_word: | |
font = ImageFont.truetype(fontname, 256) | |
width, height = self.getSize(self.text, font) | |
image = Image.new('RGB', (width+32, height+128), "white") | |
d = ImageDraw.Draw(image) | |
colorFont = random.choice(color_multi_font) | |
d.text((16,64), self.text, fill=colorFont, font=font) | |
img = np.array(image).astype(np.uint8) | |
image = Image.fromarray(img) | |
image = image.resize((self.img_size, self.img_size), resample=self.interpolation) | |
image = np.array(image).astype(np.uint8) | |
image_text = (image / 127.5 - 1.0).astype(np.float32) | |
else: | |
rand_img = str(random.randint(0,15)) | |
fontname_t = fontname.split(".")[0] | |
dir_font = f"data_fonts/{fontname_t}/{self.text}/{rand_img}"+".png" | |
image = Image.open(dir_font) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
image = image.resize((self.img_size, self.img_size), resample=self.interpolation) | |
image = np.array(image).astype(np.uint8) | |
image_text = (image / 127.5 - 1.0).astype(np.float32) | |
output["image"] = image_text | |
output["caption"] = self.text | |
################################################################################## | |
output2 = {} | |
ind = i % (len(self.data_img)-1) | |
image = Image.open(self.data_img[ind]) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
image = image.resize((self.img_size, self.img_size), resample=self.interpolation) | |
image = np.array(image).astype(np.uint8) | |
image = (image / 127.5 - 1.0).astype(np.float32) | |
output2["image"] = image | |
output2["caption"] = self.style_word | |
batch = {} | |
batch["base"] = output | |
batch["style"] = output2 | |
batch["font"] = fontname | |
batch["number"] = 0 | |
batch["epochs"] = 800*2 if self.one_font else 1000*2 | |
if self.full_word or self.just_use_style: | |
batch["cond"] = self.style_word | |
else: | |
batch["cond"] = self.style_word + " " + self.text | |
return batch | |