Update app.py
Browse files
app.py
CHANGED
@@ -1,46 +1,72 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from
|
4 |
from PIL import Image
|
|
|
5 |
|
6 |
-
# Load
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
)
|
11 |
-
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
with gr.Blocks() as demo:
|
31 |
-
gr.Markdown("
|
32 |
-
gr.Markdown("Upload a grayscale image. The model will generate a realistic colorized version.")
|
33 |
|
34 |
with gr.Row():
|
35 |
with gr.Column():
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
with gr.Column():
|
40 |
-
|
41 |
-
|
|
|
42 |
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from torchvision import transforms
|
4 |
from PIL import Image
|
5 |
+
from model import Generator # Assuming you are using Hammad712's model structure
|
6 |
|
7 |
+
# Load model
|
8 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
9 |
+
model = Generator().to(device)
|
10 |
+
model.load_state_dict(torch.load('generator.pth', map_location=device))
|
11 |
+
model.eval()
|
12 |
+
|
13 |
+
# Define preprocessing and postprocessing
|
14 |
+
preprocess = transforms.Compose([
|
15 |
+
transforms.Resize((256, 256)),
|
16 |
+
transforms.Grayscale(num_output_channels=1),
|
17 |
+
transforms.ToTensor()
|
18 |
+
])
|
19 |
+
|
20 |
+
postprocess = transforms.ToPILImage()
|
21 |
+
|
22 |
+
def colorize_image(input_image):
|
23 |
+
input_tensor = preprocess(input_image).unsqueeze(0).to(device)
|
24 |
+
with torch.no_grad():
|
25 |
+
output_tensor = model(input_tensor)
|
26 |
+
output_image = postprocess(output_tensor.squeeze(0).cpu().clamp(0, 1))
|
27 |
+
return output_image
|
28 |
+
|
29 |
+
def reset():
|
30 |
+
return None, None
|
31 |
|
32 |
with gr.Blocks() as demo:
|
33 |
+
gr.Markdown("# π¨ Image Colorization App")
|
|
|
34 |
|
35 |
with gr.Row():
|
36 |
with gr.Column():
|
37 |
+
input_image = gr.Image(label="Upload your grayscale image", type="pil")
|
38 |
+
clear_button = gr.Button("π Reset / Clear")
|
39 |
+
download_button = gr.File(label="Download Colorized Image")
|
40 |
with gr.Column():
|
41 |
+
output_image = gr.Image(label="Colorized Image")
|
42 |
+
|
43 |
+
colorize_btn = gr.Button("β¨ Colorize Image")
|
44 |
|
45 |
+
colorize_btn.click(
|
46 |
+
colorize_image,
|
47 |
+
inputs=input_image,
|
48 |
+
outputs=output_image
|
49 |
+
)
|
50 |
+
|
51 |
+
clear_button.click(
|
52 |
+
reset,
|
53 |
+
inputs=[],
|
54 |
+
outputs=[input_image, output_image]
|
55 |
+
)
|
56 |
+
|
57 |
+
# Allow download after processing
|
58 |
+
def prepare_download(image):
|
59 |
+
if image:
|
60 |
+
path = "colorized_output.png"
|
61 |
+
image.save(path)
|
62 |
+
return path
|
63 |
+
else:
|
64 |
+
return None
|
65 |
+
|
66 |
+
output_image.change(
|
67 |
+
prepare_download,
|
68 |
+
inputs=output_image,
|
69 |
+
outputs=download_button
|
70 |
+
)
|
71 |
|
72 |
demo.launch()
|