Adieee5 commited on
Commit
072b5ae
·
1 Parent(s): 4a07e4c
Files changed (3) hide show
  1. .gitignore +54 -0
  2. app.py +299 -0
  3. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # macOS system files
2
+ .DS_Store
3
+
4
+ # Node.js dependencies
5
+ node_modules/
6
+
7
+ # Python
8
+ __pycache__/
9
+ *.pyc
10
+ *.pyo
11
+ *.pyd
12
+ env/
13
+ venv/
14
+ ENV/
15
+ .venv/
16
+
17
+ # Logs
18
+ *.log
19
+
20
+ # OS generated files
21
+ Thumbs.db
22
+ ehthumbs.db
23
+ Icon?
24
+ Desktop.ini
25
+
26
+ # IDEs and editors
27
+ .vscode/
28
+ .idea/
29
+ *.sublime-workspace
30
+ *.sublime-project
31
+
32
+ # Build output
33
+ dist/
34
+ build/
35
+ out/
36
+
37
+ # Temporary files
38
+ *.tmp
39
+ *.swp
40
+ *.bak
41
+ *.orig
42
+
43
+ # Test coverage
44
+ coverage/
45
+ *.cover
46
+ .nyc_output/
47
+
48
+ # Environment files
49
+ .env
50
+ .env.*
51
+
52
+ # Yarn
53
+ .yarn/
54
+ .pnp.*
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchvision.transforms as tfs
8
+ import os
9
+
10
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
11
+ return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
12
+
13
+ class PALayer(nn.Module):
14
+ def __init__(self, channel):
15
+ super(PALayer, self).__init__()
16
+ self.pa = nn.Sequential(
17
+ nn.Conv2d(channel, channel // 8, 1, bias=True),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(channel // 8, 1, 1, bias=True),
20
+ nn.Sigmoid()
21
+ )
22
+
23
+ def forward(self, x):
24
+ y = self.pa(x)
25
+ return x * y
26
+
27
+ class CALayer(nn.Module):
28
+ def __init__(self, channel):
29
+ super(CALayer, self).__init__()
30
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
31
+ self.ca = nn.Sequential(
32
+ nn.Conv2d(channel, channel // 8, 1, bias=True),
33
+ nn.ReLU(inplace=True),
34
+ nn.Conv2d(channel // 8, channel, 1, bias=True),
35
+ nn.Sigmoid()
36
+ )
37
+
38
+ def forward(self, x):
39
+ y = self.avg_pool(x)
40
+ y = self.ca(y)
41
+ return x * y
42
+
43
+ class Block(nn.Module):
44
+ def __init__(self, conv, dim, kernel_size):
45
+ super(Block, self).__init__()
46
+ self.conv1 = conv(dim, dim, kernel_size, bias=True)
47
+ self.act1 = nn.ReLU(inplace=True)
48
+ self.conv2 = conv(dim, dim, kernel_size, bias=True)
49
+ self.calayer = CALayer(dim)
50
+ self.palayer = PALayer(dim)
51
+
52
+ def forward(self, x):
53
+ res = self.act1(self.conv1(x))
54
+ res = res + x
55
+ res = self.conv2(res)
56
+ res = self.calayer(res)
57
+ res = self.palayer(res)
58
+ res += x
59
+ return res
60
+
61
+ class Group(nn.Module):
62
+ def __init__(self, conv, dim, kernel_size, blocks):
63
+ super(Group, self).__init__()
64
+ modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]
65
+ modules.append(conv(dim, dim, kernel_size))
66
+ self.gp = nn.Sequential(*modules)
67
+
68
+ def forward(self, x):
69
+ res = self.gp(x)
70
+ res += x
71
+ return res
72
+
73
+ class FFA(nn.Module):
74
+ def __init__(self, gps, blocks, conv=default_conv):
75
+ super(FFA, self).__init__()
76
+ self.gps = gps
77
+ self.dim = 64
78
+ kernel_size = 3
79
+
80
+ pre_process = [conv(3, self.dim, kernel_size)]
81
+ assert self.gps == 3
82
+ self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)
83
+ self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)
84
+ self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)
85
+ self.ca = nn.Sequential(
86
+ nn.AdaptiveAvgPool2d(1),
87
+ nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, bias=True),
88
+ nn.ReLU(inplace=True),
89
+ nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, bias=True),
90
+ nn.Sigmoid()
91
+ )
92
+ self.palayer = PALayer(self.dim)
93
+
94
+ post_process = [
95
+ conv(self.dim, self.dim, kernel_size),
96
+ conv(self.dim, 3, kernel_size)
97
+ ]
98
+
99
+ self.pre = nn.Sequential(*pre_process)
100
+ self.post = nn.Sequential(*post_process)
101
+
102
+ def forward(self, x1):
103
+ x = self.pre(x1)
104
+ res1 = self.g1(x)
105
+ res2 = self.g2(res1)
106
+ res3 = self.g3(res2)
107
+ w = self.ca(torch.cat([res1, res2, res3], dim=1))
108
+ w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]
109
+ out = w[:, 0, :, :, :] * res1 + w[:, 1, :, :, :] * res2 + w[:, 2, :, :, :] * res3
110
+ out = self.palayer(out)
111
+ x = self.post(out)
112
+ return x + x1
113
+
114
+ MODEL_PATH = 'tti.pk'
115
+ gps = 3
116
+ blocks = 19
117
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
118
+
119
+ net = FFA(gps=gps, blocks=blocks).to(device)
120
+ net = torch.nn.DataParallel(net)
121
+
122
+ if not os.path.exists(MODEL_PATH):
123
+ raise FileNotFoundError(f"Model checkpoint not found at {MODEL_PATH}")
124
+
125
+ try:
126
+
127
+ torch.serialization.add_safe_globals([np.core.multiarray.scalar])
128
+ checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True)
129
+ except:
130
+
131
+ print("Warning: Loading checkpoint with weights_only=False. Ensure the checkpoint is from a trusted source.")
132
+ checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
133
+ net.load_state_dict(checkpoint['model'])
134
+ net.eval()
135
+
136
+ print(f"Model loaded successfully on {device}")
137
+
138
+ def dehaze_image(image):
139
+ """
140
+ Process a hazy image and return the dehazed result.
141
+
142
+ Args:
143
+ image: PIL Image or numpy array
144
+
145
+ Returns:
146
+ PIL Image: Dehazed image
147
+ """
148
+ try:
149
+
150
+ if isinstance(image, np.ndarray):
151
+ image = Image.fromarray(image)
152
+
153
+
154
+ haze_img = image.convert("RGB")
155
+
156
+
157
+ transform = tfs.Compose([
158
+ tfs.ToTensor(),
159
+ tfs.Normalize(mean=[0.64, 0.6, 0.58], std=[0.14, 0.15, 0.152])
160
+ ])
161
+
162
+ haze_tensor = transform(haze_img).unsqueeze(0).to(device)
163
+
164
+
165
+ with torch.no_grad():
166
+ pred = net(haze_tensor)
167
+
168
+
169
+ pred_clamped = pred.clamp(0, 1).cpu()
170
+ pred_numpy = pred_clamped.squeeze(0).permute(1, 2, 0).numpy()
171
+ pred_numpy = (pred_numpy * 255).astype(np.uint8)
172
+
173
+ return Image.fromarray(pred_numpy)
174
+
175
+ except Exception as e:
176
+ print(f"Error processing image: {str(e)}")
177
+ return None
178
+
179
+ SAMPLE_IMAGES = [
180
+ "./img/s2.png",
181
+ "./img/s4.png"
182
+ ]
183
+
184
+ def load_sample_image(sample_path):
185
+ """Load and return a sample image"""
186
+ try:
187
+ if os.path.exists(sample_path):
188
+ return Image.open(sample_path)
189
+ else:
190
+ print(f"Sample image not found: {sample_path}")
191
+ return None
192
+ except Exception as e:
193
+ print(f"Error loading sample image {sample_path}: {e}")
194
+ return None
195
+
196
+ def create_interface():
197
+ with gr.Blocks(title="Image Dehazing App", theme=gr.themes.Soft()) as demo:
198
+ gr.Markdown("# 🌫️ Image Dehazing with FFA-Net")
199
+ gr.Markdown("Upload a hazy image to remove fog, haze, and improve visibility!")
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ input_image = gr.Image(
204
+ label="Upload Hazy Image",
205
+ type="pil",
206
+ height=400
207
+ )
208
+
209
+
210
+ gr.Markdown("### Try Sample Images")
211
+ with gr.Row():
212
+ sample1_btn = gr.Image(
213
+ value=load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None,
214
+ label="Sample 1",
215
+ interactive=True,
216
+ width=150,
217
+ height=150,
218
+ container=True,
219
+ show_download_button=False
220
+ )
221
+ sample2_btn = gr.Image(
222
+ value=load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None,
223
+ label="Sample 2",
224
+ interactive=True,
225
+ width=150,
226
+ height=150,
227
+ container=True,
228
+ show_download_button=False
229
+ )
230
+
231
+ process_btn = gr.Button(
232
+ "Remove Haze ✨",
233
+ variant="primary",
234
+ size="lg"
235
+ )
236
+
237
+ with gr.Column():
238
+ output_image = gr.Image(
239
+ label="Dehazed Result",
240
+ type="pil",
241
+ height=400
242
+ )
243
+
244
+
245
+ def use_sample1():
246
+ return load_sample_image(SAMPLE_IMAGES[0]) if len(SAMPLE_IMAGES) > 0 else None
247
+
248
+ def use_sample2():
249
+ return load_sample_image(SAMPLE_IMAGES[1]) if len(SAMPLE_IMAGES) > 1 else None
250
+
251
+ sample1_btn.select(
252
+ fn=use_sample1,
253
+ outputs=input_image
254
+ )
255
+
256
+ sample2_btn.select(
257
+ fn=use_sample2,
258
+ outputs=input_image
259
+ )
260
+
261
+
262
+ process_btn.click(
263
+ fn=dehaze_image,
264
+ inputs=input_image,
265
+ outputs=output_image,
266
+ api_name="dehaze"
267
+ )
268
+
269
+
270
+ input_image.change(
271
+ fn=dehaze_image,
272
+ inputs=input_image,
273
+ outputs=output_image
274
+ )
275
+
276
+ gr.Markdown("""
277
+ ### About
278
+ This app uses the FFA-Net (Feature Fusion Attention Network) for single image dehazing.
279
+ The model removes atmospheric haze and fog to restore clear, vibrant images.
280
+
281
+ **Tips for best results:**
282
+ - Use good quality images with visible haze or fog
283
+ - Model works best on indoor images
284
+ **Made by <a href="https://www.linkedin.com/in/aditsg26/">Aditya Singh</a> and <a href="https://www.linkedin.com/in/ramandeep-singh-makkar/">Ramandeep Singh Makkar</a>**
285
+ """)
286
+
287
+ return demo
288
+
289
+ if __name__ == "__main__":
290
+
291
+ demo = create_interface()
292
+
293
+
294
+ demo.launch(
295
+ server_name="0.0.0.0",
296
+ server_port=7860,
297
+ share=False,
298
+ debug=False
299
+ )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ pillow
3
+ torch
4
+ torchvision