Andre commited on
Commit
22d47f1
·
1 Parent(s): e1fd06b
Files changed (2) hide show
  1. colab.ipynb +5 -0
  2. img_gen_logic_colab.py +89 -0
colab.ipynb CHANGED
@@ -8,12 +8,17 @@
8
  "source": [
9
  "# colab.ipynb\n",
10
  "# Import necessary libraries\n",
 
11
  "import random\n",
12
  "import ipywidgets as widgets\n",
13
  "from huggingface_hub import InferenceClient\n",
14
  "from IPython.display import display, clear_output\n",
15
  "from img_gen_logic import generate_image\n",
16
  "from config_colab import models, prompts, api_token\n",
 
 
 
 
17
  "\n",
18
  "# Initialize the InferenceClient with the default model\n",
19
  "client = InferenceClient(models[0][\"name\"], token=api_token)\n",
 
8
  "source": [
9
  "# colab.ipynb\n",
10
  "# Import necessary libraries\n",
11
+ "import os\n",
12
  "import random\n",
13
  "import ipywidgets as widgets\n",
14
  "from huggingface_hub import InferenceClient\n",
15
  "from IPython.display import display, clear_output\n",
16
  "from img_gen_logic import generate_image\n",
17
  "from config_colab import models, prompts, api_token\n",
18
+ "from PIL import Image\n",
19
+ "from google.colab import userdata\n",
20
+ "from datetime import datetime\n",
21
+ "from img_gen_logic_colab import generate_image, save_image\n",
22
  "\n",
23
  "# Initialize the InferenceClient with the default model\n",
24
  "client = InferenceClient(models[0][\"name\"], token=api_token)\n",
img_gen_logic_colab.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # img_gen_logic_colab.py
2
+ from huggingface_hub import InferenceClient
3
+ from PIL import Image
4
+ import random
5
+ from datetime import datetime
6
+
7
+ # Function to generate images based on the selected prompt, team, and model
8
+ def generate_image(prompt, team, model_name, height, width, num_inference_steps, guidance_scale, seed, custom_prompt, api_token, randomize_seed=True):
9
+ """
10
+ Generate an image using the Hugging Face Inference API.
11
+
12
+ Args:
13
+ prompt (str): The base prompt for image generation.
14
+ team (str): The selected team ("Red" or "Blue").
15
+ model_name (str): The name of the model to use.
16
+ height (int): The height of the generated image.
17
+ width (int): The width of the generated image.
18
+ num_inference_steps (int): The number of inference steps.
19
+ guidance_scale (float): The guidance scale for the model.
20
+ seed (int): The seed for random generation.
21
+ custom_prompt (str): Additional custom prompt text.
22
+ api_token (str): The Hugging Face API token.
23
+ randomize_seed (bool): Whether to randomize the seed.
24
+
25
+ Returns:
26
+ PIL.Image.Image or str: The generated image or an error message.
27
+ """
28
+ # Determine the enemy color
29
+ enemy_color = "blue" if team.lower() == "red" else "red"
30
+
31
+ # Replace {enemy_color} in the prompt
32
+ prompt = prompt.format(enemy_color=enemy_color)
33
+
34
+ # Add team-specific details to the prompt
35
+ if team.lower() == "red":
36
+ prompt += " The winning army is dressed in red armor and banners."
37
+ elif team.lower() == "blue":
38
+ prompt += " The winning army is dressed in blue armor and banners."
39
+ else:
40
+ return "Invalid team selection. Please choose 'Red' or 'Blue'."
41
+
42
+ # Append the custom prompt if provided
43
+ if custom_prompt.strip():
44
+ prompt += " " + custom_prompt.strip()
45
+
46
+ try:
47
+ # Randomize the seed if the checkbox is checked
48
+ if randomize_seed:
49
+ seed = random.randint(0, 1000000)
50
+
51
+ print(f"Using seed: {seed}")
52
+
53
+ # Debug: Indicate that the image is being generated
54
+ print("Generating image... Please wait.")
55
+
56
+ # Initialize the InferenceClient with the selected model
57
+ client = InferenceClient(model_name, token=api_token)
58
+
59
+ # Generate the image using the Inference API with parameters
60
+ image = client.text_to_image(
61
+ prompt,
62
+ guidance_scale=guidance_scale, # Guidance scale
63
+ num_inference_steps=num_inference_steps, # Number of inference steps
64
+ width=width, # Width
65
+ height=height, # Height
66
+ seed=seed # Random seed
67
+ )
68
+ return image
69
+ except Exception as e:
70
+ return f"An error occurred: {e}"
71
+
72
+ # Function to save the generated image with a timestamped filename
73
+ def save_image(image, model_label, prompt_label, team):
74
+ """
75
+ Save the generated image with a timestamped filename.
76
+
77
+ Args:
78
+ image (PIL.Image.Image): The generated image.
79
+ model_label (str): The label of the selected model.
80
+ prompt_label (str): The label of the selected prompt.
81
+ team (str): The selected team ("Red" or "Blue").
82
+
83
+ Returns:
84
+ str: The filename of the saved image.
85
+ """
86
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
87
+ output_filename = f"{timestamp}_{model_label.replace(' ', '_').lower()}_{prompt_label.replace(' ', '_').lower()}_{team.lower()}.png"
88
+ image.save(output_filename)
89
+ return output_filename