ziyadsuper2017 commited on
Commit
6ae7b4c
·
1 Parent(s): 5837eff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -24
app.py CHANGED
@@ -18,7 +18,7 @@ generation_config = genai.GenerationConfig(
18
 
19
  safety_settings = []
20
 
21
- # Initialize session state for chat history and file uploader key
22
  if 'chat_history' not in st.session_state:
23
  st.session_state['chat_history'] = []
24
  if 'file_uploader_key' not in st.session_state:
@@ -39,35 +39,31 @@ def get_image_base64(image):
39
  def send_message():
40
  user_input = st.session_state.user_input
41
  uploaded_files = st.session_state.uploaded_files
42
-
43
- text_prompts = []
44
- image_prompts = []
45
-
46
- # Process text input for multi-turn conversation
47
  if user_input:
48
- text_prompts.append({"role": "user", "parts": [{"text": user_input}]})
49
- st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
 
50
 
51
- # Process uploaded images for single-turn conversation
52
- if uploaded_files:
53
- for uploaded_file in uploaded_files:
54
- image = Image.open(uploaded_file).convert("RGB") # Ensure image is in RGB
55
- image_base64 = get_image_base64(image)
56
- image_prompts.append({"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": image_base64}]})
57
-
58
- # Generate text response if text input is provided
59
- if text_prompts:
60
  model = genai.GenerativeModel(
61
  model_name='gemini-pro',
62
  generation_config=generation_config,
63
  safety_settings=safety_settings
64
  )
65
- response = model.generate_content(st.session_state['chat_history'])
66
  response_text = response.text if hasattr(response, "text") else "No response text found."
67
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
68
-
69
- # Generate image response if images are uploaded
70
- if image_prompts:
 
 
 
 
 
 
 
 
71
  model = genai.GenerativeModel(
72
  model_name='gemini-pro-vision',
73
  generation_config=generation_config,
@@ -75,16 +71,18 @@ def send_message():
75
  )
76
  response = model.generate_content(image_prompts)
77
  response_text = response.text if hasattr(response, "text") else "No response text found."
78
- for prompt in image_prompts:
79
- st.session_state['chat_history'].append(prompt) # Append images to history
80
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
81
 
82
  # Clear the user input and generate a new key for the file uploader widget to reset it
83
  st.session_state.user_input = ''
 
84
  st.session_state.file_uploader_key = str(uuid.uuid4())
85
 
 
 
 
86
  # Multiline text input for the user to send messages
87
- user_input = st.text_area("Enter your message here:", key="user_input", value="")
88
 
89
  # File uploader for images
90
  uploaded_files = st.file_uploader(
@@ -97,6 +95,9 @@ uploaded_files = st.file_uploader(
97
  # Button to send the message
98
  send_button = st.button("Send", on_click=send_message)
99
 
 
 
 
100
  # Display the chat history
101
  for entry in st.session_state['chat_history']:
102
  role = entry["role"]
@@ -104,7 +105,10 @@ for entry in st.session_state['chat_history']:
104
  if 'text' in parts:
105
  st.markdown(f"**{role.title()}**: {parts['text']}")
106
  elif 'data' in parts:
 
 
107
  st.markdown(f"**{role.title()}**: (Image)")
 
108
 
109
  # Ensure the file_uploader widget state is tied to the randomly generated key
110
  st.session_state.uploaded_files = uploaded_files
 
18
 
19
  safety_settings = []
20
 
21
+ # Initialize session state for chat history
22
  if 'chat_history' not in st.session_state:
23
  st.session_state['chat_history'] = []
24
  if 'file_uploader_key' not in st.session_state:
 
39
  def send_message():
40
  user_input = st.session_state.user_input
41
  uploaded_files = st.session_state.uploaded_files
42
+
 
 
 
 
43
  if user_input:
44
+ # Send text to the gemini-pro model for text-based conversation
45
+ text_prompt = {"role": "user", "parts": [{"text": user_input}]}
46
+ st.session_state['chat_history'].append(text_prompt)
47
 
 
 
 
 
 
 
 
 
 
48
  model = genai.GenerativeModel(
49
  model_name='gemini-pro',
50
  generation_config=generation_config,
51
  safety_settings=safety_settings
52
  )
53
+ response = model.generate_content([text_prompt])
54
  response_text = response.text if hasattr(response, "text") else "No response text found."
55
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
56
+
57
+ if uploaded_files:
58
+ # Send images to the gemini-pro-vision model for image-based conversation
59
+ image_prompts = []
60
+ for uploaded_file in uploaded_files:
61
+ image = Image.open(uploaded_file).convert("RGB") # Ensure image is in RGB
62
+ image_base64 = get_image_base64(image)
63
+ image_prompt = {"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": image_base64}]}
64
+ image_prompts.append(image_prompt)
65
+ st.session_state['chat_history'].append(image_prompt)
66
+
67
  model = genai.GenerativeModel(
68
  model_name='gemini-pro-vision',
69
  generation_config=generation_config,
 
71
  )
72
  response = model.generate_content(image_prompts)
73
  response_text = response.text if hasattr(response, "text") else "No response text found."
 
 
74
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
75
 
76
  # Clear the user input and generate a new key for the file uploader widget to reset it
77
  st.session_state.user_input = ''
78
+ st.session_state.uploaded_files = None
79
  st.session_state.file_uploader_key = str(uuid.uuid4())
80
 
81
+ def clear_conversation():
82
+ st.session_state.chat_history = []
83
+
84
  # Multiline text input for the user to send messages
85
+ user_input = st.text_area("Enter your message here:", key="user_input")
86
 
87
  # File uploader for images
88
  uploaded_files = st.file_uploader(
 
95
  # Button to send the message
96
  send_button = st.button("Send", on_click=send_message)
97
 
98
+ # Button to clear the conversation
99
+ clear_button = st.button("Clear Conversation", on_click=clear_conversation)
100
+
101
  # Display the chat history
102
  for entry in st.session_state['chat_history']:
103
  role = entry["role"]
 
105
  if 'text' in parts:
106
  st.markdown(f"**{role.title()}**: {parts['text']}")
107
  elif 'data' in parts:
108
+ # Decode the base64 image data
109
+ base64_data = parts['data']
110
  st.markdown(f"**{role.title()}**: (Image)")
111
+ st.image(base64_data, caption=f"Image from {role}", use_column_width=True)
112
 
113
  # Ensure the file_uploader widget state is tied to the randomly generated key
114
  st.session_state.uploaded_files = uploaded_files