File size: 14,572 Bytes
53bf77d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import image_mask_gen
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import os
import io
import warnings
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import streamlit as st
import base64
# Function to display points on the image using matplotlib
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def remove_duplicates(coords, labels):
unique_coords = []
unique_labels = []
seen = set()
for coord, label in zip(coords, labels):
coord_tuple = tuple(coord)
if coord_tuple not in seen:
seen.add(coord_tuple)
unique_coords.append(coord)
unique_labels.append(label)
return unique_coords, unique_labels
def image_augmentation_page():
pass
st.title("Image Augmentation")
st.write("Upload an image to apply augmentation techniques.")
# Initialize session state variables
if "inclusive_points" not in st.session_state:
st.session_state.inclusive_points = []
if "exclusive_points" not in st.session_state:
st.session_state.exclusive_points = []
# Upload an image
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
# Open the uploaded image
image = Image.open(uploaded_file)
# Set the maximum width for display
max_display_width = 700 # You can adjust this value
# Calculate the scaling factor
scale_factor = min(max_display_width / image.size[0], 1)
# Resize the image for display
display_width = int(image.size[0] * scale_factor)
display_height = int(image.size[1] * scale_factor)
resized_image = image.resize((display_width, display_height))
# Inclusive Points Phase
st.subheader("Select Inclusive Points (Green)")
canvas_inclusive = st_canvas(
fill_color="rgba(0, 0, 0, 0)", # Transparent fill
stroke_width=1, # Stroke width for drawing
stroke_color="blue", # Color for the outline of clicks
background_image=resized_image,
update_streamlit=True,
height=display_height,
width=display_width,
drawing_mode="circle", # Drawing mode to capture clicks as circles
point_display_radius=3, # Radius of the circle that represents a click
key="canvas_inclusive"
)
# Process inclusive clicks
if canvas_inclusive.json_data is not None:
objects = canvas_inclusive.json_data["objects"]
new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects]
st.session_state.inclusive_points.extend(new_clicks)
# Plot the inclusive points on the original image using Matplotlib
fig_inclusive, ax = plt.subplots()
ax.imshow(image)
ax.axis('off') # Hide the axes
# Prepare data for plotting
inclusive_points = np.array(st.session_state.inclusive_points)
labels_inclusive = np.array([1] * len(st.session_state.inclusive_points))
# Call the function to show inclusive points
if len(inclusive_points) > 0:
show_points(inclusive_points, labels_inclusive, ax)
st.pyplot(fig_inclusive)
# Divider
st.divider()
# Exclusive Points Phase
st.subheader("Select Exclusive Points (Red)")
canvas_exclusive = st_canvas(
fill_color="rgba(0, 0, 0, 0)", # Transparent fill
stroke_width=1, # Stroke width for drawing
stroke_color="blue", # Color for the outline of clicks
background_image=resized_image,
update_streamlit=True,
height=display_height,
width=display_width,
drawing_mode="circle", # Drawing mode to capture clicks as circles
point_display_radius=3, # Radius of the circle that represents a click
key="canvas_exclusive"
)
# Process exclusive clicks
if canvas_exclusive.json_data is not None:
objects = canvas_exclusive.json_data["objects"]
new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects]
st.session_state.exclusive_points.extend(new_clicks)
# Plot the exclusive points on the original image using Matplotlib
fig_exclusive, ax = plt.subplots()
ax.imshow(image)
ax.axis('off') # Hide the axes
# Prepare data for plotting
exclusive_points = np.array(st.session_state.exclusive_points)
labels_exclusive = np.array([0] * len(st.session_state.exclusive_points))
# Call the function to show exclusive points
if len(exclusive_points) > 0:
show_points(exclusive_points, labels_exclusive, ax)
st.pyplot(fig_exclusive)
# Grouping coordinates and labels
coordinates = st.session_state.inclusive_points + st.session_state.exclusive_points
labels = [1] * len(st.session_state.inclusive_points) + [0] * len(st.session_state.exclusive_points)
# # Display grouped coordinates and labels
# st.subheader("Coordinates and Labels")
# st.write("Coordinates: ", tuple(coordinates))
# st.write("Labels: ", labels)
# Provide an option to clear the coordinates
if st.button("Clear All Points"):
st.session_state.inclusive_points = []
st.session_state.exclusive_points = []
# global unique_coordinates, unique_labels
unique_coordinates, unique_labels = remove_duplicates(coordinates, labels)
st.write("Unique Coordinates:", tuple(unique_coordinates))
st.write("Unique Labels:", tuple(unique_labels))
# image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)
sam2_checkpoint = "sam2_hiera_base_plus.pt"
model_cfg = "sam2_hiera_b+.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
predictor = SAM2ImagePredictor(sam2_model)
image = image
predictor.set_image(image)
input_point = np.array(unique_coordinates)
input_label = np.array(unique_labels)
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
mask_input = logits[np.argmax(scores), :, :]
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)
# Get masked images
original_image = Image.open(uploaded_file)
# st.image(original_image, caption='Original Image', use_column_width=True)
with st.container(border=True):# Display masked images
col1, col2 = st.columns(2)
with col1:
mask_images = image_mask_gen.show_masks_1(original_image, masks, scores)
for idx, (img, score) in enumerate(mask_images):
st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
with col2:
inverse_mask_images = image_mask_gen.show_inverse_masks(original_image, masks, scores)
for idx, (img, score) in enumerate(inverse_mask_images):
st.image(img, caption=f'Inverse Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
if st.checkbox("Proceed to Image Augmentation"):
image_aug_select = st.sidebar.selectbox("Select Augmentation for Mask",["Pixelate","Hue Change","Mask Replacement","Generative Img2Img"])
if image_aug_select == "Pixelate":
if st.sidebar.toggle("Proceed to Pixelate Mask"):
pixelation_level = st.slider("Select Pixelation Level", min_value=5, max_value=50, value=10)
combined_image = image_mask_gen.combine_pixelated_mask(original_image, masks[0], pixelation_level)
st.image(combined_image, caption="Combined Pixelated Image", use_column_width=True)
elif image_aug_select == "Hue Change":
if st.sidebar.toggle("Proceed to Hue Change"):
# Hue shift slider
hue_shift = st.slider("Select Hue Shift", min_value=-180, max_value=180, value=0)
# Apply hue change and show the result
combined_image = image_mask_gen.combine_hue_changed_mask(original_image, masks[0], hue_shift) # Assuming single mask
st.image(combined_image, caption="Combined Hue Changed Image", use_column_width=True)
elif image_aug_select == "Mask Replacement":
if st.sidebar.toggle("Proceed to replace Mask"):
replacement_file = st.file_uploader("Upload the replacement image", type=["png", "jpg", "jpeg"])
if replacement_file is not None:
replacement_image = Image.open(replacement_file) #.convert("RGBA")
combined_image = image_mask_gen.combine_mask_replaced_image(original_image, replacement_image, masks[0]) # Assuming single mask
st.image(combined_image, caption="Masked Area Replaced Image", use_column_width=True)
elif image_aug_select == "Generative Img2Img":
msk_img = None
mask_images_x = image_mask_gen.show_masks_1(original_image, masks, scores)
for idx, (img, score) in enumerate(mask_images_x):
msk_img = img
# st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
rgb_image = msk_img.convert("RGB")
# st.image(rgb_image)
resized_image = image_mask_gen.resize_image(rgb_image)
# st.image(resized_image, caption=f"Resized size: {resized_image.size[0]}x{resized_image.size[1]}", use_column_width=True)
width, height = resized_image.size
# User input for the prompt and API key
prompt = st.text_input("Enter your prompt:", "A Beautiful day, in the style reference of starry night by vincent van gogh")
api_key = st.text_input("Enter your Stability AI API key:")
if prompt and api_key:
# Set up our connection to the API.
os.environ['STABILITY_KEY'] = api_key
stability_api = client.StabilityInference(
key=os.environ['STABILITY_KEY'], # API Key reference.
verbose=True, # Print debug messages.
engine="stable-diffusion-xl-1024-v1-0", # Set the engine to use for generation.
)
style_preset_selector = st.sidebar.selectbox("Select Style Preset",["3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", "modeling-compound", "neon-punk",
"origami", "photographic", "pixel-art", "tile-texture"],index = 5)
if st.sidebar.toggle("Proceed to Generate Image"):
# Set up our initial generation parameters.
answers2 = stability_api.generate(
prompt=prompt,
init_image=resized_image, # Assign our uploaded image as our Initial Image for transformation.
start_schedule=0.6,
steps=250,
cfg_scale=10.0,
width=width,
height=height,
sampler=generation.SAMPLER_K_DPMPP_SDE,
style_preset=style_preset_selector
)
# Process the response from the API
for resp in answers2:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please modify the prompt and try again.")
if artifact.type == generation.ARTIFACT_IMAGE:
img2 = Image.open(io.BytesIO(artifact.binary))
# Display the generated image
st.image(img2, caption="Generated Image", use_column_width=True)
# Combine the generated image with the original image using the mask
combined_img = image_mask_gen.combine_mask_and_inverse_gen(original_image, img2, masks[0])
st.image(combined_img, caption="Combined Image", use_column_width=True) |