import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms as tfs
import os
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
class PALayer(nn.Module):
def __init__(self, channel):
super(PALayer, self).__init__()
self.pa = nn.Sequential(
nn.Conv2d(channel, channel // 8, 1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 8, 1, 1, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.pa(x)
return x * y
class CALayer(nn.Module):
def __init__(self, channel):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.ca = nn.Sequential(
nn.Conv2d(channel, channel // 8, 1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 8, channel, 1, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.ca(y)
return x * y
class Block(nn.Module):
def __init__(self, conv, dim, kernel_size):
super(Block, self).__init__()
self.conv1 = conv(dim, dim, kernel_size, bias=True)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = conv(dim, dim, kernel_size, bias=True)
self.calayer = CALayer(dim)
self.palayer = PALayer(dim)
def forward(self, x):
res = self.act1(self.conv1(x))
res = res + x
res = self.conv2(res)
res = self.calayer(res)
res = self.palayer(res)
res += x
return res
class Group(nn.Module):
def __init__(self, conv, dim, kernel_size, blocks):
super(Group, self).__init__()
modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]
modules.append(conv(dim, dim, kernel_size))
self.gp = nn.Sequential(*modules)
def forward(self, x):
res = self.gp(x)
res += x
return res
class FFA(nn.Module):
def __init__(self, gps, blocks, conv=default_conv):
super(FFA, self).__init__()
self.gps = gps
self.dim = 64
kernel_size = 3
pre_process = [conv(3, self.dim, kernel_size)]
assert self.gps == 3
self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)
self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)
self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)
self.ca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, bias=True),
nn.Sigmoid()
)
self.palayer = PALayer(self.dim)
post_process = [
conv(self.dim, self.dim, kernel_size),
conv(self.dim, 3, kernel_size)
]
self.pre = nn.Sequential(*pre_process)
self.post = nn.Sequential(*post_process)
def forward(self, x1):
x = self.pre(x1)
res1 = self.g1(x)
res2 = self.g2(res1)
res3 = self.g3(res2)
w = self.ca(torch.cat([res1, res2, res3], dim=1))
w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]
out = w[:, 0, :, :, :] * res1 + w[:, 1, :, :, :] * res2 + w[:, 2, :, :, :] * res3
out = self.palayer(out)
x = self.post(out)
return x + x1
MODEL_PATH = 'tti.pk'
gps = 3
blocks = 19
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = FFA(gps=gps, blocks=blocks).to(device)
net = torch.nn.DataParallel(net)
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model checkpoint not found at {MODEL_PATH}")
try:
torch.serialization.add_safe_globals([np.core.multiarray.scalar])
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True)
except:
print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint is from a trusted source.")
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
net.load_state_dict(checkpoint['model'])
net.eval()
print(f"Model loaded successfully on {device}")
def dehaze_image(image):
"""
Process a hazy image and return the dehazed result.
Args:
image: PIL Image or numpy array
Returns:
PIL Image: Dehazed image
"""
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
haze_img = image.convert("RGB")
transform = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152])
])
haze_tensor = transform(haze_img).unsqueeze(0).to(device)
with torch.no_grad():
pred = net(haze_tensor)
pred_clamped = pred.clamp(0, 1).cpu()
pred_numpy = pred_clamped.squeeze(0).permute(1, 2, 0).numpy()
pred_numpy = (pred_numpy * 255).astype(np.uint8)
return Image.fromarray(pred_numpy)
except Exception as e:
print(f"Error processing image: {str(e)}")
return None
SAMPLE_IMAGES = [
"./img/s2.png",
"./img/s4.png"
]
def load_sample_image(sample_path):
"""Load and return a sample image"""
try:
if os.path.exists(sample_path):
return Image.open(sample_path)
else:
print(f"Sample image not found: {sample_path}")
return None
except Exception as e:
print(f"Error loading sample image {sample_path}: {e}")
return None
def create_interface():
with gr.Blocks(title="Image Dehazing App", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🌫️ Image Dehazing with FFA-Net")
gr.Markdown("Upload a hazy image to remove fog, haze, and improve visibility!")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Hazy Image",
type="pil",
height=400
)
gr.Markdown("### Try Sample Images")
with gr.Row():
sample1_btn = gr.Image(
value=load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None,
label="Sample 1",
interactive=True,
width=150,
height=150,
container=True,
show_download_button=False
)
sample2_btn = gr.Image(
value=load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None,
label="Sample 2",
interactive=True,
width=150,
height=150,
container=True,
show_download_button=False
)
process_btn = gr.Button(
"Remove Haze ✨",
variant="primary",
size="lg"
)
with gr.Column():
output_image = gr.Image(
label="Dehazed Result",
type="pil",
height=400
)
def use_sample1():
return load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None
def use_sample2():
return load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None
sample1_btn.select(
fn=use_sample1,
outputs=input_image
)
sample2_btn.select(
fn=use_sample2,
outputs=input_image
)
process_btn.click(
fn=dehaze_image,
inputs=input_image,
outputs=output_image,
api_name="dehaze"
)
input_image.change(
fn=dehaze_image,
inputs=input_image,
outputs=output_image
)
gr.Markdown("""
### About
This app uses the FFA-Net (Feature Fusion Attention Network) for single image dehazing.
The model removes atmospheric haze and fog to restore clear, vibrant images.
**Tips for best results:**
- Use good quality images with visible haze or fog
- Model works best on indoor images
**Made by Aditya Singh and Ramandeep Singh Makkar**
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False
)