Spaces:
Sleeping
Sleeping
from model import DesignModel | |
from PIL import Image | |
import numpy as np | |
from typing import List | |
import random | |
import time | |
import torch | |
from diffusers import StableDiffusionImg2ImgPipeline | |
from transformers import CLIPTokenizer | |
import logging | |
import os | |
from datetime import datetime | |
# Set up logging | |
log_dir = "logs" | |
os.makedirs(log_dir, exist_ok=True) | |
log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler() | |
] | |
) | |
class ProductionDesignModel(DesignModel): | |
def __init__(self): | |
super().__init__() | |
try: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logging.info(f"Using device: {self.device}") | |
self.model_id = "stabilityai/stable-diffusion-2-1" | |
logging.info(f"Loading model: {self.model_id}") | |
# Initialize the pipeline with error handling | |
try: | |
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
self.model_id, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
safety_checker=None # Disable safety checker for performance | |
).to(self.device) | |
# Enable optimizations | |
self.pipe.enable_attention_slicing() | |
if self.device == "cuda": | |
self.pipe.enable_model_cpu_offload() | |
self.pipe.enable_vae_slicing() | |
logging.info("Model loaded successfully") | |
except Exception as e: | |
logging.error(f"Error loading model: {e}") | |
raise | |
# Initialize tokenizer | |
self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id) | |
# Set default prompts | |
self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped" | |
self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" | |
except Exception as e: | |
logging.error(f"Error in initialization: {e}") | |
raise | |
def _prepare_prompt(self, prompt: str) -> str: | |
"""Prepare the prompt by adding quality suffix and checking length""" | |
try: | |
full_prompt = f"{prompt}, {self.additional_quality_suffix}" | |
tokens = self.tokenizer.tokenize(full_prompt) | |
if len(tokens) > 77: | |
logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...") | |
tokens = tokens[:77] | |
full_prompt = self.tokenizer.convert_tokens_to_string(tokens) | |
logging.info(f"Prepared prompt: {full_prompt}") | |
return full_prompt | |
except Exception as e: | |
logging.error(f"Error preparing prompt: {e}") | |
return prompt # Return original prompt if processing fails | |
def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]: | |
"""Generate design variations with proper parameter handling""" | |
generation_start = time.time() | |
try: | |
# Log input parameters | |
logging.info(f"Generating {num_variations} variations with parameters: {kwargs}") | |
# Get parameters from kwargs with defaults | |
prompt = kwargs.get('prompt', '') | |
num_steps = int(kwargs.get('num_steps', 50)) | |
guidance_scale = float(kwargs.get('guidance_scale', 7.5)) | |
strength = float(kwargs.get('strength', 0.75)) | |
base_seed = kwargs.get('seed', int(time.time())) | |
# Parameter validation | |
num_steps = max(20, min(100, num_steps)) | |
guidance_scale = max(1, min(20, guidance_scale)) | |
strength = max(0.1, min(1.0, strength)) | |
# Prepare the prompt | |
full_prompt = self._prepare_prompt(prompt) | |
# Generate distinct seeds | |
seeds = [base_seed + i * 10000 for i in range(num_variations)] | |
logging.info(f"Using seeds: {seeds}") | |
# Prepare the input image | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Generate variations | |
variations = [] | |
generator = torch.Generator(device=self.device) | |
for i, seed in enumerate(seeds): | |
try: | |
variation_start = time.time() | |
generator.manual_seed(seed) | |
# Generate the image | |
output = self.pipe( | |
prompt=full_prompt, | |
negative_prompt=self.neg_prompt, | |
image=image, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
strength=strength, | |
generator=generator | |
).images[0] | |
variations.append(np.array(output)) | |
variation_time = time.time() - variation_start | |
logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s") | |
except Exception as e: | |
logging.error(f"Error generating variation {i+1}: {e}") | |
if not variations: # If no successful variations yet | |
variations.append(np.array(image.convert('RGB'))) | |
total_time = time.time() - generation_start | |
logging.info(f"Generation completed in {total_time:.2f}s") | |
return variations | |
except Exception as e: | |
logging.error(f"Error in generate_design: {e}") | |
import traceback | |
logging.error(traceback.format_exc()) | |
return [np.array(image.convert('RGB'))] | |
finally: | |
if self.device == "cuda": | |
torch.cuda.empty_cache() | |
logging.info("Cleared CUDA cache") | |
def __del__(self): | |
"""Cleanup when the model is deleted""" | |
try: | |
if self.device == "cuda": | |
torch.cuda.empty_cache() | |
logging.info("Final CUDA cache cleanup") | |
except: | |
pass |