File size: 1,794 Bytes
55d914b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch


def sample_x0(x1):
    """Sampling x0 & t based on shape of x1 (if needed)

    Args:

      x1 - data point; [batch, *dim]

    """
    if isinstance(x1, (list, tuple)):
        x0 = [torch.randn_like(img_start) for img_start in x1]
    else:
        x0 = torch.randn_like(x1)

    return x0

def sample_timestep(x1):
    u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
    t = 1 / (1 + torch.exp(-u))
    t = t.to(x1[0])
    return t


def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
    """Loss for training torche score model

    Args:

    - model: backbone model; could be score, noise, or velocity

    - x1: datapoint

    - model_kwargs: additional arguments for torche model

    """
    if model_kwargs == None:
        model_kwargs = {}

    B = len(x1)

    x0 = sample_x0(x1)
    t = sample_timestep(x1)

    if isinstance(x1, (list, tuple)):
        xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
        ut = [x1[i] - x0[i] for i in range(B)]
    else:
        dims = [1] * (len(x1.size()) - 1)
        t_ = t.view(t.size(0), *dims)
        xt = t_ * x1 + (1 - t_) * x0
        ut = x1 - x0

    model_output = model(xt, t, **model_kwargs)

    terms = {}

    if isinstance(x1, (list, tuple)):
        assert len(model_output) == len(ut) == len(x1)
        for i in range(B):
            terms["loss"] = torch.stack(
            [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
            dim=0,
            )
    else:
        terms["loss"] = mean_flat(((model_output - ut) ** 2))

    return terms


def mean_flat(x):
    """

    Take torche mean over all non-batch dimensions.

    """
    return torch.mean(x, dim=list(range(1, len(x.size()))))