yuanphon commited on
Commit
5cc1c86
·
1 Parent(s): d444104

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from PIL import Image
7
  from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, ViTImageProcessor, VisionEncoderDecoderModel, AutoTokenizer
8
  from torch.utils.data import DataLoader, TensorDataset
 
9
 
10
  model_path = '/home/user/app'
11
  train_pickle_path = 'train_data.pickle'
@@ -152,7 +153,7 @@ def train_model():
152
 
153
  model.save_pretrained("model")
154
 
155
- def predict():
156
  # Load the model
157
  model = ViTForImageClassification.from_pretrained(model_path, num_labels=num_classes)
158
 
@@ -162,7 +163,8 @@ def predict():
162
  # Load the test data
163
  # Load the image
164
 
165
- img = cv2.imread(test_image_path)
 
166
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
167
 
168
  # Resize the image to 224x224 pixels
@@ -214,16 +216,28 @@ def output(predict_class, caption):
214
  conj = ['are', 'is', 'dog']
215
  if predict_class == '不是校狗' or caption.find('dog') == -1:
216
  print(f'{caption} ({predict_class})')
 
217
  else:
218
  for c in conj:
219
  if caption.find(c) != -1:
220
  print(f'{predict_class} is{caption[caption.find(c) + len(c):]}')
221
  return
222
  print(f'{caption} ({predict_class})')
 
223
 
224
 
225
  if __name__ == '__main__':
226
 
227
  if not os.path.exists(model_path):
228
  train_model()
229
- output(predict(), captioning())
 
 
 
 
 
 
 
 
 
 
 
6
  from PIL import Image
7
  from transformers import ViTForImageClassification, AutoImageProcessor, AdamW, ViTImageProcessor, VisionEncoderDecoderModel, AutoTokenizer
8
  from torch.utils.data import DataLoader, TensorDataset
9
+ import gradio as gr
10
 
11
  model_path = '/home/user/app'
12
  train_pickle_path = 'train_data.pickle'
 
153
 
154
  model.save_pretrained("model")
155
 
156
+ def predict(upload_image):
157
  # Load the model
158
  model = ViTForImageClassification.from_pretrained(model_path, num_labels=num_classes)
159
 
 
163
  # Load the test data
164
  # Load the image
165
 
166
+ # img = cv2.imread(test_image_path)
167
+ img = upload_image
168
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
169
 
170
  # Resize the image to 224x224 pixels
 
216
  conj = ['are', 'is', 'dog']
217
  if predict_class == '不是校狗' or caption.find('dog') == -1:
218
  print(f'{caption} ({predict_class})')
219
+ return (f'{caption} ({predict_class})')
220
  else:
221
  for c in conj:
222
  if caption.find(c) != -1:
223
  print(f'{predict_class} is{caption[caption.find(c) + len(c):]}')
224
  return
225
  print(f'{caption} ({predict_class})')
226
+
227
 
228
 
229
  if __name__ == '__main__':
230
 
231
  if not os.path.exists(model_path):
232
  train_model()
233
+ # output(predict(), captioning())
234
+
235
+
236
+ # def greet(name):
237
+ # return "Hello " + name + "!!"
238
+ def get_result(upload_image):
239
+ result = output(predict(upload_image), captioning())
240
+ return result
241
+
242
+ iface = gr.Interface(fn=get_result, inputs="image", outputs="text")
243
+ iface.launch()