Spaces:
Runtime error
Runtime error
import time | |
import pickle | |
import torch | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
from torch.autograd import Variable | |
from PIL import Image | |
import cv2 | |
from models import * | |
from dataset import * | |
from loss import * | |
from build_tag import * | |
from build_vocab import * | |
class CaptionSampler(object): | |
def __init__(self): | |
# Default configuration values | |
self.args = { | |
"model_dir": "", | |
"image_dir": "", | |
"caption_json": "", | |
"vocab_path": "vocab.pkl", | |
"file_lists": "", | |
"load_model_path": "train_best_loss.pth.tar", | |
"resize": 224, | |
"cam_size": 224, | |
"generate_dir": "cam", | |
"result_path": "results", | |
"result_name": "debug", | |
"momentum": 0.1, | |
"visual_model_name": "densenet201", | |
"pretrained": False, | |
"classes": 210, | |
"sementic_features_dim": 512, | |
"k": 10, | |
"attention_version": "v4", | |
"embed_size": 512, | |
"hidden_size": 512, | |
"sent_version": "v1", | |
"sentence_num_layers": 2, | |
"dropout": 0.1, | |
"word_num_layers": 1, | |
"s_max": 10, | |
"n_max": 30, | |
"batch_size": 8, | |
"lambda_tag": 10000, | |
"lambda_stop": 10, | |
"lambda_word": 1, | |
"cuda": False # Keep CUDA disabled by default | |
} | |
self.vocab = self.__init_vocab() | |
self.tagger = self.__init_tagger() | |
self.transform = self.__init_transform() | |
self.model_state_dict = self.__load_mode_state_dict() | |
self.extractor = self.__init_visual_extractor() | |
self.mlc = self.__init_mlc() | |
self.co_attention = self.__init_co_attention() | |
self.sentence_model = self.__init_sentence_model() | |
self.word_model = self.__init_word_word() | |
self.ce_criterion = self._init_ce_criterion() | |
self.mse_criterion = self._init_mse_criterion() | |
def _init_ce_criterion(): | |
return nn.CrossEntropyLoss(size_average=False, reduce=False) | |
def _init_mse_criterion(): | |
return nn.MSELoss() | |
def sample(self, image_file): | |
self.extractor.eval() | |
self.mlc.eval() | |
self.co_attention.eval() | |
self.sentence_model.eval() | |
self.word_model.eval() | |
imageData = self.transform(image_file) | |
imageData = imageData.unsqueeze_(0) | |
image = self.__to_var(imageData, requires_grad=False) | |
visual_features, avg_features = self.extractor.forward(image) | |
tags, semantic_features = self.mlc(avg_features) | |
sentence_states = None | |
prev_hidden_states = self.__to_var(torch.zeros(image.shape[0], 1, self.args["hidden_size"])) | |
pred_sentences = [] | |
for i in range(self.args["s_max"]): | |
ctx, alpha_v, alpha_a = self.co_attention.forward(avg_features, semantic_features, prev_hidden_states) | |
topic, p_stop, hidden_state, sentence_states = self.sentence_model.forward(ctx, | |
prev_hidden_states, | |
sentence_states) | |
p_stop = p_stop.squeeze(1) | |
p_stop = torch.max(p_stop, 1)[1].unsqueeze(1) | |
start_tokens = np.zeros((topic.shape[0], 1)) | |
start_tokens[:, 0] = self.vocab('<start>') | |
start_tokens = self.__to_var(torch.Tensor(start_tokens).long(), requires_grad=False) | |
sampled_ids = self.word_model.sample(topic, start_tokens) | |
prev_hidden_states = hidden_state | |
sampled_ids = sampled_ids * p_stop.numpy() | |
pred_sentences.append(self.__vec2sent(sampled_ids[0])) | |
return pred_sentences | |
def __init_cam_path(self, image_file): | |
generate_dir = os.path.join(self.args["model_dir"], self.args["generate_dir"]) | |
if not os.path.exists(generate_dir): | |
os.makedirs(generate_dir) | |
image_dir = os.path.join(generate_dir, image_file) | |
if not os.path.exists(image_dir): | |
os.makedirs(image_dir) | |
return image_dir | |
def __save_json(self, result): | |
result_path = os.path.join(self.args["model_dir"], self.args["result_path"]) | |
if not os.path.exists(result_path): | |
os.makedirs(result_path) | |
with open(os.path.join(result_path, '{}.json'.format(self.args["result_name"])), 'w') as f: | |
json.dump(result, f) | |
def __load_mode_state_dict(self): | |
try: | |
model_state_dict = torch.load(os.path.join(self.args["model_dir"], self.args["load_model_path"]), map_location=torch.device('cpu')) | |
print("[Load Model-{} Succeed!]".format(self.args["load_model_path"])) | |
print("Load From Epoch {}".format(model_state_dict['epoch'])) | |
return model_state_dict | |
except Exception as err: | |
print("[Load Model Failed] {}".format(err)) | |
raise err | |
def __init_tagger(self): | |
return Tag() | |
def __vec2sent(self, array): | |
sampled_caption = [] | |
for word_id in array: | |
word = self.vocab.get_word_by_id(word_id) | |
if word == '<start>': | |
continue | |
if word == '<end>' or word == '<pad>': | |
break | |
sampled_caption.append(word) | |
return ' '.join(sampled_caption) | |
def __init_vocab(self): | |
with open('vocab.pkl', 'rb') as f: | |
vocab = pickle.load(f) | |
print(vocab) | |
return vocab | |
def __init_data_loader(self, file_list): | |
data_loader = get_loader(image_dir=self.args.image_dir, | |
caption_json=self.args.caption_json, | |
file_list=file_list, | |
vocabulary=self.vocab, | |
transform=self.transform, | |
batch_size=self.args.batch_size, | |
s_max=self.args.s_max, | |
n_max=self.args.n_max, | |
shuffle=False) | |
return data_loader | |
def __init_transform(self): | |
transform = transforms.Compose([ | |
transforms.Resize((self.args["resize"], self.args["resize"])), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225))]) | |
return transform | |
def __to_var(self, x, requires_grad=True): | |
if self.args["cuda"]: | |
x = x.cuda() | |
return Variable(x, requires_grad=requires_grad) | |
def __init_visual_extractor(self): | |
model = VisualFeatureExtractor(model_name=self.args["visual_model_name"], | |
pretrained=self.args["pretrained"]) | |
if self.model_state_dict is not None: | |
print("Visual Extractor Loaded!") | |
model.load_state_dict(self.model_state_dict['extractor']) | |
if self.args["cuda"]: | |
model = model.cuda() | |
return model | |
def __init_mlc(self): | |
model = MLC(classes=self.args["classes"], | |
sementic_features_dim=self.args["sementic_features_dim"], | |
fc_in_features=self.extractor.out_features, | |
k=self.args["k"]) | |
if self.model_state_dict is not None: | |
print("MLC Loaded!") | |
model.load_state_dict(self.model_state_dict['mlc']) | |
if self.args["cuda"]: | |
model = model.cuda() | |
return model | |
def __init_co_attention(self): | |
model = CoAttention(version=self.args["attention_version"], | |
embed_size=self.args["embed_size"], | |
hidden_size=self.args["hidden_size"], | |
visual_size=self.extractor.out_features, | |
k=self.args["k"], | |
momentum=self.args["momentum"]) | |
if self.model_state_dict is not None: | |
print("Co-Attention Loaded!") | |
model.load_state_dict(self.model_state_dict['co_attention']) | |
if self.args["cuda"]: | |
model = model.cuda() | |
return model | |
def __init_sentence_model(self): | |
model = SentenceLSTM(version=self.args["sent_version"], | |
embed_size=self.args["embed_size"], | |
hidden_size=self.args["hidden_size"], | |
num_layers=self.args["sentence_num_layers"], | |
dropout=self.args["dropout"], | |
momentum=self.args["momentum"]) | |
if self.model_state_dict is not None: | |
print("Sentence Model Loaded!") | |
model.load_state_dict(self.model_state_dict['sentence_model']) | |
if self.args["cuda"]: | |
model = model.cuda() | |
return model | |
def __init_word_word(self): | |
model = WordLSTM(vocab_size=len(self.vocab), | |
embed_size=self.args["embed_size"], | |
hidden_size=self.args["hidden_size"], | |
num_layers=self.args["word_num_layers"], | |
n_max=self.args["n_max"]) | |
if self.model_state_dict is not None: | |
print("Word Model Loaded!") | |
model.load_state_dict(self.model_state_dict['word_model']) | |
if self.args["cuda"]: | |
model = model.cuda() | |
return model | |
def main(image): | |
sampler = CaptionSampler() | |
# image = 'sample_images/CXR195_IM-0618-1001.png' | |
caption = sampler.sample(image) | |
print(caption[0]) | |
return caption[0] | |