File size: 2,236 Bytes
59be1d1 b294284 5287b40 59be1d1 5287b40 d32fde8 214f5df 59be1d1 5287b40 59be1d1 5287b40 b294284 5287b40 b294284 d32fde8 5287b40 b294284 5287b40 b294284 5287b40 b294284 5287b40 214f5df 5287b40 b294284 5287b40 b294284 5287b40 b294284 5287b40 b294284 5287b40 b294284 59be1d1 5287b40 b294284 5287b40 b294284 5287b40 214f5df b294284 5287b40 59be1d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from transformers import DetrImageProcessor, DetrForObjectDetection
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the Stable Diffusion Inpainting model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Load the DETR object detection model
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
@spaces.GPU
def detect_and_remove(input_image, prompt):
if input_image is None or prompt == "":
return None
image_np = np.array(input_image)
inputs = processor(images=input_image, return_tensors="pt").to(device)
outputs = detector(**inputs)
target_sizes = torch.tensor([image_np.shape[:2]]).to(device)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
mask = Image.new("L", input_image.size, 0)
draw = ImageDraw.Draw(mask)
# Draw boxes for "person" class only
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if detector.config.id2label[label.item()] == "person":
box = [int(i) for i in box.tolist()]
draw.rectangle(box, fill=255)
if np.array(mask).sum() == 0:
return "No human detected."
# Inpainting
output = pipe(
prompt=prompt,
image=input_image,
mask_image=mask
).images[0]
return output
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Automatic Human Removal and Inpainting")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
output_image = gr.Image(type="pil", label="Output Image")
prompt_text = gr.Textbox(label="Prompt", placeholder="Example: Replace humans with cartoon background")
submit = gr.Button("Submit")
submit.click(detect_and_remove, inputs=[input_image, prompt_text], outputs=output_image)
demo.launch() |