anbucur commited on
Commit
4e4b650
·
1 Parent(s): 39751c2

Added functionality

Browse files
Files changed (8) hide show
  1. README.md +107 -0
  2. app.py +888 -364
  3. credentials.json +1 -0
  4. mock_model.py +83 -0
  5. model.py +14 -0
  6. prod_model.py +170 -0
  7. requirements.txt +33 -10
  8. test_prompt.py +320 -0
README.md CHANGED
@@ -11,3 +11,110 @@ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # Interior Design Assistant
16
+
17
+ An AI-powered interior design assistant that helps transform room photos with customizable design variations.
18
+
19
+ ## Features
20
+
21
+ - Upload room photos and generate design variations
22
+ - Customize room type, style, color schemes
23
+ - Adjust floor and wall finishes
24
+ - Add wall decorations (art, mirrors, sconces)
25
+ - Control generation parameters (quality, creativity, etc.)
26
+ - Save results to Google Drive
27
+ - Production and test modes
28
+
29
+ ## Requirements
30
+
31
+ - Python 3.8 or higher
32
+ - CUDA-capable GPU with 8GB+ VRAM (recommended)
33
+ - CPU-only mode supported but slower
34
+
35
+ ## Installation
36
+
37
+ 1. Clone the repository:
38
+ ```bash
39
+ git clone [your-repo-url]
40
+ cd StableDesign2
41
+ ```
42
+
43
+ 2. Create and activate a virtual environment:
44
+ ```bash
45
+ python -m venv venv
46
+ source venv/bin/activate # Linux/Mac
47
+ # or
48
+ .\venv\Scripts\activate # Windows
49
+ ```
50
+
51
+ 3. Install dependencies:
52
+ ```bash
53
+ pip install -r requirements.txt
54
+ ```
55
+
56
+ 4. Set up Google Drive integration (optional):
57
+ - Create a project in Google Cloud Console
58
+ - Enable the Google Drive API
59
+ - Create OAuth 2.0 credentials
60
+ - Download credentials and save as `credentials.json` in the project root
61
+
62
+ ## Usage
63
+
64
+ ### Production Mode
65
+ ```bash
66
+ python app.py
67
+ ```
68
+
69
+ ### Test Mode (for development)
70
+ ```bash
71
+ python app.py --test
72
+ ```
73
+
74
+ The interface will be available at `http://localhost:7860`
75
+
76
+ ## Configuration
77
+
78
+ ### Model Settings
79
+ - Quality Steps: 20-100 (default: 50)
80
+ - Design Freedom: 1-20 (default: 7.5)
81
+ - Change Amount: 0.1-1.0 (default: 0.75)
82
+ - Number of Variations: 1-4
83
+
84
+ ### Design Options
85
+ - Room Types: 16 options
86
+ - Design Styles: 20 options
87
+ - Color Schemes: 20 options
88
+ - Floor & Wall Options: Multiple materials, colors, patterns
89
+ - Wall Decorations: Art, mirrors, sconces, shelves, plants
90
+
91
+ ## Error Handling
92
+
93
+ The application includes comprehensive error handling:
94
+ - Input validation
95
+ - Model generation fallbacks
96
+ - Google Drive upload retries
97
+ - Detailed error logging
98
+
99
+ ## Development
100
+
101
+ ### Running Tests
102
+ ```bash
103
+ python -m pytest
104
+ ```
105
+
106
+ ### Code Style
107
+ ```bash
108
+ black .
109
+ flake8
110
+ isort .
111
+ ```
112
+
113
+ ## License
114
+
115
+ [Your License]
116
+
117
+ ## Credits
118
+
119
+ - Built with [Gradio](https://gradio.app/)
120
+ - Powered by [Stable Diffusion](https://stability.ai/)
app.py CHANGED
@@ -1,377 +1,901 @@
1
- import spaces
2
- from typing import Tuple, Union, List
3
- import os
4
-
5
- import numpy as np
6
- from PIL import Image
7
-
8
- import torch
9
- from diffusers.pipelines.controlnet import StableDiffusionControlNetInpaintPipeline
10
- from diffusers import ControlNetModel, UniPCMultistepScheduler, AutoPipelineForText2Image
11
- from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, AutoModelForDepthEstimation
12
- from colors import ade_palette
13
- from utils import map_colors_rgb
14
- from diffusers import StableDiffusionXLPipeline
15
  import gradio as gr
16
- import gc
17
-
18
- device = "cuda"
19
- dtype = torch.float16
20
-
21
-
22
- css = """
23
- #img-display-container {
24
- max-height: 50vh;
25
- }
26
- #img-display-input {
27
- max-height: 40vh;
28
- }
29
- #img-display-output {
30
- max-height: 40vh;
31
- }
32
-
33
- """
34
-
35
-
36
- def filter_items(
37
- colors_list: Union[List, np.ndarray],
38
- items_list: Union[List, np.ndarray],
39
- items_to_remove: Union[List, np.ndarray]
40
- ) -> Tuple[Union[List, np.ndarray], Union[List, np.ndarray]]:
41
- """
42
- Filters items and their corresponding colors from given lists, excluding
43
- specified items.
44
-
45
- Args:
46
- colors_list: A list or numpy array of colors corresponding to items.
47
- items_list: A list or numpy array of items.
48
- items_to_remove: A list or numpy array of items to be removed.
49
-
50
- Returns:
51
- A tuple of two lists or numpy arrays: filtered colors and filtered
52
- items.
53
- """
54
- filtered_colors = []
55
- filtered_items = []
56
- for color, item in zip(colors_list, items_list):
57
- if item not in items_to_remove:
58
- filtered_colors.append(color)
59
- filtered_items.append(item)
60
- return filtered_colors, filtered_items
61
-
62
- def get_segmentation_pipeline(
63
- ) -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
64
- """Method to load the segmentation pipeline
65
- Returns:
66
- Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
67
- """
68
- image_processor = AutoImageProcessor.from_pretrained(
69
- "openmmlab/upernet-convnext-small"
70
- )
71
- image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
72
- "openmmlab/upernet-convnext-small"
73
- )
74
- return image_processor, image_segmentor
75
-
76
-
77
- @torch.inference_mode()
78
- @spaces.GPU
79
- def segment_image(
80
- image: Image,
81
- image_processor: AutoImageProcessor,
82
- image_segmentor: UperNetForSemanticSegmentation
83
- ) -> Image:
84
- """
85
- Segments an image using a semantic segmentation model.
86
-
87
- Args:
88
- image (Image): The input image to be segmented.
89
- image_processor (AutoImageProcessor): The processor to prepare the
90
- image for segmentation.
91
- image_segmentor (UperNetForSemanticSegmentation): The semantic
92
- segmentation model used to identify different segments in the image.
93
-
94
- Returns:
95
- Image: The segmented image with each segment colored differently based
96
- on its identified class.
97
- """
98
- # image_processor, image_segmentor = get_segmentation_pipeline()
99
- pixel_values = image_processor(image, return_tensors="pt").pixel_values
100
- with torch.no_grad():
101
- outputs = image_segmentor(pixel_values)
102
-
103
- seg = image_processor.post_process_semantic_segmentation(
104
- outputs, target_sizes=[image.size[::-1]])[0]
105
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
106
- palette = np.array(ade_palette())
107
- for label, color in enumerate(palette):
108
- color_seg[seg == label, :] = color
109
- color_seg = color_seg.astype(np.uint8)
110
- seg_image = Image.fromarray(color_seg).convert('RGB')
111
- return seg_image
112
-
113
-
114
- def get_depth_pipeline():
115
- feature_extractor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-large-hf",
116
- torch_dtype=dtype)
117
- depth_estimator = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-large-hf",
118
- torch_dtype=dtype)
119
- return feature_extractor, depth_estimator
120
-
121
-
122
- @torch.inference_mode()
123
- @spaces.GPU
124
- def get_depth_image(
125
- image: Image,
126
- feature_extractor: AutoImageProcessor,
127
- depth_estimator: AutoModelForDepthEstimation
128
- ) -> Image:
129
- image_to_depth = feature_extractor(images=image, return_tensors="pt").to(device)
130
- with torch.no_grad():
131
- depth_map = depth_estimator(**image_to_depth).predicted_depth
132
-
133
- width, height = image.size
134
- depth_map = torch.nn.functional.interpolate(
135
- depth_map.unsqueeze(1).float(),
136
- size=(height, width),
137
- mode="bicubic",
138
- align_corners=False,
139
- )
140
- depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
141
- depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
142
- depth_map = (depth_map - depth_min) / (depth_max - depth_min)
143
- image = torch.cat([depth_map] * 3, dim=1)
144
-
145
- image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
146
- image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
147
- return image
148
-
149
-
150
- def resize_dimensions(dimensions, target_size):
151
- """
152
- Resize PIL to target size while maintaining aspect ratio
153
- If smaller than target size leave it as is
154
- """
155
- width, height = dimensions
156
-
157
- # Check if both dimensions are smaller than the target size
158
- if width < target_size and height < target_size:
159
- return dimensions
160
-
161
- # Determine the larger side
162
- if width > height:
163
- # Calculate the aspect ratio
164
- aspect_ratio = height / width
165
- # Resize dimensions
166
- return (target_size, int(target_size * aspect_ratio))
167
- else:
168
- # Calculate the aspect ratio
169
- aspect_ratio = width / height
170
- # Resize dimensions
171
- return (int(target_size * aspect_ratio), target_size)
172
-
173
-
174
- def flush():
175
- gc.collect()
176
- torch.cuda.empty_cache()
177
-
178
-
179
- class ControlNetDepthDesignModelMulti:
180
- """ Produces random noise images """
181
 
182
- def __init__(self):
183
- """ Initialize your model(s) here """
184
- #os.environ['HF_HUB_OFFLINE'] = "True"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- self.seed = 323*111
187
- self.neg_prompt = "window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner"
188
- self.control_items = ["windowpane;window", "door;double;door"]
189
- self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- @spaces.GPU
192
- def generate_design(self, empty_room_image: Image, prompt: str, guidance_scale: int = 10, num_steps: int = 50, strength: float =0.9, img_size: int = 640) -> Image:
193
- """
194
- Given an image of an empty room and a prompt
195
- generate the designed room according to the prompt
196
- Inputs -
197
- empty_room_image - An RGB PIL Image of the empty room
198
- prompt - Text describing the target design elements of the room
199
- Returns -
200
- design_image - PIL Image of the same size as the empty room image
201
- If the size is not the same the submission will fail.
202
- """
203
- print(prompt)
204
- flush()
205
- self.generator = torch.Generator(device=device).manual_seed(self.seed)
206
-
207
- pos_prompt = prompt + f', {self.additional_quality_suffix}'
208
-
209
- orig_w, orig_h = empty_room_image.size
210
- new_width, new_height = resize_dimensions(empty_room_image.size, img_size)
211
- input_image = empty_room_image.resize((new_width, new_height))
212
- real_seg = np.array(segment_image(input_image,
213
- seg_image_processor,
214
- image_segmentor))
215
- unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
216
- unique_colors = [tuple(color) for color in unique_colors]
217
- segment_items = [map_colors_rgb(i) for i in unique_colors]
218
- chosen_colors, segment_items = filter_items(
219
- colors_list=unique_colors,
220
- items_list=segment_items,
221
- items_to_remove=self.control_items
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
- mask = np.zeros_like(real_seg)
224
- for color in chosen_colors:
225
- color_matches = (real_seg == color).all(axis=2)
226
- mask[color_matches] = 1
227
-
228
- image_np = np.array(input_image)
229
- image = Image.fromarray(image_np).convert("RGB")
230
- mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("RGB")
231
- segmentation_cond_image = Image.fromarray(real_seg).convert("RGB")
232
-
233
- image_depth = get_depth_image(image, depth_feature_extractor, depth_estimator)
234
-
235
- # generate image that would be used as IP-adapter
236
- flush()
237
- new_width_ip = int(new_width / 8) * 8
238
- new_height_ip = int(new_height / 8) * 8
239
- ip_image = guide_pipe(pos_prompt,
240
- num_inference_steps=num_steps,
241
- negative_prompt=self.neg_prompt,
242
- height=new_height_ip,
243
- width=new_width_ip,
244
- generator=[self.generator]).images[0]
245
-
246
- flush()
247
- generated_image = pipe(
248
- prompt=pos_prompt,
249
- negative_prompt=self.neg_prompt,
250
- num_inference_steps=num_steps,
251
- strength=strength,
252
- guidance_scale=guidance_scale,
253
- generator=[self.generator],
254
- image=image,
255
- mask_image=mask_image,
256
- ip_adapter_image=ip_image,
257
- control_image=[image_depth, segmentation_cond_image],
258
- controlnet_conditioning_scale=[0.5, 0.5]
259
- ).images[0]
260
-
261
- flush()
262
- design_image = generated_image.resize(
263
- (orig_w, orig_h), Image.Resampling.LANCZOS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  )
265
-
266
- return design_image
267
-
268
-
269
- def create_demo(model):
270
- gr.Markdown("### Stable Design demo")
271
- with gr.Row():
272
- with gr.Column():
273
- input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
274
- input_text = gr.Textbox(label='Prompt', placeholder='Please upload your image first', lines=2)
275
- with gr.Accordion('Advanced options', open=False):
276
- num_steps = gr.Slider(label='Steps',
277
- minimum=1,
278
- maximum=50,
279
- value=50,
280
- step=1)
281
- img_size = gr.Slider(label='Image size',
282
- minimum=256,
283
- maximum=768,
284
- value=768,
285
- step=64)
286
- guidance_scale = gr.Slider(label='Guidance Scale',
287
- minimum=0.1,
288
- maximum=30.0,
289
- value=10.0,
290
- step=0.1)
291
- seed = gr.Slider(label='Seed',
292
- minimum=-1,
293
- maximum=2147483647,
294
- value=323*111,
295
- step=1,
296
- randomize=True)
297
- strength = gr.Slider(label='Strength',
298
- minimum=0.1,
299
- maximum=1.0,
300
- value=0.9,
301
- step=0.1)
302
- a_prompt = gr.Textbox(
303
- label='Added Prompt',
304
- value="interior design, 4K, high resolution, photorealistic")
305
- n_prompt = gr.Textbox(
306
- label='Negative Prompt',
307
- value="window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner")
308
- submit = gr.Button("Submit")
309
-
310
- with gr.Column():
311
- design_image = gr.Image(label="Output Mask", elem_id='img-display-output')
312
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- def on_submit(image, text, num_steps, guidance_scale, seed, strength, a_prompt, n_prompt, img_size):
315
- model.seed = seed
316
- model.neg_prompt = n_prompt
317
- model.additional_quality_suffix = a_prompt
318
-
319
- with torch.no_grad():
320
- out_img = model.generate_design(image, text, guidance_scale=guidance_scale, num_steps=num_steps, strength=strength, img_size=img_size)
321
-
322
- return out_img
323
-
324
- submit.click(on_submit, inputs=[input_image, input_text, num_steps, guidance_scale, seed, strength, a_prompt, n_prompt, img_size], outputs=design_image)
325
- examples = gr.Examples(examples=[["imgs/bedroom_1.jpg", "An elegantly appointed bedroom in the Art Deco style, featuring a grand king-size bed with geometric bedding, a luxurious velvet armchair, and a mirrored nightstand that reflects the room's opulence. Art Deco-inspired artwork adds a touch of glamour"], ["imgs/bedroom_2.jpg", "A bedroom that exudes French country charm with a soft upholstered bed, walls adorned with floral wallpaper, and a vintage wooden wardrobe. A crystal chandelier casts a warm, inviting glow over the space"], ["imgs/dinning_room_1.jpg", "A cozy dining room that captures the essence of rustic charm with a solid wooden farmhouse table at its core, surrounded by an eclectic mix of mismatched chairs. An antique sideboard serves as a statement piece, and the ambiance is warmly lit by a series of quaint Edison bulbs dangling from the ceiling"], ["imgs/dinning_room_3.jpg", "A dining room that epitomizes contemporary elegance, anchored by a sleek, minimalist dining table paired with stylish modern chairs. Artistic lighting fixtures create a focal point above, while the surrounding minimalist decor ensures the space feels open, airy, and utterly modern"], ["imgs/image_1.jpg", "A glamorous master bedroom in Hollywood Regency style, boasting a plush tufted headboard, mirrored furniture reflecting elegance, luxurious fabrics in rich textures, and opulent gold accents for a touch of luxury."], ["imgs/image_2.jpg", "A vibrant living room with a tropical theme, complete with comfortable rattan furniture, large leafy plants bringing the outdoors in, bright cushions adding pops of color, and bamboo blinds for natural light control."], ["imgs/living_room_1.jpg", "A stylish living room embracing mid-century modern aesthetics, featuring a vintage teak coffee table at its center, complemented by a classic sunburst clock on the wall and a cozy shag rug underfoot, creating a warm and inviting atmosphere"]],
326
- inputs=[input_image, input_text], cache_examples=False)
327
-
328
-
329
- controlnet_depth= ControlNetModel.from_pretrained(
330
- "controlnet_depth", torch_dtype=dtype, use_safetensors=True)
331
- controlnet_seg = ControlNetModel.from_pretrained(
332
- "own_controlnet", torch_dtype=dtype, use_safetensors=True)
333
-
334
- pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
335
- "SG161222/Realistic_Vision_V5.1_noVAE",
336
- #"models/runwayml--stable-diffusion-inpainting",
337
- controlnet=[controlnet_depth, controlnet_seg],
338
- safety_checker=None,
339
- torch_dtype=dtype
340
- )
341
-
342
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models",
343
- weight_name="ip-adapter_sd15.bin")
344
- pipe.set_ip_adapter_scale(0.4)
345
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
346
- pipe = pipe.to(device)
347
- guide_pipe = StableDiffusionXLPipeline.from_pretrained("segmind/SSD-1B",
348
- torch_dtype=dtype, use_safetensors=True, variant="fp16")
349
- guide_pipe = guide_pipe.to(device)
350
-
351
- seg_image_processor, image_segmentor = get_segmentation_pipeline()
352
- depth_feature_extractor, depth_estimator = get_depth_pipeline()
353
- depth_estimator = depth_estimator.to(device)
354
-
 
 
 
 
 
 
 
 
 
355
 
356
  def main():
357
- model = ControlNetDepthDesignModelMulti()
358
- print('Models uploaded successfully')
359
 
360
- title = "# StableDesign"
361
- description = """
362
- <p style='font-size: 14px; margin-bottom: 10px;'><a href='https://www.linkedin.com/in/mykola-lavreniuk/'>Mykola Lavreniuk</a>, <a href='https://www.linkedin.com/in/bartosz-ludwiczuk-a677a760/'>Bartosz Ludwiczuk</a></p>
363
- <p style='font-size: 16px; margin-bottom: 0px; margin-top=0px;'>Official demo for <strong>StableDesign:</strong> 2nd place solution for the Generative Interior Design 2024 <a href='https://www.aicrowd.com/challenges/generative-interior-design-challenge-2024/leaderboards?challenge_round_id=1314'>competition</a>. StableDesign is a deep learning model designed to harness the power of AI, providing innovative and creative tools for designers. Using our algorithms, images of empty rooms can be transformed into fully furnished spaces based on text descriptions. Please refer to our <a href='https://github.com/Lavreniuk/generative-interior-design'>GitHub</a> for more details.</p>
364
- """
365
- with gr.Blocks() as demo:
366
- gr.Markdown(title)
367
- gr.Markdown(description)
368
-
369
- create_demo(model)
370
- gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/MykolaL/StableDesign?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
371
- <p><img src="https://visitor-badge.glitch.me/badge?page_id=MykolaL/StableDesign" alt="visitors"></p></center>''')
372
-
373
- demo.queue().launch(share=False)
374
-
 
 
 
375
 
376
- if __name__ == '__main__':
377
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import time
5
+ import os
6
+ import random
7
+ from typing import List
8
+ import traceback
9
+ from google.oauth2.credentials import Credentials
10
+ from google_auth_oauthlib.flow import InstalledAppFlow
11
+ from googleapiclient.discovery import build
12
+ from googleapiclient.http import MediaIoBaseUpload
13
+ from io import BytesIO
14
+ import datetime
15
+
16
+ # Import the model interface
17
+ from model import DesignModel
18
+
19
+ # For testing, import the mock model
20
+ from mock_model import MockDesignModel
21
+
22
+ def create_ui(model: DesignModel):
23
+ """Create the main UI interface with all components"""
24
+ # Store current variations at UI level
25
+ current_variations = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ with gr.Blocks(css="""
28
+ /* Base styles */
29
+ :root {
30
+ --section-title-size: 1.2rem;
31
+ --section-spacing: var(--spacing-lg);
32
+ --panel-min-height: auto;
33
+ }
34
+
35
+ /* Row styling for equal heights */
36
+ .gr-row {
37
+ margin-bottom: var(--spacing-md);
38
+ display: flex;
39
+ align-items: stretch;
40
+ }
41
+
42
+ .gr-row > .gr-column {
43
+ display: flex;
44
+ flex-direction: column;
45
+ }
46
+
47
+ .gr-row > .gr-column > .gr-group {
48
+ flex: 1;
49
+ display: flex;
50
+ flex-direction: column;
51
+ }
52
+
53
+ /* Consistent title styling */
54
+ .gr-markdown h2 {
55
+ font-size: var(--section-title-size) !important;
56
+ font-weight: 600 !important;
57
+ margin: var(--spacing-sm) 0 var(--spacing-md) !important;
58
+ padding: 0 !important;
59
+ color: var(--body-text-color) !important;
60
+ }
61
+
62
+ .gr-markdown h3 {
63
+ font-size: 1rem !important;
64
+ font-weight: 500 !important;
65
+ margin: var(--spacing-sm) 0 !important;
66
+ padding: 0 !important;
67
+ color: var(--body-text-color) !important;
68
+ }
69
+
70
+ /* Panel styling */
71
+ .gr-group {
72
+ border: 1px solid var(--border-color-primary);
73
+ border-radius: var(--radius-lg);
74
+ padding: var(--section-spacing);
75
+ margin: var(--spacing-sm) 0;
76
+ background: var(--background-fill-primary);
77
+ min-height: var(--panel-min-height);
78
+ height: auto !important;
79
+ display: flex;
80
+ flex-direction: column;
81
+ }
82
+
83
+ /* Form and input spacing */
84
+ .gr-form {
85
+ gap: var(--spacing-sm);
86
+ flex-grow: 1;
87
+ display: flex;
88
+ flex-direction: column;
89
+ }
90
+
91
+ .gr-form > div {
92
+ gap: var(--spacing-sm);
93
+ }
94
+
95
+ /* Dropdown styling */
96
+ .gr-dropdown {
97
+ margin-bottom: var(--spacing-sm);
98
+ }
99
+
100
+ /* Gallery improvements */
101
+ #gallery {
102
+ margin-top: 0;
103
+ height: 300px !important;
104
+ }
105
+
106
+ #gallery img {
107
+ object-fit: contain !important;
108
+ width: 100% !important;
109
+ height: 100% !important;
110
+ max-height: none !important;
111
+ }
112
+
113
+ /* Button styling */
114
+ .button-row {
115
+ display: flex;
116
+ justify-content: center;
117
+ padding: var(--spacing-xl) 0;
118
+ }
119
+
120
+ .button-row button {
121
+ min-width: 200px;
122
+ font-size: 1.1em;
123
+ font-weight: 600;
124
+ }
125
+
126
+ /* Text areas */
127
+ .gr-textarea {
128
+ font-family: monospace;
129
+ line-height: 1.4;
130
+ }
131
+
132
+ .gr-textarea:disabled {
133
+ opacity: 0.9;
134
+ background-color: var(--background-fill-secondary);
135
+ }
136
+
137
+ /* Progress indicator */
138
+ .progress-bar {
139
+ margin: var(--spacing-sm) 0;
140
+ }
141
+
142
+ /* Make dropdowns always visible */
143
+ .gr-dropdown {
144
+ display: block !important;
145
+ visibility: visible !important;
146
+ }
147
+
148
+ /* Upload area */
149
+ .upload-group {
150
+ height: 100%;
151
+ }
152
+
153
+ .upload-group .gr-image {
154
+ min-height: 300px;
155
+ }
156
+
157
+ /* Checkbox alignment */
158
+ .gr-checkbox-row {
159
+ display: flex !important;
160
+ align-items: center !important;
161
+ min-height: 4.5rem !important;
162
+ }
163
+
164
+ .gr-checkbox-row .gr-checkbox {
165
+ margin: auto 0 !important;
166
+ }
167
+
168
+ /* Remove any fixed heights from groups */
169
+ .gr-group > div {
170
+ height: auto !important;
171
+ min-height: unset !important;
172
+ }
173
+
174
+ /* Ensure consistent spacing in all panels */
175
+ .gr-group > div:not(:last-child) {
176
+ margin-bottom: var(--spacing-sm);
177
+ }
178
+
179
+ /* Make surface finishes more compact */
180
+ .surface-finishes {
181
+ padding: var(--spacing-xs) !important;
182
+ margin: 0 !important;
183
+ border: 1px solid var(--border-color-primary) !important;
184
+ border-radius: var(--radius-lg) !important;
185
+ background: var(--background-fill-primary) !important;
186
+ }
187
 
188
+ .surface-finishes .gr-form {
189
+ gap: var(--spacing-xs) !important;
190
+ margin: 0 !important;
191
+ padding: 0 !important;
192
+ flex-grow: 0 !important;
193
+ }
194
+
195
+ .surface-finishes .gr-dropdown {
196
+ margin: 0 0 var(--spacing-xs) 0 !important;
197
+ }
198
+
199
+ .surface-finishes .gr-row {
200
+ margin: 0 !important;
201
+ gap: var(--spacing-sm) !important;
202
+ }
203
+
204
+ .surface-finishes .gr-group {
205
+ padding: 0 !important;
206
+ margin: 0 !important;
207
+ border: none !important;
208
+ background: none !important;
209
+ box-shadow: none !important;
210
+ min-height: 0 !important;
211
+ }
212
+
213
+ .surface-finishes .gr-markdown {
214
+ margin: 0 0 var(--spacing-xs) 0 !important;
215
+ }
216
+
217
+ .surface-finishes .gr-form > div {
218
+ gap: var(--spacing-xs) !important;
219
+ margin: 0 !important;
220
+ }
221
+
222
+ .surface-finishes .gr-column {
223
+ flex-grow: 0 !important;
224
+ padding: 0 !important;
225
+ }
226
+
227
+ /* Remove any minimum heights */
228
+ .surface-finishes .gr-group > div {
229
+ min-height: 0 !important;
230
+ height: auto !important;
231
+ margin: 0 !important;
232
+ }
233
+
234
+ /* Override any flex growth */
235
+ .surface-finishes .gr-form,
236
+ .surface-finishes .gr-column,
237
+ .surface-finishes .gr-group {
238
+ flex: 0 0 auto !important;
239
+ }
240
+
241
+ /* Wall decorations and special requests */
242
+ .wall-decorations-row > .gr-column > .gr-group {
243
+ height: 100%;
244
+ }
245
+
246
+ /* Upload and gallery row */
247
+ .upload-gallery-row > .gr-column > .gr-group {
248
+ height: 100%;
249
+ }
250
+ """) as interface:
251
+ gr.Markdown("### Interior Design Assistant")
252
 
253
+ with gr.Blocks():
254
+ # Row 1 - Basic Settings
255
+ with gr.Row():
256
+ with gr.Group():
257
+ gr.Markdown("## 🏠 Basic Settings")
258
+ with gr.Row():
259
+ room_type = gr.Dropdown(
260
+ choices=[
261
+ "Living Room", "Bedroom", "Kitchen", "Dining Room",
262
+ "Bathroom", "Home Office", "Kids Room", "Master Bedroom",
263
+ "Guest Room", "Studio Apartment", "Entryway", "Hallway",
264
+ "Game Room", "Library", "Home Theater", "Gym"
265
+ ],
266
+ label="Room Type",
267
+ value="Living Room"
268
+ )
269
+ style_preset = gr.Dropdown(
270
+ choices=[
271
+ "Modern", "Contemporary", "Minimalist", "Industrial",
272
+ "Scandinavian", "Mid-Century Modern", "Traditional",
273
+ "Transitional", "Farmhouse", "Rustic", "Bohemian",
274
+ "Art Deco", "Coastal", "Mediterranean", "Japanese",
275
+ "French Country", "Victorian", "Colonial", "Gothic",
276
+ "Baroque", "Rococo", "Neoclassical", "Eclectic",
277
+ "Zen", "Tropical", "Shabby Chic", "Hollywood Regency",
278
+ "Southwestern", "Asian Fusion", "Retro"
279
+ ],
280
+ label="Design Style",
281
+ value="Modern"
282
+ )
283
+ color_scheme = gr.Dropdown(
284
+ choices=[
285
+ "Neutral", "Monochromatic", "Minimalist White",
286
+ "Warm Gray", "Cool Gray", "Earth Tones",
287
+ "Pastel", "Bold Primary", "Jewel Tones",
288
+ "Black and White", "Navy and Gold", "Forest Green",
289
+ "Desert Sand", "Ocean Blue", "Sunset Orange",
290
+ "Deep Purple", "Emerald Green", "Ruby Red",
291
+ "Sapphire Blue", "Golden Yellow", "Sage Green",
292
+ "Dusty Rose", "Charcoal", "Cream", "Burgundy",
293
+ "Teal", "Copper", "Silver", "Bronze", "Slate"
294
+ ],
295
+ label="Color Mood",
296
+ value="Neutral"
297
+ )
298
+
299
+ # Row 2 - Surface Finishes
300
+ with gr.Row():
301
+ # Floor Options
302
+ with gr.Column(scale=1):
303
+ with gr.Group():
304
+ gr.Markdown("## 🎨 Floor Options")
305
+ floor_type = gr.Dropdown(
306
+ choices=[
307
+ "Keep Existing", "Hardwood", "Stone Tiles", "Porcelain Tiles",
308
+ "Soft Carpet", "Polished Concrete", "Marble", "Vinyl",
309
+ "Natural Bamboo", "Cork", "Ceramic Tiles", "Terrazzo",
310
+ "Slate", "Travertine", "Laminate", "Engineered Wood",
311
+ "Mosaic Tiles", "Luxury Vinyl Tiles", "Stained Concrete"
312
+ ],
313
+ label="Material",
314
+ value="Keep Existing"
315
+ )
316
+ floor_color = gr.Dropdown(
317
+ choices=[
318
+ "Keep Existing", "Light Oak", "Rich Walnut", "Cool Gray",
319
+ "Whitewashed", "Warm Cherry", "Deep Brown", "Classic Black",
320
+ "Natural", "Sandy Beige", "Chocolate", "Espresso",
321
+ "Honey Oak", "Weathered Gray", "White Marble",
322
+ "Cream Travertine", "Dark Slate", "Golden Teak",
323
+ "Rustic Pine", "Ebony"
324
+ ],
325
+ label="Color",
326
+ value="Keep Existing"
327
+ )
328
+ floor_pattern = gr.Dropdown(
329
+ choices=[
330
+ "Keep Existing", "Classic Straight", "Elegant Herringbone",
331
+ "V-Pattern", "Decorative Parquet", "Diagonal Layout",
332
+ "Basketweave", "Chevron", "Random Length", "Grid Pattern",
333
+ "Versailles Pattern", "Running Bond", "Hexagonal",
334
+ "Moroccan Pattern", "Brick Layout", "Diamond Pattern",
335
+ "Windmill Pattern", "Large Format", "Mixed Width"
336
+ ],
337
+ label="Pattern",
338
+ value="Keep Existing"
339
+ )
340
+
341
+ # Wall Options
342
+ with gr.Column(scale=1):
343
+ with gr.Group():
344
+ gr.Markdown("## 🎨 Wall Options")
345
+ wall_type = gr.Dropdown(
346
+ choices=[
347
+ "Keep Existing", "Fresh Paint", "Designer Wallpaper",
348
+ "Textured Finish", "Wood Panels", "Exposed Brick",
349
+ "Natural Stone", "Wooden Planks", "Modern Concrete",
350
+ "Venetian Plaster", "Wainscoting", "Shiplap",
351
+ "3D Wall Panels", "Fabric Panels", "Metal Panels",
352
+ "Cork Wall", "Tile Feature", "Glass Panels",
353
+ "Acoustic Panels", "Living Wall"
354
+ ],
355
+ label="Treatment",
356
+ value="Keep Existing"
357
+ )
358
+ wall_color = gr.Dropdown(
359
+ choices=[
360
+ "Keep Existing", "Crisp White", "Soft White", "Warm Beige",
361
+ "Gentle Gray", "Sky Blue", "Nature Green", "Sunny Yellow",
362
+ "Blush Pink", "Deep Blue", "Bold Black", "Sage Green",
363
+ "Terracotta", "Navy Blue", "Charcoal Gray", "Lavender",
364
+ "Olive Green", "Dusty Rose", "Teal", "Burgundy"
365
+ ],
366
+ label="Color",
367
+ value="Keep Existing"
368
+ )
369
+ wall_finish = gr.Dropdown(
370
+ choices=[
371
+ "Keep Existing", "Soft Matte", "Subtle Eggshell",
372
+ "Pearl Satin", "Sleek Semi-Gloss", "High Gloss",
373
+ "Suede Texture", "Metallic", "Chalk Finish",
374
+ "Distressed", "Brushed", "Smooth", "Textured",
375
+ "Venetian", "Lime Wash", "Concrete", "Rustic",
376
+ "Lacquered", "Hammered", "Patina"
377
+ ],
378
+ label="Finish",
379
+ value="Keep Existing"
380
+ )
381
+
382
+ # Row 3 - Wall Decorations and Special Requests
383
+ with gr.Row(elem_classes="wall-decorations-row"):
384
+ # Wall Decorations
385
+ with gr.Column(scale=2):
386
+ with gr.Group():
387
+ gr.Markdown("## 🖼️ Wall Decorations")
388
+ # Art and Mirror
389
+ with gr.Row():
390
+ # Art Print
391
+ with gr.Column():
392
+ with gr.Row():
393
+ art_print_enable = gr.Checkbox(label="Add Artwork", value=False)
394
+ art_print_color = gr.Dropdown(
395
+ choices=[
396
+ "None", "Classic Black & White", "Vibrant Colors",
397
+ "Single Color", "Soft Colors", "Modern Abstract",
398
+ "Earth Tones", "Pastel Palette", "Bold Primary Colors",
399
+ "Metallic Accents", "Monochromatic", "Jewel Tones",
400
+ "Watercolor", "Vintage Colors", "Neon Accents",
401
+ "Natural Hues", "Ocean Colors", "Desert Palette"
402
+ ],
403
+ label="Art Style",
404
+ value="None"
405
+ )
406
+ art_print_size = gr.Dropdown(
407
+ choices=[
408
+ "None", "Modest", "Standard", "Statement", "Oversized",
409
+ "Gallery Wall", "Diptych", "Triptych", "Mini Series",
410
+ "Floor to Ceiling", "Custom Size"
411
+ ],
412
+ label="Art Size",
413
+ value="None"
414
+ )
415
+
416
+ # Mirror
417
+ with gr.Column():
418
+ with gr.Row():
419
+ mirror_enable = gr.Checkbox(label="Add Mirror", value=False)
420
+ mirror_frame = gr.Dropdown(
421
+ choices=[
422
+ "None", "Gold", "Silver", "Black", "White", "Wood",
423
+ "Brass", "Bronze", "Copper", "Chrome", "Antique Gold",
424
+ "Brushed Nickel", "Rustic Wood", "Ornate", "Minimalist",
425
+ "LED Backlit", "Bamboo", "Rattan", "Leather Wrapped"
426
+ ],
427
+ label="Frame Style",
428
+ value="None"
429
+ )
430
+ mirror_size = gr.Dropdown(
431
+ choices=[
432
+ "Small", "Medium", "Large", "Full Length",
433
+ "Oversized", "Double Width", "Floor Mirror",
434
+ "Vanity Size", "Statement Piece", "Custom Size"
435
+ ],
436
+ label="Mirror Size",
437
+ value="Medium"
438
+ )
439
+
440
+ # Sconce, Shelf, and Plants
441
+ with gr.Row():
442
+ # Sconce
443
+ with gr.Column():
444
+ with gr.Row():
445
+ sconce_enable = gr.Checkbox(label="Add Wall Sconce", value=False)
446
+ sconce_color = gr.Dropdown(
447
+ choices=[
448
+ "None", "Black", "Gold", "Silver", "Bronze", "White",
449
+ "Brass", "Copper", "Chrome", "Antique Brass",
450
+ "Brushed Nickel", "Oil-Rubbed Bronze", "Pewter",
451
+ "Rose Gold", "Matte Black", "Polished Nickel",
452
+ "Aged Brass", "Champagne", "Gunmetal"
453
+ ],
454
+ label="Sconce Color",
455
+ value="None"
456
+ )
457
+ sconce_style = gr.Dropdown(
458
+ choices=[
459
+ "Modern", "Traditional", "Industrial", "Art Deco",
460
+ "Minimalist", "Vintage", "Contemporary", "Rustic",
461
+ "Coastal", "Farmhouse", "Mid-Century", "Bohemian",
462
+ "Scandinavian", "Asian", "Mediterranean", "Gothic",
463
+ "Transitional", "Eclectic", "Victorian"
464
+ ],
465
+ label="Sconce Style",
466
+ value="Modern"
467
+ )
468
+
469
+ # Floating Shelves
470
+ with gr.Column():
471
+ with gr.Row():
472
+ shelf_enable = gr.Checkbox(label="Add Floating Shelves", value=False)
473
+ shelf_color = gr.Dropdown(
474
+ choices=[
475
+ "None", "White", "Black", "Natural Wood", "Glass",
476
+ "Dark Wood", "Light Wood", "Metal", "Gold", "Silver",
477
+ "Bronze", "Reclaimed Wood", "Bamboo", "Marble",
478
+ "Industrial Metal", "Two-Tone", "Concrete",
479
+ "Acrylic", "Copper", "Brass"
480
+ ],
481
+ label="Shelf Material",
482
+ value="None"
483
+ )
484
+ shelf_size = gr.Dropdown(
485
+ choices=[
486
+ "Small", "Medium", "Large", "Set of 3",
487
+ "Extra Long", "Corner Set", "Asymmetric Set",
488
+ "Graduated Sizes", "Custom Length", "Mini Cubes",
489
+ "Full Wall", "Mixed Sizes", "Modular System"
490
+ ],
491
+ label="Shelf Size",
492
+ value="Medium"
493
+ )
494
+
495
+ # Plants
496
+ with gr.Column():
497
+ with gr.Row():
498
+ plants_enable = gr.Checkbox(label="Add Plants", value=False)
499
+ plants_type = gr.Dropdown(
500
+ choices=[
501
+ "None", "Hanging Plants", "Vertical Garden",
502
+ "Plant Shelf", "Single Plant", "Climbing Vines",
503
+ "Air Plants", "Succulent Wall", "Herb Garden",
504
+ "Mixed Tropical", "Fern Collection", "Living Wall",
505
+ "Moss Wall", "Potted Arrangement", "Plant Corner",
506
+ "Cascading Plants", "Bamboo Screen", "Terrarium Wall"
507
+ ],
508
+ label="Plant Type",
509
+ value="None"
510
+ )
511
+ plants_size = gr.Dropdown(
512
+ choices=[
513
+ "Small", "Medium", "Large", "Mixed Sizes",
514
+ "Full Wall", "Statement Piece", "Compact",
515
+ "Expansive", "Accent", "Floor to Ceiling",
516
+ "Window Height", "Custom Size", "Modular"
517
+ ],
518
+ label="Plant Coverage",
519
+ value="Medium"
520
+ )
521
+
522
+ # Special Requests and Advanced Settings
523
+ with gr.Column(scale=1):
524
+ with gr.Group():
525
+ gr.Markdown("## ✨ Special Requests")
526
+ input_text = gr.Textbox(
527
+ label="Additional Details",
528
+ placeholder="Add any special requests or details here...",
529
+ lines=3
530
+ )
531
+ num_outputs = gr.Slider(
532
+ minimum=1, maximum=50, value=1, step=1,
533
+ label="Number of Variations"
534
+ )
535
+
536
+ gr.Markdown("### Advanced Settings")
537
+ num_steps = gr.Slider(
538
+ minimum=20,
539
+ maximum=100,
540
+ value=50,
541
+ step=1,
542
+ label="Quality Steps"
543
+ )
544
+ guidance_scale = gr.Slider(
545
+ minimum=1,
546
+ maximum=20,
547
+ value=7.5,
548
+ step=0.1,
549
+ label="Design Freedom"
550
+ )
551
+ strength = gr.Slider(
552
+ minimum=0.1,
553
+ maximum=1.0,
554
+ value=0.75,
555
+ step=0.05,
556
+ label="Change Amount"
557
+ )
558
+ seed = gr.Number(
559
+ label="Seed (leave empty for random)",
560
+ value=-1,
561
+ precision=0
562
+ )
563
+ with gr.Row():
564
+ save_to_drive = gr.Checkbox(label="Save to Google Drive")
565
+ drive_url = gr.Textbox(
566
+ label="Drive Folder URL",
567
+ placeholder="https://drive.google.com/drive/folders/..."
568
+ )
569
+
570
+ # Row 4 - Current Prompts
571
+ with gr.Row():
572
+ with gr.Group():
573
+ gr.Markdown("## 📝 Current Prompts")
574
+ prompt_display = gr.TextArea(
575
+ label="Positive Prompt",
576
+ interactive=False,
577
+ lines=3,
578
+ value="Your design prompt will appear here..."
579
+ )
580
+ negative_prompt = gr.TextArea(
581
+ label="Negative Prompt",
582
+ value="blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped",
583
+ lines=2,
584
+ interactive=False
585
+ )
586
+
587
+ # Row 5 - Upload and Gallery
588
+ with gr.Row(elem_classes="upload-gallery-row"):
589
+ # Upload Area
590
+ with gr.Column(scale=1):
591
+ with gr.Group():
592
+ gr.Markdown("## 📸 Upload Photo")
593
+ input_image = gr.Image(
594
+ label="Upload a photo of your room",
595
+ type='pil'
596
+ )
597
+
598
+ # Gallery Area
599
+ with gr.Column(scale=2):
600
+ with gr.Group():
601
+ gr.Markdown("## 🖼️ Generated Variations")
602
+ gallery = gr.Gallery(
603
+ show_label=False,
604
+ elem_id="gallery",
605
+ columns=4,
606
+ rows=1,
607
+ height="300px",
608
+ object_fit="contain",
609
+ preview=True,
610
+ show_share_button=False
611
+ )
612
+
613
+ # Row 6 - Create Button
614
+ with gr.Row(elem_classes="button-row"):
615
+ submit = gr.Button("✨ Create My Design", variant="primary", size="lg")
616
+
617
+ # Progress indicator
618
+ progress = gr.Progress(track_tqdm=True)
619
+
620
+ def upload_to_drive(image, folder_id):
621
+ """Upload an image to Google Drive folder"""
622
+ try:
623
+ # OAuth 2.0 scopes
624
+ SCOPES = ['https://www.googleapis.com/auth/drive.file']
625
+
626
+ # Start OAuth 2.0 flow
627
+ flow = InstalledAppFlow.from_client_secrets_file(
628
+ 'credentials.json',
629
+ SCOPES
630
+ )
631
+ creds = flow.run_local_server(port=0)
632
+
633
+ # Build the Drive API service
634
+ service = build('drive', 'v3', credentials=creds)
635
+
636
+ # Convert numpy array to bytes
637
+ img = Image.fromarray(image)
638
+ img_byte_arr = BytesIO()
639
+ img.save(img_byte_arr, format='PNG')
640
+ img_byte_arr.seek(0)
641
+
642
+ # Prepare the file metadata
643
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
644
+ file_metadata = {
645
+ 'name': f'design_variation_{timestamp}.png',
646
+ 'parents': [folder_id]
647
+ }
648
+
649
+ # Create media
650
+ media = MediaIoBaseUpload(
651
+ img_byte_arr,
652
+ mimetype='image/png',
653
+ resumable=True
654
+ )
655
+
656
+ # Execute the upload
657
+ file = service.files().create(
658
+ body=file_metadata,
659
+ media_body=media,
660
+ fields='id'
661
+ ).execute()
662
+
663
+ print(f"File uploaded successfully. File ID: {file.get('id')}")
664
+ return True
665
+
666
+ except Exception as e:
667
+ print(f"Error uploading to Drive: {e}")
668
+ return False
669
+
670
+ def extract_folder_id(drive_url):
671
+ """Extract folder ID from Google Drive URL"""
672
+ try:
673
+ # Handle different URL formats
674
+ if 'folders/' in drive_url:
675
+ folder_id = drive_url.split('folders/')[-1].split('?')[0]
676
+ return folder_id
677
+ except Exception as e:
678
+ print(f"Error extracting folder ID: {e}")
679
+ return None
680
+
681
+ def on_submit(image, room, style, colors, floor_t, floor_c, floor_p,
682
+ wall_t, wall_c, wall_f, custom_text,
683
+ art_en, art_col, art_size,
684
+ mirror_en, mirror_fr, mirror_size,
685
+ sconce_en, sconce_col, sconce_style,
686
+ shelf_en, shelf_col, shelf_size,
687
+ plants_en, plants_type, plants_size,
688
+ num_outputs, save_to_drive, drive_url, num_steps,
689
+ guidance_scale, seed, strength):
690
+
691
+ if image is None:
692
+ return []
693
+
694
+ try:
695
+ nonlocal current_variations
696
+
697
+ # Generate the prompt
698
+ prompt = update_prompt(
699
+ room, style, colors, floor_t, floor_c, floor_p,
700
+ wall_t, wall_c, wall_f, custom_text,
701
+ art_en, art_col, art_size,
702
+ mirror_en, mirror_fr, mirror_size,
703
+ sconce_en, sconce_col, sconce_style,
704
+ shelf_en, shelf_col, shelf_size,
705
+ plants_en, plants_type, plants_size
706
+ )
707
+
708
+ # Generate variations
709
+ variations = model.generate_design(
710
+ image=image,
711
+ num_variations=max(1, int(num_outputs)),
712
+ prompt=prompt,
713
+ num_steps=int(num_steps),
714
+ guidance_scale=float(guidance_scale),
715
+ strength=float(strength),
716
+ seed=int(seed) if seed != -1 else None
717
+ )
718
+
719
+ # Store variations
720
+ current_variations = variations
721
+
722
+ # Handle Google Drive upload if enabled
723
+ if save_to_drive and drive_url:
724
+ folder_id = extract_folder_id(drive_url)
725
+ if folder_id:
726
+ for variation in variations:
727
+ upload_to_drive(variation, folder_id)
728
+
729
+ # Convert variations to gallery format
730
+ gallery_images = [(v, None) for v in variations]
731
+ return gallery_images
732
+
733
+ except Exception as e:
734
+ print(f"Error in generation: {e}")
735
+ current_variations = []
736
+ return []
737
+
738
+ submit.click(
739
+ on_submit,
740
+ inputs=[
741
+ input_image, room_type, style_preset, color_scheme,
742
+ floor_type, floor_color, floor_pattern,
743
+ wall_type, wall_color, wall_finish,
744
+ input_text,
745
+ art_print_enable, art_print_color, art_print_size,
746
+ mirror_enable, mirror_frame, mirror_size,
747
+ sconce_enable, sconce_color, sconce_style,
748
+ shelf_enable, shelf_color, shelf_size,
749
+ plants_enable, plants_type, plants_size,
750
+ num_outputs, save_to_drive, drive_url, num_steps,
751
+ guidance_scale, seed, strength
752
+ ],
753
+ outputs=[gallery]
754
  )
755
+
756
+ # Update prompt display when any input changes
757
+ def update_prompt_display(*args):
758
+ try:
759
+ prompt = update_prompt(*args)
760
+ return [prompt, negative_prompt.value] # Return both prompts
761
+ except Exception as e:
762
+ print(f"Error updating prompt: {e}")
763
+ return ["Error generating prompt", negative_prompt.value]
764
+
765
+ # List of all inputs that should trigger prompt updates
766
+ prompt_inputs = [
767
+ room_type, style_preset, color_scheme,
768
+ floor_type, floor_color, floor_pattern,
769
+ wall_type, wall_color, wall_finish,
770
+ input_text,
771
+ art_print_enable, art_print_color, art_print_size,
772
+ mirror_enable, mirror_frame, mirror_size,
773
+ sconce_enable, sconce_color, sconce_style,
774
+ shelf_enable, shelf_color, shelf_size,
775
+ plants_enable, plants_type, plants_size
776
+ ]
777
+
778
+ # Connect all inputs to prompt update
779
+ for input_component in prompt_inputs:
780
+ input_component.change(
781
+ fn=update_prompt_display,
782
+ inputs=prompt_inputs,
783
+ outputs=[prompt_display, negative_prompt]
784
+ )
785
+
786
+ # Gallery click handler
787
+ def on_select(evt):
788
+ nonlocal current_variations
789
+ try:
790
+ if isinstance(evt, list) and len(evt) > 0:
791
+ # Get the clicked file path
792
+ clicked_path = evt[0][0] if isinstance(evt[0], tuple) else evt[0]
793
+
794
+ # Get all file paths from the gallery
795
+ gallery_paths = []
796
+ for item in evt:
797
+ path = item[0] if isinstance(item, tuple) else item
798
+ gallery_paths.append(path)
799
+
800
+ # Find which image was clicked by comparing paths
801
+ selected_index = gallery_paths.index(clicked_path)
802
+ if 0 <= selected_index < len(current_variations):
803
+ return current_variations[selected_index]
804
+ except Exception as e:
805
+ print(f"Gallery selection error: {e}")
806
+ return None
807
+
808
+ gallery.select(
809
+ fn=on_select,
810
+ inputs=gallery,
811
+ outputs=[]
812
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
+ return interface
815
+
816
+ def update_prompt(room, style, colors, floor_t, floor_c, floor_p,
817
+ wall_t, wall_c, wall_f, custom_text,
818
+ art_en, art_col, art_size,
819
+ mirror_en, mirror_fr, mirror_size,
820
+ sconce_en, sconce_col, sconce_style,
821
+ shelf_en, shelf_col, shelf_size,
822
+ plants_en, plants_type, plants_size):
823
+ # Start with basic room and style
824
+ prompt_parts = [f"Design a {style} {room.lower()} with a {colors} color scheme"]
825
 
826
+ # Add floor details only if type is not "Keep Existing"
827
+ if floor_t and floor_t != "Keep Existing":
828
+ floor_desc = floor_t
829
+ if floor_c and floor_c != "Keep Existing":
830
+ floor_desc += f" in {floor_c}"
831
+ if floor_p and floor_p != "Keep Existing":
832
+ floor_desc += f" with {floor_p} pattern"
833
+ prompt_parts.append(f"featuring {floor_desc} flooring")
834
+
835
+ # Add wall details only if type is not "Keep Existing"
836
+ if wall_t and wall_t != "Keep Existing":
837
+ wall_desc = wall_t
838
+ if wall_c and wall_c != "Keep Existing":
839
+ wall_desc += f" in {wall_c}"
840
+ if wall_f and wall_f != "Keep Existing":
841
+ wall_desc += f" with {wall_f} finish"
842
+ prompt_parts.append(f"with {wall_desc} walls")
843
+
844
+ # Add accessories only if enabled AND properties are selected and not "Keep Existing" or "None"
845
+ accessories = []
846
+
847
+ # Art Print
848
+ if art_en and art_col and art_col not in ["Keep Existing", "None"]:
849
+ accessories.append(f"{art_size} {art_col} Art Print")
850
+
851
+ # Mirror
852
+ if mirror_en and mirror_fr and mirror_fr not in ["Keep Existing", "None"]:
853
+ accessories.append(f"{mirror_size} Mirror with {mirror_fr} frame")
854
+
855
+ # Wall Sconce
856
+ if sconce_en and sconce_col and sconce_col not in ["Keep Existing", "None"]:
857
+ accessories.append(f"{sconce_style} {sconce_col} Wall Sconce")
858
+
859
+ # Floating Shelves
860
+ if shelf_en and shelf_col and shelf_col not in ["Keep Existing", "None"]:
861
+ accessories.append(f"{shelf_size} {shelf_col} Floating Shelves")
862
+
863
+ # Wall Plants
864
+ if plants_en and plants_type and plants_type not in ["Keep Existing", "None"]:
865
+ accessories.append(f"{plants_size} {plants_type}")
866
+
867
+ # Only add accessories section if there are any accessories
868
+ if accessories:
869
+ prompt_parts.append("decorated with " + ", ".join(accessories))
870
+
871
+ # Add custom text only if provided and non-empty
872
+ if custom_text and custom_text.strip():
873
+ prompt_parts.append(custom_text.strip())
874
+
875
+ return ", ".join(prompt_parts)
876
 
877
  def main():
878
+ """Main entry point for the application"""
879
+ import sys
880
 
881
+ # Check if we're in test mode
882
+ is_test_mode = "--test" in sys.argv
883
+
884
+ if is_test_mode:
885
+ print("Starting in TEST mode...")
886
+ from mock_model import MockDesignModel
887
+ model = MockDesignModel()
888
+ else:
889
+ print("Starting in PRODUCTION mode...")
890
+ from prod_model import ProductionDesignModel
891
+ model = ProductionDesignModel()
892
+
893
+ interface = create_ui(model)
894
+ interface.launch(
895
+ share=False,
896
+ show_api=False, # Hide API docs
897
+ show_error=True # Show errors for debugging
898
+ )
899
 
900
+ if __name__ == "__main__":
901
  main()
credentials.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"installed":{"client_id":"924767856297-cdnm065phq1cq7ncfgh33bau61hjrpr8.apps.googleusercontent.com","project_id":"stabledesign","auth_uri":"https://accounts.google.com/o/oauth2/auth","token_uri":"https://oauth2.googleapis.com/token","auth_provider_x509_cert_url":"https://www.googleapis.com/oauth2/v1/certs","client_secret":"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","redirect_uris":["http://localhost"]}}
mock_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import DesignModel
2
+ from PIL import Image
3
+ import numpy as np
4
+ import random
5
+ from typing import List
6
+ import traceback
7
+
8
+ class MockDesignModel(DesignModel):
9
+ def __init__(self):
10
+ super().__init__()
11
+ # Define base colors with proper alpha values
12
+ self.base_colors = [
13
+ (255, 0, 0), # Red
14
+ (0, 255, 0), # Green
15
+ (0, 0, 255), # Blue
16
+ (255, 255, 0), # Yellow
17
+ (255, 0, 255), # Magenta
18
+ (0, 255, 255), # Cyan
19
+ (128, 0, 0), # Maroon
20
+ (0, 128, 0), # Dark Green
21
+ (0, 0, 128), # Navy
22
+ ]
23
+ # Add test-specific attributes
24
+ self.seed = 323*111
25
+ self.neg_prompt = "window, door, low resolution, banner, logo, watermark, text"
26
+ self.additional_quality_suffix = "interior design, 4K, high resolution"
27
+
28
+ def apply_tint(self, img_array: np.ndarray, color: tuple) -> np.ndarray:
29
+ """Apply a color tint to an image array"""
30
+ # Create tint array
31
+ tint = np.array(color, dtype=np.float32) / 255.0
32
+
33
+ # Apply tint with alpha blending
34
+ alpha = 0.3 # 30% tint strength
35
+ tinted = img_array * (1 - alpha) + (img_array * tint) * alpha
36
+
37
+ # Ensure values are within valid range
38
+ return np.clip(tinted, 0, 255).astype(np.uint8)
39
+
40
+ def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
41
+ """Generate multiple variations of the input image with different color tints"""
42
+ try:
43
+ print(f"Starting generation of {num_variations} variations")
44
+
45
+ # Convert image to numpy array once
46
+ img_array = np.array(image.convert('RGB'))
47
+
48
+ # Generate base colors for all variations
49
+ colors_needed = max(1, int(num_variations))
50
+ colors = []
51
+
52
+ # Add base colors first
53
+ colors.extend(self.base_colors)
54
+
55
+ # Generate additional random colors if needed
56
+ while len(colors) < colors_needed:
57
+ new_color = (
58
+ random.randint(0, 255),
59
+ random.randint(0, 255),
60
+ random.randint(0, 255)
61
+ )
62
+ if new_color not in colors:
63
+ colors.append(new_color)
64
+
65
+ # Use only the number of colors we need
66
+ selected_colors = random.sample(colors, colors_needed)
67
+
68
+ # Generate variations
69
+ variations = []
70
+ for i, color in enumerate(selected_colors):
71
+ # Apply tint to numpy array
72
+ tinted_array = self.apply_tint(img_array.copy(), color)
73
+ variations.append(tinted_array)
74
+ print(f"Created variation {i+1}/{colors_needed}")
75
+
76
+ print(f"Successfully generated {len(variations)} variations")
77
+ return variations
78
+
79
+ except Exception as e:
80
+ print(f"Error in generate_design: {e}")
81
+ traceback.print_exc()
82
+ # Return the original image array if there's an error
83
+ return [np.array(image.convert('RGB'))]
model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ class DesignModel:
6
+ """Interface for the design model"""
7
+ def __init__(self):
8
+ self.seed = None
9
+ self.neg_prompt = None
10
+ self.additional_quality_suffix = None
11
+
12
+ def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
13
+ """Generate design variations from the input image"""
14
+ raise NotImplementedError("This method should be implemented by concrete model classes")
prod_model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import DesignModel
2
+ from PIL import Image
3
+ import numpy as np
4
+ from typing import List
5
+ import random
6
+ import time
7
+ import torch
8
+ from diffusers import StableDiffusionImg2ImgPipeline
9
+ from transformers import CLIPTokenizer
10
+ import logging
11
+ import os
12
+ from datetime import datetime
13
+
14
+ # Set up logging
15
+ log_dir = "logs"
16
+ os.makedirs(log_dir, exist_ok=True)
17
+ log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log")
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(levelname)s - %(message)s',
22
+ handlers=[
23
+ logging.FileHandler(log_file),
24
+ logging.StreamHandler()
25
+ ]
26
+ )
27
+
28
+ class ProductionDesignModel(DesignModel):
29
+ def __init__(self):
30
+ super().__init__()
31
+ try:
32
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ logging.info(f"Using device: {self.device}")
34
+
35
+ self.model_id = "stabilityai/stable-diffusion-2-1"
36
+ logging.info(f"Loading model: {self.model_id}")
37
+
38
+ # Initialize the pipeline with error handling
39
+ try:
40
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
41
+ self.model_id,
42
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
43
+ safety_checker=None # Disable safety checker for performance
44
+ ).to(self.device)
45
+
46
+ # Enable optimizations
47
+ self.pipe.enable_attention_slicing()
48
+ if self.device == "cuda":
49
+ self.pipe.enable_model_cpu_offload()
50
+ self.pipe.enable_vae_slicing()
51
+
52
+ logging.info("Model loaded successfully")
53
+
54
+ except Exception as e:
55
+ logging.error(f"Error loading model: {e}")
56
+ raise
57
+
58
+ # Initialize tokenizer
59
+ self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
60
+
61
+ # Set default prompts
62
+ self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped"
63
+ self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
64
+
65
+ except Exception as e:
66
+ logging.error(f"Error in initialization: {e}")
67
+ raise
68
+
69
+ def _prepare_prompt(self, prompt: str) -> str:
70
+ """Prepare the prompt by adding quality suffix and checking length"""
71
+ try:
72
+ full_prompt = f"{prompt}, {self.additional_quality_suffix}"
73
+ tokens = self.tokenizer.tokenize(full_prompt)
74
+
75
+ if len(tokens) > 77:
76
+ logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...")
77
+ tokens = tokens[:77]
78
+ full_prompt = self.tokenizer.convert_tokens_to_string(tokens)
79
+
80
+ logging.info(f"Prepared prompt: {full_prompt}")
81
+ return full_prompt
82
+
83
+ except Exception as e:
84
+ logging.error(f"Error preparing prompt: {e}")
85
+ return prompt # Return original prompt if processing fails
86
+
87
+ def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
88
+ """Generate design variations with proper parameter handling"""
89
+ generation_start = time.time()
90
+ try:
91
+ # Log input parameters
92
+ logging.info(f"Generating {num_variations} variations with parameters: {kwargs}")
93
+
94
+ # Get parameters from kwargs with defaults
95
+ prompt = kwargs.get('prompt', '')
96
+ num_steps = int(kwargs.get('num_steps', 50))
97
+ guidance_scale = float(kwargs.get('guidance_scale', 7.5))
98
+ strength = float(kwargs.get('strength', 0.75))
99
+ base_seed = kwargs.get('seed', int(time.time()))
100
+
101
+ # Parameter validation
102
+ num_steps = max(20, min(100, num_steps))
103
+ guidance_scale = max(1, min(20, guidance_scale))
104
+ strength = max(0.1, min(1.0, strength))
105
+
106
+ # Prepare the prompt
107
+ full_prompt = self._prepare_prompt(prompt)
108
+
109
+ # Generate distinct seeds
110
+ seeds = [base_seed + i * 10000 for i in range(num_variations)]
111
+ logging.info(f"Using seeds: {seeds}")
112
+
113
+ # Prepare the input image
114
+ if image.mode != "RGB":
115
+ image = image.convert("RGB")
116
+
117
+ # Generate variations
118
+ variations = []
119
+ generator = torch.Generator(device=self.device)
120
+
121
+ for i, seed in enumerate(seeds):
122
+ try:
123
+ variation_start = time.time()
124
+ generator.manual_seed(seed)
125
+
126
+ # Generate the image
127
+ output = self.pipe(
128
+ prompt=full_prompt,
129
+ negative_prompt=self.neg_prompt,
130
+ image=image,
131
+ num_inference_steps=num_steps,
132
+ guidance_scale=guidance_scale,
133
+ strength=strength,
134
+ generator=generator
135
+ ).images[0]
136
+
137
+ variations.append(np.array(output))
138
+
139
+ variation_time = time.time() - variation_start
140
+ logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s")
141
+
142
+ except Exception as e:
143
+ logging.error(f"Error generating variation {i+1}: {e}")
144
+ if not variations: # If no successful variations yet
145
+ variations.append(np.array(image.convert('RGB')))
146
+
147
+ total_time = time.time() - generation_start
148
+ logging.info(f"Generation completed in {total_time:.2f}s")
149
+
150
+ return variations
151
+
152
+ except Exception as e:
153
+ logging.error(f"Error in generate_design: {e}")
154
+ import traceback
155
+ logging.error(traceback.format_exc())
156
+ return [np.array(image.convert('RGB'))]
157
+
158
+ finally:
159
+ if self.device == "cuda":
160
+ torch.cuda.empty_cache()
161
+ logging.info("Cleared CUDA cache")
162
+
163
+ def __del__(self):
164
+ """Cleanup when the model is deleted"""
165
+ try:
166
+ if self.device == "cuda":
167
+ torch.cuda.empty_cache()
168
+ logging.info("Final CUDA cache cleanup")
169
+ except:
170
+ pass
requirements.txt CHANGED
@@ -1,10 +1,33 @@
1
- diffusers==0.25.0
2
- xformers==0.0.23.post1
3
- transformers==4.39.1
4
- torchvision
5
- accelerate==0.26.1
6
- opencv-python==4.9.0.80
7
- scipy==1.11.4
8
- triton==2.1.0
9
- altair==4.1.0
10
- pandas==2.1.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=3.50.2
3
+ Pillow>=10.0.0
4
+ numpy>=1.24.0
5
+
6
+ # Model dependencies
7
+ torch>=2.0.0
8
+ diffusers>=0.21.0
9
+ transformers>=4.31.0
10
+ accelerate>=0.21.0
11
+
12
+ # Google Drive integration
13
+ google-auth>=2.22.0
14
+ google-auth-oauthlib>=1.0.0
15
+ google-api-python-client>=2.95.0
16
+
17
+ # Utility packages
18
+ python-dateutil>=2.8.2
19
+ tqdm>=4.65.0
20
+ requests>=2.31.0
21
+
22
+ # Optional but recommended
23
+ opencv-python>=4.8.0 # For image processing
24
+ safetensors>=0.3.1 # For faster model loading
25
+
26
+ # Development tools
27
+ pytest>=7.4.0
28
+ black>=22.0.0
29
+ flake8>=6.0.0
30
+ isort>=5.12.0
31
+
32
+ # Testing dependencies
33
+ pytest-mock>=3.11.1
test_prompt.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from app import update_prompt
3
+ from prod_model import ProductionDesignModel
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ class TestPromptGeneration(unittest.TestCase):
8
+ def setUp(self):
9
+ """Set up default values for tests"""
10
+ self.default_params = {
11
+ "room": "Living Room",
12
+ "style": "Modern",
13
+ "colors": "Neutral",
14
+ "floor_t": "Keep Existing",
15
+ "floor_c": "Keep Existing",
16
+ "floor_p": "Keep Existing",
17
+ "wall_t": "Keep Existing",
18
+ "wall_c": "Keep Existing",
19
+ "wall_f": "Keep Existing",
20
+ "custom_text": "",
21
+ "art_en": False,
22
+ "art_col": "None",
23
+ "art_size": "None",
24
+ "mirror_en": False,
25
+ "mirror_fr": "None",
26
+ "mirror_size": "Medium",
27
+ "sconce_en": False,
28
+ "sconce_col": "None",
29
+ "sconce_style": "Modern",
30
+ "shelf_en": False,
31
+ "shelf_col": "None",
32
+ "shelf_size": "Medium",
33
+ "plants_en": False,
34
+ "plants_type": "None",
35
+ "plants_size": "Medium"
36
+ }
37
+
38
+ def test_basic_room_style(self):
39
+ """Test basic room and style prompt generation"""
40
+ prompt = update_prompt(**self.default_params)
41
+ expected = "Design a Modern living room with a Neutral color scheme"
42
+ self.assertEqual(prompt, expected)
43
+
44
+ def test_all_room_types(self):
45
+ """Test all room types"""
46
+ room_types = [
47
+ "Living Room", "Bedroom", "Kitchen", "Dining Room",
48
+ "Bathroom", "Home Office", "Kids Room", "Master Bedroom",
49
+ "Guest Room", "Studio Apartment", "Entryway", "Hallway",
50
+ "Game Room", "Library", "Home Theater", "Gym"
51
+ ]
52
+ for room in room_types:
53
+ params = self.default_params.copy()
54
+ params["room"] = room
55
+ prompt = update_prompt(**params)
56
+ expected = f"Design a Modern {room.lower()} with a Neutral color scheme"
57
+ self.assertEqual(prompt, expected)
58
+
59
+ def test_all_styles(self):
60
+ """Test all style presets"""
61
+ styles = [
62
+ "Modern", "Contemporary", "Minimalist", "Industrial",
63
+ "Scandinavian", "Mid-Century Modern", "Traditional",
64
+ "Transitional", "Farmhouse", "Rustic", "Bohemian",
65
+ "Art Deco", "Coastal", "Mediterranean", "Japanese",
66
+ "French Country", "Victorian", "Colonial", "Gothic",
67
+ "Baroque", "Rococo", "Neoclassical", "Eclectic",
68
+ "Zen", "Tropical", "Shabby Chic", "Hollywood Regency",
69
+ "Southwestern", "Asian Fusion", "Retro"
70
+ ]
71
+ for style in styles:
72
+ params = self.default_params.copy()
73
+ params["style"] = style
74
+ prompt = update_prompt(**params)
75
+ expected = f"Design a {style} living room with a Neutral color scheme"
76
+ self.assertEqual(prompt, expected)
77
+
78
+ def test_all_color_schemes(self):
79
+ """Test all color schemes"""
80
+ color_schemes = [
81
+ "Neutral", "Monochromatic", "Minimalist White",
82
+ "Warm Gray", "Cool Gray", "Earth Tones",
83
+ "Pastel", "Bold Primary", "Jewel Tones",
84
+ "Black and White", "Navy and Gold", "Forest Green",
85
+ "Desert Sand", "Ocean Blue", "Sunset Orange",
86
+ "Deep Purple", "Emerald Green", "Ruby Red",
87
+ "Sapphire Blue", "Golden Yellow", "Sage Green",
88
+ "Dusty Rose", "Charcoal", "Cream", "Burgundy",
89
+ "Teal", "Copper", "Silver", "Bronze", "Slate"
90
+ ]
91
+ for color in color_schemes:
92
+ params = self.default_params.copy()
93
+ params["colors"] = color
94
+ prompt = update_prompt(**params)
95
+ expected = f"Design a Modern living room with a {color} color scheme"
96
+ self.assertEqual(prompt, expected)
97
+
98
+ def test_floor_combinations(self):
99
+ """Test various floor combinations"""
100
+ test_cases = [
101
+ {
102
+ "floor_t": "Hardwood",
103
+ "floor_c": "Keep Existing",
104
+ "floor_p": "Keep Existing",
105
+ "expected": "featuring Hardwood flooring"
106
+ },
107
+ {
108
+ "floor_t": "Hardwood",
109
+ "floor_c": "Light Oak",
110
+ "floor_p": "Keep Existing",
111
+ "expected": "featuring Hardwood in Light Oak flooring"
112
+ },
113
+ {
114
+ "floor_t": "Hardwood",
115
+ "floor_c": "Light Oak",
116
+ "floor_p": "Elegant Herringbone",
117
+ "expected": "featuring Hardwood in Light Oak with Elegant Herringbone pattern flooring"
118
+ }
119
+ ]
120
+ for case in test_cases:
121
+ params = self.default_params.copy()
122
+ params.update({
123
+ "floor_t": case["floor_t"],
124
+ "floor_c": case["floor_c"],
125
+ "floor_p": case["floor_p"]
126
+ })
127
+ prompt = update_prompt(**params)
128
+ expected = f"Design a Modern living room with a Neutral color scheme, {case['expected']}"
129
+ self.assertEqual(prompt, expected)
130
+
131
+ def test_wall_combinations(self):
132
+ """Test various wall combinations"""
133
+ test_cases = [
134
+ {
135
+ "wall_t": "Fresh Paint",
136
+ "wall_c": "Keep Existing",
137
+ "wall_f": "Keep Existing",
138
+ "expected": "with Fresh Paint walls"
139
+ },
140
+ {
141
+ "wall_t": "Fresh Paint",
142
+ "wall_c": "Crisp White",
143
+ "wall_f": "Keep Existing",
144
+ "expected": "with Fresh Paint in Crisp White walls"
145
+ },
146
+ {
147
+ "wall_t": "Fresh Paint",
148
+ "wall_c": "Crisp White",
149
+ "wall_f": "Pearl Satin",
150
+ "expected": "with Fresh Paint in Crisp White with Pearl Satin finish walls"
151
+ }
152
+ ]
153
+ for case in test_cases:
154
+ params = self.default_params.copy()
155
+ params.update({
156
+ "wall_t": case["wall_t"],
157
+ "wall_c": case["wall_c"],
158
+ "wall_f": case["wall_f"]
159
+ })
160
+ prompt = update_prompt(**params)
161
+ expected = f"Design a Modern living room with a Neutral color scheme, {case['expected']}"
162
+ self.assertEqual(prompt, expected)
163
+
164
+ def test_accessories_individual(self):
165
+ """Test each accessory individually"""
166
+ test_cases = [
167
+ {
168
+ "name": "art",
169
+ "params": {"art_en": True, "art_col": "Vibrant Colors", "art_size": "Oversized"},
170
+ "expected": "decorated with Oversized Vibrant Colors Art Print"
171
+ },
172
+ {
173
+ "name": "mirror",
174
+ "params": {"mirror_en": True, "mirror_fr": "Gold", "mirror_size": "Large"},
175
+ "expected": "decorated with Large Mirror with Gold frame"
176
+ },
177
+ {
178
+ "name": "sconce",
179
+ "params": {"sconce_en": True, "sconce_col": "Brass", "sconce_style": "Art Deco"},
180
+ "expected": "decorated with Art Deco Brass Wall Sconce"
181
+ },
182
+ {
183
+ "name": "shelf",
184
+ "params": {"shelf_en": True, "shelf_col": "Natural Wood", "shelf_size": "Set of 3"},
185
+ "expected": "decorated with Set of 3 Natural Wood Floating Shelves"
186
+ },
187
+ {
188
+ "name": "plants",
189
+ "params": {"plants_en": True, "plants_type": "Hanging Plants", "plants_size": "Medium"},
190
+ "expected": "decorated with Medium Hanging Plants"
191
+ }
192
+ ]
193
+ for case in test_cases:
194
+ params = self.default_params.copy()
195
+ params.update(case["params"])
196
+ prompt = update_prompt(**params)
197
+ expected = f"Design a Modern living room with a Neutral color scheme, {case['expected']}"
198
+ self.assertEqual(prompt, expected)
199
+
200
+ def test_custom_text_variations(self):
201
+ """Test custom text handling"""
202
+ test_cases = [
203
+ {"text": "", "should_include": False},
204
+ {"text": " ", "should_include": False},
205
+ {"text": "Add plants", "should_include": True},
206
+ {"text": "Make it cozy and warm", "should_include": True},
207
+ {"text": "Multiple\nlines", "should_include": True}
208
+ ]
209
+ for case in test_cases:
210
+ params = self.default_params.copy()
211
+ params["custom_text"] = case["text"]
212
+ prompt = update_prompt(**params)
213
+ base = "Design a Modern living room with a Neutral color scheme"
214
+ if case["should_include"]:
215
+ expected = f"{base}, {case['text'].strip()}"
216
+ else:
217
+ expected = base
218
+ self.assertEqual(prompt, expected)
219
+
220
+ def test_complex_combinations(self):
221
+ """Test complex combinations of all features"""
222
+ test_cases = [
223
+ {
224
+ "name": "full_living_room",
225
+ "params": {
226
+ "room": "Living Room",
227
+ "style": "Modern",
228
+ "colors": "Warm Gray",
229
+ "floor_t": "Hardwood",
230
+ "floor_c": "Light Oak",
231
+ "floor_p": "Elegant Herringbone",
232
+ "wall_t": "Fresh Paint",
233
+ "wall_c": "Crisp White",
234
+ "wall_f": "Pearl Satin",
235
+ "custom_text": "Make it perfect for entertaining",
236
+ "art_en": True,
237
+ "art_col": "Modern Abstract",
238
+ "art_size": "Statement",
239
+ "mirror_en": True,
240
+ "mirror_fr": "Gold",
241
+ "mirror_size": "Large",
242
+ "sconce_en": True,
243
+ "sconce_col": "Brass",
244
+ "sconce_style": "Art Deco",
245
+ "shelf_en": True,
246
+ "shelf_col": "Natural Wood",
247
+ "shelf_size": "Set of 3",
248
+ "plants_en": True,
249
+ "plants_type": "Hanging Plants",
250
+ "plants_size": "Medium"
251
+ }
252
+ },
253
+ {
254
+ "name": "minimal_bedroom",
255
+ "params": {
256
+ "room": "Bedroom",
257
+ "style": "Japanese",
258
+ "colors": "Minimalist White",
259
+ "floor_t": "Natural Bamboo",
260
+ "floor_c": "Keep Existing",
261
+ "floor_p": "Keep Existing",
262
+ "wall_t": "Fresh Paint",
263
+ "wall_c": "Soft White",
264
+ "wall_f": "Keep Existing",
265
+ "custom_text": "Focus on minimalism and zen aesthetics"
266
+ }
267
+ }
268
+ ]
269
+ for case in test_cases:
270
+ params = self.default_params.copy()
271
+ params.update(case["params"])
272
+ prompt = update_prompt(**params)
273
+ self.assertTrue(len(prompt) > 0)
274
+ self.assertTrue(prompt.startswith("Design a"))
275
+
276
+ class TestProductionModel(unittest.TestCase):
277
+ def setUp(self):
278
+ """Set up test environment"""
279
+ self.model = ProductionDesignModel()
280
+ # Create a simple test image
281
+ self.test_image = Image.fromarray(np.zeros((64, 64, 3), dtype=np.uint8))
282
+
283
+ def test_number_of_variations(self):
284
+ """Test that the model correctly handles different numbers of variations"""
285
+ test_cases = [1, 3, 10, 25, 50] # Test various numbers of variations
286
+ for num_variations in test_cases:
287
+ variations = self.model.generate_design(
288
+ image=self.test_image,
289
+ num_variations=num_variations,
290
+ prompt="Test prompt",
291
+ num_steps=20, # Minimum steps for faster testing
292
+ guidance_scale=7.5,
293
+ strength=0.75
294
+ )
295
+ self.assertEqual(
296
+ len(variations),
297
+ num_variations,
298
+ f"Expected {num_variations} variations, got {len(variations)}"
299
+ )
300
+
301
+ def test_invalid_variation_numbers(self):
302
+ """Test handling of invalid numbers of variations"""
303
+ test_cases = [-1, 0, 51, 100] # Test invalid numbers
304
+ for num_variations in test_cases:
305
+ variations = self.model.generate_design(
306
+ image=self.test_image,
307
+ num_variations=num_variations,
308
+ prompt="Test prompt",
309
+ num_steps=20,
310
+ guidance_scale=7.5,
311
+ strength=0.75
312
+ )
313
+ # Should clamp to valid range (1-50)
314
+ self.assertTrue(
315
+ 1 <= len(variations) <= 50,
316
+ f"Number of variations {len(variations)} outside valid range 1-50"
317
+ )
318
+
319
+ if __name__ == '__main__':
320
+ unittest.main()