Add image processing functions and plot image grid
Browse files- src/utils.py +68 -1
 
    	
        src/utils.py
    CHANGED
    
    | 
         @@ -1,4 +1,8 @@ 
     | 
|
| 1 | 
         
             
            import os
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 2 | 
         
             
            import numpy as np
         
     | 
| 3 | 
         
             
            from PIL import Image, ImageOps
         
     | 
| 4 | 
         | 
| 
         @@ -76,4 +80,67 @@ def track_files(folder_path, extensions=('.jpg', '.jpeg', '.png')): 
     | 
|
| 76 | 
         
             
                        if extension.lower() in extensions:
         
     | 
| 77 | 
         
             
                            file_list.append(file_path)
         
     | 
| 78 | 
         | 
| 79 | 
         
            -
                return file_list
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import cv2
         
     | 
| 4 | 
         
            +
            import matplotlib.image as mpimg
         
     | 
| 5 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 6 | 
         
             
            import numpy as np
         
     | 
| 7 | 
         
             
            from PIL import Image, ImageOps
         
     | 
| 8 | 
         | 
| 
         | 
|
| 80 | 
         
             
                        if extension.lower() in extensions:
         
     | 
| 81 | 
         
             
                            file_list.append(file_path)
         
     | 
| 82 | 
         | 
| 83 | 
         
            +
                return file_list
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            def crop_circle_roi(image_path):
         
     | 
| 88 | 
         
            +
                """
         
     | 
| 89 | 
         
            +
                Crop the circular Region of Interest (ROI) from a fundus image.
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                Args:
         
     | 
| 92 | 
         
            +
                - image_path (str): Path to the fundus image.
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                Returns:
         
     | 
| 95 | 
         
            +
                - cropped_roi (numpy.ndarray): The cropped circular Region of Interest.
         
     | 
| 96 | 
         
            +
                """
         
     | 
| 97 | 
         
            +
                # Read the image
         
     | 
| 98 | 
         
            +
                image = cv2.imread(image_path, cv2.IMREAD_COLOR)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                # Convert the image to grayscale
         
     | 
| 101 | 
         
            +
                gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                # Apply thresholding to binarize the image
         
     | 
| 104 | 
         
            +
                _, thresholded_image = cv2.threshold(gray_image, 50, 255, cv2.THRESH_BINARY)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                # Find contours in the binary image
         
     | 
| 107 | 
         
            +
                contours, _ = cv2.findContours(thresholded_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                # Assuming the largest contour corresponds to the ROI
         
     | 
| 110 | 
         
            +
                contour = max(contours, key=cv2.contourArea)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                # Get the bounding rectangle of the contour
         
     | 
| 113 | 
         
            +
                x, y, w, h = cv2.boundingRect(contour)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                # Crop the circular ROI using the bounding rectangle
         
     | 
| 116 | 
         
            +
                cropped_roi = image[y:y+h, x:x+w]
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                return cropped_roi
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            def plot_image_grid(image_paths, roi_crop=False):
         
     | 
| 121 | 
         
            +
                """
         
     | 
| 122 | 
         
            +
                Create a grid plot with a maximum of 16 images.
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                Args:
         
     | 
| 125 | 
         
            +
                - image_paths (list): A list of image paths to be plotted.
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                Returns:
         
     | 
| 128 | 
         
            +
                - None
         
     | 
| 129 | 
         
            +
                """
         
     | 
| 130 | 
         
            +
                num_images = min(len(image_paths), 16)
         
     | 
| 131 | 
         
            +
                num_rows = (num_images - 1) // 4 + 1
         
     | 
| 132 | 
         
            +
                fig, axes = plt.subplots(num_rows, 4, figsize=(12, 3 * num_rows))
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                for i, ax in enumerate(axes.flat):
         
     | 
| 135 | 
         
            +
                    if i < num_images:
         
     | 
| 136 | 
         
            +
                        if roi_crop:
         
     | 
| 137 | 
         
            +
                            img = crop_and_pad_image(image_paths[i])
         
     | 
| 138 | 
         
            +
                        else:
         
     | 
| 139 | 
         
            +
                            img = mpimg.imread(image_paths[i])
         
     | 
| 140 | 
         
            +
                        ax.imshow(img)
         
     | 
| 141 | 
         
            +
                        ax.axis('off')
         
     | 
| 142 | 
         
            +
                    else:
         
     | 
| 143 | 
         
            +
                        ax.axis('off')
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                plt.tight_layout()
         
     | 
| 146 | 
         
            +
                plt.show()
         
     |