test-100 / app.py
Kims12's picture
Update app.py
77a62b5 verified
raw
history blame
4.65 kB
import os
import tempfile
from PIL import Image
import gradio as gr
import logging
from google import genai
from google.genai import types
# ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋กœ๋“œ
from dotenv import load_dotenv
load_dotenv()
# ๊ฐ„๋‹จํ•œ ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def save_binary_file(file_name, data):
with open(file_name, "wb") as f:
f.write(data)
def process_images_with_prompt(image1, image2, image3, prompt):
"""
3๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
"""
try:
# API ํ‚ค ํ™•์ธ
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
return None, "API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
# Gemini ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
client = genai.Client(api_key=api_key)
# ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
if not prompt or not prompt.strip():
prompt = "์ด ์ด๋ฏธ์ง€๋“ค์„ ํ™œ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”."
# ์ปจํ…์ธ  ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ (์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ ๊ฒฐํ•ฉ)
parts = []
# ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
parts.append(types.Part.from_text(text=prompt))
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
for img in [image1, image2, image3]:
if img is not None:
# PIL ์ด๋ฏธ์ง€๋ฅผ ๋ฐ”์ดํŠธ๋กœ ๋ณ€ํ™˜
with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
img_path = tmp.name
img.save(img_path, format="PNG")
with open(img_path, "rb") as f:
image_bytes = f.read()
# ์ด๋ฏธ์ง€๋ฅผ ํŒŒํŠธ๋กœ ์ถ”๊ฐ€
parts.append(types.Part.from_data(data=image_bytes, mime_type="image/png"))
# ์ƒ์„ฑ ์„ค์ •
generate_content_config = types.GenerateContentConfig(
temperature=1,
response_modalities=["image"],
)
# ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
temp_path = tmp.name
# Gemini ๋ชจ๋ธ๋กœ ์š”์ฒญ ์ „์†ก
response = client.models.generate_content(
model="gemini-2.0-flash-exp-image-generation",
contents=[types.Content(role="user", parts=parts)],
config=generate_content_config,
)
# ์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ
for part in response.candidates[0].content.parts:
if hasattr(part, 'inline_data') and part.inline_data:
save_binary_file(temp_path, part.inline_data.data)
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋ฐ˜ํ™˜
result_img = Image.open(temp_path)
if result_img.mode == "RGBA":
result_img = result_img.convert("RGB")
return result_img, "์ด๋ฏธ์ง€๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
except Exception as e:
logger.exception("์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
# ๊ฐ„์†Œํ™”๋œ Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks() as demo:
gr.HTML("<h1>๊ฐ„๋‹จํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ๊ธฐ</h1><p>์ด๋ฏธ์ง€ 3๊ฐœ์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”</p>")
with gr.Row():
with gr.Column():
# 3๊ฐœ์˜ ์ด๋ฏธ์ง€ ์ž…๋ ฅ
image1_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ 1", image_mode="RGB")
image2_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ 2", image_mode="RGB")
image3_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ 3", image_mode="RGB")
# ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ
prompt_input = gr.Textbox(
lines=3,
placeholder="์ด ์ด๋ฏธ์ง€๋“ค์„ ์–ด๋–ป๊ฒŒ ๋ณ€ํ™˜ํ• ์ง€ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”",
label="ํ”„๋กฌํ”„ํŠธ"
)
# ์ƒ์„ฑ ๋ฒ„ํŠผ
submit_btn = gr.Button("์ด๋ฏธ์ง€ ์ƒ์„ฑ")
with gr.Column():
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
output_text = gr.Textbox(label="์ƒํƒœ ๋ฉ”์‹œ์ง€")
# ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ
submit_btn.click(
fn=process_images_with_prompt,
inputs=[image1_input, image2_input, image3_input, prompt_input],
outputs=[output_image, output_text],
)
# ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰
if __name__ == "__main__":
demo.launch(share=True)