ahmetyaylalioglu's picture
Update app.py
ab8972c verified
raw
history blame
2.73 kB
import gradio as gr
from PIL import Image
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import AutoPipelineForInpainting
from diffusers.models.autoencoders.vq_model import VQEncoderOutput, VQModel
import torch
# Force the model to use CPU
device = "cpu"
# Model and Processor setup
model_name = "facebook/sam-vit-huge"
model = SamModel.from_pretrained(model_name).to(device)
processor = SamProcessor.from_pretrained(model_name)
def mask_to_rgb(mask):
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127]
return bg_transparent
def get_processed_inputs(image, points):
input_points = [points]
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
best_mask = masks[0][0][outputs.iou_scores.argmax()]
return ~best_mask.cpu().numpy()
def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
mask_image = Image.fromarray(input_mask)
rand_gen = torch.manual_seed(seed)
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
image = pipeline(
prompt=prompt,
image=raw_image,
mask_image=mask_image,
guidance_scale=cfgs,
negative_prompt=negative_prompt,
generator=rand_gen
).images[0]
return image
def gradio_interface(image, points, positive_prompt, negative_prompt):
raw_image = Image.fromarray(image).convert("RGB").resize((512, 512))
mask = get_processed_inputs(raw_image, points)
processed_image = inpaint(raw_image, mask, positive_prompt, negative_prompt)
return processed_image, mask_to_rgb(mask)
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Image(type="numpy", label="Click to select points", tool="sketch", brush_radius=1),
gr.Textbox(label="Positive Prompt", placeholder="Enter positive prompt here"),
gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here")
],
outputs=[
gr.Image(label="Inpainted Image"),
gr.Image(label="Segmentation Mask")
],
title="Interactive Image Inpainting",
description="Click on the image to select points for segmentation, provide prompts, and see the inpainted result."
)
iface.launch(share=True)