saq1b commited on
Commit
3566679
·
verified ·
1 Parent(s): 6849ccd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +193 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import base64
3
+ import io
4
+ from PIL import Image
5
+ import json
6
+ import os
7
+ import asyncio
8
+ from google import genai
9
+ from google.genai import types
10
+
11
+ # Function to convert PIL Image to bytes
12
+ def pil_to_bytes(img, format="PNG"):
13
+ img_byte_arr = io.BytesIO()
14
+ img.save(img_byte_arr, format=format)
15
+ return img_byte_arr.getvalue()
16
+
17
+ # Function to save API key in browser local storage
18
+ def save_api_key(api_key):
19
+ return api_key
20
+
21
+ # Function to load image as base64
22
+ async def load_image_base64(img):
23
+ if isinstance(img, str):
24
+ # If image is a URL or file path, load it
25
+ raise ValueError("URL loading not implemented in this version")
26
+ else:
27
+ # If image is already a PIL Image
28
+ return pil_to_bytes(img)
29
+
30
+ # Main function to generate edited image using Gemini
31
+ async def generate_image_gemini(prompt, image, api_key):
32
+ SAFETY_SETTINGS = {
33
+ types.HarmCategory.HARM_CATEGORY_HARASSMENT: types.HarmBlockThreshold.BLOCK_NONE,
34
+ types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: types.HarmBlockThreshold.BLOCK_NONE,
35
+ types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: types.HarmBlockThreshold.BLOCK_NONE,
36
+ types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: types.HarmBlockThreshold.BLOCK_NONE,
37
+ }
38
+
39
+ try:
40
+ # Initialize Gemini client with API key
41
+ client = genai.Client(api_key=api_key)
42
+
43
+ # Convert PIL image to bytes
44
+ image_bytes = await load_image_base64(image)
45
+
46
+ contents = []
47
+
48
+ # Add the image to the contents
49
+ contents.append(
50
+ types.Content(
51
+ role="user",
52
+ parts=[
53
+ types.Part.from_bytes(
54
+ data=image_bytes,
55
+ mime_type="image/png",
56
+ )
57
+ ],
58
+ )
59
+ )
60
+
61
+ # Add the prompt to the contents
62
+ edit_prompt = f"Edit this image: {prompt}"
63
+ contents.append(
64
+ types.Content(
65
+ role="user",
66
+ parts=[
67
+ types.Part.from_text(text=edit_prompt),
68
+ ],
69
+ )
70
+ )
71
+
72
+ response = await client.aio.models.generate_content(
73
+ model="gemini-2.0-flash-exp",
74
+ contents=contents,
75
+ config=types.GenerateContentConfig(
76
+ safety_settings=[
77
+ types.SafetySetting(
78
+ category=category, threshold=threshold
79
+ ) for category, threshold in SAFETY_SETTINGS.items()
80
+ ],
81
+ response_modalities=['Text', 'Image']
82
+ )
83
+ )
84
+
85
+ edited_images = []
86
+ for part in response.candidates[0].content.parts:
87
+ if part.inline_data is not None:
88
+ image_bytes = part.inline_data.data
89
+ edited_images.append(image_bytes)
90
+
91
+ # Convert the first returned image bytes to PIL image
92
+ if edited_images:
93
+ result_image = Image.open(io.BytesIO(edited_images[0]))
94
+ return result_image
95
+ else:
96
+ return None
97
+
98
+ except Exception as e:
99
+ print(f"Google GenAI client failed with error: {e}")
100
+ return None
101
+
102
+ # Function to process the image edit
103
+ def process_image_edit(image, prompt, api_key, image_history):
104
+ if not image or not prompt or not api_key:
105
+ return None, image_history, "Please provide an image, prompt, and API key"
106
+
107
+ # Store current image in history if not empty
108
+ if image is not None and image_history is None:
109
+ image_history = []
110
+ if image is not None:
111
+ image_history.append(image)
112
+
113
+ # Run the async function to edit the image
114
+ try:
115
+ edited_image = asyncio.run(generate_image_gemini(prompt, image, api_key))
116
+ if edited_image:
117
+ return edited_image, image_history, "Image edited successfully"
118
+ else:
119
+ return image, image_history, "Failed to edit image. Please try again."
120
+ except Exception as e:
121
+ return image, image_history, f"Error: {str(e)}"
122
+
123
+ # Function to undo the last edit
124
+ def undo_edit(image_history):
125
+ if image_history and len(image_history) > 1:
126
+ # Remove current image
127
+ image_history.pop()
128
+ # Return the previous image
129
+ return image_history[-1], image_history, "Reverted to previous image"
130
+ else:
131
+ return None, [], "No previous version available"
132
+
133
+ # Create Gradio UI
134
+ def create_ui():
135
+ with gr.Blocks(title="Gemini Image Editor") as app:
136
+ gr.Markdown("# Gemini Image Editor")
137
+ gr.Markdown("Upload an image, enter a description of the edit you want, and let Gemini do the rest!")
138
+
139
+ # Store image history in state
140
+ image_history = gr.State([])
141
+
142
+ with gr.Row():
143
+ with gr.Column():
144
+ input_image = gr.Image(type="pil", label="Upload Image")
145
+ prompt = gr.Textbox(label="Edit Description", placeholder="Describe the edit you want...")
146
+ api_key = gr.Textbox(label="Gemini API Key", placeholder="Enter your Gemini API key", type="password")
147
+ save_key = gr.Checkbox(label="Save API key in browser", value=True)
148
+
149
+ with gr.Row():
150
+ edit_btn = gr.Button("Edit Image")
151
+ undo_btn = gr.Button("Undo Last Edit")
152
+
153
+ with gr.Column():
154
+ output_image = gr.Image(type="pil", label="Edited Image")
155
+ status = gr.Textbox(label="Status", interactive=False)
156
+
157
+ # Set up event handlers
158
+ edit_btn.click(
159
+ fn=process_image_edit,
160
+ inputs=[input_image, prompt, api_key, image_history],
161
+ outputs=[output_image, image_history, status]
162
+ )
163
+
164
+ undo_btn.click(
165
+ fn=undo_edit,
166
+ inputs=[image_history],
167
+ outputs=[output_image, image_history, status]
168
+ )
169
+
170
+ # JavaScript for saving API key in local storage
171
+ app.load(None, None, None, _js="""
172
+ function() {
173
+ // Try to load saved API key from localStorage
174
+ const savedKey = localStorage.getItem('gemini_api_key');
175
+ if (savedKey) {
176
+ document.querySelector('input[data-testid="textbox"]#api_key').value = savedKey;
177
+ }
178
+
179
+ // Add event listener to save API key
180
+ document.querySelector('input[data-testid="textbox"]#api_key').addEventListener('change', function(e) {
181
+ if (document.querySelector('input[data-testid="checkbox"]#save_key').checked) {
182
+ localStorage.setItem('gemini_api_key', e.target.value);
183
+ }
184
+ });
185
+ }
186
+ """)
187
+
188
+ return app
189
+
190
+ # Launch the app
191
+ if __name__ == "__main__":
192
+ app = create_ui()
193
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ google-genai
3
+ Pillow
4
+ asyncio