NightRaven109's picture
Update app.py
4445e78 verified
raw
history blame
7.13 kB
import os
import torch
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import safetensors.torch
from huggingface_hub import hf_hub_download
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
from models.controlnet import ControlNetModel
from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline
from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
# Initialize global variables for models
pipeline = None
generator = None
accelerator = None
@spaces.GPU
def initialize_models():
global pipeline, generator, accelerator
# Initialize accelerator
accelerator = Accelerator(
mixed_precision="fp16",
gradient_accumulation_steps=1
)
try:
# Download and load models with authentication token
scheduler = DDPMScheduler.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/scheduler",
use_auth_token=os.environ['Read2']
)
text_encoder = CLIPTextModel.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/text_encoder",
use_auth_token=os.environ['Read2']
)
tokenizer = CLIPTokenizer.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/tokenizer",
use_auth_token=os.environ['Read2']
)
feature_extractor = CLIPImageProcessor.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/feature_extractor",
use_auth_token=os.environ['Read2']
)
unet = UNet2DConditionModel.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="stable-diffusion-2-1-base/unet",
use_auth_token=os.environ['Read2']
)
controlnet = ControlNetModel.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="Controlnet",
use_auth_token=os.environ['Read2']
)
vae = AutoencoderKL.from_pretrained(
"NightRaven109/CCSRModels",
subfolder="vae",
use_auth_token=os.environ['Read2']
)
# Rest of the code remains the same
# Freeze models
for model in [vae, text_encoder, unet, controlnet]:
model.requires_grad_(False)
# Initialize pipeline
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
)
# Get weight dtype based on mixed precision
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move models to device with appropriate dtype
for model in [text_encoder, vae, unet, controlnet]:
model.to(accelerator.device, dtype=weight_dtype)
# Initialize generator
generator = torch.Generator(device=accelerator.device)
return True
except Exception as e:
print(f"Error initializing models: {str(e)}")
return False
@spaces.GPU
def process_image(
input_image,
prompt="clean, high-resolution, 8k",
negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
guidance_scale=1.0,
conditioning_scale=1.0,
num_inference_steps=20,
seed=42,
upscale_factor=2,
color_fix_method="adain"
):
global pipeline, generator, accelerator
if pipeline is None:
if not initialize_models():
return None
try:
# Set seed
if seed is not None:
generator.manual_seed(seed)
# Process input image
input_pil = Image.fromarray(input_image)
width, height = input_pil.size
# Resize image
target_width = width * upscale_factor
target_height = height * upscale_factor
target_width = target_width - (target_width % 8)
target_height = target_height - (target_height % 8)
# Move pipeline to GPU for processing
pipeline.to(accelerator.device)
# Generate image
with torch.no_grad():
output = pipeline(
t_max=0.6666,
t_min=0.0,
tile_diffusion=False,
added_prompt=prompt,
image=input_pil,
num_inference_steps=num_inference_steps,
generator=generator,
height=target_height,
width=target_width,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
conditioning_scale=conditioning_scale,
)
generated_image = output.images[0]
# Apply color fixing if specified
if color_fix_method != "none":
fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
generated_image = fix_func(generated_image, input_pil)
# Move pipeline back to CPU
pipeline.to("cpu")
torch.cuda.empty_cache()
return generated_image
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(label="Input Image"),
gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
gr.Number(label="Seed", value=42),
gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
],
outputs=gr.Image(label="Generated Image"),
title="Controllable Conditional Super-Resolution",
description="Upload an image to enhance its resolution using CCSR.",
examples=[
["example1.jpg", "clean, sharp, detailed", "blurry, noise", 1.0, 1.0, 20, 42, 2, "adain"],
["example2.jpg", "high-resolution, pristine", "artifacts, pixelated", 1.5, 1.0, 30, 123, 2, "wavelet"],
]
)
if __name__ == "__main__":
iface.launch()