|
|
|
|
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
from transformers import SamModel, SamProcessor |
|
from diffusers import StableDiffusionInpaintPipeline |
|
import io |
|
|
|
|
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu") |
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
|
|
|
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float32 |
|
).to("cpu") |
|
|
|
|
|
def mask_to_rgba(mask): |
|
""" |
|
Converts a binary mask to an RGBA image for visualization. |
|
""" |
|
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) |
|
bg_transparent[mask == 1] = [0, 255, 0, 127] |
|
return bg_transparent |
|
|
|
def generate_mask(image, input_points): |
|
""" |
|
Generates a binary mask using SAM based on input points. |
|
|
|
Args: |
|
image (PIL.Image): The input image. |
|
input_points (list of lists): List of points selected by the user. |
|
|
|
Returns: |
|
np.ndarray: Binary mask where the object is marked with 1s. |
|
""" |
|
if not input_points: |
|
return None |
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
points = [tuple(point) for point in input_points] |
|
|
|
|
|
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu") |
|
|
|
with torch.no_grad(): |
|
outputs = sam_model(**inputs) |
|
|
|
|
|
masks = sam_processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
) |
|
|
|
if len(masks) == 0: |
|
return None |
|
|
|
|
|
best_mask = masks[0][0][outputs.iou_scores.argmax()] |
|
|
|
|
|
binary_mask = ~best_mask.numpy().astype(bool).astype(int) |
|
|
|
return binary_mask |
|
|
|
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale): |
|
""" |
|
Replaces the selected object in the image based on the prompt. |
|
|
|
Args: |
|
image (PIL.Image): The original image. |
|
mask (np.ndarray): Binary mask of the selected object. |
|
prompt (str): Text prompt describing the replacement. |
|
negative_prompt (str): Negative text prompt to refine generation. |
|
seed (int): Random seed for reproducibility. |
|
guidance_scale (float): Guidance scale for the inpainting model. |
|
|
|
Returns: |
|
PIL.Image: The augmented image with the object replaced. |
|
""" |
|
if mask is None: |
|
return image |
|
|
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
generator = torch.Generator("cpu").manual_seed(seed) |
|
|
|
try: |
|
result = inpaint_pipeline( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask_image, |
|
negative_prompt=negative_prompt if negative_prompt else None, |
|
generator=generator, |
|
guidance_scale=guidance_scale |
|
).images[0] |
|
return result |
|
except Exception as e: |
|
print(f"Inpainting error: {e}") |
|
return image |
|
|
|
def visualize_mask(image, mask): |
|
""" |
|
Overlays the mask on the image for visualization. |
|
|
|
Args: |
|
image (PIL.Image): The original image. |
|
mask (np.ndarray): Binary mask of the selected object. |
|
|
|
Returns: |
|
PIL.Image: Image with mask overlay. |
|
""" |
|
if mask is None: |
|
return image |
|
|
|
mask_rgba = mask_to_rgba(mask) |
|
mask_pil = Image.fromarray(mask_rgba) |
|
overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil) |
|
return overlay.convert("RGB") |
|
|
|
def process(image, points, prompt, negative_prompt, seed, guidance_scale): |
|
""" |
|
Processes the image by replacing the selected object based on the prompt. |
|
|
|
Args: |
|
image (PIL.Image): Uploaded image. |
|
points (list of lists): Points selected on the image. |
|
prompt (str): Text prompt for replacement. |
|
negative_prompt (str): Negative text prompt. |
|
seed (int): Seed for reproducibility. |
|
guidance_scale (float): Guidance scale. |
|
|
|
Returns: |
|
Tuple of images: Original with mask overlay and augmented image. |
|
""" |
|
mask = generate_mask(image, points) |
|
masked_image = visualize_mask(image, mask) |
|
augmented_image = replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale) |
|
return masked_image, augmented_image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Object Replacement App") |
|
gr.Markdown( |
|
""" |
|
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload Image", type="pil", interactive=True, elem_id="image") |
|
points_input = gr.Points( |
|
label="Select Points on the Object", |
|
show_label=True, |
|
source="image", |
|
interactive=True |
|
) |
|
prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2) |
|
negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2) |
|
seed_input = gr.Number(label="Seed", value=42) |
|
guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=7.5) |
|
process_button = gr.Button("Replace Object") |
|
with gr.Column(): |
|
masked_output = gr.Image(label="Selected Object Mask Overlay") |
|
augmented_output = gr.Image(label="Augmented Image") |
|
|
|
|
|
process_button.click( |
|
fn=process, |
|
inputs=[image_input, points_input, prompt_input, negative_prompt_input, seed_input, guidance_scale_input], |
|
outputs=[masked_output, augmented_output] |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
**Instructions:** |
|
1. **Upload Image:** Upload the image containing the object you want to replace. |
|
2. **Select Points:** Click on the image to select points on the object. Use multiple points for better mask accuracy. |
|
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output. |
|
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed. |
|
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image. |
|
""" |
|
) |
|
|
|
|
|
demo.launch() |