from model import DesignModel from PIL import Image import numpy as np from typing import List import random import time import torch from diffusers.pipelines.controlnet import StableDiffusionControlNetInpaintPipeline from diffusers import ControlNetModel, UniPCMultistepScheduler, AutoPipelineForText2Image from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, AutoModelForDepthEstimation import logging import os from datetime import datetime import gc # Set up logging log_dir = "logs" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) class ProductionDesignModel(DesignModel): def __init__(self): """Initialize the production model with advanced architecture""" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 if self.device == "cuda" else torch.float32 # Setup logging logging.basicConfig(filename=f'logs/prod_model_{time.strftime("%Y%m%d")}.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') self.seed = 323*111 self.neg_prompt = "window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner" self.control_items = ["windowpane;window", "door;double;door"] self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" try: logging.info(f"Initializing models on {self.device} with {self.dtype}") self._initialize_models() logging.info("Models initialized successfully") except Exception as e: logging.error(f"Error initializing models: {e}") raise def _initialize_models(self): """Initialize all required models and pipelines""" # Initialize ControlNet models self.controlnet_depth = ControlNetModel.from_pretrained( "controlnet_depth", torch_dtype=self.dtype, use_safetensors=True ) self.controlnet_seg = ControlNetModel.from_pretrained( "own_controlnet", torch_dtype=self.dtype, use_safetensors=True ) # Initialize main pipeline self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "SG161222/Realistic_Vision_V5.1_noVAE", controlnet=[self.controlnet_depth, self.controlnet_seg], safety_checker=None, torch_dtype=self.dtype ) # Setup IP-Adapter self.pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") self.pipe.set_ip_adapter_scale(0.4) self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) self.pipe = self.pipe.to(self.device) # Initialize guide pipeline self.guide_pipe = AutoPipelineForText2Image.from_pretrained( "segmind/SSD-1B", torch_dtype=self.dtype, use_safetensors=True, variant="fp16" ).to(self.device) # Initialize segmentation and depth models self.seg_processor, self.seg_model = self._init_segmentation() self.depth_processor, self.depth_model = self._init_depth() self.depth_model = self.depth_model.to(self.device) def _init_segmentation(self): """Initialize segmentation models""" processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") return processor, model def _init_depth(self): """Initialize depth estimation models""" processor = AutoImageProcessor.from_pretrained( "LiheYoung/depth-anything-large-hf", torch_dtype=self.dtype ) model = AutoModelForDepthEstimation.from_pretrained( "LiheYoung/depth-anything-large-hf", torch_dtype=self.dtype ) return processor, model def _get_depth_map(self, image: Image.Image) -> Image.Image: """Generate depth map for input image""" image_to_depth = self.depth_processor(images=image, return_tensors="pt").to(self.device) with torch.inference_mode(): depth_map = self.depth_model(**image_to_depth).predicted_depth width, height = image.size depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1).float(), size=(height, width), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] return Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) def _segment_image(self, image: Image.Image) -> Image.Image: """Generate segmentation map for input image""" pixel_values = self.seg_processor(image, return_tensors="pt").pixel_values with torch.inference_mode(): outputs = self.seg_model(pixel_values) seg = self.seg_processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]])[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # You'll need to implement the palette mapping here # This is a placeholder - you should implement proper color mapping for label in range(seg.max() + 1): color_seg[seg == label, :] = [label * 30 % 255] * 3 return Image.fromarray(color_seg).convert('RGB') def _resize_image(self, image: Image.Image, target_size: int) -> Image.Image: """Resize image while maintaining aspect ratio""" width, height = image.size if width > height: new_width = target_size new_height = int(height * (target_size / width)) else: new_height = target_size new_width = int(width * (target_size / height)) return image.resize((new_width, new_height), Image.LANCZOS) def _flush(self): """Clear CUDA cache""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_design(self, image: Image.Image, prompt: str, **kwargs) -> List[Image.Image]: """ Generate design variations based on input image and prompt """ try: # Set seed seed_param = kwargs.get('seed') base_seed = int(time.time()) if seed_param is None else int(seed_param) self.generator = torch.Generator(device=self.device).manual_seed(base_seed) # Get parameters num_variations = kwargs.get('num_variations', 1) guidance_scale = float(kwargs.get('guidance_scale', 10.0)) num_steps = int(kwargs.get('num_steps', 50)) strength = float(kwargs.get('strength', 0.9)) img_size = int(kwargs.get('img_size', 768)) logging.info(f"Generating design with parameters: guidance_scale={guidance_scale}, " f"num_steps={num_steps}, strength={strength}, img_size={img_size}") # Prepare prompt pos_prompt = f"{prompt}, {self.additional_quality_suffix}" # Process input image orig_size = image.size input_image = self._resize_image(image, img_size) # Generate depth map depth_map = self._get_depth_map(input_image) # Generate segmentation seg_map = self._segment_image(input_image) # Generate IP-adapter reference image self._flush() ip_image = self.guide_pipe( pos_prompt, num_inference_steps=num_steps, negative_prompt=self.neg_prompt, generator=self.generator ).images[0] # Generate variations variations = [] for i in range(num_variations): try: self._flush() variation = self.pipe( prompt=pos_prompt, negative_prompt=self.neg_prompt, num_inference_steps=num_steps, strength=strength, guidance_scale=guidance_scale, generator=self.generator, image=input_image, ip_adapter_image=ip_image, control_image=[depth_map, seg_map], controlnet_conditioning_scale=[0.5, 0.5] ).images[0] # Resize back to original size variation = variation.resize(orig_size, Image.LANCZOS) variations.append(variation) except Exception as e: logging.error(f"Error generating variation {i}: {e}") continue if not variations: logging.warning("No variations were generated successfully") return [image] # Return original image if no variations were generated return variations except Exception as e: logging.error(f"Error in generate_design: {e}") return [image] # Return original image in case of error def __del__(self): """Cleanup when the model is deleted""" self._flush()