File size: 2,231 Bytes
cab8a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torchvision
from vqgan import VQModel
from torch.utils.data import Dataset, DataLoader
from transformers import T5EncoderModel, AutoTokenizer

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(256),
    torchvision.transforms.RandomCrop(256),
])


class YOUR_DATASET(Dataset):
    def __init__(self, dataset_path):
        pass


def get_dataloader(dataset_path, batch_size):
    dataset = YOUR_DATASET(dataset_path)
    return DataLoader(dataset, batch_size=batch_size, num_workers=8, pin_memory=True)


def load_conditional_models(byt5_model_name, vqgan_path, device):
    vqgan = VQModel().to(device)
    vqgan.load_state_dict(torch.load(vqgan_path, map_location=device)['state_dict'])
    vqgan.eval().requires_grad_(False)

    byt5 = T5EncoderModel.from_pretrained(byt5_model_name).to(device).eval().requires_grad_(False)
    byt5_tokenizer = AutoTokenizer.from_pretrained(byt5_model_name)

    return vqgan, (byt5_tokenizer, byt5)


def sample(model, model_inputs, latent_shape, unconditional_inputs=None, steps=12, renoise_steps=11, temperature=(1.0, 0.2), cfg=8.0, t_start=1.0, t_end=0.0, device="cuda"):
    with torch.inference_mode():
        sampled = torch.randint(0, model.num_labels, size=latent_shape, device=device)
        init_noise = sampled.clone()
        t_list = torch.linspace(t_start, t_end, steps+1)
        temperatures = torch.linspace(temperature[0], temperature[1], steps)
        for i, t in enumerate(t_list[:steps]):
            t = torch.ones(latent_shape[0], device=device) * t

            logits = model(sampled, t, **model_inputs)
            if cfg:
                logits = logits * cfg + model(sampled, t, **unconditional_inputs) * (1-cfg)
            scores = logits.div(temperatures[i]).softmax(dim=1)

            sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
            sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])

            if i < renoise_steps:
                t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
                sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0]
    return sampled