rsortino commited on
Commit
e39715e
·
1 Parent(s): 81abe5e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +88 -10
README.md CHANGED
@@ -35,13 +35,7 @@ ColorizeNet is an image colorization model based on ControlNet, trained using th
35
 
36
  - **Repository:** [https://github.com/rensortino/ColorizeNet]
37
  -
38
- ## How to Get Started with the Model
39
-
40
- Use the code below to get started with the model.
41
-
42
- [More Information Needed]
43
-
44
- ## Training Details
45
 
46
  ### Training Data
47
 
@@ -51,6 +45,90 @@ The model has been trained on COCO, using all the images in the dataset and conv
51
 
52
  [https://huggingface.co/datasets/detection-datasets/coco]
53
 
54
- ### Results
55
-
56
- [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  - **Repository:** [https://github.com/rensortino/ColorizeNet]
37
  -
38
+ ## Usage
 
 
 
 
 
 
39
 
40
  ### Training Data
41
 
 
45
 
46
  [https://huggingface.co/datasets/detection-datasets/coco]
47
 
48
+ ### Run the model
49
+
50
+ Instantiate the model and load its configuration and weights
51
+
52
+ ```python
53
+ import random
54
+
55
+ import cv2
56
+ import einops
57
+ import numpy as np
58
+ import torch
59
+ from pytorch_lightning import seed_everything
60
+
61
+ from utils.data import HWC3, apply_color, resize_image
62
+ from utils.ddim import DDIMSampler
63
+ from utils.model import create_model, load_state_dict
64
+
65
+ model = create_model('./models/cldm_v21.yaml').cpu()
66
+ model.load_state_dict(load_state_dict(
67
+ 'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda'))
68
+ model = model.cuda()
69
+ ddim_sampler = DDIMSampler(model)
70
+ ```
71
+
72
+ Read the image to be colorized
73
+
74
+ ```python
75
+ input_image = cv2.imread("sample_data/sample1_bw.jpg")
76
+ input_image = HWC3(input_image)
77
+ img = resize_image(input_image, resolution=512)
78
+ H, W, C = img.shape
79
+
80
+ num_samples = 1
81
+ control = torch.from_numpy(img.copy()).float().cuda() / 255.0
82
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
83
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
84
+ ```
85
+
86
+ Prepare the input and parameters of the model
87
+
88
+ ```python
89
+ seed = 1294574436
90
+ seed_everything(seed)
91
+ prompt = "Colorize this image"
92
+ n_prompt = ""
93
+ guess_mode = False
94
+ strength = 1.0
95
+ eta = 0.0
96
+ ddim_steps = 20
97
+ scale = 9.0
98
+
99
+ cond = {"c_concat": [control], "c_crossattn": [
100
+ model.get_learned_conditioning([prompt] * num_samples)]}
101
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
102
+ model.get_learned_conditioning([n_prompt] * num_samples)]}
103
+ shape = (4, H // 8, W // 8)
104
+
105
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
106
+ [strength] * 13)
107
+ ```
108
+
109
+ Sample and post-process the results
110
+
111
+ ```python
112
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
113
+ shape, cond, verbose=False, eta=eta,
114
+ unconditional_guidance_scale=scale,
115
+ unconditional_conditioning=un_cond)
116
+
117
+ x_samples = model.decode_first_stage(samples)
118
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
119
+ * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
120
+
121
+ results = [x_samples[i] for i in range(num_samples)]
122
+ colored_results = [apply_color(img, result) for result in results]
123
+ ```
124
+
125
+ ## Results
126
+
127
+ BW Input | Colorized
128
+ :-------------------------:|:-------------------------:
129
+ ![image](docs/sample1_bw.png) | ![image](docs/sample1.png)
130
+ ![image](docs/sample2_bw.png) | ![image](docs/sample2.png)
131
+ ![image](docs/sample3_bw.png) | ![image](docs/sample3.png)
132
+ ![image](docs/sample4_bw.png) | ![image](docs/sample4.png)
133
+ ![image](docs/sample5_bw.png) | ![image](docs/sample5.png)
134
+ ![image](docs/sample6_bw.png) | ![image](docs/sample6.png)