mpamt commited on
Commit
33d2a9d
·
1 Parent(s): 9b2b758

Update README.md

Browse files

Added "How to Use"

Files changed (1) hide show
  1. README.md +51 -0
README.md CHANGED
@@ -4,3 +4,54 @@ tags:
4
  - PyTorch
5
  - huggan
6
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - PyTorch
5
  - huggan
6
  ---
7
+
8
+ ## How To Use
9
+
10
+ ```python
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ import torch
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ from torch import nn
17
+
18
+ class Generator(nn.Module):
19
+ def __init__(self):
20
+ super(Generator, self).__init__()
21
+ self.main = nn.Sequential(
22
+ nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
23
+ nn.BatchNorm2d(64 * 8),
24
+ nn.ReLU(True),
25
+ nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
26
+ nn.BatchNorm2d(64 * 4),
27
+ nn.ReLU(True),
28
+ nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
29
+ nn.BatchNorm2d(64 * 2),
30
+ nn.ReLU(True),
31
+ nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
32
+ nn.BatchNorm2d(64),
33
+ nn.ReLU(True),
34
+ nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
35
+ nn.Tanh()
36
+ )
37
+
38
+ def forward(self, input):
39
+ return self.main(input)
40
+
41
+ path = hf_hub_download('huggan/ArtGAN', 'ArtGAN.pt')
42
+ model = torch.load(path, map_location=torch.device('cpu'))
43
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+
45
+ def generate(seed):
46
+ with torch.no_grad():
47
+ noise = torch.randn(seed, 100, 1, 1, device=device)
48
+ with torch.no_grad():
49
+ art = model(noise).detach().cpu()
50
+ gen = np.transpose(art[-1], (1, 2, 0))
51
+ fig = plt.figure(figsize=(5, 5))
52
+ plt.imshow(gen)
53
+ plt.axis('off')
54
+
55
+ generate(25)
56
+
57
+ ```