DHEIVER commited on
Commit
64e5dc2
·
1 Parent(s): e352c58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -1,24 +1,27 @@
1
  import gradio as gr
2
- import numpy as np
3
  import torch
4
- from transformers import ViTForImageClassification
5
 
6
  # Load the ViT model
7
- model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
 
 
8
 
9
- # Define a function to convert a NumPy array to a torch tensor
10
- def numpy_to_tensor(array):
11
- return torch.from_numpy(array).float()
12
-
13
- # Create input and output components
14
- input_component = gr.inputs.Image()
15
- output_component = gr.outputs.Label()
 
 
16
 
17
  # Create a Gradio interface
18
  interface = gr.Interface(
19
- fn=model,
20
- inputs=input_component,
21
- outputs=output_component,
22
  title="ViT Image Classifier",
23
  description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
24
  )
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
4
 
5
  # Load the ViT model
6
+ model_name = "google/vit-base-patch16-224-in21k"
7
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
8
+ model = ViTForImageClassification.from_pretrained(model_name)
9
 
10
+ # Define a function to preprocess the image and classify it
11
+ def classify_image(input_image):
12
+ # Preprocess the image using the feature extractor
13
+ inputs = feature_extractor(input_image, return_tensors="pt")
14
+ # Perform inference with the model
15
+ outputs = model(**inputs)
16
+ # Get the predicted label
17
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
18
+ return predicted_class
19
 
20
  # Create a Gradio interface
21
  interface = gr.Interface(
22
+ fn=classify_image,
23
+ inputs=gr.inputs.Image(type="numpy", label="Upload an image"),
24
+ outputs="label",
25
  title="ViT Image Classifier",
26
  description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
27
  )