Spaces:
Runtime error
Runtime error
| from flask import Flask,request | |
| import google.generativeai as palm | |
| import re | |
| import pickle | |
| import numpy as np | |
| import requests | |
| from PIL import Image | |
| from io import BytesIO | |
| from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input | |
| from tensorflow.keras.preprocessing.image import load_img, img_to_array | |
| from tensorflow.keras.preprocessing.text import Tokenizer | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.utils import to_categorical, plot_model | |
| from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, add | |
| from tensorflow.keras.models import load_model | |
| #tokenizer=pickle.load(open('tokenizer.pkl','rb')) | |
| #vgg_model = load_model('vgg_model.h5') | |
| model = load_model('best_model.h5') | |
| max_len=35 | |
| with open('captions.txt','r') as f: | |
| next(f) | |
| caption_file=f.read() | |
| captions={} | |
| for line in caption_file.split('\n'): | |
| values=line.split(",") | |
| if(len(line)<2): | |
| continue | |
| #get image_id | |
| image_id=values[0] | |
| image_id=image_id.split('.')[0] | |
| #get caption | |
| caption=values[1:] | |
| caption=" ".join(caption) | |
| #mapping caption | |
| if image_id not in captions: | |
| captions[image_id]=[] | |
| captions[image_id].append(caption) | |
| def clean(captions): | |
| for key,caption_ in captions.items(): | |
| for i in range(len(caption_)): | |
| caption=caption_[i] | |
| #process caption | |
| caption=caption.lower() | |
| caption = re.sub('[^a-zA-Z]', ' ', caption) | |
| caption = re.sub('\s+', ' ', caption) | |
| caption=" ".join([word for word in caption.split() if len(word)>1]) | |
| caption="startseq "+caption+" endseq" | |
| caption_[i]=caption | |
| clean(captions) | |
| all_captions=[] | |
| for key,caption_ in captions.items(): | |
| for i in range(len(caption_)): | |
| all_captions.append(caption_[i]) | |
| tokenizer=Tokenizer() | |
| tokenizer.fit_on_texts(all_captions) | |
| # load vgg16 model | |
| vgg_model = VGG16() | |
| # restructure the model | |
| vgg_model = Model(inputs=vgg_model.inputs, outputs=vgg_model.layers[-2].output) | |
| def index_to_word(indx,tokenizer): | |
| for word,index in tokenizer.word_index.items(): | |
| if index == indx: | |
| return word | |
| return None | |
| def predict_captions(model,image,tokenizer,max_len): | |
| in_text='startseq' | |
| for i in range(max_len): | |
| seq=tokenizer.texts_to_sequences([in_text])[0] | |
| seq=pad_sequences([seq],max_len)[0] | |
| if len(image.shape) == 3: | |
| image = np.expand_dims(image, axis=0) | |
| y_pred=model.predict([image, np.expand_dims(seq, axis=0)],verbose=0) | |
| y_pred=np.argmax(y_pred) | |
| word=index_to_word(y_pred,tokenizer) | |
| if word == None: | |
| break | |
| in_text += " " + word | |
| if word == 'endseq': | |
| break | |
| return in_text | |
| def caption_generator(url): | |
| #load image | |
| response = requests.get(url) | |
| image= Image.open(BytesIO(response.content)) | |
| image = image.resize((224,224)) | |
| #convert image into numpy array | |
| image=img_to_array(image) | |
| #reshape image | |
| image=image.reshape((1,image.shape[0],image.shape[1],image.shape[2])) | |
| #preprrocess image for vgg16 | |
| image=preprocess_input(image) | |
| #extract features | |
| feature=vgg_model.predict(image,verbose=0) | |
| y_pred = predict_captions(model, feature, tokenizer, max_len) | |
| #plt.imshow(image_pic) | |
| return y_pred | |
| app=Flask(__name__) | |
| def home(): | |
| return "HELLO WORLD" | |
| def predict(): | |
| url=request.get_json() | |
| print(url) | |
| result=caption_generator(url['url']) | |
| palm.configure(api_key='AIzaSyDDXOjF1BBgJM6g1tMV-6tcI7xh9-ctvQU') | |
| #models = [m for m in palm.list_models() if 'generateText' in m.supported_generation_methods] | |
| #model = models[0].name | |
| model="models/text-bison-001" | |
| prompt = "Generate a creative & attractive instagram caption of 10-30 words words for" + str(result) | |
| completion = palm.generate_text( | |
| model=model, | |
| prompt=prompt, | |
| temperature=0, | |
| # The maximum length of the response | |
| max_output_tokens=100, | |
| ) | |
| return completion.result | |
| #return {'caption':str(result)} | |
| if __name__ == '__main__': | |
| app.run(debug=True) |