vvaibhav's picture
Update app.py
a99346b verified
raw
history blame
11.6 kB
# app.py
import gradio as gr
from PIL import Image, ImageDraw
import torch
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
# Constants
IMG_SIZE = 512
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
# No need for model_cpu_offload on CPU
# Global variables to store points and the original image
input_points = []
input_image = None
def mask_to_rgba(mask):
"""
Converts a binary mask to an RGBA image for visualization.
"""
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
return bg_transparent
def generate_mask(image, input_points):
"""
Generates a binary mask using SAM based on input points.
Args:
image (PIL.Image): The input image.
input_points (list of lists): List of points selected by the user.
Returns:
np.ndarray: Binary mask where the object is marked with 1s.
"""
if not input_points:
return None
# Convert image to RGB if not already
image = image.convert("RGB")
# Flatten the list of points
points = [tuple(point) for point in input_points]
# Prepare inputs for SAM
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
# Post-process masks
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
# Select the mask with the highest IoU score
best_mask = masks[0][0][outputs.iou_scores.argmax()]
# Invert mask: object=1, background=0
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the selected object in the image based on the prompt.
Args:
image (PIL.Image): The original image.
mask (np.ndarray): Binary mask of the selected object.
prompt (str): Text prompt describing the replacement.
negative_prompt (str): Negative text prompt to refine generation.
seed (int): Random seed for reproducibility.
guidance_scale (float): Guidance scale for the inpainting model.
Returns:
PIL.Image: The augmented image with the object replaced.
"""
if mask is None:
return image
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
Args:
image (PIL.Image): The original image.
mask (np.ndarray): Binary mask of the selected object.
Returns:
PIL.Image: Image with mask overlay.
"""
if mask is None:
return image
mask_rgba = mask_to_rgba(mask)
mask_pil = Image.fromarray(mask_rgba)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
return overlay.convert("RGB")
def get_points(img, evt: gr.SelectData):
"""
Captures points selected by the user on the image.
Args:
img (PIL.Image): The uploaded image.
evt (gr.SelectData): Event data containing the point coordinates.
Returns:
Tuple: (Updated mask visualization, Updated image with crossmarks)
"""
global input_points
global input_image
# The first time this is called, save the untouched input image
if len(input_points) == 0:
input_image = img.copy()
x = evt.index[0]
y = evt.index[1]
input_points.append([x, y])
# Run SAM to generate mask
mask = generate_mask(input_image, input_points)
# Mark selected points with a green crossmark
draw = ImageDraw.Draw(img)
size = 10
for point in input_points:
px, py = point
draw.line((px - size, py, px + size, py), fill="green", width=5)
draw.line((px, py - size, px, py + size), fill="green", width=5)
# Visualize the mask overlay
masked_image = visualize_mask(input_image, mask)
return masked_image, img
def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
"""
Runs the inpainting process based on user inputs.
Args:
prompt (str): Prompt for infill.
negative_prompt (str): Negative prompt.
cfg (float): Classifier-Free Guidance Scale.
seed (int): Random seed.
invert (bool): Whether to infill the subject instead of the background.
Returns:
PIL.Image: The inpainted image.
"""
global input_image
global input_points
if input_image is None or len(input_points) == 0:
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
mask = generate_mask(input_image, input_points)
if invert:
what = 'subject'
mask = ~mask
else:
what = 'background'
try:
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
except Exception as e:
raise gr.Error(str(e))
return inpainted.resize((IMG_SIZE, IMG_SIZE))
def reset_points_func():
"""
Resets the selected points and the input image.
Returns:
Tuple: (Reset mask visualization, Reset image, Empty inpainted image)
"""
global input_points
global input_image
input_points = []
input_image = None
return None, None, None
def preprocess(input_img):
"""
Preprocesses the uploaded image to ensure it is square and resized.
Args:
input_img (PIL.Image): The uploaded image.
Returns:
PIL.Image: The preprocessed image.
"""
if input_img is None:
return None
# Make sure the image is square
width, height = input_img.size
if width != height:
# Add white padding to make the image square
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), 'white')
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(input_img, (left, top))
input_img = new_image
return input_img.resize((IMG_SIZE, IMG_SIZE))
def build_app(get_processed_inputs, inpaint):
"""
Builds and launches the Gradio app.
Args:
get_processed_inputs (function): Function to process inputs for SAM.
inpaint (function): Function to perform inpainting.
Returns:
None
"""
with gr.Blocks() as demo:
gr.Markdown(
"""
# Object Replacement App
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
**Instructions:**
1. **Upload Image:** Click on the first image box to upload your image.
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
6. **Reset:** Click the "Reset" button to clear selections and start over.
""")
with gr.Row():
with gr.Column():
# Image upload and point selection
upload_image = gr.Image(label="Upload Image", type="pil", interactive=True)
mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False)
selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False)
# Capture points using the select event
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
# Preprocess image on change
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
# Text inputs and settings
prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
cfg = gr.Slider(
label="Classifier-Free Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5
)
seed = gr.Number(label="Seed", value=42, precision=0)
invert = gr.Checkbox(label="Infill subject instead of background")
# Buttons
replace_button = gr.Button("Replace Object")
reset_button = gr.Button("Reset")
with gr.Column():
# Output images
augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False)
# Define button actions
replace_button.click(
fn=run_inpaint,
inputs=[prompt, negative_prompt, cfg, seed, invert],
outputs=[augmented_image]
)
reset_button.click(
fn=reset_points_func,
inputs=[],
outputs=[mask_visualization, selected_image, augmented_image]
)
# Examples (optional)
gr.Markdown(
"""
## EXAMPLES
Click on an example to load it. Then, follow the instructions above.
""")
with gr.Row():
examples = gr.Examples(
examples=[
["car.png", "a red sports car", "blurry, low quality", 42],
["house.jpg", "a modern villa", "dark, overexposed", 123],
["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999]
],
inputs=[
upload_image,
prompt,
negative_prompt,
seed
],
label="Click to load examples",
cache_examples=True
)
demo.queue(max_size=10).launch()
# Launch the app
build_app(None, None)