AI-ANK commited on
Commit
a91a7bc
·
1 Parent(s): 34026b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -93
app.py CHANGED
@@ -4,60 +4,52 @@ import requests
4
  from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForVision2Seq
6
  from io import BytesIO
7
- #import replicate
8
  from llama_index.llms.palm import PaLM
9
  from llama_index import ServiceContext, VectorStoreIndex, Document
10
  from llama_index.memory import ChatMemoryBuffer
11
  import os
12
  import datetime
13
- from PIL import Image
14
- import io
15
 
16
  # Set up the title of the application
17
- st.title("Image Captioning and Chat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Initialize the cookie manager
20
  cookie_manager = stx.CookieManager()
21
 
22
- @st.cache_resource
23
- def get_vision_model():
24
- model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
25
- processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
26
- return model, processor
27
-
28
- model, processor = get_vision_model()
29
-
30
  # Function to get image caption via Kosmos2.
31
  @st.cache_data
32
  def get_image_caption(image_data):
33
- # Ensure image_data is a bytes stream ready to be read by Image.open
34
- if isinstance(image_data, io.BytesIO):
35
- # If it's already a BytesIO, we need to seek to the beginning of the file
36
- image_data.seek(0)
37
- image = Image.open(image_data)
38
- else:
39
- # If image_data is not a BytesIO object, create one
40
- image = Image.open(io.BytesIO(image_data.read()))
41
-
42
- model, processor = get_vision_model()
43
-
44
- prompt = "<grounding>An image of"
45
- # Pass the PIL image to the processor
46
- inputs = processor(text=prompt, images=image, return_tensors="pt")
47
-
48
- generated_ids = model.generate(
49
- pixel_values=inputs["pixel_values"],
50
- input_ids=inputs["input_ids"][:, :-1],
51
- attention_mask=inputs["attention_mask"][:, :-1],
52
- img_features=None,
53
- img_attn_mask=inputs["img_attn_mask"][:, :-1],
54
- use_cache=True,
55
- max_new_tokens=64,
56
  )
57
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
58
-
59
- text_description, entities = processor.post_process_generation(generated_text)
60
-
61
  return text_description
62
 
63
  # Function to create the chat engine.
@@ -92,46 +84,49 @@ def clear_chat():
92
  def on_image_upload():
93
  clear_chat()
94
 
95
- # Add a clear chat button
96
- if st.button("Clear Chat"):
97
- clear_chat()
98
-
99
- # Image upload section.
100
- image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
101
- if image_file:
102
- # Display the uploaded image at a standard width.
103
- st.image(image_file, caption='Uploaded Image.', width=200)
104
- # Process the uploaded image to get a caption.
105
- image_data = BytesIO(image_file.getvalue())
106
- img_desc = get_image_caption(image_data)
107
- st.write(f"Image description: {img_desc}")
108
-
109
- # Initialize the chat engine with the image description.
110
- chat_engine = create_chat_engine(img_desc, os.environ["GOOGLE_API_KEY"])
111
-
112
- # Initialize session state for messages if it doesn't exist
113
- if "messages" not in st.session_state:
114
- st.session_state.messages = []
115
-
116
- # Display previous messages
117
- for message in st.session_state.messages:
118
- with st.chat_message(message["role"]):
119
- st.markdown(message["content"])
120
-
121
- # Handle new user input
122
- user_input = st.chat_input("Ask me about the image:", key="chat_input")
123
- if user_input:
124
- # Retrieve the message count from cookies
125
- message_count = cookie_manager.get(cookie='message_count')
126
- if message_count is None:
127
- message_count = 0
128
- else:
129
- message_count = int(message_count)
130
-
131
- # Check if the message limit has been reached
132
- if message_count >= 20:
133
- st.error("Notice: The maximum message limit for this demo version has been reached.")
134
- else:
 
 
 
135
  # Append user message to the session state
136
  st.session_state.messages.append({"role": "user", "content": user_input})
137
 
@@ -140,23 +135,31 @@ if user_input:
140
  st.markdown(user_input)
141
 
142
  # Call the chat engine to get the response if an image has been uploaded
143
- if image_file:
144
- # Get the response from your chat engine
145
- response = chat_engine.chat(user_input)
146
-
147
- # Append assistant message to the session state
148
- st.session_state.messages.append({"role": "assistant", "content": response})
149
-
150
- # Display the assistant message
151
- with st.chat_message("assistant"):
152
- st.markdown(response)
153
 
154
- # Increment the message count and update the cookie
155
- message_count += 1
156
- cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
 
160
  # Set Replicate and Google API keys
161
- #os.environ['REPLICATE_API_TOKEN'] = st.secrets['REPLICATE_API_TOKEN']
162
  os.environ["GOOGLE_API_KEY"] = st.secrets['GOOGLE_API_KEY']
 
4
  from PIL import Image
5
  from transformers import AutoProcessor, AutoModelForVision2Seq
6
  from io import BytesIO
7
+ import replicate
8
  from llama_index.llms.palm import PaLM
9
  from llama_index import ServiceContext, VectorStoreIndex, Document
10
  from llama_index.memory import ChatMemoryBuffer
11
  import os
12
  import datetime
 
 
13
 
14
  # Set up the title of the application
15
+ #st.title("PaLM-Kosmos-Vision")
16
+ st.set_page_config(layout="wide")
17
+ st.write("My version of ChatGPT vision. You can upload an image and start chatting with the LLM about the image")
18
+
19
+ # Sidebar
20
+ st.sidebar.markdown('## Created By')
21
+ st.sidebar.markdown("""
22
+ [Harshad Suryawanshi](https://www.linkedin.com/in/harshadsuryawanshi/)
23
+ """)
24
+
25
+ st.sidebar.markdown('## Other Projects')
26
+ st.sidebar.markdown("""
27
+ - [AI Equity Research Analyst](https://ai-eqty-rsrch-anlyst.streamlit.app/)
28
+ - [Recasting "The Office" Scene](https://blackmirroroffice.streamlit.app/)
29
+ - [Story Generator](https://appstorycombined-agaf9j4ceit.streamlit.app/)
30
+ """)
31
+
32
+ st.sidebar.markdown('## Disclaimer')
33
+ st.sidebar.markdown("""
34
+ This application is a conceptual prototype created to demonstrate the potential of Large Language Models (LLMs) in generating equity research reports. The contents generated by this application are purely illustrative and should not be construed as financial advice, endorsements, or recommendations. The author and the application do not provide any guarantee regarding the accuracy, completeness, or timeliness of the information provided.
35
+ """)
36
 
37
  # Initialize the cookie manager
38
  cookie_manager = stx.CookieManager()
39
 
 
 
 
 
 
 
 
 
40
  # Function to get image caption via Kosmos2.
41
  @st.cache_data
42
  def get_image_caption(image_data):
43
+ input_data = {
44
+ "image": image_data,
45
+ "description_type": "Brief"
46
+ }
47
+ output = replicate.run(
48
+ "lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805",
49
+ input=input_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
+ # Split the output string on the newline character and take the first item
52
+ text_description = output.split('\n\n')[0]
 
 
53
  return text_description
54
 
55
  # Function to create the chat engine.
 
84
  def on_image_upload():
85
  clear_chat()
86
 
87
+ # Retrieve the message count from cookies
88
+ message_count = cookie_manager.get(cookie='message_count')
89
+ if message_count is None:
90
+ message_count = 0
91
+ else:
92
+ message_count = int(message_count)
93
+
94
+ # If the message limit has been reached, disable the inputs
95
+ if message_count >= 20:
96
+ st.error("Notice: The maximum message limit for this demo version has been reached.")
97
+ # Disabling the uploader and input by not displaying them
98
+ image_uploader_placeholder = st.empty() # Placeholder for the uploader
99
+ chat_input_placeholder = st.empty() # Placeholder for the chat input
100
+ else:
101
+ # Add a clear chat button
102
+ if st.button("Clear Chat"):
103
+ clear_chat()
104
+
105
+ # Image upload section.
106
+ image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
107
+ if image_file:
108
+ # Display the uploaded image at a standard width.
109
+ st.image(image_file, caption='Uploaded Image.', width=200)
110
+ # Process the uploaded image to get a caption.
111
+ image_data = BytesIO(image_file.getvalue())
112
+ img_desc = get_image_caption(image_data)
113
+ st.write("Image Uploaded Successfully. Ask me anything about it.")
114
+
115
+ # Initialize the chat engine with the image description.
116
+ chat_engine = create_chat_engine(img_desc, os.environ["GOOGLE_API_KEY"])
117
+
118
+ # Initialize session state for messages if it doesn't exist
119
+ if "messages" not in st.session_state:
120
+ st.session_state.messages = []
121
+
122
+ # Display previous messages
123
+ for message in st.session_state.messages:
124
+ with st.chat_message(message["role"]):
125
+ st.markdown(message["content"])
126
+
127
+ # Handle new user input
128
+ user_input = st.chat_input("Ask me about the image:", key="chat_input")
129
+ if user_input:
130
  # Append user message to the session state
131
  st.session_state.messages.append({"role": "user", "content": user_input})
132
 
 
135
  st.markdown(user_input)
136
 
137
  # Call the chat engine to get the response if an image has been uploaded
138
+ if image_file and user_input:
139
+ try:
140
+ with st.spinner('Waiting for the chat engine to respond...'):
141
+ # Get the response from your chat engine
142
+ response = chat_engine.chat(user_input)
143
+
144
+ # Append assistant message to the session state
145
+ st.session_state.messages.append({"role": "assistant", "content": response})
 
 
146
 
147
+ # Display the assistant message
148
+ with st.chat_message("assistant"):
149
+ st.markdown(response)
150
+
151
+ except Exception as e:
152
+ st.error(f'An error occurred: {e}')
153
+ # Optionally, you can choose to break the flow here if a critical error happens
154
+ # return
155
+
156
+ # Increment the message count and update the cookie
157
+ message_count += 1
158
+ cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))
159
+
160
 
161
 
162
 
163
  # Set Replicate and Google API keys
164
+ os.environ['REPLICATE_API_TOKEN'] = st.secrets['REPLICATE_API_TOKEN']
165
  os.environ["GOOGLE_API_KEY"] = st.secrets['GOOGLE_API_KEY']