replacebg / app.py
Munaf1987's picture
Update app.py
5287b40 verified
raw
history blame
2.24 kB
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from transformers import DetrImageProcessor, DetrForObjectDetection
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the Stable Diffusion Inpainting model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Load the DETR object detection model
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
@spaces.GPU
def detect_and_remove(input_image, prompt):
if input_image is None or prompt == "":
return None
image_np = np.array(input_image)
inputs = processor(images=input_image, return_tensors="pt").to(device)
outputs = detector(**inputs)
target_sizes = torch.tensor([image_np.shape[:2]]).to(device)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
mask = Image.new("L", input_image.size, 0)
draw = ImageDraw.Draw(mask)
# Draw boxes for "person" class only
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if detector.config.id2label[label.item()] == "person":
box = [int(i) for i in box.tolist()]
draw.rectangle(box, fill=255)
if np.array(mask).sum() == 0:
return "No human detected."
# Inpainting
output = pipe(
prompt=prompt,
image=input_image,
mask_image=mask
).images[0]
return output
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Automatic Human Removal and Inpainting")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
output_image = gr.Image(type="pil", label="Output Image")
prompt_text = gr.Textbox(label="Prompt", placeholder="Example: Replace humans with cartoon background")
submit = gr.Button("Submit")
submit.click(detect_and_remove, inputs=[input_image, prompt_text], outputs=output_image)
demo.launch()