Kims12 commited on
Commit
efacf1e
ยท
verified ยท
1 Parent(s): 6d7a6b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -4
app.py CHANGED
@@ -88,6 +88,122 @@ def preprocess_prompt(prompt, image1, image2, image3):
88
 
89
  return prompt
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def process_images_with_prompt(image1, image2, image3, prompt):
93
  """
@@ -119,14 +235,14 @@ def process_images_with_prompt(image1, image2, image3, prompt):
119
  # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ ๋ฐ ๊ธฐ๋Šฅ ๋ช…๋ น ํ•ด์„
120
  prompt = preprocess_prompt(prompt, image1, image2, image3)
121
 
122
- # ์ƒˆ๋กœ์šด API ํ˜ธ์ถœ ๋ฐฉ์‹ ์‚ฌ์šฉ (์žฌ์‹œ๋„ ๊ธฐ๋Šฅ ํฌํ•จ)
123
  return generate_with_images(prompt, valid_images)
124
 
125
  except Exception as e:
126
  logger.exception("์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
127
  return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
128
 
129
- # ๊ธฐ๋Šฅ ์„ ํƒ ์ฝœ๋ฐฑ (์ˆ˜์ •๋จ)
130
  def update_prompt_from_function(function_choice):
131
  function_templates = {
132
  "1. ์ด๋ฏธ์ง€ ๋ณ€๊ฒฝ": "#1์„ ์ฐฝ์˜์ ์œผ๋กœ ๋ฐ”๊ฟ”๋ผ",
@@ -140,7 +256,7 @@ def update_prompt_from_function(function_choice):
140
 
141
  return function_templates.get(function_choice, "")
142
 
143
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค (์ˆ˜์ •๋จ)
144
  with gr.Blocks() as demo:
145
  gr.HTML(
146
  """
@@ -159,7 +275,7 @@ with gr.Blocks() as demo:
159
  image2_input = gr.Image(type="pil", label="#2", image_mode="RGB")
160
  image3_input = gr.Image(type="pil", label="#3", image_mode="RGB")
161
 
162
- # ๊ธฐ๋Šฅ ์„ ํƒ ๋“œ๋กญ๋‹ค์šด (์ปค์Šคํ…€ ํ…์ŠคํŠธ ์ž…๋ ฅ ์ œ๊ฑฐ)
163
  function_dropdown = gr.Dropdown(
164
  choices=[
165
  "1. ์ด๋ฏธ์ง€ ๋ณ€๊ฒฝ",
 
88
 
89
  return prompt
90
 
91
+ def generate_with_images(prompt, images, max_retries=2):
92
+ """
93
+ ๊ณต์‹ ๋ฌธ์„œ์— ๊ธฐ๋ฐ˜ํ•œ ์˜ฌ๋ฐ”๋ฅธ API ํ˜ธ์ถœ ๋ฐฉ์‹ ๊ตฌํ˜„
94
+ ์‹คํŒจ ์‹œ ์žฌ์‹œ๋„ ๊ธฐ๋Šฅ ์ถ”๊ฐ€
95
+ """
96
+ retries = 0
97
+ last_error = None
98
+
99
+ while retries <= max_retries:
100
+ try:
101
+ # API ํ‚ค ํ™•์ธ
102
+ api_key = os.environ.get("GEMINI_API_KEY")
103
+ if not api_key:
104
+ return None, "API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
105
+
106
+ # Gemini ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
107
+ client = genai.Client(api_key=api_key)
108
+
109
+ logger.info(f"Gemini API ์š”์ฒญ ์‹œ์ž‘ - ํ”„๋กฌํ”„ํŠธ: {prompt} (์‹œ๋„: {retries+1}/{max_retries+1})")
110
+
111
+ # ์ด๋ฏธ์ง€ ์ถ”๊ฐ€ (์ด๋ฏธ์ง€๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ๋งŒ)
112
+ image_parts = []
113
+ for idx, img in enumerate(images, 1):
114
+ if img is not None:
115
+ # ์ด๋ฏธ์ง€๋ฅผ ๋ฐ”์ดํŠธ๋กœ ๋ณ€ํ™˜
116
+ buffered = BytesIO()
117
+ img.save(buffered, format="PNG")
118
+ image_bytes = buffered.getvalue()
119
+
120
+ # ์ด๋ฏธ์ง€ ํŒŒํŠธ ์ƒ์„ฑ
121
+ image_part = types.Content({
122
+ 'inline_data': {
123
+ 'mime_type': 'image/png',
124
+ 'data': image_bytes
125
+ }
126
+ })
127
+ image_parts.append(image_part)
128
+ logger.info(f"์ด๋ฏธ์ง€ #{idx} ์ถ”๊ฐ€๋จ")
129
+
130
+ # ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ์ปจํ…์ธ  ๊ฒฐํ•ฉ
131
+ contents = image_parts + [types.Content({'text': prompt})]
132
+
133
+ # ์ƒ์„ฑ ์„ค์ • - ๊ณต์‹ ๋ฌธ์„œ์— ๋”ฐ๋ผ responseModalities ์„ค์ •
134
+ response = client.models.generate_content(
135
+ model="gemini-2.0-flash-exp-image-generation",
136
+ contents=contents,
137
+ config=types.GenerateContentConfig(
138
+ response_modalities=['Text', 'Image'],
139
+ temperature=1,
140
+ top_p=0.95,
141
+ top_k=40,
142
+ max_output_tokens=8192
143
+ )
144
+ )
145
+
146
+ # ์‘๋‹ต ์œ ํšจ์„ฑ ํ™•์ธ
147
+ if not response or not response.candidates:
148
+ if retries < max_retries:
149
+ retries += 1
150
+ logger.warning(f"์œ ํšจํ•œ ์‘๋‹ต์„ ๋ฐ›์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์žฌ์‹œ๋„ ์ค‘... ({retries}/{max_retries})")
151
+ continue
152
+ return None, "์ด๋ฏธ์ง€ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์œ ํšจํ•œ ์‘๋‹ต์„ ๋ฐ›์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."
153
+
154
+ # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
155
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
156
+ temp_path = tmp.name
157
+
158
+ result_text = ""
159
+ image_found = False
160
+
161
+ # ์‘๋‹ต ์ฒ˜๋ฆฌ
162
+ for candidate in response.candidates:
163
+ if not candidate.content:
164
+ continue
165
+
166
+ # ํ…์ŠคํŠธ ์ถ”์ถœ
167
+ if candidate.content.text:
168
+ result_text += candidate.content.text
169
+ logger.info(f"์‘๋‹ต ํ…์ŠคํŠธ: {candidate.content.text}")
170
+
171
+ # ์ด๋ฏธ์ง€ ์ถ”์ถœ
172
+ for part in candidate.content.parts:
173
+ if hasattr(part, 'inline_data') and part.inline_data:
174
+ save_binary_file(temp_path, part.inline_data.data)
175
+ image_found = True
176
+ logger.info("์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ ์„ฑ๊ณต")
177
+ break
178
+
179
+ if image_found:
180
+ break
181
+
182
+ if not image_found:
183
+ if retries < max_retries:
184
+ retries += 1
185
+ logger.warning(f"API์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์žฌ์‹œ๋„ ์ค‘... ({retries}/{max_retries})")
186
+ continue
187
+ return None, f"API์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์‘๋‹ต ํ…์ŠคํŠธ: {result_text}"
188
+
189
+ # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋ฐ˜ํ™˜
190
+ result_img = Image.open(temp_path)
191
+ if result_img.mode == "RGBA":
192
+ result_img = result_img.convert("RGB")
193
+
194
+ return result_img, f"์ด๋ฏธ์ง€๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. {result_text}"
195
+
196
+ except Exception as e:
197
+ last_error = str(e)
198
+ logger.exception(f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ (์‹œ๋„: {retries+1}/{max_retries+1}):")
199
+ if retries < max_retries:
200
+ retries += 1
201
+ logger.info(f"์žฌ์‹œ๋„ ์ค‘... ({retries}/{max_retries})")
202
+ continue
203
+ return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {last_error}"
204
+
205
+ return None, f"์ตœ๋Œ€ ์žฌ์‹œ๋„ ํšŸ์ˆ˜ ์ดˆ๊ณผ. ๋งˆ์ง€๋ง‰ ์˜ค๋ฅ˜: {last_error}"
206
+
207
 
208
  def process_images_with_prompt(image1, image2, image3, prompt):
209
  """
 
235
  # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ ๋ฐ ๊ธฐ๋Šฅ ๋ช…๋ น ํ•ด์„
236
  prompt = preprocess_prompt(prompt, image1, image2, image3)
237
 
238
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
239
  return generate_with_images(prompt, valid_images)
240
 
241
  except Exception as e:
242
  logger.exception("์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
243
  return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
244
 
245
+ # ๊ธฐ๋Šฅ ์„ ํƒ ์ฝœ๋ฐฑ
246
  def update_prompt_from_function(function_choice):
247
  function_templates = {
248
  "1. ์ด๋ฏธ์ง€ ๋ณ€๊ฒฝ": "#1์„ ์ฐฝ์˜์ ์œผ๋กœ ๋ฐ”๊ฟ”๋ผ",
 
256
 
257
  return function_templates.get(function_choice, "")
258
 
259
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
260
  with gr.Blocks() as demo:
261
  gr.HTML(
262
  """
 
275
  image2_input = gr.Image(type="pil", label="#2", image_mode="RGB")
276
  image3_input = gr.Image(type="pil", label="#3", image_mode="RGB")
277
 
278
+ # ๊ธฐ๋Šฅ ์„ ํƒ ๋“œ๋กญ๋‹ค์šด
279
  function_dropdown = gr.Dropdown(
280
  choices=[
281
  "1. ์ด๋ฏธ์ง€ ๋ณ€๊ฒฝ",