Kims12 commited on
Commit
3e624fe
ยท
verified ยท
1 Parent(s): 5dfaa58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -64
app.py CHANGED
@@ -4,7 +4,6 @@ from PIL import Image
4
  import gradio as gr
5
  import logging
6
  import re
7
- import io
8
 
9
  from google import genai
10
  from google.genai import types
@@ -25,6 +24,7 @@ def preprocess_prompt(prompt, image1, image2, image3):
25
  """
26
  ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ๊ธฐ๋Šฅ ๋ช…๋ น์„ ํ•ด์„
27
  """
 
28
  # ์ด๋ฏธ์ง€ ์—†๋Š” ์ฐธ์กฐ ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ
29
  has_img1 = image1 is not None
30
  has_img2 = image2 is not None
@@ -96,9 +96,9 @@ def preprocess_prompt(prompt, image1, image2, image3):
96
 
97
  return prompt
98
 
99
- def process_images_with_prompt(image1, image2, image3, prompt):
100
  """
101
- 3๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
102
  """
103
  try:
104
  # API ํ‚ค ํ™•์ธ
@@ -109,86 +109,91 @@ def process_images_with_prompt(image1, image2, image3, prompt):
109
  # Gemini ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
110
  client = genai.Client(api_key=api_key)
111
 
112
- # ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ ํ™•์ธ
113
- images = [image1, image2, image3]
114
- valid_images = [img for img in images if img is not None]
115
 
116
- if not valid_images:
117
- return None, "์ ์–ด๋„ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
 
118
 
119
- # ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
120
- if not prompt or not prompt.strip():
121
- # ํ”„๋กฌํ”„ํŠธ๊ฐ€ ์—†์œผ๋ฉด ์—…๋กœ๋“œ๋œ ์ด๋ฏธ์ง€ ์ˆ˜์— ๋”ฐ๋ผ ์ž๋™ ํ•ฉ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
122
- if len(valid_images) == 1:
123
- prompt = "์ด ์ด๋ฏธ์ง€๋ฅผ ์ฐฝ์˜์ ์œผ๋กœ ๋ณ€ํ˜•ํ•ด์ฃผ์„ธ์š”. ๋” ์ƒ์ƒํ•˜๊ณ  ์˜ˆ์ˆ ์ ์ธ ๋ฒ„์ „์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."
124
- logger.info("๋‹จ์ผ ์ด๋ฏธ์ง€ ํ”„๋กฌํ”„ํŠธ ์ž๋™ ์ƒ์„ฑ")
125
- elif len(valid_images) == 2:
126
- prompt = "์ด ๋‘ ์ด๋ฏธ์ง€๋ฅผ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ•ฉ์„ฑํ•ด์ฃผ์„ธ์š”. ๋‘ ์ด๋ฏธ์ง€์˜ ์š”์†Œ๋ฅผ ์กฐํ™”๋กญ๊ฒŒ ํ†ตํ•ฉํ•˜์—ฌ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."
127
- logger.info("๋‘ ์ด๋ฏธ์ง€ ํ•ฉ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ž๋™ ์ƒ์„ฑ")
128
- else:
129
- prompt = "์ด ์„ธ ์ด๋ฏธ์ง€๋ฅผ ์ฐฝ์˜์ ์œผ๋กœ ํ•ฉ์„ฑํ•ด์ฃผ์„ธ์š”. ๋ชจ๋“  ์ด๋ฏธ์ง€์˜ ์ฃผ์š” ์š”์†Œ๋ฅผ ํฌํ•จํ•˜๋˜ ์ž์—ฐ์Šค๋Ÿฝ๊ณ  ์ผ๊ด€๋œ ํ•˜๋‚˜์˜ ์žฅ๋ฉด์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."
130
- logger.info("์„ธ ์ด๋ฏธ์ง€ ํ•ฉ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ž๋™ ์ƒ์„ฑ")
131
- else:
132
- # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ ๋ฐ ๊ธฐ๋Šฅ ๋ช…๋ น ํ•ด์„
133
- prompt = preprocess_prompt(prompt, image1, image2, image3)
134
 
135
- # ์ปจํ…์ธ  ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ (์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ ๊ฒฐํ•ฉ)
136
  parts = []
137
 
 
 
 
 
 
 
 
 
 
138
  # ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
139
- parts.append({
140
- "text": prompt
141
- })
142
 
143
- # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
144
- for idx, img in enumerate(images, 1):
145
- if img is not None:
146
- # PIL ์ด๋ฏธ์ง€๋ฅผ ๋ฐ”์ดํŠธ๋กœ ๋ณ€ํ™˜
147
- img_byte_arr = io.BytesIO()
148
- img.save(img_byte_arr, format='PNG')
149
- img_bytes = img_byte_arr.getvalue()
150
-
151
- # ์ด๋ฏธ์ง€๋ฅผ ํŒŒํŠธ๋กœ ์ถ”๊ฐ€ (genai.types.Part ๋Œ€์‹  ์ง์ ‘ ๋”•์…”๋„ˆ๋ฆฌ ์‚ฌ์šฉ)
152
- parts.append({
153
- "inline_data": {
154
- "mime_type": "image/png",
155
- "data": img_bytes
156
- }
157
- })
158
- logger.info(f"์ด๋ฏธ์ง€ #{idx} ์ถ”๊ฐ€๋จ")
159
 
160
  # ์ƒ์„ฑ ์„ค์ •
161
- generate_content_config = {
162
- "temperature": 1.0,
163
- "response_modalities": ["image"]
164
- }
165
 
166
  # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
167
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
168
  temp_path = tmp.name
169
 
170
  # Gemini ๋ชจ๋ธ๋กœ ์š”์ฒญ ์ „์†ก
171
- logger.info(f"Gemini API ์š”์ฒญ ์‹œ์ž‘ - ํ”„๋กฌํ”„ํŠธ: {prompt}")
172
-
173
- # API ์š”์ฒญ ํ˜•์‹ ์ˆ˜์ •
174
- response = client.models.generate_content(
175
  model="gemini-2.0-flash-exp-image-generation",
176
- contents=[{
177
- "role": "user",
178
- "parts": parts
179
- }],
180
- generation_config=generate_content_config
181
  )
182
 
183
  # ์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ
184
  image_found = False
 
 
 
 
 
 
 
 
 
 
185
 
186
- # ์‘๋‹ต ๊ตฌ์กฐ ์ฒ˜๋ฆฌ
187
- for part in response.candidates[0].content.parts:
188
- if hasattr(part, 'inline_data') and part.inline_data:
189
- save_binary_file(temp_path, part.inline_data.data)
190
- image_found = True
191
- logger.info("์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ ์„ฑ๊ณต")
 
 
 
 
 
 
 
192
 
193
  if not image_found:
194
  return None, "API์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋กœ ์‹œ๋„ํ•ด๋ณด์„ธ์š”."
@@ -204,7 +209,42 @@ def process_images_with_prompt(image1, image2, image3, prompt):
204
  logger.exception("์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
205
  return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
206
 
207
- # ๊ธฐ๋Šฅ ์„ ํƒ ์ฝœ๋ฐฑ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def update_prompt_from_function(function_choice, custom_text=""):
209
  function_templates = {
210
  "1. ์ด๋ฏธ์ง€ ๋ณ€๊ฒฝ": f'#1์„ "{custom_text if custom_text else "์›ํ•˜๋Š” ์„ค๋ช…"}"์œผ๋กœ ๋ฐ”๊ฟ”๋ผ',
@@ -218,7 +258,7 @@ def update_prompt_from_function(function_choice, custom_text=""):
218
 
219
  return function_templates.get(function_choice, "")
220
 
221
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค
222
  with gr.Blocks() as demo:
223
  gr.HTML(
224
  """
 
4
  import gradio as gr
5
  import logging
6
  import re
 
7
 
8
  from google import genai
9
  from google.genai import types
 
24
  """
25
  ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ๊ธฐ๋Šฅ ๋ช…๋ น์„ ํ•ด์„
26
  """
27
+ # ๊ธฐ์กด preprocess_prompt ํ•จ์ˆ˜ ์ฝ”๋“œ ์œ ์ง€
28
  # ์ด๋ฏธ์ง€ ์—†๋Š” ์ฐธ์กฐ ํ™•์ธ ๋ฐ ์ฒ˜๋ฆฌ
29
  has_img1 = image1 is not None
30
  has_img2 = image2 is not None
 
96
 
97
  return prompt
98
 
99
+ def generate_with_images(prompt, images):
100
  """
101
+ ๋ฌธ์ œ๊ฐ€ ๋œ API ํ˜ธ์ถœ ๋ถ€๋ถ„์„ ์ƒˆ๋กœ์šด ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„
102
  """
103
  try:
104
  # API ํ‚ค ํ™•์ธ
 
109
  # Gemini ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
110
  client = genai.Client(api_key=api_key)
111
 
112
+ logger.info(f"Gemini API ์š”์ฒญ ์‹œ์ž‘ - ํ”„๋กฌํ”„ํŠธ: {prompt}")
 
 
113
 
114
+ # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ ๋ฐ ์ฒ˜๋ฆฌ
115
+ temp_image_paths = []
116
+ file_uris = []
117
 
118
+ # ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ €์žฅ ๋ฐ ์—…๋กœ๋“œ
119
+ for idx, img in enumerate(images, 1):
120
+ if img is not None:
121
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
122
+ temp_path = tmp.name
123
+ img.save(temp_path, format="PNG")
124
+ temp_image_paths.append(temp_path)
125
+
126
+ # ํŒŒ์ผ ์—…๋กœ๋“œ
127
+ uploaded_file = client.files.upload(file=temp_path)
128
+ file_uris.append(uploaded_file)
129
+ logger.info(f"์ด๋ฏธ์ง€ #{idx} ์—…๋กœ๋“œ๋จ: {uploaded_file.uri}")
 
 
 
130
 
131
+ # ์ปจํ…์ธ  ํŒŒํŠธ ๊ตฌ์„ฑ
132
  parts = []
133
 
134
+ # ์ด๋ฏธ์ง€ ํŒŒํŠธ ์ถ”๊ฐ€
135
+ for file in file_uris:
136
+ parts.append(
137
+ types.Part.from_uri(
138
+ file_uri=file.uri,
139
+ mime_type=file.mime_type,
140
+ )
141
+ )
142
+
143
  # ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
144
+ parts.append(types.Part.from_text(text=prompt))
 
 
145
 
146
+ # ์š”์ฒญ ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
147
+ contents = [
148
+ types.Content(
149
+ role="user",
150
+ parts=parts,
151
+ ),
152
+ ]
 
 
 
 
 
 
 
 
 
153
 
154
  # ์ƒ์„ฑ ์„ค์ •
155
+ generate_content_config = types.GenerateContentConfig(
156
+ temperature=1,
157
+ response_modalities=["image"],
158
+ )
159
 
160
  # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
161
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
162
  temp_path = tmp.name
163
 
164
  # Gemini ๋ชจ๋ธ๋กœ ์š”์ฒญ ์ „์†ก
165
+ response_stream = client.models.generate_content_stream(
 
 
 
166
  model="gemini-2.0-flash-exp-image-generation",
167
+ contents=contents,
168
+ config=generate_content_config,
 
 
 
169
  )
170
 
171
  # ์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ
172
  image_found = False
173
+ for chunk in response_stream:
174
+ if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts:
175
+ logger.warning("์ฒญํฌ์— ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Œ. ๊ฑด๋„ˆ๋œ€.")
176
+ continue
177
+
178
+ for part in chunk.candidates[0].content.parts:
179
+ if hasattr(part, 'inline_data') and part.inline_data:
180
+ save_binary_file(temp_path, part.inline_data.data)
181
+ image_found = True
182
+ logger.info("์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ ์„ฑ๊ณต")
183
 
184
+ # ์ž„์‹œ ํŒŒ์ผ ์ •๋ฆฌ
185
+ for path in temp_image_paths:
186
+ try:
187
+ os.unlink(path)
188
+ except:
189
+ pass
190
+
191
+ # ํŒŒ์ผ URI ์ •๋ฆฌ
192
+ for file in file_uris:
193
+ try:
194
+ client.files.delete(fileLocation=file.uri)
195
+ except:
196
+ pass
197
 
198
  if not image_found:
199
  return None, "API์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋กœ ์‹œ๋„ํ•ด๋ณด์„ธ์š”."
 
209
  logger.exception("์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
210
  return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
211
 
212
+ def process_images_with_prompt(image1, image2, image3, prompt):
213
+ """
214
+ 3๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
215
+ """
216
+ try:
217
+ # ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ ํ™•์ธ
218
+ images = [image1, image2, image3]
219
+ valid_images = [img for img in images if img is not None]
220
+
221
+ if not valid_images:
222
+ return None, "์ ์–ด๋„ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
223
+
224
+ # ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
225
+ if not prompt or not prompt.strip():
226
+ # ํ”„๋กฌํ”„ํŠธ๊ฐ€ ์—†์œผ๋ฉด ์—…๋กœ๋“œ๋œ ์ด๋ฏธ์ง€ ์ˆ˜์— ๋”ฐ๋ผ ์ž๋™ ํ•ฉ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
227
+ if len(valid_images) == 1:
228
+ prompt = "์ด ์ด๋ฏธ์ง€๋ฅผ ์ฐฝ์˜์ ์œผ๋กœ ๋ณ€ํ˜•ํ•ด์ฃผ์„ธ์š”. ๋” ์ƒ์ƒํ•˜๊ณ  ์˜ˆ์ˆ ์ ์ธ ๋ฒ„์ „์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."
229
+ logger.info("๋‹จ์ผ ์ด๋ฏธ์ง€ ํ”„๋กฌํ”„ํŠธ ์ž๋™ ์ƒ์„ฑ")
230
+ elif len(valid_images) == 2:
231
+ prompt = "์ด ๋‘ ์ด๋ฏธ์ง€๋ฅผ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ•ฉ์„ฑํ•ด์ฃผ์„ธ์š”. ๋‘ ์ด๋ฏธ์ง€์˜ ์š”์†Œ๋ฅผ ์กฐํ™”๋กญ๊ฒŒ ํ†ตํ•ฉํ•˜์—ฌ ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."
232
+ logger.info("๋‘ ์ด๋ฏธ์ง€ ํ•ฉ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ž๋™ ์ƒ์„ฑ")
233
+ else:
234
+ prompt = "์ด ์„ธ ์ด๋ฏธ์ง€๋ฅผ ์ฐฝ์˜์ ์œผ๋กœ ํ•ฉ์„ฑํ•ด์ฃผ์„ธ์š”. ๋ชจ๋“  ์ด๋ฏธ์ง€์˜ ์ฃผ์š” ์š”์†Œ๋ฅผ ํฌํ•จํ•˜๋˜ ์ž์—ฐ์Šค๋Ÿฝ๊ณ  ์ผ๊ด€๋œ ํ•˜๋‚˜์˜ ์žฅ๋ฉด์œผ๋กœ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”."
235
+ logger.info("์„ธ ์ด๋ฏธ์ง€ ํ•ฉ์„ฑ ํ”„๋กฌํ”„ํŠธ ์ž๋™ ์ƒ์„ฑ")
236
+ else:
237
+ # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ ๋ฐ ๊ธฐ๋Šฅ ๋ช…๋ น ํ•ด์„
238
+ prompt = preprocess_prompt(prompt, image1, image2, image3)
239
+
240
+ # ์ƒˆ๋กœ์šด API ํ˜ธ์ถœ ๋ฐฉ์‹ ์‚ฌ์šฉ
241
+ return generate_with_images(prompt, valid_images)
242
+
243
+ except Exception as e:
244
+ logger.exception("์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
245
+ return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
246
+
247
+ # ๊ธฐ๋Šฅ ์„ ํƒ ์ฝœ๋ฐฑ (๊ธฐ์กด ์ฝ”๋“œ ์œ ์ง€)
248
  def update_prompt_from_function(function_choice, custom_text=""):
249
  function_templates = {
250
  "1. ์ด๋ฏธ์ง€ ๋ณ€๊ฒฝ": f'#1์„ "{custom_text if custom_text else "์›ํ•˜๋Š” ์„ค๋ช…"}"์œผ๋กœ ๋ฐ”๊ฟ”๋ผ',
 
258
 
259
  return function_templates.get(function_choice, "")
260
 
261
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค (๊ธฐ์กด ๏ฟฝ๏ฟฝ๏ฟฝ๋“œ ์œ ์ง€)
262
  with gr.Blocks() as demo:
263
  gr.HTML(
264
  """