Tzktz's picture
Upload 7664 files
6fc683c verified
import glob
import logging
import os
import random
import torch
from fairseq.data import FairseqDataset, data_utils
from natsort import natsorted
from PIL import Image
from tqdm import tqdm
logger = logging.getLogger(__name__)
def default_collater(target_dict, samples, dataset=None):
if not samples:
return None
if any([sample is None for sample in samples]):
if not dataset:
return None
len_batch = len(samples)
while True:
samples.append(dataset[random.choice(range(len(dataset)))])
samples =list(filter (lambda x:x is not None, samples))
if len(samples) == len_batch:
break
indices = []
imgs = [] # bs, c, h , w
target_samples = []
target_ntokens = 0
for sample in samples:
index = sample['id']
indices.append(index)
imgs.append(sample['tfm_img'])
target_samples.append(sample['label_ids'].long())
target_ntokens += len(sample['label_ids'])
num_sentences = len(samples)
target_batch = data_utils.collate_tokens(target_samples,
pad_idx=target_dict.pad(),
eos_idx=target_dict.eos(),
move_eos_to_beginning=False)
rotate_batch = data_utils.collate_tokens(target_samples,
pad_idx=target_dict.pad(),
eos_idx=target_dict.eos(),
move_eos_to_beginning=True)
indices = torch.tensor(indices, dtype=torch.long)
imgs = torch.stack(imgs, dim=0)
return {
'id': indices,
'net_input': {
'imgs': imgs,
'prev_output_tokens': rotate_batch
},
'ntokens': target_ntokens,
'nsentences': num_sentences,
'target': target_batch
}
def read_txt_and_tokenize(txt_path: str, bpe, target_dict):
annotations = []
with open(txt_path, 'r', encoding='utf8') as fp:
for line in fp.readlines():
line = line.rstrip()
if not line:
continue
line_split = line.split(',', maxsplit=8)
quadrangle = list(map(int, line_split[:8]))
content = line_split[-1]
if bpe:
encoded_str = bpe.encode(content)
else:
encoded_str = content
xs = [quadrangle[i] for i in range(0, 8, 2)]
ys = [quadrangle[i] for i in range(1, 8, 2)]
bbox = [min(xs), min(ys), max(xs), max(ys)]
annotations.append({'bbox': bbox, 'encoded_str': encoded_str, 'category_id': 0, 'segmentation': [quadrangle]}) # 0 for text, 1 for background
return annotations
def SROIETask2(root_dir: str, bpe, target_dict, crop_img_output_dir=None):
data = []
img_id = -1
crop_data = []
crop_img_id = -1
image_paths = natsorted(list(glob.glob(os.path.join(root_dir, '*.jpg'))))
for jpg_path in tqdm(image_paths):
im = Image.open(jpg_path).convert('RGB')
img_w, img_h = im.size
img_id += 1
txt_path = jpg_path.replace('.jpg', '.txt')
annotations = read_txt_and_tokenize(txt_path, bpe, target_dict)
img_dict = {'file_name': jpg_path, 'width': img_w, 'height': img_h, 'image_id':img_id, 'annotations':annotations}
data.append(img_dict)
for ann in annotations:
crop_w = ann['bbox'][2] - ann['bbox'][0]
crop_h = ann['bbox'][3] - ann['bbox'][1]
if not (crop_w > 0 and crop_h > 0):
logger.warning('Error occurs during image cropping: {} has a zero area bbox.'.format(os.path.basename(jpg_path)))
continue
crop_img_id += 1
crop_im = im.crop(ann['bbox'])
if crop_img_output_dir:
crop_im.save(os.path.join(crop_img_output_dir, '{:d}.jpg'.format(crop_img_id)))
crop_img_dict = {'img':crop_im, 'file_name': jpg_path, 'width': crop_w, 'height': crop_h, 'image_id':crop_img_id, 'encoded_str':ann['encoded_str']}
crop_data.append(crop_img_dict)
return data, crop_data
class SROIETextRecognitionDataset(FairseqDataset):
def __init__(self, root_dir, tfm, bpe_parser, target_dict, crop_img_output_dir=None):
self.root_dir = root_dir
self.tfm = tfm
self.target_dict = target_dict
# self.bpe_parser = bpe_parser
self.ori_data, self.data = SROIETask2(root_dir, bpe_parser, target_dict, crop_img_output_dir)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_dict = self.data[idx]
image = img_dict['img']
encoded_str = img_dict['encoded_str']
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False)
tfm_img = self.tfm(image) # h, w, c
return {'id': idx, 'tfm_img': tfm_img, 'label_ids': input_ids}
def size(self, idx):
img_dict = self.data[idx]
encoded_str = img_dict['encoded_str']
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False)
return len(input_ids)
def num_tokens(self, idx):
return self.size(idx)
def collater(self, samples):
return default_collater(self.target_dict, samples)
def STR(gt_path, bpe_parser):
root_dir = os.path.dirname(gt_path)
data = []
img_id = 0
with open(gt_path, 'r') as fp:
for line in tqdm(list(fp.readlines()), desc='Loading STR:'):
line = line.rstrip()
temp = line.split('\t', 1)
img_file = temp[0]
text = temp[1]
img_path = os.path.join(root_dir, 'image', img_file)
if not bpe_parser:
encoded_str = text
else:
encoded_str = bpe_parser.encode(text)
data.append({'img_path': img_path, 'image_id':img_id, 'text':text, 'encoded_str':encoded_str})
img_id += 1
return data
class SyntheticTextRecognitionDataset(FairseqDataset):
def __init__(self, gt_path, tfm, bpe_parser, target_dict):
self.gt_path = gt_path
self.tfm = tfm
self.target_dict = target_dict
self.data = STR(gt_path, bpe_parser)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_dict = self.data[idx]
image = Image.open(img_dict['img_path']).convert('RGB')
encoded_str = img_dict['encoded_str']
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False)
tfm_img = self.tfm(image) # h, w, c
return {'id': idx, 'tfm_img': tfm_img, 'label_ids': input_ids}
def size(self, idx):
img_dict = self.data[idx]
encoded_str = img_dict['encoded_str']
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False)
return len(input_ids)
def num_tokens(self, idx):
return self.size(idx)
def collater(self, samples):
return default_collater(self.target_dict, samples)
def Receipt53K(gt_path):
root_dir = os.path.dirname(gt_path)
data = []
with open(gt_path, 'r', encoding='utf8') as fp:
for line in tqdm(list(fp.readlines()), desc='Loading Receipt53K:'):
line = line.rstrip()
temp = line.split('\t', 1)
img_file = temp[0]
text = temp[1]
img_path = os.path.join(root_dir, img_file)
data.append({'img_path': img_path, 'text':text})
return data
class Receipt53KDataset(FairseqDataset):
def __init__(self, gt_path, tfm, bpe_parser, target_dict):
self.gt_path = gt_path
self.tfm = tfm
self.target_dict = target_dict
self.bpe_parser = bpe_parser
self.data = Receipt53K(gt_path)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_dict = self.data[idx]
try:
image = Image.open(img_dict['img_path']).convert('RGB')
except Exception as e:
logger.warning('Failed to load image: {}, since {}'.format(img_dict['img_path'], str(e)))
return None
encoded_str = self.bpe_parser.encode(img_dict['text'])
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False)
tfm_img = self.tfm(image) # h, w, c
return {'id': idx, 'tfm_img':tfm_img, 'label_ids':input_ids}
def size(self, idx):
img_dict = self.data[idx]
return len(img_dict['text'])
# item = self[idx]
# return len(item['label_ids'])
def num_tokens(self, idx):
return self.size(idx)
def collater(self, samples):
return default_collater(self.target_dict, samples)