Spaces:
Sleeping
Sleeping
File size: 5,489 Bytes
4f5a895 62257cb a65cb5d 62257cb a65cb5d 62257cb a65cb5d 62257cb 0af05ea 62257cb a65cb5d a91a7bc 62257cb a65cb5d 62257cb a65cb5d 62257cb a65cb5d 4f5a895 62257cb a65cb5d 4f5a895 a65cb5d 62257cb a65cb5d 62257cb a65cb5d 62257cb 4f5a895 a91a7bc 65efe6b a91a7bc 4f5a895 a91a7bc a65cb5d a91a7bc a65cb5d a91a7bc a65cb5d a91a7bc f7038a5 a91a7bc 4f5a895 a91a7bc a65cb5d a91a7bc 4f5a895 62257cb a65cb5d 62257cb a65cb5d a91a7bc a65cb5d a91a7bc a65cb5d 62257cb a91a7bc a65cb5d a91a7bc 4f5a895 a91a7bc 4f5a895 a91a7bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import streamlit as st
import extra_streamlit_components as stx
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq
from io import BytesIO
import replicate
from llama_index.llms.palm import PaLM
from llama_index import ServiceContext, VectorStoreIndex, Document
from llama_index.memory import ChatMemoryBuffer
import os
import datetime
# Set up the title of the application
#st.title("PaLM-Kosmos-Vision")
st.set_page_config(layout="wide")
st.write("My version of ChatGPT vision. You can upload an image and start chatting with the LLM about the image")
# Initialize the cookie manager
cookie_manager = stx.CookieManager()
# Function to get image caption via Kosmos2.
@st.cache_data
def get_image_caption(image_data):
input_data = {
"image": image_data,
"description_type": "Brief"
}
output = replicate.run(
"lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805",
input=input_data
)
# Split the output string on the newline character and take the first item
text_description = output.split('\n\n')[0]
return text_description
# Function to create the chat engine.
@st.cache_resource
def create_chat_engine(img_desc, api_key):
llm = PaLM(api_key=api_key)
service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
doc = Document(text=img_desc)
index = VectorStoreIndex.from_documents([doc], service_context=service_context)
chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500)
chat_engine = index.as_chat_engine(
chat_mode="context",
system_prompt=(
f"You are a chatbot, able to have normal interactions, as well as talk. "
"You always answer in great detail and are polite. Your responses always descriptive. "
"Your job is to talk about an image the user has uploaded. Image description: {img_desc}."
),
verbose=True,
memory=chatmemory
)
return chat_engine
# Clear chat function
def clear_chat():
if "messages" in st.session_state:
del st.session_state.messages
if "image_file" in st.session_state:
del st.session_state.image_file
# Callback function to clear the chat when a new image is uploaded
def on_image_upload():
clear_chat()
# Retrieve the message count from cookies
message_count = cookie_manager.get(cookie='message_count')
if message_count is None:
message_count = 0
else:
message_count = int(message_count)
# If the message limit has been reached, disable the inputs
if 0:
st.error("Notice: The maximum message limit for this demo version has been reached.")
# Disabling the uploader and input by not displaying them
image_uploader_placeholder = st.empty() # Placeholder for the uploader
chat_input_placeholder = st.empty() # Placeholder for the chat input
else:
# Add a clear chat button
if st.button("Clear Chat"):
clear_chat()
# Image upload section.
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
if image_file:
# Display the uploaded image at a standard width.
st.image(image_file, caption='Uploaded Image.', width=200)
# Process the uploaded image to get a caption.
image_data = BytesIO(image_file.getvalue())
img_desc = get_image_caption(image_data)
st.write("Image Uploaded Successfully. Ask me anything about it.")
# Initialize the chat engine with the image description.
chat_engine = create_chat_engine(img_desc, st.secrets['GOOGLE_API_KEY'])
# Initialize session state for messages if it doesn't exist
if "messages" not in st.session_state:
st.session_state.messages = []
# Display previous messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle new user input
user_input = st.chat_input("Ask me about the image:", key="chat_input")
if user_input:
# Append user message to the session state
st.session_state.messages.append({"role": "user", "content": user_input})
# Display user message immediately
with st.chat_message("user"):
st.markdown(user_input)
# Call the chat engine to get the response if an image has been uploaded
if image_file and user_input:
try:
with st.spinner('Waiting for the chat engine to respond...'):
# Get the response from your chat engine
response = chat_engine.chat(user_input)
# Append assistant message to the session state
st.session_state.messages.append({"role": "assistant", "content": response})
# Display the assistant message
with st.chat_message("assistant"):
st.markdown(response)
except Exception as e:
st.error(f'An error occurred.')
# Optionally, you can choose to break the flow here if a critical error happens
# return
# Increment the message count and update the cookie
message_count += 1
cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))
|