Spaces:
Build error
Build error
| import gradio as gr | |
| from PIL import Image, ImageEnhance | |
| import numpy as np | |
| import cv2 | |
| from lang_sam import LangSAM | |
| from color_matcher import ColorMatcher | |
| from color_matcher.normalizer import Normalizer | |
| # Load the LangSAM model | |
| model = LangSAM() # Use the default model or specify custom checkpoint: LangSAM("<model_type>", "<path/to/checkpoint>") | |
| # Function to apply color matching based on reference image | |
| def apply_color_matching(source_img_np, ref_img_np): | |
| # Initialize ColorMatcher | |
| cm = ColorMatcher() | |
| # Apply color matching | |
| img_res = cm.transfer(src=source_img_np, ref=ref_img_np, method='mkl') | |
| # Normalize the result | |
| img_res = Normalizer(img_res).uint8_norm() | |
| return img_res | |
| # Function to extract sky and apply color matching using a reference image | |
| def extract_and_color_match_sky(image_pil, reference_image_pil, text_prompt="sky"): | |
| # Use LangSAM to predict the mask for the sky | |
| masks, boxes, phrases, logits = model.predict(image_pil, text_prompt) | |
| # Convert the mask to a binary format and create a mask image | |
| sky_mask = masks[0].astype(np.uint8) * 255 | |
| # Convert PIL image to numpy array for processing | |
| img_np = np.array(image_pil) | |
| # Convert sky mask to 3-channel format to blend with the original image | |
| sky_mask_3ch = cv2.merge([sky_mask, sky_mask, sky_mask]) | |
| # Extract the sky region | |
| sky_region = cv2.bitwise_and(img_np, sky_mask_3ch) | |
| # Convert the reference image to a numpy array | |
| ref_img_np = np.array(reference_image_pil) | |
| # Apply color matching using the reference image to the extracted sky region | |
| sky_region_color_matched = apply_color_matching(sky_region, ref_img_np) | |
| # Combine the color-matched sky region back into the original image | |
| result_img_np = np.where(sky_mask_3ch > 0, sky_region_color_matched, img_np) | |
| # Convert the result back to PIL Image for final output | |
| result_img_pil = Image.fromarray(result_img_np) | |
| return result_img_pil | |
| # Gradio Interface | |
| def gradio_interface(): | |
| # Gradio function to be called on input | |
| def process_image(source_img, ref_img): | |
| # Extract sky and apply color matching using reference image | |
| result_img_pil = extract_and_color_match_sky(source_img, ref_img) | |
| return result_img_pil | |
| # Define Gradio input components | |
| inputs = [ | |
| gr.Image(type="pil", label="Source Image"), | |
| gr.Image(type="pil", label="Reference Image") # Second input for reference image | |
| ] | |
| # Define Gradio output component | |
| outputs = gr.Image(type="pil", label="Resulting Image") | |
| # Launch Gradio app | |
| gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Sky Extraction and Color Matching").launch() | |
| # Run the Gradio Interface | |
| if __name__ == "__main__": | |
| gradio_interface() | |