Reciept_Analyzer / models.py
Huy0502's picture
Upload 19 files
cfc8c5a verified
raw
history blame
6.56 kB
import os
import torch
import numpy as np
from ultralytics import YOLO
from transformers import AutoProcessor
from transformers import AutoModelForTokenClassification
from utils import normalize_box, unnormalize_box, draw_output, create_df
from PIL import Image, ImageDraw
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
class Reciept_Analyzer:
def __init__(self,
processor_pretrained='microsoft/layoutlmv3-base',
layoutlm_pretrained=os.path.join(
'models', 'checkpoint'),
yolo_pretrained=os.path.join(
'models', 'best.pt'),
vietocr_pretrained=os.path.join(
'models', 'vietocr', 'vgg_seq2seq.pth')
):
print("Initializing processor")
self.processor = AutoProcessor.from_pretrained(
processor_pretrained, apply_ocr=False)
print("Finished initializing processor")
print("Initializing LayoutLM model")
self.lalm_model = AutoModelForTokenClassification.from_pretrained(
layoutlm_pretrained)
print("Finished initializing LayoutLM model")
if yolo_pretrained is not None:
print("Initializing YOLO model")
self.yolo_model = YOLO(yolo_pretrained)
print("Finished initializing YOLO model")
print("Initializing VietOCR model")
config = Cfg.load_config_from_name('vgg_seq2seq')
config['weights'] = vietocr_pretrained
config['cnn']['pretrained']= False
config['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.vietocr = Predictor(config)
print("Finished initializing VietOCR model")
def forward(self, img, output_path="output", is_save_cropped_img=False):
input_image = Image.open(img)
# detection with YOLOv8
bboxes = self.yolov8_det(input_image)
# sort
sorted_bboxes = self.sort_bboxes(bboxes)
# draw bbox
image_draw = input_image.copy()
self.draw_bbox(image_draw, sorted_bboxes, output_path)
# crop images
cropped_images, normalized_boxes = self.get_cropped_images(input_image, sorted_bboxes, is_save_cropped_img, output_path)
# recognition with VietOCR
texts, mapping_bbox_texts = self.ocr(cropped_images, normalized_boxes)
# KIE with LayoutLMv3
pred_texts, pred_label, boxes = self.kie(input_image, texts, normalized_boxes, mapping_bbox_texts, output_path)
# create dataframe
return create_df(pred_texts, pred_label)
def yolov8_det(self, img):
return self.yolo_model.predict(source=img, conf=0.3, iou=0.1)[0].boxes.xyxy.int()
def sort_bboxes(self, bboxes):
bbox_list = []
for box in bboxes:
tlx, tly, brx, bry = map(int, box)
bbox_list.append([tlx, tly, brx, bry])
bbox_list.sort(key=lambda x: (x[1], x[2]))
return bbox_list
def draw_bbox(self, image_draw, bboxes, output_path):
# draw bbox
draw = ImageDraw.Draw(image_draw)
for box in bboxes:
draw.rectangle(box, outline='red', width=2)
image_draw.save(os.path.join(output_path, 'bbox.jpg'))
print(f"Exported image with bounding boxes to {os.path.join(output_path, 'bbox.jpg')}")
def get_cropped_images(self, input_image, bboxes, is_save_cropped=False, output_path="output"):
normalized_boxes = []
cropped_images = []
# OCR
if is_save_cropped:
cropped_folder = os.path.join(output_path, "cropped")
if not os.path.exists(cropped_folder):
os.makedirs(cropped_folder)
i = 0
for box in bboxes:
tlx, tly, brx, bry = map(int, box)
normalized_box = normalize_box(box, input_image.width, input_image.height)
normalized_boxes.append(normalized_box)
cropped_ = input_image.crop((tlx, tly, brx, bry))
if is_save_cropped:
cropped_.save(os.path.join(cropped_folder, f'cropped_{i}.jpg'))
i += 1
cropped_images.append(cropped_)
return cropped_images, normalized_boxes
def ocr(self, cropped_images, normalized_boxes):
mapping_bbox_texts = {}
texts = []
for img, normalized_box in zip(cropped_images, normalized_boxes):
result = self.vietocr.predict(img)
text = result.strip().replace('\n', ' ')
texts.append(text)
mapping_bbox_texts[','.join(map(str, normalized_box))] = text
return texts, mapping_bbox_texts
def kie(self, img, texts, boxes, mapping_bbox_texts, output_path):
encoding = self.processor(img, texts,
boxes=boxes,
return_offsets_mapping=True,
return_tensors='pt',
max_length=512,
padding='max_length')
offset_mapping = encoding.pop('offset_mapping')
with torch.no_grad():
outputs = self.lalm_model(**encoding)
id2label = self.lalm_model.config.id2label
logits = outputs.logits
token_boxes = encoding.bbox.squeeze().tolist()
offset_mapping = offset_mapping.squeeze().tolist()
predictions = logits.argmax(-1).squeeze().tolist()
is_subword = np.array(offset_mapping)[:, 0] != 0
true_predictions = []
true_boxes = []
true_texts = []
for idx in range(len(predictions)):
if not is_subword[idx] and token_boxes[idx] != [0, 0, 0, 0]:
true_predictions.append(id2label[predictions[idx]])
true_boxes.append(unnormalize_box(
token_boxes[idx], img.width, img.height))
true_texts.append(mapping_bbox_texts.get(
','.join(map(str, token_boxes[idx])), ''))
if isinstance(output_path, str):
os.makedirs(output_path, exist_ok=True)
img_output = draw_output(
image=img,
true_predictions=true_predictions,
true_boxes=true_boxes
)
img_output.save(os.path.join(output_path, 'result.jpg'))
print(f"Exported result to {os.path.join(output_path, 'result.jpg')}")
return true_texts, true_predictions, true_boxes