Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,132 Bytes
7c89d3a 4445e78 7c89d3a 4445e78 7c89d3a 4445e78 7c89d3a 4445e78 7c89d3a 4445e78 7c89d3a 4445e78 7c89d3a 4445e78 7c89d3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import os
import torch
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import safetensors.torch
from huggingface_hub import hf_hub_download
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
from models.controlnet import ControlNetModel
from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline
from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
# Initialize global variables for models
pipeline = None
generator = None
accelerator = None
@spaces.GPU
def initialize_models():
global pipeline, generator, accelerator
# Initialize accelerator
accelerator = Accelerator(
mixed_precision="fp16",
gradient_accumulation_steps=1
)
try:
# Download and load models with authentication token
scheduler = DDPMScheduler.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/scheduler",
use_auth_token=os.environ['Read2']
)
text_encoder = CLIPTextModel.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/text_encoder",
use_auth_token=os.environ['Read2']
)
tokenizer = CLIPTokenizer.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/tokenizer",
use_auth_token=os.environ['Read2']
)
feature_extractor = CLIPImageProcessor.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/feature_extractor",
use_auth_token=os.environ['Read2']
)
unet = UNet2DConditionModel.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/unet",
use_auth_token=os.environ['Read2']
)
controlnet = ControlNetModel.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="Controlnet",
use_auth_token=os.environ['Read2']
)
vae = AutoencoderKL.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="vae",
use_auth_token=os.environ['Read2']
)
# Rest of the code remains the same
# Freeze models
for model in [vae, text_encoder, unet, controlnet]:
model.requires_grad_(False)
# Initialize pipeline
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
)
# Get weight dtype based on mixed precision
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move models to device with appropriate dtype
for model in [text_encoder, vae, unet, controlnet]:
model.to(accelerator.device, dtype=weight_dtype)
# Initialize generator
generator = torch.Generator(device=accelerator.device)
return True
except Exception as e:
print(f"Error initializing models: {str(e)}")
return False
@spaces.GPU
def process_image(
input_image,
prompt="clean, high-resolution, 8k",
negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
guidance_scale=1.0,
conditioning_scale=1.0,
num_inference_steps=20,
seed=42,
upscale_factor=2,
color_fix_method="adain"
):
global pipeline, generator, accelerator
if pipeline is None:
if not initialize_models():
return None
try:
# Set seed
if seed is not None:
generator.manual_seed(seed)
# Process input image
input_pil = Image.fromarray(input_image)
width, height = input_pil.size
# Resize image
target_width = width * upscale_factor
target_height = height * upscale_factor
target_width = target_width - (target_width % 8)
target_height = target_height - (target_height % 8)
# Move pipeline to GPU for processing
pipeline.to(accelerator.device)
# Generate image
with torch.no_grad():
output = pipeline(
t_max=0.6666,
t_min=0.0,
tile_diffusion=False,
added_prompt=prompt,
image=input_pil,
num_inference_steps=num_inference_steps,
generator=generator,
height=target_height,
width=target_width,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
conditioning_scale=conditioning_scale,
)
generated_image = output.images[0]
# Apply color fixing if specified
if color_fix_method != "none":
fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
generated_image = fix_func(generated_image, input_pil)
# Move pipeline back to CPU
pipeline.to("cpu")
torch.cuda.empty_cache()
return generated_image
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Input Image"),
gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
],
outputs=gr.Image(label="Generated Image"),
title="Controllable Conditional Super-Resolution",
description="Upload an image to enhance its resolution using CCSR.",
examples=[
["example1.jpg", "clean, sharp, detailed", "blurry, noise", 1.0, 1.0, 20, 42, 2, "adain"],
["example2.jpg", "high-resolution, pristine", "artifacts, pixelated", 1.5, 1.0, 30, 123, 2, "wavelet"],
]
)
if __name__ == "__main__":
iface.launch()
|