Spaces:
Runtime error
Runtime error
import models | |
import torch | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from dataset import ImageDataset | |
from torch.utils.data import DataLoader | |
# initialize the computation device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
#intialize the model | |
model = models.model(pretrained=False, requires_grad=False).to(device) | |
# load the model checkpoint | |
checkpoint = torch.load('../outputs/model.pth') | |
# load model weights state_dict | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
train_csv = pd.read_csv('../input/movie-classifier/Multi_Label_dataset/train.csv') | |
genres = train_csv.columns.values[2:] | |
print(genres) | |
# prepare the test dataset and dataloader | |
test_data = ImageDataset( | |
train_csv, train=False, test=True | |
) | |
test_loader = DataLoader( | |
test_data, | |
batch_size=1, | |
shuffle=False | |
) | |
for counter, data in enumerate(test_loader): | |
image, target = data['image'].to(device), data['label'] | |
# get all the index positions where value == 1 | |
target_indices = [i for i in range(len(target[0])) if target[0][i] == 1] | |
# get the predictions by passing the image through the model | |
print(image.shape) | |
outputs = model(image) | |
outputs = torch.sigmoid(outputs) | |
outputs = outputs.detach().cpu() | |
sorted_indices = np.argsort(outputs[0]) | |
best = sorted_indices[-3:] | |
string_predicted = '' | |
string_actual = '' | |
for i in range(len(best)): | |
string_predicted += f"{genres[best[i]]} " | |
for i in range(len(target_indices)): | |
string_actual += f"{genres[target_indices[i]]} " | |
image = image.squeeze(0) | |
image = image.detach().cpu().numpy() | |
image = np.transpose(image, (1, 2, 0)) | |
plt.imshow(image) | |
plt.axis('off') | |
plt.title(f"PREDICTED: {string_predicted}\nACTUAL: {string_actual}") | |
plt.savefig(f"../outputs/inference_{counter}.jpg") | |
plt.show() |