import os import sys import matplotlib.pyplot as plt from pandas.core.common import flatten import torch from torch import nn from torch import optim import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms, models import albumentations as A from albumentations.pytorch import ToTensorV2 from tqdm import tqdm import random import cv2 sys.path.append('/workspace') import dataset import argparse def parse_args(): parser = argparse.ArgumentParser(description='MiSLAS training (Stage-2)') parser.add_argument('--input', help='test image path', required=True, type=str) args = parser.parse_args() return args classes = ('no_trunk', 'trunk') test_transforms = A.Compose( [ A.SmallestMaxSize(max_size=350), A.CenterCrop(height=256, width=256), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ] ) def main(): args = parse_args() assert os.path.exists(args.input) device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu") model = models.resnet50(pretrained=True) model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(model.fc.in_features, 2) ) state_dict = torch.load('./result/best_model.pth') model.load_state_dict(state_dict) for _, p in model.named_parameters(): p.requires_grad = False model.to(device) model.eval() test_transforms = A.Compose( [ A.SmallestMaxSize(max_size=350), A.CenterCrop(height=256, width=256), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ] ) image = cv2.imread(args.input) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = test_transforms(image=image)["image"] image = torch.unsqueeze(image, 0).to(device) output = model(image) _, preds = output.max(1) input_cls = 'trunk' if 't_' in args.input else 'no_trunk' print("input: %s \n" %(input_cls)) print("output: %s" %(classes[preds.item()])) if __name__ == '__main__': main()