Spaces:
Sleeping
Sleeping
File size: 2,823 Bytes
9522bcd 57231ec c749f9d 57231ec 08ec34d 57231ec 7e847dc 82b37de 9522bcd 420f5bf 57231ec 08ec34d 5f69d7b 7e847dc 57231ec 1af1fd7 0959b2d 57231ec 0959b2d 57231ec 08ec34d 57231ec 08ec34d b26fc84 57231ec 27a4a72 57231ec 27a4a72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from transformers import AutoTokenizer, AutoModel ,AutoConfig
import torch
from transformers import ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast
import PIL
import streamlit as st
from PIL import Image
def set_page_config():
st.set_page_config(
page_title='Caption an Cartoon Image',
page_icon=':camera:',
layout='wide',
)
def initialize_model():
device = 'cpu'
config = AutoConfig.from_pretrained("sourabhbargi11/Caption_generator_model")
model = VisionEncoderDecoderModel.from_pretrained("sourabhbargi11/Caption_generator_model", config=config).to(device)
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224",device=device)
return image_processor, model,tokenizer, device
def upload_image():
return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])
def image_preprocess(image):
image = image.resize((224,224))
if image.mode == "L":
image = image.convert("RGB")
return image
def generate_caption(processor, model, device, image):
inputs = image_processor (image, return_tensors='pt').to(device)
model.eval()
# Generate caption
with torch.no_grad():
output = model.generate(
pixel_values=inputs ,
max_length=1000, # Adjust the maximum length of the generated caption as needed
num_beams=4, # Adjust the number of beams for beam search decoding
early_stopping=True # Enable early stopping to stop generation when all beams finished
)
# Decode the generated caption
caption = tokenizer.decode(output[0], skip_special_tokens=True)
return caption
def main():
set_page_config()
st.header("Caption an Image :camera:")
uploaded_image = upload_image()
if uploaded_image is not None:
image = Image.open(uploaded_image)
image = image_preprocess(image)
st.image(image, caption='Your image')
with st.sidebar:
st.divider()
if st.sidebar.button('Generate Caption'):
with st.spinner('Generating caption...'):
image_processor, model,tokenizer, device = initialize_model()
caption = generate_caption(image_processor, model, device, image)
st.header("Caption:")
st.markdown(f'**{caption}**')
if __name__ == '__main__':
main()
# st.markdown("""
# ---
# You are looking at partial tuned model , please JUDGE ME!!! (I am Funny , Sensible , Creative )""")
st.markdown("""
---
You are looking at a partially tuned model. Judge me! (I am Funny and Creative) ππ¨""")
|