|
|
|
|
|
import gradio as gr |
|
from PIL import Image, ImageDraw |
|
import torch |
|
import numpy as np |
|
from transformers import SamModel, SamProcessor |
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
|
IMG_SIZE = 512 |
|
|
|
|
|
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu") |
|
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
|
|
|
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float32 |
|
).to("cpu") |
|
|
|
|
|
|
|
input_points = [] |
|
input_image = None |
|
|
|
def mask_to_rgba(mask): |
|
""" |
|
Converts a binary mask to an RGBA image for visualization. |
|
""" |
|
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8) |
|
bg_transparent[mask == 1] = [0, 255, 0, 127] |
|
return bg_transparent |
|
|
|
def generate_mask(image, input_points): |
|
""" |
|
Generates a binary mask using SAM based on input points. |
|
|
|
Args: |
|
image (PIL.Image): The input image. |
|
input_points (list of lists): List of points selected by the user. |
|
|
|
Returns: |
|
np.ndarray: Binary mask where the object is marked with 1s. |
|
""" |
|
if not input_points: |
|
return None |
|
|
|
|
|
image = image.convert("RGB") |
|
|
|
|
|
points = [tuple(point) for point in input_points] |
|
|
|
|
|
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu") |
|
|
|
with torch.no_grad(): |
|
outputs = sam_model(**inputs) |
|
|
|
|
|
masks = sam_processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
) |
|
|
|
if len(masks) == 0: |
|
return None |
|
|
|
|
|
best_mask = masks[0][0][outputs.iou_scores.argmax()] |
|
|
|
|
|
binary_mask = ~best_mask.numpy().astype(bool).astype(int) |
|
|
|
return binary_mask |
|
|
|
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale): |
|
""" |
|
Replaces the selected object in the image based on the prompt. |
|
|
|
Args: |
|
image (PIL.Image): The original image. |
|
mask (np.ndarray): Binary mask of the selected object. |
|
prompt (str): Text prompt describing the replacement. |
|
negative_prompt (str): Negative text prompt to refine generation. |
|
seed (int): Random seed for reproducibility. |
|
guidance_scale (float): Guidance scale for the inpainting model. |
|
|
|
Returns: |
|
PIL.Image: The augmented image with the object replaced. |
|
""" |
|
if mask is None: |
|
return image |
|
|
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
generator = torch.Generator("cpu").manual_seed(seed) |
|
|
|
try: |
|
result = inpaint_pipeline( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask_image, |
|
negative_prompt=negative_prompt if negative_prompt else None, |
|
generator=generator, |
|
guidance_scale=guidance_scale |
|
).images[0] |
|
return result |
|
except Exception as e: |
|
print(f"Inpainting error: {e}") |
|
return image |
|
|
|
def visualize_mask(image, mask): |
|
""" |
|
Overlays the mask on the image for visualization. |
|
|
|
Args: |
|
image (PIL.Image): The original image. |
|
mask (np.ndarray): Binary mask of the selected object. |
|
|
|
Returns: |
|
PIL.Image: Image with mask overlay. |
|
""" |
|
if mask is None: |
|
return image |
|
|
|
mask_rgba = mask_to_rgba(mask) |
|
mask_pil = Image.fromarray(mask_rgba) |
|
overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil) |
|
return overlay.convert("RGB") |
|
|
|
def get_points(img, evt: gr.SelectData): |
|
""" |
|
Captures points selected by the user on the image. |
|
|
|
Args: |
|
img (PIL.Image): The uploaded image. |
|
evt (gr.SelectData): Event data containing the point coordinates. |
|
|
|
Returns: |
|
Tuple: (Updated mask visualization, Updated image with crossmarks) |
|
""" |
|
global input_points |
|
global input_image |
|
|
|
|
|
if len(input_points) == 0: |
|
input_image = img.copy() |
|
|
|
x = evt.index[0] |
|
y = evt.index[1] |
|
|
|
input_points.append([x, y]) |
|
|
|
|
|
mask = generate_mask(input_image, input_points) |
|
|
|
|
|
draw = ImageDraw.Draw(img) |
|
size = 10 |
|
for point in input_points: |
|
px, py = point |
|
draw.line((px - size, py, px + size, py), fill="green", width=5) |
|
draw.line((px, py - size, px, py + size), fill="green", width=5) |
|
|
|
|
|
masked_image = visualize_mask(input_image, mask) |
|
|
|
return masked_image, img |
|
|
|
def run_inpaint(prompt, negative_prompt, cfg, seed, invert): |
|
""" |
|
Runs the inpainting process based on user inputs. |
|
|
|
Args: |
|
prompt (str): Prompt for infill. |
|
negative_prompt (str): Negative prompt. |
|
cfg (float): Classifier-Free Guidance Scale. |
|
seed (int): Random seed. |
|
invert (bool): Whether to infill the subject instead of the background. |
|
|
|
Returns: |
|
PIL.Image: The inpainted image. |
|
""" |
|
global input_image |
|
global input_points |
|
|
|
if input_image is None or len(input_points) == 0: |
|
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.") |
|
|
|
mask = generate_mask(input_image, input_points) |
|
|
|
if invert: |
|
what = 'subject' |
|
mask = ~mask |
|
else: |
|
what = 'background' |
|
|
|
try: |
|
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg) |
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|
|
return inpainted.resize((IMG_SIZE, IMG_SIZE)) |
|
|
|
def reset_points_func(): |
|
""" |
|
Resets the selected points and the input image. |
|
|
|
Returns: |
|
Tuple: (Reset mask visualization, Reset image, Empty inpainted image) |
|
""" |
|
global input_points |
|
global input_image |
|
input_points = [] |
|
input_image = None |
|
return None, None, None |
|
|
|
def preprocess(input_img): |
|
""" |
|
Preprocesses the uploaded image to ensure it is square and resized. |
|
|
|
Args: |
|
input_img (PIL.Image): The uploaded image. |
|
|
|
Returns: |
|
PIL.Image: The preprocessed image. |
|
""" |
|
if input_img is None: |
|
return None |
|
|
|
|
|
width, height = input_img.size |
|
|
|
if width != height: |
|
|
|
new_size = max(width, height) |
|
new_image = Image.new("RGB", (new_size, new_size), 'white') |
|
left = (new_size - width) // 2 |
|
top = (new_size - height) // 2 |
|
new_image.paste(input_img, (left, top)) |
|
input_img = new_image |
|
|
|
return input_img.resize((IMG_SIZE, IMG_SIZE)) |
|
|
|
def build_app(get_processed_inputs, inpaint): |
|
""" |
|
Builds and launches the Gradio app. |
|
|
|
Args: |
|
get_processed_inputs (function): Function to process inputs for SAM. |
|
inpaint (function): Function to perform inpainting. |
|
|
|
Returns: |
|
None |
|
""" |
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
# Object Replacement App |
|
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image. |
|
|
|
**Instructions:** |
|
1. **Upload Image:** Click on the first image box to upload your image. |
|
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy. |
|
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output. |
|
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed. |
|
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image. |
|
6. **Reset:** Click the "Reset" button to clear selections and start over. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
upload_image = gr.Image(label="Upload Image", type="pil", interactive=True) |
|
mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False) |
|
selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False) |
|
|
|
|
|
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image]) |
|
|
|
|
|
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image]) |
|
|
|
|
|
prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2) |
|
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2) |
|
cfg = gr.Slider( |
|
label="Classifier-Free Guidance Scale", |
|
minimum=1.0, |
|
maximum=20.0, |
|
value=7.5, |
|
step=0.5 |
|
) |
|
seed = gr.Number(label="Seed", value=42, precision=0) |
|
invert = gr.Checkbox(label="Infill subject instead of background") |
|
|
|
|
|
replace_button = gr.Button("Replace Object") |
|
reset_button = gr.Button("Reset") |
|
with gr.Column(): |
|
|
|
augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False) |
|
|
|
|
|
replace_button.click( |
|
fn=run_inpaint, |
|
inputs=[prompt, negative_prompt, cfg, seed, invert], |
|
outputs=[augmented_image] |
|
) |
|
|
|
reset_button.click( |
|
fn=reset_points_func, |
|
inputs=[], |
|
outputs=[mask_visualization, selected_image, augmented_image] |
|
) |
|
|
|
|
|
gr.Markdown( |
|
""" |
|
## EXAMPLES |
|
Click on an example to load it. Then, follow the instructions above. |
|
""") |
|
|
|
with gr.Row(): |
|
examples = gr.Examples( |
|
examples=[ |
|
["car.png", "a red sports car", "blurry, low quality", 42], |
|
["house.jpg", "a modern villa", "dark, overexposed", 123], |
|
["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999] |
|
], |
|
inputs=[ |
|
upload_image, |
|
prompt, |
|
negative_prompt, |
|
seed |
|
], |
|
label="Click to load examples", |
|
cache_examples=True |
|
) |
|
|
|
demo.queue(max_size=10).launch() |
|
|
|
|
|
build_app(None, None) |