Spaces:
Runtime error
Runtime error
File size: 1,899 Bytes
cb8043e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import models
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dataset import ImageDataset
from torch.utils.data import DataLoader
# initialize the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#intialize the model
model = models.model(pretrained=False, requires_grad=False).to(device)
# load the model checkpoint
checkpoint = torch.load('../outputs/model.pth')
# load model weights state_dict
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
train_csv = pd.read_csv('../input/movie-classifier/Multi_Label_dataset/train.csv')
genres = train_csv.columns.values[2:]
print(genres)
# prepare the test dataset and dataloader
test_data = ImageDataset(
train_csv, train=False, test=True
)
test_loader = DataLoader(
test_data,
batch_size=1,
shuffle=False
)
for counter, data in enumerate(test_loader):
image, target = data['image'].to(device), data['label']
# get all the index positions where value == 1
target_indices = [i for i in range(len(target[0])) if target[0][i] == 1]
# get the predictions by passing the image through the model
print(image.shape)
outputs = model(image)
outputs = torch.sigmoid(outputs)
outputs = outputs.detach().cpu()
sorted_indices = np.argsort(outputs[0])
best = sorted_indices[-3:]
string_predicted = ''
string_actual = ''
for i in range(len(best)):
string_predicted += f"{genres[best[i]]} "
for i in range(len(target_indices)):
string_actual += f"{genres[target_indices[i]]} "
image = image.squeeze(0)
image = image.detach().cpu().numpy()
image = np.transpose(image, (1, 2, 0))
plt.imshow(image)
plt.axis('off')
plt.title(f"PREDICTED: {string_predicted}\nACTUAL: {string_actual}")
plt.savefig(f"../outputs/inference_{counter}.jpg")
plt.show() |