Spaces:
Sleeping
Sleeping
import glob | |
import logging | |
import os | |
import random | |
import torch | |
from fairseq.data import FairseqDataset, data_utils | |
from natsort import natsorted | |
from PIL import Image | |
from tqdm import tqdm | |
logger = logging.getLogger(__name__) | |
def default_collater(target_dict, samples, dataset=None): | |
if not samples: | |
return None | |
if any([sample is None for sample in samples]): | |
if not dataset: | |
return None | |
len_batch = len(samples) | |
while True: | |
samples.append(dataset[random.choice(range(len(dataset)))]) | |
samples =list(filter (lambda x:x is not None, samples)) | |
if len(samples) == len_batch: | |
break | |
indices = [] | |
imgs = [] # bs, c, h , w | |
target_samples = [] | |
target_ntokens = 0 | |
for sample in samples: | |
index = sample['id'] | |
indices.append(index) | |
imgs.append(sample['tfm_img']) | |
target_samples.append(sample['label_ids'].long()) | |
target_ntokens += len(sample['label_ids']) | |
num_sentences = len(samples) | |
target_batch = data_utils.collate_tokens(target_samples, | |
pad_idx=target_dict.pad(), | |
eos_idx=target_dict.eos(), | |
move_eos_to_beginning=False) | |
rotate_batch = data_utils.collate_tokens(target_samples, | |
pad_idx=target_dict.pad(), | |
eos_idx=target_dict.eos(), | |
move_eos_to_beginning=True) | |
indices = torch.tensor(indices, dtype=torch.long) | |
imgs = torch.stack(imgs, dim=0) | |
return { | |
'id': indices, | |
'net_input': { | |
'imgs': imgs, | |
'prev_output_tokens': rotate_batch | |
}, | |
'ntokens': target_ntokens, | |
'nsentences': num_sentences, | |
'target': target_batch | |
} | |
def read_txt_and_tokenize(txt_path: str, bpe, target_dict): | |
annotations = [] | |
with open(txt_path, 'r', encoding='utf8') as fp: | |
for line in fp.readlines(): | |
line = line.rstrip() | |
if not line: | |
continue | |
line_split = line.split(',', maxsplit=8) | |
quadrangle = list(map(int, line_split[:8])) | |
content = line_split[-1] | |
if bpe: | |
encoded_str = bpe.encode(content) | |
else: | |
encoded_str = content | |
xs = [quadrangle[i] for i in range(0, 8, 2)] | |
ys = [quadrangle[i] for i in range(1, 8, 2)] | |
bbox = [min(xs), min(ys), max(xs), max(ys)] | |
annotations.append({'bbox': bbox, 'encoded_str': encoded_str, 'category_id': 0, 'segmentation': [quadrangle]}) # 0 for text, 1 for background | |
return annotations | |
def SROIETask2(root_dir: str, bpe, target_dict, crop_img_output_dir=None): | |
data = [] | |
img_id = -1 | |
crop_data = [] | |
crop_img_id = -1 | |
image_paths = natsorted(list(glob.glob(os.path.join(root_dir, '*.jpg')))) | |
for jpg_path in tqdm(image_paths): | |
im = Image.open(jpg_path).convert('RGB') | |
img_w, img_h = im.size | |
img_id += 1 | |
txt_path = jpg_path.replace('.jpg', '.txt') | |
annotations = read_txt_and_tokenize(txt_path, bpe, target_dict) | |
img_dict = {'file_name': jpg_path, 'width': img_w, 'height': img_h, 'image_id':img_id, 'annotations':annotations} | |
data.append(img_dict) | |
for ann in annotations: | |
crop_w = ann['bbox'][2] - ann['bbox'][0] | |
crop_h = ann['bbox'][3] - ann['bbox'][1] | |
if not (crop_w > 0 and crop_h > 0): | |
logger.warning('Error occurs during image cropping: {} has a zero area bbox.'.format(os.path.basename(jpg_path))) | |
continue | |
crop_img_id += 1 | |
crop_im = im.crop(ann['bbox']) | |
if crop_img_output_dir: | |
crop_im.save(os.path.join(crop_img_output_dir, '{:d}.jpg'.format(crop_img_id))) | |
crop_img_dict = {'img':crop_im, 'file_name': jpg_path, 'width': crop_w, 'height': crop_h, 'image_id':crop_img_id, 'encoded_str':ann['encoded_str']} | |
crop_data.append(crop_img_dict) | |
return data, crop_data | |
class SROIETextRecognitionDataset(FairseqDataset): | |
def __init__(self, root_dir, tfm, bpe_parser, target_dict, crop_img_output_dir=None): | |
self.root_dir = root_dir | |
self.tfm = tfm | |
self.target_dict = target_dict | |
# self.bpe_parser = bpe_parser | |
self.ori_data, self.data = SROIETask2(root_dir, bpe_parser, target_dict, crop_img_output_dir) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img_dict = self.data[idx] | |
image = img_dict['img'] | |
encoded_str = img_dict['encoded_str'] | |
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False) | |
tfm_img = self.tfm(image) # h, w, c | |
return {'id': idx, 'tfm_img': tfm_img, 'label_ids': input_ids} | |
def size(self, idx): | |
img_dict = self.data[idx] | |
encoded_str = img_dict['encoded_str'] | |
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False) | |
return len(input_ids) | |
def num_tokens(self, idx): | |
return self.size(idx) | |
def collater(self, samples): | |
return default_collater(self.target_dict, samples) | |
def STR(gt_path, bpe_parser): | |
root_dir = os.path.dirname(gt_path) | |
data = [] | |
img_id = 0 | |
with open(gt_path, 'r') as fp: | |
for line in tqdm(list(fp.readlines()), desc='Loading STR:'): | |
line = line.rstrip() | |
temp = line.split('\t', 1) | |
img_file = temp[0] | |
text = temp[1] | |
img_path = os.path.join(root_dir, 'image', img_file) | |
if not bpe_parser: | |
encoded_str = text | |
else: | |
encoded_str = bpe_parser.encode(text) | |
data.append({'img_path': img_path, 'image_id':img_id, 'text':text, 'encoded_str':encoded_str}) | |
img_id += 1 | |
return data | |
class SyntheticTextRecognitionDataset(FairseqDataset): | |
def __init__(self, gt_path, tfm, bpe_parser, target_dict): | |
self.gt_path = gt_path | |
self.tfm = tfm | |
self.target_dict = target_dict | |
self.data = STR(gt_path, bpe_parser) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img_dict = self.data[idx] | |
image = Image.open(img_dict['img_path']).convert('RGB') | |
encoded_str = img_dict['encoded_str'] | |
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False) | |
tfm_img = self.tfm(image) # h, w, c | |
return {'id': idx, 'tfm_img': tfm_img, 'label_ids': input_ids} | |
def size(self, idx): | |
img_dict = self.data[idx] | |
encoded_str = img_dict['encoded_str'] | |
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False) | |
return len(input_ids) | |
def num_tokens(self, idx): | |
return self.size(idx) | |
def collater(self, samples): | |
return default_collater(self.target_dict, samples) | |
def Receipt53K(gt_path): | |
root_dir = os.path.dirname(gt_path) | |
data = [] | |
with open(gt_path, 'r', encoding='utf8') as fp: | |
for line in tqdm(list(fp.readlines()), desc='Loading Receipt53K:'): | |
line = line.rstrip() | |
temp = line.split('\t', 1) | |
img_file = temp[0] | |
text = temp[1] | |
img_path = os.path.join(root_dir, img_file) | |
data.append({'img_path': img_path, 'text':text}) | |
return data | |
class Receipt53KDataset(FairseqDataset): | |
def __init__(self, gt_path, tfm, bpe_parser, target_dict): | |
self.gt_path = gt_path | |
self.tfm = tfm | |
self.target_dict = target_dict | |
self.bpe_parser = bpe_parser | |
self.data = Receipt53K(gt_path) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
img_dict = self.data[idx] | |
try: | |
image = Image.open(img_dict['img_path']).convert('RGB') | |
except Exception as e: | |
logger.warning('Failed to load image: {}, since {}'.format(img_dict['img_path'], str(e))) | |
return None | |
encoded_str = self.bpe_parser.encode(img_dict['text']) | |
input_ids = self.target_dict.encode_line(encoded_str, add_if_not_exist=False) | |
tfm_img = self.tfm(image) # h, w, c | |
return {'id': idx, 'tfm_img':tfm_img, 'label_ids':input_ids} | |
def size(self, idx): | |
img_dict = self.data[idx] | |
return len(img_dict['text']) | |
# item = self[idx] | |
# return len(item['label_ids']) | |
def num_tokens(self, idx): | |
return self.size(idx) | |
def collater(self, samples): | |
return default_collater(self.target_dict, samples) |