anbucur commited on
Commit
d81760a
·
1 Parent(s): 8c93fbf

Enhance UI dropdown options and improve ProductionDesignModel initialization

Browse files

- Updated UI dropdowns in app.py to provide a comprehensive list of choices for room types, design styles, and color moods, enhancing user experience.
- Refactored layout for better organization of UI elements.
- Improved the ProductionDesignModel class in prod_model.py by implementing a more robust model initialization process, including advanced architecture setup and detailed logging for better error tracking.
- Added new model dependencies in requirements.txt to support the updated functionality.

Files changed (3) hide show
  1. app.py +86 -62
  2. prod_model.py +191 -130
  3. requirements.txt +7 -15
app.py CHANGED
@@ -256,29 +256,53 @@ def create_ui(model: DesignModel):
256
  with gr.Group():
257
  gr.Markdown("## 🏠 Basic Settings")
258
  with gr.Row():
259
- room_type = gr.Dropdown(
 
 
 
 
 
 
260
  label="Room Type",
261
- choices=["None"] + ["Living Room", "Bedroom", "Dining Room", "Kitchen", "Bathroom", "Home Office", "Master Bedroom", "Guest Room", "Study Room", "Game Room", "Media Room", "Nursery", "Gym", "Library"],
262
- value="None"
263
- )
264
- style_preset = gr.Dropdown(
 
 
 
 
 
 
 
 
 
265
  label="Design Style",
266
- choices=["None"] + ["Modern", "Contemporary", "Minimalist", "Industrial", "Scandinavian", "Mid-Century Modern", "Traditional", "Transitional", "Farmhouse", "Rustic", "Bohemian", "Art Deco", "Coastal", "Mediterranean", "Japanese", "French Country", "Victorian", "Colonial", "Gothic", "Baroque", "Rococo", "Neoclassical", "Eclectic", "Zen", "Tropical", "Shabby Chic", "Hollywood Regency", "Southwestern", "Asian Fusion", "Retro"],
267
- value="None"
268
- )
269
- color_scheme = gr.Dropdown(
 
 
 
 
 
 
 
 
 
 
270
  label="Color Mood",
271
- choices=["None"] + ["Neutral", "Monochromatic", "Minimalist White", "Warm Gray", "Cool Gray", "Earth Tones", "Pastel", "Bold Primary", "Jewel Tones", "Black and White", "Navy and Gold", "Forest Green", "Desert Sand", "Ocean Blue", "Sunset Orange", "Deep Purple", "Emerald Green", "Ruby Red", "Sapphire Blue", "Golden Yellow", "Sage Green", "Dusty Rose", "Charcoal", "Cream", "Burgundy", "Teal", "Copper", "Silver", "Bronze", "Slate"],
272
- value="None"
273
  )
274
 
275
  # Row 2 - Surface Finishes
276
- with gr.Row():
277
  # Floor Options
278
  with gr.Column(scale=1):
279
  with gr.Group():
280
  gr.Markdown("## 🎨 Floor Options")
281
- floor_type = gr.Dropdown(
282
  choices=[
283
  "Keep Existing", "Hardwood", "Stone Tiles", "Porcelain Tiles",
284
  "Soft Carpet", "Polished Concrete", "Marble", "Vinyl",
@@ -287,9 +311,9 @@ def create_ui(model: DesignModel):
287
  "Mosaic Tiles", "Luxury Vinyl Tiles", "Stained Concrete"
288
  ],
289
  label="Material",
290
- value="Keep Existing"
291
- )
292
- floor_color = gr.Dropdown(
293
  choices=[
294
  "Keep Existing", "Light Oak", "Rich Walnut", "Cool Gray",
295
  "Whitewashed", "Warm Cherry", "Deep Brown", "Classic Black",
@@ -298,10 +322,10 @@ def create_ui(model: DesignModel):
298
  "Cream Travertine", "Dark Slate", "Golden Teak",
299
  "Rustic Pine", "Ebony"
300
  ],
301
- label="Color",
302
- value="Keep Existing"
303
- )
304
- floor_pattern = gr.Dropdown(
305
  choices=[
306
  "Keep Existing", "Classic Straight", "Elegant Herringbone",
307
  "V-Pattern", "Decorative Parquet", "Diagonal Layout",
@@ -311,17 +335,17 @@ def create_ui(model: DesignModel):
311
  "Windmill Pattern", "Large Format", "Mixed Width"
312
  ],
313
  label="Pattern",
314
- value="Keep Existing"
315
- )
316
-
317
  # Wall Options
318
  with gr.Column(scale=1):
319
  with gr.Group():
320
  gr.Markdown("## 🎨 Wall Options")
321
- wall_type = gr.Dropdown(
322
  choices=[
323
  "Keep Existing", "Fresh Paint", "Designer Wallpaper",
324
- "Textured Finish", "Wood Panels", "Exposed Brick",
325
  "Natural Stone", "Wooden Planks", "Modern Concrete",
326
  "Venetian Plaster", "Wainscoting", "Shiplap",
327
  "3D Wall Panels", "Fabric Panels", "Metal Panels",
@@ -329,9 +353,9 @@ def create_ui(model: DesignModel):
329
  "Acoustic Panels", "Living Wall"
330
  ],
331
  label="Treatment",
332
- value="Keep Existing"
333
- )
334
- wall_color = gr.Dropdown(
335
  choices=[
336
  "Keep Existing", "Crisp White", "Soft White", "Warm Beige",
337
  "Gentle Gray", "Sky Blue", "Nature Green", "Sunny Yellow",
@@ -339,10 +363,10 @@ def create_ui(model: DesignModel):
339
  "Terracotta", "Navy Blue", "Charcoal Gray", "Lavender",
340
  "Olive Green", "Dusty Rose", "Teal", "Burgundy"
341
  ],
342
- label="Color",
343
- value="Keep Existing"
344
- )
345
- wall_finish = gr.Dropdown(
346
  choices=[
347
  "Keep Existing", "Soft Matte", "Subtle Eggshell",
348
  "Pearl Satin", "Sleek Semi-Gloss", "High Gloss",
@@ -351,9 +375,9 @@ def create_ui(model: DesignModel):
351
  "Venetian", "Lime Wash", "Concrete", "Rustic",
352
  "Lacquered", "Hammered", "Patina"
353
  ],
354
- label="Finish",
355
- value="Keep Existing"
356
- )
357
 
358
  # Row 3 - Wall Decorations and Special Requests
359
  with gr.Row(elem_classes="wall-decorations-row"):
@@ -367,7 +391,7 @@ def create_ui(model: DesignModel):
367
  with gr.Column():
368
  with gr.Row():
369
  art_print_enable = gr.Checkbox(label="Add Artwork", value=False)
370
- art_print_color = gr.Dropdown(
371
  choices=[
372
  "None", "Classic Black & White", "Vibrant Colors",
373
  "Single Color", "Soft Colors", "Modern Abstract",
@@ -378,8 +402,8 @@ def create_ui(model: DesignModel):
378
  ],
379
  label="Art Style",
380
  value="None"
381
- )
382
- art_print_size = gr.Dropdown(
383
  choices=[
384
  "None", "Modest", "Standard", "Statement", "Oversized",
385
  "Gallery Wall", "Diptych", "Triptych", "Mini Series",
@@ -391,9 +415,9 @@ def create_ui(model: DesignModel):
391
 
392
  # Mirror
393
  with gr.Column():
394
- with gr.Row():
395
  mirror_enable = gr.Checkbox(label="Add Mirror", value=False)
396
- mirror_frame = gr.Dropdown(
397
  choices=[
398
  "None", "Gold", "Silver", "Black", "White", "Wood",
399
  "Brass", "Bronze", "Copper", "Chrome", "Antique Gold",
@@ -402,8 +426,8 @@ def create_ui(model: DesignModel):
402
  ],
403
  label="Frame Style",
404
  value="None"
405
- )
406
- mirror_size = gr.Dropdown(
407
  choices=[
408
  "Small", "Medium", "Large", "Full Length",
409
  "Oversized", "Double Width", "Floor Mirror",
@@ -419,7 +443,7 @@ def create_ui(model: DesignModel):
419
  with gr.Column():
420
  with gr.Row():
421
  sconce_enable = gr.Checkbox(label="Add Wall Sconce", value=False)
422
- sconce_color = gr.Dropdown(
423
  choices=[
424
  "None", "Black", "Gold", "Silver", "Bronze", "White",
425
  "Brass", "Copper", "Chrome", "Antique Brass",
@@ -429,8 +453,8 @@ def create_ui(model: DesignModel):
429
  ],
430
  label="Sconce Color",
431
  value="None"
432
- )
433
- sconce_style = gr.Dropdown(
434
  choices=[
435
  "Modern", "Traditional", "Industrial", "Art Deco",
436
  "Minimalist", "Vintage", "Contemporary", "Rustic",
@@ -444,9 +468,9 @@ def create_ui(model: DesignModel):
444
 
445
  # Floating Shelves
446
  with gr.Column():
447
- with gr.Row():
448
  shelf_enable = gr.Checkbox(label="Add Floating Shelves", value=False)
449
- shelf_color = gr.Dropdown(
450
  choices=[
451
  "None", "White", "Black", "Natural Wood", "Glass",
452
  "Dark Wood", "Light Wood", "Metal", "Gold", "Silver",
@@ -456,8 +480,8 @@ def create_ui(model: DesignModel):
456
  ],
457
  label="Shelf Material",
458
  value="None"
459
- )
460
- shelf_size = gr.Dropdown(
461
  choices=[
462
  "Small", "Medium", "Large", "Set of 3",
463
  "Extra Long", "Corner Set", "Asymmetric Set",
@@ -466,13 +490,13 @@ def create_ui(model: DesignModel):
466
  ],
467
  label="Shelf Size",
468
  value="Medium"
469
- )
470
 
471
- # Plants
472
  with gr.Column():
473
- with gr.Row():
474
  plants_enable = gr.Checkbox(label="Add Plants", value=False)
475
- plants_type = gr.Dropdown(
476
  choices=[
477
  "None", "Hanging Plants", "Vertical Garden",
478
  "Plant Shelf", "Single Plant", "Climbing Vines",
@@ -483,8 +507,8 @@ def create_ui(model: DesignModel):
483
  ],
484
  label="Plant Type",
485
  value="None"
486
- )
487
- plants_size = gr.Dropdown(
488
  choices=[
489
  "Small", "Medium", "Large", "Mixed Sizes",
490
  "Full Wall", "Statement Piece", "Compact",
@@ -499,7 +523,7 @@ def create_ui(model: DesignModel):
499
  with gr.Column(scale=1):
500
  with gr.Group():
501
  gr.Markdown("## ✨ Special Requests")
502
- input_text = gr.Textbox(
503
  label="Additional Details",
504
  placeholder="Add any special requests or details here...",
505
  lines=3
@@ -517,14 +541,14 @@ def create_ui(model: DesignModel):
517
  step=1,
518
  label="Quality Steps"
519
  )
520
- guidance_scale = gr.Slider(
521
  minimum=1,
522
  maximum=20,
523
  value=7.5,
524
  step=0.1,
525
  label="Design Freedom"
526
- )
527
- strength = gr.Slider(
528
  minimum=0.1,
529
  maximum=1.0,
530
  value=0.75,
@@ -544,7 +568,7 @@ def create_ui(model: DesignModel):
544
  )
545
 
546
  # Row 4 - Current Prompts
547
- with gr.Row():
548
  with gr.Group():
549
  gr.Markdown("## 📝 Current Prompts")
550
  prompt_display = gr.TextArea(
@@ -858,9 +882,9 @@ def main():
858
  is_test_mode = "--test" in sys.argv
859
 
860
  if is_test_mode:
861
- print("Starting in TEST mode...")
862
  from mock_model import MockDesignModel
863
- model = MockDesignModel()
864
  else:
865
  print("Starting in PRODUCTION mode...")
866
  from prod_model import ProductionDesignModel
 
256
  with gr.Group():
257
  gr.Markdown("## 🏠 Basic Settings")
258
  with gr.Row():
259
+ room_type = gr.Dropdown(
260
+ choices=[
261
+ "Living Room", "Bedroom", "Kitchen", "Dining Room",
262
+ "Bathroom", "Home Office", "Kids Room", "Master Bedroom",
263
+ "Guest Room", "Studio Apartment", "Entryway", "Hallway",
264
+ "Game Room", "Library", "Home Theater", "Gym"
265
+ ],
266
  label="Room Type",
267
+ value="Living Room"
268
+ )
269
+ style_preset = gr.Dropdown(
270
+ choices=[
271
+ "Modern", "Contemporary", "Minimalist", "Industrial",
272
+ "Scandinavian", "Mid-Century Modern", "Traditional",
273
+ "Transitional", "Farmhouse", "Rustic", "Bohemian",
274
+ "Art Deco", "Coastal", "Mediterranean", "Japanese",
275
+ "French Country", "Victorian", "Colonial", "Gothic",
276
+ "Baroque", "Rococo", "Neoclassical", "Eclectic",
277
+ "Zen", "Tropical", "Shabby Chic", "Hollywood Regency",
278
+ "Southwestern", "Asian Fusion", "Retro"
279
+ ],
280
  label="Design Style",
281
+ value="Modern"
282
+ )
283
+ color_scheme = gr.Dropdown(
284
+ choices=[
285
+ "Neutral", "Monochromatic", "Minimalist White",
286
+ "Warm Gray", "Cool Gray", "Earth Tones",
287
+ "Pastel", "Bold Primary", "Jewel Tones",
288
+ "Black and White", "Navy and Gold", "Forest Green",
289
+ "Desert Sand", "Ocean Blue", "Sunset Orange",
290
+ "Deep Purple", "Emerald Green", "Ruby Red",
291
+ "Sapphire Blue", "Golden Yellow", "Sage Green",
292
+ "Dusty Rose", "Charcoal", "Cream", "Burgundy",
293
+ "Teal", "Copper", "Silver", "Bronze", "Slate"
294
+ ],
295
  label="Color Mood",
296
+ value="Neutral"
 
297
  )
298
 
299
  # Row 2 - Surface Finishes
300
+ with gr.Row():
301
  # Floor Options
302
  with gr.Column(scale=1):
303
  with gr.Group():
304
  gr.Markdown("## 🎨 Floor Options")
305
+ floor_type = gr.Dropdown(
306
  choices=[
307
  "Keep Existing", "Hardwood", "Stone Tiles", "Porcelain Tiles",
308
  "Soft Carpet", "Polished Concrete", "Marble", "Vinyl",
 
311
  "Mosaic Tiles", "Luxury Vinyl Tiles", "Stained Concrete"
312
  ],
313
  label="Material",
314
+ value="Keep Existing"
315
+ )
316
+ floor_color = gr.Dropdown(
317
  choices=[
318
  "Keep Existing", "Light Oak", "Rich Walnut", "Cool Gray",
319
  "Whitewashed", "Warm Cherry", "Deep Brown", "Classic Black",
 
322
  "Cream Travertine", "Dark Slate", "Golden Teak",
323
  "Rustic Pine", "Ebony"
324
  ],
325
+ label="Color",
326
+ value="Keep Existing"
327
+ )
328
+ floor_pattern = gr.Dropdown(
329
  choices=[
330
  "Keep Existing", "Classic Straight", "Elegant Herringbone",
331
  "V-Pattern", "Decorative Parquet", "Diagonal Layout",
 
335
  "Windmill Pattern", "Large Format", "Mixed Width"
336
  ],
337
  label="Pattern",
338
+ value="Keep Existing"
339
+ )
340
+
341
  # Wall Options
342
  with gr.Column(scale=1):
343
  with gr.Group():
344
  gr.Markdown("## 🎨 Wall Options")
345
+ wall_type = gr.Dropdown(
346
  choices=[
347
  "Keep Existing", "Fresh Paint", "Designer Wallpaper",
348
+ "Textured Finish", "Wood Panels", "Exposed Brick",
349
  "Natural Stone", "Wooden Planks", "Modern Concrete",
350
  "Venetian Plaster", "Wainscoting", "Shiplap",
351
  "3D Wall Panels", "Fabric Panels", "Metal Panels",
 
353
  "Acoustic Panels", "Living Wall"
354
  ],
355
  label="Treatment",
356
+ value="Keep Existing"
357
+ )
358
+ wall_color = gr.Dropdown(
359
  choices=[
360
  "Keep Existing", "Crisp White", "Soft White", "Warm Beige",
361
  "Gentle Gray", "Sky Blue", "Nature Green", "Sunny Yellow",
 
363
  "Terracotta", "Navy Blue", "Charcoal Gray", "Lavender",
364
  "Olive Green", "Dusty Rose", "Teal", "Burgundy"
365
  ],
366
+ label="Color",
367
+ value="Keep Existing"
368
+ )
369
+ wall_finish = gr.Dropdown(
370
  choices=[
371
  "Keep Existing", "Soft Matte", "Subtle Eggshell",
372
  "Pearl Satin", "Sleek Semi-Gloss", "High Gloss",
 
375
  "Venetian", "Lime Wash", "Concrete", "Rustic",
376
  "Lacquered", "Hammered", "Patina"
377
  ],
378
+ label="Finish",
379
+ value="Keep Existing"
380
+ )
381
 
382
  # Row 3 - Wall Decorations and Special Requests
383
  with gr.Row(elem_classes="wall-decorations-row"):
 
391
  with gr.Column():
392
  with gr.Row():
393
  art_print_enable = gr.Checkbox(label="Add Artwork", value=False)
394
+ art_print_color = gr.Dropdown(
395
  choices=[
396
  "None", "Classic Black & White", "Vibrant Colors",
397
  "Single Color", "Soft Colors", "Modern Abstract",
 
402
  ],
403
  label="Art Style",
404
  value="None"
405
+ )
406
+ art_print_size = gr.Dropdown(
407
  choices=[
408
  "None", "Modest", "Standard", "Statement", "Oversized",
409
  "Gallery Wall", "Diptych", "Triptych", "Mini Series",
 
415
 
416
  # Mirror
417
  with gr.Column():
418
+ with gr.Row():
419
  mirror_enable = gr.Checkbox(label="Add Mirror", value=False)
420
+ mirror_frame = gr.Dropdown(
421
  choices=[
422
  "None", "Gold", "Silver", "Black", "White", "Wood",
423
  "Brass", "Bronze", "Copper", "Chrome", "Antique Gold",
 
426
  ],
427
  label="Frame Style",
428
  value="None"
429
+ )
430
+ mirror_size = gr.Dropdown(
431
  choices=[
432
  "Small", "Medium", "Large", "Full Length",
433
  "Oversized", "Double Width", "Floor Mirror",
 
443
  with gr.Column():
444
  with gr.Row():
445
  sconce_enable = gr.Checkbox(label="Add Wall Sconce", value=False)
446
+ sconce_color = gr.Dropdown(
447
  choices=[
448
  "None", "Black", "Gold", "Silver", "Bronze", "White",
449
  "Brass", "Copper", "Chrome", "Antique Brass",
 
453
  ],
454
  label="Sconce Color",
455
  value="None"
456
+ )
457
+ sconce_style = gr.Dropdown(
458
  choices=[
459
  "Modern", "Traditional", "Industrial", "Art Deco",
460
  "Minimalist", "Vintage", "Contemporary", "Rustic",
 
468
 
469
  # Floating Shelves
470
  with gr.Column():
471
+ with gr.Row():
472
  shelf_enable = gr.Checkbox(label="Add Floating Shelves", value=False)
473
+ shelf_color = gr.Dropdown(
474
  choices=[
475
  "None", "White", "Black", "Natural Wood", "Glass",
476
  "Dark Wood", "Light Wood", "Metal", "Gold", "Silver",
 
480
  ],
481
  label="Shelf Material",
482
  value="None"
483
+ )
484
+ shelf_size = gr.Dropdown(
485
  choices=[
486
  "Small", "Medium", "Large", "Set of 3",
487
  "Extra Long", "Corner Set", "Asymmetric Set",
 
490
  ],
491
  label="Shelf Size",
492
  value="Medium"
493
+ )
494
 
495
+ # Plants
496
  with gr.Column():
497
+ with gr.Row():
498
  plants_enable = gr.Checkbox(label="Add Plants", value=False)
499
+ plants_type = gr.Dropdown(
500
  choices=[
501
  "None", "Hanging Plants", "Vertical Garden",
502
  "Plant Shelf", "Single Plant", "Climbing Vines",
 
507
  ],
508
  label="Plant Type",
509
  value="None"
510
+ )
511
+ plants_size = gr.Dropdown(
512
  choices=[
513
  "Small", "Medium", "Large", "Mixed Sizes",
514
  "Full Wall", "Statement Piece", "Compact",
 
523
  with gr.Column(scale=1):
524
  with gr.Group():
525
  gr.Markdown("## ✨ Special Requests")
526
+ input_text = gr.Textbox(
527
  label="Additional Details",
528
  placeholder="Add any special requests or details here...",
529
  lines=3
 
541
  step=1,
542
  label="Quality Steps"
543
  )
544
+ guidance_scale = gr.Slider(
545
  minimum=1,
546
  maximum=20,
547
  value=7.5,
548
  step=0.1,
549
  label="Design Freedom"
550
+ )
551
+ strength = gr.Slider(
552
  minimum=0.1,
553
  maximum=1.0,
554
  value=0.75,
 
568
  )
569
 
570
  # Row 4 - Current Prompts
571
+ with gr.Row():
572
  with gr.Group():
573
  gr.Markdown("## 📝 Current Prompts")
574
  prompt_display = gr.TextArea(
 
882
  is_test_mode = "--test" in sys.argv
883
 
884
  if is_test_mode:
885
+ print("Starting in TEST mode...")
886
  from mock_model import MockDesignModel
887
+ model = MockDesignModel()
888
  else:
889
  print("Starting in PRODUCTION mode...")
890
  from prod_model import ProductionDesignModel
prod_model.py CHANGED
@@ -5,11 +5,13 @@ from typing import List
5
  import random
6
  import time
7
  import torch
8
- from diffusers import StableDiffusionImg2ImgPipeline
9
- from transformers import CLIPTokenizer
 
10
  import logging
11
  import os
12
  from datetime import datetime
 
13
 
14
  # Set up logging
15
  log_dir = "logs"
@@ -27,158 +29,217 @@ logging.basicConfig(
27
 
28
  class ProductionDesignModel(DesignModel):
29
  def __init__(self):
30
- super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
- logging.info(f"Using device: {self.device}")
34
-
35
- self.model_id = "stabilityai/stable-diffusion-2-1"
36
- self.tokenizer_id = "openai/clip-vit-large-patch14" # Correct tokenizer for SD 2.1
37
- logging.info(f"Loading model: {self.model_id}")
38
- logging.info(f"Loading tokenizer: {self.tokenizer_id}")
39
-
40
- # Initialize the pipeline with error handling
41
- try:
42
- self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
43
- self.model_id,
44
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
45
- safety_checker=None # Disable safety checker for performance
46
- ).to(self.device)
47
-
48
- # Enable optimizations
49
- self.pipe.enable_attention_slicing()
50
- if self.device == "cuda":
51
- self.pipe.enable_model_cpu_offload()
52
- self.pipe.enable_vae_slicing()
53
-
54
- logging.info("Model loaded successfully")
55
-
56
- except Exception as e:
57
- logging.error(f"Error loading model: {e}")
58
- raise
59
-
60
- # Initialize tokenizer with correct path
61
- try:
62
- self.tokenizer = CLIPTokenizer.from_pretrained(self.tokenizer_id)
63
- logging.info("Tokenizer loaded successfully")
64
- except Exception as e:
65
- logging.error(f"Error loading tokenizer: {e}")
66
- raise
67
-
68
- # Set default prompts
69
- self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped"
70
- self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
71
-
72
  except Exception as e:
73
- logging.error(f"Error in initialization: {e}")
74
  raise
75
 
76
- def _prepare_prompt(self, prompt: str) -> str:
77
- """Prepare the prompt by adding quality suffix and checking length"""
78
- try:
79
- full_prompt = f"{prompt}, {self.additional_quality_suffix}"
80
- tokens = self.tokenizer.tokenize(full_prompt)
81
-
82
- if len(tokens) > 77:
83
- logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...")
84
- tokens = tokens[:77]
85
- full_prompt = self.tokenizer.convert_tokens_to_string(tokens)
86
-
87
- logging.info(f"Prepared prompt: {full_prompt}")
88
- return full_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- except Exception as e:
91
- logging.error(f"Error preparing prompt: {e}")
92
- return prompt # Return original prompt if processing fails
93
 
94
- def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
95
- """Generate design variations with proper parameter handling"""
96
- generation_start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  try:
98
- # Log input parameters
99
- logging.info(f"Generating {num_variations} variations with parameters: {kwargs}")
100
-
101
- # Get parameters from kwargs with defaults
102
- prompt = kwargs.get('prompt', '')
103
- num_steps = int(kwargs.get('num_steps', 50))
104
- guidance_scale = float(kwargs.get('guidance_scale', 7.5))
105
- strength = float(kwargs.get('strength', 0.75))
106
-
107
- # Handle seed properly
108
  seed_param = kwargs.get('seed')
109
  base_seed = int(time.time()) if seed_param is None else int(seed_param)
110
- logging.info(f"Using base seed: {base_seed}")
111
-
112
- # Parameter validation
113
- num_steps = max(20, min(100, num_steps))
114
- guidance_scale = max(1, min(20, guidance_scale))
115
- strength = max(0.1, min(1.0, strength))
116
-
117
- # Log validated parameters
118
- logging.info(f"Validated parameters: steps={num_steps}, guidance={guidance_scale}, strength={strength}")
119
 
120
- # Prepare the prompt
121
- full_prompt = self._prepare_prompt(prompt)
122
-
123
- # Generate distinct seeds
124
- seeds = [base_seed + i * 10000 for i in range(num_variations)]
125
- logging.info(f"Using seeds: {seeds}")
 
 
 
 
 
 
 
 
 
 
126
 
127
- # Prepare the input image
128
- if image.mode != "RGB":
129
- image = image.convert("RGB")
130
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Generate variations
132
  variations = []
133
- generator = torch.Generator(device=self.device)
134
-
135
- for i, seed in enumerate(seeds):
136
  try:
137
- variation_start = time.time()
138
- generator.manual_seed(seed)
139
-
140
- # Generate the image
141
- output = self.pipe(
142
- prompt=full_prompt,
143
  negative_prompt=self.neg_prompt,
144
- image=image,
145
  num_inference_steps=num_steps,
146
- guidance_scale=guidance_scale,
147
  strength=strength,
148
- generator=generator
 
 
 
 
 
149
  ).images[0]
150
 
151
- variations.append(np.array(output))
152
-
153
- variation_time = time.time() - variation_start
154
- logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s")
155
 
156
  except Exception as e:
157
- logging.error(f"Error generating variation {i+1}: {e}")
158
- if not variations: # If no successful variations yet
159
- variations.append(np.array(image.convert('RGB')))
160
-
161
- total_time = time.time() - generation_start
162
- logging.info(f"Generation completed in {total_time:.2f}s")
163
-
164
  return variations
165
-
166
  except Exception as e:
167
  logging.error(f"Error in generate_design: {e}")
168
- import traceback
169
- logging.error(traceback.format_exc())
170
- return [np.array(image.convert('RGB'))]
171
-
172
- finally:
173
- if self.device == "cuda":
174
- torch.cuda.empty_cache()
175
- logging.info("Cleared CUDA cache")
176
-
177
  def __del__(self):
178
  """Cleanup when the model is deleted"""
179
- try:
180
- if self.device == "cuda":
181
- torch.cuda.empty_cache()
182
- logging.info("Final CUDA cache cleanup")
183
- except:
184
- pass
 
5
  import random
6
  import time
7
  import torch
8
+ from diffusers.pipelines.controlnet import StableDiffusionControlNetInpaintPipeline
9
+ from diffusers import ControlNetModel, UniPCMultistepScheduler, AutoPipelineForText2Image
10
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, AutoModelForDepthEstimation
11
  import logging
12
  import os
13
  from datetime import datetime
14
+ import gc
15
 
16
  # Set up logging
17
  log_dir = "logs"
 
29
 
30
  class ProductionDesignModel(DesignModel):
31
  def __init__(self):
32
+ """Initialize the production model with advanced architecture"""
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
35
+
36
+ # Setup logging
37
+ logging.basicConfig(filename=f'logs/prod_model_{time.strftime("%Y%m%d")}.log',
38
+ level=logging.INFO,
39
+ format='%(asctime)s - %(levelname)s - %(message)s')
40
+
41
+ self.seed = 323*111
42
+ self.neg_prompt = "window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner"
43
+ self.control_items = ["windowpane;window", "door;double;door"]
44
+ self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
45
+
46
  try:
47
+ logging.info(f"Initializing models on {self.device} with {self.dtype}")
48
+ self._initialize_models()
49
+ logging.info("Models initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
+ logging.error(f"Error initializing models: {e}")
52
  raise
53
 
54
+ def _initialize_models(self):
55
+ """Initialize all required models and pipelines"""
56
+ # Initialize ControlNet models
57
+ self.controlnet_depth = ControlNetModel.from_pretrained(
58
+ "controlnet_depth", torch_dtype=self.dtype, use_safetensors=True
59
+ )
60
+ self.controlnet_seg = ControlNetModel.from_pretrained(
61
+ "own_controlnet", torch_dtype=self.dtype, use_safetensors=True
62
+ )
63
+
64
+ # Initialize main pipeline
65
+ self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
66
+ "SG161222/Realistic_Vision_V5.1_noVAE",
67
+ controlnet=[self.controlnet_depth, self.controlnet_seg],
68
+ safety_checker=None,
69
+ torch_dtype=self.dtype
70
+ )
71
+
72
+ # Setup IP-Adapter
73
+ self.pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models",
74
+ weight_name="ip-adapter_sd15.bin")
75
+ self.pipe.set_ip_adapter_scale(0.4)
76
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
77
+ self.pipe = self.pipe.to(self.device)
78
+
79
+ # Initialize guide pipeline
80
+ self.guide_pipe = AutoPipelineForText2Image.from_pretrained(
81
+ "segmind/SSD-1B",
82
+ torch_dtype=self.dtype,
83
+ use_safetensors=True,
84
+ variant="fp16"
85
+ ).to(self.device)
86
+
87
+ # Initialize segmentation and depth models
88
+ self.seg_processor, self.seg_model = self._init_segmentation()
89
+ self.depth_processor, self.depth_model = self._init_depth()
90
+ self.depth_model = self.depth_model.to(self.device)
91
+
92
+ def _init_segmentation(self):
93
+ """Initialize segmentation models"""
94
+ processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
95
+ model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
96
+ return processor, model
97
+
98
+ def _init_depth(self):
99
+ """Initialize depth estimation models"""
100
+ processor = AutoImageProcessor.from_pretrained(
101
+ "LiheYoung/depth-anything-large-hf",
102
+ torch_dtype=self.dtype
103
+ )
104
+ model = AutoModelForDepthEstimation.from_pretrained(
105
+ "LiheYoung/depth-anything-large-hf",
106
+ torch_dtype=self.dtype
107
+ )
108
+ return processor, model
109
+
110
+ def _get_depth_map(self, image: Image.Image) -> Image.Image:
111
+ """Generate depth map for input image"""
112
+ image_to_depth = self.depth_processor(images=image, return_tensors="pt").to(self.device)
113
+ with torch.inference_mode():
114
+ depth_map = self.depth_model(**image_to_depth).predicted_depth
115
+
116
+ width, height = image.size
117
+ depth_map = torch.nn.functional.interpolate(
118
+ depth_map.unsqueeze(1).float(),
119
+ size=(height, width),
120
+ mode="bicubic",
121
+ align_corners=False,
122
+ )
123
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
124
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
125
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
126
+ image = torch.cat([depth_map] * 3, dim=1)
127
+
128
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
129
+ return Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
130
+
131
+ def _segment_image(self, image: Image.Image) -> Image.Image:
132
+ """Generate segmentation map for input image"""
133
+ pixel_values = self.seg_processor(image, return_tensors="pt").pixel_values
134
+ with torch.inference_mode():
135
+ outputs = self.seg_model(pixel_values)
136
+
137
+ seg = self.seg_processor.post_process_semantic_segmentation(
138
+ outputs, target_sizes=[image.size[::-1]])[0]
139
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
140
+
141
+ # You'll need to implement the palette mapping here
142
+ # This is a placeholder - you should implement proper color mapping
143
+ for label in range(seg.max() + 1):
144
+ color_seg[seg == label, :] = [label * 30 % 255] * 3
145
 
146
+ return Image.fromarray(color_seg).convert('RGB')
 
 
147
 
148
+ def _resize_image(self, image: Image.Image, target_size: int) -> Image.Image:
149
+ """Resize image while maintaining aspect ratio"""
150
+ width, height = image.size
151
+ if width > height:
152
+ new_width = target_size
153
+ new_height = int(height * (target_size / width))
154
+ else:
155
+ new_height = target_size
156
+ new_width = int(width * (target_size / height))
157
+ return image.resize((new_width, new_height), Image.LANCZOS)
158
+
159
+ def _flush(self):
160
+ """Clear CUDA cache"""
161
+ gc.collect()
162
+ if torch.cuda.is_available():
163
+ torch.cuda.empty_cache()
164
+
165
+ def generate_design(self, image: Image.Image, prompt: str, **kwargs) -> List[Image.Image]:
166
+ """
167
+ Generate design variations based on input image and prompt
168
+ """
169
  try:
170
+ # Set seed
 
 
 
 
 
 
 
 
 
171
  seed_param = kwargs.get('seed')
172
  base_seed = int(time.time()) if seed_param is None else int(seed_param)
173
+ self.generator = torch.Generator(device=self.device).manual_seed(base_seed)
 
 
 
 
 
 
 
 
174
 
175
+ # Get parameters
176
+ num_variations = kwargs.get('num_variations', 1)
177
+ guidance_scale = float(kwargs.get('guidance_scale', 10.0))
178
+ num_steps = int(kwargs.get('num_steps', 50))
179
+ strength = float(kwargs.get('strength', 0.9))
180
+ img_size = int(kwargs.get('img_size', 768))
181
+
182
+ logging.info(f"Generating design with parameters: guidance_scale={guidance_scale}, "
183
+ f"num_steps={num_steps}, strength={strength}, img_size={img_size}")
184
+
185
+ # Prepare prompt
186
+ pos_prompt = f"{prompt}, {self.additional_quality_suffix}"
187
+
188
+ # Process input image
189
+ orig_size = image.size
190
+ input_image = self._resize_image(image, img_size)
191
 
192
+ # Generate depth map
193
+ depth_map = self._get_depth_map(input_image)
 
194
 
195
+ # Generate segmentation
196
+ seg_map = self._segment_image(input_image)
197
+
198
+ # Generate IP-adapter reference image
199
+ self._flush()
200
+ ip_image = self.guide_pipe(
201
+ pos_prompt,
202
+ num_inference_steps=num_steps,
203
+ negative_prompt=self.neg_prompt,
204
+ generator=self.generator
205
+ ).images[0]
206
+
207
  # Generate variations
208
  variations = []
209
+ for i in range(num_variations):
 
 
210
  try:
211
+ self._flush()
212
+ variation = self.pipe(
213
+ prompt=pos_prompt,
 
 
 
214
  negative_prompt=self.neg_prompt,
 
215
  num_inference_steps=num_steps,
 
216
  strength=strength,
217
+ guidance_scale=guidance_scale,
218
+ generator=self.generator,
219
+ image=input_image,
220
+ ip_adapter_image=ip_image,
221
+ control_image=[depth_map, seg_map],
222
+ controlnet_conditioning_scale=[0.5, 0.5]
223
  ).images[0]
224
 
225
+ # Resize back to original size
226
+ variation = variation.resize(orig_size, Image.LANCZOS)
227
+ variations.append(variation)
 
228
 
229
  except Exception as e:
230
+ logging.error(f"Error generating variation {i}: {e}")
231
+ continue
232
+
233
+ if not variations:
234
+ logging.warning("No variations were generated successfully")
235
+ return [image] # Return original image if no variations were generated
236
+
237
  return variations
238
+
239
  except Exception as e:
240
  logging.error(f"Error in generate_design: {e}")
241
+ return [image] # Return original image in case of error
242
+
 
 
 
 
 
 
 
243
  def __del__(self):
244
  """Cleanup when the model is deleted"""
245
+ self._flush()
 
 
 
 
 
requirements.txt CHANGED
@@ -2,32 +2,24 @@
2
  gradio>=3.50.2
3
  Pillow>=10.0.0
4
  numpy>=1.24.0
5
-
6
- # Model dependencies
7
  torch>=2.0.0
8
  diffusers>=0.21.0
9
  transformers>=4.31.0
10
  accelerate>=0.21.0
 
11
 
12
  # Google Drive integration
13
- google-auth>=2.22.0
14
- google-auth-oauthlib>=1.0.0
15
  google-api-python-client>=2.95.0
 
 
16
 
17
  # Utility packages
18
  python-dateutil>=2.8.2
19
- tqdm>=4.65.0
20
  requests>=2.31.0
21
-
22
- # Optional but recommended
23
- opencv-python>=4.8.0 # For image processing
24
- safetensors>=0.3.1 # For faster model loading
25
 
26
  # Development tools
27
  pytest>=7.4.0
28
- black>=22.0.0
29
- flake8>=6.0.0
30
- isort>=5.12.0
31
-
32
- # Testing dependencies
33
- pytest-mock>=3.11.1
 
2
  gradio>=3.50.2
3
  Pillow>=10.0.0
4
  numpy>=1.24.0
 
 
5
  torch>=2.0.0
6
  diffusers>=0.21.0
7
  transformers>=4.31.0
8
  accelerate>=0.21.0
9
+ safetensors>=0.3.1
10
 
11
  # Google Drive integration
 
 
12
  google-api-python-client>=2.95.0
13
+ google-auth-oauthlib>=1.0.0
14
+ google-auth>=2.22.0
15
 
16
  # Utility packages
17
  python-dateutil>=2.8.2
 
18
  requests>=2.31.0
19
+ tqdm>=4.65.0
20
+ opencv-python>=4.8.0
 
 
21
 
22
  # Development tools
23
  pytest>=7.4.0
24
+ pytest-mock>=3.11.1
25
+ mock>=5.1.0