import os import torch import gradio as gr import spaces import numpy as np from PIL import Image import safetensors.torch from huggingface_hub import hf_hub_download from accelerate import Accelerator from accelerate.utils import set_seed from diffusers import ( AutoencoderKL, DDPMScheduler, UNet2DConditionModel, ) from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor from models.controlnet import ControlNetModel from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix # Initialize global variables for models pipeline = None generator = None accelerator = None @spaces.GPU def initialize_models(): global pipeline, generator, accelerator # Initialize accelerator accelerator = Accelerator( mixed_precision="fp16", gradient_accumulation_steps=1 ) try: # Download and load models with authentication token scheduler = DDPMScheduler.from_pretrained( "NightRaven109/CCSRModels", subfolder="stable-diffusion-2-1-base/scheduler", use_auth_token=os.environ['Read2'] ) text_encoder = CLIPTextModel.from_pretrained( "NightRaven109/CCSRModels", subfolder="stable-diffusion-2-1-base/text_encoder", use_auth_token=os.environ['Read2'] ) tokenizer = CLIPTokenizer.from_pretrained( "NightRaven109/CCSRModels", subfolder="stable-diffusion-2-1-base/tokenizer", use_auth_token=os.environ['Read2'] ) feature_extractor = CLIPImageProcessor.from_pretrained( "NightRaven109/CCSRModels", subfolder="stable-diffusion-2-1-base/feature_extractor", use_auth_token=os.environ['Read2'] ) unet = UNet2DConditionModel.from_pretrained( "NightRaven109/CCSRModels", subfolder="stable-diffusion-2-1-base/unet", use_auth_token=os.environ['Read2'] ) controlnet = ControlNetModel.from_pretrained( "NightRaven109/CCSRModels", subfolder="Controlnet", use_auth_token=os.environ['Read2'] ) vae = AutoencoderKL.from_pretrained( "NightRaven109/CCSRModels", subfolder="vae", use_auth_token=os.environ['Read2'] ) # Rest of the code remains the same # Freeze models for model in [vae, text_encoder, unet, controlnet]: model.requires_grad_(False) # Initialize pipeline pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, ) # Get weight dtype based on mixed precision weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move models to device with appropriate dtype for model in [text_encoder, vae, unet, controlnet]: model.to(accelerator.device, dtype=weight_dtype) # Initialize generator generator = torch.Generator(device=accelerator.device) return True except Exception as e: print(f"Error initializing models: {str(e)}") return False @spaces.GPU def process_image( input_image, prompt="clean, high-resolution, 8k", negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", guidance_scale=1.0, conditioning_scale=1.0, num_inference_steps=20, seed=42, upscale_factor=2, color_fix_method="adain" ): global pipeline, generator, accelerator if pipeline is None: if not initialize_models(): return None try: # Set seed if seed is not None: generator.manual_seed(seed) # Process input image input_pil = Image.fromarray(input_image) width, height = input_pil.size # Resize image target_width = width * upscale_factor target_height = height * upscale_factor target_width = target_width - (target_width % 8) target_height = target_height - (target_height % 8) # Move pipeline to GPU for processing pipeline.to(accelerator.device) # Generate image with torch.no_grad(): output = pipeline( t_max=0.6666, t_min=0.0, tile_diffusion=False, added_prompt=prompt, image=input_pil, num_inference_steps=num_inference_steps, generator=generator, height=target_height, width=target_width, guidance_scale=guidance_scale, negative_prompt=negative_prompt, conditioning_scale=conditioning_scale, ) generated_image = output.images[0] # Apply color fixing if specified if color_fix_method != "none": fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix generated_image = fix_func(generated_image, input_pil) # Move pipeline back to CPU pipeline.to("cpu") torch.cuda.empty_cache() return generated_image except Exception as e: print(f"Error processing image: {str(e)}") return None # Create Gradio interface iface = gr.Interface( fn=process_image, inputs=[ gr.Image(label="Input Image"), gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"), gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"), gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"), gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"), gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"), gr.Number(label="Seed", value=42), gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"), gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"), ], outputs=gr.Image(label="Generated Image"), title="Controllable Conditional Super-Resolution", description="Upload an image to enhance its resolution using CCSR.", examples=[ ["example1.jpg", "clean, sharp, detailed", "blurry, noise", 1.0, 1.0, 20, 42, 2, "adain"], ["example2.jpg", "high-resolution, pristine", "artifacts, pixelated", 1.5, 1.0, 30, 123, 2, "wavelet"], ] ) if __name__ == "__main__": iface.launch()