File size: 1,738 Bytes
afc0c95
 
 
614425e
3e6762d
614425e
 
 
 
 
 
45e02cf
614425e
45e02cf
614425e
2746f82
45e02cf
 
614425e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2746f82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
license: mit
---
# A Text-Conditioned Diffusion-Prior

## Training Details
Training details can be found here: https://wandb.ai/nousr_laion/conditioned-prior/reports/Updated-Text-Conditioned-Prior--VmlldzoyMDI2OTIx
## Source Code
Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch with defaults specified in the train_diffusion_prior.py script
## Community: LAION
Join Us!: https://discord.gg/uPMftTmrvS

---

# Models
```
depth=12
d_model=768
clip = OpenAIClipAdapter(clip_choice=["ViT-L/14" | "ViT-B/32"])
```

### Loading the models might look something like this:
```python
def load_diffusion_model(dprior_path, device, clip_choice):

    loaded_obj = torch.load(str(dprior_path), map_location='cpu')
    
    if clip_choice == "ViT-B/32":
        dim = 512
    else:
        dim = 768

    prior_network = DiffusionPriorNetwork(
        dim=dim,
        depth=12,
        dim_head=64,
        heads=12,
        normformer=True
    ).to(device)

    diffusion_prior = DiffusionPrior(
        net=prior_network,
        clip=OpenAIClipAdapter(clip_choice),
        image_embed_dim=dim,
        timesteps=1000,
        cond_drop_prob=0.1,
        loss_type="l2",
    ).to(device)


    diffusion_prior.load_state_dict(loaded_obj["model"], strict=True)

    diffusion_prior = DiffusionPriorTrainer(
                      diffusion_prior = diffusion_prior,
                      lr = 1.1e-4,
                      wd = 6.02e-2,
                      max_grad_norm = 0.5,
                      amp = False,
                  ).to(device)

    diffusion_prior.optimizer.load_state_dict(loaded_obj['optimizer'])
    diffusion_prior.scaler.load_state_dict(loaded_obj['scaler'])

    return diffusion_prior
```