DHEIVER commited on
Commit
59a3399
·
1 Parent(s): 452a079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import ViTFeatureExtractor, ViTForImageClassification
3
  from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
 
4
 
5
  # Load the pre-trained ViT model
6
  path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
@@ -9,7 +10,7 @@ classifier = VisionClassifierInference(
9
  model=ViTForImageClassification.from_pretrained(path),
10
  )
11
 
12
- # Define a Gradio interface
13
  def classify_image(image_file):
14
  """Classify an image using a pre-trained ViT model."""
15
  label = classifier.predict(img_path=image_file.name)
@@ -17,9 +18,19 @@ def classify_image(image_file):
17
  confidence = classifier.predict_proba(img_path=image_file.name)[0][label]
18
 
19
  # Get the PIL Image object for the uploaded image
20
- image = image_file.read()
 
 
 
 
 
 
 
 
 
 
21
 
22
- return image, f"Predicted class: {label} (confidence: {confidence:.2f})"
23
 
24
  iface = gr.Interface(
25
  fn=classify_image,
 
1
  import gradio as gr
2
  from transformers import ViTFeatureExtractor, ViTForImageClassification
3
  from hugsvision.inference.VisionClassifierInference import VisionClassifierInference
4
+ from PIL import Image, ImageDraw, ImageFont
5
 
6
  # Load the pre-trained ViT model
7
  path = "mrm8488/vit-base-patch16-224_finetuned-kvasirv2-colonoscopy"
 
10
  model=ViTForImageClassification.from_pretrained(path),
11
  )
12
 
13
+
14
  def classify_image(image_file):
15
  """Classify an image using a pre-trained ViT model."""
16
  label = classifier.predict(img_path=image_file.name)
 
18
  confidence = classifier.predict_proba(img_path=image_file.name)[0][label]
19
 
20
  # Get the PIL Image object for the uploaded image
21
+ image = Image.open(image_file)
22
+
23
+ # Draw the predicted label on the image
24
+ draw = ImageDraw.Draw(image)
25
+ font = ImageFont.truetype("arial.ttf", 20)
26
+ draw.text((10, 10), f"Predicted class: {label} (confidence: {confidence:.2f})", font=font, fill=(255, 255, 255))
27
+
28
+ # Save the modified image to a BytesIO object
29
+ output_image = BytesIO()
30
+ image.save(output_image, format="JPEG")
31
+ output_image.seek(0)
32
 
33
+ return output_image, f"Predicted class: {label} (confidence: {confidence:.2f})"
34
 
35
  iface = gr.Interface(
36
  fn=classify_image,