MNCJihun commited on
Commit
547f1df
·
1 Parent(s): 31d48c4

add requirements

Browse files
Files changed (2) hide show
  1. app.py +92 -1
  2. requirements.txt +6 -2
app.py CHANGED
@@ -1,3 +1,94 @@
1
  import os
 
2
 
3
- print(os.getcwd())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
 
4
+ import matplotlib.pyplot as plt
5
+ from pandas.core.common import flatten
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch import optim
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from torchvision import datasets, transforms, models
15
+ import albumentations as A
16
+ from albumentations.pytorch import ToTensorV2
17
+
18
+ from tqdm import tqdm
19
+ import random
20
+ import cv2
21
+
22
+ import gradio as gr
23
+
24
+ sys.path.append('/workspace')
25
+ import dataset
26
+ import argparse
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(description='MiSLAS training (Stage-2)')
30
+ parser.add_argument('--input',
31
+ help='test image path',
32
+ required=True,
33
+ type=str)
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+ classes = ('no_trunk', 'trunk')
38
+
39
+ test_transforms = A.Compose(
40
+ [
41
+ A.SmallestMaxSize(max_size=350),
42
+ A.CenterCrop(height=256, width=256),
43
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
44
+ ToTensorV2(),
45
+ ]
46
+ )
47
+
48
+ def main():
49
+ args = parse_args()
50
+ assert os.path.exists(args.input)
51
+ device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")
52
+
53
+ model = models.resnet50(pretrained=True)
54
+ model.fc = nn.Sequential(
55
+ nn.Dropout(0.5),
56
+ nn.Linear(model.fc.in_features, 2)
57
+ )
58
+
59
+ state_dict = torch.load('./result/best_model.pth')
60
+ model.load_state_dict(state_dict)
61
+
62
+ for _, p in model.named_parameters():
63
+ p.requires_grad = False
64
+
65
+ model.to(device)
66
+ model.eval()
67
+
68
+ test_transforms = A.Compose(
69
+ [
70
+ A.SmallestMaxSize(max_size=350),
71
+ A.CenterCrop(height=256, width=256),
72
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
73
+ ToTensorV2(),
74
+ ]
75
+ )
76
+
77
+ image = cv2.imread(args.input)
78
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
79
+ image = test_transforms(image=image)["image"]
80
+ image = torch.unsqueeze(image, 0).to(device)
81
+
82
+ output = model(image)
83
+ _, preds = output.max(1)
84
+
85
+ input_cls = 'trunk' if 't_' in args.input else 'no_trunk'
86
+
87
+ print("input: %s \n" %(input_cls))
88
+ print("output: %s" %(classes[preds.item()]))
89
+
90
+ if __name__ == '__main__':
91
+ main()
92
+
93
+
94
+ with gr.Blocks()
requirements.txt CHANGED
@@ -1,7 +1,11 @@
 
 
 
 
 
1
  pandas
2
  numpy
3
  albumentations
4
- opencv-python
5
  tqdm
6
  matplotlib
7
- jupyter
 
1
+ torch==1.10.0
2
+ torchvision==0.11.0
3
+ torchtext==0.11.0
4
+ opencv-pytho==4.7.9.72
5
+ opencv-python-headless==4.7.9.72
6
  pandas
7
  numpy
8
  albumentations
 
9
  tqdm
10
  matplotlib
11
+ jupyter