Sutirtha commited on
Commit
6227239
·
verified ·
1 Parent(s): 2a7a5d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -9
app.py CHANGED
@@ -11,21 +11,50 @@ import warnings
11
  # Suppress specific warnings if desired
12
  warnings.filterwarnings("ignore", category=UserWarning)
13
 
14
- # Load the LangSAM model
15
- model = LangSAM() # Use the default model or specify custom checkpoint if necessary
 
 
 
 
 
 
16
 
17
  def extract_masks(image_pil, prompts):
 
 
 
 
 
 
 
 
 
 
18
  prompts_list = [p.strip() for p in prompts.split(',') if p.strip()]
19
  masks_dict = {}
20
- for prompt in prompts_list:
21
- masks, boxes, phrases, logits = model.predict(image_pil, prompt)
22
- if masks is not None and len(masks) > 0:
23
- masks_np = masks[0].cpu().numpy()
24
- mask = (masks_np > 0).astype(np.uint8) * 255 # Binary mask
25
- masks_dict[prompt] = mask
 
 
 
26
  return masks_dict
27
 
28
  def apply_color_matching(source_img_np, ref_img_np):
 
 
 
 
 
 
 
 
 
 
29
  # Initialize ColorMatcher
30
  cm = ColorMatcher()
31
 
@@ -38,6 +67,24 @@ def apply_color_matching(source_img_np, ref_img_np):
38
  return img_res
39
 
40
  def process_image(current_image_pil, selected_prompt, masks_dict, replacement_image_pil, color_ref_image_pil, apply_replacement, apply_color_grading, apply_color_to_full_image, blending_amount, image_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Check if current_image_pil is None
42
  if current_image_pil is None:
43
  return None, "No current image to edit.", image_history, None
@@ -119,6 +166,15 @@ def process_image(current_image_pil, selected_prompt, masks_dict, replacement_im
119
  return current_image_pil, f"Applied changes to '{selected_prompt}'", image_history, current_image_pil
120
 
121
  def undo(image_history):
 
 
 
 
 
 
 
 
 
122
  if image_history and len(image_history) > 1:
123
  # Pop the last image
124
  image_history.pop()
@@ -133,6 +189,9 @@ def undo(image_history):
133
  return None, [], None
134
 
135
  def gradio_interface():
 
 
 
136
  with gr.Blocks() as demo:
137
  # Define the state variables
138
  image_history = gr.State([])
@@ -160,7 +219,15 @@ def gradio_interface():
160
  status = gr.Textbox(lines=2, interactive=False, label="Status")
161
 
162
  def initialize_image(initial_image_pil):
163
- # Initialize image history with the initial image
 
 
 
 
 
 
 
 
164
  if initial_image_pil is not None:
165
  image_history = [initial_image_pil]
166
  current_image_pil = initial_image_pil
@@ -177,6 +244,16 @@ def gradio_interface():
177
 
178
  # Segment button click
179
  def segment_image_wrapper(current_image_pil, prompts):
 
 
 
 
 
 
 
 
 
 
180
  if current_image_pil is None:
181
  return "No image uploaded.", {}, gr.update(choices=[], value=None)
182
  masks = extract_masks(current_image_pil, prompts)
 
11
  # Suppress specific warnings if desired
12
  warnings.filterwarnings("ignore", category=UserWarning)
13
 
14
+ # Device configuration: Use CUDA if available, else CPU
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ print(f"Using device: {device}")
17
+
18
+ # Load the LangSAM model and move it to the selected device
19
+ model = LangSAM()
20
+ model.to(device)
21
+ model.eval() # Set model to evaluation mode
22
 
23
  def extract_masks(image_pil, prompts):
24
+ """
25
+ Extracts masks for each prompt using the LangSAM model.
26
+
27
+ Args:
28
+ image_pil (PIL.Image): The input image.
29
+ prompts (str): Comma-separated prompts for segmentation.
30
+
31
+ Returns:
32
+ dict: A dictionary mapping each prompt to its corresponding binary mask.
33
+ """
34
  prompts_list = [p.strip() for p in prompts.split(',') if p.strip()]
35
  masks_dict = {}
36
+ with torch.no_grad(): # Disable gradient computation for inference
37
+ for prompt in prompts_list:
38
+ # Ensure the model uses the correct device
39
+ masks, boxes, phrases, logits = model.predict(image_pil, prompt)
40
+ if masks is not None and len(masks) > 0:
41
+ # Move masks to CPU and convert to numpy
42
+ masks_np = masks[0].cpu().numpy()
43
+ mask = (masks_np > 0).astype(np.uint8) * 255 # Binary mask
44
+ masks_dict[prompt] = mask
45
  return masks_dict
46
 
47
  def apply_color_matching(source_img_np, ref_img_np):
48
+ """
49
+ Applies color matching from the reference image to the source image.
50
+
51
+ Args:
52
+ source_img_np (numpy.ndarray): Source image in NumPy array format.
53
+ ref_img_np (numpy.ndarray): Reference image in NumPy array format.
54
+
55
+ Returns:
56
+ numpy.ndarray: Color-matched image.
57
+ """
58
  # Initialize ColorMatcher
59
  cm = ColorMatcher()
60
 
 
67
  return img_res
68
 
69
  def process_image(current_image_pil, selected_prompt, masks_dict, replacement_image_pil, color_ref_image_pil, apply_replacement, apply_color_grading, apply_color_to_full_image, blending_amount, image_history):
70
+ """
71
+ Processes the image by applying replacement and/or color grading based on user input.
72
+
73
+ Args:
74
+ current_image_pil (PIL.Image): The current image to be edited.
75
+ selected_prompt (str): The selected segment prompt.
76
+ masks_dict (dict): Dictionary of masks for each prompt.
77
+ replacement_image_pil (PIL.Image): Replacement image (optional).
78
+ color_ref_image_pil (PIL.Image): Color reference image (optional).
79
+ apply_replacement (bool): Flag to apply replacement.
80
+ apply_color_grading (bool): Flag to apply color grading.
81
+ apply_color_to_full_image (bool): Flag to apply color grading to the full image.
82
+ blending_amount (int): Amount for blending the mask.
83
+ image_history (list): History of images for undo functionality.
84
+
85
+ Returns:
86
+ tuple: Updated image, status message, updated history, and image display.
87
+ """
88
  # Check if current_image_pil is None
89
  if current_image_pil is None:
90
  return None, "No current image to edit.", image_history, None
 
166
  return current_image_pil, f"Applied changes to '{selected_prompt}'", image_history, current_image_pil
167
 
168
  def undo(image_history):
169
+ """
170
+ Undoes the last image edit by reverting to the previous image in the history.
171
+
172
+ Args:
173
+ image_history (list): History of images.
174
+
175
+ Returns:
176
+ tuple: Reverted image, updated history, and image display.
177
+ """
178
  if image_history and len(image_history) > 1:
179
  # Pop the last image
180
  image_history.pop()
 
189
  return None, [], None
190
 
191
  def gradio_interface():
192
+ """
193
+ Defines and launches the Gradio interface for continuous image editing.
194
+ """
195
  with gr.Blocks() as demo:
196
  # Define the state variables
197
  image_history = gr.State([])
 
219
  status = gr.Textbox(lines=2, interactive=False, label="Status")
220
 
221
  def initialize_image(initial_image_pil):
222
+ """
223
+ Initializes the image history and sets up the initial image.
224
+
225
+ Args:
226
+ initial_image_pil (PIL.Image): The uploaded initial image.
227
+
228
+ Returns:
229
+ tuple: Updated states and status message.
230
+ """
231
  if initial_image_pil is not None:
232
  image_history = [initial_image_pil]
233
  current_image_pil = initial_image_pil
 
244
 
245
  # Segment button click
246
  def segment_image_wrapper(current_image_pil, prompts):
247
+ """
248
+ Handles the segmentation of the image based on user prompts.
249
+
250
+ Args:
251
+ current_image_pil (PIL.Image): The current image.
252
+ prompts (str): Comma-separated prompts.
253
+
254
+ Returns:
255
+ tuple: Status message, updated masks, and dropdown updates.
256
+ """
257
  if current_image_pil is None:
258
  return "No image uploaded.", {}, gr.update(choices=[], value=None)
259
  masks = extract_masks(current_image_pil, prompts)