AI-ANK commited on
Commit
0af05ea
·
1 Parent(s): a91a7bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -53
app.py CHANGED
@@ -2,28 +2,35 @@ import streamlit as st
2
  import extra_streamlit_components as stx
3
  import requests
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForVision2Seq
6
  from io import BytesIO
7
- import replicate
8
  from llama_index.llms.palm import PaLM
9
- from llama_index import ServiceContext, VectorStoreIndex, Document
10
  from llama_index.memory import ChatMemoryBuffer
11
  import os
12
  import datetime
13
 
 
 
 
 
 
14
  # Set up the title of the application
15
- #st.title("PaLM-Kosmos-Vision")
16
- st.set_page_config(layout="wide")
17
- st.write("My version of ChatGPT vision. You can upload an image and start chatting with the LLM about the image")
18
 
19
  # Sidebar
20
  st.sidebar.markdown('## Created By')
21
  st.sidebar.markdown("""
22
- [Harshad Suryawanshi](https://www.linkedin.com/in/harshadsuryawanshi/)
 
 
23
  """)
24
 
 
25
  st.sidebar.markdown('## Other Projects')
26
  st.sidebar.markdown("""
 
27
  - [AI Equity Research Analyst](https://ai-eqty-rsrch-anlyst.streamlit.app/)
28
  - [Recasting "The Office" Scene](https://blackmirroroffice.streamlit.app/)
29
  - [Story Generator](https://appstorycombined-agaf9j4ceit.streamlit.app/)
@@ -31,54 +38,103 @@ st.sidebar.markdown("""
31
 
32
  st.sidebar.markdown('## Disclaimer')
33
  st.sidebar.markdown("""
34
- This application is a conceptual prototype created to demonstrate the potential of Large Language Models (LLMs) in generating equity research reports. The contents generated by this application are purely illustrative and should not be construed as financial advice, endorsements, or recommendations. The author and the application do not provide any guarantee regarding the accuracy, completeness, or timeliness of the information provided.
35
  """)
36
 
37
  # Initialize the cookie manager
38
  cookie_manager = stx.CookieManager()
39
 
40
- # Function to get image caption via Kosmos2.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @st.cache_data
42
  def get_image_caption(image_data):
43
- input_data = {
44
- "image": image_data,
45
- "description_type": "Brief"
46
- }
47
- output = replicate.run(
48
- "lucataco/kosmos-2:3e7b211c29c092f4bcc8853922cc986baa52efe255876b80cac2c2fbb4aff805",
49
- input=input_data
50
- )
51
- # Split the output string on the newline character and take the first item
52
- text_description = output.split('\n\n')[0]
53
- return text_description
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # Function to create the chat engine.
56
  @st.cache_resource
57
  def create_chat_engine(img_desc, api_key):
58
- llm = PaLM(api_key=api_key)
59
- service_context = ServiceContext.from_defaults(llm=llm)
 
60
  doc = Document(text=img_desc)
61
- index = VectorStoreIndex.from_documents([doc], service_context=service_context)
62
- chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500)
 
63
 
64
  chat_engine = index.as_chat_engine(
65
- chat_mode="context",
66
  system_prompt=(
67
- f"You are a chatbot, able to have normal interactions, as well as talk. "
68
- "You always answer in great detail and are polite. Your responses always descriptive. "
69
- "Your job is to talk about an image the user has uploaded. Image description: {img_desc}."
 
 
 
 
70
  ),
71
  verbose=True,
72
  memory=chatmemory
73
  )
74
- return chat_engine
 
 
 
75
 
76
  # Clear chat function
77
  def clear_chat():
78
  if "messages" in st.session_state:
79
  del st.session_state.messages
80
- if "image_file" in st.session_state:
81
- del st.session_state.image_file
82
 
83
  # Callback function to clear the chat when a new image is uploaded
84
  def on_image_upload():
@@ -92,11 +148,13 @@ else:
92
  message_count = int(message_count)
93
 
94
  # If the message limit has been reached, disable the inputs
95
- if message_count >= 20:
 
96
  st.error("Notice: The maximum message limit for this demo version has been reached.")
97
  # Disabling the uploader and input by not displaying them
98
  image_uploader_placeholder = st.empty() # Placeholder for the uploader
99
  chat_input_placeholder = st.empty() # Placeholder for the chat input
 
100
  else:
101
  # Add a clear chat button
102
  if st.button("Clear Chat"):
@@ -104,16 +162,38 @@ else:
104
 
105
  # Image upload section.
106
  image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
107
- if image_file:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Display the uploaded image at a standard width.
109
- st.image(image_file, caption='Uploaded Image.', width=200)
 
 
110
  # Process the uploaded image to get a caption.
111
- image_data = BytesIO(image_file.getvalue())
112
- img_desc = get_image_caption(image_data)
113
- st.write("Image Uploaded Successfully. Ask me anything about it.")
 
 
 
 
114
 
115
  # Initialize the chat engine with the image description.
116
- chat_engine = create_chat_engine(img_desc, os.environ["GOOGLE_API_KEY"])
 
 
117
 
118
  # Initialize session state for messages if it doesn't exist
119
  if "messages" not in st.session_state:
@@ -121,8 +201,9 @@ else:
121
 
122
  # Display previous messages
123
  for message in st.session_state.messages:
124
- with st.chat_message(message["role"]):
125
- st.markdown(message["content"])
 
126
 
127
  # Handle new user input
128
  user_input = st.chat_input("Ask me about the image:", key="chat_input")
@@ -132,24 +213,27 @@ else:
132
 
133
  # Display user message immediately
134
  with st.chat_message("user"):
135
- st.markdown(user_input)
136
 
137
  # Call the chat engine to get the response if an image has been uploaded
138
- if image_file and user_input:
139
  try:
140
  with st.spinner('Waiting for the chat engine to respond...'):
141
  # Get the response from your chat engine
142
- response = chat_engine.chat(user_input)
143
-
 
 
 
144
  # Append assistant message to the session state
145
- st.session_state.messages.append({"role": "assistant", "content": response})
146
 
147
  # Display the assistant message
148
  with st.chat_message("assistant"):
149
- st.markdown(response)
150
 
151
  except Exception as e:
152
- st.error(f'An error occurred: {e}')
153
  # Optionally, you can choose to break the flow here if a critical error happens
154
  # return
155
 
@@ -157,9 +241,3 @@ else:
157
  message_count += 1
158
  cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))
159
 
160
-
161
-
162
-
163
- # Set Replicate and Google API keys
164
- os.environ['REPLICATE_API_TOKEN'] = st.secrets['REPLICATE_API_TOKEN']
165
- os.environ["GOOGLE_API_KEY"] = st.secrets['GOOGLE_API_KEY']
 
2
  import extra_streamlit_components as stx
3
  import requests
4
  from PIL import Image
 
5
  from io import BytesIO
 
6
  from llama_index.llms.palm import PaLM
7
+ from llama_index import ServiceContext, VectorStoreIndex, Document, StorageContext, load_index_from_storage
8
  from llama_index.memory import ChatMemoryBuffer
9
  import os
10
  import datetime
11
 
12
+ #imports for resnet
13
+ from transformers import AutoFeatureExtractor, ResNetForImageClassification
14
+ import torch
15
+ from io import BytesIO
16
+
17
  # Set up the title of the application
18
+ st.title("AInimal Go!")
19
+ #st.set_page_config(layout="wide")
20
+ st.write("My Pokemon Go inspired 'AInimal Go!' app. You can upload an image or snap a picture of an animal and start chatting with it")
21
 
22
  # Sidebar
23
  st.sidebar.markdown('## Created By')
24
  st.sidebar.markdown("""
25
+ Harshad Suryawanshi
26
+ - [Linkedin](https://www.linkedin.com/in/harshadsuryawanshi/)
27
+ - [Medium](https://harshadsuryawanshi.medium.com/)
28
  """)
29
 
30
+
31
  st.sidebar.markdown('## Other Projects')
32
  st.sidebar.markdown("""
33
+ - [Building My Own GPT4-V with PaLM and Kosmos](https://lnkd.in/dawgKZBP)
34
  - [AI Equity Research Analyst](https://ai-eqty-rsrch-anlyst.streamlit.app/)
35
  - [Recasting "The Office" Scene](https://blackmirroroffice.streamlit.app/)
36
  - [Story Generator](https://appstorycombined-agaf9j4ceit.streamlit.app/)
 
38
 
39
  st.sidebar.markdown('## Disclaimer')
40
  st.sidebar.markdown("""
41
+ This application, titled 'AInimal Go!', is a conceptual prototype designed to demonstrate the innovative use of Large Language Models (LLMs) in enabling interactive conversations with animals through images. While the concept is vaguely inspired by the interactive and augmented reality elements popularized by games like Pokemon Go, it does not use any assets, characters, or intellectual property from the Pokemon franchise. The interactions and conversations generated by this application are entirely fictional and created for entertainment and educational purposes. They should not be regarded as factual or accurate representations of animal behavior or communication. The author and the application do not hold any affiliation with the Pokemon brand or its creators, and no endorsement from them is implied. Users are encouraged to use this application responsibly and with an understanding of its purely illustrative nature.
42
  """)
43
 
44
  # Initialize the cookie manager
45
  cookie_manager = stx.CookieManager()
46
 
47
+ #Function to init resnet
48
+
49
+ @st.cache_resource(show_spinner="Initializing ResNet model for image classification. Please wait...")
50
+ def load_model_and_labels():
51
+ # Load animal labels as a dictionary
52
+ animal_labels_dict = {}
53
+ with open('imagenet_animal_labels_subset.txt', 'r') as file:
54
+ for line in file:
55
+ parts = line.strip().split(':')
56
+ class_id = int(parts[0].strip())
57
+ label_name = parts[1].strip().strip("'")
58
+ animal_labels_dict[class_id] = label_name
59
+
60
+ # Initialize feature extractor and model
61
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18")
62
+ model = ResNetForImageClassification.from_pretrained("microsoft/resnet-18")
63
+
64
+ return feature_extractor, model, animal_labels_dict
65
+
66
+ feature_extractor, model, animal_labels_dict = load_model_and_labels()
67
+
68
+ # Function to predict image label
69
  @st.cache_data
70
  def get_image_caption(image_data):
71
+ image = Image.open(image_data)
72
+ inputs = feature_extractor(images=image, return_tensors="pt")
73
+
74
+ with torch.no_grad():
75
+ logits = model(**inputs).logits
76
+
77
+ predicted_label_id = logits.argmax(-1).item()
78
+ predicted_label_name = model.config.id2label[predicted_label_id]
79
+ st.write(predicted_label_name)
80
+ # Return the predicted animal name
81
+ return predicted_label_name, predicted_label_id
82
+
83
+
84
+ @st.cache_resource(show_spinner="Initializing LLM and setting up service context. Please wait...")
85
+ def init_llm(api_key):
86
+ llm = PaLM(api_key=api_key)
87
+ service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
88
+
89
+ storage_context = StorageContext.from_defaults(persist_dir="storage")
90
+ index = load_index_from_storage(storage_context, index_id="index", service_context=service_context)
91
+ chatmemory = ChatMemoryBuffer.from_defaults(token_limit=1500)
92
+
93
+ return llm, service_context, storage_context, index, chatmemory
94
+
95
+ llm, service_context, storage_context, index, chatmemory = init_llm(st.secrets['GOOGLE_API_KEY'])
96
+
97
+ def is_animal(predicted_label_id):
98
+ # Check if the predicted label ID is within the animal classes range
99
+ return 0 <= predicted_label_id <= 398
100
+
101
 
102
  # Function to create the chat engine.
103
  @st.cache_resource
104
  def create_chat_engine(img_desc, api_key):
105
+
106
+ #llm = PaLM(api_key=api_key)
107
+ #service_context = ServiceContext.from_defaults(llm=llm,embed_model="local")
108
  doc = Document(text=img_desc)
109
+
110
+ # Now is_animal is a boolean indicating whether the image is of an animal
111
+ print("Is the image of an animal:", is_animal)
112
 
113
  chat_engine = index.as_chat_engine(
114
+ chat_mode="react",
115
  system_prompt=(
116
+ #f"You are a chatbot, able to have normal interactions, as well as talk. "
117
+ #"You always answer in great detail and are polite. Your responses always descriptive. "
118
+ #"Your job is to talk about an image the user has uploaded. Image description: {img_desc}."
119
+ f"""You are a chatbot, able to have normal interactions, as well as talk.
120
+ You always answer in great detail and are polite. Your responses always descriptive.
121
+ Your job is to rolelpay as the animal that is mentioned in the image the user has uploaded. Image description: {img_desc}."""
122
+
123
  ),
124
  verbose=True,
125
  memory=chatmemory
126
  )
127
+
128
+ return chat_engine
129
+
130
+
131
 
132
  # Clear chat function
133
  def clear_chat():
134
  if "messages" in st.session_state:
135
  del st.session_state.messages
136
+ if "image_data" in st.session_state:
137
+ del st.session_state.image_data
138
 
139
  # Callback function to clear the chat when a new image is uploaded
140
  def on_image_upload():
 
148
  message_count = int(message_count)
149
 
150
  # If the message limit has been reached, disable the inputs
151
+ #if message_count <= 20:
152
+ if 0:
153
  st.error("Notice: The maximum message limit for this demo version has been reached.")
154
  # Disabling the uploader and input by not displaying them
155
  image_uploader_placeholder = st.empty() # Placeholder for the uploader
156
  chat_input_placeholder = st.empty() # Placeholder for the chat input
157
+ st.stop()
158
  else:
159
  # Add a clear chat button
160
  if st.button("Clear Chat"):
 
162
 
163
  # Image upload section.
164
  image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], key="uploaded_image", on_change=on_image_upload)
165
+
166
+ col1, col2, col3 = st.columns([1, 2, 1])
167
+ with col2: # Camera input will be in the middle column
168
+ camera_image = st.camera_input("Take a picture", on_change=on_image_upload)
169
+
170
+
171
+ # Determine the source of the image (upload or camera)
172
+ if image_file is not None:
173
+ image_data = BytesIO(image_file.getvalue())
174
+ elif camera_image is not None:
175
+ image_data = BytesIO(camera_image.getvalue())
176
+ else:
177
+ image_data = None
178
+
179
+ if image_data:
180
  # Display the uploaded image at a standard width.
181
+ st.session_state['assistant_avatar'] = image_data
182
+ st.image(image_data, caption='Uploaded Image.', width=200)
183
+
184
  # Process the uploaded image to get a caption.
185
+ #img_desc = get_image_caption(image_data)
186
+ img_desc, label_id = get_image_caption(image_data)
187
+
188
+ if not (is_animal(label_id)):
189
+ #st.error("Please upload image of an animal!")
190
+ st.error("Please upload image of an animal!")
191
+ st.stop()
192
 
193
  # Initialize the chat engine with the image description.
194
+ chat_engine = create_chat_engine(img_desc, st.secrets['GOOGLE_API_KEY'])
195
+ st.write("Image Uploaded Successfully. Ask me anything about it.")
196
+
197
 
198
  # Initialize session state for messages if it doesn't exist
199
  if "messages" not in st.session_state:
 
201
 
202
  # Display previous messages
203
  for message in st.session_state.messages:
204
+ avatar = st.session_state['assistant_avatar'] if message["role"] == "assistant" else None
205
+ with st.chat_message(message["role"], avatar = avatar):
206
+ st.write(message["content"])
207
 
208
  # Handle new user input
209
  user_input = st.chat_input("Ask me about the image:", key="chat_input")
 
213
 
214
  # Display user message immediately
215
  with st.chat_message("user"):
216
+ st.write(user_input)
217
 
218
  # Call the chat engine to get the response if an image has been uploaded
219
+ if image_data and user_input:
220
  try:
221
  with st.spinner('Waiting for the chat engine to respond...'):
222
  # Get the response from your chat engine
223
+ response = chat_engine.chat(f"""You are a chatbot that roleplays as an animal and also makes animal sounds when chatting.
224
+ You always answer in great detail and are polite. Your responses always descriptive.
225
+ Your job is to rolelpay as the animal that is mentioned in the image the user has uploaded. Image description: {img_desc}. User question
226
+ {user_input}""")
227
+
228
  # Append assistant message to the session state
229
+ st.session_state.messages.append({"role": "assistant", "content": response.response})
230
 
231
  # Display the assistant message
232
  with st.chat_message("assistant"):
233
+ st.write(response.response)
234
 
235
  except Exception as e:
236
+ st.error(f'An error occurred.')
237
  # Optionally, you can choose to break the flow here if a critical error happens
238
  # return
239
 
 
241
  message_count += 1
242
  cookie_manager.set('message_count', str(message_count), expires_at=datetime.datetime.now() + datetime.timedelta(days=30))
243