FLOL / app.py
juaben's picture
Update app.py
7b39006 verified
raw
history blame
3.37 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from model.flol import create_model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
# Define a dictionary to map image filenames to weight files
image_to_weights = {
"425_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
"1778_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
"1791_UHD_LL.JPG": './weights/flolv2_UHDLL.pt',
"low00748.png": './weights/flolv2_all_111439.pt',
"low00723.png": './weights/flolv2_all_111439.pt',
"low00772.png": './weights/flolv2_all_111439.pt'
}
# Initial model setup (without weights)
model = create_model()
def load_img(filename):
img = Image.open(filename).convert("RGB")
img_tensor = pil_to_tensor(img)
return img_tensor
def process_img(image, filename):
# Select the correct weight file based on the image filename
if filename in image_to_weights:
model_path = image_to_weights[filename]
checkpoints = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoints['params'])
model.to(device)
img = np.array(image)
img = img / 255. # Normalize to [0, 1]
img = img.astype(np.float32)
y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat = model(y)
restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0., 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8
return Image.fromarray(restored_img)
title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
[Juan Carlos Benito](https://github.com/juaben)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''
examples = [['images/425_UHD_LL.JPG'],
['images/low00772.png'],
['images/low00723.png'],
['images/low00748.png'],
['images/1778_UHD_LL.JPG'],
['images/1791_UHD_LL.JPG']]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn=process_img,
inputs=[
gr.Image(type='pil', label='input', tool='editor'),
gr.Textbox(label="Image Filename", interactive=False)
],
outputs=[gr.Image(type='pil', label='output')],
title=title,
description=description,
examples=examples,
css=css
)
# Updating the filename in the input after selection
def update_filename(image):
# Retrieve the filename from the input image
if image:
filename = image.filename # Gradio automatically gives the file name
return filename
return ""
# Define the update logic for filename
demo.input_components[0].change(update_filename, inputs=demo.input_components[0], outputs=demo.input_components[1])
if __name__ == '__main__':
demo.launch()