Spaces:
Running
Running
import unittest | |
from app import update_prompt | |
from prod_model import ProductionDesignModel | |
from PIL import Image | |
import numpy as np | |
class TestPromptGeneration(unittest.TestCase): | |
def setUp(self): | |
"""Set up default values for tests""" | |
self.default_params = { | |
"room": "Living Room", | |
"style": "Modern", | |
"colors": "Neutral", | |
"floor_t": "Keep Existing", | |
"floor_c": "Keep Existing", | |
"floor_p": "Keep Existing", | |
"wall_t": "Keep Existing", | |
"wall_c": "Keep Existing", | |
"wall_f": "Keep Existing", | |
"custom_text": "", | |
"art_en": False, | |
"art_col": "None", | |
"art_size": "None", | |
"mirror_en": False, | |
"mirror_fr": "None", | |
"mirror_size": "Medium", | |
"sconce_en": False, | |
"sconce_col": "None", | |
"sconce_style": "Modern", | |
"shelf_en": False, | |
"shelf_col": "None", | |
"shelf_size": "Medium", | |
"plants_en": False, | |
"plants_type": "None", | |
"plants_size": "Medium" | |
} | |
def test_basic_room_style(self): | |
"""Test basic room and style prompt generation""" | |
prompt = update_prompt(**self.default_params) | |
expected = "Design a Modern living room with a Neutral color scheme" | |
self.assertEqual(prompt, expected) | |
def test_all_room_types(self): | |
"""Test all room types""" | |
room_types = [ | |
"Living Room", "Bedroom", "Kitchen", "Dining Room", | |
"Bathroom", "Home Office", "Kids Room", "Master Bedroom", | |
"Guest Room", "Studio Apartment", "Entryway", "Hallway", | |
"Game Room", "Library", "Home Theater", "Gym" | |
] | |
for room in room_types: | |
params = self.default_params.copy() | |
params["room"] = room | |
prompt = update_prompt(**params) | |
expected = f"Design a Modern {room.lower()} with a Neutral color scheme" | |
self.assertEqual(prompt, expected) | |
def test_all_styles(self): | |
"""Test all style presets""" | |
styles = [ | |
"Modern", "Contemporary", "Minimalist", "Industrial", | |
"Scandinavian", "Mid-Century Modern", "Traditional", | |
"Transitional", "Farmhouse", "Rustic", "Bohemian", | |
"Art Deco", "Coastal", "Mediterranean", "Japanese", | |
"French Country", "Victorian", "Colonial", "Gothic", | |
"Baroque", "Rococo", "Neoclassical", "Eclectic", | |
"Zen", "Tropical", "Shabby Chic", "Hollywood Regency", | |
"Southwestern", "Asian Fusion", "Retro" | |
] | |
for style in styles: | |
params = self.default_params.copy() | |
params["style"] = style | |
prompt = update_prompt(**params) | |
expected = f"Design a {style} living room with a Neutral color scheme" | |
self.assertEqual(prompt, expected) | |
def test_all_color_schemes(self): | |
"""Test all color schemes""" | |
color_schemes = [ | |
"Neutral", "Monochromatic", "Minimalist White", | |
"Warm Gray", "Cool Gray", "Earth Tones", | |
"Pastel", "Bold Primary", "Jewel Tones", | |
"Black and White", "Navy and Gold", "Forest Green", | |
"Desert Sand", "Ocean Blue", "Sunset Orange", | |
"Deep Purple", "Emerald Green", "Ruby Red", | |
"Sapphire Blue", "Golden Yellow", "Sage Green", | |
"Dusty Rose", "Charcoal", "Cream", "Burgundy", | |
"Teal", "Copper", "Silver", "Bronze", "Slate" | |
] | |
for color in color_schemes: | |
params = self.default_params.copy() | |
params["colors"] = color | |
prompt = update_prompt(**params) | |
expected = f"Design a Modern living room with a {color} color scheme" | |
self.assertEqual(prompt, expected) | |
def test_floor_combinations(self): | |
"""Test various floor combinations""" | |
test_cases = [ | |
{ | |
"floor_t": "Hardwood", | |
"floor_c": "Keep Existing", | |
"floor_p": "Keep Existing", | |
"expected": "featuring Hardwood flooring" | |
}, | |
{ | |
"floor_t": "Hardwood", | |
"floor_c": "Light Oak", | |
"floor_p": "Keep Existing", | |
"expected": "featuring Hardwood in Light Oak flooring" | |
}, | |
{ | |
"floor_t": "Hardwood", | |
"floor_c": "Light Oak", | |
"floor_p": "Elegant Herringbone", | |
"expected": "featuring Hardwood in Light Oak with Elegant Herringbone pattern flooring" | |
} | |
] | |
for case in test_cases: | |
params = self.default_params.copy() | |
params.update({ | |
"floor_t": case["floor_t"], | |
"floor_c": case["floor_c"], | |
"floor_p": case["floor_p"] | |
}) | |
prompt = update_prompt(**params) | |
expected = f"Design a Modern living room with a Neutral color scheme, {case['expected']}" | |
self.assertEqual(prompt, expected) | |
def test_wall_combinations(self): | |
"""Test various wall combinations""" | |
test_cases = [ | |
{ | |
"wall_t": "Fresh Paint", | |
"wall_c": "Keep Existing", | |
"wall_f": "Keep Existing", | |
"expected": "with Fresh Paint walls" | |
}, | |
{ | |
"wall_t": "Fresh Paint", | |
"wall_c": "Crisp White", | |
"wall_f": "Keep Existing", | |
"expected": "with Fresh Paint in Crisp White walls" | |
}, | |
{ | |
"wall_t": "Fresh Paint", | |
"wall_c": "Crisp White", | |
"wall_f": "Pearl Satin", | |
"expected": "with Fresh Paint in Crisp White with Pearl Satin finish walls" | |
} | |
] | |
for case in test_cases: | |
params = self.default_params.copy() | |
params.update({ | |
"wall_t": case["wall_t"], | |
"wall_c": case["wall_c"], | |
"wall_f": case["wall_f"] | |
}) | |
prompt = update_prompt(**params) | |
expected = f"Design a Modern living room with a Neutral color scheme, {case['expected']}" | |
self.assertEqual(prompt, expected) | |
def test_accessories_individual(self): | |
"""Test each accessory individually""" | |
test_cases = [ | |
{ | |
"name": "art", | |
"params": {"art_en": True, "art_col": "Vibrant Colors", "art_size": "Oversized"}, | |
"expected": "decorated with Oversized Vibrant Colors Art Print" | |
}, | |
{ | |
"name": "mirror", | |
"params": {"mirror_en": True, "mirror_fr": "Gold", "mirror_size": "Large"}, | |
"expected": "decorated with Large Mirror with Gold frame" | |
}, | |
{ | |
"name": "sconce", | |
"params": {"sconce_en": True, "sconce_col": "Brass", "sconce_style": "Art Deco"}, | |
"expected": "decorated with Art Deco Brass Wall Sconce" | |
}, | |
{ | |
"name": "shelf", | |
"params": {"shelf_en": True, "shelf_col": "Natural Wood", "shelf_size": "Set of 3"}, | |
"expected": "decorated with Set of 3 Natural Wood Floating Shelves" | |
}, | |
{ | |
"name": "plants", | |
"params": {"plants_en": True, "plants_type": "Hanging Plants", "plants_size": "Medium"}, | |
"expected": "decorated with Medium Hanging Plants" | |
} | |
] | |
for case in test_cases: | |
params = self.default_params.copy() | |
params.update(case["params"]) | |
prompt = update_prompt(**params) | |
expected = f"Design a Modern living room with a Neutral color scheme, {case['expected']}" | |
self.assertEqual(prompt, expected) | |
def test_custom_text_variations(self): | |
"""Test custom text handling""" | |
test_cases = [ | |
{"text": "", "should_include": False}, | |
{"text": " ", "should_include": False}, | |
{"text": "Add plants", "should_include": True}, | |
{"text": "Make it cozy and warm", "should_include": True}, | |
{"text": "Multiple\nlines", "should_include": True} | |
] | |
for case in test_cases: | |
params = self.default_params.copy() | |
params["custom_text"] = case["text"] | |
prompt = update_prompt(**params) | |
base = "Design a Modern living room with a Neutral color scheme" | |
if case["should_include"]: | |
expected = f"{base}, {case['text'].strip()}" | |
else: | |
expected = base | |
self.assertEqual(prompt, expected) | |
def test_complex_combinations(self): | |
"""Test complex combinations of all features""" | |
test_cases = [ | |
{ | |
"name": "full_living_room", | |
"params": { | |
"room": "Living Room", | |
"style": "Modern", | |
"colors": "Warm Gray", | |
"floor_t": "Hardwood", | |
"floor_c": "Light Oak", | |
"floor_p": "Elegant Herringbone", | |
"wall_t": "Fresh Paint", | |
"wall_c": "Crisp White", | |
"wall_f": "Pearl Satin", | |
"custom_text": "Make it perfect for entertaining", | |
"art_en": True, | |
"art_col": "Modern Abstract", | |
"art_size": "Statement", | |
"mirror_en": True, | |
"mirror_fr": "Gold", | |
"mirror_size": "Large", | |
"sconce_en": True, | |
"sconce_col": "Brass", | |
"sconce_style": "Art Deco", | |
"shelf_en": True, | |
"shelf_col": "Natural Wood", | |
"shelf_size": "Set of 3", | |
"plants_en": True, | |
"plants_type": "Hanging Plants", | |
"plants_size": "Medium" | |
} | |
}, | |
{ | |
"name": "minimal_bedroom", | |
"params": { | |
"room": "Bedroom", | |
"style": "Japanese", | |
"colors": "Minimalist White", | |
"floor_t": "Natural Bamboo", | |
"floor_c": "Keep Existing", | |
"floor_p": "Keep Existing", | |
"wall_t": "Fresh Paint", | |
"wall_c": "Soft White", | |
"wall_f": "Keep Existing", | |
"custom_text": "Focus on minimalism and zen aesthetics" | |
} | |
} | |
] | |
for case in test_cases: | |
params = self.default_params.copy() | |
params.update(case["params"]) | |
prompt = update_prompt(**params) | |
self.assertTrue(len(prompt) > 0) | |
self.assertTrue(prompt.startswith("Design a")) | |
class TestProductionModel(unittest.TestCase): | |
def setUp(self): | |
"""Set up test environment""" | |
self.model = ProductionDesignModel() | |
# Create a simple test image | |
self.test_image = Image.fromarray(np.zeros((64, 64, 3), dtype=np.uint8)) | |
def test_number_of_variations(self): | |
"""Test that the model correctly handles different numbers of variations""" | |
test_cases = [1, 3, 10, 25, 50] # Test various numbers of variations | |
for num_variations in test_cases: | |
variations = self.model.generate_design( | |
image=self.test_image, | |
num_variations=num_variations, | |
prompt="Test prompt", | |
num_steps=20, # Minimum steps for faster testing | |
guidance_scale=7.5, | |
strength=0.75 | |
) | |
self.assertEqual( | |
len(variations), | |
num_variations, | |
f"Expected {num_variations} variations, got {len(variations)}" | |
) | |
def test_invalid_variation_numbers(self): | |
"""Test handling of invalid numbers of variations""" | |
test_cases = [-1, 0, 51, 100] # Test invalid numbers | |
for num_variations in test_cases: | |
variations = self.model.generate_design( | |
image=self.test_image, | |
num_variations=num_variations, | |
prompt="Test prompt", | |
num_steps=20, | |
guidance_scale=7.5, | |
strength=0.75 | |
) | |
# Should clamp to valid range (1-50) | |
self.assertTrue( | |
1 <= len(variations) <= 50, | |
f"Number of variations {len(variations)} outside valid range 1-50" | |
) | |
if __name__ == '__main__': | |
unittest.main() |