|
import os |
|
import tempfile |
|
from PIL import Image |
|
import gradio as gr |
|
import logging |
|
|
|
from google import genai |
|
from google.genai import types |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def save_binary_file(file_name, data): |
|
with open(file_name, "wb") as f: |
|
f.write(data) |
|
|
|
def process_images_with_prompt(image1, image2, image3, prompt): |
|
""" |
|
3๊ฐ์ ์ด๋ฏธ์ง์ ํ๋กฌํํธ๋ฅผ ์ฒ๋ฆฌํ๋ ํจ์ |
|
""" |
|
try: |
|
|
|
api_key = os.environ.get("GEMINI_API_KEY") |
|
if not api_key: |
|
return None, "API ํค๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. ํ๊ฒฝ๋ณ์๋ฅผ ํ์ธํด์ฃผ์ธ์." |
|
|
|
|
|
client = genai.Client(api_key=api_key) |
|
|
|
|
|
if not prompt or not prompt.strip(): |
|
prompt = "์ด ์ด๋ฏธ์ง๋ค์ ํ์ฉํ์ฌ ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ ์์ฑํด์ฃผ์ธ์." |
|
|
|
|
|
parts = [] |
|
|
|
|
|
parts.append(types.Part.from_text(text=prompt)) |
|
|
|
|
|
for img in [image1, image2, image3]: |
|
if img is not None: |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as tmp: |
|
img_path = tmp.name |
|
img.save(img_path, format="PNG") |
|
with open(img_path, "rb") as f: |
|
image_bytes = f.read() |
|
|
|
|
|
parts.append(types.Part.from_data(data=image_bytes, mime_type="image/png")) |
|
|
|
|
|
generate_content_config = types.GenerateContentConfig( |
|
temperature=1, |
|
response_modalities=["image"], |
|
) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: |
|
temp_path = tmp.name |
|
|
|
|
|
response = client.models.generate_content( |
|
model="gemini-2.0-flash-exp-image-generation", |
|
contents=[types.Content(role="user", parts=parts)], |
|
config=generate_content_config, |
|
) |
|
|
|
|
|
for part in response.candidates[0].content.parts: |
|
if hasattr(part, 'inline_data') and part.inline_data: |
|
save_binary_file(temp_path, part.inline_data.data) |
|
|
|
|
|
result_img = Image.open(temp_path) |
|
if result_img.mode == "RGBA": |
|
result_img = result_img.convert("RGB") |
|
|
|
return result_img, "์ด๋ฏธ์ง๊ฐ ์ฑ๊ณต์ ์ผ๋ก ์์ฑ๋์์ต๋๋ค." |
|
|
|
except Exception as e: |
|
logger.exception("์ด๋ฏธ์ง ์์ฑ ์ค ์ค๋ฅ ๋ฐ์:") |
|
return None, f"์ค๋ฅ ๋ฐ์: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML("<h1>๊ฐ๋จํ ์ด๋ฏธ์ง ์์ฑ๊ธฐ</h1><p>์ด๋ฏธ์ง 3๊ฐ์ ํ๋กฌํํธ๋ฅผ ์
๋ ฅํ์ธ์</p>") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
image1_input = gr.Image(type="pil", label="์ด๋ฏธ์ง 1", image_mode="RGB") |
|
image2_input = gr.Image(type="pil", label="์ด๋ฏธ์ง 2", image_mode="RGB") |
|
image3_input = gr.Image(type="pil", label="์ด๋ฏธ์ง 3", image_mode="RGB") |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
lines=3, |
|
placeholder="์ด ์ด๋ฏธ์ง๋ค์ ์ด๋ป๊ฒ ๋ณํํ ์ง ์ค๋ช
ํด์ฃผ์ธ์", |
|
label="ํ๋กฌํํธ" |
|
) |
|
|
|
|
|
submit_btn = gr.Button("์ด๋ฏธ์ง ์์ฑ") |
|
|
|
with gr.Column(): |
|
|
|
output_image = gr.Image(label="์์ฑ๋ ์ด๋ฏธ์ง") |
|
output_text = gr.Textbox(label="์ํ ๋ฉ์์ง") |
|
|
|
|
|
submit_btn.click( |
|
fn=process_images_with_prompt, |
|
inputs=[image1_input, image2_input, image3_input, prompt_input], |
|
outputs=[output_image, output_text], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |