FLOL / app.py
juaben's picture
Update app.py
d8a314a verified
raw
history blame
2.43 kB
import gradio as gr
from PIL import Image
import torch
import torchvision.transforms as transforms
import numpy as np
from archs.model import FourNet
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
# define some parameters based on the run we want to make
model = FourNet(nf = 16)
checkpoints = torch.load('./models/NAFourNet16_LOLv2Real.pt', map_location=device)
model.load_state_dict(checkpoints['model_state_dict'])
model = model.to(device)
def load_img (filename):
img = Image.open(filename).convert("RGB")
img_tensor = pil_to_tensor(img)
return img_tensor
def process_img(image):
img = np.array(image)
img = img / 255.
img = img.astype(np.float32)
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat, _, _, _ = model(y)
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0. , 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
[Juan Carlos Benito](https://github.com/juaben)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''
examples = [['examples/inputs/0010.png'],
['examples/inputs/0060.png'],
['examples/inputs/0075.png'],
["examples/inputs/0087.png"],
["examples/inputs/0088.png"],
["examples/inputs/BDD100k_input.jpg"]]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn = process_img,
inputs = [
gr.Image(type = 'pil', label = 'input')
],
outputs = [gr.Image(type='pil', label = 'output')],
title = title,
description = description,
examples = examples,
css = css
)
if __name__ == '__main__':
demo.launch()