File size: 7,385 Bytes
4ebc565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from PIL import Image, ImageDraw, ImageFont
import os
import torch
import glob
import matplotlib.pyplot as plt

def read_images_in_path(path, size = (512,512)):
    image_paths = []
    for filename in os.listdir(path):
        if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
            image_path = os.path.join(path, filename)
            image_paths.append(image_path)
    image_paths = sorted(image_paths)
    return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths]

def concatenate_images(image_lists, return_list = False):
    num_rows = len(image_lists[0])
    num_columns = len(image_lists)
    image_width = image_lists[0][0].width
    image_height = image_lists[0][0].height

    grid_width = num_columns * image_width
    grid_height = num_rows * image_height if not return_list else image_height
    if not return_list:
        grid_image = [Image.new('RGB', (grid_width, grid_height))]
    else:
        grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)]

    for i in range(num_rows):
        row_index = i if return_list else 0
        for j in range(num_columns):
            image = image_lists[j][i]
            x_offset = j * image_width
            y_offset = i * image_height if not return_list else 0
            grid_image[row_index].paste(image, (x_offset, y_offset))

    return grid_image if return_list else grid_image[0]

def concatenate_images_single(image_lists):
    num_columns = len(image_lists)
    image_width = image_lists[0].width
    image_height = image_lists[0].height

    grid_width = num_columns * image_width
    grid_height = image_height
    grid_image = Image.new('RGB', (grid_width, grid_height))

    for j in range(num_columns):
        image = image_lists[j]
        x_offset = j * image_width
        y_offset = 0
        grid_image.paste(image, (x_offset, y_offset))

    return grid_image

def get_captions_for_images(images, device):
    from transformers import Blip2Processor, Blip2ForConditionalGeneration

    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
    )  # doctest: +IGNORE_RESULT

    res = []
    
    for image in images:
        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

        generated_ids = model.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        res.append(generated_text)

    del processor
    del model
    
    return res

def find_and_plot_images(directory, output_file, recursive=True, figsize=(15, 15), image_formats=("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tiff")):
    """
    Finds all images in the specified directory (optionally recursively) 
    and saves them in a single figure with their filenames.

    Parameters:
        directory (str): Path to the directory.
        output_file (str): Path to save the resulting figure (e.g., 'output.png').
        recursive (bool): Whether to search directories recursively.
        figsize (tuple): Size of the resulting figure.
        image_formats (tuple): Image file formats to look for.

    Returns:
        None
    """
    # Gather all image file paths
    pattern = "**/" if recursive else ""
    images = []
    for fmt in image_formats:
        images.extend(glob.glob(os.path.join(directory, pattern + fmt), recursive=recursive))

    images = [image for image in images if "noise.jpg" not in image and "results.jpg" not in image]  # Filter out noise and result images
    # move "original" to the front, followed by "reconstruction" and then the rest
    images = sorted(
        images,
        key=lambda x: (not x.endswith("original.jpg"), not x.endswith("reconstruction.jpg"), x)
    )
    
    if not images:
        print("No images found!")
        return

    # Create a figure
    num_images = len(images)
    cols = num_images  # Max 5 images per row
    rows = (num_images + cols - 1) // cols  # Calculate number of rows
    fig, axs = plt.subplots(rows, cols, figsize=figsize)
    axs = axs.flatten() if num_images > 1 else [axs]  # Flatten axes for single image case

    for i, image_path in enumerate(images):
        # Open and plot image
        img = Image.open(image_path)
        axs[i].imshow(img)
        axs[i].axis('off')  # Remove axes
        axs[i].set_title(os.path.basename(image_path), fontsize=8)  # Add filename

    # Hide any remaining empty axes
    for j in range(i + 1, len(axs)):
        axs[j].axis('off')

    plt.tight_layout()
    plt.savefig(output_file, bbox_inches='tight', dpi=300)  # Save the figure to the file
    plt.close(fig)  # Close the figure to free up memory
    print(f"Figure saved to {output_file}")


def add_label_to_image(image, label):
    """
    Adds a label to the lower-right corner of an image.

    Args:
        image (PIL.Image): Image to add the label to.
        label (str): Text to add as a label.

    Returns:
        PIL.Image: Image with the added label.
    """
    # Create a drawing context
    draw = ImageDraw.Draw(image)


    # Create a drawing context
    draw = ImageDraw.Draw(image)

    # Define font and size
    font_size = int(min(image.size) * 0.05)  # Adjust font size based on image dimensions
    try:
        font = ImageFont.truetype("fonts/arial.ttf", font_size)  # Replace with a font path if needed
    except IOError:
        font = ImageFont.load_default()  # Fallback to default font if arial.ttf is not found

    # Measure text size using textbbox
    text_bbox = draw.textbbox((0, 0), label, font=font)  # (left, top, right, bottom)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    # Position the text in the lower-right corner with some padding
    padding = 10
    position = (image.width - text_width - padding, image.height - text_height - padding)

    # Add a semi-transparent background for the label
    draw.rectangle(
        [
            (position[0] - padding, position[1] - padding),
            (position[0] + text_width + padding, position[1] + text_height + padding)
        ],
        fill=(0, 0, 0, 150)  # Black with transparency
    )

    # Draw the label
    draw.text(position, label, fill="white", font=font)

    return image

def crop_center_square_and_resize(img, size, output_path=None):
    """
    Crops the center of an image to make it square.
    
    Args:
        img (PIL.Image): Image to crop.
        output_path (str, optional): Path to save the cropped image. If None, the cropped image is not saved.
    
    Returns:
        Image: The cropped square image.
    """
    width, height = img.size
    # Determine the shorter side
    side_length = min(width, height)
    # Calculate the cropping box
    left = (width - side_length) // 2
    top = (height - side_length) // 2
    right = left + side_length
    bottom = top + side_length
    # Crop the image
    cropped_img = img.crop((left, top, right, bottom))
    # Resize the image
    cropped_img = cropped_img.resize(size)
    
    # Save the cropped image if output path is specified
    if output_path:
        cropped_img.save(output_path)
    
    return cropped_img