ItsJATAYU commited on
Commit
4852587
Β·
verified Β·
1 Parent(s): 0bd319a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -33
app.py CHANGED
@@ -1,46 +1,72 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionImg2ImgPipeline
4
  from PIL import Image
 
5
 
6
- # Load the model
7
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
8
- "rsortino/ColorizeNet", # YOUR MODEL
9
- torch_dtype=torch.float16
10
- )
11
-
12
- # Move to CUDA if available
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- pipe = pipe.to(device)
15
-
16
- # Disable safety checker
17
- pipe.safety_checker = lambda images, **kwargs: (images, False)
18
-
19
- def colorize(image: Image.Image) -> Image.Image:
20
- image = image.convert("RGB").resize((512, 512))
21
- result = pipe(
22
- prompt="A realistic colorized version of this image.",
23
- image=image,
24
- strength=0.8,
25
- guidance_scale=7.5,
26
- num_inference_steps=30
27
- )
28
- return result.images[0]
 
29
 
30
  with gr.Blocks() as demo:
31
- gr.Markdown("## 🎨 ColorizeNet - Grayscale to Color Image")
32
- gr.Markdown("Upload a grayscale image. The model will generate a realistic colorized version.")
33
 
34
  with gr.Row():
35
  with gr.Column():
36
- input_img = gr.Image(label="Grayscale Input", type="pil")
37
- submit_btn = gr.Button("Colorize")
38
-
39
  with gr.Column():
40
- output_img = gr.Image(label="Colorized Output", type="pil")
41
- download_btn = gr.Button("Download")
 
42
 
43
- submit_btn.click(fn=colorize, inputs=input_img, outputs=output_img)
44
- download_btn.click(fn=lambda img: img, inputs=output_img, outputs=gr.File())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from torchvision import transforms
4
  from PIL import Image
5
+ from model import Generator # Assuming you are using Hammad712's model structure
6
 
7
+ # Load model
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+ model = Generator().to(device)
10
+ model.load_state_dict(torch.load('generator.pth', map_location=device))
11
+ model.eval()
12
+
13
+ # Define preprocessing and postprocessing
14
+ preprocess = transforms.Compose([
15
+ transforms.Resize((256, 256)),
16
+ transforms.Grayscale(num_output_channels=1),
17
+ transforms.ToTensor()
18
+ ])
19
+
20
+ postprocess = transforms.ToPILImage()
21
+
22
+ def colorize_image(input_image):
23
+ input_tensor = preprocess(input_image).unsqueeze(0).to(device)
24
+ with torch.no_grad():
25
+ output_tensor = model(input_tensor)
26
+ output_image = postprocess(output_tensor.squeeze(0).cpu().clamp(0, 1))
27
+ return output_image
28
+
29
+ def reset():
30
+ return None, None
31
 
32
  with gr.Blocks() as demo:
33
+ gr.Markdown("# 🎨 Image Colorization App")
 
34
 
35
  with gr.Row():
36
  with gr.Column():
37
+ input_image = gr.Image(label="Upload your grayscale image", type="pil")
38
+ clear_button = gr.Button("πŸ”„ Reset / Clear")
39
+ download_button = gr.File(label="Download Colorized Image")
40
  with gr.Column():
41
+ output_image = gr.Image(label="Colorized Image")
42
+
43
+ colorize_btn = gr.Button("✨ Colorize Image")
44
 
45
+ colorize_btn.click(
46
+ colorize_image,
47
+ inputs=input_image,
48
+ outputs=output_image
49
+ )
50
+
51
+ clear_button.click(
52
+ reset,
53
+ inputs=[],
54
+ outputs=[input_image, output_image]
55
+ )
56
+
57
+ # Allow download after processing
58
+ def prepare_download(image):
59
+ if image:
60
+ path = "colorized_output.png"
61
+ image.save(path)
62
+ return path
63
+ else:
64
+ return None
65
+
66
+ output_image.change(
67
+ prepare_download,
68
+ inputs=output_image,
69
+ outputs=download_button
70
+ )
71
 
72
  demo.launch()