vvaibhav's picture
Update app.py
a05e0f7 verified
raw
history blame
9.61 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
# Constants
IMG_SIZE = 512
# Global variables to store points and the original image
input_points = []
input_image = None
def generate_mask(image, points):
"""
Generates a mask using SAM based on input points.
"""
if not points:
return None
image = image.convert("RGB")
points = [tuple(point) for point in points]
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
best_mask = masks[0][0][outputs.iou_scores.argmax()]
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the object in the image based on the mask and prompt.
"""
if mask is None:
return image
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
"""
if mask is None:
return image
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
mask_rgba = Image.fromarray(bg_transparent)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
return overlay.convert("RGB")
def get_points(img, evt: gr.SelectData):
"""
Captures points selected by the user on the image.
"""
global input_points
global input_image
if len(input_points) == 0:
input_image = img.copy()
x = evt.index[0]
y = evt.index[1]
input_points.append([x, y])
# Generate mask based on selected points
mask = generate_mask(input_image, input_points)
# Mark selected points with a green crossmark
draw = ImageDraw.Draw(img)
size = 10
for point in input_points:
px, py = point
draw.line((px - size, py, px + size, py), fill="green", width=5)
draw.line((px, py - size, px, py + size), fill="green", width=5)
# Visualize the mask overlay
masked_image = visualize_mask(input_image, mask)
return masked_image, img
def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
"""
Runs the inpainting process based on user inputs.
"""
global input_image
global input_points
if input_image is None or len(input_points) == 0:
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
mask = generate_mask(input_image, input_points)
if invert:
what = 'subject'
mask = ~mask
else:
what = 'background'
try:
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
except Exception as e:
raise gr.Error(str(e))
return inpainted.resize((IMG_SIZE, IMG_SIZE))
def reset_points_func():
"""
Resets the selected points and the input image.
"""
global input_points
global input_image
input_points = []
input_image = None
return None, None, None
def preprocess(input_img):
"""
Preprocesses the uploaded image to ensure it is square and resized.
"""
if input_img is None:
return None
width, height = input_img.size
if width != height:
# Add white padding to make the image square
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), 'white')
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(input_img, (left, top))
input_img = new_image
return input_img.resize((IMG_SIZE, IMG_SIZE))
with gr.Blocks() as demo:
gr.Markdown(
"""
# Object Replacement App
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
**Instructions:**
1. **Upload Image:** Click on the first image box to upload your image.
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
6. **Reset:** Click the "Reset" button to clear selections and start over.
""")
with gr.Row():
with gr.Column():
# Image upload and point selection
upload_image = gr.Image(
label="Upload Image",
type="pil",
interactive=True,
height=IMG_SIZE,
width=IMG_SIZE
)
mask_visualization = gr.Image(
label="Selected Object Mask Overlay",
interactive=False,
height=IMG_SIZE,
width=IMG_SIZE
)
selected_image = gr.Image(
label="Image with Selected Points",
type="pil",
interactive=False,
height=IMG_SIZE,
width=IMG_SIZE,
)
# Capture points using the select event
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
# Preprocess image on change
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
# Text inputs and settings
prompt = gr.Textbox(
label="Replacement Prompt",
placeholder="e.g., a red sports car",
lines=2
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="e.g., blurry, low quality",
lines=2
)
cfg = gr.Slider(
label="Classifier-Free Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5
)
seed = gr.Number(
label="Seed",
value=42,
precision=0
)
invert = gr.Checkbox(
label="Infill subject instead of background"
)
# Buttons
replace_button = gr.Button("Replace Object")
reset_button = gr.Button("Reset")
with gr.Column():
# Output images
augmented_image = gr.Image(
label="Augmented Image",
type="pil",
interactive=False,
height=IMG_SIZE,
width=IMG_SIZE,
)
# Define button actions
replace_button.click(
fn=run_inpaint,
inputs=[prompt, negative_prompt, cfg, seed, invert],
outputs=[augmented_image]
)
reset_button.click(
fn=reset_points_func,
inputs=[],
outputs=[mask_visualization, selected_image, augmented_image]
)
# Examples (optional)
gr.Markdown(
"""
## EXAMPLES
Click on an example to load it. Then, follow the instructions above.
""")
with gr.Row():
examples = gr.Examples(
examples=[
[
"car.png",
"a red sports car",
"blurry, low quality",
42
],
[
"monalisa.png",
"a rockstar",
"dark, overexposed",
123
],
],
inputs=[
upload_image,
prompt,
negative_prompt,
seed
],
label="Click to load examples",
cache_examples=False # Set to False to avoid the error
)
demo.queue(max_size=10).launch()