Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +36 -21
inference.py
CHANGED
@@ -16,24 +16,6 @@ model.to(device)
|
|
16 |
|
17 |
import json
|
18 |
|
19 |
-
|
20 |
-
# Funzione per creare l'output JSON in formato CORD-like
|
21 |
-
def create_json_output(words, labels, boxes):
|
22 |
-
output = []
|
23 |
-
|
24 |
-
for word, label, box in zip(words, labels, boxes):
|
25 |
-
# Considera solo le etichette rilevanti (escludendo "O")
|
26 |
-
if label != "O":
|
27 |
-
output.append({
|
28 |
-
"text": word,
|
29 |
-
"category": label, # la categoria predetta dal modello (es. "B-PRODUCT", "B-PRICE", "B-TOTAL")
|
30 |
-
"bounding_box": box # le coordinate di bounding box per la parola
|
31 |
-
})
|
32 |
-
|
33 |
-
# Converti in JSON
|
34 |
-
json_output = json.dumps(output, indent=4)
|
35 |
-
return json_output
|
36 |
-
|
37 |
def prediction(image):
|
38 |
|
39 |
boxes, words = OCR(image)
|
@@ -55,8 +37,41 @@ def prediction(image):
|
|
55 |
token_boxes = encoding.bbox.squeeze().tolist()
|
56 |
probabilities = torch.softmax(outputs.logits, dim=-1)
|
57 |
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
draw = ImageDraw.Draw(image, "RGBA")
|
62 |
font = ImageFont.load_default()
|
@@ -65,5 +80,5 @@ def prediction(image):
|
|
65 |
draw.rectangle(box)
|
66 |
draw.text((box[0]+10, box[1]-10), text=str(prediction)+ ", "+ str(confidence), font=font, fill="black", font_size="15")
|
67 |
|
68 |
-
return image,
|
69 |
|
|
|
16 |
|
17 |
import json
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def prediction(image):
|
20 |
|
21 |
boxes, words = OCR(image)
|
|
|
37 |
token_boxes = encoding.bbox.squeeze().tolist()
|
38 |
probabilities = torch.softmax(outputs.logits, dim=-1)
|
39 |
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
|
40 |
+
|
41 |
+
inp_ids = encoding.input_ids.squeeze().tolist()
|
42 |
+
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
43 |
+
|
44 |
+
width, height = image.size
|
45 |
+
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
|
46 |
+
|
47 |
+
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
48 |
+
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
49 |
+
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
|
50 |
+
true_words = []
|
51 |
+
|
52 |
+
for id, i in enumerate(inp_words):
|
53 |
+
if not is_subword[id]:
|
54 |
+
true_words.append(i)
|
55 |
+
else:
|
56 |
+
true_words[-1] = true_words[-1]+i
|
57 |
+
|
58 |
+
true_predictions = true_predictions[1:-1]
|
59 |
+
true_boxes = true_boxes[1:-1]
|
60 |
+
true_words = true_words[1:-1]
|
61 |
+
true_confidence_scores = true_confidence_scores[1:-1]
|
62 |
+
|
63 |
+
d = {}
|
64 |
+
for id, i in enumerate(true_predictions):
|
65 |
+
#rimuovo i prefissi
|
66 |
+
if i != "O":
|
67 |
+
i = i[2:]
|
68 |
+
if i not in d.keys():
|
69 |
+
d[i] = true_words[id]
|
70 |
+
else:
|
71 |
+
d[i] = d[i] + ", " + true_words[id]
|
72 |
+
d = {k: v.strip() for (k, v) in d.items()}
|
73 |
+
|
74 |
+
if "O" in d: d.pop("O")
|
75 |
|
76 |
draw = ImageDraw.Draw(image, "RGBA")
|
77 |
font = ImageFont.load_default()
|
|
|
80 |
draw.rectangle(box)
|
81 |
draw.text((box[0]+10, box[1]-10), text=str(prediction)+ ", "+ str(confidence), font=font, fill="black", font_size="15")
|
82 |
|
83 |
+
return image, d
|
84 |
|