File size: 4,645 Bytes
8d852cf ddcfe75 8d852cf 77a62b5 e0e8bcb 77a62b5 8d852cf ddcfe75 77a62b5 8d852cf 77a62b5 0ae66bb 77a62b5 1e9bbd5 77a62b5 fd7e6af 1e9bbd5 8d852cf 77a62b5 8d852cf 77a62b5 8d852cf 77a62b5 8d852cf 1e9bbd5 77a62b5 fd7e6af 77a62b5 fd7e6af 77a62b5 8d852cf e0e8bcb 77a62b5 ddcfe75 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 77a62b5 5ff01f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import tempfile
from PIL import Image
import gradio as gr
import logging
from google import genai
from google.genai import types
# ํ๊ฒฝ๋ณ์ ๋ก๋
from dotenv import load_dotenv
load_dotenv()
# ๊ฐ๋จํ ๋ก๊น
์ค์
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def save_binary_file(file_name, data):
with open(file_name, "wb") as f:
f.write(data)
def process_images_with_prompt(image1, image2, image3, prompt):
"""
3๊ฐ์ ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํ๋ ํจ์
"""
try:
# API ํค ํ์ธ
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
return None, "API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. ํ๊ฒฝ๋ณ์๋ฅผ ํ์ธํด์ฃผ์ธ์."
# Gemini ํด๋ผ์ด์ธํธ ์ด๊ธฐํ
client = genai.Client(api_key=api_key)
# ํ๋กฌํํธ ๊ธฐ๋ณธ๊ฐ ์ค์
if not prompt or not prompt.strip():
prompt = "์ด ์ด๋ฏธ์ง๋ค์ ํ์ฉํ์ฌ ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ ์์ฑํด์ฃผ์ธ์."
# ์ปจํ
์ธ ๋ฆฌ์คํธ ์์ฑ (์ด๋ฏธ์ง์ ํ๋กฌํํธ ๊ฒฐํฉ)
parts = []
# ํ
์คํธ ํ๋กฌํํธ ์ถ๊ฐ
parts.append(types.Part.from_text(text=prompt))
# ์ฌ์ฉ ๊ฐ๋ฅํ ์ด๋ฏธ์ง ์ถ๊ฐ
for img in [image1, image2, image3]:
if img is not None:
# PIL ์ด๋ฏธ์ง๋ฅผ ๋ฐ์ดํธ๋ก ๋ณํ
with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
img_path = tmp.name
img.save(img_path, format="PNG")
with open(img_path, "rb") as f:
image_bytes = f.read()
# ์ด๋ฏธ์ง๋ฅผ ํํธ๋ก ์ถ๊ฐ
parts.append(types.Part.from_data(data=image_bytes, mime_type="image/png"))
# ์์ฑ ์ค์
generate_content_config = types.GenerateContentConfig(
temperature=1,
response_modalities=["image"],
)
# ์์ ํ์ผ ์์ฑ
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
temp_path = tmp.name
# Gemini ๋ชจ๋ธ๋ก ์์ฒญ ์ ์ก
response = client.models.generate_content(
model="gemini-2.0-flash-exp-image-generation",
contents=[types.Content(role="user", parts=parts)],
config=generate_content_config,
)
# ์๋ต์์ ์ด๋ฏธ์ง ์ถ์ถ
for part in response.candidates[0].content.parts:
if hasattr(part, 'inline_data') and part.inline_data:
save_binary_file(temp_path, part.inline_data.data)
# ๊ฒฐ๊ณผ ์ด๋ฏธ์ง ๋ฐํ
result_img = Image.open(temp_path)
if result_img.mode == "RGBA":
result_img = result_img.convert("RGB")
return result_img, "์ด๋ฏธ์ง๊ฐ ์ฑ๊ณต์ ์ผ๋ก ์์ฑ๋์์ต๋๋ค."
except Exception as e:
logger.exception("์ด๋ฏธ์ง ์์ฑ ์ค ์ค๋ฅ ๋ฐ์:")
return None, f"์ค๋ฅ ๋ฐ์: {str(e)}"
# ๊ฐ์ํ๋ Gradio ์ธํฐํ์ด์ค
with gr.Blocks() as demo:
gr.HTML("<h1>๊ฐ๋จํ ์ด๋ฏธ์ง ์์ฑ๊ธฐ</h1><p>์ด๋ฏธ์ง 3๊ฐ์ ํ๋กฌํํธ๋ฅผ ์
๋ ฅํ์ธ์</p>")
with gr.Row():
with gr.Column():
# 3๊ฐ์ ์ด๋ฏธ์ง ์
๋ ฅ
image1_input = gr.Image(type="pil", label="์ด๋ฏธ์ง 1", image_mode="RGB")
image2_input = gr.Image(type="pil", label="์ด๋ฏธ์ง 2", image_mode="RGB")
image3_input = gr.Image(type="pil", label="์ด๋ฏธ์ง 3", image_mode="RGB")
# ํ๋กฌํํธ ์
๋ ฅ
prompt_input = gr.Textbox(
lines=3,
placeholder="์ด ์ด๋ฏธ์ง๋ค์ ์ด๋ป๊ฒ ๋ณํํ ์ง ์ค๋ช
ํด์ฃผ์ธ์",
label="ํ๋กฌํํธ"
)
# ์์ฑ ๋ฒํผ
submit_btn = gr.Button("์ด๋ฏธ์ง ์์ฑ")
with gr.Column():
# ๊ฒฐ๊ณผ ์ถ๋ ฅ
output_image = gr.Image(label="์์ฑ๋ ์ด๋ฏธ์ง")
output_text = gr.Textbox(label="์ํ ๋ฉ์์ง")
# ๋ฒํผ ํด๋ฆญ ์ด๋ฒคํธ
submit_btn.click(
fn=process_images_with_prompt,
inputs=[image1_input, image2_input, image3_input, prompt_input],
outputs=[output_image, output_text],
)
# ์ ํ๋ฆฌ์ผ์ด์
์คํ
if __name__ == "__main__":
demo.launch(share=True) |