StableDesign2 / prod_model.py
anbucur
Added functionality
4e4b650
raw
history blame
6.75 kB
from model import DesignModel
from PIL import Image
import numpy as np
from typing import List
import random
import time
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPTokenizer
import logging
import os
from datetime import datetime
# Set up logging
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
class ProductionDesignModel(DesignModel):
def __init__(self):
super().__init__()
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {self.device}")
self.model_id = "stabilityai/stable-diffusion-2-1"
logging.info(f"Loading model: {self.model_id}")
# Initialize the pipeline with error handling
try:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
safety_checker=None # Disable safety checker for performance
).to(self.device)
# Enable optimizations
self.pipe.enable_attention_slicing()
if self.device == "cuda":
self.pipe.enable_model_cpu_offload()
self.pipe.enable_vae_slicing()
logging.info("Model loaded successfully")
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
# Initialize tokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
# Set default prompts
self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped"
self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
except Exception as e:
logging.error(f"Error in initialization: {e}")
raise
def _prepare_prompt(self, prompt: str) -> str:
"""Prepare the prompt by adding quality suffix and checking length"""
try:
full_prompt = f"{prompt}, {self.additional_quality_suffix}"
tokens = self.tokenizer.tokenize(full_prompt)
if len(tokens) > 77:
logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...")
tokens = tokens[:77]
full_prompt = self.tokenizer.convert_tokens_to_string(tokens)
logging.info(f"Prepared prompt: {full_prompt}")
return full_prompt
except Exception as e:
logging.error(f"Error preparing prompt: {e}")
return prompt # Return original prompt if processing fails
def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
"""Generate design variations with proper parameter handling"""
generation_start = time.time()
try:
# Log input parameters
logging.info(f"Generating {num_variations} variations with parameters: {kwargs}")
# Get parameters from kwargs with defaults
prompt = kwargs.get('prompt', '')
num_steps = int(kwargs.get('num_steps', 50))
guidance_scale = float(kwargs.get('guidance_scale', 7.5))
strength = float(kwargs.get('strength', 0.75))
base_seed = kwargs.get('seed', int(time.time()))
# Parameter validation
num_steps = max(20, min(100, num_steps))
guidance_scale = max(1, min(20, guidance_scale))
strength = max(0.1, min(1.0, strength))
# Prepare the prompt
full_prompt = self._prepare_prompt(prompt)
# Generate distinct seeds
seeds = [base_seed + i * 10000 for i in range(num_variations)]
logging.info(f"Using seeds: {seeds}")
# Prepare the input image
if image.mode != "RGB":
image = image.convert("RGB")
# Generate variations
variations = []
generator = torch.Generator(device=self.device)
for i, seed in enumerate(seeds):
try:
variation_start = time.time()
generator.manual_seed(seed)
# Generate the image
output = self.pipe(
prompt=full_prompt,
negative_prompt=self.neg_prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=generator
).images[0]
variations.append(np.array(output))
variation_time = time.time() - variation_start
logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s")
except Exception as e:
logging.error(f"Error generating variation {i+1}: {e}")
if not variations: # If no successful variations yet
variations.append(np.array(image.convert('RGB')))
total_time = time.time() - generation_start
logging.info(f"Generation completed in {total_time:.2f}s")
return variations
except Exception as e:
logging.error(f"Error in generate_design: {e}")
import traceback
logging.error(traceback.format_exc())
return [np.array(image.convert('RGB'))]
finally:
if self.device == "cuda":
torch.cuda.empty_cache()
logging.info("Cleared CUDA cache")
def __del__(self):
"""Cleanup when the model is deleted"""
try:
if self.device == "cuda":
torch.cuda.empty_cache()
logging.info("Final CUDA cache cleanup")
except:
pass