Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,45 +1,77 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
from color_matcher.normalizer import Normalizer
|
4 |
import numpy as np
|
5 |
import cv2
|
6 |
-
from
|
|
|
|
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
# Convert PIL images to OpenCV format (numpy arrays)
|
11 |
-
img_src = np.array(source_img)
|
12 |
-
img_ref = np.array(reference_img)
|
13 |
-
|
14 |
-
# Ensure images are in RGB format (3 channels)
|
15 |
-
if img_src.shape[2] == 4:
|
16 |
-
img_src = cv2.cvtColor(img_src, cv2.COLOR_RGBA2RGB)
|
17 |
-
if img_ref.shape[2] == 4:
|
18 |
-
img_ref = cv2.cvtColor(img_ref, cv2.COLOR_RGBA2RGB)
|
19 |
|
20 |
-
|
|
|
|
|
21 |
cm = ColorMatcher()
|
22 |
-
|
|
|
|
|
23 |
|
24 |
# Normalize the result
|
25 |
img_res = Normalizer(img_res).uint8_norm()
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# Gradio Interface
|
33 |
def gradio_interface():
|
34 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
inputs = [
|
36 |
gr.Image(type="pil", label="Source Image"),
|
37 |
-
gr.Image(type="pil", label="Reference Image")
|
38 |
]
|
|
|
|
|
39 |
outputs = gr.Image(type="pil", label="Resulting Image")
|
40 |
|
41 |
# Launch Gradio app
|
42 |
-
gr.Interface(fn=
|
43 |
|
44 |
# Run the Gradio Interface
|
45 |
if __name__ == "__main__":
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image, ImageEnhance
|
|
|
3 |
import numpy as np
|
4 |
import cv2
|
5 |
+
from lang_sam import LangSAM
|
6 |
+
from color_matcher import ColorMatcher
|
7 |
+
from color_matcher.normalizer import Normalizer
|
8 |
|
9 |
+
# Load the LangSAM model
|
10 |
+
model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
# Function to apply color matching based on reference image
|
13 |
+
def apply_color_matching(source_img_np, ref_img_np):
|
14 |
+
# Initialize ColorMatcher
|
15 |
cm = ColorMatcher()
|
16 |
+
|
17 |
+
# Apply color matching
|
18 |
+
img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
|
19 |
|
20 |
# Normalize the result
|
21 |
img_res = Normalizer(img_res).uint8_norm()
|
22 |
+
|
23 |
+
return img_res
|
24 |
|
25 |
+
# Function to extract sky and apply color matching using a reference image
|
26 |
+
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
|
27 |
+
# Use LangSAM to predict the mask for the sky
|
28 |
+
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
|
29 |
+
|
30 |
+
# Convert the mask to a binary format and create a mask image
|
31 |
+
sky_mask = masks[0].astype(np.uint8) * 255
|
32 |
|
33 |
+
# Convert PIL image to numpy array for processing
|
34 |
+
img_np = np.array(image_pil)
|
35 |
+
|
36 |
+
# Convert sky mask to 3-channel format to blend with the original image
|
37 |
+
sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
|
38 |
+
|
39 |
+
# Extract the sky region
|
40 |
+
sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
|
41 |
+
|
42 |
+
# Convert the reference image to a numpy array
|
43 |
+
ref_img_np = np.array(reference_image_pil)
|
44 |
+
|
45 |
+
# Apply color matching using the reference image to the extracted sky region
|
46 |
+
sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
|
47 |
+
|
48 |
+
# Combine the color-matched sky region back into the original image
|
49 |
+
result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
|
50 |
+
|
51 |
+
# Convert the result back to PIL Image for final output
|
52 |
+
result_img_pil = Image.fromarray(result_img_np)
|
53 |
+
|
54 |
+
return result_img_pil
|
55 |
|
56 |
# Gradio Interface
|
57 |
def gradio_interface():
|
58 |
+
# Gradio function to be called on input
|
59 |
+
def process_image(source_img, ref_img):
|
60 |
+
# Extract sky and apply color matching using reference image
|
61 |
+
result_img_pil = extract_and_color_match_sky(source_img, ref_img)
|
62 |
+
return result_img_pil
|
63 |
+
|
64 |
+
# Define Gradio input components
|
65 |
inputs = [
|
66 |
gr.Image(type="pil", label="Source Image"),
|
67 |
+
gr.Image(type="pil", label="Reference Image") # Second input for reference image
|
68 |
]
|
69 |
+
|
70 |
+
# Define Gradio output component
|
71 |
outputs = gr.Image(type="pil", label="Resulting Image")
|
72 |
|
73 |
# Launch Gradio app
|
74 |
+
gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch()
|
75 |
|
76 |
# Run the Gradio Interface
|
77 |
if __name__ == "__main__":
|