File size: 3,760 Bytes
1d3f775
 
 
 
 
fefbab6
1d3f775
 
 
da2fab4
 
 
1d3f775
 
 
fefbab6
438b834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37e59b4
438b834
 
 
 
 
 
 
 
 
 
 
 
 
fefbab6
1d3f775
438b834
1d3f775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fefbab6
1d3f775
 
 
 
438b834
1d3f775
 
 
 
 
 
 
 
 
305e358
1d3f775
 
 
305e358
1d3f775
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from ultralytics import YOLO
import torch

# Define available YOLO models
available_models = {
    "X-ray": YOLO("xray.pt"),
    "CT scan": YOLO("CT.pt"),
    "Ultrasound": YOLO("ultrasound.pt"),
    # Add more models as needed
}


def segment_image(input_image):
    # Resize the input image to 255x255
    img = np.array(input_image)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Perform object detection and segmentation
    results = model(img)
    mask =  results[0].masks.data.numpy()
    target_height = img.shape[0]
    target_width = img.shape[1]

    # Resize the mask using OpenCV
    resized_mask = cv2.resize(mask[0], (target_width, target_height))
    resized_mask = (resized_mask * 255).astype(np.uint8)

    # Create a copy of the original image
    overlay_image = img.copy()

    # Apply the resized mask to the overlay image
    overlay_image[resized_mask > 0] = [50, 0, 0]  # Overlay in green

    # Convert the overlay image to PIL format
    overlay_pil = Image.fromarray(overlay_image)

    # Convert the resized mask to PIL format
    mask_pil = Image.fromarray(resized_mask)
    
    for result in results:
        boxes = result.boxes
        bbox = boxes.xyxy.tolist()[0]
    
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device='cpu')
    predictor = SamPredictor(sam)
    predictor.set_image(img)
    
    
    input_box = np.array(bbox)
    masks_, _, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box,
        multimask_output=False)

    fmask = masks_[0].astype(int)

    resized_mask1 =cv2.resize(fmask, (target_width, target_height))
    resized_mask1 = (resized_mask1 * 255).astype(np.uint8)

    overlay_image1 = img.copy()
    # Apply the resized mask to the overlay image
    overlay_image1[resized_mask1 > 0] = [50, 50, 0]  # Overlay in green

    # Convert the overlay image to PIL format
    overlay_pil1 = Image.fromarray(overlay_image1)

    
    return overlay_pil, overlay_pil1  # Return both overlay image and mask



# Create a function to perform image segmentation using the selected model
'''def segment_image(input_image, selected_model):
    # Resize the input image to 255x255
    img = np.array(input_image)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Perform object detection and segmentation using the selected model
    model = available_models[selected_model]
    results = model(img)
    mask = results[0].masks.data.numpy()
    target_height = img.shape[0]
    target_width = img.shape[1]

    # Resize the mask using OpenCV
    resized_mask = cv2.resize(mask[0], (target_width, target_height))
    resized_mask = (resized_mask * 255).astype(np.uint8)

    # Create a copy of the original image
    overlay_image = img.copy()

    # Apply the resized mask to the overlay image
    overlay_image[resized_mask > 0] = [50, 0, 0]  # Overlay in green

    # Convert the overlay image to PIL format
    overlay_pil = Image.fromarray(overlay_image)

    return overlay_pil'''

# Create the Gradio interface with a dropdown for model selection
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.inputs.Image(type="pil", label="Upload an image"),
        gr.inputs.Dropdown(
            choices=list(available_models.keys()),
            label="Select YOLO Model",
            default="X-ray"
        )
    ],
    outputs=gr.outputs.Image(type="numpy", label="Segmented Image"),
    title="YOLOv8 with SAM πŸ˜ƒ",
    description='This software generates the segmentation mask for Aorta for Point of Care Ultrasound (POCUS) images'
)

iface.launch()