File size: 1,811 Bytes
a87c3be
0462cba
172362f
5b982d8
a87c3be
9378acd
a87c3be
5b982d8
 
 
172362f
d7acb8a
 
 
 
 
5b982d8
 
a36cae8
5b982d8
 
 
172362f
5b982d8
 
 
 
 
 
9378acd
d7acb8a
 
 
 
5b982d8
d7acb8a
172362f
5b982d8
3763dfd
5b982d8
 
d7acb8a
5b982d8
 
 
d7acb8a
 
5b982d8
a87c3be
 
 
d7acb8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from PIL import Image
import numpy as np
import torch

# λͺ¨λΈκ³Ό feature extractor λ‘œλ“œ
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"
model = SegformerForSemanticSegmentation.from_pretrained(model_name)
feature_extractor = SegformerFeatureExtractor.from_pretrained(model_name)

def create_color_map(num_classes):
    """ μž„μ˜μ˜ 색상 λ§€ν•‘ 생성 """
    np.random.seed(42)  # μž¬ν˜„μ„±μ„ μœ„ν•œ μ‹œλ“œ μ„€μ •
    return {i: np.random.randint(0, 256, 3) for i in range(num_classes)}

def segment_image(image):
    # 이미지 처리
    image = image.resize(512,512)
    inputs = feature_extractor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # 마슀크 생성
    upsampled_logits = torch.nn.functional.interpolate(
        outputs.logits, size=image.size[::-1], mode="bilinear", align_corners=False
    )
    upsampled_predictions = upsampled_logits.argmax(dim=1)
    mask = upsampled_predictions.squeeze().numpy()

    # 색상 λ§€ν•‘
    color_map = create_color_map(150)  # ADE20Kμ—λŠ” μ•½ 150개의 ν΄λž˜μŠ€κ°€ 있음
    colored_mask = np.array([color_map[class_id] for class_id in mask.flatten()]).reshape(mask.shape + (3,))

    # κ²°κ³Ό λ°˜ν™˜
    return Image.fromarray(colored_mask.astype(np.uint8))

# μ˜ˆμ‹œ 이미지 경둜
example_images = ["image1.jpg", "image2.jpg", "image3.jpg"]

# Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
iface = gr.Interface(
    fn=segment_image,
    inputs=gr.inputs.Image(type="pil"),
    outputs="image",
    title="Image Segmentation with SegFormer",
    description="Upload an image to segment it using SegFormer model.",
    examples=example_images
)

# μΈν„°νŽ˜μ΄μŠ€ μ‹€ν–‰
iface.launch()