Spaces:
Sleeping
Sleeping
File size: 7,413 Bytes
e1eaf41 22d47f1 9e5d569 fbb01db 465c4f7 e1eaf41 6dd0a15 e1eaf41 3ece4ac 1b95506 22d47f1 1b95506 05873ed 0f4c106 e1eaf41 4f48282 e1eaf41 0fe0f24 e1eaf41 f601bfb 1c2cfc0 f601bfb 0fe0f24 e1eaf41 0fe0f24 4dea78d 4f48282 ca8223a 6524712 e1eaf41 0fe0f24 b6b11b7 0fe0f24 4dea78d e1eaf41 0fe0f24 e1eaf41 ca8223a 71aa0b4 ca8223a 9c69778 e1eaf41 bd50b77 e1eaf41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"# Set PYTHONPATH to the project root \n",
"# Solves all problems w subfolders\n",
"os.environ[\"PYTHONPATH\"] = os.path.abspath(os.path.join(\"..\"))\n",
"\n",
"\n",
"# Import necessary libraries\n",
"import random\n",
"import ipywidgets as widgets\n",
"from huggingface_hub import InferenceClient\n",
"from IPython.display import display, clear_output\n",
"from src.img_gen_colab import generate_image, save_image\n",
"from config.config_colab import prompts, api_token # Import from config folder\n",
"from PIL import Image\n",
"from google.colab import userdata\n",
"from datetime import datetime\n",
"from config.models import models\n",
"\n",
"# Initialize the InferenceClient with the default model\n",
"client = InferenceClient(models[0][\"name\"], token=api_token)\n",
"\n",
"# Dropdown menu for model selection\n",
"model_dropdown = widgets.Dropdown(\n",
" options=[(model[\"alias\"], model[\"name\"]) for model in models],\n",
" description=\"Select Model:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Dropdown menu for prompt selection\n",
"prompt_dropdown = widgets.Dropdown(\n",
" options=[(prompt[\"alias\"], prompt[\"text\"]) for prompt in prompts],\n",
" description=\"Select Prompt:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Dropdown menu for team selection\n",
"team_dropdown = widgets.Dropdown(\n",
" options=[\"Red\", \"Blue\"],\n",
" description=\"Select Team:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Input for width\n",
"width_input = widgets.IntText(\n",
" value=640,\n",
" description=\"Width:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Input for height\n",
"height_input = widgets.IntText(\n",
" value=360,\n",
" description=\"Height:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Input for number of inference steps\n",
"num_inference_steps_input = widgets.IntSlider(\n",
" value=20,\n",
" min=10,\n",
" max=100,\n",
" step=1,\n",
" description=\"Inference Steps:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Input for guidance scale\n",
"guidance_scale_input = widgets.FloatSlider(\n",
" value=2,\n",
" min=1.0,\n",
" max=20.0,\n",
" step=0.5,\n",
" description=\"Guidance Scale:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Input for seed\n",
"seed_input = widgets.IntText(\n",
" value=random.randint(0, 1000000),\n",
" description=\"Seed:\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Checkbox to randomize seed\n",
"randomize_seed_checkbox = widgets.Checkbox(\n",
" value=True,\n",
" description=\"Randomize Seed\",\n",
" style={\"description_width\": \"initial\"}\n",
")\n",
"\n",
"# Text box for custom prompt\n",
"custom_prompt_input = widgets.Textarea(\n",
" value=\"\",\n",
" placeholder=\"Enter your custom prompt (up to 200 characters)...\",\n",
" description=\"Custom Prompt:\",\n",
" style={\"description_width\": \"initial\"},\n",
" layout=widgets.Layout(width=\"500px\", height=\"80px\")\n",
")\n",
"\n",
"# Button to generate image\n",
"generate_button = widgets.Button(\n",
" description=\"Generate Image\",\n",
" button_style=\"success\"\n",
")\n",
"\n",
"# Output area to display the image\n",
"output = widgets.Output()\n",
"\n",
"def on_generate_button_clicked(b):\n",
" with output:\n",
" clear_output(wait=True) # Clear previous output\n",
"\n",
" # Get selected values from widgets\n",
" selected_prompt = prompt_dropdown.value\n",
" selected_team = team_dropdown.value\n",
" selected_model = model_dropdown.value\n",
" height = height_input.value\n",
" width = width_input.value\n",
" num_inference_steps = num_inference_steps_input.value\n",
" guidance_scale = guidance_scale_input.value\n",
" seed = seed_input.value\n",
" custom_prompt = custom_prompt_input.value\n",
"\n",
" # Debug: Show selected parameters\n",
" print(\"=== Debug: Selected Parameters ===\")\n",
" print(f\"Selected Model: {model_dropdown.label}\")\n",
" print(f\"Selected Prompt: {prompt_dropdown.label}\")\n",
" print(f\"Selected Team: {selected_team}\")\n",
" print(f\"Height: {height}\")\n",
" print(f\"Width: {width}\")\n",
" print(f\"Inference Steps: {num_inference_steps}\")\n",
" print(f\"Guidance Scale: {guidance_scale}\")\n",
" print(f\"Seed: {seed}\")\n",
" print(f\"Custom Prompt: {custom_prompt}\")\n",
" print(\"==================================\")\n",
"\n",
" # Generate the image\n",
" print(\"=== Debug: Calling generate_image ===\")\n",
" image = generate_image(\n",
" selected_prompt, selected_team, selected_model, width, height,\n",
" num_inference_steps, guidance_scale, seed, custom_prompt, api_token,\n",
" randomize_seed=randomize_seed_checkbox.value\n",
" )\n",
"\n",
" # Debug: Check the output of generate_image\n",
" #print(\"=== Debug: generate_image Output ===\")\n",
" #print(f\"Image: {image}\")\n",
" #print(\"====================================\")\n",
"\n",
" if isinstance(image, str):\n",
" print(\"=== Debug: Error ===\")\n",
" print(image)\n",
" else:\n",
" # Debug: Indicate that the image is being displayed and saved\n",
" print(\"=== Debug: Image Generation ===\")\n",
" print(\"Image generated successfully!\")\n",
" print(\"Displaying image...\")\n",
"\n",
" # Display the image in the notebook\n",
" display(image)\n",
"\n",
" # Save the image with a timestamped filename\n",
" output_filename = save_image(image, model_dropdown.label, seed, prompt_dropdown.label, selected_team)\n",
" print(f\"Image saved as {output_filename}\")\n",
"\n",
"# Attach the button click event handler\n",
"generate_button.on_click(on_generate_button_clicked)\n",
"\n",
"# Display the widgets\n",
"display(prompt_dropdown, team_dropdown, model_dropdown, custom_prompt_input, seed_input, randomize_seed_checkbox, generate_button, output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|