{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "# Set PYTHONPATH to the project root \n", "# Solves all problems w subfolders\n", "os.environ[\"PYTHONPATH\"] = os.path.abspath(os.path.join(\"..\"))\n", "\n", "\n", "# Import necessary libraries\n", "import random\n", "import ipywidgets as widgets\n", "from huggingface_hub import InferenceClient\n", "from IPython.display import display, clear_output\n", "from src.img_gen_colab import generate_image, save_image\n", "from config.config_colab import prompts, api_token # Import from config folder\n", "from PIL import Image\n", "from google.colab import userdata\n", "from datetime import datetime\n", "from config.models import models\n", "\n", "# Initialize the InferenceClient with the default model\n", "client = InferenceClient(models[0][\"name\"], token=api_token)\n", "\n", "# Dropdown menu for model selection\n", "model_dropdown = widgets.Dropdown(\n", " options=[(model[\"alias\"], model[\"name\"]) for model in models],\n", " description=\"Select Model:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Dropdown menu for prompt selection\n", "prompt_dropdown = widgets.Dropdown(\n", " options=[(prompt[\"alias\"], prompt[\"text\"]) for prompt in prompts],\n", " description=\"Select Prompt:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Dropdown menu for team selection\n", "team_dropdown = widgets.Dropdown(\n", " options=[\"Red\", \"Blue\"],\n", " description=\"Select Team:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Input for width\n", "width_input = widgets.IntText(\n", " value=640,\n", " description=\"Width:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Input for height\n", "height_input = widgets.IntText(\n", " value=360,\n", " description=\"Height:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Input for number of inference steps\n", "num_inference_steps_input = widgets.IntSlider(\n", " value=20,\n", " min=10,\n", " max=100,\n", " step=1,\n", " description=\"Inference Steps:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Input for guidance scale\n", "guidance_scale_input = widgets.FloatSlider(\n", " value=2,\n", " min=1.0,\n", " max=20.0,\n", " step=0.5,\n", " description=\"Guidance Scale:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Input for seed\n", "seed_input = widgets.IntText(\n", " value=random.randint(0, 1000000),\n", " description=\"Seed:\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Checkbox to randomize seed\n", "randomize_seed_checkbox = widgets.Checkbox(\n", " value=True,\n", " description=\"Randomize Seed\",\n", " style={\"description_width\": \"initial\"}\n", ")\n", "\n", "# Text box for custom prompt\n", "custom_prompt_input = widgets.Textarea(\n", " value=\"\",\n", " placeholder=\"Enter your custom prompt (up to 200 characters)...\",\n", " description=\"Custom Prompt:\",\n", " style={\"description_width\": \"initial\"},\n", " layout=widgets.Layout(width=\"500px\", height=\"80px\")\n", ")\n", "\n", "# Button to generate image\n", "generate_button = widgets.Button(\n", " description=\"Generate Image\",\n", " button_style=\"success\"\n", ")\n", "\n", "# Output area to display the image\n", "output = widgets.Output()\n", "\n", "def on_generate_button_clicked(b):\n", " with output:\n", " clear_output(wait=True) # Clear previous output\n", "\n", " # Get selected values from widgets\n", " selected_prompt = prompt_dropdown.value\n", " selected_team = team_dropdown.value\n", " selected_model = model_dropdown.value\n", " height = height_input.value\n", " width = width_input.value\n", " num_inference_steps = num_inference_steps_input.value\n", " guidance_scale = guidance_scale_input.value\n", " seed = seed_input.value\n", " custom_prompt = custom_prompt_input.value\n", "\n", " # Debug: Show selected parameters\n", " print(\"=== Debug: Selected Parameters ===\")\n", " print(f\"Selected Model: {model_dropdown.label}\")\n", " print(f\"Selected Prompt: {prompt_dropdown.label}\")\n", " print(f\"Selected Team: {selected_team}\")\n", " print(f\"Height: {height}\")\n", " print(f\"Width: {width}\")\n", " print(f\"Inference Steps: {num_inference_steps}\")\n", " print(f\"Guidance Scale: {guidance_scale}\")\n", " print(f\"Seed: {seed}\")\n", " print(f\"Custom Prompt: {custom_prompt}\")\n", " print(\"==================================\")\n", "\n", " # Generate the image\n", " print(\"=== Debug: Calling generate_image ===\")\n", " image = generate_image(\n", " selected_prompt, selected_team, selected_model, width, height,\n", " num_inference_steps, guidance_scale, seed, custom_prompt, api_token,\n", " randomize_seed=randomize_seed_checkbox.value\n", " )\n", "\n", " # Debug: Check the output of generate_image\n", " #print(\"=== Debug: generate_image Output ===\")\n", " #print(f\"Image: {image}\")\n", " #print(\"====================================\")\n", "\n", " if isinstance(image, str):\n", " print(\"=== Debug: Error ===\")\n", " print(image)\n", " else:\n", " # Debug: Indicate that the image is being displayed and saved\n", " print(\"=== Debug: Image Generation ===\")\n", " print(\"Image generated successfully!\")\n", " print(\"Displaying image...\")\n", "\n", " # Display the image in the notebook\n", " display(image)\n", "\n", " # Save the image with a timestamped filename\n", " output_filename = save_image(image, model_dropdown.label, seed, prompt_dropdown.label, selected_team)\n", " print(f\"Image saved as {output_filename}\")\n", "\n", "# Attach the button click event handler\n", "generate_button.on_click(on_generate_button_clicked)\n", "\n", "# Display the widgets\n", "display(prompt_dropdown, team_dropdown, model_dropdown, custom_prompt_input, seed_input, randomize_seed_checkbox, generate_button, output)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.12.0" } }, "nbformat": 4, "nbformat_minor": 2 }