Andre commited on
Commit
ca8223a
·
1 Parent(s): 1c2cfc0
Files changed (2) hide show
  1. colab.ipynb +13 -13
  2. img_gen_logic_colab.py +4 -6
colab.ipynb CHANGED
@@ -13,12 +13,11 @@
13
  "import ipywidgets as widgets\n",
14
  "from huggingface_hub import InferenceClient\n",
15
  "from IPython.display import display, clear_output\n",
16
- "from img_gen_logic import generate_image\n",
17
  "from config_colab import models, prompts, api_token\n",
18
  "from PIL import Image\n",
19
  "from google.colab import userdata\n",
20
  "from datetime import datetime\n",
21
- "from img_gen_logic_colab import generate_image, save_image\n",
22
  "\n",
23
  "# Initialize the InferenceClient with the default model\n",
24
  "client = InferenceClient(models[0][\"name\"], token=api_token)\n",
@@ -110,6 +109,7 @@
110
  "# Output area to display the image\n",
111
  "output = widgets.Output()\n",
112
  "\n",
 
113
  "def on_generate_button_clicked(b):\n",
114
  " with output:\n",
115
  " clear_output(wait=True) # Clear previous output\n",
@@ -142,30 +142,30 @@
142
  " print(\"=== Debug: Calling generate_image ===\")\n",
143
  " image, message = generate_image(\n",
144
  " selected_prompt, selected_team, selected_model, height, width,\n",
145
- " num_inference_steps, guidance_scale, seed, custom_prompt, api_token\n",
 
146
  " )\n",
147
  "\n",
148
  " # Debug: Check the output of generate_image\n",
149
  " print(\"=== Debug: generate_image Output ===\")\n",
150
- " print(f\"Image: {image}\")\n",
151
  " print(f\"Message: {message}\")\n",
152
  " print(\"====================================\")\n",
153
  "\n",
154
- " if isinstance(image, str):\n",
155
- " print(\"=== Debug: Error ===\")\n",
156
- " print(image)\n",
157
- " else:\n",
158
  " # Debug: Indicate that the image is being displayed and saved\n",
159
  " print(\"=== Debug: Image Generation ===\")\n",
160
  " print(\"Image generated successfully!\")\n",
161
  " print(\"Displaying image...\")\n",
162
  "\n",
163
  " # Display the image in the notebook\n",
164
- " if image is not None:\n",
165
- " display(image)\n",
166
- " else:\n",
167
- " print(\"=== Debug: Error ===\")\n",
168
- " print(\"No image was returned by generate_image.\")\n",
 
 
 
169
  "\n",
170
  "# Attach the button click event handler\n",
171
  "generate_button.on_click(on_generate_button_clicked)\n",
 
13
  "import ipywidgets as widgets\n",
14
  "from huggingface_hub import InferenceClient\n",
15
  "from IPython.display import display, clear_output\n",
16
+ "from img_gen_logic_colab import generate_image, save_image\n",
17
  "from config_colab import models, prompts, api_token\n",
18
  "from PIL import Image\n",
19
  "from google.colab import userdata\n",
20
  "from datetime import datetime\n",
 
21
  "\n",
22
  "# Initialize the InferenceClient with the default model\n",
23
  "client = InferenceClient(models[0][\"name\"], token=api_token)\n",
 
109
  "# Output area to display the image\n",
110
  "output = widgets.Output()\n",
111
  "\n",
112
+ "# Function to handle button click event\n",
113
  "def on_generate_button_clicked(b):\n",
114
  " with output:\n",
115
  " clear_output(wait=True) # Clear previous output\n",
 
142
  " print(\"=== Debug: Calling generate_image ===\")\n",
143
  " image, message = generate_image(\n",
144
  " selected_prompt, selected_team, selected_model, height, width,\n",
145
+ " num_inference_steps, guidance_scale, seed, custom_prompt, api_token,\n",
146
+ " randomize_seed=randomize_seed_checkbox.value\n",
147
  " )\n",
148
  "\n",
149
  " # Debug: Check the output of generate_image\n",
150
  " print(\"=== Debug: generate_image Output ===\")\n",
 
151
  " print(f\"Message: {message}\")\n",
152
  " print(\"====================================\")\n",
153
  "\n",
154
+ " if image is not None:\n",
 
 
 
155
  " # Debug: Indicate that the image is being displayed and saved\n",
156
  " print(\"=== Debug: Image Generation ===\")\n",
157
  " print(\"Image generated successfully!\")\n",
158
  " print(\"Displaying image...\")\n",
159
  "\n",
160
  " # Display the image in the notebook\n",
161
+ " display(image)\n",
162
+ "\n",
163
+ " # Save the image with a timestamped filename\n",
164
+ " output_filename = save_image(image, model_dropdown.label, prompt_dropdown.label, selected_team)\n",
165
+ " print(f\"Image saved as {output_filename}\")\n",
166
+ " else:\n",
167
+ " print(\"=== Debug: Error ===\")\n",
168
+ " print(message)\n",
169
  "\n",
170
  "# Attach the button click event handler\n",
171
  "generate_button.on_click(on_generate_button_clicked)\n",
img_gen_logic_colab.py CHANGED
@@ -4,7 +4,6 @@ from PIL import Image
4
  import random
5
  from datetime import datetime
6
 
7
- # Function to generate images based on the selected prompt, team, and model
8
  def generate_image(prompt, team, model_name, height, width, num_inference_steps, guidance_scale, seed, custom_prompt, api_token, randomize_seed=True):
9
  """
10
  Generate an image using the Hugging Face Inference API.
@@ -23,7 +22,7 @@ def generate_image(prompt, team, model_name, height, width, num_inference_steps,
23
  randomize_seed (bool): Whether to randomize the seed.
24
 
25
  Returns:
26
- PIL.Image.Image or str: The generated image or an error message.
27
  """
28
  # Determine the enemy color
29
  enemy_color = "blue" if team.lower() == "red" else "red"
@@ -37,7 +36,7 @@ def generate_image(prompt, team, model_name, height, width, num_inference_steps,
37
  elif team.lower() == "blue":
38
  prompt += " The winning army is dressed in blue armor and banners."
39
  else:
40
- return "Invalid team selection. Please choose 'Red' or 'Blue'."
41
 
42
  # Append the custom prompt if provided
43
  if custom_prompt.strip():
@@ -65,11 +64,10 @@ def generate_image(prompt, team, model_name, height, width, num_inference_steps,
65
  height=height, # Height
66
  seed=seed # Random seed
67
  )
68
- return image
69
  except Exception as e:
70
- return f"An error occurred: {e}"
71
 
72
- # Function to save the generated image with a timestamped filename
73
  def save_image(image, model_label, prompt_label, team):
74
  """
75
  Save the generated image with a timestamped filename.
 
4
  import random
5
  from datetime import datetime
6
 
 
7
  def generate_image(prompt, team, model_name, height, width, num_inference_steps, guidance_scale, seed, custom_prompt, api_token, randomize_seed=True):
8
  """
9
  Generate an image using the Hugging Face Inference API.
 
22
  randomize_seed (bool): Whether to randomize the seed.
23
 
24
  Returns:
25
+ tuple: A tuple containing the generated image (PIL.Image.Image) and a message (str).
26
  """
27
  # Determine the enemy color
28
  enemy_color = "blue" if team.lower() == "red" else "red"
 
36
  elif team.lower() == "blue":
37
  prompt += " The winning army is dressed in blue armor and banners."
38
  else:
39
+ return None, "Invalid team selection. Please choose 'Red' or 'Blue'."
40
 
41
  # Append the custom prompt if provided
42
  if custom_prompt.strip():
 
64
  height=height, # Height
65
  seed=seed # Random seed
66
  )
67
+ return image, "Image generated successfully."
68
  except Exception as e:
69
+ return None, f"An error occurred: {e}"
70
 
 
71
  def save_image(image, model_label, prompt_label, team):
72
  """
73
  Save the generated image with a timestamped filename.