uyen23 commited on
Commit
6a39bcf
·
verified ·
1 Parent(s): 477aab8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import requests
7
+ import gradio as gr
8
+ import numpy as np
9
+
10
+ # convenience expression for automatically determining device
11
+ device = (
12
+ "cuda"
13
+ # Device for NVIDIA or AMD GPUs
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ # Device for Apple Silicon (Metal Performance Shaders)
17
+ if torch.backends.mps.is_available()
18
+ else "cpu"
19
+ )
20
+
21
+ # Load models
22
+ image_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
23
+ model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
24
+ model.to(device)
25
+
26
+ # Inference function
27
+ def infer(image: Image.Image) -> np.ndarray:
28
+ # Preprocess image
29
+ inputs = image_processor(images=image, return_tensors="pt").to(device)
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits # shape (batch_size, num_labels, ~height/4, ~width/4)
32
+
33
+ # Resize output to match input image dimensions
34
+ upsampled_logits = nn.functional.interpolate(logits,
35
+ size=image.size[::-1], # H x W
36
+ mode='bilinear',
37
+ align_corners=False)
38
+
39
+ # Get label masks
40
+ labels = upsampled_logits.argmax(dim=1)[0]
41
+
42
+ # Move to CPU to visualize in matplotlib
43
+ labels_viz = labels.cpu().numpy()
44
+ return labels_viz
45
+
46
+ # Create Gradio interface
47
+ iface = gr.Interface(
48
+ fn=infer, # the function to be used for inference
49
+ inputs=gr.inputs.Image(type="pil"), # input type (image)
50
+ outputs=gr.outputs.Image(type="numpy"), # output type (image as numpy array)
51
+ live=True, # run inference live as the image is uploaded
52
+ title="Face Parsing with Segformer", # interface title
53
+ description="Upload an image to perform face parsing using the Segformer model for semantic segmentation." # description
54
+ )
55
+
56
+ # Launch the interface
57
+ iface.launch()