File size: 14,572 Bytes
53bf77d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import image_mask_gen
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import os
import io
import warnings
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

import streamlit as st
import base64


# Function to display points on the image using matplotlib
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def remove_duplicates(coords, labels):
    unique_coords = []
    unique_labels = []
    seen = set()

    for coord, label in zip(coords, labels):
        coord_tuple = tuple(coord)
        if coord_tuple not in seen:
            seen.add(coord_tuple)
            unique_coords.append(coord)
            unique_labels.append(label)
            
    return unique_coords, unique_labels


def image_augmentation_page():
    pass
    st.title("Image Augmentation")
    st.write("Upload an image to apply augmentation techniques.")

    # Initialize session state variables
    if "inclusive_points" not in st.session_state:
        st.session_state.inclusive_points = []
    if "exclusive_points" not in st.session_state:
        st.session_state.exclusive_points = []
    
    # Upload an image
    uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])

    if uploaded_file is not None:
        # Open the uploaded image
        image = Image.open(uploaded_file)

        # Set the maximum width for display
        max_display_width = 700  # You can adjust this value

        # Calculate the scaling factor
        scale_factor = min(max_display_width / image.size[0], 1)

        # Resize the image for display
        display_width = int(image.size[0] * scale_factor)
        display_height = int(image.size[1] * scale_factor)
        resized_image = image.resize((display_width, display_height))

        # Inclusive Points Phase
        st.subheader("Select Inclusive Points (Green)")
        canvas_inclusive = st_canvas(
            fill_color="rgba(0, 0, 0, 0)",  # Transparent fill
            stroke_width=1,                # Stroke width for drawing
            stroke_color="blue",           # Color for the outline of clicks
            background_image=resized_image,
            update_streamlit=True,
            height=display_height,
            width=display_width,
            drawing_mode="circle",         # Drawing mode to capture clicks as circles
            point_display_radius=3,        # Radius of the circle that represents a click
            key="canvas_inclusive"
        )

        # Process inclusive clicks
        if canvas_inclusive.json_data is not None:
            objects = canvas_inclusive.json_data["objects"]
            new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects]
            st.session_state.inclusive_points.extend(new_clicks)

        # Plot the inclusive points on the original image using Matplotlib
        fig_inclusive, ax = plt.subplots()
        ax.imshow(image)
        ax.axis('off')  # Hide the axes

        # Prepare data for plotting
        inclusive_points = np.array(st.session_state.inclusive_points)
        labels_inclusive = np.array([1] * len(st.session_state.inclusive_points))

        # Call the function to show inclusive points
        if len(inclusive_points) > 0:
            show_points(inclusive_points, labels_inclusive, ax)

        st.pyplot(fig_inclusive)

        # Divider
        st.divider()

        # Exclusive Points Phase
        st.subheader("Select Exclusive Points (Red)")
        canvas_exclusive = st_canvas(
            fill_color="rgba(0, 0, 0, 0)",  # Transparent fill
            stroke_width=1,                # Stroke width for drawing
            stroke_color="blue",           # Color for the outline of clicks
            background_image=resized_image,
            update_streamlit=True,
            height=display_height,
            width=display_width,
            drawing_mode="circle",         # Drawing mode to capture clicks as circles
            point_display_radius=3,        # Radius of the circle that represents a click
            key="canvas_exclusive"
        )

        # Process exclusive clicks
        if canvas_exclusive.json_data is not None:
            objects = canvas_exclusive.json_data["objects"]
            new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects]
            st.session_state.exclusive_points.extend(new_clicks)

        # Plot the exclusive points on the original image using Matplotlib
        fig_exclusive, ax = plt.subplots()
        ax.imshow(image)
        ax.axis('off')  # Hide the axes

        # Prepare data for plotting
        exclusive_points = np.array(st.session_state.exclusive_points)
        labels_exclusive = np.array([0] * len(st.session_state.exclusive_points))

        # Call the function to show exclusive points
        if len(exclusive_points) > 0:
            show_points(exclusive_points, labels_exclusive, ax)

        st.pyplot(fig_exclusive)

        # Grouping coordinates and labels
        coordinates = st.session_state.inclusive_points + st.session_state.exclusive_points
        labels = [1] * len(st.session_state.inclusive_points) + [0] * len(st.session_state.exclusive_points)

        # # Display grouped coordinates and labels
        # st.subheader("Coordinates and Labels")
        # st.write("Coordinates: ", tuple(coordinates))
        # st.write("Labels: ", labels)

        # Provide an option to clear the coordinates
        if st.button("Clear All Points"):
            st.session_state.inclusive_points = []
            st.session_state.exclusive_points = []
        # global unique_coordinates, unique_labels
        unique_coordinates, unique_labels = remove_duplicates(coordinates, labels)

        st.write("Unique Coordinates:", tuple(unique_coordinates))
        st.write("Unique Labels:", tuple(unique_labels))

        # image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)
        sam2_checkpoint = "sam2_hiera_base_plus.pt"
        model_cfg = "sam2_hiera_b+.yaml"

        sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")

        predictor = SAM2ImagePredictor(sam2_model)

        image = image
        predictor.set_image(image)

        input_point = np.array(unique_coordinates)
        input_label = np.array(unique_labels)

        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=True,
        )
        sorted_ind = np.argsort(scores)[::-1]
        masks = masks[sorted_ind]
        scores = scores[sorted_ind]
        logits = logits[sorted_ind]

        mask_input = logits[np.argmax(scores), :, :]

        masks, scores, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            mask_input=mask_input[None, :, :],
            multimask_output=False,
        )
        image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)

        
        # Get masked images
        original_image = Image.open(uploaded_file)
        # st.image(original_image, caption='Original Image', use_column_width=True)

        with st.container(border=True):# Display masked images
            col1, col2 = st.columns(2)
            with col1:
                mask_images = image_mask_gen.show_masks_1(original_image, masks, scores)
                for idx, (img, score) in enumerate(mask_images):
                    st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
            with col2:
                inverse_mask_images = image_mask_gen.show_inverse_masks(original_image, masks, scores)
                for idx, (img, score) in enumerate(inverse_mask_images):
                    st.image(img, caption=f'Inverse Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
        
        if st.checkbox("Proceed to Image Augmentation"):
        
            image_aug_select = st.sidebar.selectbox("Select Augmentation for Mask",["Pixelate","Hue Change","Mask Replacement","Generative Img2Img"])
            if image_aug_select == "Pixelate":
                
                if st.sidebar.toggle("Proceed to Pixelate Mask"):
                    pixelation_level = st.slider("Select Pixelation Level", min_value=5, max_value=50, value=10)
                    combined_image = image_mask_gen.combine_pixelated_mask(original_image, masks[0], pixelation_level)
                    st.image(combined_image, caption="Combined Pixelated Image", use_column_width=True)
            elif image_aug_select == "Hue Change":

                if st.sidebar.toggle("Proceed to Hue Change"):
                    # Hue shift slider
                    hue_shift = st.slider("Select Hue Shift", min_value=-180, max_value=180, value=0)
                    # Apply hue change and show the result
                    combined_image = image_mask_gen.combine_hue_changed_mask(original_image, masks[0], hue_shift)  # Assuming single mask
                    st.image(combined_image, caption="Combined Hue Changed Image", use_column_width=True)
            elif image_aug_select == "Mask Replacement":

                if st.sidebar.toggle("Proceed to replace Mask"):
                    replacement_file = st.file_uploader("Upload the replacement image", type=["png", "jpg", "jpeg"])
                    if replacement_file is not None:
                        replacement_image = Image.open(replacement_file) #.convert("RGBA")
                        combined_image = image_mask_gen.combine_mask_replaced_image(original_image, replacement_image, masks[0])  # Assuming single mask
                        st.image(combined_image, caption="Masked Area Replaced Image", use_column_width=True)
            elif image_aug_select == "Generative Img2Img":
        
                msk_img = None
                mask_images_x = image_mask_gen.show_masks_1(original_image, masks, scores)
                for idx, (img, score) in enumerate(mask_images_x):
                    msk_img = img
                    # st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True)

                rgb_image = msk_img.convert("RGB")
                # st.image(rgb_image)
                resized_image = image_mask_gen.resize_image(rgb_image)
                # st.image(resized_image, caption=f"Resized size: {resized_image.size[0]}x{resized_image.size[1]}", use_column_width=True)
                width, height = resized_image.size
                
                # User input for the prompt and API key
                prompt = st.text_input("Enter your prompt:", "A Beautiful day, in the style reference of starry night by vincent van gogh")
                api_key = st.text_input("Enter your Stability AI API key:")

                if prompt and api_key:
                    # Set up our connection to the API.
                    os.environ['STABILITY_KEY'] = api_key
                    stability_api = client.StabilityInference(
                        key=os.environ['STABILITY_KEY'], # API Key reference.
                        verbose=True, # Print debug messages.
                        engine="stable-diffusion-xl-1024-v1-0", # Set the engine to use for generation.
                    )
                    style_preset_selector = st.sidebar.selectbox("Select Style Preset",["3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", "modeling-compound", "neon-punk",
                                                                "origami", "photographic", "pixel-art", "tile-texture"],index = 5)
                    if st.sidebar.toggle("Proceed to Generate Image"):
                        # Set up our initial generation parameters.
                        answers2 = stability_api.generate(
                            prompt=prompt,
                            init_image=resized_image, # Assign our uploaded image as our Initial Image for transformation.
                            start_schedule=0.6,
                            steps=250,
                            cfg_scale=10.0,
                            width=width,
                            height=height,
                            sampler=generation.SAMPLER_K_DPMPP_SDE,
                            style_preset=style_preset_selector
                        )

                        # Process the response from the API
                        for resp in answers2:
                            for artifact in resp.artifacts:
                                if artifact.finish_reason == generation.FILTER:
                                    warnings.warn(
                                        "Your request activated the API's safety filters and could not be processed."
                                        "Please modify the prompt and try again.")
                                if artifact.type == generation.ARTIFACT_IMAGE:
                                    img2 = Image.open(io.BytesIO(artifact.binary))
                                    # Display the generated image
                                    st.image(img2, caption="Generated Image", use_column_width=True)

                                    # Combine the generated image with the original image using the mask
                                    combined_img = image_mask_gen.combine_mask_and_inverse_gen(original_image, img2, masks[0])
                                    st.image(combined_img, caption="Combined Image", use_column_width=True)