File size: 2,073 Bytes
8d9306e
6f0178d
8d9306e
 
 
c951094
374fa3e
8d9306e
0bb133b
f884ea7
6f0178d
 
 
c951094
 
6f0178d
f884ea7
 
 
d28411b
8d9306e
 
 
6f0178d
 
 
 
f705683
8d9306e
6f0178d
 
 
 
 
 
 
 
 
 
9a6a97f
8d9306e
6f0178d
 
 
 
 
 
 
 
 
 
8d9306e
 
 
 
686f21e
8d9306e
d28411b
c951094
 
6f0178d
 
8d9306e
686f21e
c951094
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
import requests


# Designing the interface
st.title("🖼️ Image Captioning Demo 📝")
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)")

st.sidebar.markdown(
    """
    An image captioning model by combining ViT model with GPT2 model.
    The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder
    framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html).
    The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights.
    The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256).
    [Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n
    """
)

with st.spinner('Loading and compiling ViT-GPT2 model ...'):
    from model import *


st.sidebar.title("Select a sample image")
image_id = st.sidebar.selectbox(
    "Please choose a sample image",
    sample_image_ids
)

random_image_id = None
if st.sidebar.button("Random COCO 2017 (val) images"):
    random_image_id = get_random_image_id()

if random_image_id is not None:
    image_id = random_image_id

st.write(image_id)

sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg"
sample_path = os.path.join(sample_dir, sample_name)

if os.path.isfile(sample_path):
    image = Image.open(sample_path)
else:
    url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg"
    image = Image.open(requests.get(url, stream=True).raw)

resized = image.resize(size=(384, 384))
show = st.image(resized, width=384)
show.image(resized, '\n\nSelected Image', width=384)
resized.close()

# For newline
st.sidebar.write('\n')

with st.spinner('Generating image caption ...'):

    caption = predict(image)

    caption_en = caption
    st.header(f'Predicted caption:\n\n')
    st.subheader(caption_en)

st.sidebar.header("ViT-GPT2 predicts:")
st.sidebar.write(f"**English**: {caption}")

image.close()