File size: 6,990 Bytes
18c979d 452ea00 18c979d d633b07 18c979d d633b07 452ea00 d633b07 18c979d 452ea00 18c979d 452ea00 18c979d d633b07 18c979d d633b07 18c979d d633b07 18c979d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# app.py
import gradio as gr
from PIL import Image
import torch
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
import io
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
# No need for model_cpu_offload on CPU
def mask_to_rgba(mask):
"""
Converts a binary mask to an RGBA image for visualization.
"""
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
return bg_transparent
def generate_mask(image, input_points):
"""
Generates a binary mask using SAM based on input points.
Args:
image (PIL.Image): The input image.
input_points (list of lists): List of points selected by the user.
Returns:
np.ndarray: Binary mask where the object is marked with 1s.
"""
if not input_points:
return None
# Convert image to RGB if not already
image = image.convert("RGB")
# Flatten the list of points
points = [tuple(point) for point in input_points]
# Prepare inputs for SAM
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
# Post-process masks
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
# Select the mask with the highest IoU score
best_mask = masks[0][0][outputs.iou_scores.argmax()]
# Invert mask: object=1, background=0
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the selected object in the image based on the prompt.
Args:
image (PIL.Image): The original image.
mask (np.ndarray): Binary mask of the selected object.
prompt (str): Text prompt describing the replacement.
negative_prompt (str): Negative text prompt to refine generation.
seed (int): Random seed for reproducibility.
guidance_scale (float): Guidance scale for the inpainting model.
Returns:
PIL.Image: The augmented image with the object replaced.
"""
if mask is None:
return image
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
Args:
image (PIL.Image): The original image.
mask (np.ndarray): Binary mask of the selected object.
Returns:
PIL.Image: Image with mask overlay.
"""
if mask is None:
return image
mask_rgba = mask_to_rgba(mask)
mask_pil = Image.fromarray(mask_rgba)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
return overlay.convert("RGB")
def process(image, points, prompt, negative_prompt, seed, guidance_scale):
"""
Processes the image by replacing the selected object based on the prompt.
Args:
image (PIL.Image): Uploaded image.
points (list of lists): Points selected on the image.
prompt (str): Text prompt for replacement.
negative_prompt (str): Negative text prompt.
seed (int): Seed for reproducibility.
guidance_scale (float): Guidance scale.
Returns:
Tuple of images: Original with mask overlay and augmented image.
"""
mask = generate_mask(image, points)
masked_image = visualize_mask(image, mask)
augmented_image = replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale)
return masked_image, augmented_image
# Define Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Object Replacement App")
gr.Markdown(
"""
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", type="pil", interactive=True, elem_id="image")
points_input = gr.Points(
label="Select Points on the Object",
show_label=True,
source="image", # Links Points to the Image component via elem_id
interactive=True
)
prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
seed_input = gr.Number(label="Seed", value=42)
guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=7.5)
process_button = gr.Button("Replace Object")
with gr.Column():
masked_output = gr.Image(label="Selected Object Mask Overlay")
augmented_output = gr.Image(label="Augmented Image")
# Bind the process function to the button click
process_button.click(
fn=process,
inputs=[image_input, points_input, prompt_input, negative_prompt_input, seed_input, guidance_scale_input],
outputs=[masked_output, augmented_output]
)
gr.Markdown(
"""
**Instructions:**
1. **Upload Image:** Upload the image containing the object you want to replace.
2. **Select Points:** Click on the image to select points on the object. Use multiple points for better mask accuracy.
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
"""
)
# Launch the app
demo.launch() |