Spaces:
Running
Running
anbucur
Refactor prompt generation logic in app.py to enhance validation and user experience
25e8e9c
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import time | |
import os | |
import random | |
from typing import List | |
import traceback | |
from google.oauth2.credentials import Credentials | |
from google_auth_oauthlib.flow import InstalledAppFlow | |
from googleapiclient.discovery import build | |
from googleapiclient.http import MediaIoBaseUpload | |
from io import BytesIO | |
import datetime | |
# Import the model interface | |
from model import DesignModel | |
# For testing, import the mock model | |
from mock_model import MockDesignModel | |
def create_ui(model: DesignModel): | |
"""Create the user interface for the application""" | |
with gr.Blocks() as interface: | |
gr.Markdown("## 🏠 Basic Settings") | |
with gr.Row(): | |
with gr.Column(): | |
room_type = gr.Dropdown( | |
choices=[ | |
"None", "Living Room", "Bedroom", "Dining Room", "Kitchen", | |
"Bathroom", "Home Office", "Master Bedroom", "Guest Room", | |
"Family Room", "Study", "Game Room", "Library", | |
"Nursery", "Craft Room" | |
], | |
label="Room Type", | |
value="None" | |
) | |
style_preset = gr.Dropdown( | |
choices=[ | |
"None", "Modern", "Contemporary", "Minimalist", "Industrial", | |
"Scandinavian", "Mid-Century Modern", "Traditional", | |
"Transitional", "Farmhouse", "Rustic", "Bohemian", | |
"Art Deco", "Coastal", "Mediterranean", "Japanese", | |
"French Country", "Victorian", "Colonial", "Gothic", | |
"Baroque", "Rococo", "Neoclassical", "Eclectic", | |
"Zen", "Tropical", "Shabby Chic", "Hollywood Regency", | |
"Southwestern", "Asian Fusion", "Retro" | |
], | |
label="Design Style", | |
value="None" | |
) | |
color_scheme = gr.Dropdown( | |
choices=[ | |
"None", "Neutral", "Monochromatic", "Minimalist White", | |
"Warm Gray", "Cool Gray", "Earth Tones", | |
"Pastel", "Bold Primary", "Jewel Tones", | |
"Black and White", "Navy and Gold", "Forest Green", | |
"Desert Sand", "Ocean Blue", "Sunset Orange", | |
"Deep Purple", "Emerald Green", "Ruby Red", | |
"Sapphire Blue", "Golden Yellow", "Sage Green", | |
"Dusty Rose", "Charcoal", "Cream", "Burgundy", | |
"Teal", "Copper", "Silver", "Bronze", "Slate" | |
], | |
label="Color Mood", | |
value="None" | |
) | |
# Row 2 - Surface Finishes | |
with gr.Row(): | |
# Floor Options | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("## 🎨 Floor Options") | |
floor_type = gr.Dropdown( | |
choices=[ | |
"Keep Existing", "Hardwood", "Stone Tiles", "Porcelain Tiles", | |
"Soft Carpet", "Polished Concrete", "Marble", "Vinyl", | |
"Natural Bamboo", "Cork", "Ceramic Tiles", "Terrazzo", | |
"Slate", "Travertine", "Laminate", "Engineered Wood", | |
"Mosaic Tiles", "Luxury Vinyl Tiles", "Stained Concrete" | |
], | |
label="Material", | |
value="Keep Existing" | |
) | |
floor_color = gr.Dropdown( | |
choices=[ | |
"Keep Existing", "Light Oak", "Rich Walnut", "Cool Gray", | |
"Whitewashed", "Warm Cherry", "Deep Brown", "Classic Black", | |
"Natural", "Sandy Beige", "Chocolate", "Espresso", | |
"Honey Oak", "Weathered Gray", "White Marble", | |
"Cream Travertine", "Dark Slate", "Golden Teak", | |
"Rustic Pine", "Ebony" | |
], | |
label="Color", | |
value="Keep Existing" | |
) | |
floor_pattern = gr.Dropdown( | |
choices=[ | |
"Keep Existing", "Classic Straight", "Elegant Herringbone", | |
"V-Pattern", "Decorative Parquet", "Diagonal Layout", | |
"Basketweave", "Chevron", "Random Length", "Grid Pattern", | |
"Versailles Pattern", "Running Bond", "Hexagonal", | |
"Moroccan Pattern", "Brick Layout", "Diamond Pattern", | |
"Windmill Pattern", "Large Format", "Mixed Width" | |
], | |
label="Pattern", | |
value="Keep Existing" | |
) | |
# Wall Options | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("## 🎨 Wall Options") | |
wall_type = gr.Dropdown( | |
choices=[ | |
"Keep Existing", "Fresh Paint", "Designer Wallpaper", | |
"Textured Finish", "Wood Panels", "Exposed Brick", | |
"Natural Stone", "Wooden Planks", "Modern Concrete", | |
"Venetian Plaster", "Wainscoting", "Shiplap", | |
"3D Wall Panels", "Fabric Panels", "Metal Panels", | |
"Cork Wall", "Tile Feature", "Glass Panels", | |
"Acoustic Panels", "Living Wall" | |
], | |
label="Treatment", | |
value="Keep Existing" | |
) | |
wall_color = gr.Dropdown( | |
choices=[ | |
"Keep Existing", "Crisp White", "Soft White", "Warm Beige", | |
"Gentle Gray", "Sky Blue", "Nature Green", "Sunny Yellow", | |
"Blush Pink", "Deep Blue", "Bold Black", "Sage Green", | |
"Terracotta", "Navy Blue", "Charcoal Gray", "Lavender", | |
"Olive Green", "Dusty Rose", "Teal", "Burgundy" | |
], | |
label="Color", | |
value="Keep Existing" | |
) | |
wall_finish = gr.Dropdown( | |
choices=[ | |
"Keep Existing", "Soft Matte", "Subtle Eggshell", | |
"Pearl Satin", "Sleek Semi-Gloss", "High Gloss", | |
"Suede Texture", "Metallic", "Chalk Finish", | |
"Distressed", "Brushed", "Smooth", "Textured", | |
"Venetian", "Lime Wash", "Concrete", "Rustic", | |
"Lacquered", "Hammered", "Patina" | |
], | |
label="Finish", | |
value="Keep Existing" | |
) | |
# Row 3 - Wall Decorations and Special Requests | |
with gr.Row(elem_classes="wall-decorations-row"): | |
# Wall Decorations | |
with gr.Column(scale=2): | |
with gr.Group(): | |
gr.Markdown("## 🖼️ Wall Decorations") | |
# Art and Mirror | |
with gr.Row(): | |
# Art Print | |
with gr.Column(): | |
with gr.Row(): | |
art_print_enable = gr.Checkbox(label="Add Artwork", value=False) | |
art_print_color = gr.Dropdown( | |
choices=[ | |
"None", "Classic Black & White", "Vibrant Colors", | |
"Single Color", "Soft Colors", "Modern Abstract", | |
"Earth Tones", "Pastel Palette", "Bold Primary Colors", | |
"Metallic Accents", "Monochromatic", "Jewel Tones", | |
"Watercolor", "Vintage Colors", "Neon Accents", | |
"Natural Hues", "Ocean Colors", "Desert Palette" | |
], | |
label="Art Style", | |
value="None" | |
) | |
art_print_size = gr.Dropdown( | |
choices=[ | |
"None", "Modest", "Standard", "Statement", "Oversized", | |
"Gallery Wall", "Diptych", "Triptych", "Mini Series", | |
"Floor to Ceiling", "Custom Size" | |
], | |
label="Art Size", | |
value="None" | |
) | |
# Mirror | |
with gr.Column(): | |
with gr.Row(): | |
mirror_enable = gr.Checkbox(label="Add Mirror", value=False) | |
mirror_frame = gr.Dropdown( | |
choices=[ | |
"None", "Gold", "Silver", "Black", "White", "Wood", | |
"Brass", "Bronze", "Copper", "Chrome", "Antique Gold", | |
"Brushed Nickel", "Rustic Wood", "Ornate", "Minimalist", | |
"LED Backlit", "Bamboo", "Rattan", "Leather Wrapped" | |
], | |
label="Frame Style", | |
value="None" | |
) | |
mirror_size = gr.Dropdown( | |
choices=[ | |
"Small", "Medium", "Large", "Full Length", | |
"Oversized", "Double Width", "Floor Mirror", | |
"Vanity Size", "Statement Piece", "Custom Size" | |
], | |
label="Mirror Size", | |
value="Medium" | |
) | |
# Sconce, Shelf, and Plants | |
with gr.Row(): | |
# Sconce | |
with gr.Column(): | |
with gr.Row(): | |
sconce_enable = gr.Checkbox(label="Add Wall Sconce", value=False) | |
sconce_color = gr.Dropdown( | |
choices=[ | |
"None", "Black", "Gold", "Silver", "Bronze", "White", | |
"Brass", "Copper", "Chrome", "Antique Brass", | |
"Brushed Nickel", "Oil-Rubbed Bronze", "Pewter", | |
"Rose Gold", "Matte Black", "Polished Nickel", | |
"Aged Brass", "Champagne", "Gunmetal" | |
], | |
label="Sconce Color", | |
value="None" | |
) | |
sconce_style = gr.Dropdown( | |
choices=[ | |
"Modern", "Traditional", "Industrial", "Art Deco", | |
"Minimalist", "Vintage", "Contemporary", "Rustic", | |
"Coastal", "Farmhouse", "Mid-Century", "Bohemian", | |
"Scandinavian", "Asian", "Mediterranean", "Gothic", | |
"Transitional", "Eclectic", "Victorian" | |
], | |
label="Sconce Style", | |
value="Modern" | |
) | |
# Floating Shelves | |
with gr.Column(): | |
with gr.Row(): | |
shelf_enable = gr.Checkbox(label="Add Floating Shelves", value=False) | |
shelf_color = gr.Dropdown( | |
choices=[ | |
"None", "White", "Black", "Natural Wood", "Glass", | |
"Dark Wood", "Light Wood", "Metal", "Gold", "Silver", | |
"Bronze", "Reclaimed Wood", "Bamboo", "Marble", | |
"Industrial Metal", "Two-Tone", "Concrete", | |
"Acrylic", "Copper", "Brass" | |
], | |
label="Shelf Material", | |
value="None" | |
) | |
shelf_size = gr.Dropdown( | |
choices=[ | |
"Small", "Medium", "Large", "Set of 3", | |
"Extra Long", "Corner Set", "Asymmetric Set", | |
"Graduated Sizes", "Custom Length", "Mini Cubes", | |
"Full Wall", "Mixed Sizes", "Modular System" | |
], | |
label="Shelf Size", | |
value="Medium" | |
) | |
# Plants | |
with gr.Column(): | |
with gr.Row(): | |
plants_enable = gr.Checkbox(label="Add Plants", value=False) | |
plants_type = gr.Dropdown( | |
choices=[ | |
"None", "Hanging Plants", "Vertical Garden", | |
"Plant Shelf", "Single Plant", "Climbing Vines", | |
"Air Plants", "Succulent Wall", "Herb Garden", | |
"Mixed Tropical", "Fern Collection", "Living Wall", | |
"Moss Wall", "Potted Arrangement", "Plant Corner", | |
"Cascading Plants", "Bamboo Screen", "Terrarium Wall" | |
], | |
label="Plant Type", | |
value="None" | |
) | |
plants_size = gr.Dropdown( | |
choices=[ | |
"Small", "Medium", "Large", "Mixed Sizes", | |
"Full Wall", "Statement Piece", "Compact", | |
"Expansive", "Accent", "Floor to Ceiling", | |
"Window Height", "Custom Size", "Modular" | |
], | |
label="Plant Coverage", | |
value="Medium" | |
) | |
# Special Requests and Advanced Settings | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("## ✨ Special Requests") | |
input_text = gr.Textbox( | |
label="Additional Details", | |
placeholder="Add any special requests or details here...", | |
lines=3 | |
) | |
num_outputs = gr.Slider( | |
minimum=1, maximum=50, value=1, step=1, | |
label="Number of Variations" | |
) | |
gr.Markdown("### Advanced Settings") | |
num_steps = gr.Slider( | |
minimum=20, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Quality Steps" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=10.0, | |
step=0.1, | |
label="Design Freedom" | |
) | |
strength = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Change Amount" | |
) | |
seed = gr.Number( | |
label="Seed (leave empty for random)", | |
value=-1, | |
precision=0 | |
) | |
with gr.Row(): | |
save_to_drive = gr.Checkbox(label="Save to Google Drive") | |
drive_url = gr.Textbox( | |
label="Drive Folder URL", | |
placeholder="https://drive.google.com/drive/folders/..." | |
) | |
# Row 4 - Current Prompts | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("## 📝 Current Prompts") | |
prompt_display = gr.TextArea( | |
label="Positive Prompt", | |
interactive=False, | |
lines=3, | |
value="Your design prompt will appear here..." | |
) | |
negative_prompt = gr.TextArea( | |
label="Negative Prompt", | |
value="blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped", | |
lines=2, | |
interactive=False | |
) | |
# Row 5 - Upload and Gallery | |
with gr.Row(elem_classes="upload-gallery-row"): | |
# Upload Area | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("## 📸 Upload Photo") | |
input_image = gr.Image( | |
label="Upload a photo of your room", | |
type='pil' | |
) | |
# Gallery Area | |
with gr.Column(scale=2): | |
with gr.Group(): | |
gr.Markdown("## 🖼️ Generated Variations") | |
gallery = gr.Gallery( | |
show_label=False, | |
elem_id="gallery", | |
columns=4, | |
rows=1, | |
height="300px", | |
object_fit="contain", | |
preview=True, | |
show_share_button=False | |
) | |
# Row 6 - Create Button | |
with gr.Row(elem_classes="button-row"): | |
submit = gr.Button("✨ Create My Design", variant="primary", size="lg") | |
def is_valid_value(*values): | |
"""Helper function to check if values are valid (not None or Keep Existing)""" | |
return all(v not in ["None", "Keep Existing"] for v in values) | |
def update_prompt(room, style, colors, floor_t, floor_c, floor_p, | |
wall_t, wall_c, wall_f, custom_text, | |
art_en, art_col, art_size, | |
mirror_en, mirror_fr, mirror_size, | |
sconce_en, sconce_col, sconce_style, | |
shelf_en, shelf_col, shelf_size, | |
plants_en, plants_type, plants_size): | |
"""Generate a prompt for the design, skipping any None or Keep Existing values.""" | |
invalid_values = {"None", "Keep Existing", None, ""} | |
prompt_parts = [] | |
# Basic room settings - each part separate for maximum flexibility | |
if all(x not in invalid_values for x in [style, room, colors]): | |
base_parts = [] | |
if style not in invalid_values: | |
base_parts.append(style) | |
if room not in invalid_values: | |
base_parts.append(room.lower()) | |
if colors not in invalid_values: | |
base_parts.append(f"with a {colors} color scheme") | |
if base_parts: | |
prompt_parts.append("Design a " + " ".join(base_parts)) | |
# Floor details - build up piece by piece | |
if floor_t not in invalid_values: | |
floor_parts = [floor_t] | |
if floor_c not in invalid_values: | |
floor_parts.append(f"in {floor_c}") | |
if floor_p not in invalid_values: | |
floor_parts.append(f"with {floor_p} pattern") | |
prompt_parts.append(f"featuring {' '.join(floor_parts)} flooring") | |
# Wall details - build up piece by piece | |
if wall_t not in invalid_values: | |
wall_parts = [wall_t] | |
if wall_c not in invalid_values: | |
wall_parts.append(f"in {wall_c}") | |
if wall_f not in invalid_values: | |
wall_parts.append(f"with {wall_f} finish") | |
prompt_parts.append(f"with {' '.join(wall_parts)} walls") | |
# Accessories - only add valid combinations | |
accessories = [] | |
# Art Print | |
if art_en and art_col not in invalid_values and art_size not in invalid_values: | |
accessories.append(f"{art_size} {art_col} Art Print") | |
# Mirror | |
if mirror_en and mirror_fr not in invalid_values and mirror_size not in invalid_values: | |
accessories.append(f"{mirror_size} Mirror with {mirror_fr} frame") | |
# Wall Sconce | |
if sconce_en and sconce_col not in invalid_values and sconce_style not in invalid_values: | |
accessories.append(f"{sconce_style} {sconce_col} Wall Sconce") | |
# Floating Shelves | |
if shelf_en and shelf_col not in invalid_values and shelf_size not in invalid_values: | |
accessories.append(f"{shelf_size} {shelf_col} Floating Shelves") | |
# Wall Plants | |
if plants_en and plants_type not in invalid_values and plants_size not in invalid_values: | |
accessories.append(f"{plants_size} {plants_type}") | |
if accessories: | |
prompt_parts.append("decorated with " + ", ".join(accessories)) | |
# Add custom text if present | |
if custom_text and custom_text.strip() and custom_text not in invalid_values: | |
prompt_parts.append(custom_text.strip()) | |
return ", ".join(prompt_parts) if prompt_parts else "Please select room type, style, and color scheme" | |
def on_submit(image, room, style, colors, floor_t, floor_c, floor_p, | |
wall_t, wall_c, wall_f, custom_text, | |
art_en, art_col, art_size, | |
mirror_en, mirror_fr, mirror_size, | |
sconce_en, sconce_col, sconce_style, | |
shelf_en, shelf_col, shelf_size, | |
plants_en, plants_type, plants_size, | |
num_outputs, save_to_drive, drive_url, num_steps, | |
guidance_scale, seed, strength): | |
if image is None: | |
return [] | |
try: | |
# Generate the prompt | |
prompt = update_prompt( | |
room, style, colors, floor_t, floor_c, floor_p, | |
wall_t, wall_c, wall_f, custom_text, | |
art_en, art_col, art_size, | |
mirror_en, mirror_fr, mirror_size, | |
sconce_en, sconce_col, sconce_style, | |
shelf_en, shelf_col, shelf_size, | |
plants_en, plants_type, plants_size | |
) | |
# Generate variations | |
variations = model.generate_design( | |
image=image, | |
prompt=prompt, | |
num_variations=max(1, int(num_outputs)), | |
num_steps=int(num_steps), | |
guidance_scale=float(guidance_scale), | |
strength=float(strength), | |
seed=int(seed) if seed != -1 else None | |
) | |
# Handle Google Drive upload if enabled | |
if save_to_drive and drive_url: | |
folder_id = extract_folder_id(drive_url) | |
if folder_id: | |
for variation in variations: | |
upload_to_drive(variation, folder_id) | |
# Convert variations to gallery format | |
gallery_images = [(v, None) for v in variations] | |
return gallery_images | |
except Exception as e: | |
print(f"Error in generation: {e}") | |
return [] | |
submit.click( | |
on_submit, | |
inputs=[ | |
input_image, room_type, style_preset, color_scheme, | |
floor_type, floor_color, floor_pattern, | |
wall_type, wall_color, wall_finish, | |
input_text, | |
art_print_enable, art_print_color, art_print_size, | |
mirror_enable, mirror_frame, mirror_size, | |
sconce_enable, sconce_color, sconce_style, | |
shelf_enable, shelf_color, shelf_size, | |
plants_enable, plants_type, plants_size, | |
num_outputs, save_to_drive, drive_url, num_steps, | |
guidance_scale, seed, strength | |
], | |
outputs=[gallery] | |
) | |
# Update prompt display when any input changes | |
def update_prompt_display(*args): | |
try: | |
prompt = update_prompt(*args) | |
return [prompt, negative_prompt.value] # Return both prompts | |
except Exception as e: | |
print(f"Error updating prompt: {e}") | |
return ["Error generating prompt", negative_prompt.value] | |
# List of all inputs that should trigger prompt updates | |
prompt_inputs = [ | |
room_type, style_preset, color_scheme, | |
floor_type, floor_color, floor_pattern, | |
wall_type, wall_color, wall_finish, | |
input_text, | |
art_print_enable, art_print_color, art_print_size, | |
mirror_enable, mirror_frame, mirror_size, | |
sconce_enable, sconce_color, sconce_style, | |
shelf_enable, shelf_color, shelf_size, | |
plants_enable, plants_type, plants_size | |
] | |
# Connect all inputs to prompt update | |
for input_component in prompt_inputs: | |
input_component.change( | |
fn=update_prompt_display, | |
inputs=prompt_inputs, | |
outputs=[prompt_display, negative_prompt] | |
) | |
return interface | |
def main(): | |
"""Main entry point for the application""" | |
import sys | |
# Check if we're in test mode | |
is_test_mode = "--test" in sys.argv | |
if is_test_mode: | |
print("Starting in TEST mode...") | |
from mock_model import MockDesignModel | |
model = MockDesignModel() | |
else: | |
print("Starting in PRODUCTION mode...") | |
from prod_model import ProductionDesignModel | |
model = ProductionDesignModel() | |
interface = create_ui(model) | |
interface.launch( | |
share=False, | |
show_api=False, # Hide API docs | |
show_error=True # Show errors for debugging | |
) | |
if __name__ == "__main__": | |
main() | |