andresampa commited on
Commit
a58cf7d
·
verified ·
1 Parent(s): 3e55389

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -182
app.py CHANGED
@@ -1,35 +1,21 @@
1
-
2
  import os
3
  import random
4
  from huggingface_hub import InferenceClient
5
  from PIL import Image
6
- from IPython.display import display, clear_output
7
- import ipywidgets as widgets
8
  from datetime import datetime
9
 
10
- # Retrieve the Hugging Face token from Colab secrets
11
- api_token = os.environ.get("HF_CTB_TOKEN")
12
 
13
  # List of models with aliases
14
  models = [
15
- {
16
- "alias": "FLUX.1-dev",
17
- "name": "black-forest-labs/FLUX.1-dev"
18
- },
19
- {
20
- "alias": "Stable Diffusion 3.5 turbo",
21
- "name": "stabilityai/stable-diffusion-3.5-large-turbo"
22
- },
23
- {
24
- "alias": "Midjourney",
25
- "name": "strangerzonehf/Flux-Midjourney-Mix2-LoRA"
26
- }
27
  ]
28
 
29
- # Initialize the InferenceClient with the default model
30
- client = InferenceClient(models[0]["name"], token=api_token)
31
-
32
- # List of 10 prompts with intense combat
33
  prompts = [
34
  {
35
  "alias": "Castle Siege",
@@ -73,173 +59,77 @@ prompts = [
73
  }
74
  ]
75
 
76
- # Dropdown menu for model selection
77
- model_dropdown = widgets.Dropdown(
78
- options=[(model["alias"], model["name"]) for model in models],
79
- description="Select Model:",
80
- style={"description_width": "initial"}
81
- )
82
-
83
- # Dropdown menu for prompt selection
84
- prompt_dropdown = widgets.Dropdown(
85
- options=[(prompt["alias"], prompt["text"]) for prompt in prompts],
86
- description="Select Prompt:",
87
- style={"description_width": "initial"}
88
- )
89
-
90
- # Dropdown menu for team selection
91
- team_dropdown = widgets.Dropdown(
92
- options=["Red", "Blue"],
93
- description="Select Team:",
94
- style={"description_width": "initial"}
95
- )
96
-
97
- # Input for height
98
- height_input = widgets.IntText(
99
- value=360,
100
- description="Height:",
101
- style={"description_width": "initial"}
102
- )
103
-
104
- # Input for width
105
- width_input = widgets.IntText(
106
- value=640,
107
- description="Width:",
108
- style={"description_width": "initial"}
109
- )
110
-
111
- # Input for number of inference steps
112
- num_inference_steps_input = widgets.IntSlider(
113
- value=20,
114
- min=10,
115
- max=100,
116
- step=1,
117
- description="Inference Steps:",
118
- style={"description_width": "initial"}
119
- )
120
-
121
- # Input for guidance scale
122
- guidance_scale_input = widgets.FloatSlider(
123
- value=2,
124
- min=1.0,
125
- max=20.0,
126
- step=0.5,
127
- description="Guidance Scale:",
128
- style={"description_width": "initial"}
129
- )
130
-
131
- # Input for seed
132
- seed_input = widgets.IntText(
133
- value=random.randint(0, 1000000),
134
- description="Seed:",
135
- style={"description_width": "initial"}
136
- )
137
-
138
- # Checkbox to randomize seed
139
- randomize_seed_checkbox = widgets.Checkbox(
140
- value=True,
141
- description="Randomize Seed",
142
- style={"description_width": "initial"}
143
- )
144
 
145
- # Button to generate image
146
- generate_button = widgets.Button(
147
- description="Generate Image",
148
- button_style="success"
149
- )
150
-
151
- # Output area to display the image
152
- output = widgets.Output()
153
-
154
- # Function to generate images based on the selected prompt, team, and model
155
- def generate_image(prompt, team, model_name, height, width, num_inference_steps, guidance_scale, seed):
156
  # Determine the enemy color
157
  enemy_color = "blue" if team.lower() == "red" else "red"
158
-
159
- # Replace {enemy_color} in the prompt
160
  prompt = prompt.format(enemy_color=enemy_color)
161
-
162
  if team.lower() == "red":
163
  prompt += " The winning army is dressed in red armor and banners."
164
  elif team.lower() == "blue":
165
  prompt += " The winning army is dressed in blue armor and banners."
166
- else:
167
- return "Invalid team selection. Please choose 'Red' or 'Blue'."
168
-
169
- try:
170
- # Randomize the seed if the checkbox is checked
171
- if randomize_seed_checkbox.value:
172
- seed = random.randint(0, 1000000)
173
- seed_input.value = seed # Update the seed input box
174
-
175
- print(f"Using seed: {seed}")
176
-
177
- # Debug: Indicate that the image is being generated
178
- print("Generating image... Please wait.")
179
-
180
- # Initialize the InferenceClient with the selected model
181
- client = InferenceClient(model_name, token=api_token)
182
-
183
- # Generate the image using the Inference API with parameters
184
- image = client.text_to_image(
185
- prompt,
186
- guidance_scale=guidance_scale, # Guidance scale
187
- num_inference_steps=num_inference_steps, # Number of inference steps
188
- width=width, # Width
189
- height=height, # Height
190
- seed=seed # Random seed
191
- )
192
- return image
193
- except Exception as e:
194
- return f"An error occurred: {e}"
195
-
196
- # Function to handle button click event
197
- def on_generate_button_clicked(b):
198
- with output:
199
- clear_output(wait=True) # Clear previous output
200
- selected_prompt = prompt_dropdown.value
201
- selected_team = team_dropdown.value
202
- selected_model = model_dropdown.value
203
- height = height_input.value
204
- width = width_input.value
205
- num_inference_steps = num_inference_steps_input.value
206
- guidance_scale = guidance_scale_input.value
207
- seed = seed_input.value
208
-
209
- # Debug: Show selected parameters
210
- print(f"Selected Model: {model_dropdown.label}")
211
- print(f"Selected Prompt: {prompt_dropdown.label}")
212
- print(f"Selected Team: {selected_team}")
213
- print(f"Height: {height}")
214
- print(f"Width: {width}")
215
- print(f"Inference Steps: {num_inference_steps}")
216
- print(f"Guidance Scale: {guidance_scale}")
217
- print(f"Seed: {seed}")
218
-
219
- # Generate the image
220
- image = generate_image(selected_prompt, selected_team, selected_model, height, width, num_inference_steps, guidance_scale, seed)
221
-
222
- if isinstance(image, str):
223
- print(image)
224
- else:
225
- # Debug: Indicate that the image is being displayed and saved
226
- print("Image generated successfully!")
227
- print("Displaying image...")
228
-
229
- # Display the image in the notebook
230
- display(image)
231
-
232
- # Save the image with a timestamped filename
233
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
234
- output_filename = f"{timestamp}_{model_dropdown.label.replace(' ', '_').lower()}_{prompt_dropdown.label.replace(' ', '_').lower()}_{selected_team.lower()}.png"
235
- print(f"Saving image as {output_filename}...")
236
- image.save(output_filename)
237
- print(f"Image saved as {output_filename}")
238
-
239
- # Attach the button click event handler
240
- generate_button.on_click(on_generate_button_clicked)
241
-
242
- # Display the widgets
243
- #display(model_dropdown, prompt_dropdown, team_dropdown, height_input, width_input, num_inference_steps_input, guidance_scale_input, seed_input, randomize_seed_checkbox, generate_button, output)
244
 
245
- display(prompt_dropdown, team_dropdown, generate_button, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
  from huggingface_hub import InferenceClient
4
  from PIL import Image
5
+ import gradio as gr
 
6
  from datetime import datetime
7
 
8
+ # Retrieve the Hugging Face token from environment variables
9
+ api_token = os.getenv("HF_TOKEN")
10
 
11
  # List of models with aliases
12
  models = [
13
+ {"alias": "FLUX.1-dev", "name": "black-forest-labs/FLUX.1-dev"},
14
+ {"alias": "Stable Diffusion 3.5 turbo", "name": "stabilityai/stable-diffusion-3.5-large-turbo"},
15
+ {"alias": "Midjourney", "name": "strangerzonehf/Flux-Midjourney-Mix2-LoRA"}
 
 
 
 
 
 
 
 
 
16
  ]
17
 
18
+ # List of prompts with intense combat
 
 
 
19
  prompts = [
20
  {
21
  "alias": "Castle Siege",
 
59
  }
60
  ]
61
 
62
+ # Function to generate images
63
+ def generate_image(prompt_alias, team, model_alias, height, width, num_inference_steps, guidance_scale, seed):
64
+ # Find the selected prompt and model
65
+ prompt = next(p for p in prompts if p["alias"] == prompt_alias)["text"]
66
+ model_name = next(m for m in models if m["alias"] == model_alias)["name"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Determine the enemy color
69
  enemy_color = "blue" if team.lower() == "red" else "red"
 
 
70
  prompt = prompt.format(enemy_color=enemy_color)
71
+
72
  if team.lower() == "red":
73
  prompt += " The winning army is dressed in red armor and banners."
74
  elif team.lower() == "blue":
75
  prompt += " The winning army is dressed in blue armor and banners."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # Randomize the seed if needed
78
+ if seed == -1:
79
+ seed = random.randint(0, 1000000)
80
+
81
+ # Initialize the InferenceClient
82
+ client = InferenceClient(model_name, token=api_token)
83
+
84
+ # Generate the image
85
+ image = client.text_to_image(
86
+ prompt,
87
+ guidance_scale=guidance_scale,
88
+ num_inference_steps=num_inference_steps,
89
+ width=width,
90
+ height=height,
91
+ seed=seed
92
+ )
93
+
94
+ # Save the image with a timestamped filename
95
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
96
+ output_filename = f"{timestamp}_{model_alias.replace(' ', '_').lower()}_{prompt_alias.replace(' ', '_').lower()}_{team.lower()}.png"
97
+ image.save(output_filename)
98
+
99
+ return output_filename
100
+
101
+ # Gradio Interface
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("# CtB AI Image Generator")
104
+ with gr.Row():
105
+ prompt_dropdown = gr.Dropdown(choices=[p["alias"] for p in prompts], label="Select Prompt")
106
+ team_dropdown = gr.Dropdown(choices=["Red", "Blue"], label="Select Team")
107
+ model_dropdown = gr.Dropdown(choices=[m["alias"] for m in models], label="Select Model")
108
+ with gr.Row():
109
+ height_input = gr.Number(value=360, label="Height")
110
+ width_input = gr.Number(value=640, label="Width")
111
+ num_inference_steps_input = gr.Slider(minimum=10, maximum=100, value=20, label="Inference Steps")
112
+ guidance_scale_input = gr.Slider(minimum=1.0, maximum=20.0, value=2.0, step=0.5, label="Guidance Scale")
113
+ seed_input = gr.Number(value=-1, label="Seed (-1 for random)")
114
+ with gr.Row():
115
+ generate_button = gr.Button("Generate Image")
116
+ with gr.Row():
117
+ output_image = gr.Image(label="Generated Image")
118
+
119
+ # Function to handle button click
120
+ def generate(prompt_alias, team, model_alias, height, width, num_inference_steps, guidance_scale, seed):
121
+ try:
122
+ image_path = generate_image(prompt_alias, team, model_alias, height, width, num_inference_steps, guidance_scale, seed)
123
+ return image_path
124
+ except Exception as e:
125
+ return f"An error occurred: {e}"
126
+
127
+ # Connect the button to the function
128
+ generate_button.click(
129
+ generate,
130
+ inputs=[prompt_dropdown, team_dropdown, model_dropdown, height_input, width_input, num_inference_steps_input, guidance_scale_input, seed_input],
131
+ outputs=output_image
132
+ )
133
+
134
+ # Launch the Gradio app
135
+ demo.launch()