from model import DesignModel from PIL import Image import numpy as np from typing import List import random import time import torch from diffusers.pipelines.controlnet import StableDiffusionControlNetInpaintPipeline from diffusers import ControlNetModel, UniPCMultistepScheduler, AutoPipelineForText2Image from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, AutoModelForDepthEstimation import logging import os from datetime import datetime import gc # Set up logging log_dir = "logs" os.makedirs(log_dir, exist_ok=True) log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) class ProductionDesignModel(DesignModel): def __init__(self): """Initialize the production model with advanced architecture""" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 if self.device == "cuda" else torch.float32 # Setup logging logging.basicConfig(filename=f'logs/prod_model_{time.strftime("%Y%m%d")}.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') self.seed = 323*111 self.neg_prompt = "window, door, low resolution, banner, logo, watermark, text, deformed, blurry, out of focus, surreal, ugly, beginner" self.control_items = ["windowpane;window", "door;double;door"] self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" try: logging.info(f"Initializing models on {self.device} with {self.dtype}") self._initialize_models() logging.info("Models initialized successfully") except Exception as e: logging.error(f"Error initializing models: {e}") raise def _initialize_models(self): """Initialize all required models and pipelines""" # Initialize ControlNet models self.controlnet_depth = ControlNetModel.from_pretrained( "controlnet_depth", torch_dtype=self.dtype, use_safetensors=True ) self.controlnet_seg = ControlNetModel.from_pretrained( "own_controlnet", torch_dtype=self.dtype, use_safetensors=True ) # Initialize main pipeline self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( "SG161222/Realistic_Vision_V5.1_noVAE", controlnet=[self.controlnet_depth, self.controlnet_seg], safety_checker=None, torch_dtype=self.dtype ) # Setup IP-Adapter self.pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") self.pipe.set_ip_adapter_scale(0.4) self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) self.pipe = self.pipe.to(self.device) # Initialize guide pipeline self.guide_pipe = AutoPipelineForText2Image.from_pretrained( "segmind/SSD-1B", torch_dtype=self.dtype, use_safetensors=True, variant="fp16" ).to(self.device) # Initialize segmentation and depth models self.seg_processor, self.seg_model = self._init_segmentation() self.depth_processor, self.depth_model = self._init_depth() self.depth_model = self.depth_model.to(self.device) def _init_segmentation(self): """Initialize segmentation models""" processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") return processor, model def _init_depth(self): """Initialize depth estimation models""" processor = AutoImageProcessor.from_pretrained( "LiheYoung/depth-anything-large-hf", torch_dtype=self.dtype ) model = AutoModelForDepthEstimation.from_pretrained( "LiheYoung/depth-anything-large-hf", torch_dtype=self.dtype ) return processor, model def _get_depth_map(self, image: Image.Image) -> Image.Image: """Generate depth map for input image""" image_to_depth = self.depth_processor(images=image, return_tensors="pt").to(self.device) with torch.inference_mode(): depth_map = self.depth_model(**image_to_depth).predicted_depth width, height = image.size depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1).float(), size=(height, width), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] return Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) def _segment_image(self, image: Image.Image) -> Image.Image: """Generate segmentation map for input image""" pixel_values = self.seg_processor(image, return_tensors="pt").pixel_values with torch.inference_mode(): outputs = self.seg_model(pixel_values) seg = self.seg_processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]])[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # You'll need to implement the palette mapping here # This is a placeholder - you should implement proper color mapping for label in range(seg.max() + 1): color_seg[seg == label, :] = [label * 30 % 255] * 3 return Image.fromarray(color_seg).convert('RGB') def _resize_image(self, image: Image.Image, target_size: int) -> Image.Image: """Resize image while maintaining aspect ratio""" width, height = image.size if width > height: new_width = target_size new_height = int(height * (target_size / width)) else: new_height = target_size new_width = int(width * (target_size / height)) return image.resize((new_width, new_height), Image.LANCZOS) def _flush(self): """Clear CUDA cache""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def generate_design(self, image, num_variations=1, **kwargs): """Generate design variations using the model. Args: image: Input image (PIL Image, numpy array, or torch tensor) num_variations: Number of variations to generate **kwargs: Additional parameters like prompt, num_steps, guidance_scale, strength Returns: List of generated images """ try: # Convert image to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) elif isinstance(image, torch.Tensor): # Convert tensor to numpy then PIL image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) if not isinstance(image, Image.Image): raise ValueError(f"Unsupported image type: {type(image)}") # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Get parameters prompt = kwargs.get('prompt', '') num_steps = int(kwargs.get('num_steps', 50)) guidance_scale = float(kwargs.get('guidance_scale', 10.0)) strength = float(kwargs.get('strength', 0.9)) seed_param = kwargs.get('seed') # Handle seed base_seed = int(time.time()) if seed_param is None else int(seed_param) logging.info(f"Using base seed: {base_seed}") variations = [] for i in range(num_variations): try: # Generate distinct seed for each variation seed = base_seed + i generator = torch.Generator(device=self.device).manual_seed(seed) # Generate variation output = self.pipe( prompt=prompt, image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, strength=strength, generator=generator, negative_prompt=self.neg_prompt ).images[0] variations.append(output) logging.info(f"Successfully generated variation {i} with seed {seed}") except Exception as e: logging.error(f"Error generating variation {i}: {str(e)}") continue finally: # Clear CUDA cache after each variation if torch.cuda.is_available(): torch.cuda.empty_cache() if not variations: logging.warning("No variations were generated successfully") return [image] # Return original image if no variations generated return variations except Exception as e: logging.error(f"Error in generate_design: {str(e)}") return [image] # Return original image on error def __del__(self): """Cleanup when the model is deleted""" self._flush()