Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as tfs | |
import os | |
def default_conv(in_channels, out_channels, kernel_size, bias=True): | |
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) | |
class PALayer(nn.Module): | |
def __init__(self, channel): | |
super(PALayer, self).__init__() | |
self.pa = nn.Sequential( | |
nn.Conv2d(channel, channel // 8, 1, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channel // 8, 1, 1, bias=True), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
y = self.pa(x) | |
return x * y | |
class CALayer(nn.Module): | |
def __init__(self, channel): | |
super(CALayer, self).__init__() | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.ca = nn.Sequential( | |
nn.Conv2d(channel, channel // 8, 1, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(channel // 8, channel, 1, bias=True), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
y = self.avg_pool(x) | |
y = self.ca(y) | |
return x * y | |
class Block(nn.Module): | |
def __init__(self, conv, dim, kernel_size): | |
super(Block, self).__init__() | |
self.conv1 = conv(dim, dim, kernel_size, bias=True) | |
self.act1 = nn.ReLU(inplace=True) | |
self.conv2 = conv(dim, dim, kernel_size, bias=True) | |
self.calayer = CALayer(dim) | |
self.palayer = PALayer(dim) | |
def forward(self, x): | |
res = self.act1(self.conv1(x)) | |
res = res + x | |
res = self.conv2(res) | |
res = self.calayer(res) | |
res = self.palayer(res) | |
res += x | |
return res | |
class Group(nn.Module): | |
def __init__(self, conv, dim, kernel_size, blocks): | |
super(Group, self).__init__() | |
modules = [Block(conv, dim, kernel_size) for _ in range(blocks)] | |
modules.append(conv(dim, dim, kernel_size)) | |
self.gp = nn.Sequential(*modules) | |
def forward(self, x): | |
res = self.gp(x) | |
res += x | |
return res | |
class FFA(nn.Module): | |
def __init__(self, gps, blocks, conv=default_conv): | |
super(FFA, self).__init__() | |
self.gps = gps | |
self.dim = 64 | |
kernel_size = 3 | |
pre_process = [conv(3, self.dim, kernel_size)] | |
assert self.gps == 3 | |
self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks) | |
self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks) | |
self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks) | |
self.ca = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, bias=True), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, bias=True), | |
nn.Sigmoid() | |
) | |
self.palayer = PALayer(self.dim) | |
post_process = [ | |
conv(self.dim, self.dim, kernel_size), | |
conv(self.dim, 3, kernel_size) | |
] | |
self.pre = nn.Sequential(*pre_process) | |
self.post = nn.Sequential(*post_process) | |
def forward(self, x1): | |
x = self.pre(x1) | |
res1 = self.g1(x) | |
res2 = self.g2(res1) | |
res3 = self.g3(res2) | |
w = self.ca(torch.cat([res1, res2, res3], dim=1)) | |
w = w.view(-1, self.gps, self.dim)[:, :, :, None, None] | |
out = w[:, 0, :, :, :] * res1 + w[:, 1, :, :, :] * res2 + w[:, 2, :, :, :] * res3 | |
out = self.palayer(out) | |
x = self.post(out) | |
return x + x1 | |
MODEL_PATH = 'tti.pk' | |
gps = 3 | |
blocks = 19 | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
net = FFA(gps=gps, blocks=blocks).to(device) | |
net = torch.nn.DataParallel(net) | |
if not os.path.exists(MODEL_PATH): | |
raise FileNotFoundError(f"Model checkpoint not found at {MODEL_PATH}") | |
try: | |
torch.serialization.add_safe_globals([np.core.multiarray.scalar]) | |
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True) | |
except: | |
print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint is from a trusted source.") | |
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False) | |
net.load_state_dict(checkpoint['model']) | |
net.eval() | |
print(f"Model loaded successfully on {device}") | |
def dehaze_image(image): | |
""" | |
Process a hazy image and return the dehazed result. | |
Args: | |
image: PIL Image or numpy array | |
Returns: | |
PIL Image: Dehazed image | |
""" | |
try: | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
haze_img = image.convert("RGB") | |
transform = tfs.Compose([ | |
tfs.ToTensor(), | |
tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152]) | |
]) | |
haze_tensor = transform(haze_img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
pred = net(haze_tensor) | |
pred_clamped = pred.clamp(0, 1).cpu() | |
pred_numpy = pred_clamped.squeeze(0).permute(1, 2, 0).numpy() | |
pred_numpy = (pred_numpy * 255).astype(np.uint8) | |
return Image.fromarray(pred_numpy) | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
return None | |
SAMPLE_IMAGES = [ | |
"./img/s2.png", | |
"./img/s4.png" | |
] | |
def load_sample_image(sample_path): | |
"""Load and return a sample image""" | |
try: | |
if os.path.exists(sample_path): | |
return Image.open(sample_path) | |
else: | |
print(f"Sample image not found: {sample_path}") | |
return None | |
except Exception as e: | |
print(f"Error loading sample image {sample_path}: {e}") | |
return None | |
def create_interface(): | |
with gr.Blocks(title="Image Dehazing App", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🌫️ Image Dehazing with FFA-Net") | |
gr.Markdown("Upload a hazy image to remove fog, haze, and improve visibility!") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Hazy Image", | |
type="pil", | |
height=400 | |
) | |
gr.Markdown("### Try Sample Images") | |
with gr.Row(): | |
sample1_btn = gr.Image( | |
value=load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None, | |
label="Sample 1", | |
interactive=True, | |
width=150, | |
height=150, | |
container=True, | |
show_download_button=False | |
) | |
sample2_btn = gr.Image( | |
value=load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None, | |
label="Sample 2", | |
interactive=True, | |
width=150, | |
height=150, | |
container=True, | |
show_download_button=False | |
) | |
process_btn = gr.Button( | |
"Remove Haze ✨", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Dehazed Result", | |
type="pil", | |
height=400 | |
) | |
def use_sample1(): | |
return load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None | |
def use_sample2(): | |
return load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None | |
sample1_btn.select( | |
fn=use_sample1, | |
outputs=input_image | |
) | |
sample2_btn.select( | |
fn=use_sample2, | |
outputs=input_image | |
) | |
process_btn.click( | |
fn=dehaze_image, | |
inputs=input_image, | |
outputs=output_image, | |
api_name="dehaze" | |
) | |
input_image.change( | |
fn=dehaze_image, | |
inputs=input_image, | |
outputs=output_image | |
) | |
gr.Markdown(""" | |
### About | |
This app uses the FFA-Net (Feature Fusion Attention Network) for single image dehazing. | |
The model removes atmospheric haze and fog to restore clear, vibrant images. | |
**Tips for best results:** | |
- Use good quality images with visible haze or fog | |
- Model works best on indoor images | |
**Made by <a href="https://www.linkedin.com/in/aditsg26/">Aditya Singh</a> and <a href="https://www.linkedin.com/in/ramandeep-singh-makkar/">Ramandeep Singh Makkar</a>** | |
""") | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=False | |
) | |