LaMa-Demo-ONNX / app.py
anodev's picture
Update app.py
18bdd95 verified
raw
history blame
3.51 kB
import os
import imageio
from PIL import Image
import gradio as gr
import cv2
import numpy as np
import paddlehub as hub
import onnxruntime
# Download and setup models
os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx")
os.system("pip install onnxruntime imageio")
os.makedirs("data", exist_ok=True)
os.makedirs("dataout", exist_ok=True)
# Load LaMa ONNX model
sess_options = onnxruntime.SessionOptions()
lama_model = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options)
# Load U^2-Net model for automatic masking
u2net_model = hub.Module(name='U2Net')
# --- Helper Functions ---
def prepare_image(image, target_size=(512, 512)):
"""Resizes and preprocesses image for LaMa model."""
if isinstance(image, Image.Image):
image = image.resize(target_size)
image = np.array(image)
elif isinstance(image, np.ndarray):
image = cv2.resize(image, target_size)
else:
raise ValueError("Input image should be either PIL Image or numpy array!")
# Normalize to [0, 1] and convert to CHW format
image = image.astype(np.float32) / 255.0
if image.ndim == 3:
image = np.transpose(image, (2, 0, 1))
elif image.ndim == 2:
image = image[np.newaxis, ...]
return image[np.newaxis, ...] # Add batch dimension
def generate_mask(image, method="automatic"):
"""Generates mask from image using U^2-Net or user input."""
if method == "automatic":
input_size = 320 # Adjust based on U^2-Net requirements
result = u2net_model.Segmentation(
images=[cv2.cvtColor(image, cv2.COLOR_RGB2BGR)],
paths=None,
batch_size=1,
input_size=input_size,
output_dir='output',
visualization=False
)
mask = Image.fromarray(result[0]['mask'])
mask = mask.resize((512, 512)) # Resize to match LaMa input
mask.save("./data/data_mask.png")
else: # "manual"
mask = imageio.imread("./data/data_mask.png")
mask = Image.fromarray(mask).convert("L") # Ensure grayscale
mask = mask.resize((512, 512))
return prepare_image(mask, (512, 512))
def inpaint_image(image, mask):
"""Performs inpainting using the LaMa model."""
outputs = lama_model.run(None, {'image': image, 'mask': mask})
output = outputs[0][0]
output = output.transpose(1, 2, 0)
output = (output * 255).astype(np.uint8)
return Image.fromarray(output)
# --- Gradio Interface ---
def process_image(input_image, mask_option):
"""Main function for Gradio interface."""
imageio.imwrite("./data/data.png", input_image)
image = prepare_image(input_image)
mask = generate_mask(input_image, method=mask_option)
inpainted_image = inpaint_image(image, mask)
inpainted_image = inpainted_image.resize(Image.open("./data/data.png").size)
inpainted_image.save("./dataout/data_mask.png")
return "./dataout/data_mask.png", "./data/data_mask.png"
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Input Image", type="numpy"),
gr.Radio(choices=["automatic", ],
type="value", label="Masking Option")
],
outputs=[
gr.Image(type="filepath", label="Inpainted Image"),
gr.Image(type="filepath", label="Generated Mask")
],
title="LaMa Image Inpainting",
description="Image inpainting with LaMa and U^2-Net. Upload your image and choose automatic.",
)
iface.launch()