DHEIVER commited on
Commit
21e2e72
·
1 Parent(s): 64e5dc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -2,29 +2,43 @@ import gradio as gr
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
 
5
- # Load the ViT model
6
  model_name = "google/vit-base-patch16-224-in21k"
7
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
8
  model = ViTForImageClassification.from_pretrained(model_name)
9
 
10
- # Define a function to preprocess the image and classify it
 
 
 
 
 
 
 
 
 
 
 
 
11
  def classify_image(input_image):
12
- # Preprocess the image using the feature extractor
13
  inputs = feature_extractor(input_image, return_tensors="pt")
14
- # Perform inference with the model
15
  outputs = model(**inputs)
16
- # Get the predicted label
17
- predicted_class = torch.argmax(outputs.logits, dim=1).item()
18
- return predicted_class
 
 
19
 
20
- # Create a Gradio interface
21
  interface = gr.Interface(
22
  fn=classify_image,
23
- inputs=gr.inputs.Image(type="numpy", label="Upload an image"),
24
  outputs="label",
25
- title="ViT Image Classifier",
26
- description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
27
  )
28
 
29
- # Launch the Gradio app
30
  interface.launch()
 
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
 
5
+ # Carregue o modelo ViT
6
  model_name = "google/vit-base-patch16-224-in21k"
7
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
8
  model = ViTForImageClassification.from_pretrained(model_name)
9
 
10
+ # Mapeamento de classe ID para rótulo
11
+ id2label = {
12
+ "0": "dyed-lifted-polyps",
13
+ "1": "dyed-resection-margins",
14
+ "2": "esophagitis",
15
+ "3": "normal-cecum",
16
+ "4": "normal-pylorus",
17
+ "5": "normal-z-line",
18
+ "6": "polyps",
19
+ "7": "ulcerative-colitis"
20
+ }
21
+
22
+ # Função para classificar a imagem
23
  def classify_image(input_image):
24
+ # Pré-processar a imagem usando o extrator de características
25
  inputs = feature_extractor(input_image, return_tensors="pt")
26
+ # Realizar inferência com o modelo
27
  outputs = model(**inputs)
28
+ # Obter a classe prevista
29
+ predicted_class_id = torch.argmax(outputs.logits, dim=1).item()
30
+ # Converter o ID da classe em rótulo usando o mapeamento id2label
31
+ predicted_class_label = id2label.get(str(predicted_class_id), "Desconhecido")
32
+ return predicted_class_label
33
 
34
+ # Criar uma interface Gradio
35
  interface = gr.Interface(
36
  fn=classify_image,
37
+ inputs=gr.inputs.Image(type="numpy", label="Carregar uma imagem"),
38
  outputs="label",
39
+ title="Classificador de Imagem ViT",
40
+ description="Esta aplicação Gradio permite classificar imagens usando um modelo Vision Transformer (ViT)."
41
  )
42
 
43
+ # Iniciar a aplicação Gradio
44
  interface.launch()