ItsJATAYU commited on
Commit
e5285b0
·
verified ·
1 Parent(s): e2b90c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
- from huggingface_hub import hf_hub_download
3
  import torch
 
4
  from PIL import Image
5
  from torchvision import transforms
6
  from skimage.color import rgb2lab, lab2rgb
@@ -13,7 +13,6 @@ repo_id = "Hammad712/GAN-Colorization-Model"
13
  model_filename = "generator.pt"
14
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
15
 
16
- # Define the generator model (same architecture as used during training)
17
  from fastai.vision.learner import create_body
18
  from torchvision.models import resnet34
19
  from fastai.vision.models.unet import DynamicUnet
@@ -52,21 +51,23 @@ def colorize_image(img, model):
52
  for img in Lab:
53
  img_rgb = lab2rgb(img)
54
  rgb_imgs.append(img_rgb)
55
- return np.stack(rgb_imgs, axis=0)[0] # Return the first (and only) image
56
 
57
- # Gradio interface function
58
- def colorization_function(image):
59
- colorized_image = colorize_image(image, G_net)
60
- return colorized_image
 
61
 
62
- # Gradio Interface Setup
63
  iface = gr.Interface(
64
- fn=colorization_function,
65
- inputs=gr.inputs.Image(type="pil", label="Upload Grayscale Image"),
66
- outputs=gr.outputs.Image(type="numpy", label="Colorized Image"),
67
- title="Image Colorization with GAN",
68
- description="Upload a grayscale image, and the model will colorize it using AI."
 
69
  )
70
 
71
- # Launch the Gradio interface
72
  iface.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ from huggingface_hub import hf_hub_download
4
  from PIL import Image
5
  from torchvision import transforms
6
  from skimage.color import rgb2lab, lab2rgb
 
13
  model_filename = "generator.pt"
14
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
15
 
 
16
  from fastai.vision.learner import create_body
17
  from torchvision.models import resnet34
18
  from fastai.vision.models.unet import DynamicUnet
 
51
  for img in Lab:
52
  img_rgb = lab2rgb(img)
53
  rgb_imgs.append(img_rgb)
54
+ return np.stack(rgb_imgs, axis=0)
55
 
56
+ # Gradio interface
57
+ def colorize(img):
58
+ colorized_images = colorize_image(img, G_net)
59
+ colorized_image = colorized_images[0]
60
+ return Image.fromarray((colorized_image * 255).astype(np.uint8))
61
 
62
+ # Create the Gradio interface
63
  iface = gr.Interface(
64
+ fn=colorize,
65
+ inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
66
+ outputs=gr.Image(type="pil", label="Colorized Image"),
67
+ title="AI Image Colorization",
68
+ description="Upload a black and white image, and the AI will colorize it.",
69
+ allow_flagging="never"
70
  )
71
 
72
+ # Launch the interface
73
  iface.launch()