File size: 3,260 Bytes
9bdd97c
 
 
 
 
 
 
 
f57e36f
9bdd97c
 
 
 
 
 
 
 
 
 
 
7479a3a
9bdd97c
7479a3a
9bdd97c
 
f57e36f
 
 
 
9bdd97c
f57e36f
7479a3a
f57e36f
 
 
 
 
 
7479a3a
f57e36f
 
 
 
 
 
 
 
 
7479a3a
f57e36f
 
 
9bdd97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76221dd
9bdd97c
 
 
 
 
 
ed7463d
 
 
 
 
 
 
7479a3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python

from __future__ import annotations

import functools
import os
import pathlib
import sys
import tarfile

import gradio as gr
import huggingface_hub
import PIL.Image
import torch
import torchvision

sys.path.insert(0, 'bizarre-pose-estimator')

from _util.twodee_v0 import I as ImageWrapper

DESCRIPTION = '# [ShuhongChen/bizarre-pose-estimator (tagger)](https://github.com/ShuhongChen/bizarre-pose-estimator)'

MODEL_REPO = 'public-data/bizarre-pose-estimator-models'


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path('images')
    if not image_dir.exists():
        dataset_repo = 'hysts/sample-images-TADNE'
        path = huggingface_hub.hf_hub_download(dataset_repo,
                                               'images.tar.gz',
                                               repo_type='dataset')
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob('*'))


def load_model(device: torch.device) -> torch.nn.Module:
    path = huggingface_hub.hf_hub_download(MODEL_REPO, 'tagger.pth')
    state_dict = torch.load(path)
    model = torchvision.models.resnet50(num_classes=1062)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


def load_labels() -> list[str]:
    label_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'tags.txt')
    with open(label_path) as f:
        labels = [line.strip() for line in f.readlines()]
    return labels


@torch.inference_mode()
def predict(image: PIL.Image.Image, score_threshold: float,
            device: torch.device, model: torch.nn.Module,
            labels: list[str]) -> dict[str, float]:
    data = ImageWrapper(image).resize_square(256).alpha_bg(
        c='w').convert('RGB').tensor()
    data = data.to(device).unsqueeze(0)

    preds = model(data)[0]
    preds = torch.sigmoid(preds)
    preds = preds.cpu().numpy().astype(float)

    res = dict()
    for prob, label in zip(preds.tolist(), labels):
        if prob < score_threshold:
            continue
        res[label] = prob
    return res


image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.5] for path in image_paths]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = load_model(device)
labels = load_labels()

fn = functools.partial(predict, device=device, model=model, labels=labels)

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label='Input', type='pil')
            threshold = gr.Slider(label='Score Threshold',
                                  minimum=0,
                                  maximum=1,
                                  step=0.05,
                                  value=0.5)
            run_button = gr.Button('Run')
        with gr.Column():
            result = gr.Label(label='Output')

    inputs = [image, threshold]
    gr.Examples(examples=examples,
                inputs=inputs,
                outputs=result,
                fn=fn,
                cache_examples=os.getenv('CACHE_EXAMPLES') == '1')
    run_button.click(fn=fn, inputs=inputs, outputs=result, api_name='predict')
demo.queue(max_size=15).launch()