File size: 3,757 Bytes
f3c4d54 1519a8d ad11ca2 1519a8d 7b39006 f3c4d54 ad11ca2 1519a8d f3c4d54 1519a8d 7b39006 1da87bb 7b39006 f3c4d54 1519a8d f3c4d54 1519a8d f3c4d54 1519a8d 0c35185 1519a8d f3c4d54 1519a8d 26a752a 1519a8d f3c4d54 18ff062 7b39006 1519a8d f3c4d54 1519a8d 7b39006 a76fdf3 1da87bb a76fdf3 3bb88a4 a76fdf3 7b39006 1519a8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from model.flol import create_model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
# Define a dictionary to map image filenames to weight files
image_to_weights = {
"425_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
"1778_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
"1791_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
"low00748.png": './weights/flolv2_all_111439.pt',
"low00723.png": './weights/flolv2_all_111439.pt',
"low00772.png": './weights/flolv2_all_111439.pt'
}
# Initial model setup (without weights)
model = create_model()
def load_img(filename):
img = Image.open(filename).convert("RGB")
img_tensor = pil_to_tensor(img)
return img_tensor
def process_img(image, filename):
# Select the correct weight file based on the image filename
filename = image.name.split("/")[-1]
if filename in image_to_weights:
model_path = image_to_weights[filename]
checkpoints = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoints['params'])
model.to(device)
img = np.array(image)
img = img / 255. # Normalize to [0, 1]
img = img.astype(np.float32)
y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat = model(y)
restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0., 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8
return Image.fromarray(restored_img)
title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
[Juan Carlos Benito](https://github.com/juaben)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''
examples = [['images/425_UHD_LL.JPG', '425_UHD_LL.JPG'],
['images/low00772.png', 'low00772.png'],
['images/low00723.png', 'low00723.png'],
['images/low00748.png', 'low00748.png'],
['images/1778_UHD_LL.JPG', '1778_UHD_LL.JPG'],
['images/1791_UHD_LL.JPG', '1791_UHD_LL.JPG']]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn=process_img,
inputs=[
gr.Image(type='pil', label='input'),
gr.Textbox(label="Image Filename", interactive=False)
],
outputs=[gr.Image(type='pil', label='output')],
title=title,
description=description,
examples=examples,
css=css
)
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(type='pil', label='input', interactive = True)
filename_output = gr.Textbox(label="Image Filename", interactive=False)
# Define the output
output_image = gr.Image(type='pil', label='output')
# Define the interaction flow: when the image is uploaded, update the filename
input_image.change(update_filename, inputs=input_image, outputs=filename_output)
# Set the function to process the image with the filename
input_image.change(process_img, inputs=[input_image, filename_output], outputs=output_image)
# Provide examples
demo.examples = examples
if __name__ == '__main__':
demo.launch() |