ivanovot commited on
Commit
6b599fd
·
1 Parent(s): a6a4a50
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. QuckDrawGAN/__init__.py +2 -0
  3. QuckDrawGAN/__pycache__/__init__.cpython-312.pyc +0 -0
  4. QuckDrawGAN/__pycache__/model.cpython-312.pyc +0 -0
  5. QuckDrawGAN/__pycache__/train.cpython-312.pyc +0 -0
  6. QuckDrawGAN/model.py +89 -0
  7. QuckDrawGAN/train.py +198 -0
  8. QuckDrawGAN/utils/__pycache__/data.cpython-312.pyc +0 -0
  9. QuckDrawGAN/utils/__pycache__/models.cpython-312.pyc +0 -0
  10. QuckDrawGAN/utils/data.py +63 -0
  11. QuckDrawGAN/utils/models.py +107 -0
  12. app.py +71 -0
  13. pretrained_output/images/1.png +0 -0
  14. pretrained_output/images/10.png +0 -0
  15. pretrained_output/images/100.png +0 -0
  16. pretrained_output/images/11.png +0 -0
  17. pretrained_output/images/12.png +0 -0
  18. pretrained_output/images/13.png +0 -0
  19. pretrained_output/images/14.png +0 -0
  20. pretrained_output/images/15.png +0 -0
  21. pretrained_output/images/16.png +0 -0
  22. pretrained_output/images/17.png +0 -0
  23. pretrained_output/images/18.png +0 -0
  24. pretrained_output/images/19.png +0 -0
  25. pretrained_output/images/2.png +0 -0
  26. pretrained_output/images/20.png +0 -0
  27. pretrained_output/images/21.png +0 -0
  28. pretrained_output/images/22.png +0 -0
  29. pretrained_output/images/23.png +0 -0
  30. pretrained_output/images/24.png +0 -0
  31. pretrained_output/images/25.png +0 -0
  32. pretrained_output/images/26.png +0 -0
  33. pretrained_output/images/27.png +0 -0
  34. pretrained_output/images/28.png +0 -0
  35. pretrained_output/images/29.png +0 -0
  36. pretrained_output/images/3.png +0 -0
  37. pretrained_output/images/30.png +0 -0
  38. pretrained_output/images/31.png +0 -0
  39. pretrained_output/images/32.png +0 -0
  40. pretrained_output/images/33.png +0 -0
  41. pretrained_output/images/34.png +0 -0
  42. pretrained_output/images/35.png +0 -0
  43. pretrained_output/images/36.png +0 -0
  44. pretrained_output/images/37.png +0 -0
  45. pretrained_output/images/38.png +0 -0
  46. pretrained_output/images/39.png +0 -0
  47. pretrained_output/images/4.png +0 -0
  48. pretrained_output/images/40.png +0 -0
  49. pretrained_output/images/41.png +0 -0
  50. pretrained_output/images/42.png +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 ivanovot
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
QuckDrawGAN/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .train import train, discriminator_fine_tune
2
+ from .model import Model
QuckDrawGAN/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (250 Bytes). View file
 
QuckDrawGAN/__pycache__/model.cpython-312.pyc ADDED
Binary file (5.46 kB). View file
 
QuckDrawGAN/__pycache__/train.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
QuckDrawGAN/model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import numpy as np
4
+ import torchvision.utils as vutils
5
+ from .utils.models import Generator, Discriminator, latent_dim
6
+ import hashlib
7
+ from PIL import Image
8
+ import warnings
9
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
10
+
11
+
12
+ class Model:
13
+ def __init__(self, generator_path, discriminator_path=None):
14
+ # Определяем устройство для выполнения (GPU или CPU)
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Инициализация и загрузка генератора
18
+ self.generator = Generator(latent_dim).to(self.device)
19
+ self.generator.load_state_dict(torch.load(generator_path, map_location=self.device, weights_only=True)) # Загружаем веса генератора
20
+ self.generator.eval() # Переводим генератор в режим оценки
21
+
22
+ # Инициализация дискриминатора, если задан путь к его весам
23
+ if discriminator_path:
24
+ self.discriminator = Discriminator().to(self.device)
25
+ self.discriminator.load_state_dict(torch.load(discriminator_path, map_location=self.device, weights_only=True)) # Загружаем веса дискриминатора
26
+ self.discriminator.eval() # Переводим дискриминатор в режим оценки
27
+ else:
28
+ self.discriminator = None # Если дискриминатор не используется
29
+
30
+ def generate(self, n=1, seed=None):
31
+ """Генерирует n изображений. Если дискриминатор загружен, возвращает изображение с наибольшей оценкой дискриминатора."""
32
+ with torch.no_grad(): # Отключаем градиенты для режима оценки
33
+ # Установка сида для воспроизводимости, если задан
34
+ if seed is not None:
35
+ seed_number = int(hashlib.md5(seed.encode()).hexdigest(), 16) % (2**32) # Преобразуем текстовый сид в число
36
+ torch.manual_seed(seed_number) # Устанавливаем сид для генерации
37
+
38
+ # Генерация случайного латентного вектора
39
+ z = torch.randn(n, latent_dim).to(self.device)
40
+
41
+ # Генерация изображений
42
+ gen_imgs = self.generator(z)
43
+
44
+ # Если дискриминатор загружен, выбираем изображение с наилучшей оценкой
45
+ if self.discriminator:
46
+ predictions = self.discriminator(gen_imgs).cpu().numpy().flatten() # Получаем оценки дискриминатора
47
+ max_pred_idx = predictions.argmax() # Находим индекс изображения с максимальной оценкой
48
+ best_img = gen_imgs[max_pred_idx].cpu().squeeze().numpy() # Преобразуем изображение в формат (H, W)
49
+ return best_img # Возвращаем лучшее изображение
50
+ else:
51
+ # Если дискриминатор не загружен, возвращаем первое сгенерированное изображение
52
+ return gen_imgs[0].cpu().squeeze().numpy() # Преобразуем изображение в формат (H, W)
53
+
54
+ if __name__ == "__main__":
55
+ # Определение аргументов командной строки
56
+ parser = argparse.ArgumentParser(description="Generate image using pretrained GAN model")
57
+ parser.add_argument('--generator_path', type=str, required=True, help='Path to generator model weights')
58
+ parser.add_argument('--discriminator_path', type=str, help='Path to discriminator model weights (optional)')
59
+ parser.add_argument('--output_path', type=str, default='result.png', help='Path to save the generated image')
60
+ parser.add_argument('--n', type=int, default=1, help='Number of images to generate')
61
+ parser.add_argument('--seed', type=str, help='Seed for random generation (optional)')
62
+
63
+ args = parser.parse_args()
64
+
65
+ # Инициализация модели и генерация изображения
66
+ model = Model(args.generator_path, args.discriminator_path)
67
+ generated_image = model.generate(n=args.n, seed=args.seed)
68
+
69
+ # Нормализация изображения
70
+ min_val = np.min(generated_image)
71
+ max_val = np.max(generated_image)
72
+
73
+ # Применяем нормализацию
74
+ normalized_image = (generated_image - min_val) / (max_val - min_val) * 255
75
+
76
+ # Приводим к 8-битному формату
77
+ normalized_image = normalized_image.astype(np.uint8)
78
+
79
+ # Проверяем размерность и преобразуем в RGB, если это необходимо
80
+ if normalized_image.ndim == 2: # Если изображение в градациях серого
81
+ # Преобразуем в RGB (64, 64) -> (64, 64, 3)
82
+ normalized_image = np.stack([normalized_image] * 3, axis=-1)
83
+ elif normalized_image.shape[2] == 1: # Если изображение с одним каналом
84
+ # Удаляем канал и преобразуем в RGB
85
+ normalized_image = np.squeeze(normalized_image, axis=2)
86
+
87
+ # Создаем изображение в формате RGB
88
+ img = Image.fromarray(normalized_image, mode='RGB')
89
+ img.save(args.output_path)
QuckDrawGAN/train.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import random
5
+ import torch
6
+ import torch.optim as optim
7
+ import torch.nn as nn
8
+ import torchvision.utils as vutils
9
+ from torch.utils.data import DataLoader
10
+ from tqdm import tqdm
11
+ from .utils.models import Generator, Discriminator, latent_dim
12
+ from .utils.data import DrawDataset
13
+
14
+
15
+ def train(epochs, batch_size, data_path, output_path='output', lr_g=0.001, lr_d=0.002, data_max_size=None):
16
+ # Создание директорий для сохранения изображений и моделей
17
+ os.makedirs(os.path.join(output_path, 'images'), exist_ok=True)
18
+ os.makedirs(os.path.join(output_path, 'models'), exist_ok=True)
19
+
20
+ # Инициализация логирования
21
+ log_file = os.path.join(output_path, 'training_logs.log')
22
+ with open(log_file, 'w'):
23
+ pass # Очищаем файл логов
24
+ logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s')
25
+
26
+ # Определение устройства для обучения
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ logging.info(f"Using device: {device}")
29
+
30
+ logging.info("Loading dataset")
31
+ dataset = DrawDataset(data_path, data_max_size)
32
+
33
+ # Инициализация генератора и дискриминатора
34
+ generator = Generator(latent_dim).to(device)
35
+ discriminator = Discriminator().to(device)
36
+
37
+ # Оптимизаторы для генератора и дискриминатора
38
+ optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.9, 0.999))
39
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.9, 0.999))
40
+
41
+ # Функция потерь
42
+ adversarial_loss = nn.L1Loss()
43
+
44
+ # Фиксированные векторы шума для генерации изображений в каждом эпохе
45
+ fix_z = torch.randn(64, latent_dim).to(device)
46
+
47
+ # Загрузка данных
48
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
49
+
50
+ logging.info("Training started")
51
+
52
+ # Основной цикл обучения
53
+ for epoch in range(epochs):
54
+ progress_bar = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{epochs}]", bar_format='{l_bar}{bar:12}{r_bar}')
55
+ generator.train()
56
+
57
+ for i, real_imgs in enumerate(progress_bar):
58
+ real_imgs = real_imgs.to(device)
59
+ batch_size = real_imgs.size(0)
60
+
61
+ # Создание меток для реальных и поддельных изображений
62
+ valid_labels = torch.full((batch_size, 1), random.uniform(0.7, 1)).to(device)
63
+ fake_labels = torch.full((batch_size, 1), random.uniform(0, 0.3)).to(device)
64
+
65
+ # Обновление дискриминатора
66
+ optimizer_D.zero_grad()
67
+
68
+ # Генерация поддельных изображений
69
+ z = torch.randn(batch_size, latent_dim).to(device)
70
+ gen_imgs = generator(z)
71
+
72
+ # Вычисление потерь для реальных и поддельных изображений
73
+ real_preds = discriminator(real_imgs)
74
+ fake_preds = discriminator(gen_imgs.detach())
75
+
76
+ loss_real = adversarial_loss(real_preds, valid_labels)
77
+ loss_fake = adversarial_loss(fake_preds, fake_labels)
78
+
79
+ loss_D = loss_real + loss_fake
80
+ loss_D.backward()
81
+ optimizer_D.step()
82
+
83
+ # Обновление генератора
84
+ optimizer_G.zero_grad()
85
+
86
+ # Генерация новых поддельных изображений
87
+ gen_imgs = generator(z)
88
+
89
+ # Потери генератора на основе предсказаний дискриминатора
90
+ fake_preds_for_gen = discriminator(gen_imgs)
91
+ loss_G = adversarial_loss(fake_preds_for_gen, valid_labels)
92
+
93
+ loss_G.backward()
94
+ optimizer_G.step()
95
+
96
+ # Обновление информации в прогресс-баре
97
+ progress_bar.set_postfix(Loss_D=loss_D.item(), Loss_G=loss_G.item())
98
+
99
+ # Логирование итогов эпохи
100
+ logging.info(f"Epoch [{epoch+1}/{epochs}], Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
101
+
102
+ # Сохранение изображений и модели
103
+ with torch.no_grad():
104
+ generator.eval()
105
+ gen_imgs = generator(fix_z)
106
+ vutils.save_image(gen_imgs.data, os.path.join(output_path, 'images', f'{epoch+1}.png'), nrow=8, normalize=True)
107
+ torch.save(generator.state_dict(), os.path.join(output_path, 'models', 'generator.pt'))
108
+ torch.save(discriminator.state_dict(), os.path.join(output_path, 'models', 'discriminator.pt'))
109
+
110
+
111
+ def discriminator_fine_tune(generator_file, discriminator_file, data_path, batch_size=64, fine_tune_epochs=10, lr_d=0.002, data_max_size=None):
112
+ # Определение устройства
113
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+ logging.info(f"Using device: {device}")
115
+
116
+ # Загрузка сохранённых моделей
117
+ generator = Generator(latent_dim).to(device)
118
+ discriminator = Discriminator().to(device)
119
+ generator.load_state_dict(torch.load(generator_file, map_location=device, weights_only=True))
120
+ discriminator.load_state_dict(torch.load(discriminator_file, map_location=device, weights_only=True))
121
+
122
+ # Оптимизатор для дискриминатора
123
+ optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.9, 0.999))
124
+ adversarial_loss = nn.L1Loss()
125
+
126
+ # Загрузка данных
127
+ dataset = DrawDataset(data_path, data_max_size)
128
+ fine_tune_dataloader = DataLoader(dataset, batch_size=batch_size//2, shuffle=True, pin_memory=True)
129
+
130
+ logging.info(f"Fine-tuning discriminator for {fine_tune_epochs} epochs")
131
+
132
+ for epoch in range(fine_tune_epochs):
133
+ progress_bar = tqdm(fine_tune_dataloader, desc=f"Fine-tuning Discriminator [{epoch+1}/{fine_tune_epochs}]", bar_format='{l_bar}{bar:12}{r_bar}')
134
+ discriminator.train()
135
+
136
+ for i, real_imgs in enumerate(progress_bar):
137
+ real_imgs = real_imgs.to(device)
138
+ batch_size = real_imgs.size(0)
139
+
140
+ # Создание меток для реальных и поддельных изображений
141
+ valid_labels = torch.full((batch_size, 1), random.uniform(0.7, 1)).to(device)
142
+ fake_labels = torch.full((batch_size // 2, 1), random.uniform(0, 0.3)).to(device)
143
+
144
+ # Обновление дискриминатора
145
+ optimizer_D.zero_grad()
146
+
147
+ # Генерация поддельных изображений
148
+ z = torch.randn(batch_size // 2, latent_dim).to(device)
149
+ gen_imgs = generator(z)
150
+
151
+ # Потери для реальных и поддельных изображений
152
+ real_preds = discriminator(real_imgs)
153
+ fake_preds = discriminator(gen_imgs.detach())
154
+
155
+ loss_real = adversarial_loss(real_preds, valid_labels)
156
+ loss_fake = adversarial_loss(fake_preds, fake_labels)
157
+
158
+ loss_D = loss_real + loss_fake
159
+ loss_D.backward()
160
+ optimizer_D.step()
161
+
162
+ # Обновление информации в прогресс-баре
163
+ progress_bar.set_postfix(Loss_D=loss_D.item())
164
+
165
+ # Логирование результатов дообучения
166
+ logging.info(f"Fine-tune Epoch [{epoch+1}/{fine_tune_epochs}], Loss_D: {loss_D.item():.4f}")
167
+
168
+ # Сохранение обновлённого дискриминатора
169
+ torch.save(discriminator.state_dict(), os.path.join(os.path.dirname(discriminator_file), 'discriminator_fine_tuned.pt'))
170
+
171
+
172
+ # Определение аргументов командной строки
173
+ if __name__ == "__main__":
174
+ parser = argparse.ArgumentParser(description="Train GAN model with specified parameters.")
175
+
176
+ # Основные аргументы для функции train
177
+ parser.add_argument('--epochs', type=int, default=10, help='Number of epochs for training')
178
+ parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
179
+ parser.add_argument('--data_path', type=str, required=True, help='Path to training data')
180
+ parser.add_argument('--output_path', type=str, default='output', help='Directory to save outputs')
181
+ parser.add_argument('--lr_g', type=float, default=0.001, help='Learning rate for generator')
182
+ parser.add_argument('--lr_d', type=float, default=0.002, help='Learning rate for discriminator')
183
+ parser.add_argument('--data_max_size', type=int, default=None, help='Maximum size of data to use')
184
+
185
+ # Аргументы для дообучения дискриминатора
186
+ parser.add_argument('--fine_tune', action='store_true', help='Fine-tune discriminator')
187
+ parser.add_argument('--generator_file', type=str, help='Path to generator weights for fine-tuning')
188
+ parser.add_argument('--discriminator_file', type=str, help='Path to discriminator weights for fine-tuning')
189
+ parser.add_argument('--fine_tune_epochs', type=int, default=10, help='Number of epochs for fine-tuning discriminator')
190
+
191
+ args = parser.parse_args()
192
+
193
+ if args.fine_tune:
194
+ # Запуск функции дообучения дискриминатора
195
+ discriminator_fine_tune(args.generator_file, args.discriminator_file, args.data_path, args.batch_size, args.fine_tune_epochs, args.lr_d, args.data_max_size)
196
+ else:
197
+ # Запуск основной функции тренировки
198
+ train(args.epochs, args.batch_size, args.data_path, args.output_path, args.lr_g, args.lr_d, args.data_max_size)
QuckDrawGAN/utils/__pycache__/data.cpython-312.pyc ADDED
Binary file (3.5 kB). View file
 
QuckDrawGAN/utils/__pycache__/models.cpython-312.pyc ADDED
Binary file (5.28 kB). View file
 
QuckDrawGAN/utils/data.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import pandas as pd
4
+ from PIL import Image, ImageDraw
5
+ import numpy as np
6
+
7
+ class DrawDataset(Dataset):
8
+ def __init__(self, file_path, data_max_size=None):
9
+ # Загрузка данных из файла формата JSON
10
+ self.data = pd.read_json(file_path, lines=True)
11
+ # Отбор только распознанных рисунков
12
+ self.data = self.data[self.data['recognized'] == True]
13
+ # Ограничение размера набора данных, если указано
14
+ if data_max_size and len(self.data) > data_max_size:
15
+ self.data = self.data[:data_max_size]
16
+
17
+ self.images = self.data['drawing'].values
18
+ self.processed_images = []
19
+
20
+ # Преобразование набора в изображения и нормализация
21
+ for raw_drawing in self.images:
22
+ img = self.stroke_to_image(raw_drawing)
23
+ img = np.array(img).astype(np.float32) / 255.0 # Нормализация изображения в диапазон [0, 1]
24
+ img = torch.from_numpy(img) # Преобразование в тензор PyTorch
25
+ self.processed_images.append(img.unsqueeze(0)) # Добавление оси канала (1, 64, 64)
26
+
27
+ def stroke_to_image(self, raw_drawing):
28
+ # Коэффициенты для изменения размера изображения и его улучшения
29
+ scale_factor = 0.22 # Масштаб для уменьшения координат рисунков
30
+ upscale_factor = 8 # Коэффициент увеличения для получения плавных линий
31
+ original_size = 64 # Окончательный размер изображения
32
+ large_size = original_size * upscale_factor # Увеличенный размер для рисования линий
33
+
34
+ # Преобразование координат линий с масштабированием и смещением
35
+ polylines = (
36
+ zip([(x + 25) * scale_factor * upscale_factor for x in polyline[0]],
37
+ [(y + 25) * scale_factor * upscale_factor for y in polyline[1]])
38
+ for polyline in raw_drawing if len(polyline) == 2
39
+ )
40
+
41
+ # Преобразуем набор линий в список для последующего рисования
42
+ polylines_list = [list(polyline) for polyline in polylines]
43
+
44
+ # Создание пустого увеличенного изображения
45
+ pil_img = Image.new("L", (large_size, large_size), 255) # Черно-белое изображение, белый фон
46
+ d = ImageDraw.Draw(pil_img)
47
+
48
+ # Рисование линий с учетом масштабирования и увеличенной толщины
49
+ for polyline in polylines_list:
50
+ d.line(polyline, fill=0, width=int(1.5 * upscale_factor)) # Линии черного цвета
51
+
52
+ # Масштабирование изображения обратно до 64x64 с использованием LANCZOS для сглаживания
53
+ pil_img = pil_img.resize((original_size, original_size), Image.Resampling.LANCZOS)
54
+
55
+ return pil_img
56
+
57
+ def __len__(self):
58
+ # Возвращает количество изображений в наборе данных
59
+ return len(self.images)
60
+
61
+ def __getitem__(self, idx):
62
+ # Возвращает обработанное изображение по индексу
63
+ return self.processed_images[idx]
QuckDrawGAN/utils/models.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ latent_dim = 100 # Размерность латентного пространства (входного шума для генератора)
4
+
5
+ # Класс генератора для генерации изображений из латентного пространства
6
+ class Generator(nn.Module):
7
+ def __init__(self, latent_dim):
8
+ super(Generator, self).__init__()
9
+
10
+ # Размеры начального изображения и количество каналов для начала транспонированных операций
11
+ self.init_size = 4 # Размер изображения после первой линейной трансформации
12
+ self.start_channels = 512 # Количество каналов на первом этапе
13
+
14
+ # Последовательная модель генератора
15
+ self.model = nn.Sequential(
16
+ # Линейное преобразование латентного вектора в развернутую форму для дальнейшего увеличения
17
+ nn.Linear(latent_dim, self.start_channels * self.init_size ** 2),
18
+ nn.BatchNorm1d(self.start_channels * self.init_size ** 2, 0.8),
19
+ nn.LeakyReLU(0.2, inplace=True),
20
+
21
+ # Преобразование в 4D тензор для начала операций с изображениями
22
+ nn.Unflatten(1, (self.start_channels, self.init_size, self.init_size)),
23
+
24
+ # Начало операций с изображением (увеличение размера)
25
+ nn.Upsample(scale_factor=2), # Увеличение размера изображения в 2 раза
26
+
27
+ # Сверточные слои с уменьшением количества каналов и последующими нелинейностями
28
+ nn.Conv2d(self.start_channels, self.start_channels // 3, 3, stride=1, padding=1),
29
+ nn.BatchNorm2d(self.start_channels // 3, 0.8),
30
+ nn.LeakyReLU(0.2, inplace=True),
31
+
32
+ nn.Upsample(scale_factor=2), # Еще одно увеличение
33
+
34
+ nn.Conv2d(self.start_channels // 3, self.start_channels // 4, 3, stride=1, padding=1),
35
+ nn.BatchNorm2d(self.start_channels // 4, 0.8),
36
+ nn.LeakyReLU(0.2, inplace=True),
37
+
38
+ nn.Upsample(scale_factor=2), # Третье увеличение
39
+
40
+ nn.Conv2d(self.start_channels // 4, self.start_channels // 6, 3, stride=1, padding=1),
41
+ nn.BatchNorm2d(self.start_channels // 6, 0.8),
42
+ nn.LeakyReLU(0.2, inplace=True),
43
+
44
+ nn.Upsample(scale_factor=2), # Четвертое увеличение
45
+
46
+ nn.Conv2d(self.start_channels // 6, self.start_channels // 8, 3, stride=1, padding=1),
47
+ nn.BatchNorm2d(self.start_channels // 8, 0.8),
48
+ nn.LeakyReLU(0.2, inplace=True),
49
+
50
+ # Последний сверточный слой для вывода изображения размером 1xWxH с функцией активации Тангенс
51
+ nn.Conv2d(self.start_channels // 8, 1, 3, stride=1, padding=1),
52
+ nn.Tanh() # Приведение значений пикселей в диапазон [-1, 1]
53
+ )
54
+
55
+ def forward(self, z):
56
+ # Прямое распространение через сеть генератора
57
+ out = self.model(z)
58
+ return out # Возвращаем сгенерированное изображение
59
+
60
+
61
+ # Класс дискриминатора для различения реальных и сгенерированных изображений
62
+ class Discriminator(nn.Module):
63
+ def __init__(self):
64
+ super(Discriminator, self).__init__()
65
+
66
+ # Последовательная модель дискриминатора
67
+ self.model = nn.Sequential(
68
+ # Первый сверточный блок
69
+ nn.Conv2d(1, 64, 3, stride=2, padding=1), # Уменьшает размер изображения до (64, 32, 32)
70
+ nn.BatchNorm2d(64, 0.8),
71
+ nn.LeakyReLU(0.2, inplace=True),
72
+ nn.Dropout(0.3), # Вероятность выключения нейронов для регуляризации
73
+
74
+ # Второй сверточный блок
75
+ nn.Conv2d(64, 128, 3, stride=2, padding=1), # Уменьшает размер до (128, 16, 16)
76
+ nn.BatchNorm2d(128, 0.8),
77
+ nn.LeakyReLU(0.2, inplace=True),
78
+ nn.Dropout(0.3),
79
+
80
+ # Третий сверточный блок
81
+ nn.Conv2d(128, 256, 3, stride=2, padding=1), # Уменьшает размер до (256, 8, 8)
82
+ nn.BatchNorm2d(256, 0.8),
83
+ nn.LeakyReLU(0.2, inplace=True),
84
+ nn.Dropout(0.3),
85
+
86
+ # Четвертый сверточный блок
87
+ nn.Conv2d(256, 256, 3, stride=1, padding=1), # Поддерживает размер (256, 8, 8)
88
+ nn.BatchNorm2d(256, 0.8),
89
+ nn.LeakyReLU(0.2, inplace=True),
90
+ nn.MaxPool2d(2), # Уменьшает размер до (256, 4, 4)
91
+ nn.Dropout(0.3),
92
+
93
+ # Пятый сверточный блок
94
+ nn.Conv2d(256, 512, 3, stride=1, padding=1), # Уменьшает размер до (512, 2, 2)
95
+ nn.BatchNorm2d(512, 0.8),
96
+ nn.LeakyReLU(0.2, inplace=True),
97
+ nn.MaxPool2d(2), # Размер до (512, 1, 1)
98
+
99
+ # Преобразование в плоский вектор для классификации
100
+ nn.Flatten(), # Преобразует изображение в вектор (512 * 2 * 2 = 2048)
101
+ nn.Linear(512 * 2 * 2, 1) # Полносвязный слой для получения одного скалярного выхода
102
+ )
103
+
104
+ def forward(self, img):
105
+ # Прямое распространение через сеть дискриминатора
106
+ out = self.model(img)
107
+ return out # Возвращаем вероятность принадлежности к классу "реальное" или "сгенерированное"
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import QuckDrawGAN as qd
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # Загрузим модель
7
+ generator_file = r'pretrained_output/models/generator.pt'
8
+ discriminator_file = r'pretrained_output/models/discriminator_fine_tuned.pt'
9
+
10
+ # Создаем объект модели
11
+ model = qd.Model(generator_file, discriminator_file)
12
+
13
+ # Функция для генерации изображения с учетом сида, нормализацией и изменением размера
14
+ def generate_image(n_images=16, seed=""):
15
+ # Если сид не задан, не передаем его в модель
16
+ if seed == "":
17
+ seed = None
18
+ best_image = model.generate(n_images, seed) # Генерация с учетом сида
19
+
20
+ # Нормализация: находим минимум и максимум в изображении
21
+ best_image_min = np.min(best_image)
22
+ best_image_max = np.max(best_image)
23
+
24
+ # Нормализуем изображение, чтобы значения были в диапазоне от 0 до 255
25
+ normalized_image = 255 * (best_image - best_image_min) / (best_image_max - best_image_min)
26
+
27
+ # Преобразуем изображение в формат, подходящий для отображения
28
+ pil_image = Image.fromarray(normalized_image.astype(np.uint8)) # Преобразуем в uint8 для отображения
29
+ pil_image = pil_image.resize((256, 256), Image.Resampling.LANCZOS) # Ресайз изображения до 256x256
30
+
31
+ return pil_image
32
+
33
+ # Создаем интерфейс Gradio с помощью Blocks (для большей гибкости)
34
+ with gr.Blocks() as interface:
35
+ gr.Markdown("# Генератор изображений с использованием QuckDrawGAN")
36
+ gr.Markdown("Этот интерфейс позволяет генерировать изображения с помощью модели QuckDrawGAN. Настройте количество генерируемых изображений и задайте сид для повторяемости.")
37
+
38
+ with gr.Row():
39
+ # Блок для изображения сверху
40
+ generated_image = gr.Image(type="pil", label="Сгенерированное изображение", elem_id="generated_image", scale=2) # Увеличиваем масштаб изображения
41
+
42
+ with gr.Row():
43
+ # Блок параметров и кнопки
44
+ with gr.Column():
45
+ seed_input = gr.Textbox(value="", label="Сид (опционально)", interactive=True)
46
+ num_images = gr.Slider(minimum=1, maximum=1024, value=16, label="Количество изображений для генерации", interactive=True, step=1)
47
+
48
+ # Кнопка генерации изображения справа
49
+ generate_button = gr.Button("Сгенерировать")
50
+
51
+ # Логика для автогенерации при изменении параметров
52
+ seed_input.change(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
53
+ num_images.change(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
54
+
55
+ # Логика для кнопки генерации
56
+ generate_button.click(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
57
+
58
+ # Автогенерация при старте
59
+ interface.load(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
60
+
61
+ # Стилизация блока изображения (увеличение размера блока)
62
+ interface.css = """
63
+ #generated_image {
64
+ width: 400px;
65
+ height: 400px;
66
+ margin-top: 20px;
67
+ }
68
+ """
69
+
70
+ # Запуск интерфейса
71
+ interface.launch()
pretrained_output/images/1.png ADDED
pretrained_output/images/10.png ADDED
pretrained_output/images/100.png ADDED
pretrained_output/images/11.png ADDED
pretrained_output/images/12.png ADDED
pretrained_output/images/13.png ADDED
pretrained_output/images/14.png ADDED
pretrained_output/images/15.png ADDED
pretrained_output/images/16.png ADDED
pretrained_output/images/17.png ADDED
pretrained_output/images/18.png ADDED
pretrained_output/images/19.png ADDED
pretrained_output/images/2.png ADDED
pretrained_output/images/20.png ADDED
pretrained_output/images/21.png ADDED
pretrained_output/images/22.png ADDED
pretrained_output/images/23.png ADDED
pretrained_output/images/24.png ADDED
pretrained_output/images/25.png ADDED
pretrained_output/images/26.png ADDED
pretrained_output/images/27.png ADDED
pretrained_output/images/28.png ADDED
pretrained_output/images/29.png ADDED
pretrained_output/images/3.png ADDED
pretrained_output/images/30.png ADDED
pretrained_output/images/31.png ADDED
pretrained_output/images/32.png ADDED
pretrained_output/images/33.png ADDED
pretrained_output/images/34.png ADDED
pretrained_output/images/35.png ADDED
pretrained_output/images/36.png ADDED
pretrained_output/images/37.png ADDED
pretrained_output/images/38.png ADDED
pretrained_output/images/39.png ADDED
pretrained_output/images/4.png ADDED
pretrained_output/images/40.png ADDED
pretrained_output/images/41.png ADDED
pretrained_output/images/42.png ADDED