File size: 3,741 Bytes
1cae162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import os
from PIL import Image
import numpy as np
from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
from diffusion_module.utils.Pipline import SDMLDMPipeline

def log_validation(vae, unet, noise_scheduler, accelerator, weight_dtype, data_ld, 
                   resolution=512,g_step=2,save_dir="cityspace_test"):
    scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config)
    pipeline = SDMLDMPipeline(
        vae=accelerator.unwrap_model(vae),
        unet=accelerator.unwrap_model(unet),
        scheduler=scheduler,
        torch_dtype=weight_dtype,
        resolution = resolution,
        resolution_type="crack"
    )

    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=False)
    pipeline.enable_xformers_memory_efficient_attention()

    generator = None
    for i ,batch in enumerate(data_ld):
        if i > 2:
            break
        images = []
        with torch.autocast("cuda"):
            segmap = preprocess_input(batch[1]['label'], num_classes=151)
            segmap = segmap.to("cuda").to(torch.float16)
            # 暂时删除这个因为不想写绘图的函数,种类多太麻烦了
            # segmap_clr = batch[1]['label_ori'][0].permute(0, 3, 1, 2) / 255.
           
            image = pipeline(segmap=segmap[0][None,:], generator=generator,batch_size = 1,
                              num_inference_steps=50, s=1.5).images

            #segmap_clr = segmap_clr.cpu()
            #segmap_clr = segmap_clr[0].permute(1, 2, 0).numpy()
            #segmap_clr = (segmap_clr * 255).astype('uint8')
            # pil_image = Image.fromarray(segmap_clr)
            # images.append(pil_image)
            #print(image)
            #image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0]

        images.extend(image)
        merge_images(images, i,accelerator,g_step)
    del pipeline
    torch.cuda.empty_cache()
        

def merge_images(images, val_step,accelerator,step):
    for k, image in enumerate(images):
        """
        if k == 0:
            filename = "{}_condition.png".format(val_step)
        else:
            filename = "{}_{}.png".format(val_step, k)
        """
        filename = "{}_{}.png".format(val_step, k)
        # 更新的路径,包含'singles'文件夹
        path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "singles", filename)
        os.makedirs(os.path.split(path)[0], exist_ok=True)
        
        image.save(path)

    # 创建一个新的画板来合并所有图像
    total_width = sum(img.width for img in images)
    max_height = max(img.height for img in images)
    combined_image = Image.new('RGB', (total_width, max_height))

    # 粘贴每张图像到画板上
    x_offset = 0
    for img in images:
        # 转换灰度图像为RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')
        combined_image.paste(img, (x_offset, 0))
        x_offset += img.width

    # 保存合并后的图像,路径包含'merges'文件夹
    merge_filename = "{}_merge.png".format(val_step)
    merge_path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "merges", merge_filename)
    os.makedirs(os.path.split(merge_path)[0], exist_ok=True)
    combined_image.save(merge_path)
    
def preprocess_input(data, num_classes):
    # move to GPU and change data types
    data = data.to(dtype=torch.int64)

    # create one-hot label map
    label_map = data
    bs, _, h, w = label_map.size()
    input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device)
    input_semantics = input_label.scatter_(1, label_map, 1.0)

    return input_semantics