DHEIVER commited on
Commit
c26ef06
·
1 Parent(s): 68c4602

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -1,9 +1,14 @@
1
  import gradio
 
2
  from transformers import ViTForImageClassification
3
 
4
  # Load the ViT model
5
  model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
6
 
 
 
 
 
7
  # Create a Gradio interface
8
  interface = gradio.Interface(
9
  fn=model,
@@ -13,5 +18,8 @@ interface = gradio.Interface(
13
  description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
14
  )
15
 
 
 
 
16
  # Launch the Gradio app
17
  interface.launch()
 
1
  import gradio
2
+ import numpy as np
3
  from transformers import ViTForImageClassification
4
 
5
  # Load the ViT model
6
  model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
7
 
8
+ # Define a function to convert a NumPy array to a torch tensor
9
+ def numpy_to_tensor(array):
10
+ return torch.from_numpy(array).float()
11
+
12
  # Create a Gradio interface
13
  interface = gradio.Interface(
14
  fn=model,
 
18
  description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
19
  )
20
 
21
+ # Set the input block to handle NumPy arrays
22
+ interface.inputs[0].type = numpy_to_tensor
23
+
24
  # Launch the Gradio app
25
  interface.launch()