DHEIVER commited on
Commit
68c4602
·
1 Parent(s): 5c84014

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -43
app.py CHANGED
@@ -1,45 +1,17 @@
1
- import gradio as gr
2
- from transformers import ViTFeatureExtractor, ViTForImageClassification
3
- from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
4
- from PIL import Image, ImageDraw, ImageFont
5
-
6
- # Load the pre-trained ViT model
7
- path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
8
- classifier = VisionClassifierInference(
9
- feature_extractor=ViTFeatureExtractor.from_pretrained(path),
10
- model=ViTForImageClassification.from_pretrained(path),
11
- )
12
-
13
-
14
- def classify_image(image_file):
15
- """Classify an image using a pre-trained ViT model."""
16
- label = classifier.predict(img=Image.open(image_file.name))
17
- # Add a confidence score to the output
18
- confidence = classifier.predict_proba(img=Image.open(image_file.name))[0][label]
19
-
20
- # Get the PIL Image object for the uploaded image
21
- image = Image.open(image_file)
22
-
23
- # Draw the predicted label on the image
24
- draw = ImageDraw.Draw(image)
25
- font = ImageFont.truetype("arial.ttf", 20)
26
- draw.text((10, 10), f"Predicted class: {label} (confidence: {confidence:.2f})", font=font, fill=(255, 255, 255))
27
-
28
- # Save the modified image to a BytesIO object
29
- output_image = BytesIO()
30
- image.save(output_image, format="JPEG")
31
- output_image.seek(0)
32
-
33
- return output_image, f"Predicted class: {label} (confidence: {confidence:.2f})"
34
-
35
-
36
-
37
- iface = gr.Interface(
38
- fn=classify_image,
39
- inputs=gr.inputs.Image(type="filepath", label="Upload an image"),
40
- outputs=[gr.outputs.Image(type="numpy"), "text"],
41
- title="Image Classifier",
42
- description="Classify images using a pre-trained ViT model",
43
  )
44
 
45
- iface.launch()
 
 
1
+ import gradio
2
+ from transformers import ViTForImageClassification
3
+
4
+ # Load the ViT model
5
+ model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
6
+
7
+ # Create a Gradio interface
8
+ interface = gradio.Interface(
9
+ fn=model,
10
+ inputs="image",
11
+ outputs=["label"],
12
+ title="ViT Image Classifier",
13
+ description="This Gradio app allows you to classify images using a Vision Transformer (ViT) model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  )
15
 
16
+ # Launch the Gradio app
17
+ interface.launch()