|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from PIL import Image, ImageDraw |
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
import spaces |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting", |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
|
).to(device) |
|
|
|
|
|
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") |
|
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device) |
|
|
|
@spaces.GPU |
|
def detect_and_remove(input_image, prompt): |
|
if input_image is None or prompt == "": |
|
return None |
|
|
|
image_np = np.array(input_image) |
|
inputs = processor(images=input_image, return_tensors="pt").to(device) |
|
|
|
outputs = detector(**inputs) |
|
target_sizes = torch.tensor([image_np.shape[:2]]).to(device) |
|
|
|
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] |
|
|
|
mask = Image.new("L", input_image.size, 0) |
|
draw = ImageDraw.Draw(mask) |
|
|
|
|
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
|
if detector.config.id2label[label.item()] == "person": |
|
box = [int(i) for i in box.tolist()] |
|
draw.rectangle(box, fill=255) |
|
|
|
if np.array(mask).sum() == 0: |
|
return "No human detected." |
|
|
|
|
|
output = pipe( |
|
prompt=prompt, |
|
image=input_image, |
|
mask_image=mask |
|
).images[0] |
|
|
|
return output |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Automatic Human Removal and Inpainting") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="pil", label="Input Image") |
|
output_image = gr.Image(type="pil", label="Output Image") |
|
|
|
prompt_text = gr.Textbox(label="Prompt", placeholder="Example: Replace humans with cartoon background") |
|
submit = gr.Button("Submit") |
|
|
|
submit.click(detect_and_remove, inputs=[input_image, prompt_text], outputs=output_image) |
|
|
|
demo.launch() |