JunhaoZhuang commited on
Commit
2aa0b4c
·
verified ·
1 Parent(s): b3ef14a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -6
app.py CHANGED
@@ -152,6 +152,22 @@ examples = [
152
  0,
153
  10
154
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  [
156
  "./assets/example_1/input.jpg",
157
  ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
@@ -177,9 +193,13 @@ global MultiResNetModel
177
  def load_ckpt(input_style):
178
  global pipeline
179
  global MultiResNetModel
180
- if input_style == "Sketch":
181
- ckpt_path = model_global_path + '/sketch/'
182
- rank = 128
 
 
 
 
183
  pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
184
  transformer = PixArtTransformer2DModel.from_pretrained(
185
  pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
@@ -345,6 +365,17 @@ def extract_line_image(query_image_, input_style, resolution):
345
  extracted_line = extract_lines(query_image)
346
  extracted_line = extracted_line.convert('L').convert('RGB')
347
  input_context = extracted_line
 
 
 
 
 
 
 
 
 
 
 
348
  torch.cuda.empty_cache()
349
  return input_context, extracted_line, input_context
350
 
@@ -461,7 +492,7 @@ with gr.Blocks() as demo:
461
  </div>
462
  <div style="text-align: left; margin: 0 auto;">
463
  <ol style="font-size: 1.1em;">
464
- <li>Choose input style: GrayImage(ScreenStyle) or Sketch.</li>
465
  <li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
466
  <li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
467
  <li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
@@ -481,7 +512,7 @@ with gr.Blocks() as demo:
481
  </div>
482
  <div style="text-align: left; margin: 0 auto;">
483
  <ol style="font-size: 1.1em;">
484
- <li>选择输入样式:灰度图(ScreenStyle)、线稿。</li>
485
  <li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
486
  <li>预处理图像:点击“预处理”按钮以去色图像。</li>
487
  <li>上传参考图像:上传多张参考图像以指导上色。</li>
@@ -499,7 +530,7 @@ with gr.Blocks() as demo:
499
 
500
  with gr.Column():
501
  with gr.Row():
502
- input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
503
  with gr.Row():
504
  with gr.Column():
505
  input_image = gr.Image(type="pil", label="Image to Colorize")
 
152
  0,
153
  10
154
  ],
155
+ [
156
+ "./assets/example_6/input.png",
157
+ ["./assets/example_6/ref1.png", "./assets/example_6/ref2.png", "./assets/example_6/ref3.png"],
158
+ "Sketch_Shading",
159
+ "512x800",
160
+ 0,
161
+ 10
162
+ ],
163
+ [
164
+ "./assets/example_7/input.png",
165
+ ["./assets/example_7/ref1.png", "./assets/example_7/ref2.png", "./assets/example_7/ref3.png", "./assets/example_7/ref4.png"],
166
+ "Sketch_Shading",
167
+ "640x640",
168
+ 2,
169
+ 10
170
+ ],
171
  [
172
  "./assets/example_1/input.jpg",
173
  ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
 
193
  def load_ckpt(input_style):
194
  global pipeline
195
  global MultiResNetModel
196
+ if input_style == "Sketch" or input_style == "Sketch_Shading":
197
+ if input_style == "Sketch":
198
+ ckpt_path = './ckpt/sketch/'
199
+ rank = 128
200
+ else:
201
+ ckpt_path = './ckpt/shading/'
202
+ rank = 128
203
  pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
204
  transformer = PixArtTransformer2DModel.from_pretrained(
205
  pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
 
365
  extracted_line = extract_lines(query_image)
366
  extracted_line = extracted_line.convert('L').convert('RGB')
367
  input_context = extracted_line
368
+ elif input_style == "Sketch_Shading":
369
+ query_image = query_image.convert('L').convert('RGB')
370
+ extracted_line = extract_lines(query_image)
371
+ extracted_line = extracted_line.convert('L').convert('RGB')
372
+ array1 = np.array(query_image)
373
+ array2 = np.array(extracted_line)
374
+ array2[array1 < 0.3 * 255.0] = 0
375
+ gray_rate = 125
376
+ up_bound = 145
377
+ array2[(array2 > gray_rate) & (array1 < up_bound) & (array1 > 0.3 * 255.0)] = gray_rate
378
+ input_context = Image.fromarray(np.uint8(array2))
379
  torch.cuda.empty_cache()
380
  return input_context, extracted_line, input_context
381
 
 
492
  </div>
493
  <div style="text-align: left; margin: 0 auto;">
494
  <ol style="font-size: 1.1em;">
495
+ <li>Choose input style: GrayImage(ScreenStyle)、Sketch with Shading or Sketch.</li>
496
  <li>Upload your image: Use the 'Upload' button to select the image you want to colorize.</li>
497
  <li>Preprocess the image: Click the 'Preprocess' button to decolorize the image.</li>
498
  <li>Upload reference images: Upload multiple reference images to guide the colorization.</li>
 
512
  </div>
513
  <div style="text-align: left; margin: 0 auto;">
514
  <ol style="font-size: 1.1em;">
515
+ <li>选择输入样式:灰度图(ScreenStyle)、线稿+阴影、线稿。</li>
516
  <li>上传您的图像:使用“上传”按钮选择要上色的图像。</li>
517
  <li>预处理图像:点击“预处理”按钮以去色图像。</li>
518
  <li>上传参考图像:上传多张参考图像以指导上色。</li>
 
530
 
531
  with gr.Column():
532
  with gr.Row():
533
+ input_style = gr.Radio(["GrayImage(ScreenStyle)", "Sketch_Shading", "Sketch"], label="Input Style", value="GrayImage(ScreenStyle)")
534
  with gr.Row():
535
  with gr.Column():
536
  input_image = gr.Image(type="pil", label="Image to Colorize")