Spaces:
Running
Running
anbucur
Refactor generate_design method in ProductionDesignModel for improved image handling and variation generation
5d8e518
from model import DesignModel | |
from PIL import Image | |
import numpy as np | |
from typing import List | |
import random | |
import time | |
import torch | |
from diffusers.pipelines.controlnet import StableDiffusionControlNetInpaintPipeline | |
from diffusers import ControlNetModel, UniPCMultistepScheduler, AutoPipelineForText2Image | |
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, AutoModelForDepthEstimation | |
import logging | |
import os | |
from datetime import datetime | |
import gc | |
# Set up logging | |
log_dir = "logs" | |
os.makedirs(log_dir, exist_ok=True) | |
log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler() | |
] | |
) | |
class ProductionDesignModel(DesignModel): | |
def __init__(self): | |
"""Initialize the production model with advanced architecture""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
# Setup logging | |
logging.basicConfig(filename=f'logs/prod_model_{time.strftime("%Y%m%d")}.log', | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
self.seed = 323*111 | |
self.neg_prompt = "window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner" | |
self.control_items = ["windowpane;window", "door;double;door"] | |
self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" | |
try: | |
logging.info(f"Initializing models on {self.device} with {self.dtype}") | |
self._initialize_models() | |
logging.info("Models initialized successfully") | |
except Exception as e: | |
logging.error(f"Error initializing models: {e}") | |
raise | |
def _initialize_models(self): | |
"""Initialize all required models and pipelines""" | |
# Initialize ControlNet models | |
self.controlnet_depth = ControlNetModel.from_pretrained( | |
"controlnet_depth", torch_dtype=self.dtype, use_safetensors=True | |
) | |
self.controlnet_seg = ControlNetModel.from_pretrained( | |
"own_controlnet", torch_dtype=self.dtype, use_safetensors=True | |
) | |
# Initialize main pipeline | |
self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
"SG161222/Realistic_Vision_V5.1_noVAE", | |
controlnet=[self.controlnet_depth, self.controlnet_seg], | |
safety_checker=None, | |
torch_dtype=self.dtype | |
) | |
# Setup IP-Adapter | |
self.pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", | |
weight_name="ip-adapter_sd15.bin") | |
self.pipe.set_ip_adapter_scale(0.4) | |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) | |
self.pipe = self.pipe.to(self.device) | |
# Initialize guide pipeline | |
self.guide_pipe = AutoPipelineForText2Image.from_pretrained( | |
"segmind/SSD-1B", | |
torch_dtype=self.dtype, | |
use_safetensors=True, | |
variant="fp16" | |
).to(self.device) | |
# Initialize segmentation and depth models | |
self.seg_processor, self.seg_model = self._init_segmentation() | |
self.depth_processor, self.depth_model = self._init_depth() | |
self.depth_model = self.depth_model.to(self.device) | |
def _init_segmentation(self): | |
"""Initialize segmentation models""" | |
processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") | |
model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") | |
return processor, model | |
def _init_depth(self): | |
"""Initialize depth estimation models""" | |
processor = AutoImageProcessor.from_pretrained( | |
"LiheYoung/depth-anything-large-hf", | |
torch_dtype=self.dtype | |
) | |
model = AutoModelForDepthEstimation.from_pretrained( | |
"LiheYoung/depth-anything-large-hf", | |
torch_dtype=self.dtype | |
) | |
return processor, model | |
def _get_depth_map(self, image: Image.Image) -> Image.Image: | |
"""Generate depth map for input image""" | |
image_to_depth = self.depth_processor(images=image, return_tensors="pt").to(self.device) | |
with torch.inference_mode(): | |
depth_map = self.depth_model(**image_to_depth).predicted_depth | |
width, height = image.size | |
depth_map = torch.nn.functional.interpolate( | |
depth_map.unsqueeze(1).float(), | |
size=(height, width), | |
mode="bicubic", | |
align_corners=False, | |
) | |
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) | |
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) | |
depth_map = (depth_map - depth_min) / (depth_max - depth_min) | |
image = torch.cat([depth_map] * 3, dim=1) | |
image = image.permute(0, 2, 3, 1).cpu().numpy()[0] | |
return Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) | |
def _segment_image(self, image: Image.Image) -> Image.Image: | |
"""Generate segmentation map for input image""" | |
pixel_values = self.seg_processor(image, return_tensors="pt").pixel_values | |
with torch.inference_mode(): | |
outputs = self.seg_model(pixel_values) | |
seg = self.seg_processor.post_process_semantic_segmentation( | |
outputs, target_sizes=[image.size[::-1]])[0] | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |
# You'll need to implement the palette mapping here | |
# This is a placeholder - you should implement proper color mapping | |
for label in range(seg.max() + 1): | |
color_seg[seg == label, :] = [label * 30 % 255] * 3 | |
return Image.fromarray(color_seg).convert('RGB') | |
def _resize_image(self, image: Image.Image, target_size: int) -> Image.Image: | |
"""Resize image while maintaining aspect ratio""" | |
width, height = image.size | |
if width > height: | |
new_width = target_size | |
new_height = int(height * (target_size / width)) | |
else: | |
new_height = target_size | |
new_width = int(width * (target_size / height)) | |
return image.resize((new_width, new_height), Image.LANCZOS) | |
def _flush(self): | |
"""Clear CUDA cache""" | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def generate_design(self, image, num_variations=1, **kwargs): | |
"""Generate design variations using the model. | |
Args: | |
image: Input image (PIL Image, numpy array, or torch tensor) | |
num_variations: Number of variations to generate | |
**kwargs: Additional parameters like prompt, num_steps, guidance_scale, strength | |
Returns: | |
List of generated images | |
""" | |
try: | |
# Convert image to PIL Image if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif isinstance(image, torch.Tensor): | |
# Convert tensor to numpy then PIL | |
image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) | |
if not isinstance(image, Image.Image): | |
raise ValueError(f"Unsupported image type: {type(image)}") | |
# Ensure image is RGB | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Get parameters | |
prompt = kwargs.get('prompt', '') | |
num_steps = int(kwargs.get('num_steps', 50)) | |
guidance_scale = float(kwargs.get('guidance_scale', 10.0)) | |
strength = float(kwargs.get('strength', 0.9)) | |
seed_param = kwargs.get('seed') | |
# Handle seed | |
base_seed = int(time.time()) if seed_param is None else int(seed_param) | |
logging.info(f"Using base seed: {base_seed}") | |
variations = [] | |
for i in range(num_variations): | |
try: | |
# Generate distinct seed for each variation | |
seed = base_seed + i | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
# Generate variation | |
output = self.pipe( | |
prompt=prompt, | |
image=image, | |
num_inference_steps=num_steps, | |
guidance_scale=guidance_scale, | |
strength=strength, | |
generator=generator, | |
negative_prompt=self.neg_prompt | |
).images[0] | |
variations.append(output) | |
logging.info(f"Successfully generated variation {i} with seed {seed}") | |
except Exception as e: | |
logging.error(f"Error generating variation {i}: {str(e)}") | |
continue | |
finally: | |
# Clear CUDA cache after each variation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if not variations: | |
logging.warning("No variations were generated successfully") | |
return [image] # Return original image if no variations generated | |
return variations | |
except Exception as e: | |
logging.error(f"Error in generate_design: {str(e)}") | |
return [image] # Return original image on error | |
def __del__(self): | |
"""Cleanup when the model is deleted""" | |
self._flush() |