In [None]:
import os
import sys

# Set PYTHONPATH to the project root 
# Solves all problems w subfolders
os.environ["PYTHONPATH"] = os.path.abspath(os.path.join(".."))


# Import necessary libraries
import random
import ipywidgets as widgets
from huggingface_hub import InferenceClient
from IPython.display import display, clear_output
from src.img_gen_colab import generate_image, save_image
from config.config_colab import prompts, api_token # Import from config folder
from PIL import Image
from google.colab import userdata
from datetime import datetime
from config.models import models

# Initialize the InferenceClient with the default model
client = InferenceClient(models[0]["name"], token=api_token)

# Dropdown menu for model selection
model_dropdown = widgets.Dropdown(
 options=[(model["alias"], model["name"]) for model in models],
 description="Select Model:",
 style={"description_width": "initial"}
)

# Dropdown menu for prompt selection
prompt_dropdown = widgets.Dropdown(
 options=[(prompt["alias"], prompt["text"]) for prompt in prompts],
 description="Select Prompt:",
 style={"description_width": "initial"}
)

# Dropdown menu for team selection
team_dropdown = widgets.Dropdown(
 options=["Red", "Blue"],
 description="Select Team:",
 style={"description_width": "initial"}
)

# Input for width
width_input = widgets.IntText(
 value=640,
 description="Width:",
 style={"description_width": "initial"}
)

# Input for height
height_input = widgets.IntText(
 value=360,
 description="Height:",
 style={"description_width": "initial"}
)

# Input for number of inference steps
num_inference_steps_input = widgets.IntSlider(
 value=20,
 min=10,
 max=100,
 step=1,
 description="Inference Steps:",
 style={"description_width": "initial"}
)

# Input for guidance scale
guidance_scale_input = widgets.FloatSlider(
 value=2,
 min=1.0,
 max=20.0,
 step=0.5,
 description="Guidance Scale:",
 style={"description_width": "initial"}
)

# Input for seed
seed_input = widgets.IntText(
 value=random.randint(0, 1000000),
 description="Seed:",
 style={"description_width": "initial"}
)

# Checkbox to randomize seed
randomize_seed_checkbox = widgets.Checkbox(
 value=True,
 description="Randomize Seed",
 style={"description_width": "initial"}
)

# Text box for custom prompt
custom_prompt_input = widgets.Textarea(
 value="",
 placeholder="Enter your custom prompt (up to 200 characters)...",
 description="Custom Prompt:",
 style={"description_width": "initial"},
 layout=widgets.Layout(width="500px", height="80px")
)

# Button to generate image
generate_button = widgets.Button(
 description="Generate Image",
 button_style="success"
)

# Output area to display the image
output = widgets.Output()

def on_generate_button_clicked(b):
 with output:
 clear_output(wait=True) # Clear previous output

 # Get selected values from widgets
 selected_prompt = prompt_dropdown.value
 selected_team = team_dropdown.value
 selected_model = model_dropdown.value
 height = height_input.value
 width = width_input.value
 num_inference_steps = num_inference_steps_input.value
 guidance_scale = guidance_scale_input.value
 seed = seed_input.value
 custom_prompt = custom_prompt_input.value

 # Debug: Show selected parameters
 print("=== Debug: Selected Parameters ===")
 print(f"Selected Model: {model_dropdown.label}")
 print(f"Selected Prompt: {prompt_dropdown.label}")
 print(f"Selected Team: {selected_team}")
 print(f"Height: {height}")
 print(f"Width: {width}")
 print(f"Inference Steps: {num_inference_steps}")
 print(f"Guidance Scale: {guidance_scale}")
 print(f"Seed: {seed}")
 print(f"Custom Prompt: {custom_prompt}")
 print("==================================")

 # Generate the image
 print("=== Debug: Calling generate_image ===")
 image = generate_image(
 selected_prompt, selected_team, selected_model, width, height,
 num_inference_steps, guidance_scale, seed, custom_prompt, api_token,
 randomize_seed=randomize_seed_checkbox.value
 )

 # Debug: Check the output of generate_image
 #print("=== Debug: generate_image Output ===")
 #print(f"Image: {image}")
 #print("====================================")

 if isinstance(image, str):
 print("=== Debug: Error ===")
 print(image)
 else:
 # Debug: Indicate that the image is being displayed and saved
 print("=== Debug: Image Generation ===")
 print("Image generated successfully!")
 print("Displaying image...")

 # Display the image in the notebook
 display(image)

 # Save the image with a timestamped filename
 output_filename = save_image(image, model_dropdown.label, seed, prompt_dropdown.label, selected_team)
 print(f"Image saved as {output_filename}")

# Attach the button click event handler
generate_button.on_click(on_generate_button_clicked)

# Display the widgets
display(prompt_dropdown, team_dropdown, model_dropdown, custom_prompt_input, seed_input, randomize_seed_checkbox, generate_button, output)