File size: 3,311 Bytes
9b2337b
 
 
 
 
 
 
c636056
 
 
9b2337b
8ccf912
9b2337b
 
 
 
 
c636056
9b2337b
c636056
9b2337b
 
e2a5dbf
 
 
 
 
9b2337b
 
 
 
 
 
 
 
 
 
 
c636056
9b2337b
 
 
c636056
 
 
 
 
9b2337b
 
 
 
 
 
 
 
 
 
e2a5dbf
 
9b2337b
e2a5dbf
ed05f1f
9b2337b
 
 
 
8ccf912
ab8972c
e2a5dbf
ed05f1f
 
8ccf912
ed05f1f
 
 
 
 
2e6d89c
9b2337b
ed4535e
7558e02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
import torch

# Check if GPU is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model and Processor setup
model_name = "facebook/sam-vit-huge"
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)

def mask_to_rgb(mask):
    """ Convert binary mask to RGB with transparency for the background. """
    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]  # Green mask with some transparency
    return bg_transparent

def get_processed_inputs(image, annotation):
    """ Process the input image and annotated drawing using SAM model and processor. """
    mask = np.zeros(image.size, dtype=np.uint8)
    mask[annotation[:,:,3] > 128] = 1  # Assume drawing is in alpha channel of RGBA
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), 
        inputs["original_sizes"].cpu(), 
        inputs["reshaped_input_sizes"].cpu()
    )
    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    return ~best_mask.cpu().numpy()

def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
    """ Inpaint the masked area in the image using a text prompt and an inpainting pipeline. """
    mask_image = Image.fromarray(input_mask)
    rand_gen = torch.manual_seed(seed)
    pipeline = AutoPipelineForInpainting.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    ).to(device)
    if device == "cpu":
        pipeline.enable_model_cpu_offload()
    image = pipeline(
        prompt=prompt,
        image=raw_image,
        mask_image=mask_image,
        guidance_scale=cfgs,
        negative_prompt=negative_prompt,
        generator=rand_gen
    ).images[0]
    return image

def gradio_interface(image, annotation, positive_prompt, negative_prompt):
    """ Gradio interface function to handle image, annotated drawing, and prompts. """
    raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
    mask = get_processed_inputs(raw_image, annotation)
    processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
    return processed_image, mask_to_rgb(mask)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="numpy", label="Input Image"),
        gr.Image(tool="editor", label="Draw on the image", output="png", shape=(512, 512)),
        gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
        gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
    ],
    outputs=[
        gr.Image(label="Inpainted Image"),
        gr.Image(label="Segmentation Mask")
    ],
    title="Interactive Image Inpainting",
    description="Draw on the image to select areas for segmentation, provide prompts, and see the inpainted result."
)

iface.launch(share=True)