File size: 3,103 Bytes
9b2337b
 
 
 
 
0685b0e
9b2337b
 
c636056
 
 
9b2337b
8ccf912
9b2337b
 
 
 
 
 
0685b0e
9b2337b
 
0685b0e
80a267a
 
 
d25512a
80a267a
0685b0e
9b2337b
 
 
 
 
 
 
 
 
 
 
 
 
 
c636056
 
 
 
 
9b2337b
 
 
 
 
 
 
 
 
 
0685b0e
9b2337b
0685b0e
ed05f1f
9b2337b
 
 
 
8ccf912
ab8972c
80a267a
ed05f1f
 
8ccf912
ed05f1f
 
 
 
 
80a267a
9b2337b
ed4535e
7558e02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
from diffusers.models.autoencoders.vq_model import VQEncoderOutput, VQModel
import torch

# Check if GPU is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model and Processor setup
model_name = "facebook/sam-vit-huge"
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)

def mask_to_rgb(mask):
    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]
    return bg_transparent

def get_processed_inputs(image, points_str):
    # Parse the input string into a list of points
    points = list(map(int, points_str.split(',')))
    # Reshape the points into pairs
    input_points = [[[x, y] for x, y in zip(points[::2], points[1::2])]]
    
    inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), 
        inputs["original_sizes"].cpu(), 
        inputs["reshaped_input_sizes"].cpu()
    )
    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    return ~best_mask.cpu().numpy()

def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
    mask_image = Image.fromarray(input_mask)
    rand_gen = torch.manual_seed(seed)
    pipeline = AutoPipelineForInpainting.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    ).to(device)
    if device == "cpu":
        pipeline.enable_model_cpu_offload()
    image = pipeline(
        prompt=prompt,
        image=raw_image,
        mask_image=mask_image,
        guidance_scale=cfgs,
        negative_prompt=negative_prompt,
        generator=rand_gen
    ).images[0]
    return image

def gradio_interface(image, points, positive_prompt, negative_prompt):
    raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
    mask = get_processed_inputs(raw_image, points)
    processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
    return processed_image, mask_to_rgb(mask)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(type="numpy", label="Input Image"),
        gr.Textbox(label="Points (format: x1,y1,x2,y2,...)", placeholder="e.g., 100,100,200,200"),
        gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
        gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
    ],
    outputs=[
        gr.Image(label="Inpainted Image"),
        gr.Image(label="Segmentation Mask")
    ],
    title="Interactive Image Inpainting",
    description="Enter points as 'x1,y1,x2,y2,...' for segmentation, provide prompts, and see the inpainted result."
)

iface.launch(share=True)