Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
from PIL import Image
|
7 |
from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, ViTImageProcessor, VisionEncoderDecoderModel, AutoTokenizer
|
8 |
from torch.utils.data import DataLoader, TensorDataset
|
|
|
9 |
|
10 |
model_path = '/home/user/app'
|
11 |
train_pickle_path = 'train_data.pickle'
|
@@ -152,7 +153,7 @@ def train_model():
|
|
152 |
|
153 |
model.save_pretrained("model")
|
154 |
|
155 |
-
def predict():
|
156 |
# Load the model
|
157 |
model = ViTForImageClassification.from_pretrained(model_path, num_labels=num_classes)
|
158 |
|
@@ -162,7 +163,8 @@ def predict():
|
|
162 |
# Load the test data
|
163 |
# Load the image
|
164 |
|
165 |
-
img = cv2.imread(test_image_path)
|
|
|
166 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
167 |
|
168 |
# Resize the image to 224x224 pixels
|
@@ -214,16 +216,28 @@ def output(predict_class, caption):
|
|
214 |
conj = ['are', 'is', 'dog']
|
215 |
if predict_class == '不是校狗' or caption.find('dog') == -1:
|
216 |
print(f'{caption} ({predict_class})')
|
|
|
217 |
else:
|
218 |
for c in conj:
|
219 |
if caption.find(c) != -1:
|
220 |
print(f'{predict_class} is{caption[caption.find(c) + len(c):]}')
|
221 |
return
|
222 |
print(f'{caption} ({predict_class})')
|
|
|
223 |
|
224 |
|
225 |
if __name__ == '__main__':
|
226 |
|
227 |
if not os.path.exists(model_path):
|
228 |
train_model()
|
229 |
-
output(predict(), captioning())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from PIL import Image
|
7 |
from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, ViTImageProcessor, VisionEncoderDecoderModel, AutoTokenizer
|
8 |
from torch.utils.data import DataLoader, TensorDataset
|
9 |
+
import gradio as gr
|
10 |
|
11 |
model_path = '/home/user/app'
|
12 |
train_pickle_path = 'train_data.pickle'
|
|
|
153 |
|
154 |
model.save_pretrained("model")
|
155 |
|
156 |
+
def predict(upload_image):
|
157 |
# Load the model
|
158 |
model = ViTForImageClassification.from_pretrained(model_path, num_labels=num_classes)
|
159 |
|
|
|
163 |
# Load the test data
|
164 |
# Load the image
|
165 |
|
166 |
+
# img = cv2.imread(test_image_path)
|
167 |
+
img = upload_image
|
168 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
169 |
|
170 |
# Resize the image to 224x224 pixels
|
|
|
216 |
conj = ['are', 'is', 'dog']
|
217 |
if predict_class == '不是校狗' or caption.find('dog') == -1:
|
218 |
print(f'{caption} ({predict_class})')
|
219 |
+
return (f'{caption} ({predict_class})')
|
220 |
else:
|
221 |
for c in conj:
|
222 |
if caption.find(c) != -1:
|
223 |
print(f'{predict_class} is{caption[caption.find(c) + len(c):]}')
|
224 |
return
|
225 |
print(f'{caption} ({predict_class})')
|
226 |
+
|
227 |
|
228 |
|
229 |
if __name__ == '__main__':
|
230 |
|
231 |
if not os.path.exists(model_path):
|
232 |
train_model()
|
233 |
+
# output(predict(), captioning())
|
234 |
+
|
235 |
+
|
236 |
+
# def greet(name):
|
237 |
+
# return "Hello " + name + "!!"
|
238 |
+
def get_result(upload_image):
|
239 |
+
result = output(predict(upload_image), captioning())
|
240 |
+
return result
|
241 |
+
|
242 |
+
iface = gr.Interface(fn=get_result, inputs="image", outputs="text")
|
243 |
+
iface.launch()
|