File size: 2,849 Bytes
b795d51
9c7b939
b795d51
 
9c7b939
 
 
b795d51
9c7b939
 
b795d51
9c7b939
 
 
b795d51
9c7b939
 
 
b795d51
 
 
9c7b939
 
b795d51
9c7b939
 
 
 
 
 
 
b795d51
9c7b939
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b795d51
 
 
9c7b939
 
 
 
 
 
 
b795d51
 
9c7b939
b795d51
9c7b939
 
b795d51
 
 
9c7b939
b795d51
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from PIL import Image, ImageEnhance
import numpy as np
import cv2
from lang_sam import LangSAM
from color_matcher import ColorMatcher
from color_matcher.normalizer import Normalizer

# Load the LangSAM model
model = LangSAM()  # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")

# Function to apply color matching based on reference image
def apply_color_matching(source_img_np, ref_img_np):
    # Initialize ColorMatcher
    cm = ColorMatcher()
    
    # Apply color matching
    img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
    
    # Normalize the result
    img_res = Normalizer(img_res).uint8_norm()
    
    return img_res

# Function to extract sky and apply color matching using a reference image
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
    # Use LangSAM to predict the mask for the sky
    masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
    
    # Convert the mask to a binary format and create a mask image
    sky_mask = masks[0].astype(np.uint8) * 255
    
    # Convert PIL image to numpy array for processing
    img_np = np.array(image_pil)
    
    # Convert sky mask to 3-channel format to blend with the original image
    sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
    
    # Extract the sky region
    sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
    
    # Convert the reference image to a numpy array
    ref_img_np = np.array(reference_image_pil)
    
    # Apply color matching using the reference image to the extracted sky region
    sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
    
    # Combine the color-matched sky region back into the original image
    result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
    
    # Convert the result back to PIL Image for final output
    result_img_pil = Image.fromarray(result_img_np)
    
    return result_img_pil

# Gradio Interface
def gradio_interface():
    # Gradio function to be called on input
    def process_image(source_img, ref_img):
        # Extract sky and apply color matching using reference image
        result_img_pil = extract_and_color_match_sky(source_img, ref_img)
        return result_img_pil

    # Define Gradio input components
    inputs = [
        gr.Image(type="pil", label="Source Image"),
        gr.Image(type="pil", label="Reference Image")  # Second input for reference image
    ]

    # Define Gradio output component
    outputs = gr.Image(type="pil", label="Resulting Image")

    # Launch Gradio app
    gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch()

# Run the Gradio Interface
if __name__ == "__main__":
    gradio_interface()