Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,409 Bytes
7c89d3a eeb4ef5 7c89d3a eeb4ef5 7c89d3a a495ef9 7c89d3a eeb4ef5 7c89d3a eeb4ef5 a495ef9 7c89d3a 83686fb 7c89d3a 83686fb 7c89d3a 9a3a9c1 83686fb 7c89d3a 83686fb 7c89d3a a495ef9 7c89d3a 83686fb 7c89d3a 83686fb 7c89d3a 83686fb 7c89d3a a495ef9 7c89d3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import os
import torch
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import safetensors.torch
from huggingface_hub import snapshot_download
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
from models.controlnet import ControlNetModel
from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline
from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
# Initialize global variables for models
pipeline = None
generator = None
accelerator = None
model_path = None
def load_pipeline(accelerator, model_path):
# Load scheduler
scheduler = DDPMScheduler.from_pretrained(
model_path,
subfolder="stable-diffusion-2-1-base/scheduler"
)
# Load models
text_encoder = CLIPTextModel.from_pretrained(
model_path,
subfolder="stable-diffusion-2-1-base/text_encoder"
)
tokenizer = CLIPTokenizer.from_pretrained(
model_path,
subfolder="stable-diffusion-2-1-base/tokenizer"
)
feature_extractor = CLIPImageProcessor.from_pretrained(
os.path.join(model_path, "stable-diffusion-2-1-base/feature_extractor")
)
unet = UNet2DConditionModel.from_pretrained(
model_path,
subfolder="stable-diffusion-2-1-base/unet"
)
controlnet = ControlNetModel.from_pretrained(
model_path,
subfolder="Controlnet"
)
vae = AutoencoderKL.from_pretrained(
model_path,
subfolder="vae"
)
# Freeze models
for model in [vae, text_encoder, unet, controlnet]:
model.requires_grad_(False)
# Initialize pipeline
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
)
# Set weight dtype based on mixed precision
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move models to accelerator device with appropriate dtype
for model in [text_encoder, vae, unet, controlnet]:
model.to(accelerator.device, dtype=weight_dtype)
return pipeline
@spaces.GPU
def initialize_models():
global pipeline, generator, accelerator, model_path
# Initialize accelerator
accelerator = Accelerator(
mixed_precision="fp16",
gradient_accumulation_steps=1
)
try:
# Download the entire repository
model_path = snapshot_download(
repo_id="NightRaven109/CCSRModels",
token=os.environ['Read2']
)
# Load pipeline using the original loading function
pipeline = load_pipeline(accelerator, model_path)
# Initialize generator
generator = torch.Generator(device=accelerator.device)
return True
except Exception as e:
print(f"Error initializing models: {str(e)}")
return False
@spaces.GPU
def process_image(
input_image,
prompt="clean, high-resolution, 8k",
negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
guidance_scale=1.0,
conditioning_scale=1.0,
num_inference_steps=20,
seed=42,
upscale_factor=2,
color_fix_method="adain"
):
global pipeline, generator, accelerator
if pipeline is None:
if not initialize_models():
return None
try:
# Set seed
if seed is not None:
generator.manual_seed(seed)
# Process input image
validation_image = Image.fromarray(input_image)
ori_width, ori_height = validation_image.size
# Resize logic from original script
resize_flag = False
rscale = upscale_factor
process_size = 512 # Same as args.process_size in original
if ori_width < process_size//rscale or ori_height < process_size//rscale:
scale = (process_size//rscale)/min(ori_width, ori_height)
tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
validation_image = tmp_image
resize_flag = True
validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
width, height = validation_image.size
# Move pipeline to GPU for processing
pipeline.to(accelerator.device)
# Generate image
with torch.no_grad():
inference_time, output = pipeline(
0.6666, # t_max
0.0, # t_min
False, # tile_diffusion
None, # tile_diffusion_size
None, # tile_diffusion_stride
prompt,
validation_image,
num_inference_steps=num_inference_steps,
generator=generator,
height=height,
width=width,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
conditioning_scale=conditioning_scale,
start_steps=999,
start_point='lr',
use_vae_encode_condition=False
)
image = output.images[0]
# Apply color fixing if specified
if color_fix_method != "none":
fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
image = fix_func(image, validation_image)
if resize_flag:
image = image.resize((ori_width*rscale, ori_height*rscale))
# Move pipeline back to CPU
pipeline.to("cpu")
torch.cuda.empty_cache()
return image
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Input Image"),
gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
],
outputs=gr.Image(label="Generated Image"),
title="Controllable Conditional Super-Resolution",
description="Upload an image to enhance its resolution using CCSR."
)
if __name__ == "__main__":
iface.launch()
|