kai-sheng's picture
first upload image caption generator
a426d06 verified
raw
history blame
2.9 kB
import os
import io
from flask import Flask, request, jsonify
import base64
import numpy as np
from pickle import load
from PIL import Image
from keras.applications.xception import Xception #to get pre-trained model Xception
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
app = Flask(__name__)
# MAX_LENGTH = 34
MAX_LENGTH = 38
def extract_features(image_data, model):
try:
image = Image.open(io.BytesIO(image_data))
except Exception as e:
print("ERROR: Can't open image! Ensure that image data is correct and in the expected format")
print(str(e))
return None
image = image.resize((299,299))
image = np.array(image)
# for 4 channels images, we need to convert them into 3 channels
if image.shape[2] == 4:
image = image[..., :3]
image = np.expand_dims(image, axis=0)
image = image/127.5
image = image - 1.0
feature = model.predict(image)
return feature
def word_for_id(integer, tokenizer):
for word, index in tokenizer.word_index.items():
if index == integer:
return word
return None
def generate_desc(model, tokenizer, photo, max_length):
in_text = 'start'
for i in range(max_length):
sequence = tokenizer.texts_to_sequences([in_text])[0]
sequence = pad_sequences([sequence], maxlen=max_length)
pred = model.predict([photo,sequence], verbose=0)
pred = np.argmax(pred)
word = word_for_id(pred, tokenizer)
if word is None or word == 'end':
break
in_text += ' ' + word
return in_text.replace('start ', '')
# API endpoint to receive image and generate caption
@app.route('/api', methods=['POST'])
def generate_caption():
try:
base64_image_data = request.form['image']
# return jsonify({'caption': base64_image_data}), 200
# Replace spaces with "+" characters to handle cases where "+" characters are missing
# base64_image_data = base64_image_data.replace(" ", "+")
# Decode the Base64 string into binary image data
image_data = base64.b64decode(base64_image_data)
tokenizer = load(open("tokenizer.p","rb"))
# model = load_model('model_9.h5')
model = load_model('models/model_9.keras')
xception_model = Xception(include_top=False, pooling="avg")
photo = extract_features(image_data, xception_model)
if photo is None:
return jsonify({'error': 'Failed to extract features from the image'}), 400
caption = generate_desc(model, tokenizer, photo, MAX_LENGTH)
# Return the generated caption
return jsonify({'caption': caption}), 200
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0')