Sutirtha's picture
Update app.py
9c7b939 verified
raw
history blame
2.85 kB
import gradio as gr
from PIL import Image, ImageEnhance
import numpy as np
import cv2
from lang_sam import LangSAM
from color_matcher import ColorMatcher
from color_matcher.normalizer import Normalizer
# Load the LangSAM model
model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>")
# Function to apply color matching based on reference image
def apply_color_matching(source_img_np, ref_img_np):
# Initialize ColorMatcher
cm = ColorMatcher()
# Apply color matching
img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl')
# Normalize the result
img_res = Normalizer(img_res).uint8_norm()
return img_res
# Function to extract sky and apply color matching using a reference image
def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"):
# Use LangSAM to predict the mask for the sky
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
# Convert the mask to a binary format and create a mask image
sky_mask = masks[0].astype(np.uint8) * 255
# Convert PIL image to numpy array for processing
img_np = np.array(image_pil)
# Convert sky mask to 3-channel format to blend with the original image
sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask])
# Extract the sky region
sky_region = cv2.bitwise_and(img_np, sky_mask_3ch)
# Convert the reference image to a numpy array
ref_img_np = np.array(reference_image_pil)
# Apply color matching using the reference image to the extracted sky region
sky_region_color_matched = apply_color_matching(sky_region, ref_img_np)
# Combine the color-matched sky region back into the original image
result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np)
# Convert the result back to PIL Image for final output
result_img_pil = Image.fromarray(result_img_np)
return result_img_pil
# Gradio Interface
def gradio_interface():
# Gradio function to be called on input
def process_image(source_img, ref_img):
# Extract sky and apply color matching using reference image
result_img_pil = extract_and_color_match_sky(source_img, ref_img)
return result_img_pil
# Define Gradio input components
inputs = [
gr.Image(type="pil", label="Source Image"),
gr.Image(type="pil", label="Reference Image") # Second input for reference image
]
# Define Gradio output component
outputs = gr.Image(type="pil", label="Resulting Image")
# Launch Gradio app
gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch()
# Run the Gradio Interface
if __name__ == "__main__":
gradio_interface()