yuanphon commited on
Commit
7bc99d1
·
1 Parent(s): aff216a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import pickle
3
  import cv2
@@ -8,11 +9,11 @@ from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, V
8
  from torch.utils.data import DataLoader, TensorDataset
9
  import gradio as gr
10
 
11
- model_path = '/home/user/app'
12
  train_pickle_path = 'train_data.pickle'
13
  valid_pickle_path = 'valid_data.pickle'
14
  image_directory = 'images'
15
- test_image_path = '/home/user/app/test.jpg'
16
  num_epochs = 5 # Fine-tune the model
17
  label_list = ["小白", "巧巧", "冏媽", "乖狗", "花捲", "超人", "黑胖", "橘子"]
18
  label_dictionary = {"小白": 0, "巧巧": 1, "冏媽": 2, "乖狗": 3, "花捲": 4, "超人": 5, "黑胖": 6, "橘子": 7}
@@ -163,14 +164,21 @@ def predict(upload_image):
163
  # Load the test data
164
  # Load the image
165
 
166
- # img = cv2.imread(test_image_path)
 
 
167
  # img = upload_image
168
- # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
169
  pil_image = upload_image.convert('RGB')
170
  open_cv_image = np.array(pil_image)
171
  # Convert RGB to BGR
172
  img = open_cv_image[:, :, ::-1].copy()
173
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
 
 
 
 
174
 
175
  # Resize the image to 224x224 pixels
176
  img = Image.fromarray(img)
 
1
+
2
  import torch
3
  import pickle
4
  import cv2
 
9
  from torch.utils.data import DataLoader, TensorDataset
10
  import gradio as gr
11
 
12
+ model_path = '/home/usr/app'
13
  train_pickle_path = 'train_data.pickle'
14
  valid_pickle_path = 'valid_data.pickle'
15
  image_directory = 'images'
16
+ test_image_path = '/home/usr/app/test.jpg'
17
  num_epochs = 5 # Fine-tune the model
18
  label_list = ["小白", "巧巧", "冏媽", "乖狗", "花捲", "超人", "黑胖", "橘子"]
19
  label_dictionary = {"小白": 0, "巧巧": 1, "冏媽": 2, "乖狗": 3, "花捲": 4, "超人": 5, "黑胖": 6, "橘子": 7}
 
164
  # Load the test data
165
  # Load the image
166
 
167
+ # img2 = cv2.imread(test_image_path)
168
+ # print("cv2: ", img2)
169
+ # print("cv2 shape: ", img2.shape)
170
  # img = upload_image
171
+ # img = cv2.cvtColor((upload_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
172
  pil_image = upload_image.convert('RGB')
173
  open_cv_image = np.array(pil_image)
174
  # Convert RGB to BGR
175
  img = open_cv_image[:, :, ::-1].copy()
176
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
177
+ # print("gradio: ", img)
178
+ # print("gradio shape: ", img.shape)
179
+
180
+
181
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
182
 
183
  # Resize the image to 224x224 pixels
184
  img = Image.fromarray(img)