Kaushik Bar commited on
Commit
523248f
·
1 Parent(s): a7e927b

zsic_unicl

Browse files
Files changed (2) hide show
  1. app.py +143 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.data import create_transform
12
+ from config import get_config
13
+ from model import build_model
14
+
15
+ # Download human-readable labels for ImageNet.
16
+ response = requests.get("https://git.io/JJkYN")
17
+ labels = response.text.split("\n")
18
+
19
+ def parse_option():
20
+ parser = argparse.ArgumentParser('UniCL demo script', add_help=False)
21
+ parser.add_argument('--cfg', type=str, default="configs/unicl_swin_base.yaml", metavar="FILE", help='path to config file', )
22
+ args, unparsed = parser.parse_known_args()
23
+
24
+ config = get_config(args)
25
+
26
+ return args, config
27
+
28
+ def build_transforms(img_size, center_crop=True):
29
+ t = [transforms.ToPILImage()]
30
+ if center_crop:
31
+ size = int((256 / 224) * img_size)
32
+ t.append(
33
+ transforms.Resize(size)
34
+ )
35
+ t.append(
36
+ transforms.CenterCrop(img_size)
37
+ )
38
+ else:
39
+ t.append(
40
+ transforms.Resize(img_size)
41
+ )
42
+ t.append(transforms.ToTensor())
43
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
44
+ return transforms.Compose(t)
45
+
46
+ def build_transforms4display(img_size, center_crop=True):
47
+ t = [transforms.ToPILImage()]
48
+ if center_crop:
49
+ size = int((256 / 224) * img_size)
50
+ t.append(
51
+ transforms.Resize(size)
52
+ )
53
+ t.append(
54
+ transforms.CenterCrop(img_size)
55
+ )
56
+ else:
57
+ t.append(
58
+ transforms.Resize(img_size)
59
+ )
60
+ t.append(transforms.ToTensor())
61
+ return transforms.Compose(t)
62
+
63
+ args, config = parse_option()
64
+
65
+ '''
66
+ build model
67
+ '''
68
+ model = build_model(config)
69
+
70
+ url = './in21k_yfcc14m_gcc15m_swin_base.pth'
71
+ checkpoint = torch.load(url, map_location="cpu")
72
+ model.load_state_dict(checkpoint["model"])
73
+ model.eval()
74
+
75
+ '''
76
+ build data transform
77
+ '''
78
+ eval_transforms = build_transforms(224, center_crop=True)
79
+ display_transforms = build_transforms4display(224, center_crop=True)
80
+
81
+ '''
82
+ build upsampler
83
+ '''
84
+ # upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
85
+
86
+ '''
87
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
88
+ '''
89
+ def show_cam_on_image(img: np.ndarray,
90
+ mask: np.ndarray,
91
+ use_rgb: bool = False,
92
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
93
+ """ This function overlays the cam mask on the image as an heatmap.
94
+ By default the heatmap is in BGR format.
95
+ :param img: The base image in RGB or BGR format.
96
+ :param mask: The cam mask.
97
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
98
+ :param colormap: The OpenCV colormap to be used.
99
+ :returns: The default image with the cam overlay.
100
+ """
101
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
102
+ if use_rgb:
103
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
104
+ heatmap = np.float32(heatmap) / 255
105
+
106
+ if np.max(img) > 1:
107
+ raise Exception(
108
+ "The input image should np.float32 in the range [0, 1]")
109
+
110
+ cam = 0.7*heatmap + 0.3*img
111
+ # cam = cam / np.max(cam)
112
+ return np.uint8(255 * cam)
113
+
114
+ def recognize_image(image, texts):
115
+ img_t = eval_transforms(image)
116
+ img_d = display_transforms(image).permute(1, 2, 0).numpy()
117
+
118
+ text_embeddings = model.get_text_embeddings(texts.split(';'))
119
+
120
+ # compute output
121
+ feat_img = model.encode_image(img_t.unsqueeze(0))
122
+ output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
123
+ prediction = output.softmax(-1).flatten()
124
+
125
+ return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
126
+
127
+
128
+ image = gr.inputs.Image()
129
+ label = gr.outputs.Label(num_top_classes=100)
130
+
131
+ gr.Interface(
132
+ description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
133
+ fn=recognize_image,
134
+ inputs=["image", "text"],
135
+ outputs=[
136
+ label,
137
+ ],
138
+ examples=[
139
+ ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
140
+ ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"],
141
+ ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"],
142
+ ],
143
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ opencv-python-headless==4.5.3.56
4
+ timm==0.4.12
5
+ numpy
6
+ yacs
7
+ transformers