File size: 1,961 Bytes
147a8af
 
4852587
147a8af
4852587
147a8af
4852587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147a8af
 
4852587
147a8af
 
 
4852587
 
 
147a8af
4852587
 
 
147a8af
4852587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147a8af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from model import Generator  # Assuming you are using Hammad712's model structure

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Generator().to(device)
model.load_state_dict(torch.load('generator.pth', map_location=device))
model.eval()

# Define preprocessing and postprocessing
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

postprocess = transforms.ToPILImage()

def colorize_image(input_image):
    input_tensor = preprocess(input_image).unsqueeze(0).to(device)
    with torch.no_grad():
        output_tensor = model(input_tensor)
    output_image = postprocess(output_tensor.squeeze(0).cpu().clamp(0, 1))
    return output_image

def reset():
    return None, None

with gr.Blocks() as demo:
    gr.Markdown("# 🎨 Image Colorization App")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload your grayscale image", type="pil")
            clear_button = gr.Button("πŸ”„ Reset / Clear")
            download_button = gr.File(label="Download Colorized Image")
        with gr.Column():
            output_image = gr.Image(label="Colorized Image")

    colorize_btn = gr.Button("✨ Colorize Image")

    colorize_btn.click(
        colorize_image,
        inputs=input_image,
        outputs=output_image
    )

    clear_button.click(
        reset,
        inputs=[],
        outputs=[input_image, output_image]
    )

    # Allow download after processing
    def prepare_download(image):
        if image:
            path = "colorized_output.png"
            image.save(path)
            return path
        else:
            return None

    output_image.change(
        prepare_download,
        inputs=output_image,
        outputs=download_button
    )

demo.launch()