File size: 4,645 Bytes
8d852cf
 
 
 
 
 
ddcfe75
 
8d852cf
77a62b5
e0e8bcb
 
 
77a62b5
 
8d852cf
 
 
 
 
ddcfe75
77a62b5
 
 
 
8d852cf
77a62b5
 
 
 
 
 
 
 
 
0ae66bb
77a62b5
1e9bbd5
77a62b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd7e6af
1e9bbd5
8d852cf
 
77a62b5
8d852cf
77a62b5
 
8d852cf
 
77a62b5
 
 
 
 
8d852cf
 
1e9bbd5
77a62b5
 
 
 
fd7e6af
77a62b5
 
 
 
fd7e6af
77a62b5
 
8d852cf
e0e8bcb
77a62b5
ddcfe75
77a62b5
5ff01f8
77a62b5
 
5ff01f8
 
77a62b5
 
 
 
5ff01f8
77a62b5
5ff01f8
 
77a62b5
 
5ff01f8
77a62b5
 
5ff01f8
 
 
77a62b5
 
 
 
 
5ff01f8
77a62b5
 
 
5ff01f8
 
77a62b5
5ff01f8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import tempfile
from PIL import Image
import gradio as gr
import logging

from google import genai
from google.genai import types

# ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋กœ๋“œ
from dotenv import load_dotenv
load_dotenv()

# ๊ฐ„๋‹จํ•œ ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def save_binary_file(file_name, data):
    with open(file_name, "wb") as f:
        f.write(data)

def process_images_with_prompt(image1, image2, image3, prompt):
    """
    3๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜
    """
    try:
        # API ํ‚ค ํ™•์ธ
        api_key = os.environ.get("GEMINI_API_KEY")
        if not api_key:
            return None, "API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
            
        # Gemini ํด๋ผ์ด์–ธํŠธ ์ดˆ๊ธฐํ™”
        client = genai.Client(api_key=api_key)
        
        # ํ”„๋กฌํ”„ํŠธ ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
        if not prompt or not prompt.strip():
            prompt = "์ด ์ด๋ฏธ์ง€๋“ค์„ ํ™œ์šฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ด์ฃผ์„ธ์š”."
        
        # ์ปจํ…์ธ  ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ (์ด๋ฏธ์ง€์™€ ํ”„๋กฌํ”„ํŠธ ๊ฒฐํ•ฉ)
        parts = []
        
        # ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
        parts.append(types.Part.from_text(text=prompt))
        
        # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ด๋ฏธ์ง€ ์ถ”๊ฐ€
        for img in [image1, image2, image3]:
            if img is not None:
                # PIL ์ด๋ฏธ์ง€๋ฅผ ๋ฐ”์ดํŠธ๋กœ ๋ณ€ํ™˜
                with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
                    img_path = tmp.name
                    img.save(img_path, format="PNG")
                    with open(img_path, "rb") as f:
                        image_bytes = f.read()
                
                # ์ด๋ฏธ์ง€๋ฅผ ํŒŒํŠธ๋กœ ์ถ”๊ฐ€
                parts.append(types.Part.from_data(data=image_bytes, mime_type="image/png"))
        
        # ์ƒ์„ฑ ์„ค์ •
        generate_content_config = types.GenerateContentConfig(
            temperature=1,
            response_modalities=["image"],
        )
        
        # ์ž„์‹œ ํŒŒ์ผ ์ƒ์„ฑ
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
            temp_path = tmp.name
            
            # Gemini ๋ชจ๋ธ๋กœ ์š”์ฒญ ์ „์†ก
            response = client.models.generate_content(
                model="gemini-2.0-flash-exp-image-generation",
                contents=[types.Content(role="user", parts=parts)],
                config=generate_content_config,
            )
            
            # ์‘๋‹ต์—์„œ ์ด๋ฏธ์ง€ ์ถ”์ถœ
            for part in response.candidates[0].content.parts:
                if hasattr(part, 'inline_data') and part.inline_data:
                    save_binary_file(temp_path, part.inline_data.data)
                    
            # ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€ ๋ฐ˜ํ™˜
            result_img = Image.open(temp_path)
            if result_img.mode == "RGBA":
                result_img = result_img.convert("RGB")
                
            return result_img, "์ด๋ฏธ์ง€๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
            
    except Exception as e:
        logger.exception("์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ:")
        return None, f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"

# ๊ฐ„์†Œํ™”๋œ Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks() as demo:
    gr.HTML("<h1>๊ฐ„๋‹จํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ๊ธฐ</h1><p>์ด๋ฏธ์ง€ 3๊ฐœ์™€ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”</p>")
    
    with gr.Row():
        with gr.Column():
            # 3๊ฐœ์˜ ์ด๋ฏธ์ง€ ์ž…๋ ฅ
            image1_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ 1", image_mode="RGB")
            image2_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ 2", image_mode="RGB")
            image3_input = gr.Image(type="pil", label="์ด๋ฏธ์ง€ 3", image_mode="RGB")
            
            # ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ
            prompt_input = gr.Textbox(
                lines=3,
                placeholder="์ด ์ด๋ฏธ์ง€๋“ค์„ ์–ด๋–ป๊ฒŒ ๋ณ€ํ™˜ํ• ์ง€ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”",
                label="ํ”„๋กฌํ”„ํŠธ"
            )
            
            # ์ƒ์„ฑ ๋ฒ„ํŠผ
            submit_btn = gr.Button("์ด๋ฏธ์ง€ ์ƒ์„ฑ")
        
        with gr.Column():
            # ๊ฒฐ๊ณผ ์ถœ๋ ฅ
            output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")
            output_text = gr.Textbox(label="์ƒํƒœ ๋ฉ”์‹œ์ง€")
    
    # ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ
    submit_btn.click(
        fn=process_images_with_prompt,
        inputs=[image1_input, image2_input, image3_input, prompt_input],
        outputs=[output_image, output_text],
    )

# ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰
if __name__ == "__main__":
    demo.launch(share=True)