Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageEnhance | |
import numpy as np | |
import cv2 | |
from lang_sam import LangSAM | |
from color_matcher import ColorMatcher | |
from color_matcher.normalizer import Normalizer | |
# Load the LangSAM model | |
model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>") | |
# Function to apply color matching based on reference image | |
def apply_color_matching(source_img_np, ref_img_np): | |
# Initialize ColorMatcher | |
cm = ColorMatcher() | |
# Apply color matching | |
img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl') | |
# Normalize the result | |
img_res = Normalizer(img_res).uint8_norm() | |
return img_res | |
# Function to extract sky and apply color matching using a reference image | |
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"): | |
# Use LangSAM to predict the mask for the sky | |
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt) | |
# Convert the mask to a binary format and create a mask image | |
sky_mask = masks[0].astype(np.uint8) * 255 | |
# Convert PIL image to numpy array for processing | |
img_np = np.array(image_pil) | |
# Convert sky mask to 3-channel format to blend with the original image | |
sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask]) | |
# Extract the sky region | |
sky_region = cv2.bitwise_and(img_np, sky_mask_3ch) | |
# Convert the reference image to a numpy array | |
ref_img_np = np.array(reference_image_pil) | |
# Apply color matching using the reference image to the extracted sky region | |
sky_region_color_matched = apply_color_matching(sky_region, ref_img_np) | |
# Combine the color-matched sky region back into the original image | |
result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np) | |
# Convert the result back to PIL Image for final output | |
result_img_pil = Image.fromarray(result_img_np) | |
return result_img_pil | |
# Gradio Interface | |
def gradio_interface(): | |
# Gradio function to be called on input | |
def process_image(source_img, ref_img): | |
# Extract sky and apply color matching using reference image | |
result_img_pil = extract_and_color_match_sky(source_img, ref_img) | |
return result_img_pil | |
# Define Gradio input components | |
inputs = [ | |
gr.Image(type="pil", label="Source Image"), | |
gr.Image(type="pil", label="Reference Image") # Second input for reference image | |
] | |
# Define Gradio output component | |
outputs = gr.Image(type="pil", label="Resulting Image") | |
# Launch Gradio app | |
gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch() | |
# Run the Gradio Interface | |
if __name__ == "__main__": | |
gradio_interface() | |