ziyadsuper2017 commited on
Commit
5837eff
·
1 Parent(s): eabca2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -23
app.py CHANGED
@@ -40,37 +40,48 @@ def send_message():
40
  user_input = st.session_state.user_input
41
  uploaded_files = st.session_state.uploaded_files
42
 
43
- if user_input or uploaded_files:
44
- # Save user input to the chat history
45
- if user_input:
46
- st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
47
-
48
- # Process uploaded images
49
- image_prompts = []
50
- if uploaded_files:
51
- for uploaded_file in uploaded_files:
52
- image = Image.open(uploaded_file).convert("RGB") # Ensure image is in RGB
53
- image_base64 = get_image_base64(image)
54
- image_prompt = {"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": image_base64}]}
55
- image_prompts.append(image_prompt)
56
- st.session_state['chat_history'].extend(image_prompts)
57
-
58
- # Choose the appropriate model based on the input type
59
- model_name = 'gemini-pro-vision' if uploaded_files else 'gemini-pro'
60
  model = genai.GenerativeModel(
61
- model_name=model_name,
62
  generation_config=generation_config,
63
  safety_settings=safety_settings
64
  )
65
-
66
- # Generate the response
67
  response = model.generate_content(st.session_state['chat_history'])
68
  response_text = response.text if hasattr(response, "text") else "No response text found."
69
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # Clear the user input and generate a new key for the file uploader widget to reset it
72
- st.session_state.user_input = ''
73
- st.session_state.file_uploader_key = str(uuid.uuid4())
74
 
75
  # Multiline text input for the user to send messages
76
  user_input = st.text_area("Enter your message here:", key="user_input", value="")
 
40
  user_input = st.session_state.user_input
41
  uploaded_files = st.session_state.uploaded_files
42
 
43
+ text_prompts = []
44
+ image_prompts = []
45
+
46
+ # Process text input for multi-turn conversation
47
+ if user_input:
48
+ text_prompts.append({"role": "user", "parts": [{"text": user_input}]})
49
+ st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
50
+
51
+ # Process uploaded images for single-turn conversation
52
+ if uploaded_files:
53
+ for uploaded_file in uploaded_files:
54
+ image = Image.open(uploaded_file).convert("RGB") # Ensure image is in RGB
55
+ image_base64 = get_image_base64(image)
56
+ image_prompts.append({"role": "user", "parts": [{"mime_type": uploaded_file.type, "data": image_base64}]})
57
+
58
+ # Generate text response if text input is provided
59
+ if text_prompts:
60
  model = genai.GenerativeModel(
61
+ model_name='gemini-pro',
62
  generation_config=generation_config,
63
  safety_settings=safety_settings
64
  )
 
 
65
  response = model.generate_content(st.session_state['chat_history'])
66
  response_text = response.text if hasattr(response, "text") else "No response text found."
67
  st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
68
+
69
+ # Generate image response if images are uploaded
70
+ if image_prompts:
71
+ model = genai.GenerativeModel(
72
+ model_name='gemini-pro-vision',
73
+ generation_config=generation_config,
74
+ safety_settings=safety_settings
75
+ )
76
+ response = model.generate_content(image_prompts)
77
+ response_text = response.text if hasattr(response, "text") else "No response text found."
78
+ for prompt in image_prompts:
79
+ st.session_state['chat_history'].append(prompt) # Append images to history
80
+ st.session_state['chat_history'].append({"role": "model", "parts": [{"text": response_text}]})
81
 
82
+ # Clear the user input and generate a new key for the file uploader widget to reset it
83
+ st.session_state.user_input = ''
84
+ st.session_state.file_uploader_key = str(uuid.uuid4())
85
 
86
  # Multiline text input for the user to send messages
87
  user_input = st.text_area("Enter your message here:", key="user_input", value="")