File size: 2,823 Bytes
9522bcd
 
 
 
57231ec
 
c749f9d
57231ec
 
 
08ec34d
57231ec
 
 
 
7e847dc
82b37de
9522bcd
 
 
 
420f5bf
57231ec
 
 
 
08ec34d
5f69d7b
7e847dc
 
57231ec
 
 
1af1fd7
0959b2d
 
 
 
 
 
 
 
 
 
 
 
57231ec
 
0959b2d
57231ec
 
 
 
 
 
 
 
08ec34d
 
57231ec
 
 
 
 
 
08ec34d
 
b26fc84
57231ec
 
 
 
 
 
 
27a4a72
 
 
 
57231ec
 
27a4a72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from transformers import AutoTokenizer, AutoModel ,AutoConfig
import torch
from transformers import  ViTImageProcessor, VisionEncoderDecoderModel,RobertaTokenizerFast
import PIL 
import streamlit as st
from PIL import Image


def set_page_config():
    st.set_page_config(
        page_title='Caption an Cartoon Image', 
        page_icon=':camera:', 
        layout='wide',
    )

def initialize_model():
    device = 'cpu'
    config = AutoConfig.from_pretrained("sourabhbargi11/Caption_generator_model")  
    model = VisionEncoderDecoderModel.from_pretrained("sourabhbargi11/Caption_generator_model", config=config).to(device)
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224",device=device)
    return image_processor, model,tokenizer, device

def upload_image():
    return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])

def image_preprocess(image):
    image = image.resize((224,224))
    if image.mode == "L":
        image = image.convert("RGB")
    return image

def generate_caption(processor, model, device, image):
    inputs =  image_processor (image, return_tensors='pt').to(device)
    model.eval()
    # Generate caption
    with torch.no_grad():
      output = model.generate(
          pixel_values=inputs ,
          max_length=1000,  # Adjust the maximum length of the generated caption as needed
          num_beams=4,    # Adjust the number of beams for beam search decoding
          early_stopping=True  # Enable early stopping to stop generation when all beams finished
      )

      # Decode the generated caption
      caption = tokenizer.decode(output[0], skip_special_tokens=True)
    return caption


def main():
    set_page_config()
    st.header("Caption an Image :camera:")

    uploaded_image = upload_image()

    if uploaded_image is not None:
        image = Image.open(uploaded_image)
        image = image_preprocess(image)
        
        st.image(image, caption='Your image')
        
        with st.sidebar:
            st.divider() 
            if st.sidebar.button('Generate Caption'):
                with st.spinner('Generating caption...'):
                    image_processor, model,tokenizer, device = initialize_model()
                    caption = generate_caption(image_processor, model, device, image)
            
                    st.header("Caption:")
                    st.markdown(f'**{caption}**')


if __name__ == '__main__':
    main()

# st.markdown("""
#     ---
#    You are looking at partial tuned model , please JUDGE ME!!!  (I am Funny , Sensible , Creative )""")

st.markdown("""
    ---
   You are looking at a partially tuned model. Judge me! (I am Funny and Creative) πŸ˜„πŸŽ¨""")