Spaces:
Running
Running
1w33
Browse files- .gitignore +54 -0
- app.py +299 -0
- requirements.txt +4 -0
.gitignore
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# macOS system files
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
# Node.js dependencies
|
5 |
+
node_modules/
|
6 |
+
|
7 |
+
# Python
|
8 |
+
__pycache__/
|
9 |
+
*.pyc
|
10 |
+
*.pyo
|
11 |
+
*.pyd
|
12 |
+
env/
|
13 |
+
venv/
|
14 |
+
ENV/
|
15 |
+
.venv/
|
16 |
+
|
17 |
+
# Logs
|
18 |
+
*.log
|
19 |
+
|
20 |
+
# OS generated files
|
21 |
+
Thumbs.db
|
22 |
+
ehthumbs.db
|
23 |
+
Icon?
|
24 |
+
Desktop.ini
|
25 |
+
|
26 |
+
# IDEs and editors
|
27 |
+
.vscode/
|
28 |
+
.idea/
|
29 |
+
*.sublime-workspace
|
30 |
+
*.sublime-project
|
31 |
+
|
32 |
+
# Build output
|
33 |
+
dist/
|
34 |
+
build/
|
35 |
+
out/
|
36 |
+
|
37 |
+
# Temporary files
|
38 |
+
*.tmp
|
39 |
+
*.swp
|
40 |
+
*.bak
|
41 |
+
*.orig
|
42 |
+
|
43 |
+
# Test coverage
|
44 |
+
coverage/
|
45 |
+
*.cover
|
46 |
+
.nyc_output/
|
47 |
+
|
48 |
+
# Environment files
|
49 |
+
.env
|
50 |
+
.env.*
|
51 |
+
|
52 |
+
# Yarn
|
53 |
+
.yarn/
|
54 |
+
.pnp.*
|
app.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torchvision.transforms as tfs
|
8 |
+
import os
|
9 |
+
|
10 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
11 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
|
12 |
+
|
13 |
+
class PALayer(nn.Module):
|
14 |
+
def __init__(self, channel):
|
15 |
+
super(PALayer, self).__init__()
|
16 |
+
self.pa = nn.Sequential(
|
17 |
+
nn.Conv2d(channel, channel // 8, 1, bias=True),
|
18 |
+
nn.ReLU(inplace=True),
|
19 |
+
nn.Conv2d(channel // 8, 1, 1, bias=True),
|
20 |
+
nn.Sigmoid()
|
21 |
+
)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
y = self.pa(x)
|
25 |
+
return x * y
|
26 |
+
|
27 |
+
class CALayer(nn.Module):
|
28 |
+
def __init__(self, channel):
|
29 |
+
super(CALayer, self).__init__()
|
30 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
31 |
+
self.ca = nn.Sequential(
|
32 |
+
nn.Conv2d(channel, channel // 8, 1, bias=True),
|
33 |
+
nn.ReLU(inplace=True),
|
34 |
+
nn.Conv2d(channel // 8, channel, 1, bias=True),
|
35 |
+
nn.Sigmoid()
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
y = self.avg_pool(x)
|
40 |
+
y = self.ca(y)
|
41 |
+
return x * y
|
42 |
+
|
43 |
+
class Block(nn.Module):
|
44 |
+
def __init__(self, conv, dim, kernel_size):
|
45 |
+
super(Block, self).__init__()
|
46 |
+
self.conv1 = conv(dim, dim, kernel_size, bias=True)
|
47 |
+
self.act1 = nn.ReLU(inplace=True)
|
48 |
+
self.conv2 = conv(dim, dim, kernel_size, bias=True)
|
49 |
+
self.calayer = CALayer(dim)
|
50 |
+
self.palayer = PALayer(dim)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
res = self.act1(self.conv1(x))
|
54 |
+
res = res + x
|
55 |
+
res = self.conv2(res)
|
56 |
+
res = self.calayer(res)
|
57 |
+
res = self.palayer(res)
|
58 |
+
res += x
|
59 |
+
return res
|
60 |
+
|
61 |
+
class Group(nn.Module):
|
62 |
+
def __init__(self, conv, dim, kernel_size, blocks):
|
63 |
+
super(Group, self).__init__()
|
64 |
+
modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]
|
65 |
+
modules.append(conv(dim, dim, kernel_size))
|
66 |
+
self.gp = nn.Sequential(*modules)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
res = self.gp(x)
|
70 |
+
res += x
|
71 |
+
return res
|
72 |
+
|
73 |
+
class FFA(nn.Module):
|
74 |
+
def __init__(self, gps, blocks, conv=default_conv):
|
75 |
+
super(FFA, self).__init__()
|
76 |
+
self.gps = gps
|
77 |
+
self.dim = 64
|
78 |
+
kernel_size = 3
|
79 |
+
|
80 |
+
pre_process = [conv(3, self.dim, kernel_size)]
|
81 |
+
assert self.gps == 3
|
82 |
+
self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)
|
83 |
+
self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)
|
84 |
+
self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)
|
85 |
+
self.ca = nn.Sequential(
|
86 |
+
nn.AdaptiveAvgPool2d(1),
|
87 |
+
nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, bias=True),
|
88 |
+
nn.ReLU(inplace=True),
|
89 |
+
nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, bias=True),
|
90 |
+
nn.Sigmoid()
|
91 |
+
)
|
92 |
+
self.palayer = PALayer(self.dim)
|
93 |
+
|
94 |
+
post_process = [
|
95 |
+
conv(self.dim, self.dim, kernel_size),
|
96 |
+
conv(self.dim, 3, kernel_size)
|
97 |
+
]
|
98 |
+
|
99 |
+
self.pre = nn.Sequential(*pre_process)
|
100 |
+
self.post = nn.Sequential(*post_process)
|
101 |
+
|
102 |
+
def forward(self, x1):
|
103 |
+
x = self.pre(x1)
|
104 |
+
res1 = self.g1(x)
|
105 |
+
res2 = self.g2(res1)
|
106 |
+
res3 = self.g3(res2)
|
107 |
+
w = self.ca(torch.cat([res1, res2, res3], dim=1))
|
108 |
+
w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]
|
109 |
+
out = w[:, 0, :, :, :] * res1 + w[:, 1, :, :, :] * res2 + w[:, 2, :, :, :] * res3
|
110 |
+
out = self.palayer(out)
|
111 |
+
x = self.post(out)
|
112 |
+
return x + x1
|
113 |
+
|
114 |
+
MODEL_PATH = 'tti.pk'
|
115 |
+
gps = 3
|
116 |
+
blocks = 19
|
117 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
118 |
+
|
119 |
+
net = FFA(gps=gps, blocks=blocks).to(device)
|
120 |
+
net = torch.nn.DataParallel(net)
|
121 |
+
|
122 |
+
if not os.path.exists(MODEL_PATH):
|
123 |
+
raise FileNotFoundError(f"Model checkpoint not found at {MODEL_PATH}")
|
124 |
+
|
125 |
+
try:
|
126 |
+
|
127 |
+
torch.serialization.add_safe_globals([np.core.multiarray.scalar])
|
128 |
+
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True)
|
129 |
+
except:
|
130 |
+
|
131 |
+
print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint is from a trusted source.")
|
132 |
+
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
|
133 |
+
net.load_state_dict(checkpoint['model'])
|
134 |
+
net.eval()
|
135 |
+
|
136 |
+
print(f"Model loaded successfully on {device}")
|
137 |
+
|
138 |
+
def dehaze_image(image):
|
139 |
+
"""
|
140 |
+
Process a hazy image and return the dehazed result.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
image: PIL Image or numpy array
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
PIL Image: Dehazed image
|
147 |
+
"""
|
148 |
+
try:
|
149 |
+
|
150 |
+
if isinstance(image, np.ndarray):
|
151 |
+
image = Image.fromarray(image)
|
152 |
+
|
153 |
+
|
154 |
+
haze_img = image.convert("RGB")
|
155 |
+
|
156 |
+
|
157 |
+
transform = tfs.Compose([
|
158 |
+
tfs.ToTensor(),
|
159 |
+
tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152])
|
160 |
+
])
|
161 |
+
|
162 |
+
haze_tensor = transform(haze_img).unsqueeze(0).to(device)
|
163 |
+
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
pred = net(haze_tensor)
|
167 |
+
|
168 |
+
|
169 |
+
pred_clamped = pred.clamp(0, 1).cpu()
|
170 |
+
pred_numpy = pred_clamped.squeeze(0).permute(1, 2, 0).numpy()
|
171 |
+
pred_numpy = (pred_numpy * 255).astype(np.uint8)
|
172 |
+
|
173 |
+
return Image.fromarray(pred_numpy)
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
print(f"Error processing image: {str(e)}")
|
177 |
+
return None
|
178 |
+
|
179 |
+
SAMPLE_IMAGES = [
|
180 |
+
"./img/s2.png",
|
181 |
+
"./img/s4.png"
|
182 |
+
]
|
183 |
+
|
184 |
+
def load_sample_image(sample_path):
|
185 |
+
"""Load and return a sample image"""
|
186 |
+
try:
|
187 |
+
if os.path.exists(sample_path):
|
188 |
+
return Image.open(sample_path)
|
189 |
+
else:
|
190 |
+
print(f"Sample image not found: {sample_path}")
|
191 |
+
return None
|
192 |
+
except Exception as e:
|
193 |
+
print(f"Error loading sample image {sample_path}: {e}")
|
194 |
+
return None
|
195 |
+
|
196 |
+
def create_interface():
|
197 |
+
with gr.Blocks(title="Image Dehazing App", theme=gr.themes.Soft()) as demo:
|
198 |
+
gr.Markdown("# 🌫️ Image Dehazing with FFA-Net")
|
199 |
+
gr.Markdown("Upload a hazy image to remove fog, haze, and improve visibility!")
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column():
|
203 |
+
input_image = gr.Image(
|
204 |
+
label="Upload Hazy Image",
|
205 |
+
type="pil",
|
206 |
+
height=400
|
207 |
+
)
|
208 |
+
|
209 |
+
|
210 |
+
gr.Markdown("### Try Sample Images")
|
211 |
+
with gr.Row():
|
212 |
+
sample1_btn = gr.Image(
|
213 |
+
value=load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None,
|
214 |
+
label="Sample 1",
|
215 |
+
interactive=True,
|
216 |
+
width=150,
|
217 |
+
height=150,
|
218 |
+
container=True,
|
219 |
+
show_download_button=False
|
220 |
+
)
|
221 |
+
sample2_btn = gr.Image(
|
222 |
+
value=load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None,
|
223 |
+
label="Sample 2",
|
224 |
+
interactive=True,
|
225 |
+
width=150,
|
226 |
+
height=150,
|
227 |
+
container=True,
|
228 |
+
show_download_button=False
|
229 |
+
)
|
230 |
+
|
231 |
+
process_btn = gr.Button(
|
232 |
+
"Remove Haze ✨",
|
233 |
+
variant="primary",
|
234 |
+
size="lg"
|
235 |
+
)
|
236 |
+
|
237 |
+
with gr.Column():
|
238 |
+
output_image = gr.Image(
|
239 |
+
label="Dehazed Result",
|
240 |
+
type="pil",
|
241 |
+
height=400
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
def use_sample1():
|
246 |
+
return load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None
|
247 |
+
|
248 |
+
def use_sample2():
|
249 |
+
return load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None
|
250 |
+
|
251 |
+
sample1_btn.select(
|
252 |
+
fn=use_sample1,
|
253 |
+
outputs=input_image
|
254 |
+
)
|
255 |
+
|
256 |
+
sample2_btn.select(
|
257 |
+
fn=use_sample2,
|
258 |
+
outputs=input_image
|
259 |
+
)
|
260 |
+
|
261 |
+
|
262 |
+
process_btn.click(
|
263 |
+
fn=dehaze_image,
|
264 |
+
inputs=input_image,
|
265 |
+
outputs=output_image,
|
266 |
+
api_name="dehaze"
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
input_image.change(
|
271 |
+
fn=dehaze_image,
|
272 |
+
inputs=input_image,
|
273 |
+
outputs=output_image
|
274 |
+
)
|
275 |
+
|
276 |
+
gr.Markdown("""
|
277 |
+
### About
|
278 |
+
This app uses the FFA-Net (Feature Fusion Attention Network) for single image dehazing.
|
279 |
+
The model removes atmospheric haze and fog to restore clear, vibrant images.
|
280 |
+
|
281 |
+
**Tips for best results:**
|
282 |
+
- Use good quality images with visible haze or fog
|
283 |
+
- Model works best on indoor images
|
284 |
+
**Made by <a href="https://www.linkedin.com/in/aditsg26/">Aditya Singh</a> and <a href="https://www.linkedin.com/in/ramandeep-singh-makkar/">Ramandeep Singh Makkar</a>**
|
285 |
+
""")
|
286 |
+
|
287 |
+
return demo
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
|
291 |
+
demo = create_interface()
|
292 |
+
|
293 |
+
|
294 |
+
demo.launch(
|
295 |
+
server_name="0.0.0.0",
|
296 |
+
server_port=7860,
|
297 |
+
share=False,
|
298 |
+
debug=False
|
299 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pillow
|
3 |
+
torch
|
4 |
+
torchvision
|