JeCabrera commited on
Commit
bc54a0a
·
verified ·
1 Parent(s): 8daaad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -48
app.py CHANGED
@@ -1,58 +1,211 @@
1
- import gradio as gr
2
- import google.generativeai as genai
3
  import os
4
- from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Cargar variables de entorno
7
- load_dotenv()
 
8
 
9
- # Configurar la API de Google Gemini
10
- genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
11
 
12
- def respond(
13
- message,
14
- history,
15
- system_message,
16
- max_tokens,
17
- temperature,
18
- top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  ):
20
- messages = [{"role": "system", "content": system_message}]
21
-
22
- for val in history:
23
- if val[0]:
24
- messages.append({"role": "user", "content": val[0]})
25
- if val[1]:
26
- messages.append({"role": "assistant", "content": val[1]})
27
-
28
- messages.append({"role": "user", "content": message})
29
-
30
- # Configurar el modelo de Gemini
31
- model_name = "gemini-1.5-pro" # Ajusta el modelo según sea necesario
32
- generation_config = {
33
- "temperature": temperature,
34
- "top_p": top_p,
35
- "max_output_tokens": max_tokens,
36
- "response_mime_type": "text/plain",
37
- }
38
- model = genai.GenerativeModel(model_name=model_name, generation_config=generation_config)
39
- chat_session = model.start_chat(
40
- history=messages
41
  )
42
- response = chat_session.send_message(message)
43
- return response.text
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Crear la interfaz de Gradio
47
- demo = gr.ChatInterface(
48
- respond,
49
- additional_inputs=[
50
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
54
- ],
 
 
 
 
 
 
 
55
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- if __name__ == "__main__":
58
- demo.launch()
 
 
 
1
  import os
2
+ import time
3
+ import uuid
4
+ from typing import List, Tuple, Optional, Dict, Union
5
+
6
+ import google.generativeai as genai
7
+ import gradio as gr
8
+ from PIL import Image
9
+
10
+ print("google-generativeai:", genai.__version__)
11
+
12
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
13
+
14
+ TITLE = """<h1 align="center">Gemini Playground 💬</h1>"""
15
+ SUBTITLE = """<h2 align="center">Play with Gemini Pro and Gemini Pro Vision</h2>"""
16
+ DES = """
17
+ <div style="text-align: center; display: flex; justify-content: center; align-items: center;">
18
+ <span>Run with your
19
+ <a href="https://makersuite.google.com/app/apikey">GOOGLE API KEY</a>.
20
+ </span>
21
+ </div>
22
+ """
23
 
24
+ IMAGE_CACHE_DIRECTORY = "/tmp"
25
+ IMAGE_WIDTH = 512
26
+ CHAT_HISTORY = List[Tuple[Optional[Union[Tuple[str], str]], Optional[str]]]
27
 
28
+ def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
29
+ return [sequence.strip() for sequence in stop_sequences.split(",")] if stop_sequences else None
30
 
31
+ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
32
+ if image:
33
+ image_height = int(image.height * IMAGE_WIDTH / image.width)
34
+ return image.resize((IMAGE_WIDTH, image_height))
35
+
36
+ def cache_pil_image(image: Image.Image) -> str:
37
+ image_filename = f"{uuid.uuid4()}.jpeg"
38
+ os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
39
+ image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
40
+ image.save(image_path, "JPEG")
41
+ return image_path
42
+
43
+ def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
44
+ for file in files:
45
+ image = Image.open(file).convert('RGB')
46
+ image_preview = preprocess_image(image)
47
+ if image_preview:
48
+ # Display a preview of the uploaded image
49
+ gr.Image(image_preview).render()
50
+ image_path = cache_pil_image(image)
51
+ chatbot.append(((image_path,), None))
52
+ return chatbot
53
+
54
+ def user(text_prompt: str, chatbot: CHAT_HISTORY):
55
+ if text_prompt:
56
+ chatbot.append((text_prompt, None))
57
+ return "", chatbot
58
+
59
+ def bot(
60
+ google_key: str,
61
+ files: Optional[List[str]],
62
+ temperature: float,
63
+ max_output_tokens: int,
64
+ stop_sequences: str,
65
+ top_k: int,
66
+ top_p: float,
67
+ chatbot: CHAT_HISTORY
68
  ):
69
+ if not google_key and not GOOGLE_API_KEY:
70
+ raise ValueError("GOOGLE_API_KEY is not set.")
71
+
72
+ genai.configure(api_key=google_key if google_key else GOOGLE_API_KEY)
73
+ generation_config = genai.types.GenerationConfig(
74
+ temperature=temperature,
75
+ max_output_tokens=max_output_tokens,
76
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
77
+ top_k=top_k,
78
+ top_p=top_p
 
 
 
 
 
 
 
 
 
 
 
79
  )
 
 
80
 
81
+ text_prompt = [chatbot[-1][0]] if chatbot and chatbot[-1][0] and isinstance(chatbot[-1][0], str) else []
82
+ image_prompt = [preprocess_image(Image.open(file).convert('RGB')) for file in files] if files else []
83
+ model_name = 'gemini-pro-vision' if files else 'gemini-pro'
84
+ model = genai.GenerativeModel(model_name)
85
+ response = model.generate_content(text_prompt + image_prompt, stream=True, generation_config=generation_config)
86
+
87
+ chatbot[-1][1] = ""
88
+ for chunk in response:
89
+ for i in range(0, len(chunk.text), 10):
90
+ section = chunk.text[i:i + 10]
91
+ chatbot[-1][1] += section
92
+ time.sleep(0.01)
93
+ yield chatbot
94
 
95
+ google_key_component = gr.Textbox(
96
+ label="GOOGLE API KEY",
97
+ value="",
98
+ type="password",
99
+ placeholder="...",
100
+ info="Please provide your own GOOGLE_API_KEY for this app",
101
+ visible=GOOGLE_API_KEY is None
102
+ )
103
+ chatbot_component = gr.Chatbot(
104
+ label='Gemini',
105
+ bubble_full_width=False,
106
+ scale=2,
107
+ height=600
108
+ )
109
+ text_prompt_component = gr.Textbox(
110
+ placeholder="Message...", show_label=False, autofocus=True, scale=8
111
  )
112
+ upload_button_component = gr.UploadButton(
113
+ label="Upload Images", file_count="multiple", file_types=["image"], scale=1
114
+ )
115
+ run_button_component = gr.Button(value="Run", variant="primary", scale=1)
116
+ temperature_component = gr.Slider(
117
+ minimum=0,
118
+ maximum=1.0,
119
+ value=0.4,
120
+ step=0.05,
121
+ label="Temperature",
122
+ )
123
+ max_output_tokens_component = gr.Slider(
124
+ minimum=1,
125
+ maximum=2048,
126
+ value=1024,
127
+ step=1,
128
+ label="Token limit",
129
+ )
130
+ stop_sequences_component = gr.Textbox(
131
+ label="Add stop sequence",
132
+ value="",
133
+ type="text",
134
+ placeholder="STOP, END",
135
+ )
136
+ top_k_component = gr.Slider(
137
+ minimum=1,
138
+ maximum=40,
139
+ value=32,
140
+ step=1,
141
+ label="Top-K",
142
+ )
143
+ top_p_component = gr.Slider(
144
+ minimum=0,
145
+ maximum=1,
146
+ value=1,
147
+ step=0.01,
148
+ label="Top-P",
149
+ )
150
+
151
+ user_inputs = [
152
+ text_prompt_component,
153
+ chatbot_component
154
+ ]
155
+
156
+ bot_inputs = [
157
+ google_key_component,
158
+ upload_button_component,
159
+ temperature_component,
160
+ max_output_tokens_component,
161
+ stop_sequences_component,
162
+ top_k_component,
163
+ top_p_component,
164
+ chatbot_component
165
+ ]
166
+
167
+ with gr.Blocks() as demo:
168
+ gr.HTML(TITLE)
169
+ gr.HTML(SUBTITLE)
170
+ gr.HTML(DES)
171
+ with gr.Column():
172
+ google_key_component.render()
173
+ chatbot_component.render()
174
+ with gr.Row():
175
+ text_prompt_component.render()
176
+ upload_button_component.render()
177
+ run_button_component.render()
178
+ with gr.Accordion("Parameters", open=False):
179
+ temperature_component.render()
180
+ max_output_tokens_component.render()
181
+ stop_sequences_component.render()
182
+ with gr.Accordion("Advanced", open=False):
183
+ top_k_component.render()
184
+ top_p_component.render()
185
+
186
+ run_button_component.click(
187
+ fn=user,
188
+ inputs=user_inputs,
189
+ outputs=[text_prompt_component, chatbot_component],
190
+ queue=False
191
+ ).then(
192
+ fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
193
+ )
194
+
195
+ text_prompt_component.submit(
196
+ fn=user,
197
+ inputs=user_inputs,
198
+ outputs=[text_prompt_component, chatbot_component],
199
+ queue=False
200
+ ).then(
201
+ fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
202
+ )
203
+
204
+ upload_button_component.upload(
205
+ fn=upload,
206
+ inputs=[upload_button_component, chatbot_component],
207
+ outputs=[chatbot_component],
208
+ queue=False
209
+ )
210
 
211
+ demo.queue(max_size=99).launch(debug=False, show_error=True)