Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
|
|
|
|
4 |
|
5 |
# Carregue o modelo ViT
|
6 |
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
@@ -29,13 +31,13 @@ def classify_image(input_image):
|
|
29 |
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
|
30 |
# Converter o ID da classe em r贸tulo usando o mapeamento id2label
|
31 |
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
32 |
-
return predicted_class_label
|
33 |
|
34 |
# Criar uma interface Gradio
|
35 |
interface = gr.Interface(
|
36 |
fn=classify_image,
|
37 |
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
|
38 |
-
outputs="
|
39 |
title="Classificador de Imagem ViT",
|
40 |
description="Esta aplica莽茫o Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
|
41 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
|
7 |
# Carregue o modelo ViT
|
8 |
model_name = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
|
|
|
31 |
predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
|
32 |
# Converter o ID da classe em r贸tulo usando o mapeamento id2label
|
33 |
predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
|
34 |
+
return input_image, predicted_class_label
|
35 |
|
36 |
# Criar uma interface Gradio
|
37 |
interface = gr.Interface(
|
38 |
fn=classify_image,
|
39 |
inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
|
40 |
+
outputs=["image", "text"],
|
41 |
title="Classificador de Imagem ViT",
|
42 |
description="Esta aplica莽茫o Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
|
43 |
)
|