StableDesign2 / prod_model.py
anbucur
Enhance seed handling and logging in ProductionDesignModel class
9cd6532
raw
history blame
7.48 kB
from model import DesignModel
from PIL import Image
import numpy as np
from typing import List
import random
import time
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPTokenizer
import logging
import os
from datetime import datetime
# Set up logging
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
class ProductionDesignModel(DesignModel):
def __init__(self):
super().__init__()
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {self.device}")
self.model_id = "stabilityai/stable-diffusion-2-1"
self.tokenizer_id = "openai/clip-vit-large-patch14" # Correct tokenizer for SD 2.1
logging.info(f"Loading model: {self.model_id}")
logging.info(f"Loading tokenizer: {self.tokenizer_id}")
# Initialize the pipeline with error handling
try:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
safety_checker=None # Disable safety checker for performance
).to(self.device)
# Enable optimizations
self.pipe.enable_attention_slicing()
if self.device == "cuda":
self.pipe.enable_model_cpu_offload()
self.pipe.enable_vae_slicing()
logging.info("Model loaded successfully")
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
# Initialize tokenizer with correct path
try:
self.tokenizer = CLIPTokenizer.from_pretrained(self.tokenizer_id)
logging.info("Tokenizer loaded successfully")
except Exception as e:
logging.error(f"Error loading tokenizer: {e}")
raise
# Set default prompts
self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped"
self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
except Exception as e:
logging.error(f"Error in initialization: {e}")
raise
def _prepare_prompt(self, prompt: str) -> str:
"""Prepare the prompt by adding quality suffix and checking length"""
try:
full_prompt = f"{prompt}, {self.additional_quality_suffix}"
tokens = self.tokenizer.tokenize(full_prompt)
if len(tokens) > 77:
logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...")
tokens = tokens[:77]
full_prompt = self.tokenizer.convert_tokens_to_string(tokens)
logging.info(f"Prepared prompt: {full_prompt}")
return full_prompt
except Exception as e:
logging.error(f"Error preparing prompt: {e}")
return prompt # Return original prompt if processing fails
def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
"""Generate design variations with proper parameter handling"""
generation_start = time.time()
try:
# Log input parameters
logging.info(f"Generating {num_variations} variations with parameters: {kwargs}")
# Get parameters from kwargs with defaults
prompt = kwargs.get('prompt', '')
num_steps = int(kwargs.get('num_steps', 50))
guidance_scale = float(kwargs.get('guidance_scale', 7.5))
strength = float(kwargs.get('strength', 0.75))
# Handle seed properly
seed_param = kwargs.get('seed')
base_seed = int(time.time()) if seed_param is None else seed_param
logging.info(f"Using base seed: {base_seed}")
# Parameter validation
num_steps = max(20, min(100, num_steps))
guidance_scale = max(1, min(20, guidance_scale))
strength = max(0.1, min(1.0, strength))
# Log validated parameters
logging.info(f"Validated parameters: steps={num_steps}, guidance={guidance_scale}, strength={strength}")
# Prepare the prompt
full_prompt = self._prepare_prompt(prompt)
# Generate distinct seeds
seeds = [base_seed + i * 10000 for i in range(num_variations)]
logging.info(f"Using seeds: {seeds}")
# Prepare the input image
if image.mode != "RGB":
image = image.convert("RGB")
# Generate variations
variations = []
generator = torch.Generator(device=self.device)
for i, seed in enumerate(seeds):
try:
variation_start = time.time()
generator.manual_seed(seed)
# Generate the image
output = self.pipe(
prompt=full_prompt,
negative_prompt=self.neg_prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=generator
).images[0]
variations.append(np.array(output))
variation_time = time.time() - variation_start
logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s")
except Exception as e:
logging.error(f"Error generating variation {i+1}: {e}")
if not variations: # If no successful variations yet
variations.append(np.array(image.convert('RGB')))
total_time = time.time() - generation_start
logging.info(f"Generation completed in {total_time:.2f}s")
return variations
except Exception as e:
logging.error(f"Error in generate_design: {e}")
import traceback
logging.error(traceback.format_exc())
return [np.array(image.convert('RGB'))]
finally:
if self.device == "cuda":
torch.cuda.empty_cache()
logging.info("Cleared CUDA cache")
def __del__(self):
"""Cleanup when the model is deleted"""
try:
if self.device == "cuda":
torch.cuda.empty_cache()
logging.info("Final CUDA cache cleanup")
except:
pass