ItsJATAYU commited on
Commit
147a8af
·
verified ·
1 Parent(s): 0cd0942

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel
4
+ from PIL import Image
5
+
6
+ # Load ControlNet model
7
+ controlnet = ControlNetModel.from_pretrained(
8
+ "rsortino/ColorizeNet",
9
+ torch_dtype=torch.float16,
10
+ use_safetensors=True
11
+ )
12
+
13
+ # Load the pipeline
14
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
15
+ "stabilityai/stable-diffusion-2-1",
16
+ controlnet=controlnet,
17
+ torch_dtype=torch.float16
18
+ )
19
+
20
+ # Move to CUDA if available
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ pipe = pipe.to(device)
23
+
24
+ # Disable safety checker
25
+ pipe.safety_checker = lambda images, **kwargs: (images, False)
26
+
27
+ def colorize(image: Image.Image) -> Image.Image:
28
+ image = image.convert("RGB").resize((512, 512))
29
+ result = pipe(
30
+ prompt="A realistic colorized version of this image.",
31
+ image=image,
32
+ control_image=image,
33
+ strength=1.0,
34
+ guidance_scale=9.0,
35
+ num_inference_steps=30
36
+ )
37
+ return result.images[0]
38
+
39
+ with gr.Blocks() as demo:
40
+ gr.Markdown("## 🎨 ColorizeNet - Grayscale to Color Image")
41
+ gr.Markdown("Upload a grayscale image. The model will generate a realistic colorized version.")
42
+
43
+ with gr.Row():
44
+ with gr.Column():
45
+ input_img = gr.Image(label="Grayscale Input", type="pil")
46
+ submit_btn = gr.Button("Colorize")
47
+
48
+ with gr.Column():
49
+ output_img = gr.Image(label="Colorized Output", type="pil")
50
+ download_btn = gr.Button("Download")
51
+
52
+ def handle_colorize(img):
53
+ return colorize(img)
54
+
55
+ def download_image(img):
56
+ return img
57
+
58
+ submit_btn.click(fn=handle_colorize, inputs=input_img, outputs=output_img)
59
+ download_btn.click(fn=download_image, inputs=output_img, outputs=gr.File())
60
+
61
+ demo.launch()