Spaces:
Sleeping
Sleeping
Andre
commited on
Commit
·
22d47f1
1
Parent(s):
e1fd06b
Update
Browse files- colab.ipynb +5 -0
- img_gen_logic_colab.py +89 -0
colab.ipynb
CHANGED
@@ -8,12 +8,17 @@
|
|
8 |
"source": [
|
9 |
"# colab.ipynb\n",
|
10 |
"# Import necessary libraries\n",
|
|
|
11 |
"import random\n",
|
12 |
"import ipywidgets as widgets\n",
|
13 |
"from huggingface_hub import InferenceClient\n",
|
14 |
"from IPython.display import display, clear_output\n",
|
15 |
"from img_gen_logic import generate_image\n",
|
16 |
"from config_colab import models, prompts, api_token\n",
|
|
|
|
|
|
|
|
|
17 |
"\n",
|
18 |
"# Initialize the InferenceClient with the default model\n",
|
19 |
"client = InferenceClient(models[0][\"name\"], token=api_token)\n",
|
|
|
8 |
"source": [
|
9 |
"# colab.ipynb\n",
|
10 |
"# Import necessary libraries\n",
|
11 |
+
"import os\n",
|
12 |
"import random\n",
|
13 |
"import ipywidgets as widgets\n",
|
14 |
"from huggingface_hub import InferenceClient\n",
|
15 |
"from IPython.display import display, clear_output\n",
|
16 |
"from img_gen_logic import generate_image\n",
|
17 |
"from config_colab import models, prompts, api_token\n",
|
18 |
+
"from PIL import Image\n",
|
19 |
+
"from google.colab import userdata\n",
|
20 |
+
"from datetime import datetime\n",
|
21 |
+
"from img_gen_logic_colab import generate_image, save_image\n",
|
22 |
"\n",
|
23 |
"# Initialize the InferenceClient with the default model\n",
|
24 |
"client = InferenceClient(models[0][\"name\"], token=api_token)\n",
|
img_gen_logic_colab.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# img_gen_logic_colab.py
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
from PIL import Image
|
4 |
+
import random
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
# Function to generate images based on the selected prompt, team, and model
|
8 |
+
def generate_image(prompt, team, model_name, height, width, num_inference_steps, guidance_scale, seed, custom_prompt, api_token, randomize_seed=True):
|
9 |
+
"""
|
10 |
+
Generate an image using the Hugging Face Inference API.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
prompt (str): The base prompt for image generation.
|
14 |
+
team (str): The selected team ("Red" or "Blue").
|
15 |
+
model_name (str): The name of the model to use.
|
16 |
+
height (int): The height of the generated image.
|
17 |
+
width (int): The width of the generated image.
|
18 |
+
num_inference_steps (int): The number of inference steps.
|
19 |
+
guidance_scale (float): The guidance scale for the model.
|
20 |
+
seed (int): The seed for random generation.
|
21 |
+
custom_prompt (str): Additional custom prompt text.
|
22 |
+
api_token (str): The Hugging Face API token.
|
23 |
+
randomize_seed (bool): Whether to randomize the seed.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PIL.Image.Image or str: The generated image or an error message.
|
27 |
+
"""
|
28 |
+
# Determine the enemy color
|
29 |
+
enemy_color = "blue" if team.lower() == "red" else "red"
|
30 |
+
|
31 |
+
# Replace {enemy_color} in the prompt
|
32 |
+
prompt = prompt.format(enemy_color=enemy_color)
|
33 |
+
|
34 |
+
# Add team-specific details to the prompt
|
35 |
+
if team.lower() == "red":
|
36 |
+
prompt += " The winning army is dressed in red armor and banners."
|
37 |
+
elif team.lower() == "blue":
|
38 |
+
prompt += " The winning army is dressed in blue armor and banners."
|
39 |
+
else:
|
40 |
+
return "Invalid team selection. Please choose 'Red' or 'Blue'."
|
41 |
+
|
42 |
+
# Append the custom prompt if provided
|
43 |
+
if custom_prompt.strip():
|
44 |
+
prompt += " " + custom_prompt.strip()
|
45 |
+
|
46 |
+
try:
|
47 |
+
# Randomize the seed if the checkbox is checked
|
48 |
+
if randomize_seed:
|
49 |
+
seed = random.randint(0, 1000000)
|
50 |
+
|
51 |
+
print(f"Using seed: {seed}")
|
52 |
+
|
53 |
+
# Debug: Indicate that the image is being generated
|
54 |
+
print("Generating image... Please wait.")
|
55 |
+
|
56 |
+
# Initialize the InferenceClient with the selected model
|
57 |
+
client = InferenceClient(model_name, token=api_token)
|
58 |
+
|
59 |
+
# Generate the image using the Inference API with parameters
|
60 |
+
image = client.text_to_image(
|
61 |
+
prompt,
|
62 |
+
guidance_scale=guidance_scale, # Guidance scale
|
63 |
+
num_inference_steps=num_inference_steps, # Number of inference steps
|
64 |
+
width=width, # Width
|
65 |
+
height=height, # Height
|
66 |
+
seed=seed # Random seed
|
67 |
+
)
|
68 |
+
return image
|
69 |
+
except Exception as e:
|
70 |
+
return f"An error occurred: {e}"
|
71 |
+
|
72 |
+
# Function to save the generated image with a timestamped filename
|
73 |
+
def save_image(image, model_label, prompt_label, team):
|
74 |
+
"""
|
75 |
+
Save the generated image with a timestamped filename.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
image (PIL.Image.Image): The generated image.
|
79 |
+
model_label (str): The label of the selected model.
|
80 |
+
prompt_label (str): The label of the selected prompt.
|
81 |
+
team (str): The selected team ("Red" or "Blue").
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The filename of the saved image.
|
85 |
+
"""
|
86 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
87 |
+
output_filename = f"{timestamp}_{model_label.replace(' ', '_').lower()}_{prompt_label.replace(' ', '_').lower()}_{team.lower()}.png"
|
88 |
+
image.save(output_filename)
|
89 |
+
return output_filename
|