Spaces:
Runtime error
Runtime error
Kaushik Bar
commited on
Commit
·
523248f
1
Parent(s):
a7e927b
zsic_unicl
Browse files- app.py +143 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms
|
10 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
+
from timm.data import create_transform
|
12 |
+
from config import get_config
|
13 |
+
from model import build_model
|
14 |
+
|
15 |
+
# Download human-readable labels for ImageNet.
|
16 |
+
response = requests.get("https://git.io/JJkYN")
|
17 |
+
labels = response.text.split("\n")
|
18 |
+
|
19 |
+
def parse_option():
|
20 |
+
parser = argparse.ArgumentParser('UniCL demo script', add_help=False)
|
21 |
+
parser.add_argument('--cfg', type=str, default="configs/unicl_swin_base.yaml", metavar="FILE", help='path to config file', )
|
22 |
+
args, unparsed = parser.parse_known_args()
|
23 |
+
|
24 |
+
config = get_config(args)
|
25 |
+
|
26 |
+
return args, config
|
27 |
+
|
28 |
+
def build_transforms(img_size, center_crop=True):
|
29 |
+
t = [transforms.ToPILImage()]
|
30 |
+
if center_crop:
|
31 |
+
size = int((256 / 224) * img_size)
|
32 |
+
t.append(
|
33 |
+
transforms.Resize(size)
|
34 |
+
)
|
35 |
+
t.append(
|
36 |
+
transforms.CenterCrop(img_size)
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
t.append(
|
40 |
+
transforms.Resize(img_size)
|
41 |
+
)
|
42 |
+
t.append(transforms.ToTensor())
|
43 |
+
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
44 |
+
return transforms.Compose(t)
|
45 |
+
|
46 |
+
def build_transforms4display(img_size, center_crop=True):
|
47 |
+
t = [transforms.ToPILImage()]
|
48 |
+
if center_crop:
|
49 |
+
size = int((256 / 224) * img_size)
|
50 |
+
t.append(
|
51 |
+
transforms.Resize(size)
|
52 |
+
)
|
53 |
+
t.append(
|
54 |
+
transforms.CenterCrop(img_size)
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
t.append(
|
58 |
+
transforms.Resize(img_size)
|
59 |
+
)
|
60 |
+
t.append(transforms.ToTensor())
|
61 |
+
return transforms.Compose(t)
|
62 |
+
|
63 |
+
args, config = parse_option()
|
64 |
+
|
65 |
+
'''
|
66 |
+
build model
|
67 |
+
'''
|
68 |
+
model = build_model(config)
|
69 |
+
|
70 |
+
url = './in21k_yfcc14m_gcc15m_swin_base.pth'
|
71 |
+
checkpoint = torch.load(url, map_location="cpu")
|
72 |
+
model.load_state_dict(checkpoint["model"])
|
73 |
+
model.eval()
|
74 |
+
|
75 |
+
'''
|
76 |
+
build data transform
|
77 |
+
'''
|
78 |
+
eval_transforms = build_transforms(224, center_crop=True)
|
79 |
+
display_transforms = build_transforms4display(224, center_crop=True)
|
80 |
+
|
81 |
+
'''
|
82 |
+
build upsampler
|
83 |
+
'''
|
84 |
+
# upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
|
85 |
+
|
86 |
+
'''
|
87 |
+
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
88 |
+
'''
|
89 |
+
def show_cam_on_image(img: np.ndarray,
|
90 |
+
mask: np.ndarray,
|
91 |
+
use_rgb: bool = False,
|
92 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
93 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
94 |
+
By default the heatmap is in BGR format.
|
95 |
+
:param img: The base image in RGB or BGR format.
|
96 |
+
:param mask: The cam mask.
|
97 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
98 |
+
:param colormap: The OpenCV colormap to be used.
|
99 |
+
:returns: The default image with the cam overlay.
|
100 |
+
"""
|
101 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
102 |
+
if use_rgb:
|
103 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
104 |
+
heatmap = np.float32(heatmap) / 255
|
105 |
+
|
106 |
+
if np.max(img) > 1:
|
107 |
+
raise Exception(
|
108 |
+
"The input image should np.float32 in the range [0, 1]")
|
109 |
+
|
110 |
+
cam = 0.7*heatmap + 0.3*img
|
111 |
+
# cam = cam / np.max(cam)
|
112 |
+
return np.uint8(255 * cam)
|
113 |
+
|
114 |
+
def recognize_image(image, texts):
|
115 |
+
img_t = eval_transforms(image)
|
116 |
+
img_d = display_transforms(image).permute(1, 2, 0).numpy()
|
117 |
+
|
118 |
+
text_embeddings = model.get_text_embeddings(texts.split(';'))
|
119 |
+
|
120 |
+
# compute output
|
121 |
+
feat_img = model.encode_image(img_t.unsqueeze(0))
|
122 |
+
output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
|
123 |
+
prediction = output.softmax(-1).flatten()
|
124 |
+
|
125 |
+
return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
|
126 |
+
|
127 |
+
|
128 |
+
image = gr.inputs.Image()
|
129 |
+
label = gr.outputs.Label(num_top_classes=100)
|
130 |
+
|
131 |
+
gr.Interface(
|
132 |
+
description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
|
133 |
+
fn=recognize_image,
|
134 |
+
inputs=["image", "text"],
|
135 |
+
outputs=[
|
136 |
+
label,
|
137 |
+
],
|
138 |
+
examples=[
|
139 |
+
["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
140 |
+
["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"],
|
141 |
+
["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"],
|
142 |
+
],
|
143 |
+
).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.1
|
2 |
+
torchvision==0.11.2
|
3 |
+
opencv-python-headless==4.5.3.56
|
4 |
+
timm==0.4.12
|
5 |
+
numpy
|
6 |
+
yacs
|
7 |
+
transformers
|