import gradio as gr import torch import torchvision.transforms as transforms import numpy as np from PIL import Image from model.flol import create_model device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #define some auxiliary functions pil_to_tensor = transforms.ToTensor() # Define a list of available weight files weights_files = { "flolv2_UHDLL": './weights/flolv2_UHDLL.pt', "flolv2_LOLv2-Real": './weights/flolv2_all_111439.pt', # Add other available weights here } # Initial model setup (without weights) model = create_model() def load_img(filename): img = Image.open(filename).convert("RGB") img_tensor = pil_to_tensor(img) return img_tensor def process_img(image, weight_file): # Load the selected weight file if weight_file in weights_files: model_path = weights_files[weight_file] checkpoints = torch.load(model_path, map_location=device) model.load_state_dict(checkpoints['params']) model.to(device) img = np.array(image) img = img / 255. # Normalize to [0, 1] img = img.astype(np.float32) y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): x_hat = model(y) restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy() restored_img = np.clip(restored_img, 0., 1.) restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8 return Image.fromarray(restored_img) title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗" description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet) [Juan Carlos Benito](https://github.com/juaben) Fundación Cidaut > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations. **This demo expects an image with some degradations.** Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
''' examples = [['images/425_UHD_LL.JPG'], ['images/674_UHD_LL.JPG'], ['images/685_UHD_LL.JPG'], ['images/945_UHD_LL.JPG'], ['images/1778_UHD_LL.JPG'], ['images/1791_UHD_LL.JPG']] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn=process_img, inputs=[ gr.Image(type='pil', label='input'), gr.Dropdown(choices=list(weights_files.keys()), label='Select Weight File', default="flolv2_UHDLL") ], outputs=[gr.Image(type='pil', label='output')], title=title, description=description, examples=examples, css=css ) if __name__ == '__main__': demo.launch()