Update app.py
Browse files
app.py
CHANGED
@@ -29,25 +29,36 @@ def classify_image(input_image):
|
|
29 |
outputs = model(**inputs)
|
30 |
# Obter a classe prevista
|
31 |
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
|
|
|
|
|
32 |
# Converter o ID da classe em r贸tulo usando o mapeamento id2label
|
33 |
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
34 |
|
35 |
# Converter a imagem de numpy para BGR (formato OpenCV)
|
36 |
input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# Adicionar o r贸tulo da previs茫o na imagem
|
39 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
40 |
-
text = f'
|
41 |
text_size = cv2.getTextSize(text, font, 0.7, 2)[0]
|
42 |
text_x = (input_image_bgr.shape[1] - text_size[0]) // 2
|
43 |
text_y = input_image_bgr.shape[0] - 20
|
44 |
-
cv2.putText(input_image_bgr, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2)
|
45 |
|
46 |
# Converter a imagem resultante de volta para RGB (formato Pillow)
|
47 |
result_image = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)
|
48 |
return result_image
|
49 |
|
50 |
|
|
|
51 |
# Criar uma interface Gradio
|
52 |
interface = gr.Interface(
|
53 |
fn=classify_image,
|
|
|
29 |
outputs = model(**inputs)
|
30 |
# Obter a classe prevista
|
31 |
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
|
32 |
+
# Obter a probabilidade da classe prevista
|
33 |
+
predicted_class_prob = torch.softmax(outputs.logits, dim=1)[0, predicted_class_id].item()
|
34 |
# Converter o ID da classe em r贸tulo usando o mapeamento id2label
|
35 |
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
36 |
|
37 |
# Converter a imagem de numpy para BGR (formato OpenCV)
|
38 |
input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
|
39 |
|
40 |
+
# Definir cores de borda para cada classe (aqui, cores aleat贸rias)
|
41 |
+
class_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
|
42 |
+
(255, 0, 255), (0, 255, 255), (128, 128, 128), (0, 0, 0)]
|
43 |
+
|
44 |
+
# Adicionar uma borda colorida 脿 imagem
|
45 |
+
border_color = class_colors[predicted_class_id]
|
46 |
+
input_image_bgr = cv2.copyMakeBorder(input_image_bgr, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=border_color)
|
47 |
+
|
48 |
# Adicionar o r贸tulo da previs茫o na imagem
|
49 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
50 |
+
text = f'Classe: {predicted_class_label} ({predicted_class_prob:.2f})'
|
51 |
text_size = cv2.getTextSize(text, font, 0.7, 2)[0]
|
52 |
text_x = (input_image_bgr.shape[1] - text_size[0]) // 2
|
53 |
text_y = input_image_bgr.shape[0] - 20
|
54 |
+
cv2.putText(input_image_bgr, text, (text_x, text_y), font, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
55 |
|
56 |
# Converter a imagem resultante de volta para RGB (formato Pillow)
|
57 |
result_image = cv2.cvtColor(input_image_bgr, cv2.COLOR_BGR2RGB)
|
58 |
return result_image
|
59 |
|
60 |
|
61 |
+
|
62 |
# Criar uma interface Gradio
|
63 |
interface = gr.Interface(
|
64 |
fn=classify_image,
|