fadzwan commited on
Commit
2411469
·
verified ·
1 Parent(s): 447c034

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
2
  import gradio as gr
 
3
  from PIL import Image
4
 
5
  # Load the model and feature extractor
@@ -7,13 +8,33 @@ model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b
7
  feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
8
 
9
  def predict(image):
 
10
  inputs = feature_extractor(images=image, return_tensors="pt")
 
11
  outputs = model(**inputs)
12
- # Decode outputs and return results as needed
13
- return "Segmentation output placeholder" # Replace with actual processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def segmentation_interface(image):
16
  return predict(image)
17
 
18
  # Create a Gradio interface for image segmentation
19
- gr.Interface(fn=segmentation_interface, inputs="image", outputs="text").launch()
 
1
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
2
  import gradio as gr
3
+ import numpy as np
4
  from PIL import Image
5
 
6
  # Load the model and feature extractor
 
8
  feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
9
 
10
  def predict(image):
11
+ # Prepare the image for the model
12
  inputs = feature_extractor(images=image, return_tensors="pt")
13
+ # Get model outputs
14
  outputs = model(**inputs)
15
+
16
+ # Get the segmentation logits
17
+ logits = outputs.logits
18
+ # Apply softmax to get probabilities
19
+ probabilities = logits.softmax(dim=1) # shape: (batch_size, num_classes, height, width)
20
+
21
+ # Get the predicted class for each pixel
22
+ predicted_class = probabilities.argmax(dim=1).squeeze().cpu().numpy() # shape: (height, width)
23
+
24
+ # Create a color map (you can define your own color mapping for different classes)
25
+ color_map = np.array([[0, 0, 0], # Class 0 - background
26
+ [255, 0, 0], # Class 1 - red
27
+ [0, 255, 0], # Class 2 - green
28
+ [0, 0, 255]]) # Class 3 - blue
29
+
30
+ # Create an output mask image
31
+ mask_image = color_map[predicted_class] # Map class indices to colors
32
+ mask_image = Image.fromarray(mask_image.astype('uint8')) # Convert to PIL Image
33
+
34
+ return mask_image
35
 
36
  def segmentation_interface(image):
37
  return predict(image)
38
 
39
  # Create a Gradio interface for image segmentation
40
+ gr.Interface(fn=segmentation_interface, inputs="image", outputs="image").launch()