File size: 3,757 Bytes
f3c4d54
1519a8d
 
 
 
ad11ca2
 
1519a8d
 
 
 
 
7b39006
 
 
 
 
 
 
 
 
 
f3c4d54
 
ad11ca2
1519a8d
f3c4d54
1519a8d
 
 
 
7b39006
 
1da87bb
 
7b39006
 
f3c4d54
 
 
 
1519a8d
f3c4d54
1519a8d
f3c4d54
1519a8d
 
0c35185
1519a8d
f3c4d54
 
 
 
 
1519a8d
 
 
 
 
 
 
 
 
 
 
 
26a752a
 
 
 
 
 
1519a8d
 
 
 
 
 
 
 
 
 
f3c4d54
 
18ff062
7b39006
1519a8d
f3c4d54
 
 
 
 
1519a8d
 
7b39006
a76fdf3
 
1da87bb
a76fdf3
 
 
 
 
 
 
 
 
3bb88a4
a76fdf3
 
 
7b39006
1519a8d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import torch
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from model.flol import create_model

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()

# Define a dictionary to map image filenames to weight files
image_to_weights = {
    "425_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
    "1778_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',  
    "1791_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',  
    
    "low00748.png": './weights/flolv2_all_111439.pt',
    "low00723.png": './weights/flolv2_all_111439.pt',
    "low00772.png": './weights/flolv2_all_111439.pt'
    }

# Initial model setup (without weights)
model = create_model()

def load_img(filename):
    img = Image.open(filename).convert("RGB")
    img_tensor = pil_to_tensor(img)
    return img_tensor

def process_img(image, filename):
    # Select the correct weight file based on the image filename

    filename = image.name.split("/")[-1]
    if filename in image_to_weights:
        model_path = image_to_weights[filename]
        checkpoints = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoints['params'])
        model.to(device)

    img = np.array(image)
    img = img / 255.  # Normalize to [0, 1]
    img = img.astype(np.float32)
    y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)

    with torch.no_grad():
        x_hat = model(y)

    restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
    restored_img = np.clip(restored_img, 0., 1.)

    restored_img = (restored_img * 255.0).round().astype(np.uint8)  # Convert to uint8
    return Image.fromarray(restored_img)


title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
[Juan Carlos Benito](https://github.com/juaben)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''

examples = [['images/425_UHD_LL.JPG', '425_UHD_LL.JPG'],
            ['images/low00772.png', 'low00772.png'], 
            ['images/low00723.png', 'low00723.png'], 
            ['images/low00748.png', 'low00748.png'], 
            ['images/1778_UHD_LL.JPG', '1778_UHD_LL.JPG'],
            ['images/1791_UHD_LL.JPG', '1791_UHD_LL.JPG']]

css = """
    .image-frame img, .image-container img {
        width: auto;
        height: auto;
        max-width: none;
    }
"""

demo = gr.Interface(
    fn=process_img,
    inputs=[
        gr.Image(type='pil', label='input'),
        gr.Textbox(label="Image Filename", interactive=False)
    ],
    outputs=[gr.Image(type='pil', label='output')],
    title=title,
    description=description,
    examples=examples,
    css=css
)


with gr.Blocks() as demo:
    with gr.Row():
        input_image = gr.Image(type='pil', label='input', interactive = True)
        filename_output = gr.Textbox(label="Image Filename", interactive=False)

    # Define the output
    output_image = gr.Image(type='pil', label='output')

    # Define the interaction flow: when the image is uploaded, update the filename
    input_image.change(update_filename, inputs=input_image, outputs=filename_output)
    
    # Set the function to process the image with the filename
    input_image.change(process_img, inputs=[input_image, filename_output], outputs=output_image)

    # Provide examples
    demo.examples = examples

if __name__ == '__main__':
    demo.launch()