Sutirtha commited on
Commit
9c7b939
·
verified ·
1 Parent(s): cfc6ece

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -22
app.py CHANGED
@@ -1,45 +1,77 @@
1
  import gradio as gr
2
- from color_matcher import ColorMatcher
3
- from color_matcher.normalizer import Normalizer
4
  import numpy as np
5
  import cv2
6
- from PIL import Image
 
 
7
 
8
- # Function to apply color correction
9
- def color_match(source_img, reference_img):
10
- # Convert PIL images to OpenCV format (numpy arrays)
11
- img_src = np.array(source_img)
12
- img_ref = np.array(reference_img)
13
-
14
- # Ensure images are in RGB format (3 channels)
15
- if img_src.shape[2] == 4:
16
- img_src = cv2.cvtColor(img_src, cv2.COLOR_RGBA2RGB)
17
- if img_ref.shape[2] == 4:
18
- img_ref = cv2.cvtColor(img_ref, cv2.COLOR_RGBA2RGB)
19
 
20
- # Apply color matching
 
 
21
  cm = ColorMatcher()
22
- img_res = cm.transfer(src=img_src, ref=img_ref, method='mkl')
 
 
23
 
24
  # Normalize the result
25
  img_res = Normalizer(img_res).uint8_norm()
 
 
26
 
27
- # Convert back to PIL for displaying in Gradio
28
- img_res_pil = Image.fromarray(img_res)
 
 
 
 
 
29
 
30
- return img_res_pil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Gradio Interface
33
  def gradio_interface():
34
- # Define input and output components
 
 
 
 
 
 
35
  inputs = [
36
  gr.Image(type="pil", label="Source Image"),
37
- gr.Image(type="pil", label="Reference Image")
38
  ]
 
 
39
  outputs = gr.Image(type="pil", label="Resulting Image")
40
 
41
  # Launch Gradio app
42
- gr.Interface(fn=color_match, inputs=inputs, outputs=outputs, title="Color Matching Tool").launch()
43
 
44
  # Run the Gradio Interface
45
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from PIL import Image, ImageEnhance
 
3
  import numpy as np
4
  import cv2
5
+ from lang_sam import LangSAM
6
+ from color_matcher import ColorMatcher
7
+ from color_matcher.normalizer import Normalizer
8
 
9
+ # Load the LangSAM model
10
+ model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")
 
 
 
 
 
 
 
 
 
11
 
12
+ # Function to apply color matching based on reference image
13
+ def apply_color_matching(source_img_np, ref_img_np):
14
+ # Initialize ColorMatcher
15
  cm = ColorMatcher()
16
+
17
+ # Apply color matching
18
+ img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
19
 
20
  # Normalize the result
21
  img_res = Normalizer(img_res).uint8_norm()
22
+
23
+ return img_res
24
 
25
+ # Function to extract sky and apply color matching using a reference image
26
+ def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
27
+ # Use LangSAM to predict the mask for the sky
28
+ masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
29
+
30
+ # Convert the mask to a binary format and create a mask image
31
+ sky_mask = masks[0].astype(np.uint8) * 255
32
 
33
+ # Convert PIL image to numpy array for processing
34
+ img_np = np.array(image_pil)
35
+
36
+ # Convert sky mask to 3-channel format to blend with the original image
37
+ sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
38
+
39
+ # Extract the sky region
40
+ sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
41
+
42
+ # Convert the reference image to a numpy array
43
+ ref_img_np = np.array(reference_image_pil)
44
+
45
+ # Apply color matching using the reference image to the extracted sky region
46
+ sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
47
+
48
+ # Combine the color-matched sky region back into the original image
49
+ result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
50
+
51
+ # Convert the result back to PIL Image for final output
52
+ result_img_pil = Image.fromarray(result_img_np)
53
+
54
+ return result_img_pil
55
 
56
  # Gradio Interface
57
  def gradio_interface():
58
+ # Gradio function to be called on input
59
+ def process_image(source_img, ref_img):
60
+ # Extract sky and apply color matching using reference image
61
+ result_img_pil = extract_and_color_match_sky(source_img, ref_img)
62
+ return result_img_pil
63
+
64
+ # Define Gradio input components
65
  inputs = [
66
  gr.Image(type="pil", label="Source Image"),
67
+ gr.Image(type="pil", label="Reference Image") # Second input for reference image
68
  ]
69
+
70
+ # Define Gradio output component
71
  outputs = gr.Image(type="pil", label="Resulting Image")
72
 
73
  # Launch Gradio app
74
+ gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch()
75
 
76
  # Run the Gradio Interface
77
  if __name__ == "__main__":