Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import gradio as gr
|
2 |
import torch
|
3 |
import torchvision.transforms as transforms
|
4 |
import numpy as np
|
@@ -10,46 +10,51 @@ from model.flol import create_model
|
|
10 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
11 |
#define some auxiliary functions
|
12 |
pil_to_tensor = transforms.ToTensor()
|
13 |
-
|
14 |
-
# define some parameters based on the run we want to make
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
model = create_model()
|
17 |
-
checkpoints = torch.load('./weights/flolv2_UHDLL.pt', map_location=device)
|
18 |
-
model.load_state_dict(checkpoints['params'])
|
19 |
-
model = model.to(device)
|
20 |
|
21 |
-
def load_img
|
22 |
img = Image.open(filename).convert("RGB")
|
23 |
img_tensor = pil_to_tensor(img)
|
24 |
return img_tensor
|
25 |
|
26 |
-
def process_img(image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
img = np.array(image)
|
28 |
-
img = img / 255.
|
29 |
img = img.astype(np.float32)
|
30 |
-
y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
|
31 |
|
32 |
with torch.no_grad():
|
33 |
x_hat = model(y)
|
34 |
|
35 |
-
restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
|
36 |
-
restored_img = np.clip(restored_img, 0
|
|
|
|
|
|
|
37 |
|
38 |
-
restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
|
39 |
-
return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
|
40 |
|
41 |
title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
|
42 |
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
|
43 |
-
|
44 |
[Juan Carlos Benito](https://github.com/juaben)
|
45 |
-
|
46 |
Fundación Cidaut
|
47 |
-
|
48 |
-
|
49 |
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
|
50 |
**This demo expects an image with some degradations.**
|
51 |
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
|
52 |
-
|
53 |
<br>
|
54 |
'''
|
55 |
|
@@ -69,15 +74,16 @@ css = """
|
|
69 |
"""
|
70 |
|
71 |
demo = gr.Interface(
|
72 |
-
fn
|
73 |
-
inputs
|
74 |
-
|
|
|
75 |
],
|
76 |
-
outputs
|
77 |
-
title
|
78 |
-
description
|
79 |
-
examples
|
80 |
-
css
|
81 |
)
|
82 |
|
83 |
if __name__ == '__main__':
|
|
|
1 |
+
import gradio as gr
|
2 |
import torch
|
3 |
import torchvision.transforms as transforms
|
4 |
import numpy as np
|
|
|
10 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
11 |
#define some auxiliary functions
|
12 |
pil_to_tensor = transforms.ToTensor()
|
|
|
|
|
13 |
|
14 |
+
# Define a list of available weight files
|
15 |
+
weights_files = {
|
16 |
+
"flolv2_UHDLL": './weights/flolv2_UHDLL.pt',
|
17 |
+
"flolv2_other": './weights/flolv2_other.pt', # Add other available weights here
|
18 |
+
}
|
19 |
+
|
20 |
+
# Initial model setup (without weights)
|
21 |
model = create_model()
|
|
|
|
|
|
|
22 |
|
23 |
+
def load_img(filename):
|
24 |
img = Image.open(filename).convert("RGB")
|
25 |
img_tensor = pil_to_tensor(img)
|
26 |
return img_tensor
|
27 |
|
28 |
+
def process_img(image, weight_file):
|
29 |
+
# Load the selected weight file
|
30 |
+
if weight_file in weights_files:
|
31 |
+
model_path = weights_files[weight_file]
|
32 |
+
checkpoints = torch.load(model_path, map_location=device)
|
33 |
+
model.load_state_dict(checkpoints['params'])
|
34 |
+
model.to(device)
|
35 |
+
|
36 |
img = np.array(image)
|
37 |
+
img = img / 255. # Normalize to [0, 1]
|
38 |
img = img.astype(np.float32)
|
39 |
+
y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
|
40 |
|
41 |
with torch.no_grad():
|
42 |
x_hat = model(y)
|
43 |
|
44 |
+
restored_img = x_hat.squeeze().permute(1, 2, 0).clamp_(0, 1).cpu().detach().numpy()
|
45 |
+
restored_img = np.clip(restored_img, 0., 1.)
|
46 |
+
|
47 |
+
restored_img = (restored_img * 255.0).round().astype(np.uint8) # Convert to uint8
|
48 |
+
return Image.fromarray(restored_img)
|
49 |
|
|
|
|
|
50 |
|
51 |
title = "Efficient Low-Light Enhancement ✏️🖼️ 🤗"
|
52 |
description = ''' ## [Efficient Low-Light Enhancement](https://github.com/cidautai/NAFourNet)
|
|
|
53 |
[Juan Carlos Benito](https://github.com/juaben)
|
|
|
54 |
Fundación Cidaut
|
|
|
|
|
55 |
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
|
56 |
**This demo expects an image with some degradations.**
|
57 |
Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
|
|
|
58 |
<br>
|
59 |
'''
|
60 |
|
|
|
74 |
"""
|
75 |
|
76 |
demo = gr.Interface(
|
77 |
+
fn=process_img,
|
78 |
+
inputs=[
|
79 |
+
gr.Image(type='pil', label='input'),
|
80 |
+
gr.Dropdown(choices=list(weights_files.keys()), label='Select Weight File', default="flolv2_UHDLL")
|
81 |
],
|
82 |
+
outputs=[gr.Image(type='pil', label='output')],
|
83 |
+
title=title,
|
84 |
+
description=description,
|
85 |
+
examples=examples,
|
86 |
+
css=css
|
87 |
)
|
88 |
|
89 |
if __name__ == '__main__':
|