File size: 6,754 Bytes
4e4b650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from model import DesignModel
from PIL import Image
import numpy as np
from typing import List
import random
import time
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPTokenizer
import logging
import os
from datetime import datetime

# Set up logging
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log")

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

class ProductionDesignModel(DesignModel):
    def __init__(self):
        super().__init__()
        try:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            logging.info(f"Using device: {self.device}")
            
            self.model_id = "stabilityai/stable-diffusion-2-1"
            logging.info(f"Loading model: {self.model_id}")
            
            # Initialize the pipeline with error handling
            try:
                self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
                    self.model_id,
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                    safety_checker=None  # Disable safety checker for performance
                ).to(self.device)
                
                # Enable optimizations
                self.pipe.enable_attention_slicing()
                if self.device == "cuda":
                    self.pipe.enable_model_cpu_offload()
                    self.pipe.enable_vae_slicing()
                
                logging.info("Model loaded successfully")
                
            except Exception as e:
                logging.error(f"Error loading model: {e}")
                raise
            
            # Initialize tokenizer
            self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
            
            # Set default prompts
            self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped"
            self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
            
        except Exception as e:
            logging.error(f"Error in initialization: {e}")
            raise

    def _prepare_prompt(self, prompt: str) -> str:
        """Prepare the prompt by adding quality suffix and checking length"""
        try:
            full_prompt = f"{prompt}, {self.additional_quality_suffix}"
            tokens = self.tokenizer.tokenize(full_prompt)
            
            if len(tokens) > 77:
                logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...")
                tokens = tokens[:77]
                full_prompt = self.tokenizer.convert_tokens_to_string(tokens)
            
            logging.info(f"Prepared prompt: {full_prompt}")
            return full_prompt
            
        except Exception as e:
            logging.error(f"Error preparing prompt: {e}")
            return prompt  # Return original prompt if processing fails

    def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
        """Generate design variations with proper parameter handling"""
        generation_start = time.time()
        try:
            # Log input parameters
            logging.info(f"Generating {num_variations} variations with parameters: {kwargs}")
            
            # Get parameters from kwargs with defaults
            prompt = kwargs.get('prompt', '')
            num_steps = int(kwargs.get('num_steps', 50))
            guidance_scale = float(kwargs.get('guidance_scale', 7.5))
            strength = float(kwargs.get('strength', 0.75))
            base_seed = kwargs.get('seed', int(time.time()))
            
            # Parameter validation
            num_steps = max(20, min(100, num_steps))
            guidance_scale = max(1, min(20, guidance_scale))
            strength = max(0.1, min(1.0, strength))
            
            # Prepare the prompt
            full_prompt = self._prepare_prompt(prompt)
            
            # Generate distinct seeds
            seeds = [base_seed + i * 10000 for i in range(num_variations)]
            logging.info(f"Using seeds: {seeds}")
            
            # Prepare the input image
            if image.mode != "RGB":
                image = image.convert("RGB")
            
            # Generate variations
            variations = []
            generator = torch.Generator(device=self.device)
            
            for i, seed in enumerate(seeds):
                try:
                    variation_start = time.time()
                    generator.manual_seed(seed)
                    
                    # Generate the image
                    output = self.pipe(
                        prompt=full_prompt,
                        negative_prompt=self.neg_prompt,
                        image=image,
                        num_inference_steps=num_steps,
                        guidance_scale=guidance_scale,
                        strength=strength,
                        generator=generator
                    ).images[0]
                    
                    variations.append(np.array(output))
                    
                    variation_time = time.time() - variation_start
                    logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s")
                    
                except Exception as e:
                    logging.error(f"Error generating variation {i+1}: {e}")
                    if not variations:  # If no successful variations yet
                        variations.append(np.array(image.convert('RGB')))
            
            total_time = time.time() - generation_start
            logging.info(f"Generation completed in {total_time:.2f}s")
            
            return variations
            
        except Exception as e:
            logging.error(f"Error in generate_design: {e}")
            import traceback
            logging.error(traceback.format_exc())
            return [np.array(image.convert('RGB'))]
        
        finally:
            if self.device == "cuda":
                torch.cuda.empty_cache()
                logging.info("Cleared CUDA cache")
        
    def __del__(self):
        """Cleanup when the model is deleted"""
        try:
            if self.device == "cuda":
                torch.cuda.empty_cache()
                logging.info("Final CUDA cache cleanup")
        except:
            pass