mp-02 commited on
Commit
9b8edeb
·
verified ·
1 Parent(s): 9c725ef

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +36 -21
inference.py CHANGED
@@ -16,24 +16,6 @@ model.to(device)
16
 
17
  import json
18
 
19
-
20
- # Funzione per creare l'output JSON in formato CORD-like
21
- def create_json_output(words, labels, boxes):
22
- output = []
23
-
24
- for word, label, box in zip(words, labels, boxes):
25
- # Considera solo le etichette rilevanti (escludendo "O")
26
- if label != "O":
27
- output.append({
28
- "text": word,
29
- "category": label, # la categoria predetta dal modello (es. "B-PRODUCT", "B-PRICE", "B-TOTAL")
30
- "bounding_box": box # le coordinate di bounding box per la parola
31
- })
32
-
33
- # Converti in JSON
34
- json_output = json.dumps(output, indent=4)
35
- return json_output
36
-
37
  def prediction(image):
38
 
39
  boxes, words = OCR(image)
@@ -55,8 +37,41 @@ def prediction(image):
55
  token_boxes = encoding.bbox.squeeze().tolist()
56
  probabilities = torch.softmax(outputs.logits, dim=-1)
57
  confidence_scores = probabilities.max(-1).values.squeeze().tolist()
58
- # Crea il JSON usando i risultati ottenuti
59
- json_result = create_json_output(words, labels, boxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  draw = ImageDraw.Draw(image, "RGBA")
62
  font = ImageFont.load_default()
@@ -65,5 +80,5 @@ def prediction(image):
65
  draw.rectangle(box)
66
  draw.text((box[0]+10, box[1]-10), text=str(prediction)+ ", "+ str(confidence), font=font, fill="black", font_size="15")
67
 
68
- return image, json_result
69
 
 
16
 
17
  import json
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def prediction(image):
20
 
21
  boxes, words = OCR(image)
 
37
  token_boxes = encoding.bbox.squeeze().tolist()
38
  probabilities = torch.softmax(outputs.logits, dim=-1)
39
  confidence_scores = probabilities.max(-1).values.squeeze().tolist()
40
+
41
+ inp_ids = encoding.input_ids.squeeze().tolist()
42
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
43
+
44
+ width, height = image.size
45
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
46
+
47
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
48
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
49
+ true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
50
+ true_words = []
51
+
52
+ for id, i in enumerate(inp_words):
53
+ if not is_subword[id]:
54
+ true_words.append(i)
55
+ else:
56
+ true_words[-1] = true_words[-1]+i
57
+
58
+ true_predictions = true_predictions[1:-1]
59
+ true_boxes = true_boxes[1:-1]
60
+ true_words = true_words[1:-1]
61
+ true_confidence_scores = true_confidence_scores[1:-1]
62
+
63
+ d = {}
64
+ for id, i in enumerate(true_predictions):
65
+ #rimuovo i prefissi
66
+ if i != "O":
67
+ i = i[2:]
68
+ if i not in d.keys():
69
+ d[i] = true_words[id]
70
+ else:
71
+ d[i] = d[i] + ", " + true_words[id]
72
+ d = {k: v.strip() for (k, v) in d.items()}
73
+
74
+ if "O" in d: d.pop("O")
75
 
76
  draw = ImageDraw.Draw(image, "RGBA")
77
  font = ImageFont.load_default()
 
80
  draw.rectangle(box)
81
  draw.text((box[0]+10, box[1]-10), text=str(prediction)+ ", "+ str(confidence), font=font, fill="black", font_size="15")
82
 
83
+ return image, d
84