Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
import pickle
|
3 |
import cv2
|
@@ -8,11 +9,11 @@ from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, V
|
|
8 |
from torch.utils.data import DataLoader, TensorDataset
|
9 |
import gradio as gr
|
10 |
|
11 |
-
model_path = '/home/
|
12 |
train_pickle_path = 'train_data.pickle'
|
13 |
valid_pickle_path = 'valid_data.pickle'
|
14 |
image_directory = 'images'
|
15 |
-
test_image_path = '/home/
|
16 |
num_epochs = 5 # Fine-tune the model
|
17 |
label_list = ["小白", "巧巧", "冏媽", "乖狗", "花捲", "超人", "黑胖", "橘子"]
|
18 |
label_dictionary = {"小白": 0, "巧巧": 1, "冏媽": 2, "乖狗": 3, "花捲": 4, "超人": 5, "黑胖": 6, "橘子": 7}
|
@@ -163,14 +164,21 @@ def predict(upload_image):
|
|
163 |
# Load the test data
|
164 |
# Load the image
|
165 |
|
166 |
-
#
|
|
|
|
|
167 |
# img = upload_image
|
168 |
-
# img = cv2.cvtColor(
|
169 |
pil_image = upload_image.convert('RGB')
|
170 |
open_cv_image = np.array(pil_image)
|
171 |
# Convert RGB to BGR
|
172 |
img = open_cv_image[:, :, ::-1].copy()
|
173 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# Resize the image to 224x224 pixels
|
176 |
img = Image.fromarray(img)
|
|
|
1 |
+
|
2 |
import torch
|
3 |
import pickle
|
4 |
import cv2
|
|
|
9 |
from torch.utils.data import DataLoader, TensorDataset
|
10 |
import gradio as gr
|
11 |
|
12 |
+
model_path = '/home/usr/app'
|
13 |
train_pickle_path = 'train_data.pickle'
|
14 |
valid_pickle_path = 'valid_data.pickle'
|
15 |
image_directory = 'images'
|
16 |
+
test_image_path = '/home/usr/app/test.jpg'
|
17 |
num_epochs = 5 # Fine-tune the model
|
18 |
label_list = ["小白", "巧巧", "冏媽", "乖狗", "花捲", "超人", "黑胖", "橘子"]
|
19 |
label_dictionary = {"小白": 0, "巧巧": 1, "冏媽": 2, "乖狗": 3, "花捲": 4, "超人": 5, "黑胖": 6, "橘子": 7}
|
|
|
164 |
# Load the test data
|
165 |
# Load the image
|
166 |
|
167 |
+
# img2 = cv2.imread(test_image_path)
|
168 |
+
# print("cv2: ", img2)
|
169 |
+
# print("cv2 shape: ", img2.shape)
|
170 |
# img = upload_image
|
171 |
+
# img = cv2.cvtColor((upload_image * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
|
172 |
pil_image = upload_image.convert('RGB')
|
173 |
open_cv_image = np.array(pil_image)
|
174 |
# Convert RGB to BGR
|
175 |
img = open_cv_image[:, :, ::-1].copy()
|
176 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
177 |
+
# print("gradio: ", img)
|
178 |
+
# print("gradio shape: ", img.shape)
|
179 |
+
|
180 |
+
|
181 |
+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
182 |
|
183 |
# Resize the image to 224x224 pixels
|
184 |
img = Image.fromarray(img)
|