face / app.py
uyen23's picture
Create app.py
6a39bcf verified
raw
history blame
1.94 kB
import torch
from torch import nn
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import matplotlib.pyplot as plt
import requests
import gradio as gr
import numpy as np
# convenience expression for automatically determining device
device = (
"cuda"
# Device for NVIDIA or AMD GPUs
if torch.cuda.is_available()
else "mps"
# Device for Apple Silicon (Metal Performance Shaders)
if torch.backends.mps.is_available()
else "cpu"
)
# Load models
image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
# Inference function
def infer(image: Image.Image) -> np.ndarray:
# Preprocess image
inputs = image_processor(images=image, return_tensors="pt").to(device)
outputs = model(**inputs)
logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4)
# Resize output to match input image dimensions
upsampled_logits = nn.functional.interpolate(logits,
size=image.size[::-1], # H x W
mode='bilinear',
align_corners=False)
# Get label masks
labels = upsampled_logits.argmax(dim=1)[0]
# Move to CPU to visualize in matplotlib
labels_viz = labels.cpu().numpy()
return labels_viz
# Create Gradio interface
iface = gr.Interface(
fn=infer, # the function to be used for inference
inputs=gr.inputs.Image(type="pil"), # input type (image)
outputs=gr.outputs.Image(type="numpy"), # output type (image as numpy array)
live=True, # run inference live as the image is uploaded
title="Face Parsing with Segformer", # interface title
description="Upload an image to perform face parsing using the Segformer model for semantic segmentation." # description
)
# Launch the interface
iface.launch()