kai-sheng commited on
Commit
8384356
·
verified ·
1 Parent(s): d7b2ea0

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -3
main.py CHANGED
@@ -3,9 +3,10 @@ from flask import Flask, request, jsonify
3
  import base64
4
  import pytesseract
5
  import numpy as np
 
6
  from pickle import load
7
  from PIL import Image
8
- from keras.applications.xception import Xception #to get pre-trained model Xception
9
  from keras.models import load_model
10
  from keras.preprocessing.sequence import pad_sequences
11
 
@@ -13,6 +14,17 @@ app = Flask(__name__)
13
 
14
  MAX_LENGTH = 38
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  def format_tesseract_output(output_text):
17
  formatted_text = ""
18
  lines = output_text.strip().split("\n")
@@ -34,7 +46,7 @@ def extract_features(image_data, model):
34
  image = image.resize((299,299))
35
  image = np.array(image)
36
 
37
- # for 4 channels images, we need to convert them into 3 channels
38
  if image.shape[2] == 4:
39
  image = image[..., :3]
40
 
@@ -77,7 +89,7 @@ def generate_caption():
77
  image_data = base64.b64decode(base64_image_data)
78
 
79
  # Convert the image data to a PIL image object
80
- pil_image = Image.open(io.BytesIO(img_path))
81
 
82
  extracted_text = pytesseract.image_to_string(pil_image, lang="eng+chi_sim+msa")
83
  hasText = bool(extracted_text.strip())
 
3
  import base64
4
  import pytesseract
5
  import numpy as np
6
+ import tensorflow as tf
7
  from pickle import load
8
  from PIL import Image
9
+ from keras.applications.xception import Xception # to get pre-trained model Xception
10
  from keras.models import load_model
11
  from keras.preprocessing.sequence import pad_sequences
12
 
 
14
 
15
  MAX_LENGTH = 38
16
 
17
+ # Set up GPU memory growth
18
+ physical_devices = tf.config.list_physical_devices('GPU')
19
+ if physical_devices:
20
+ try:
21
+ # Allow memory growth for all GPUs
22
+ for gpu in physical_devices:
23
+ tf.config.experimental.set_memory_growth(gpu, True)
24
+ print("GPU(s) memory growth set to True")
25
+ except RuntimeError as e:
26
+ print(e)
27
+
28
  def format_tesseract_output(output_text):
29
  formatted_text = ""
30
  lines = output_text.strip().split("\n")
 
46
  image = image.resize((299,299))
47
  image = np.array(image)
48
 
49
+ # convert 4 channels image into 3 channels
50
  if image.shape[2] == 4:
51
  image = image[..., :3]
52
 
 
89
  image_data = base64.b64decode(base64_image_data)
90
 
91
  # Convert the image data to a PIL image object
92
+ pil_image = Image.open(io.BytesIO(image_data))
93
 
94
  extracted_text = pytesseract.image_to_string(pil_image, lang="eng+chi_sim+msa")
95
  hasText = bool(extracted_text.strip())