File size: 5,384 Bytes
1f460ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import torch
import yaml
import os
import numpy as np
from PIL import Image
import time
from cdim.noise import get_noise
from cdim.operators import get_operator
from cdim.image_utils import save_to_image
from cdim.dps_model.dps_unet import create_model
from cdim.diffusion.scheduling_ddim import DDIMScheduler
from cdim.diffusion.diffusion_pipeline import run_diffusion
from cdim.eta_scheduler import EtaScheduler
from diffusers import DiffusionPipeline


# Global variables for model and scheduler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
ddim_scheduler = None
model_type = None


def load_image(image_path):
    """Process input image to tensor format."""
    image = Image.open(image_path)
    original_image = np.array(image.resize((256, 256), Image.BICUBIC))
    original_image = torch.from_numpy(original_image).unsqueeze(0).permute(0, 3, 1, 2)
    return (original_image / 127.5 - 1.0).to(torch.float)[:, :3]


def load_yaml(file_path: str) -> dict:
    """Load configurations from a YAML file."""
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


def convert_to_np(torch_image):
    return ((torch_image.detach().clamp(-1, 1).cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).astype(np.uint8)


def generate_noisy_image(image_choice, noise_sigma, operator_key):
    """Generate the noisy image and store necessary data for restoration."""
    # Map image choice to path
    image_paths = {
        "CelebA HQ 1": "sample_images/celebhq_29999.jpg",
        "CelebA HQ 2": "sample_images/celebhq_00001.jpg",
        "CelebA HQ 3": "sample_images/celebhq_00000.jpg"
    }

    config_paths = {
        "Box Inpainting": "operator_configs/box_inpainting_config.yaml",
        "Random Inpainting": "operator_configs/random_inpainting_config.yaml", 
        "Super Resolution": "operator_configs/super_resolution_config.yaml",
        "Gaussian Deblur": "operator_configs/gaussian_blur_config.yaml"
    }

    image_path = image_paths[image_choice]
        
    # Load image and get noisy version
    original_image = load_image(image_path).to(device)
    noise_config = load_yaml("noise_configs/gaussian_noise_config.yaml")
    noise_config["sigma"] = noise_sigma
    noise_function = get_noise(**noise_config)
    operator_config = load_yaml(config_paths[operator_key])
    operator_config["device"] = device
    operator = get_operator(**operator_config)
        
    noisy_measurement = noise_function(operator(original_image))
    noisy_image = Image.fromarray(convert_to_np(noisy_measurement[0]))

    # Store necessary data for restoration
    data = {
        'noisy_measurement': noisy_measurement,
        'operator': operator,
        'noise_function': noise_function
    }

    return noisy_image, data  # Return the noisy image and data for restoration


def run_restoration(data, T, K):
    """Run the restoration process and return the restored image."""
    global model, ddim_scheduler, model_type

    # Extract stored data
    noisy_measurement = data['noisy_measurement']
    operator = data['operator']
    noise_function = data['noise_function']

    # Initialize model if not already done
    if model is None:
        model_type = "diffusers"
        model = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256").to("cuda").unet
            
        ddim_scheduler = DDIMScheduler(
            num_train_timesteps=1000, 
            beta_start=0.0001, 
            beta_end=0.02, 
            beta_schedule="linear"
        )

    # Run restoration
    eta_scheduler = EtaScheduler("gradnorm", operator.name, T, K, 'l2', noise_function, None)
    output_image = run_diffusion(
        model, ddim_scheduler, noisy_measurement, operator, noise_function, device,
        eta_scheduler, num_inference_steps=T, K=K, model_type=model_type, loss_type='l2'
    )
        
    # Convert output image for display
    output_image = Image.fromarray(convert_to_np(output_image[0]))
    return output_image


with gr.Blocks() as demo:
    gr.Markdown("# Noisy Image Restoration with Diffusion Models")
    
    with gr.Row():
        T = gr.Slider(10, 200, value=50, step=1, label="Number of Inference Steps (T)")
        K = gr.Slider(1, 10, value=3, step=1, label="K Value")
        noise_sigma = gr.Slider(0, 0.6, value=0.05, step=0.01, label="Noise Sigma")
    
    image_select = gr.Dropdown(
        choices=["CelebA HQ 1", "CelebA HQ 2", "CelebA HQ 3"],
        value="CelebA HQ 1",
        label="Select Input Image"
    )
    
    operator_select = gr.Dropdown(
        choices=["Box Inpainting", "Random Inpainting", "Super Resolution", "Gaussian Deblur"],
        value="Box Inpainting",
        label="Select Task"
    )
    
    run_button = gr.Button("Run Inference")
    noisy_image = gr.Image(label="Noisy Image")
    restored_image = gr.Image(label="Restored Image")
    state = gr.State()  # To store intermediate data

    # First function generates the noisy image and stores data
    run_button.click(
        fn=generate_noisy_image,
        inputs=[image_select, noise_sigma, operator_select],
        outputs=[noisy_image, state],
    ).then(
        fn=run_restoration,
        inputs=[state, T, K],
        outputs=restored_image
    )


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)