import gradio as gr import torch import torchvision.transforms as transforms import numpy as np from PIL import Image from model.flol import create_model device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #define some auxiliary functions pil_to_tensor = transforms.ToTensor() # Define a dictionary to map image filenames to weight files image_to_weights = { "425_UHD_LL.JPG": './weights/flolv2_UHDLL.pt', "1778_UHD_LL.JPG": './weights/flolv2_UHDLL.pt', "1791_UHD_LL.JPG": './weights/flolv2_UHDLL.pt', "low00748.png": './weights/flolv2_all_111439.pt', "low00723.png": './weights/flolv2_all_111439.pt', "low00772.png": './weights/flolv2_all_111439.pt' } # Initial model setup (without weights) model = create_model() def load_img(filename): img = Image.open(filename).convert("RGB") img_tensor = pil_to_tensor(img) return img_tensor def process_img(image, filename): # Select the correct weight file based on the image filename if filename in image_to_weights: model_path = image_to_weights[filename] checkpoints = torch.load(model_path, map_location=device) model.load_state_dict(checkpoints['params']) model.to(device) img = np.array(image) img = img / 255. # Normalize to [0, 1] img = img.astype(np.float32) y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device) with torch.no_grad(): x_hat = model(y) restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy() restored_img = np.clip(restored_img, 0., 1.) restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8 return Image.fromarray(restored_img) title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗" description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet) [Juan Carlos Benito](https://github.com/juaben) Fundación Cidaut > **Disclaimer:** please remember this is not a product, thus, you will notice some limitations. **This demo expects an image with some degradations.** Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
''' examples = [['images/425_UHD_LL.JPG'], ['images/low00772.png'], ['images/low00723.png'], ['images/low00748.png'], ['images/1778_UHD_LL.JPG'], ['images/1791_UHD_LL.JPG']] css = """ .image-frame img, .image-container img { width: auto; height: auto; max-width: none; } """ demo = gr.Interface( fn=process_img, inputs=[ gr.Image(type='pil', label='input', tool='editor'), gr.Textbox(label="Image Filename", interactive=False) ], outputs=[gr.Image(type='pil', label='output')], title=title, description=description, examples=examples, css=css ) # Updating the filename in the input after selection def update_filename(image): # Retrieve the filename from the input image if image: filename = image.filename # Gradio automatically gives the file name return filename return "" # Define the update logic for filename demo.input_components[0].change(update_filename, inputs=demo.input_components[0], outputs=demo.input_components[1]) if __name__ == '__main__': demo.launch()