DHEIVER commited on
Commit
e352c58
·
1 Parent(s): c26ef06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import gradio
2
  import numpy as np
 
3
  from transformers import ViTForImageClassification
4
 
5
  # Load the ViT model
@@ -7,19 +8,20 @@ model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-i
7
 
8
  # Define a function to convert a NumPy array to a torch tensor
9
  def numpy_to_tensor(array):
10
- return torch.from_numpy(array).float()
 
 
 
 
11
 
12
  # Create a Gradio interface
13
- interface = gradio.Interface(
14
  fn=model,
15
- inputs="image",
16
- outputs=["label"],
17
  title="ViT Image Classifier",
18
  description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
19
  )
20
 
21
- # Set the input block to handle NumPy arrays
22
- interface.inputs[0].type = numpy_to_tensor
23
-
24
  # Launch the Gradio app
25
  interface.launch()
 
1
+ import gradio as gr
2
  import numpy as np
3
+ import torch
4
  from transformers import ViTForImageClassification
5
 
6
  # Load the ViT model
 
8
 
9
  # Define a function to convert a NumPy array to a torch tensor
10
  def numpy_to_tensor(array):
11
+ return torch.from_numpy(array).float()
12
+
13
+ # Create input and output components
14
+ input_component = gr.inputs.Image()
15
+ output_component = gr.outputs.Label()
16
 
17
  # Create a Gradio interface
18
+ interface = gr.Interface(
19
  fn=model,
20
+ inputs=input_component,
21
+ outputs=output_component,
22
  title="ViT Image Classifier",
23
  description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
24
  )
25
 
 
 
 
26
  # Launch the Gradio app
27
  interface.launch()