File size: 2,680 Bytes
f3c4d54 1519a8d ad11ca2 1519a8d 7b39006 d5949ba f3c4d54 ad11ca2 1519a8d f3c4d54 1519a8d 91844a8 7b39006 1da87bb 2c5c388 91844a8 2c5c388 f3c4d54 1519a8d f3c4d54 1519a8d f3c4d54 1519a8d 0c35185 1519a8d f3c4d54 1519a8d 7281338 1519a8d 2c5c388 91844a8 2c5c388 1519a8d 7b39006 2c5c388 91844a8 2c5c388 7b39006 1519a8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
from model.flol import create_model
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
# Define a dictionary to map image filenames to weight files
image_to_weights = ['./weights/flolv2_UHDLL.pt','./weights/flolv2_all_111439.pt']
# Initial model setup (without weights)
model = create_model()
def load_img(filename):
img = Image.open(filename).convert("RGB")
img_tensor = pil_to_tensor(img)
return img_tensor
def process_img(image, UHD_LL_model):
# Select the correct weight file based on the image filename
# filename = image.name.split("/")[-1]
# if filename in image_to_weights:
model_path = image_to_weights[0] if UHD_LL_model else image_to_weights[1]
checkpoints = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoints['params'])
model.to(device)
img = np.array(image)
img = img / 255. # Normalize to [0, 1]
img = img.astype(np.float32)
y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
with torch.no_grad():
x_hat = model(y)
restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
restored_img = np.clip(restored_img, 0., 1.)
restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8
return Image.fromarray(restored_img)
title = "Fast Baselines for Real-World Low-Light Enhancement 🌠⚡🎆"
description = ''' ## [Github Repository](https://github.com/cidautai/NAFourNet)
[Juan Carlos Benito](https://github.com/juaben)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some degradations.**
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
<br>
'''
examples = [
['images/low00772.png'],
['images/low00723.png'],
['images/425_UHD_LL.JPG'],
['images/1778_UHD_LL.JPG'],
['images/1791_UHD_LL.JPG']
]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn = process_img,
inputs = [gr.Image(type = 'pil', label = 'input'), 'checkbox'],
outputs = [gr.Image(type='pil', label = 'output')],
title = title,
description = description,
examples = examples,
css = css
)
if __name__ == '__main__':
demo.launch() |