File size: 2,236 Bytes
59be1d1
b294284
5287b40
59be1d1
5287b40
 
d32fde8
214f5df
59be1d1
 
5287b40
59be1d1
5287b40
 
 
b294284
5287b40
 
 
b294284
d32fde8
5287b40
 
 
b294284
5287b40
 
b294284
5287b40
 
b294284
5287b40
214f5df
5287b40
 
b294284
5287b40
 
 
 
 
b294284
5287b40
 
b294284
5287b40
b294284
 
 
5287b40
b294284
 
 
 
59be1d1
 
5287b40
b294284
 
5287b40
b294284
 
5287b40
214f5df
b294284
5287b40
59be1d1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
from transformers import DetrImageProcessor, DetrForObjectDetection
import spaces

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the Stable Diffusion Inpainting model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

# Load the DETR object detection model
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)

@spaces.GPU
def detect_and_remove(input_image, prompt):
    if input_image is None or prompt == "":
        return None

    image_np = np.array(input_image)
    inputs = processor(images=input_image, return_tensors="pt").to(device)

    outputs = detector(**inputs)
    target_sizes = torch.tensor([image_np.shape[:2]]).to(device)

    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

    mask = Image.new("L", input_image.size, 0)
    draw = ImageDraw.Draw(mask)

    # Draw boxes for "person" class only
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        if detector.config.id2label[label.item()] == "person":
            box = [int(i) for i in box.tolist()]
            draw.rectangle(box, fill=255)

    if np.array(mask).sum() == 0:
        return "No human detected."

    # Inpainting
    output = pipe(
        prompt=prompt,
        image=input_image,
        mask_image=mask
    ).images[0]

    return output

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Automatic Human Removal and Inpainting")

    with gr.Row():
        input_image = gr.Image(type="pil", label="Input Image")
        output_image = gr.Image(type="pil", label="Output Image")

    prompt_text = gr.Textbox(label="Prompt", placeholder="Example: Replace humans with cartoon background")
    submit = gr.Button("Submit")

    submit.click(detect_and_remove, inputs=[input_image, prompt_text], outputs=output_image)

demo.launch()