from model import DesignModel from PIL import Image import numpy as np from typing import List import random import time import torch from diffusers import StableDiffusionImg2ImgPipeline from transformers import CLIPTokenizer import logging import os from datetime import datetime # Set up logging log_dir = "logs" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) class ProductionDesignModel(DesignModel): def __init__(self): super().__init__() try: self.device = "cuda" if torch.cuda.is_available() else "cpu" logging.info(f"Using device: {self.device}") self.model_id = "stabilityai/stable-diffusion-2-1" logging.info(f"Loading model: {self.model_id}") # Initialize the pipeline with error handling try: self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None # Disable safety checker for performance ).to(self.device) # Enable optimizations self.pipe.enable_attention_slicing() if self.device == "cuda": self.pipe.enable_model_cpu_offload() self.pipe.enable_vae_slicing() logging.info("Model loaded successfully") except Exception as e: logging.error(f"Error loading model: {e}") raise # Initialize tokenizer self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id) # Set default prompts self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped" self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" except Exception as e: logging.error(f"Error in initialization: {e}") raise def _prepare_prompt(self, prompt: str) -> str: """Prepare the prompt by adding quality suffix and checking length""" try: full_prompt = f"{prompt}, {self.additional_quality_suffix}" tokens = self.tokenizer.tokenize(full_prompt) if len(tokens) > 77: logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...") tokens = tokens[:77] full_prompt = self.tokenizer.convert_tokens_to_string(tokens) logging.info(f"Prepared prompt: {full_prompt}") return full_prompt except Exception as e: logging.error(f"Error preparing prompt: {e}") return prompt # Return original prompt if processing fails def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]: """Generate design variations with proper parameter handling""" generation_start = time.time() try: # Log input parameters logging.info(f"Generating {num_variations} variations with parameters: {kwargs}") # Get parameters from kwargs with defaults prompt = kwargs.get('prompt', '') num_steps = int(kwargs.get('num_steps', 50)) guidance_scale = float(kwargs.get('guidance_scale', 7.5)) strength = float(kwargs.get('strength', 0.75)) base_seed = kwargs.get('seed', int(time.time())) # Parameter validation num_steps = max(20, min(100, num_steps)) guidance_scale = max(1, min(20, guidance_scale)) strength = max(0.1, min(1.0, strength)) # Prepare the prompt full_prompt = self._prepare_prompt(prompt) # Generate distinct seeds seeds = [base_seed + i * 10000 for i in range(num_variations)] logging.info(f"Using seeds: {seeds}") # Prepare the input image if image.mode != "RGB": image = image.convert("RGB") # Generate variations variations = [] generator = torch.Generator(device=self.device) for i, seed in enumerate(seeds): try: variation_start = time.time() generator.manual_seed(seed) # Generate the image output = self.pipe( prompt=full_prompt, negative_prompt=self.neg_prompt, image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, strength=strength, generator=generator ).images[0] variations.append(np.array(output)) variation_time = time.time() - variation_start logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s") except Exception as e: logging.error(f"Error generating variation {i+1}: {e}") if not variations: # If no successful variations yet variations.append(np.array(image.convert('RGB'))) total_time = time.time() - generation_start logging.info(f"Generation completed in {total_time:.2f}s") return variations except Exception as e: logging.error(f"Error in generate_design: {e}") import traceback logging.error(traceback.format_exc()) return [np.array(image.convert('RGB'))] finally: if self.device == "cuda": torch.cuda.empty_cache() logging.info("Cleared CUDA cache") def __del__(self): """Cleanup when the model is deleted""" try: if self.device == "cuda": torch.cuda.empty_cache() logging.info("Final CUDA cache cleanup") except: pass