import torch import torch.nn as nn import torch.nn.functional as F import gradio as gr import numpy as np from PIL import Image import torchvision.transforms as tfs import os def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) class PALayer(nn.Module): def __init__(self, channel): super(PALayer, self).__init__() self.pa = nn.Sequential( nn.Conv2d(channel, channel // 8, 1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // 8, 1, 1, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.pa(x) return x * y class CALayer(nn.Module): def __init__(self, channel): super(CALayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.ca = nn.Sequential( nn.Conv2d(channel, channel // 8, 1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // 8, channel, 1, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.ca(y) return x * y class Block(nn.Module): def __init__(self, conv, dim, kernel_size): super(Block, self).__init__() self.conv1 = conv(dim, dim, kernel_size, bias=True) self.act1 = nn.ReLU(inplace=True) self.conv2 = conv(dim, dim, kernel_size, bias=True) self.calayer = CALayer(dim) self.palayer = PALayer(dim) def forward(self, x): res = self.act1(self.conv1(x)) res = res + x res = self.conv2(res) res = self.calayer(res) res = self.palayer(res) res += x return res class Group(nn.Module): def __init__(self, conv, dim, kernel_size, blocks): super(Group, self).__init__() modules = [Block(conv, dim, kernel_size) for _ in range(blocks)] modules.append(conv(dim, dim, kernel_size)) self.gp = nn.Sequential(*modules) def forward(self, x): res = self.gp(x) res += x return res class FFA(nn.Module): def __init__(self, gps, blocks, conv=default_conv): super(FFA, self).__init__() self.gps = gps self.dim = 64 kernel_size = 3 pre_process = [conv(3, self.dim, kernel_size)] assert self.gps == 3 self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks) self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks) self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks) self.ca = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, bias=True), nn.Sigmoid() ) self.palayer = PALayer(self.dim) post_process = [ conv(self.dim, self.dim, kernel_size), conv(self.dim, 3, kernel_size) ] self.pre = nn.Sequential(*pre_process) self.post = nn.Sequential(*post_process) def forward(self, x1): x = self.pre(x1) res1 = self.g1(x) res2 = self.g2(res1) res3 = self.g3(res2) w = self.ca(torch.cat([res1, res2, res3], dim=1)) w = w.view(-1, self.gps, self.dim)[:, :, :, None, None] out = w[:, 0, :, :, :] * res1 + w[:, 1, :, :, :] * res2 + w[:, 2, :, :, :] * res3 out = self.palayer(out) x = self.post(out) return x + x1 MODEL_PATH = 'tti.pk' gps = 3 blocks = 19 device = 'cuda' if torch.cuda.is_available() else 'cpu' net = FFA(gps=gps, blocks=blocks).to(device) net = torch.nn.DataParallel(net) if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model checkpoint not found at {MODEL_PATH}") try: torch.serialization.add_safe_globals([np.core.multiarray.scalar]) checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True) except: print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint is from a trusted source.") checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False) net.load_state_dict(checkpoint['model']) net.eval() print(f"Model loaded successfully on {device}") def dehaze_image(image): """ Process a hazy image and return the dehazed result. Args: image: PIL Image or numpy array Returns: PIL Image: Dehazed image """ try: if isinstance(image, np.ndarray): image = Image.fromarray(image) haze_img = image.convert("RGB") transform = tfs.Compose([ tfs.ToTensor(), tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152]) ]) haze_tensor = transform(haze_img).unsqueeze(0).to(device) with torch.no_grad(): pred = net(haze_tensor) pred_clamped = pred.clamp(0, 1).cpu() pred_numpy = pred_clamped.squeeze(0).permute(1, 2, 0).numpy() pred_numpy = (pred_numpy * 255).astype(np.uint8) return Image.fromarray(pred_numpy) except Exception as e: print(f"Error processing image: {str(e)}") return None SAMPLE_IMAGES = [ "./img/s2.png", "./img/s4.png" ] def load_sample_image(sample_path): """Load and return a sample image""" try: if os.path.exists(sample_path): return Image.open(sample_path) else: print(f"Sample image not found: {sample_path}") return None except Exception as e: print(f"Error loading sample image {sample_path}: {e}") return None def create_interface(): with gr.Blocks(title="Image Dehazing App", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🌫️ Image Dehazing with FFA-Net") gr.Markdown("Upload a hazy image to remove fog, haze, and improve visibility!") with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Hazy Image", type="pil", height=400 ) gr.Markdown("### Try Sample Images") with gr.Row(): sample1_btn = gr.Image( value=load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None, label="Sample 1", interactive=True, width=150, height=150, container=True, show_download_button=False ) sample2_btn = gr.Image( value=load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None, label="Sample 2", interactive=True, width=150, height=150, container=True, show_download_button=False ) process_btn = gr.Button( "Remove Haze ✨", variant="primary", size="lg" ) with gr.Column(): output_image = gr.Image( label="Dehazed Result", type="pil", height=400 ) def use_sample1(): return load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None def use_sample2(): return load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None sample1_btn.select( fn=use_sample1, outputs=input_image ) sample2_btn.select( fn=use_sample2, outputs=input_image ) process_btn.click( fn=dehaze_image, inputs=input_image, outputs=output_image, api_name="dehaze" ) input_image.change( fn=dehaze_image, inputs=input_image, outputs=output_image ) gr.Markdown(""" ### About This app uses the FFA-Net (Feature Fusion Attention Network) for single image dehazing. The model removes atmospheric haze and fog to restore clear, vibrant images. **Tips for best results:** - Use good quality images with visible haze or fog - Model works best on indoor images **Made by Aditya Singh and Ramandeep Singh Makkar** """) return demo if __name__ == "__main__": demo = create_interface() demo.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=False )