File size: 6,007 Bytes
62257cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import streamlit as st
import extra_streamlit_components as stx
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from io import BytesIO
import replicate
from llama_index.llms.palm import PaLM
from llama_index import ServiceContext, VectorStoreIndex, Document
from llama_index.memory import ChatMemoryBuffer
import os
import datetime

# Set up the title of the application
st.title("Image Captioning and Chat")

# Initialize the cookie manager
cookie_manager = stx.CookieManager()

@st.cache_resource
def get_vision_model():
    model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
    processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
    return model, processor

# Function to get image caption via Kosmos2.
@st.cache_data
def get_image_caption(image_data):

    model, processor = get_vision_model()
    #model = AutoModelForVision2Seq.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)
    #processor = AutoProcessor.from_pretrained("ydshieh/kosmos-2-patch14-224", trust_remote_code=True)

    prompt = "<grounding>An image of"
    inputs = processor(text=prompt, images=image_data, return_tensors="pt")

    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"][:, :-1],
        attention_mask=inputs["attention_mask"][:, :-1],
        img_features=None,
        img_attn_mask=inputs["img_attn_mask"][:, :-1],
        use_cache=True,
        max_new_tokens=64,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    text_description, entities = processor.post_process_generation(generated_text)
    
#Using replicate API
#    input_data = {
#        "image": image_data,
#        "description_type": "Brief"
#    }
#    output = replicate.run(
#        "lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805",
#        input=input_data
#    )
#    # Split the output string on the newline character and take the first item
#    text_description = output.split('\n\n')[0]
    return text_description

# Function to create the chat engine.
@st.cache_resource
def create_chat_engine(img_desc, api_key):
    llm = PaLM(api_key=api_key)
    service_context = ServiceContext.from_defaults(llm=llm)
    doc = Document(text=img_desc)
    index = VectorStoreIndex.from_documents([doc], service_context=service_context)
    chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500)
    
    chat_engine = index.as_chat_engine(
        chat_mode="context",
        system_prompt=(
            f"You are a chatbot, able to have normal interactions, as well as talk. "
            "You always answer in great detail and are polite. Your responses always descriptive. "
            "Your job is to talk about an image the user has uploaded. Image description: {img_desc}."
        ),
        verbose=True,
        memory=chatmemory
    )
    return chat_engine

# Clear chat function
def clear_chat():
    if "messages" in st.session_state:
        del st.session_state.messages
    if "image_file" in st.session_state:
        del st.session_state.image_file

# Callback function to clear the chat when a new image is uploaded
def on_image_upload():
    clear_chat()        

# Add a clear chat button
if st.button("Clear Chat"):
    clear_chat()        

# Image upload section.
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
if image_file:
    # Display the uploaded image at a standard width.
    st.image(image_file, caption='Uploaded Image.', width=200)
    # Process the uploaded image to get a caption.
    image_data = BytesIO(image_file.getvalue())
    img_desc = get_image_caption(image_data)
    st.write(f"Image description: {img_desc}")

    # Initialize the chat engine with the image description.
    chat_engine = create_chat_engine(img_desc, os.environ["GOOGLE_API_KEY"])

# Initialize session state for messages if it doesn't exist
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display previous messages
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Handle new user input
user_input = st.chat_input("Ask me about the image:", key="chat_input")
if user_input:
    # Retrieve the message count from cookies
    message_count = cookie_manager.get(cookie='message_count')
    if message_count is None:
        message_count = 0
    else:
        message_count = int(message_count)

    # Check if the message limit has been reached
    if message_count >= 20:
        st.error("Notice: The maximum message limit for this demo version has been reached.")
    else:
        # Append user message to the session state
        st.session_state.messages.append({"role": "user", "content": user_input})

        # Display user message immediately
        with st.chat_message("user"):
            st.markdown(user_input)

        # Call the chat engine to get the response if an image has been uploaded
        if image_file:
            # Get the response from your chat engine
            response = chat_engine.chat(user_input)

            # Append assistant message to the session state
            st.session_state.messages.append({"role": "assistant", "content": response})

            # Display the assistant message
            with st.chat_message("assistant"):
                st.markdown(response)
        
        # Increment the message count and update the cookie
        message_count += 1
        cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))



# Set Replicate and Google API keys
os.environ['REPLICATE_API_TOKEN'] = st.secrets['REPLICATE_API_TOKEN']
os.environ["GOOGLE_API_KEY"] = st.secrets['GOOGLE_API_KEY']