ItsJATAYU commited on
Commit
4820090
·
verified ·
1 Parent(s): 50039ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -53
app.py CHANGED
@@ -1,55 +1,72 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionPipeline
3
  import torch
4
- from transformers import pipeline
5
-
6
- # Use an available colorization pipeline from Hugging Face
7
- colorizer = pipeline("image-to-image", model="HuggingFaceM4/DeOldify", device_map="auto")
8
-
9
- def colorize_image(input_image):
10
- output = colorizer(input_image)
11
- return output[0]['image']
12
-
13
- def reset():
14
- return None, None
15
-
16
- with gr.Blocks() as demo:
17
- gr.Markdown("# 🎨 Image Colorization App")
18
-
19
- with gr.Row():
20
- with gr.Column():
21
- input_image = gr.Image(label="Upload your grayscale image", type="pil")
22
- clear_button = gr.Button("🔄 Reset / Clear")
23
- download_button = gr.File(label="Download Colorized Image")
24
- with gr.Column():
25
- output_image = gr.Image(label="Colorized Image")
26
-
27
- colorize_btn = gr.Button("✨ Colorize Image")
28
-
29
- colorize_btn.click(
30
- colorize_image,
31
- inputs=input_image,
32
- outputs=output_image
33
- )
34
-
35
- clear_button.click(
36
- reset,
37
- inputs=[],
38
- outputs=[input_image, output_image]
39
- )
40
-
41
- def prepare_download(image):
42
- if image:
43
- path = "colorized_output.png"
44
- image.save(path)
45
- return path
46
- else:
47
- return None
48
-
49
- output_image.change(
50
- prepare_download,
51
- inputs=output_image,
52
- outputs=download_button
53
- )
54
-
55
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
  import torch
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ from skimage.color import rgb2lab, lab2rgb
7
+ import numpy as np
8
+ import requests
9
+ from io import BytesIO
10
+
11
+ # Download the model from Hugging Face Hub
12
+ repo_id = "Hammad712/GAN-Colorization-Model"
13
+ model_filename = "generator.pt"
14
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
15
+
16
+ # Define the generator model (same architecture as used during training)
17
+ from fastai.vision.learner import create_body
18
+ from torchvision.models import resnet34
19
+ from fastai.vision.models.unet import DynamicUnet
20
+
21
+ def build_generator(n_input=1, n_output=2, size=256):
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ backbone = create_body(resnet34(), pretrained=True, n_in=n_input, cut=-2)
24
+ G_net = DynamicUnet(backbone, n_output, (size, size)).to(device)
25
+ return G_net
26
+
27
+ # Initialize and load the model
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ G_net = build_generator(n_input=1, n_output=2, size=256)
30
+ G_net.load_state_dict(torch.load(model_path, map_location=device))
31
+ G_net.eval()
32
+
33
+ # Preprocessing function
34
+ def preprocess_image(img):
35
+ img = img.convert("RGB")
36
+ img = transforms.Resize((256, 256), Image.BICUBIC)(img)
37
+ img = np.array(img)
38
+ img_to_lab = rgb2lab(img).astype("float32")
39
+ img_to_lab = transforms.ToTensor()(img_to_lab)
40
+ L = img_to_lab[[0], ...] / 50. - 1.
41
+ return L.unsqueeze(0).to(device)
42
+
43
+ # Inference function
44
+ def colorize_image(img, model):
45
+ L = preprocess_image(img)
46
+ with torch.no_grad():
47
+ ab = model(L)
48
+ L = (L + 1.) * 50.
49
+ ab = ab * 110.
50
+ Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
51
+ rgb_imgs = []
52
+ for img in Lab:
53
+ img_rgb = lab2rgb(img)
54
+ rgb_imgs.append(img_rgb)
55
+ return np.stack(rgb_imgs, axis=0)[0] # Return the first (and only) image
56
+
57
+ # Gradio interface function
58
+ def colorization_function(image):
59
+ colorized_image = colorize_image(image, G_net)
60
+ return colorized_image
61
+
62
+ # Gradio Interface Setup
63
+ iface = gr.Interface(
64
+ fn=colorization_function,
65
+ inputs=gr.inputs.Image(type="pil", label="Upload Grayscale Image"),
66
+ outputs=gr.outputs.Image(type="numpy", label="Colorized Image"),
67
+ title="Image Colorization with GAN",
68
+ description="Upload a grayscale image, and the model will colorize it using AI."
69
+ )
70
+
71
+ # Launch the Gradio interface
72
+ iface.launch()