saba000farahani commited on
Commit
431f73f
·
verified ·
1 Parent(s): 5b84368

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -56
app.py CHANGED
@@ -1,70 +1,41 @@
1
  import os
2
  import gradio as gr
 
 
3
  import numpy as np
4
- import tensorflow as tf
5
  import cv2
6
- from mrcnn.config import Config
7
- from mrcnn import model as modellib
8
 
9
- # Set environment variable to avoid floating-point errors
10
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
 
11
 
12
- # Define Mask R-CNN configuration
13
- class InferenceConfig(Config):
14
- NAME = "mask_rcnn"
15
- NUM_CLASSES = 1 + 80 # Update according to your dataset
16
- GPU_COUNT = 1
17
- IMAGES_PER_GPU = 1
18
 
19
- config = InferenceConfig()
20
-
21
- # Initialize the Mask R-CNN model
22
- model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
23
-
24
- # Load the Mask R-CNN model weights
25
- model_path = os.path.join('toolkit', 'condmodel_100.h5')
26
- model.load_weights(model_path, by_name=True)
27
- print("Mask R-CNN model loaded successfully with weights.")
28
-
29
- # Function to apply Mask R-CNN for image segmentation
30
- def apply_mask_rcnn(image):
31
  try:
32
- # Convert image to RGB (in case of RGBA or grayscale)
33
- if image.shape[2] == 4: # Convert RGBA to RGB
34
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
35
-
36
- # Resize the image to match the model input size (for inference)
37
- resized_image = cv2.resize(image, (1024, 1024)) # Adjust based on model input requirements
38
- input_image = np.expand_dims(resized_image, axis=0)
39
-
40
- # Use Mask R-CNN to predict
41
- result = model.detect(input_image, verbose=0)
42
- r = result[0]
43
-
44
- # Create a mask for the detected objects
45
- mask = r['masks']
46
- mask = np.sum(mask, axis=-1) # Combine masks for all objects
47
-
48
- # Resize mask back to the original image size
49
- mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
50
-
51
- # Create a segmentation overlay on the original image
52
- mask_overlay = np.zeros_like(image)
53
- mask_overlay[mask > 0.5] = [0, 255, 0] # Green mask
54
-
55
- # Combine the original image with the mask
56
- segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)
57
 
58
- return segmented_image
 
 
 
 
 
 
59
 
60
  except Exception as e:
61
- print(f"Error in segmentation: {e}")
62
- return image # Return original image if segmentation fails
63
 
64
- # Gradio interface definition
65
  with gr.Blocks() as demo:
66
- gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>")
67
- gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.")
68
 
69
  # Input and output layout
70
  with gr.Row():
@@ -79,8 +50,8 @@ with gr.Blocks() as demo:
79
  output_image = gr.Image(type="numpy", label="Segmented Image")
80
 
81
  # Set up button functionality
82
- submit_btn.click(fn=apply_mask_rcnn, inputs=input_image, outputs=output_image)
83
- clear_btn.click(fn=lambda: None)
84
 
85
  # Launch the Gradio app
86
  demo.launch()
 
1
  import os
2
  import gradio as gr
3
+ from ultralytics import YOLO
4
+ from PIL import Image
5
  import numpy as np
 
6
  import cv2
 
 
7
 
8
+ # Define the directory paths for the model and assets based on the current script location
9
+ script_dir = os.path.dirname(os.path.abspath(__file__)) # Current script directory
10
+ yolo_weights_path = os.path.join(script_dir, 'toolkit', 'ALL_best.pt') # Path to YOLO weights
11
 
12
+ # Load the YOLO model
13
+ model = YOLO(yolo_weights_path)
14
+ model.fuse() # Optional for optimization
15
+ print("YOLO model loaded successfully with weights.")
 
 
16
 
17
+ # Function to perform detection and segmentation with YOLO
18
+ def apply_yolo_segmentation(image):
 
 
 
 
 
 
 
 
 
 
19
  try:
20
+ # Run YOLO on the image and get the results
21
+ results = model.predict(source=image, save=False) # Use save=False to keep results in memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Retrieve the annotated image from YOLO results
24
+ result_image = results[0].plot() # `plot()` returns the image with annotations
25
+
26
+ # Convert to RGB for Gradio output
27
+ result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
28
+
29
+ return result_image_rgb
30
 
31
  except Exception as e:
32
+ print(f"Error in YOLO segmentation: {e}")
33
+ return image # Return the original image if segmentation fails
34
 
35
+ # Define the Gradio interface
36
  with gr.Blocks() as demo:
37
+ gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with YOLO</h1>")
38
+ gr.Markdown("Upload an image to see segmentation results using the YOLO model.")
39
 
40
  # Input and output layout
41
  with gr.Row():
 
50
  output_image = gr.Image(type="numpy", label="Segmented Image")
51
 
52
  # Set up button functionality
53
+ submit_btn.click(fn=apply_yolo_segmentation, inputs=input_image, outputs=output_image)
54
+ clear_btn.click(fn=lambda: None, outputs=output_image)
55
 
56
  # Launch the Gradio app
57
  demo.launch()