Spaces:
ismot
/
Runtime error

ismot geninhu commited on
Commit
40e6a61
·
0 Parent(s):

Duplicate from huggan/FastGan

Browse files

Co-authored-by: Nhu Hoang <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.wasm filter=lfs diff=lfs merge=lfs -text
25
+ *.xz filter=lfs diff=lfs merge=lfs -text
26
+ *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ aurora.png filter=lfs diff=lfs merge=lfs -text
30
+ fauvism.png filter=lfs diff=lfs merge=lfs -text
31
+ painting.png filter=lfs diff=lfs merge=lfs -text
32
+ shell.png filter=lfs diff=lfs merge=lfs -text
33
+ grumpy_cat.png filter=lfs diff=lfs merge=lfs -text
34
+ universe.png filter=lfs diff=lfs merge=lfs -text
35
+ moon_gate.png filter=lfs diff=lfs merge=lfs -text
36
+ assets/video/anime.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/video/aurora.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/video/fauvism.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/video/grumpy_cat.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/video/moon_gate.gif filter=lfs diff=lfs merge=lfs -text
41
+ assets/video/painting.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/video/universe.gif filter=lfs diff=lfs merge=lfs -text
43
+ assets/video/anime.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ assets/video/fauvism.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ assets/video/universe.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ assets/video/moongate.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ assets/video/painting.mp4 filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FastGan
3
+ emoji: 😎
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ sdk_version: 1.2.0
8
+ app_file: app.py
9
+ pinned: true
10
+ duplicated_from: huggan/FastGan
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
StyleMix.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.optim as optim
4
+ import torch.nn.functional as F
5
+ from torch.utils.data.dataloader import DataLoader
6
+ from torchvision import transforms
7
+ from torchvision import utils as vutils
8
+
9
+ from models import Generator
10
+ from utils import copy_G_params, load_params
11
+
12
+
13
+
14
+ def get_early_features(net, noise):
15
+ with torch.no_grad():
16
+ feat_4 = net._init(noise)
17
+ feat_8 = net._upsample_8(feat_4)
18
+ feat_16 = net._upsample_16(feat_8)
19
+ feat_32 = net._upsample_32(feat_16)
20
+ feat_64 = net._upsample_64(feat_32)
21
+ return feat_8, feat_16, feat_32, feat_64
22
+
23
+ def get_late_features(net, feat_64, feat_8, feat_16, feat_32):
24
+ with torch.no_grad():
25
+ feat_128 = net._upsample_128(feat_64)
26
+ feat_128 = net._sle_128(feat_8, feat_128)
27
+
28
+ feat_256 = net._upsample_256(feat_128)
29
+ feat_256 = net._sle_256(feat_16, feat_256)
30
+
31
+ feat_512 = net._upsample_512(feat_256)
32
+ feat_512 = net._sle_512(feat_32, feat_512)
33
+
34
+ feat_1024 = net._upsample_1024(feat_512)
35
+
36
+ return net._out_1024(feat_1024)
37
+
38
+ def style_mix(model_name_or_path, bs, device):
39
+ _in_channels = 256
40
+ im_size = 1024
41
+
42
+ netG = Generator(in_channels=_in_channels, out_channels=3)
43
+ netG = netG.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
44
+ _ = netG.to(device)
45
+ _ = netG.eval()
46
+
47
+ avg_param_G = copy_G_params(netG)
48
+ load_params(netG, avg_param_G)
49
+
50
+ noise_a = torch.randn(bs, 256, 1, 1, device=device).to(device)
51
+ noise_b = torch.randn(bs, 256, 1, 1, device=device).to(device)
52
+
53
+ feat_8_a, feat_16_a, feat_32_a, feat_64_a = get_early_features(netG, noise_a)
54
+ feat_8_b, feat_16_b, feat_32_b, feat_64_b = get_early_features(netG, noise_b)
55
+
56
+ images_b = get_late_features(netG, feat_64_b, feat_8_b, feat_16_b, feat_32_b)
57
+ images_a = get_late_features(netG, feat_64_a, feat_8_a, feat_16_a, feat_32_a)
58
+
59
+ imgs = [ torch.ones(1, 3, im_size, im_size) ]
60
+
61
+ imgs.append(images_b.cpu())
62
+ for i in range(bs):
63
+ imgs.append(images_a[i].unsqueeze(0).cpu())
64
+ gimgs = get_late_features(netG, feat_64_a[i].unsqueeze(0).repeat(bs, 1, 1, 1), feat_8_b, feat_16_b, feat_32_b)
65
+ imgs.append(gimgs.cpu())
66
+
67
+ imgs = torch.cat(imgs)
68
+ # vutils.save_image(imgs.add(1).mul(0.5), 'style_mix/style_mix_2.jpg', nrow=bs+1)
69
+
70
+ return imgs
__pycache__/StyleMix.cpython-39.pyc ADDED
Binary file (2.15 kB). View file
 
__pycache__/layers.cpython-39.pyc ADDED
Binary file (6.82 kB). View file
 
__pycache__/models.cpython-39.pyc ADDED
Binary file (4.76 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.82 kB). View file
 
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import streamlit as st
4
+
5
+ from models import Generator, Discriminrator
6
+ from StyleMix import style_mix
7
+ import torch
8
+ import torchvision.transforms as T
9
+ from torchvision.utils import make_grid
10
+ from PIL import Image
11
+
12
+ from streamlit_lottie import st_lottie
13
+ from streamlit_option_menu import option_menu
14
+ import requests
15
+
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+
19
+ model_name = {
20
+ "aurora": 'huggan/fastgan-few-shot-aurora',
21
+ "painting": 'huggan/fastgan-few-shot-painting',
22
+ "shell": 'huggan/fastgan-few-shot-shells',
23
+ "fauvism": 'huggan/fastgan-few-shot-fauvism-still-life',
24
+ "universe": 'huggan/fastgan-few-shot-universe',
25
+ "grumpy cat": 'huggan/fastgan-few-shot-grumpy-cat',
26
+ "anime": 'huggan/fastgan-few-shot-anime-face',
27
+ "moon gate": 'huggan/fastgan-few-shot-moongate',
28
+ }
29
+
30
+ #@st.cache(allow_output_mutation=True)
31
+ def load_generator(model_name_or_path):
32
+ generator = Generator(in_channels=256, out_channels=3)
33
+ generator = generator.from_pretrained(model_name_or_path, in_channels=256, out_channels=3)
34
+ _ = generator.to(device)
35
+ _ = generator.eval()
36
+
37
+ return generator
38
+
39
+ def _denormalize(input: torch.Tensor) -> torch.Tensor:
40
+ return (input * 127.5) + 127.5
41
+
42
+
43
+ def generate_images(generator, number_imgs):
44
+ noise = torch.zeros(number_imgs, 256, 1, 1, device=device).normal_(0.0, 1.0)
45
+ with torch.no_grad():
46
+ gan_images, _ = generator(noise)
47
+
48
+ gan_images = _denormalize(gan_images.detach()).cpu()
49
+ gan_images = [i for i in gan_images]
50
+ gan_images = [make_grid(i, nrow=1, normalize=True) for i in gan_images]
51
+ gan_images = [i.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() for i in gan_images]
52
+ gan_images = [Image.fromarray(i) for i in gan_images]
53
+ return gan_images
54
+
55
+ def load_lottieurl(url: str):
56
+ r = requests.get(url)
57
+ if r.status_code != 200:
58
+ return None
59
+ return r.json()
60
+
61
+ def show_model_summary(expanded):
62
+ st.subheader("Model gallery")
63
+ with st.expander('Image gallery', expanded=expanded):
64
+ col1, col2, col3, col4 = st.columns(4)
65
+ with col1:
66
+ st.markdown('Fauvism GAN [model](https://huggingface.co/huggan/fastgan-few-shot-fauvism-still-life)', unsafe_allow_html=True)
67
+ st.image('assets/image/fauvism.png', width=200)
68
+ st.markdown('Painting GAN [model](https://huggingface.co/huggan/fastgan-few-shot-painting)', unsafe_allow_html=True)
69
+ st.image('assets/image/painting.png', width=200)
70
+
71
+ with col2:
72
+ st.markdown('Aurora GAN [model](https://huggingface.co/huggan/fastgan-few-shot-aurora)', unsafe_allow_html=True)
73
+ st.image('assets/image/aurora.png', width=200)
74
+ st.markdown('Universe GAN [model](https://huggingface.co/huggan/fastgan-few-shot-universe)', unsafe_allow_html=True)
75
+ st.image('assets/image/universe.png', width=200)
76
+
77
+ with col3:
78
+ st.markdown('Anime GAN [model](https://huggingface.co/huggan/fastgan-few-shot-anime-face)', unsafe_allow_html=True)
79
+ st.image('assets/image/anime.png', width=200)
80
+ st.markdown('Shell GAN [model](https://huggingface.co/huggan/fastgan-few-shot-shells)', unsafe_allow_html=True)
81
+ st.image('assets/image/shell.png', width=200)
82
+
83
+ with col4:
84
+ st.markdown('Grumpy cat GAN [model](https://huggingface.co/huggan/fastgan-few-shot-grumpy-cat)', unsafe_allow_html=True)
85
+ st.image('assets/image/grumpy_cat.png', width=200)
86
+ st.markdown('Moon gate GAN [model](https://huggingface.co/huggan/fastgan-few-shot-moongate)', unsafe_allow_html=True)
87
+ st.image('assets/image/moon_gate.png', width=200)
88
+
89
+ with st.expander('Video gallery', expanded=True):
90
+ cols=st.columns(4)
91
+
92
+ cols[0].write("Universe GAN")
93
+ cols[0].video('assets/video/universe.mp4')
94
+ cols[0].write("Fauvism still life GAN")
95
+ cols[0].video('assets/video/fauvism.mp4')
96
+
97
+ cols[1].write("Aurora GAN")
98
+ cols[1].video('assets/video/aurora.mp4')
99
+ cols[1].write("Moon gate GAN")
100
+ cols[1].video('assets/video/moongate.mp4')
101
+
102
+ cols[2].write("Anime GAN")
103
+ cols[2].video('assets/video/anime.mp4')
104
+ cols[2].write("Painting GAN")
105
+ cols[2].video('assets/video/painting.mp4')
106
+
107
+ cols[3].write("Grumpy cat GAN")
108
+ cols[3].video('assets/video/grumpy.mp4')
109
+
110
+
111
+ def main():
112
+
113
+ st.set_page_config(
114
+ page_title="FastGAN Generator",
115
+ page_icon="🖥️",
116
+ layout="wide",
117
+ initial_sidebar_state="expanded"
118
+ )
119
+
120
+ lottie_penguin = load_lottieurl('https://assets7.lottiefiles.com/packages/lf20_mm4bsl3l.json')
121
+
122
+ with st.sidebar:
123
+ st_lottie(lottie_penguin, height=200)
124
+ choose = option_menu("FastGAN", ["Model Gallery", "Generate images", "Mix style"],
125
+ icons=['collection', 'file-plus', 'intersect'],
126
+ menu_icon="infinity", default_index=0,
127
+ styles={
128
+ "container": {"padding": ".0rem", "font-size": "14px"},
129
+ "nav-link-selected": {"color": "#000000", "font-size": "16px"},
130
+ }
131
+ )
132
+ st.sidebar.markdown(
133
+ """
134
+ ___
135
+ <p style='text-align: center'>
136
+ FastGAN is a few-shot GAN model trained on high-fidelity images which requires less computation resource and samples for training.
137
+ <br/>
138
+ <a href="https://arxiv.org/abs/2101.04775" target="_blank">Article</a>
139
+ </p>
140
+ <p style='text-align: center; font-size: 14px;'>
141
+ Model training and Spaces creating by
142
+ <br/>
143
+ <a href="https://www.linkedin.com/in/vumichien/" target="_blank">Chien Vu</a> | <a href="https://www.linkedin.com/in/nhu-hoang/" target="_blank">Nhu Hoang</a>
144
+ <br/>
145
+ </p>
146
+ """,
147
+ unsafe_allow_html=True,
148
+ )
149
+
150
+ if choose == 'Model Gallery':
151
+ st.header("Welcome to FastGAN")
152
+ show_model_summary(True)
153
+ elif choose == 'Generate images':
154
+ st.header("Generate images")
155
+ col11, col12, col13 = st.columns([3,3.5,3.5])
156
+ with col11:
157
+ img_type = st.selectbox("Choose type of image to generate", index=0,
158
+ options=["aurora", "anime", "painting", "fauvism", "shell", "universe", "grumpy cat", "moon gate"])
159
+
160
+ number_imgs = st.slider('How many images you want to generate ?', min_value=1, max_value=5)
161
+ if number_imgs is None:
162
+ st.write('Invalid number ! Please insert number of images to generate !')
163
+ raise ValueError('Invalid number ! Please insert number of images to generate !')
164
+
165
+ generate_button = st.button('Get Image')
166
+ if generate_button:
167
+ st.markdown("""
168
+ <small><i>Predictions may take up to 1 minute under high load. Please stand by.</i></small>
169
+ """,
170
+ unsafe_allow_html=True,)
171
+
172
+ if generate_button:
173
+ with col11:
174
+ with st.spinner(text=f"Loading selected model..."):
175
+ generator = load_generator(model_name[img_type])
176
+ with st.spinner(text=f"Generating images..."):
177
+ gan_images = generate_images(generator, number_imgs)
178
+ with col12:
179
+ st.image(gan_images[0], width=300)
180
+ if len(gan_images) > 1:
181
+ with col13:
182
+ if len(gan_images) <= 2:
183
+ st.image(gan_images[1], width=300)
184
+ else:
185
+ st.image(gan_images[1:], width=150)
186
+
187
+ elif choose == 'Mix style':
188
+ st.header("Mix style")
189
+ st.markdown(
190
+ """
191
+ <p style='text-align: left'>
192
+ Get the style representations of 2 images generated from the model to create a new one that mixes the style of two.
193
+ </p>
194
+ """,
195
+ unsafe_allow_html=True,
196
+ )
197
+ st.markdown("""___""")
198
+ col21, col22 = st.columns([3, 6])
199
+ with col21:
200
+ img_type = st.selectbox("Choose type of image to mix", index=0,
201
+ options=["aurora", "anime", "painting", "fauvism", "shell", "universe", "grumpy cat", "moon gate"])
202
+ number_imgs = st.slider('How many images you want to generate ?', min_value=1, max_value=3)
203
+ generate_button = st.button('Mix style')
204
+
205
+ if generate_button:
206
+ with col21:
207
+ with st.spinner(text=f"Mixing styles..."):
208
+ mix_imgs = style_mix(model_name[img_type], number_imgs, device)
209
+ mix_imgs = make_grid(mix_imgs, nrow=number_imgs+1, normalize=True)
210
+ mix_imgs = mix_imgs.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
211
+ mix_imgs = Image.fromarray(mix_imgs)
212
+ with col22:
213
+ st.image(mix_imgs, width=600)
214
+
215
+
216
+ if __name__ == '__main__':
217
+ main()
assets/image/anime.png ADDED
assets/image/aurora.png ADDED

Git LFS Details

  • SHA256: b48f5574f1e5dbd8a7ea95d31d55b3b8965ba968b900feaea30013f7258e0075
  • Pointer size: 132 Bytes
  • Size of remote file: 1.43 MB
assets/image/fauvism.png ADDED

Git LFS Details

  • SHA256: dc1126a228d6a58153c798fa42f107d6a4381866b4e16354269492614b33966f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
assets/image/grumpy_cat.png ADDED

Git LFS Details

  • SHA256: d91e48877b054a79cee2ccd8d6f4b3db9531afce3df8a7655bb38535740c8a6c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
assets/image/moon_gate.png ADDED

Git LFS Details

  • SHA256: 0bf214e6f577a3acc4e2be8e00178805fba5057bec9c4e8b4afd2eaf41238cd6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.17 MB
assets/image/painting.png ADDED

Git LFS Details

  • SHA256: fee0d866738cb2c70209cdb2e620045a564e31c84d3e7887310c0d81eaf38bc7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
assets/image/shell.png ADDED

Git LFS Details

  • SHA256: d1840716d3a61e93cd8fb581caa14bc1a92507a42e89bcaa5a4ca441dd8253a5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.24 MB
assets/image/universe.png ADDED

Git LFS Details

  • SHA256: 9a40182f912968b8df5e5d2733566ef143bc7366895674d99f5ed47370e45e68
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
assets/video/anime.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:040ef94e35a65978826df636850403c72a8ad8ca97432f0e8b543db9e1474b08
3
+ size 3398750
assets/video/aurora.mp4 ADDED
Binary file (903 kB). View file
 
assets/video/fauvism.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6fe077eefbd400b8202876f8a7f1de2982a11a3e4e6e68ef2ed7f85eb398ab1
3
+ size 1573497
assets/video/grumpy.mp4 ADDED
Binary file (627 kB). View file
 
assets/video/moongate.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba99ee59ea1c330dcad52c6448e726090bf741ac115d4a31765bd85d8316e85c
3
+ size 2861613
assets/video/painting.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f172256a1373a8c513aaf097f61ec69df14dbdce08c0d765b3ecd92e132b9c56
3
+ size 1477719
assets/video/shells.mp4 ADDED
Binary file (880 kB). View file
 
assets/video/universe.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d7ed5c0b077180d5bb9975e2a2051fca3bf67d8e1e59bfca8a8b31728c63271
3
+ size 15562186
layers.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.modules.batchnorm import BatchNorm2d
5
+ from torch.nn.utils import spectral_norm
6
+
7
+
8
+ class SpectralConv2d(nn.Module):
9
+
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__()
12
+ self._conv = spectral_norm(
13
+ nn.Conv2d(*args, **kwargs)
14
+ )
15
+
16
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
17
+ return self._conv(input)
18
+
19
+
20
+ class SpectralConvTranspose2d(nn.Module):
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__()
24
+ self._conv = spectral_norm(
25
+ nn.ConvTranspose2d(*args, **kwargs)
26
+ )
27
+
28
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
29
+ return self._conv(input)
30
+
31
+
32
+ class Noise(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self._weight = nn.Parameter(
37
+ torch.zeros(1),
38
+ requires_grad=True,
39
+ )
40
+
41
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
42
+ batch_size, _, height, width = input.shape
43
+ noise = torch.randn(batch_size, 1, height, width, device=input.device)
44
+ return self._weight * noise + input
45
+
46
+
47
+ class InitLayer(nn.Module):
48
+
49
+ def __init__(self, in_channels: int,
50
+ out_channels: int):
51
+ super().__init__()
52
+
53
+ self._layers = nn.Sequential(
54
+ SpectralConvTranspose2d(
55
+ in_channels=in_channels,
56
+ out_channels=out_channels * 2,
57
+ kernel_size=4,
58
+ stride=1,
59
+ padding=0,
60
+ bias=False,
61
+ ),
62
+ nn.BatchNorm2d(num_features=out_channels * 2),
63
+ nn.GLU(dim=1),
64
+ )
65
+
66
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
67
+ return self._layers(input)
68
+
69
+
70
+ class SLEBlock(nn.Module):
71
+
72
+ def __init__(self, in_channels: int,
73
+ out_channels: int):
74
+ super().__init__()
75
+
76
+ self._layers = nn.Sequential(
77
+ nn.AdaptiveAvgPool2d(output_size=4),
78
+ SpectralConv2d(
79
+ in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=4,
82
+ stride=1,
83
+ padding=0,
84
+ bias=False,
85
+ ),
86
+ nn.SiLU(),
87
+ SpectralConv2d(
88
+ in_channels=out_channels,
89
+ out_channels=out_channels,
90
+ kernel_size=1,
91
+ stride=1,
92
+ padding=0,
93
+ bias=False,
94
+ ),
95
+ nn.Sigmoid(),
96
+ )
97
+
98
+ def forward(self, low_dim: torch.Tensor,
99
+ high_dim: torch.Tensor) -> torch.Tensor:
100
+ return high_dim * self._layers(low_dim)
101
+
102
+
103
+ class UpsampleBlockT1(nn.Module):
104
+
105
+ def __init__(self, in_channels: int,
106
+ out_channels: int):
107
+ super().__init__()
108
+
109
+ self._layers = nn.Sequential(
110
+ nn.Upsample(scale_factor=2, mode='nearest'),
111
+ SpectralConv2d(
112
+ in_channels=in_channels,
113
+ out_channels=out_channels * 2,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding='same',
117
+ bias=False,
118
+ ),
119
+ nn.BatchNorm2d(num_features=out_channels * 2),
120
+ nn.GLU(dim=1),
121
+ )
122
+
123
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
124
+ return self._layers(input)
125
+
126
+
127
+ class UpsampleBlockT2(nn.Module):
128
+
129
+ def __init__(self, in_channels: int,
130
+ out_channels: int):
131
+ super().__init__()
132
+
133
+ self._layers = nn.Sequential(
134
+ nn.Upsample(scale_factor=2, mode='nearest'),
135
+ SpectralConv2d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels * 2,
138
+ kernel_size=3,
139
+ stride=1,
140
+ padding='same',
141
+ bias=False,
142
+ ),
143
+ Noise(),
144
+ BatchNorm2d(num_features=out_channels * 2),
145
+ nn.GLU(dim=1),
146
+ SpectralConv2d(
147
+ in_channels=out_channels,
148
+ out_channels=out_channels * 2,
149
+ kernel_size=3,
150
+ stride=1,
151
+ padding='same',
152
+ bias=False,
153
+ ),
154
+ Noise(),
155
+ nn.BatchNorm2d(num_features=out_channels * 2),
156
+ nn.GLU(dim=1),
157
+ )
158
+
159
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
160
+ return self._layers(input)
161
+
162
+
163
+ class DownsampleBlockT1(nn.Module):
164
+
165
+ def __init__(self, in_channels: int,
166
+ out_channels: int):
167
+ super().__init__()
168
+
169
+ self._layers = nn.Sequential(
170
+ SpectralConv2d(
171
+ in_channels=in_channels,
172
+ out_channels=out_channels,
173
+ kernel_size=4,
174
+ stride=2,
175
+ padding=1,
176
+ bias=False,
177
+ ),
178
+ nn.BatchNorm2d(num_features=out_channels),
179
+ nn.LeakyReLU(negative_slope=0.2),
180
+ )
181
+
182
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
183
+ return self._layers(input)
184
+
185
+
186
+ class DownsampleBlockT2(nn.Module):
187
+
188
+ def __init__(self, in_channels: int,
189
+ out_channels: int):
190
+ super().__init__()
191
+
192
+ self._layers_1 = nn.Sequential(
193
+ SpectralConv2d(
194
+ in_channels=in_channels,
195
+ out_channels=out_channels,
196
+ kernel_size=4,
197
+ stride=2,
198
+ padding=1,
199
+ bias=False,
200
+ ),
201
+ nn.BatchNorm2d(num_features=out_channels),
202
+ nn.LeakyReLU(negative_slope=0.2),
203
+ SpectralConv2d(
204
+ in_channels=out_channels,
205
+ out_channels=out_channels,
206
+ kernel_size=3,
207
+ stride=1,
208
+ padding='same',
209
+ bias=False,
210
+ ),
211
+ nn.BatchNorm2d(num_features=out_channels),
212
+ nn.LeakyReLU(negative_slope=0.2),
213
+ )
214
+
215
+ self._layers_2 = nn.Sequential(
216
+ nn.AvgPool2d(
217
+ kernel_size=2,
218
+ stride=2,
219
+ ),
220
+ SpectralConv2d(
221
+ in_channels=in_channels,
222
+ out_channels=out_channels,
223
+ kernel_size=1,
224
+ stride=1,
225
+ padding=0,
226
+ bias=False,
227
+ ),
228
+ nn.BatchNorm2d(num_features=out_channels),
229
+ nn.LeakyReLU(negative_slope=0.2),
230
+ )
231
+
232
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
233
+ t1 = self._layers_1(input)
234
+ t2 = self._layers_2(input)
235
+ return (t1 + t2) / 2
236
+
237
+
238
+ class Decoder(nn.Module):
239
+
240
+ def __init__(self, in_channels: int,
241
+ out_channels: int):
242
+ super().__init__()
243
+
244
+ self._channels = {
245
+ 16: 128,
246
+ 32: 64,
247
+ 64: 64,
248
+ 128: 32,
249
+ 256: 16,
250
+ 512: 8,
251
+ 1024: 4,
252
+ }
253
+
254
+ self._layers = nn.Sequential(
255
+ nn.AdaptiveAvgPool2d(output_size=8),
256
+ UpsampleBlockT1(in_channels=in_channels, out_channels=self._channels[16]),
257
+ UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]),
258
+ UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]),
259
+ UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]),
260
+ SpectralConv2d(
261
+ in_channels=self._channels[128],
262
+ out_channels=out_channels,
263
+ kernel_size=3,
264
+ stride=1,
265
+ padding='same',
266
+ bias=False,
267
+ ),
268
+ nn.Tanh(),
269
+ )
270
+
271
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
272
+ return self._layers(input)
models.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Any, Tuple, Union
5
+
6
+ from utils import (
7
+ ImageType,
8
+ crop_image_part,
9
+ )
10
+
11
+ from layers import (
12
+ SpectralConv2d,
13
+ InitLayer,
14
+ SLEBlock,
15
+ UpsampleBlockT1,
16
+ UpsampleBlockT2,
17
+ DownsampleBlockT1,
18
+ DownsampleBlockT2,
19
+ Decoder,
20
+ )
21
+
22
+ from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
23
+
24
+
25
+ class Generator(nn.Module, HugGANModelHubMixin):
26
+
27
+ def __init__(self, in_channels: int,
28
+ out_channels: int):
29
+ super().__init__()
30
+
31
+ self._channels = {
32
+ 4: 1024,
33
+ 8: 512,
34
+ 16: 256,
35
+ 32: 128,
36
+ 64: 128,
37
+ 128: 64,
38
+ 256: 32,
39
+ 512: 16,
40
+ 1024: 8,
41
+ }
42
+
43
+ self._init = InitLayer(
44
+ in_channels=in_channels,
45
+ out_channels=self._channels[4],
46
+ )
47
+
48
+ self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] )
49
+ self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] )
50
+ self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] )
51
+ self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] )
52
+ self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] )
53
+ self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] )
54
+ self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] )
55
+ self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024])
56
+
57
+ self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] )
58
+ self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128])
59
+ self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256])
60
+ self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512])
61
+
62
+ self._out_128 = nn.Sequential(
63
+ SpectralConv2d(
64
+ in_channels=self._channels[128],
65
+ out_channels=out_channels,
66
+ kernel_size=1,
67
+ stride=1,
68
+ padding='same',
69
+ bias=False,
70
+ ),
71
+ nn.Tanh(),
72
+ )
73
+
74
+ self._out_1024 = nn.Sequential(
75
+ SpectralConv2d(
76
+ in_channels=self._channels[1024],
77
+ out_channels=out_channels,
78
+ kernel_size=3,
79
+ stride=1,
80
+ padding='same',
81
+ bias=False,
82
+ ),
83
+ nn.Tanh(),
84
+ )
85
+
86
+ def forward(self, input: torch.Tensor) -> \
87
+ Tuple[torch.Tensor, torch.Tensor]:
88
+ size_4 = self._init(input)
89
+ size_8 = self._upsample_8(size_4)
90
+ size_16 = self._upsample_16(size_8)
91
+ size_32 = self._upsample_32(size_16)
92
+
93
+ size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) )
94
+ size_128 = self._sle_128(size_8, self._upsample_128(size_64) )
95
+ size_256 = self._sle_256(size_16, self._upsample_256(size_128))
96
+ size_512 = self._sle_512(size_32, self._upsample_512(size_256))
97
+
98
+ size_1024 = self._upsample_1024(size_512)
99
+
100
+ out_128 = self._out_128 (size_128)
101
+ out_1024 = self._out_1024(size_1024)
102
+ return out_1024, out_128
103
+
104
+
105
+ class Discriminrator(nn.Module, HugGANModelHubMixin):
106
+
107
+ def __init__(self, in_channels: int):
108
+ super().__init__()
109
+
110
+ self._channels = {
111
+ 4: 1024,
112
+ 8: 512,
113
+ 16: 256,
114
+ 32: 128,
115
+ 64: 128,
116
+ 128: 64,
117
+ 256: 32,
118
+ 512: 16,
119
+ 1024: 8,
120
+ }
121
+
122
+ self._init = nn.Sequential(
123
+ SpectralConv2d(
124
+ in_channels=in_channels,
125
+ out_channels=self._channels[1024],
126
+ kernel_size=4,
127
+ stride=2,
128
+ padding=1,
129
+ bias=False,
130
+ ),
131
+ nn.LeakyReLU(negative_slope=0.2),
132
+ SpectralConv2d(
133
+ in_channels=self._channels[1024],
134
+ out_channels=self._channels[512],
135
+ kernel_size=4,
136
+ stride=2,
137
+ padding=1,
138
+ bias=False,
139
+ ),
140
+ nn.BatchNorm2d(num_features=self._channels[512]),
141
+ nn.LeakyReLU(negative_slope=0.2),
142
+ )
143
+
144
+ self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256])
145
+ self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128])
146
+ self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] )
147
+ self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] )
148
+ self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] )
149
+
150
+ self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64])
151
+ self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32])
152
+ self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16])
153
+
154
+ self._small_track = nn.Sequential(
155
+ SpectralConv2d(
156
+ in_channels=in_channels,
157
+ out_channels=self._channels[256],
158
+ kernel_size=4,
159
+ stride=2,
160
+ padding=1,
161
+ bias=False,
162
+ ),
163
+ nn.LeakyReLU(negative_slope=0.2),
164
+ DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]),
165
+ DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ),
166
+ DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ),
167
+ )
168
+
169
+ self._features_large = nn.Sequential(
170
+ SpectralConv2d(
171
+ in_channels=self._channels[16] ,
172
+ out_channels=self._channels[8],
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0,
176
+ bias=False,
177
+ ),
178
+ nn.BatchNorm2d(num_features=self._channels[8]),
179
+ nn.LeakyReLU(negative_slope=0.2),
180
+ SpectralConv2d(
181
+ in_channels=self._channels[8],
182
+ out_channels=1,
183
+ kernel_size=4,
184
+ stride=1,
185
+ padding=0,
186
+ bias=False,
187
+ )
188
+ )
189
+
190
+ self._features_small = nn.Sequential(
191
+ SpectralConv2d(
192
+ in_channels=self._channels[32],
193
+ out_channels=1,
194
+ kernel_size=4,
195
+ stride=1,
196
+ padding=0,
197
+ bias=False,
198
+ ),
199
+ )
200
+
201
+ self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3)
202
+ self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3)
203
+ self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3)
204
+
205
+ def forward(self, images_1024: torch.Tensor,
206
+ images_128: torch.Tensor,
207
+ image_type: ImageType) -> \
208
+ Union[
209
+ torch.Tensor,
210
+ Tuple[torch.Tensor, Tuple[Any, Any, Any]]
211
+ ]:
212
+ # large track
213
+
214
+ down_512 = self._init(images_1024)
215
+ down_256 = self._downsample_256(down_512)
216
+ down_128 = self._downsample_128(down_256)
217
+
218
+ down_64 = self._downsample_64(down_128)
219
+ down_64 = self._sle_64(down_512, down_64)
220
+
221
+ down_32 = self._downsample_32(down_64)
222
+ down_32 = self._sle_32(down_256, down_32)
223
+
224
+ down_16 = self._downsample_16(down_32)
225
+ down_16 = self._sle_16(down_128, down_16)
226
+
227
+ # small track
228
+
229
+ down_small = self._small_track(images_128)
230
+
231
+ # features
232
+
233
+ features_large = self._features_large(down_16).view(-1)
234
+ features_small = self._features_small(down_small).view(-1)
235
+ features = torch.cat([features_large, features_small], dim=0)
236
+
237
+ # decoder
238
+
239
+ if image_type != ImageType.FAKE:
240
+ dec_large = self._decoder_large(down_16)
241
+ dec_small = self._decoder_small(down_small)
242
+ dec_piece = self._decoder_piece(crop_image_part(down_32, image_type))
243
+ return features, (dec_large, dec_small, dec_piece)
244
+
245
+ return features
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/community-events@main
2
+ streamlit==1.8.0
3
+ torch
4
+ streamlit-lottie
5
+ streamlit-option-menu
utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from enum import Enum
4
+
5
+ import base64
6
+ import json
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ import requests
10
+ import re
11
+ from copy import deepcopy
12
+
13
+ class ImageType(Enum):
14
+ REAL_UP_L = 0
15
+ REAL_UP_R = 1
16
+ REAL_DOWN_R = 2
17
+ REAL_DOWN_L = 3
18
+ FAKE = 4
19
+
20
+
21
+ def crop_image_part(image: torch.Tensor,
22
+ part: ImageType) -> torch.Tensor:
23
+ size = image.shape[2] // 2
24
+
25
+ if part == ImageType.REAL_UP_L:
26
+ return image[:, :, :size, :size]
27
+
28
+ elif part == ImageType.REAL_UP_R:
29
+ return image[:, :, :size, size:]
30
+
31
+ elif part == ImageType.REAL_DOWN_L:
32
+ return image[:, :, size:, :size]
33
+
34
+ elif part == ImageType.REAL_DOWN_R:
35
+ return image[:, :, size:, size:]
36
+
37
+ else:
38
+ raise ValueError('invalid part')
39
+
40
+
41
+ def init_weights(module: nn.Module):
42
+ if isinstance(module, nn.Conv2d):
43
+ torch.nn.init.normal_(module.weight, 0.0, 0.02)
44
+
45
+ if isinstance(module, nn.BatchNorm2d):
46
+ torch.nn.init.normal_(module.weight, 1.0, 0.02)
47
+ module.bias.data.fill_(0)
48
+
49
+ def load_image_from_local(image_path, image_resize=None):
50
+ image = Image.open(image_path)
51
+
52
+ if isinstance(image_resize, tuple):
53
+ image = image.resize(image_resize)
54
+ return image
55
+
56
+ def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None):
57
+ try:
58
+ image = Image.open(requests.get(image_url, stream=True).raw)
59
+
60
+ if rgba_mode:
61
+ image = image.convert("RGBA")
62
+
63
+ if isinstance(image_resize, tuple):
64
+ image = image.resize(image_resize)
65
+
66
+ except Exception as e:
67
+ image = None
68
+ if default_image:
69
+ image = load_image_from_local(default_image, image_resize=image_resize)
70
+
71
+ return image
72
+
73
+ def image_to_base64(image_array):
74
+ buffered = BytesIO()
75
+ image_array.save(buffered, format="PNG")
76
+ image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
77
+ return f"data:image/png;base64, {image_b64}"
78
+
79
+
80
+ def copy_G_params(model):
81
+ flatten = deepcopy(list(p.data for p in model.parameters()))
82
+ return flatten
83
+
84
+
85
+ def load_params(model, new_param):
86
+ for p, new_p in zip(model.parameters(), new_param):
87
+ p.data.copy_(new_p)