Spaces:
Running
Running
File size: 6,754 Bytes
4e4b650 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from model import DesignModel
from PIL import Image
import numpy as np
from typing import List
import random
import time
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from transformers import CLIPTokenizer
import logging
import os
from datetime import datetime
# Set up logging
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log")
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
class ProductionDesignModel(DesignModel):
def __init__(self):
super().__init__()
try:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {self.device}")
self.model_id = "stabilityai/stable-diffusion-2-1"
logging.info(f"Loading model: {self.model_id}")
# Initialize the pipeline with error handling
try:
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
safety_checker=None # Disable safety checker for performance
).to(self.device)
# Enable optimizations
self.pipe.enable_attention_slicing()
if self.device == "cuda":
self.pipe.enable_model_cpu_offload()
self.pipe.enable_vae_slicing()
logging.info("Model loaded successfully")
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
# Initialize tokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id)
# Set default prompts
self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped"
self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic"
except Exception as e:
logging.error(f"Error in initialization: {e}")
raise
def _prepare_prompt(self, prompt: str) -> str:
"""Prepare the prompt by adding quality suffix and checking length"""
try:
full_prompt = f"{prompt}, {self.additional_quality_suffix}"
tokens = self.tokenizer.tokenize(full_prompt)
if len(tokens) > 77:
logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...")
tokens = tokens[:77]
full_prompt = self.tokenizer.convert_tokens_to_string(tokens)
logging.info(f"Prepared prompt: {full_prompt}")
return full_prompt
except Exception as e:
logging.error(f"Error preparing prompt: {e}")
return prompt # Return original prompt if processing fails
def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]:
"""Generate design variations with proper parameter handling"""
generation_start = time.time()
try:
# Log input parameters
logging.info(f"Generating {num_variations} variations with parameters: {kwargs}")
# Get parameters from kwargs with defaults
prompt = kwargs.get('prompt', '')
num_steps = int(kwargs.get('num_steps', 50))
guidance_scale = float(kwargs.get('guidance_scale', 7.5))
strength = float(kwargs.get('strength', 0.75))
base_seed = kwargs.get('seed', int(time.time()))
# Parameter validation
num_steps = max(20, min(100, num_steps))
guidance_scale = max(1, min(20, guidance_scale))
strength = max(0.1, min(1.0, strength))
# Prepare the prompt
full_prompt = self._prepare_prompt(prompt)
# Generate distinct seeds
seeds = [base_seed + i * 10000 for i in range(num_variations)]
logging.info(f"Using seeds: {seeds}")
# Prepare the input image
if image.mode != "RGB":
image = image.convert("RGB")
# Generate variations
variations = []
generator = torch.Generator(device=self.device)
for i, seed in enumerate(seeds):
try:
variation_start = time.time()
generator.manual_seed(seed)
# Generate the image
output = self.pipe(
prompt=full_prompt,
negative_prompt=self.neg_prompt,
image=image,
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
strength=strength,
generator=generator
).images[0]
variations.append(np.array(output))
variation_time = time.time() - variation_start
logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s")
except Exception as e:
logging.error(f"Error generating variation {i+1}: {e}")
if not variations: # If no successful variations yet
variations.append(np.array(image.convert('RGB')))
total_time = time.time() - generation_start
logging.info(f"Generation completed in {total_time:.2f}s")
return variations
except Exception as e:
logging.error(f"Error in generate_design: {e}")
import traceback
logging.error(traceback.format_exc())
return [np.array(image.convert('RGB'))]
finally:
if self.device == "cuda":
torch.cuda.empty_cache()
logging.info("Cleared CUDA cache")
def __del__(self):
"""Cleanup when the model is deleted"""
try:
if self.device == "cuda":
torch.cuda.empty_cache()
logging.info("Final CUDA cache cleanup")
except:
pass |