Sutirtha commited on
Commit
d16663b
·
verified ·
1 Parent(s): 861b0a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -73
app.py CHANGED
@@ -10,58 +10,62 @@ import torch
10
  # Load the LangSAM model
11
  model = LangSAM() # Use the default model or specify custom checkpoint if necessary
12
 
13
- def extract_mask(image_pil, text_prompt):
14
- masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
15
- masks_np = masks[0].cpu().numpy()
16
- mask = (masks_np > 0).astype(np.uint8) * 255 # Binary mask
17
- return mask
 
 
 
 
 
18
 
19
  def apply_color_matching(source_img_np, ref_img_np):
20
  # Initialize ColorMatcher
21
  cm = ColorMatcher()
22
-
23
  # Apply color matching
24
  img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
25
-
26
  # Normalize the result
27
  img_res = Normalizer(img_res).uint8_norm()
28
-
29
  return img_res
30
 
31
- def process_image(current_image_pil, prompt, replacement_image_pil, color_ref_image_pil, apply_replacement, apply_color_grading, blending_amount, image_history):
32
  # Check if current_image_pil is None
33
  if current_image_pil is None:
34
  return None, "No current image to edit.", image_history, None
35
-
36
  if not apply_replacement and not apply_color_grading:
37
  return current_image_pil, "No changes applied. Please select at least one operation.", image_history, current_image_pil
38
-
39
  if apply_replacement and replacement_image_pil is None:
40
  return current_image_pil, "Replacement image not provided.", image_history, current_image_pil
41
 
42
  if apply_color_grading and color_ref_image_pil is None:
43
  return current_image_pil, "Color reference image not provided.", image_history, current_image_pil
44
-
 
 
 
 
 
 
45
  # Save current image to history for undo
46
  if image_history is None:
47
  image_history = []
48
  image_history.append(current_image_pil.copy())
49
-
50
- # Extract mask
51
- mask = extract_mask(current_image_pil, prompt)
52
-
53
- # Check if mask is valid
54
- if mask.sum() == 0:
55
- return current_image_pil, f"No mask detected for prompt: {prompt}", image_history, current_image_pil
56
-
57
  # Proceed with replacement or color matching
58
  current_image_np = np.array(current_image_pil)
59
  result_image_np = current_image_np.copy()
60
-
61
  # Create mask with blending
62
  # First, normalize mask to range [0,1]
63
  mask_normalized = mask.astype(np.float32) / 255.0
64
-
65
  # Apply blending by blurring the mask
66
  if blending_amount > 0:
67
  # The kernel size for blurring; larger blending_amount means more blur
@@ -71,56 +75,44 @@ def process_image(current_image_pil, prompt, replacement_image_pil, color_ref_im
71
  mask_blurred = cv2.GaussianBlur(mask_normalized, (kernel_size, kernel_size), 0)
72
  else:
73
  mask_blurred = mask_normalized
74
-
75
  # Convert mask to 3 channels
76
  mask_blurred_3ch = cv2.merge([mask_blurred, mask_blurred, mask_blurred])
77
-
78
  # If apply replacement
79
  if apply_replacement:
80
- # Resize replacement image to fit the mask area
81
- # Get bounding box of the mask
82
- y_indices, x_indices = np.where(mask > 0)
83
- if y_indices.size == 0 or x_indices.size == 0:
84
- # No mask detected
85
- return current_image_pil, f"No mask detected for prompt: {prompt}", image_history, current_image_pil
86
- y_min, y_max = y_indices.min(), y_indices.max()
87
- x_min, x_max = x_indices.min(), x_indices.max()
88
-
89
- # Extract the region of interest
90
- mask_height = y_max - y_min + 1
91
- mask_width = x_max - x_min + 1
92
-
93
- # Resize replacement image to fit mask area
94
- replacement_image_resized = replacement_image_pil.resize((mask_width, mask_height))
95
  replacement_image_np = np.array(replacement_image_resized)
96
-
97
- # Create a mask for the ROI
98
- mask_roi = mask_blurred[y_min:y_max+1, x_min:x_max+1]
99
- mask_roi_3ch = cv2.merge([mask_roi, mask_roi, mask_roi])
100
-
101
- # Replace the masked area with the replacement image using blending
102
- region_to_replace = result_image_np[y_min:y_max+1, x_min:x_max+1]
103
- blended_region = (replacement_image_np.astype(np.float32) * mask_roi_3ch + region_to_replace.astype(np.float32) * (1 - mask_roi_3ch)).astype(np.uint8)
104
- result_image_np[y_min:y_max+1, x_min:x_max+1] = blended_region
105
-
106
  # If apply color grading
107
  if apply_color_grading:
108
- # Extract the masked area
109
- masked_region = (result_image_np.astype(np.float32) * mask_blurred_3ch).astype(np.uint8)
110
  # Convert color reference image to numpy
111
  color_ref_image_np = np.array(color_ref_image_pil)
112
- # Apply color matching
113
- color_matched_region = apply_color_matching(masked_region, color_ref_image_np)
114
- # Blend the color matched region back into the result image
115
- result_image_np = (color_matched_region.astype(np.float32) * mask_blurred_3ch + result_image_np.astype(np.float32) * (1 - mask_blurred_3ch)).astype(np.uint8)
116
-
 
 
 
 
 
 
 
 
 
117
  # Convert result back to PIL Image
118
  result_image_pil = Image.fromarray(result_image_np)
119
-
120
  # Update current_image_pil
121
  current_image_pil = result_image_pil
122
-
123
- return current_image_pil, f"Applied changes for prompt: {prompt}", image_history, current_image_pil
124
 
125
  def undo(image_history):
126
  if image_history and len(image_history) > 1:
@@ -141,46 +133,62 @@ def gradio_interface():
141
  # Define the state variables
142
  image_history = gr.State([])
143
  current_image_pil = gr.State(None)
144
-
 
145
  gr.Markdown("## Continuous Image Editing with LangSAM")
146
-
147
  with gr.Row():
148
  with gr.Column():
149
  initial_image = gr.Image(type="pil", label="Upload Image")
150
- prompt = gr.Textbox(lines=1, placeholder="Enter prompt for object detection", label="Prompt")
 
 
151
  replacement_image = gr.Image(type="pil", label="Replacement Image (optional)")
152
  color_ref_image = gr.Image(type="pil", label="Color Reference Image (optional)")
153
  apply_replacement = gr.Checkbox(label="Apply Replacement", value=False)
154
  apply_color_grading = gr.Checkbox(label="Apply Color Grading", value=False)
 
155
  blending_amount = gr.Slider(minimum=0, maximum=50, step=1, label="Blending Amount", value=0)
156
  apply_button = gr.Button("Apply Changes")
157
  undo_button = gr.Button("Undo")
158
  with gr.Column():
159
  current_image_display = gr.Image(type="pil", label="Edited Image", interactive=False)
160
  status = gr.Textbox(lines=2, interactive=False, label="Status")
161
-
162
  def initialize_image(initial_image_pil):
163
  # Initialize image history with the initial image
164
  if initial_image_pil is not None:
165
  image_history = [initial_image_pil]
166
  current_image_pil = initial_image_pil
167
- return current_image_pil, image_history, initial_image_pil
168
  else:
169
- return None, [], None
170
-
171
  # When the initial image is uploaded, initialize the image history
172
- initial_image.upload(fn=initialize_image, inputs=initial_image, outputs=[current_image_pil, image_history, current_image_display])
173
-
 
 
 
 
 
 
 
 
 
 
 
 
174
  # Apply button click
175
- apply_button.click(fn=process_image,
176
- inputs=[current_image_pil, prompt, replacement_image, color_ref_image, apply_replacement, apply_color_grading, blending_amount, image_history],
177
  outputs=[current_image_pil, status, image_history, current_image_display])
178
-
179
  # Undo button click
180
  undo_button.click(fn=undo, inputs=image_history, outputs=[current_image_pil, image_history, current_image_display])
181
-
182
  demo.launch(share=True)
183
-
184
  # Run the Gradio Interface
185
  if __name__ == "__main__":
186
  gradio_interface()
 
10
  # Load the LangSAM model
11
  model = LangSAM() # Use the default model or specify custom checkpoint if necessary
12
 
13
+ def extract_masks(image_pil, prompts):
14
+ prompts_list = [p.strip() for p in prompts.split(',') if p.strip()]
15
+ masks_dict = {}
16
+ for prompt in prompts_list:
17
+ masks, boxes, phrases, logits = model.predict(image_pil, prompt)
18
+ if masks:
19
+ masks_np = masks[0].cpu().numpy()
20
+ mask = (masks_np > 0).astype(np.uint8) * 255 # Binary mask
21
+ masks_dict[prompt] = mask
22
+ return masks_dict
23
 
24
  def apply_color_matching(source_img_np, ref_img_np):
25
  # Initialize ColorMatcher
26
  cm = ColorMatcher()
27
+
28
  # Apply color matching
29
  img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
30
+
31
  # Normalize the result
32
  img_res = Normalizer(img_res).uint8_norm()
33
+
34
  return img_res
35
 
36
+ def process_image(current_image_pil, selected_prompt, masks_dict, replacement_image_pil, color_ref_image_pil, apply_replacement, apply_color_grading, apply_color_to_full_image, blending_amount, image_history):
37
  # Check if current_image_pil is None
38
  if current_image_pil is None:
39
  return None, "No current image to edit.", image_history, None
40
+
41
  if not apply_replacement and not apply_color_grading:
42
  return current_image_pil, "No changes applied. Please select at least one operation.", image_history, current_image_pil
43
+
44
  if apply_replacement and replacement_image_pil is None:
45
  return current_image_pil, "Replacement image not provided.", image_history, current_image_pil
46
 
47
  if apply_color_grading and color_ref_image_pil is None:
48
  return current_image_pil, "Color reference image not provided.", image_history, current_image_pil
49
+
50
+ # Get the mask from masks_dict
51
+ if selected_prompt not in masks_dict:
52
+ return current_image_pil, f"No mask available for selected segment: {selected_prompt}", image_history, current_image_pil
53
+
54
+ mask = masks_dict[selected_prompt]
55
+
56
  # Save current image to history for undo
57
  if image_history is None:
58
  image_history = []
59
  image_history.append(current_image_pil.copy())
60
+
 
 
 
 
 
 
 
61
  # Proceed with replacement or color matching
62
  current_image_np = np.array(current_image_pil)
63
  result_image_np = current_image_np.copy()
64
+
65
  # Create mask with blending
66
  # First, normalize mask to range [0,1]
67
  mask_normalized = mask.astype(np.float32) / 255.0
68
+
69
  # Apply blending by blurring the mask
70
  if blending_amount > 0:
71
  # The kernel size for blurring; larger blending_amount means more blur
 
75
  mask_blurred = cv2.GaussianBlur(mask_normalized, (kernel_size, kernel_size), 0)
76
  else:
77
  mask_blurred = mask_normalized
78
+
79
  # Convert mask to 3 channels
80
  mask_blurred_3ch = cv2.merge([mask_blurred, mask_blurred, mask_blurred])
81
+
82
  # If apply replacement
83
  if apply_replacement:
84
+ # Resize replacement image to match current image
85
+ replacement_image_resized = replacement_image_pil.resize(current_image_pil.size)
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  replacement_image_np = np.array(replacement_image_resized)
87
+
88
+ # Blend the replacement image with the current image using the mask
89
+ result_image_np = (replacement_image_np.astype(np.float32) * mask_blurred_3ch + result_image_np.astype(np.float32) * (1 - mask_blurred_3ch)).astype(np.uint8)
90
+
 
 
 
 
 
 
91
  # If apply color grading
92
  if apply_color_grading:
 
 
93
  # Convert color reference image to numpy
94
  color_ref_image_np = np.array(color_ref_image_pil)
95
+
96
+ if apply_color_to_full_image:
97
+ # Apply color matching to the full image
98
+ color_matched_image = apply_color_matching(result_image_np, color_ref_image_np)
99
+ result_image_np = color_matched_image
100
+ else:
101
+ # Apply color matching only to the masked area
102
+ # Extract the masked area
103
+ masked_region = (result_image_np.astype(np.float32) * mask_blurred_3ch).astype(np.uint8)
104
+ # Apply color matching
105
+ color_matched_region = apply_color_matching(masked_region, color_ref_image_np)
106
+ # Blend the color matched region back into the result image
107
+ result_image_np = (color_matched_region.astype(np.float32) * mask_blurred_3ch + result_image_np.astype(np.float32) * (1 - mask_blurred_3ch)).astype(np.uint8)
108
+
109
  # Convert result back to PIL Image
110
  result_image_pil = Image.fromarray(result_image_np)
111
+
112
  # Update current_image_pil
113
  current_image_pil = result_image_pil
114
+
115
+ return current_image_pil, f"Applied changes to '{selected_prompt}'", image_history, current_image_pil
116
 
117
  def undo(image_history):
118
  if image_history and len(image_history) > 1:
 
133
  # Define the state variables
134
  image_history = gr.State([])
135
  current_image_pil = gr.State(None)
136
+ masks_dict = gr.State({}) # Store masks for each prompt
137
+
138
  gr.Markdown("## Continuous Image Editing with LangSAM")
139
+
140
  with gr.Row():
141
  with gr.Column():
142
  initial_image = gr.Image(type="pil", label="Upload Image")
143
+ prompts = gr.Textbox(lines=1, placeholder="Enter prompts separated by commas (e.g., sky, grass)", label="Prompts")
144
+ segment_button = gr.Button("Segment Image")
145
+ segment_dropdown = gr.Dropdown(label="Select Segment", choices=[])
146
  replacement_image = gr.Image(type="pil", label="Replacement Image (optional)")
147
  color_ref_image = gr.Image(type="pil", label="Color Reference Image (optional)")
148
  apply_replacement = gr.Checkbox(label="Apply Replacement", value=False)
149
  apply_color_grading = gr.Checkbox(label="Apply Color Grading", value=False)
150
+ apply_color_to_full_image = gr.Checkbox(label="Apply Color Correction to Full Image", value=False)
151
  blending_amount = gr.Slider(minimum=0, maximum=50, step=1, label="Blending Amount", value=0)
152
  apply_button = gr.Button("Apply Changes")
153
  undo_button = gr.Button("Undo")
154
  with gr.Column():
155
  current_image_display = gr.Image(type="pil", label="Edited Image", interactive=False)
156
  status = gr.Textbox(lines=2, interactive=False, label="Status")
157
+
158
  def initialize_image(initial_image_pil):
159
  # Initialize image history with the initial image
160
  if initial_image_pil is not None:
161
  image_history = [initial_image_pil]
162
  current_image_pil = initial_image_pil
163
+ return current_image_pil, image_history, initial_image_pil, {}, [], "Image loaded."
164
  else:
165
+ return None, [], None, {}, [], "No image loaded."
166
+
167
  # When the initial image is uploaded, initialize the image history
168
+ initial_image.upload(fn=initialize_image, inputs=initial_image, outputs=[current_image_pil, image_history, current_image_display, masks_dict, segment_dropdown, status])
169
+
170
+ # Segment button click
171
+ def segment_image_wrapper(current_image_pil, prompts):
172
+ if current_image_pil is None:
173
+ return "No image uploaded.", {}, []
174
+ masks = extract_masks(current_image_pil, prompts)
175
+ if not masks:
176
+ return "No masks detected for the given prompts.", {}, []
177
+ dropdown_choices = list(masks.keys())
178
+ return "Segmentation completed.", masks, gr.Dropdown.update(choices=dropdown_choices, value=dropdown_choices[0])
179
+
180
+ segment_button.click(fn=segment_image_wrapper, inputs=[current_image_pil, prompts], outputs=[status, masks_dict, segment_dropdown])
181
+
182
  # Apply button click
183
+ apply_button.click(fn=process_image,
184
+ inputs=[current_image_pil, segment_dropdown, masks_dict, replacement_image, color_ref_image, apply_replacement, apply_color_grading, apply_color_to_full_image, blending_amount, image_history],
185
  outputs=[current_image_pil, status, image_history, current_image_display])
186
+
187
  # Undo button click
188
  undo_button.click(fn=undo, inputs=image_history, outputs=[current_image_pil, image_history, current_image_display])
189
+
190
  demo.launch(share=True)
191
+
192
  # Run the Gradio Interface
193
  if __name__ == "__main__":
194
  gradio_interface()