FLOL / app.py
juaben's picture
Update app.py
f3c4d54 verified
raw
history blame
2.75 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from model.flol import create_model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
# Define a list of available weight files
weights_files = {
"flolv2_UHDLL": './weights/flolv2_UHDLL.pt',
"flolv2_other": './weights/flolv2_other.pt', # Add other available weights here
}
# Initial model setup (without weights)
model = create_model()
def load_img(filename):
img = Image.open(filename).convert("RGB")
img_tensor = pil_to_tensor(img)
return img_tensor
def process_img(image, weight_file):
# Load the selected weight file
if weight_file in weights_files:
model_path = weights_files[weight_file]
checkpoints = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoints['params'])
model.to(device)
img = np.array(image)
img = img / 255. # Normalize to [0, 1]
img = img.astype(np.float32)
y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat = model(y)
restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0., 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8
return Image.fromarray(restored_img)
title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
[Juan Carlos Benito](https://github.com/juaben)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''
examples = [['images/425_UHD_LL.JPG'],
['images/674_UHD_LL.JPG'],
['images/685_UHD_LL.JPG'],
['images/945_UHD_LL.JPG'],
['images/1778_UHD_LL.JPG'],
['images/1791_UHD_LL.JPG']]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn=process_img,
inputs=[
gr.Image(type='pil', label='input'),
gr.Dropdown(choices=list(weights_files.keys()), label='Select Weight File', default="flolv2_UHDLL")
],
outputs=[gr.Image(type='pil', label='output')],
title=title,
description=description,
examples=examples,
css=css
)
if __name__ == '__main__':
demo.launch()