LayoutLMv3_for_recepits2 / inference.py
mp-02's picture
Update inference.py
f3df04e verified
raw
history blame
4.17 kB
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie")
id2label = model.config.id2label
label2id = model.config.label2id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
import json
def token2json(words, labels):
result = {}
current_entity = None
current_text = []
for word, label in zip(words, labels):
if label.startswith("B-"):
if current_entity:
result[current_entity] = " ".join(current_text).strip()
current_text = []
current_entity = label[2:].lower()
current_text = [word]
elif label.startswith("I-"):
if current_entity == label[2:].lower():
current_text.append(word)
else:
# Gestione di sequenze I- non precedute da B-
if current_entity:
result[current_entity] = " ".join(current_text).strip()
current_entity = label[2:].lower()
current_text = [word]
else: # Label "O"
if current_entity:
result[current_entity] = " ".join(current_text).strip()
current_entity = None
current_text = []
# Aggiunge l'ultima entità se presente
if current_entity:
result[current_entity] = " ".join(current_text).strip()
return json.dumps(result, ensure_ascii=False, indent=2)
def prediction(image):
boxes, words = OCR(image)
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offset_mapping = encoding.pop('offset_mapping')
for k, v in encoding.items():
encoding[k] = v.to(device)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
probabilities = torch.softmax(outputs.logits, dim=-1)
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
inp_ids = encoding.input_ids.squeeze().tolist()
inp_words = [tokenizer.decode(i) for i in inp_ids]
width, height = image.size
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
true_words = []
for id, i in enumerate(inp_words):
if not is_subword[id]:
true_words.append(i)
else:
true_words[-1] = true_words[-1]+i
true_predictions = true_predictions[1:-1]
true_boxes = true_boxes[1:-1]
true_words = true_words[1:-1]
true_confidence_scores = true_confidence_scores[1:-1]
for i, conf in enumerate(true_confidence_scores):
if conf < 0.6 :
true_predictions[i] = "O"
d = token2json(true_words, true_predictions)
"""for id, i in enumerate(true_predictions):
if i not in d.keys():
d[i] = true_words[id]
else:
d[i] = d[i] + ", " + true_words[id]
d = {k: v.strip() for (k, v) in d.items()}
d.pop("O")"""
draw = ImageDraw.Draw(image, "RGBA")
font = ImageFont.load_default()
for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
draw.rectangle(box)
draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")
return image, d