File size: 2,347 Bytes
9b2337b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
import torch

# Force the model to use CPU regardless of the availability of CUDA
device = "cpu"

# Update model loading and processing to work on CPU
model_name = "facebook/sam-vit-huge"
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)

def mask_to_rgb(mask):
    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]
    return bg_transparent

def get_processed_inputs(image, points):
    input_points = [[list(map(int, point.split(',')))] for point in points.split('|') if point]
    inputs = processor(image, input_points, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), 
        inputs["original_sizes"].cpu(), 
        inputs["reshaped_input_sizes"].cpu()
    )
    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    return ~best_mask.cpu().numpy()

def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
    mask_image = Image.fromarray(input_mask)
    rand_gen = torch.manual_seed(seed)
    pipeline = AutoPipelineForInpainting.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16
    )
    pipeline.enable_model_cpu_offload()
    image = pipeline(
        prompt=prompt,
        image=raw_image,
        mask_image=mask_image,
        guidance_scale=cfgs,
        negative_prompt=negative_prompt,
        generator=rand_gen
    ).images[0]
    return image

# Gradio Interface with Click Events
def gradio_interface(image, points):
    raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
    mask = get_processed_inputs(raw_image, points)
    processed_image = inpaint(raw_image, mask, "a car driving on Mars. Studio lights, 1970s", "artifacts, low quality, distortion")
    return processed_image, mask_to_rgb(mask)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=["image", gr.Image(shape=(512, 512), image_mode='RGB', source="canvas", tool="sketch")],
    outputs=["image", "image"]
)
iface.launch(share=True)