from model import DesignModel from PIL import Image import numpy as np from typing import List import random import time import torch from diffusers import StableDiffusionImg2ImgPipeline from transformers import CLIPTokenizer import logging import os from datetime import datetime # Set up logging log_dir = "logs" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) class ProductionDesignModel(DesignModel): def __init__(self): super().__init__() try: self.device = "cuda" if torch.cuda.is_available() else "cpu" logging.info(f"Using device: {self.device}") self.model_id = "stabilityai/stable-diffusion-2-1" self.tokenizer_id = "openai/clip-vit-large-patch14" # Correct tokenizer for SD 2.1 logging.info(f"Loading model: {self.model_id}") logging.info(f"Loading tokenizer: {self.tokenizer_id}") # Initialize the pipeline with error handling try: self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, safety_checker=None # Disable safety checker for performance ).to(self.device) # Enable optimizations self.pipe.enable_attention_slicing() if self.device == "cuda": self.pipe.enable_model_cpu_offload() self.pipe.enable_vae_slicing() logging.info("Model loaded successfully") except Exception as e: logging.error(f"Error loading model: {e}") raise # Initialize tokenizer with correct path try: self.tokenizer = CLIPTokenizer.from_pretrained(self.tokenizer_id) logging.info("Tokenizer loaded successfully") except Exception as e: logging.error(f"Error loading tokenizer: {e}") raise # Set default prompts self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped" self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" except Exception as e: logging.error(f"Error in initialization: {e}") raise def _prepare_prompt(self, prompt: str) -> str: """Prepare the prompt by adding quality suffix and checking length""" try: full_prompt = f"{prompt}, {self.additional_quality_suffix}" tokens = self.tokenizer.tokenize(full_prompt) if len(tokens) > 77: logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...") tokens = tokens[:77] full_prompt = self.tokenizer.convert_tokens_to_string(tokens) logging.info(f"Prepared prompt: {full_prompt}") return full_prompt except Exception as e: logging.error(f"Error preparing prompt: {e}") return prompt # Return original prompt if processing fails def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]: """Generate design variations with proper parameter handling""" generation_start = time.time() try: # Log input parameters logging.info(f"Generating {num_variations} variations with parameters: {kwargs}") # Get parameters from kwargs with defaults prompt = kwargs.get('prompt', '') num_steps = int(kwargs.get('num_steps', 50)) guidance_scale = float(kwargs.get('guidance_scale', 7.5)) strength = float(kwargs.get('strength', 0.75)) # Handle seed properly seed_param = kwargs.get('seed') base_seed = int(time.time()) if seed_param is None else seed_param logging.info(f"Using base seed: {base_seed}") # Parameter validation num_steps = max(20, min(100, num_steps)) guidance_scale = max(1, min(20, guidance_scale)) strength = max(0.1, min(1.0, strength)) # Log validated parameters logging.info(f"Validated parameters: steps={num_steps}, guidance={guidance_scale}, strength={strength}") # Prepare the prompt full_prompt = self._prepare_prompt(prompt) # Generate distinct seeds seeds = [base_seed + i * 10000 for i in range(num_variations)] logging.info(f"Using seeds: {seeds}") # Prepare the input image if image.mode != "RGB": image = image.convert("RGB") # Generate variations variations = [] generator = torch.Generator(device=self.device) for i, seed in enumerate(seeds): try: variation_start = time.time() generator.manual_seed(seed) # Generate the image output = self.pipe( prompt=full_prompt, negative_prompt=self.neg_prompt, image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, strength=strength, generator=generator ).images[0] variations.append(np.array(output)) variation_time = time.time() - variation_start logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s") except Exception as e: logging.error(f"Error generating variation {i+1}: {e}") if not variations: # If no successful variations yet variations.append(np.array(image.convert('RGB'))) total_time = time.time() - generation_start logging.info(f"Generation completed in {total_time:.2f}s") return variations except Exception as e: logging.error(f"Error in generate_design: {e}") import traceback logging.error(traceback.format_exc()) return [np.array(image.convert('RGB'))] finally: if self.device == "cuda": torch.cuda.empty_cache() logging.info("Cleared CUDA cache") def __del__(self): """Cleanup when the model is deleted""" try: if self.device == "cuda": torch.cuda.empty_cache() logging.info("Final CUDA cache cleanup") except: pass