Spaces:
Sleeping
Sleeping
ivanovot
commited on
Commit
·
6b599fd
1
Parent(s):
a6a4a50
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- QuckDrawGAN/__init__.py +2 -0
- QuckDrawGAN/__pycache__/__init__.cpython-312.pyc +0 -0
- QuckDrawGAN/__pycache__/model.cpython-312.pyc +0 -0
- QuckDrawGAN/__pycache__/train.cpython-312.pyc +0 -0
- QuckDrawGAN/model.py +89 -0
- QuckDrawGAN/train.py +198 -0
- QuckDrawGAN/utils/__pycache__/data.cpython-312.pyc +0 -0
- QuckDrawGAN/utils/__pycache__/models.cpython-312.pyc +0 -0
- QuckDrawGAN/utils/data.py +63 -0
- QuckDrawGAN/utils/models.py +107 -0
- app.py +71 -0
- pretrained_output/images/1.png +0 -0
- pretrained_output/images/10.png +0 -0
- pretrained_output/images/100.png +0 -0
- pretrained_output/images/11.png +0 -0
- pretrained_output/images/12.png +0 -0
- pretrained_output/images/13.png +0 -0
- pretrained_output/images/14.png +0 -0
- pretrained_output/images/15.png +0 -0
- pretrained_output/images/16.png +0 -0
- pretrained_output/images/17.png +0 -0
- pretrained_output/images/18.png +0 -0
- pretrained_output/images/19.png +0 -0
- pretrained_output/images/2.png +0 -0
- pretrained_output/images/20.png +0 -0
- pretrained_output/images/21.png +0 -0
- pretrained_output/images/22.png +0 -0
- pretrained_output/images/23.png +0 -0
- pretrained_output/images/24.png +0 -0
- pretrained_output/images/25.png +0 -0
- pretrained_output/images/26.png +0 -0
- pretrained_output/images/27.png +0 -0
- pretrained_output/images/28.png +0 -0
- pretrained_output/images/29.png +0 -0
- pretrained_output/images/3.png +0 -0
- pretrained_output/images/30.png +0 -0
- pretrained_output/images/31.png +0 -0
- pretrained_output/images/32.png +0 -0
- pretrained_output/images/33.png +0 -0
- pretrained_output/images/34.png +0 -0
- pretrained_output/images/35.png +0 -0
- pretrained_output/images/36.png +0 -0
- pretrained_output/images/37.png +0 -0
- pretrained_output/images/38.png +0 -0
- pretrained_output/images/39.png +0 -0
- pretrained_output/images/4.png +0 -0
- pretrained_output/images/40.png +0 -0
- pretrained_output/images/41.png +0 -0
- pretrained_output/images/42.png +0 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 ivanovot
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
QuckDrawGAN/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .train import train, discriminator_fine_tune
|
2 |
+
from .model import Model
|
QuckDrawGAN/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (250 Bytes). View file
|
|
QuckDrawGAN/__pycache__/model.cpython-312.pyc
ADDED
Binary file (5.46 kB). View file
|
|
QuckDrawGAN/__pycache__/train.cpython-312.pyc
ADDED
Binary file (11.3 kB). View file
|
|
QuckDrawGAN/model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torchvision.utils as vutils
|
5 |
+
from .utils.models import Generator, Discriminator, latent_dim
|
6 |
+
import hashlib
|
7 |
+
from PIL import Image
|
8 |
+
import warnings
|
9 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
10 |
+
|
11 |
+
|
12 |
+
class Model:
|
13 |
+
def __init__(self, generator_path, discriminator_path=None):
|
14 |
+
# Определяем устройство для выполнения (GPU или CPU)
|
15 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
# Инициализация и загрузка генератора
|
18 |
+
self.generator = Generator(latent_dim).to(self.device)
|
19 |
+
self.generator.load_state_dict(torch.load(generator_path, map_location=self.device, weights_only=True)) # Загружаем веса генератора
|
20 |
+
self.generator.eval() # Переводим генератор в режим оценки
|
21 |
+
|
22 |
+
# Инициализация дискриминатора, если задан путь к его весам
|
23 |
+
if discriminator_path:
|
24 |
+
self.discriminator = Discriminator().to(self.device)
|
25 |
+
self.discriminator.load_state_dict(torch.load(discriminator_path, map_location=self.device, weights_only=True)) # Загружаем веса дискриминатора
|
26 |
+
self.discriminator.eval() # Переводим дискриминатор в режим оценки
|
27 |
+
else:
|
28 |
+
self.discriminator = None # Если дискриминатор не используется
|
29 |
+
|
30 |
+
def generate(self, n=1, seed=None):
|
31 |
+
"""Генерирует n изображений. Если дискриминатор загружен, возвращает изображение с наибольшей оценкой дискриминатора."""
|
32 |
+
with torch.no_grad(): # Отключаем градиенты для режима оценки
|
33 |
+
# Установка сида для воспроизводимости, если задан
|
34 |
+
if seed is not None:
|
35 |
+
seed_number = int(hashlib.md5(seed.encode()).hexdigest(), 16) % (2**32) # Преобразуем текстовый сид в число
|
36 |
+
torch.manual_seed(seed_number) # Устанавливаем сид для генерации
|
37 |
+
|
38 |
+
# Генерация случайного латентного вектора
|
39 |
+
z = torch.randn(n, latent_dim).to(self.device)
|
40 |
+
|
41 |
+
# Генерация изображений
|
42 |
+
gen_imgs = self.generator(z)
|
43 |
+
|
44 |
+
# Если дискриминатор загружен, выбираем изображение с наилучшей оценкой
|
45 |
+
if self.discriminator:
|
46 |
+
predictions = self.discriminator(gen_imgs).cpu().numpy().flatten() # Получаем оценки дискриминатора
|
47 |
+
max_pred_idx = predictions.argmax() # Находим индекс изображения с максимальной оценкой
|
48 |
+
best_img = gen_imgs[max_pred_idx].cpu().squeeze().numpy() # Преобразуем изображение в формат (H, W)
|
49 |
+
return best_img # Возвращаем лучшее изображение
|
50 |
+
else:
|
51 |
+
# Если дискриминатор не загружен, возвращаем первое сгенерированное изображение
|
52 |
+
return gen_imgs[0].cpu().squeeze().numpy() # Преобразуем изображение в формат (H, W)
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
# Определение аргументов командной строки
|
56 |
+
parser = argparse.ArgumentParser(description="Generate image using pretrained GAN model")
|
57 |
+
parser.add_argument('--generator_path', type=str, required=True, help='Path to generator model weights')
|
58 |
+
parser.add_argument('--discriminator_path', type=str, help='Path to discriminator model weights (optional)')
|
59 |
+
parser.add_argument('--output_path', type=str, default='result.png', help='Path to save the generated image')
|
60 |
+
parser.add_argument('--n', type=int, default=1, help='Number of images to generate')
|
61 |
+
parser.add_argument('--seed', type=str, help='Seed for random generation (optional)')
|
62 |
+
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
# Инициализация модели и генерация изображения
|
66 |
+
model = Model(args.generator_path, args.discriminator_path)
|
67 |
+
generated_image = model.generate(n=args.n, seed=args.seed)
|
68 |
+
|
69 |
+
# Нормализация изображения
|
70 |
+
min_val = np.min(generated_image)
|
71 |
+
max_val = np.max(generated_image)
|
72 |
+
|
73 |
+
# Применяем нормализацию
|
74 |
+
normalized_image = (generated_image - min_val) / (max_val - min_val) * 255
|
75 |
+
|
76 |
+
# Приводим к 8-битному формату
|
77 |
+
normalized_image = normalized_image.astype(np.uint8)
|
78 |
+
|
79 |
+
# Проверяем размерность и преобразуем в RGB, если это необходимо
|
80 |
+
if normalized_image.ndim == 2: # Если изображение в градациях серого
|
81 |
+
# Преобразуем в RGB (64, 64) -> (64, 64, 3)
|
82 |
+
normalized_image = np.stack([normalized_image] * 3, axis=-1)
|
83 |
+
elif normalized_image.shape[2] == 1: # Если изображение с одним каналом
|
84 |
+
# Удаляем канал и преобразуем в RGB
|
85 |
+
normalized_image = np.squeeze(normalized_image, axis=2)
|
86 |
+
|
87 |
+
# Создаем изображение в формате RGB
|
88 |
+
img = Image.fromarray(normalized_image, mode='RGB')
|
89 |
+
img.save(args.output_path)
|
QuckDrawGAN/train.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.optim as optim
|
7 |
+
import torch.nn as nn
|
8 |
+
import torchvision.utils as vutils
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from tqdm import tqdm
|
11 |
+
from .utils.models import Generator, Discriminator, latent_dim
|
12 |
+
from .utils.data import DrawDataset
|
13 |
+
|
14 |
+
|
15 |
+
def train(epochs, batch_size, data_path, output_path='output', lr_g=0.001, lr_d=0.002, data_max_size=None):
|
16 |
+
# Создание директорий для сохранения изображений и моделей
|
17 |
+
os.makedirs(os.path.join(output_path, 'images'), exist_ok=True)
|
18 |
+
os.makedirs(os.path.join(output_path, 'models'), exist_ok=True)
|
19 |
+
|
20 |
+
# Инициализация логирования
|
21 |
+
log_file = os.path.join(output_path, 'training_logs.log')
|
22 |
+
with open(log_file, 'w'):
|
23 |
+
pass # Очищаем файл логов
|
24 |
+
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s')
|
25 |
+
|
26 |
+
# Определение устройства для обучения
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
logging.info(f"Using device: {device}")
|
29 |
+
|
30 |
+
logging.info("Loading dataset")
|
31 |
+
dataset = DrawDataset(data_path, data_max_size)
|
32 |
+
|
33 |
+
# Инициализация генератора и дискриминатора
|
34 |
+
generator = Generator(latent_dim).to(device)
|
35 |
+
discriminator = Discriminator().to(device)
|
36 |
+
|
37 |
+
# Оптимизаторы для генератора и дискриминатора
|
38 |
+
optimizer_G = optim.Adam(generator.parameters(), lr=lr_g, betas=(0.9, 0.999))
|
39 |
+
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.9, 0.999))
|
40 |
+
|
41 |
+
# Функция потерь
|
42 |
+
adversarial_loss = nn.L1Loss()
|
43 |
+
|
44 |
+
# Фиксированные векторы шума для генерации изображений в каждом эпохе
|
45 |
+
fix_z = torch.randn(64, latent_dim).to(device)
|
46 |
+
|
47 |
+
# Загрузка данных
|
48 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
|
49 |
+
|
50 |
+
logging.info("Training started")
|
51 |
+
|
52 |
+
# Основной цикл обучения
|
53 |
+
for epoch in range(epochs):
|
54 |
+
progress_bar = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{epochs}]", bar_format='{l_bar}{bar:12}{r_bar}')
|
55 |
+
generator.train()
|
56 |
+
|
57 |
+
for i, real_imgs in enumerate(progress_bar):
|
58 |
+
real_imgs = real_imgs.to(device)
|
59 |
+
batch_size = real_imgs.size(0)
|
60 |
+
|
61 |
+
# Создание меток для реальных и поддельных изображений
|
62 |
+
valid_labels = torch.full((batch_size, 1), random.uniform(0.7, 1)).to(device)
|
63 |
+
fake_labels = torch.full((batch_size, 1), random.uniform(0, 0.3)).to(device)
|
64 |
+
|
65 |
+
# Обновление дискриминатора
|
66 |
+
optimizer_D.zero_grad()
|
67 |
+
|
68 |
+
# Генерация поддельных изображений
|
69 |
+
z = torch.randn(batch_size, latent_dim).to(device)
|
70 |
+
gen_imgs = generator(z)
|
71 |
+
|
72 |
+
# Вычисление потерь для реальных и поддельных изображений
|
73 |
+
real_preds = discriminator(real_imgs)
|
74 |
+
fake_preds = discriminator(gen_imgs.detach())
|
75 |
+
|
76 |
+
loss_real = adversarial_loss(real_preds, valid_labels)
|
77 |
+
loss_fake = adversarial_loss(fake_preds, fake_labels)
|
78 |
+
|
79 |
+
loss_D = loss_real + loss_fake
|
80 |
+
loss_D.backward()
|
81 |
+
optimizer_D.step()
|
82 |
+
|
83 |
+
# Обновление генератора
|
84 |
+
optimizer_G.zero_grad()
|
85 |
+
|
86 |
+
# Генерация новых поддельных изображений
|
87 |
+
gen_imgs = generator(z)
|
88 |
+
|
89 |
+
# Потери генератора на основе предсказаний дискриминатора
|
90 |
+
fake_preds_for_gen = discriminator(gen_imgs)
|
91 |
+
loss_G = adversarial_loss(fake_preds_for_gen, valid_labels)
|
92 |
+
|
93 |
+
loss_G.backward()
|
94 |
+
optimizer_G.step()
|
95 |
+
|
96 |
+
# Обновление информации в прогресс-баре
|
97 |
+
progress_bar.set_postfix(Loss_D=loss_D.item(), Loss_G=loss_G.item())
|
98 |
+
|
99 |
+
# Логирование итогов эпохи
|
100 |
+
logging.info(f"Epoch [{epoch+1}/{epochs}], Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
|
101 |
+
|
102 |
+
# Сохранение изображений и модели
|
103 |
+
with torch.no_grad():
|
104 |
+
generator.eval()
|
105 |
+
gen_imgs = generator(fix_z)
|
106 |
+
vutils.save_image(gen_imgs.data, os.path.join(output_path, 'images', f'{epoch+1}.png'), nrow=8, normalize=True)
|
107 |
+
torch.save(generator.state_dict(), os.path.join(output_path, 'models', 'generator.pt'))
|
108 |
+
torch.save(discriminator.state_dict(), os.path.join(output_path, 'models', 'discriminator.pt'))
|
109 |
+
|
110 |
+
|
111 |
+
def discriminator_fine_tune(generator_file, discriminator_file, data_path, batch_size=64, fine_tune_epochs=10, lr_d=0.002, data_max_size=None):
|
112 |
+
# Определение устройства
|
113 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
114 |
+
logging.info(f"Using device: {device}")
|
115 |
+
|
116 |
+
# Загрузка сохранённых моделей
|
117 |
+
generator = Generator(latent_dim).to(device)
|
118 |
+
discriminator = Discriminator().to(device)
|
119 |
+
generator.load_state_dict(torch.load(generator_file, map_location=device, weights_only=True))
|
120 |
+
discriminator.load_state_dict(torch.load(discriminator_file, map_location=device, weights_only=True))
|
121 |
+
|
122 |
+
# Оптимизатор для дискриминатора
|
123 |
+
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.9, 0.999))
|
124 |
+
adversarial_loss = nn.L1Loss()
|
125 |
+
|
126 |
+
# Загрузка данных
|
127 |
+
dataset = DrawDataset(data_path, data_max_size)
|
128 |
+
fine_tune_dataloader = DataLoader(dataset, batch_size=batch_size//2, shuffle=True, pin_memory=True)
|
129 |
+
|
130 |
+
logging.info(f"Fine-tuning discriminator for {fine_tune_epochs} epochs")
|
131 |
+
|
132 |
+
for epoch in range(fine_tune_epochs):
|
133 |
+
progress_bar = tqdm(fine_tune_dataloader, desc=f"Fine-tuning Discriminator [{epoch+1}/{fine_tune_epochs}]", bar_format='{l_bar}{bar:12}{r_bar}')
|
134 |
+
discriminator.train()
|
135 |
+
|
136 |
+
for i, real_imgs in enumerate(progress_bar):
|
137 |
+
real_imgs = real_imgs.to(device)
|
138 |
+
batch_size = real_imgs.size(0)
|
139 |
+
|
140 |
+
# Создание меток для реальных и поддельных изображений
|
141 |
+
valid_labels = torch.full((batch_size, 1), random.uniform(0.7, 1)).to(device)
|
142 |
+
fake_labels = torch.full((batch_size // 2, 1), random.uniform(0, 0.3)).to(device)
|
143 |
+
|
144 |
+
# Обновление дискриминатора
|
145 |
+
optimizer_D.zero_grad()
|
146 |
+
|
147 |
+
# Генерация поддельных изображений
|
148 |
+
z = torch.randn(batch_size // 2, latent_dim).to(device)
|
149 |
+
gen_imgs = generator(z)
|
150 |
+
|
151 |
+
# Потери для реальных и поддельных изображений
|
152 |
+
real_preds = discriminator(real_imgs)
|
153 |
+
fake_preds = discriminator(gen_imgs.detach())
|
154 |
+
|
155 |
+
loss_real = adversarial_loss(real_preds, valid_labels)
|
156 |
+
loss_fake = adversarial_loss(fake_preds, fake_labels)
|
157 |
+
|
158 |
+
loss_D = loss_real + loss_fake
|
159 |
+
loss_D.backward()
|
160 |
+
optimizer_D.step()
|
161 |
+
|
162 |
+
# Обновление информации в прогресс-баре
|
163 |
+
progress_bar.set_postfix(Loss_D=loss_D.item())
|
164 |
+
|
165 |
+
# Логирование результатов дообучения
|
166 |
+
logging.info(f"Fine-tune Epoch [{epoch+1}/{fine_tune_epochs}], Loss_D: {loss_D.item():.4f}")
|
167 |
+
|
168 |
+
# Сохранение обновлённого дискриминатора
|
169 |
+
torch.save(discriminator.state_dict(), os.path.join(os.path.dirname(discriminator_file), 'discriminator_fine_tuned.pt'))
|
170 |
+
|
171 |
+
|
172 |
+
# Определение аргументов командной строки
|
173 |
+
if __name__ == "__main__":
|
174 |
+
parser = argparse.ArgumentParser(description="Train GAN model with specified parameters.")
|
175 |
+
|
176 |
+
# Основные аргументы для функции train
|
177 |
+
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs for training')
|
178 |
+
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
179 |
+
parser.add_argument('--data_path', type=str, required=True, help='Path to training data')
|
180 |
+
parser.add_argument('--output_path', type=str, default='output', help='Directory to save outputs')
|
181 |
+
parser.add_argument('--lr_g', type=float, default=0.001, help='Learning rate for generator')
|
182 |
+
parser.add_argument('--lr_d', type=float, default=0.002, help='Learning rate for discriminator')
|
183 |
+
parser.add_argument('--data_max_size', type=int, default=None, help='Maximum size of data to use')
|
184 |
+
|
185 |
+
# Аргументы для дообучения дискриминатора
|
186 |
+
parser.add_argument('--fine_tune', action='store_true', help='Fine-tune discriminator')
|
187 |
+
parser.add_argument('--generator_file', type=str, help='Path to generator weights for fine-tuning')
|
188 |
+
parser.add_argument('--discriminator_file', type=str, help='Path to discriminator weights for fine-tuning')
|
189 |
+
parser.add_argument('--fine_tune_epochs', type=int, default=10, help='Number of epochs for fine-tuning discriminator')
|
190 |
+
|
191 |
+
args = parser.parse_args()
|
192 |
+
|
193 |
+
if args.fine_tune:
|
194 |
+
# Запуск функции дообучения дискриминатора
|
195 |
+
discriminator_fine_tune(args.generator_file, args.discriminator_file, args.data_path, args.batch_size, args.fine_tune_epochs, args.lr_d, args.data_max_size)
|
196 |
+
else:
|
197 |
+
# Запуск основной функции тренировки
|
198 |
+
train(args.epochs, args.batch_size, args.data_path, args.output_path, args.lr_g, args.lr_d, args.data_max_size)
|
QuckDrawGAN/utils/__pycache__/data.cpython-312.pyc
ADDED
Binary file (3.5 kB). View file
|
|
QuckDrawGAN/utils/__pycache__/models.cpython-312.pyc
ADDED
Binary file (5.28 kB). View file
|
|
QuckDrawGAN/utils/data.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import pandas as pd
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class DrawDataset(Dataset):
|
8 |
+
def __init__(self, file_path, data_max_size=None):
|
9 |
+
# Загрузка данных из файла формата JSON
|
10 |
+
self.data = pd.read_json(file_path, lines=True)
|
11 |
+
# Отбор только распознанных рисунков
|
12 |
+
self.data = self.data[self.data['recognized'] == True]
|
13 |
+
# Ограничение размера набора данных, если указано
|
14 |
+
if data_max_size and len(self.data) > data_max_size:
|
15 |
+
self.data = self.data[:data_max_size]
|
16 |
+
|
17 |
+
self.images = self.data['drawing'].values
|
18 |
+
self.processed_images = []
|
19 |
+
|
20 |
+
# Преобразование набора в изображения и нормализация
|
21 |
+
for raw_drawing in self.images:
|
22 |
+
img = self.stroke_to_image(raw_drawing)
|
23 |
+
img = np.array(img).astype(np.float32) / 255.0 # Нормализация изображения в диапазон [0, 1]
|
24 |
+
img = torch.from_numpy(img) # Преобразование в тензор PyTorch
|
25 |
+
self.processed_images.append(img.unsqueeze(0)) # Добавление оси канала (1, 64, 64)
|
26 |
+
|
27 |
+
def stroke_to_image(self, raw_drawing):
|
28 |
+
# Коэффициенты для изменения размера изображения и его улучшения
|
29 |
+
scale_factor = 0.22 # Масштаб для уменьшения координат рисунков
|
30 |
+
upscale_factor = 8 # Коэффициент увеличения для получения плавных линий
|
31 |
+
original_size = 64 # Окончательный размер изображения
|
32 |
+
large_size = original_size * upscale_factor # Увеличенный размер для рисования линий
|
33 |
+
|
34 |
+
# Преобразование координат линий с масштабированием и смещением
|
35 |
+
polylines = (
|
36 |
+
zip([(x + 25) * scale_factor * upscale_factor for x in polyline[0]],
|
37 |
+
[(y + 25) * scale_factor * upscale_factor for y in polyline[1]])
|
38 |
+
for polyline in raw_drawing if len(polyline) == 2
|
39 |
+
)
|
40 |
+
|
41 |
+
# Преобразуем набор линий в список для последующего рисования
|
42 |
+
polylines_list = [list(polyline) for polyline in polylines]
|
43 |
+
|
44 |
+
# Создание пустого увеличенного изображения
|
45 |
+
pil_img = Image.new("L", (large_size, large_size), 255) # Черно-белое изображение, белый фон
|
46 |
+
d = ImageDraw.Draw(pil_img)
|
47 |
+
|
48 |
+
# Рисование линий с учетом масштабирования и увеличенной толщины
|
49 |
+
for polyline in polylines_list:
|
50 |
+
d.line(polyline, fill=0, width=int(1.5 * upscale_factor)) # Линии черного цвета
|
51 |
+
|
52 |
+
# Масштабирование изображения обратно до 64x64 с использованием LANCZOS для сглаживания
|
53 |
+
pil_img = pil_img.resize((original_size, original_size), Image.Resampling.LANCZOS)
|
54 |
+
|
55 |
+
return pil_img
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
# Возвращает количество изображений в наборе данных
|
59 |
+
return len(self.images)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
# Возвращает обработанное изображение по индексу
|
63 |
+
return self.processed_images[idx]
|
QuckDrawGAN/utils/models.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
latent_dim = 100 # Размерность латентного пространства (входного шума для генератора)
|
4 |
+
|
5 |
+
# Класс генератора для генерации изображений из латентного пространства
|
6 |
+
class Generator(nn.Module):
|
7 |
+
def __init__(self, latent_dim):
|
8 |
+
super(Generator, self).__init__()
|
9 |
+
|
10 |
+
# Размеры начального изображения и количество каналов для начала транспонированных операций
|
11 |
+
self.init_size = 4 # Размер изображения после первой линейной трансформации
|
12 |
+
self.start_channels = 512 # Количество каналов на первом этапе
|
13 |
+
|
14 |
+
# Последовательная модель генератора
|
15 |
+
self.model = nn.Sequential(
|
16 |
+
# Линейное преобразование латентного вектора в развернутую форму для дальнейшего увеличения
|
17 |
+
nn.Linear(latent_dim, self.start_channels * self.init_size ** 2),
|
18 |
+
nn.BatchNorm1d(self.start_channels * self.init_size ** 2, 0.8),
|
19 |
+
nn.LeakyReLU(0.2, inplace=True),
|
20 |
+
|
21 |
+
# Преобразование в 4D тензор для начала операций с изображениями
|
22 |
+
nn.Unflatten(1, (self.start_channels, self.init_size, self.init_size)),
|
23 |
+
|
24 |
+
# Начало операций с изображением (увеличение размера)
|
25 |
+
nn.Upsample(scale_factor=2), # Увеличение размера изображения в 2 раза
|
26 |
+
|
27 |
+
# Сверточные слои с уменьшением количества каналов и последующими нелинейностями
|
28 |
+
nn.Conv2d(self.start_channels, self.start_channels // 3, 3, stride=1, padding=1),
|
29 |
+
nn.BatchNorm2d(self.start_channels // 3, 0.8),
|
30 |
+
nn.LeakyReLU(0.2, inplace=True),
|
31 |
+
|
32 |
+
nn.Upsample(scale_factor=2), # Еще одно увеличение
|
33 |
+
|
34 |
+
nn.Conv2d(self.start_channels // 3, self.start_channels // 4, 3, stride=1, padding=1),
|
35 |
+
nn.BatchNorm2d(self.start_channels // 4, 0.8),
|
36 |
+
nn.LeakyReLU(0.2, inplace=True),
|
37 |
+
|
38 |
+
nn.Upsample(scale_factor=2), # Третье увеличение
|
39 |
+
|
40 |
+
nn.Conv2d(self.start_channels // 4, self.start_channels // 6, 3, stride=1, padding=1),
|
41 |
+
nn.BatchNorm2d(self.start_channels // 6, 0.8),
|
42 |
+
nn.LeakyReLU(0.2, inplace=True),
|
43 |
+
|
44 |
+
nn.Upsample(scale_factor=2), # Четвертое увеличение
|
45 |
+
|
46 |
+
nn.Conv2d(self.start_channels // 6, self.start_channels // 8, 3, stride=1, padding=1),
|
47 |
+
nn.BatchNorm2d(self.start_channels // 8, 0.8),
|
48 |
+
nn.LeakyReLU(0.2, inplace=True),
|
49 |
+
|
50 |
+
# Последний сверточный слой для вывода изображения размером 1xWxH с функцией активации Тангенс
|
51 |
+
nn.Conv2d(self.start_channels // 8, 1, 3, stride=1, padding=1),
|
52 |
+
nn.Tanh() # Приведение значений пикселей в диапазон [-1, 1]
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, z):
|
56 |
+
# Прямое распространение через сеть генератора
|
57 |
+
out = self.model(z)
|
58 |
+
return out # Возвращаем сгенерированное изображение
|
59 |
+
|
60 |
+
|
61 |
+
# Класс дискриминатора для различения реальных и сгенерированных изображений
|
62 |
+
class Discriminator(nn.Module):
|
63 |
+
def __init__(self):
|
64 |
+
super(Discriminator, self).__init__()
|
65 |
+
|
66 |
+
# Последовательная модель дискриминатора
|
67 |
+
self.model = nn.Sequential(
|
68 |
+
# Первый сверточный блок
|
69 |
+
nn.Conv2d(1, 64, 3, stride=2, padding=1), # Уменьшает размер изображения до (64, 32, 32)
|
70 |
+
nn.BatchNorm2d(64, 0.8),
|
71 |
+
nn.LeakyReLU(0.2, inplace=True),
|
72 |
+
nn.Dropout(0.3), # Вероятность выключения нейронов для регуляризации
|
73 |
+
|
74 |
+
# Второй сверточный блок
|
75 |
+
nn.Conv2d(64, 128, 3, stride=2, padding=1), # Уменьшает размер до (128, 16, 16)
|
76 |
+
nn.BatchNorm2d(128, 0.8),
|
77 |
+
nn.LeakyReLU(0.2, inplace=True),
|
78 |
+
nn.Dropout(0.3),
|
79 |
+
|
80 |
+
# Третий сверточный блок
|
81 |
+
nn.Conv2d(128, 256, 3, stride=2, padding=1), # Уменьшает размер до (256, 8, 8)
|
82 |
+
nn.BatchNorm2d(256, 0.8),
|
83 |
+
nn.LeakyReLU(0.2, inplace=True),
|
84 |
+
nn.Dropout(0.3),
|
85 |
+
|
86 |
+
# Четвертый сверточный блок
|
87 |
+
nn.Conv2d(256, 256, 3, stride=1, padding=1), # Поддерживает размер (256, 8, 8)
|
88 |
+
nn.BatchNorm2d(256, 0.8),
|
89 |
+
nn.LeakyReLU(0.2, inplace=True),
|
90 |
+
nn.MaxPool2d(2), # Уменьшает размер до (256, 4, 4)
|
91 |
+
nn.Dropout(0.3),
|
92 |
+
|
93 |
+
# Пятый сверточный блок
|
94 |
+
nn.Conv2d(256, 512, 3, stride=1, padding=1), # Уменьшает размер до (512, 2, 2)
|
95 |
+
nn.BatchNorm2d(512, 0.8),
|
96 |
+
nn.LeakyReLU(0.2, inplace=True),
|
97 |
+
nn.MaxPool2d(2), # Размер до (512, 1, 1)
|
98 |
+
|
99 |
+
# Преобразование в плоский вектор для классификации
|
100 |
+
nn.Flatten(), # Преобразует изображение в вектор (512 * 2 * 2 = 2048)
|
101 |
+
nn.Linear(512 * 2 * 2, 1) # Полносвязный слой для получения одного скалярного выхода
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, img):
|
105 |
+
# Прямое распространение через сеть дискриминатора
|
106 |
+
out = self.model(img)
|
107 |
+
return out # Возвращаем вероятность принадлежности к классу "реальное" или "сгенерированное"
|
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import QuckDrawGAN as qd
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
# Загрузим модель
|
7 |
+
generator_file = r'pretrained_output/models/generator.pt'
|
8 |
+
discriminator_file = r'pretrained_output/models/discriminator_fine_tuned.pt'
|
9 |
+
|
10 |
+
# Создаем объект модели
|
11 |
+
model = qd.Model(generator_file, discriminator_file)
|
12 |
+
|
13 |
+
# Функция для генерации изображения с учетом сида, нормализацией и изменением размера
|
14 |
+
def generate_image(n_images=16, seed=""):
|
15 |
+
# Если сид не задан, не передаем его в модель
|
16 |
+
if seed == "":
|
17 |
+
seed = None
|
18 |
+
best_image = model.generate(n_images, seed) # Генерация с учетом сида
|
19 |
+
|
20 |
+
# Нормализация: находим минимум и максимум в изображении
|
21 |
+
best_image_min = np.min(best_image)
|
22 |
+
best_image_max = np.max(best_image)
|
23 |
+
|
24 |
+
# Нормализуем изображение, чтобы значения были в диапазоне от 0 до 255
|
25 |
+
normalized_image = 255 * (best_image - best_image_min) / (best_image_max - best_image_min)
|
26 |
+
|
27 |
+
# Преобразуем изображение в формат, подходящий для отображения
|
28 |
+
pil_image = Image.fromarray(normalized_image.astype(np.uint8)) # Преобразуем в uint8 для отображения
|
29 |
+
pil_image = pil_image.resize((256, 256), Image.Resampling.LANCZOS) # Ресайз изображения до 256x256
|
30 |
+
|
31 |
+
return pil_image
|
32 |
+
|
33 |
+
# Создаем интерфейс Gradio с помощью Blocks (для большей гибкости)
|
34 |
+
with gr.Blocks() as interface:
|
35 |
+
gr.Markdown("# Генератор изображений с использованием QuckDrawGAN")
|
36 |
+
gr.Markdown("Этот интерфейс позволяет генерировать изображения с помощью модели QuckDrawGAN. Настройте количество генерируемых изображений и задайте сид для повторяемости.")
|
37 |
+
|
38 |
+
with gr.Row():
|
39 |
+
# Блок для изображения сверху
|
40 |
+
generated_image = gr.Image(type="pil", label="Сгенерированное изображение", elem_id="generated_image", scale=2) # Увеличиваем масштаб изображения
|
41 |
+
|
42 |
+
with gr.Row():
|
43 |
+
# Блок параметров и кнопки
|
44 |
+
with gr.Column():
|
45 |
+
seed_input = gr.Textbox(value="", label="Сид (опционально)", interactive=True)
|
46 |
+
num_images = gr.Slider(minimum=1, maximum=1024, value=16, label="Количество изображений для генерации", interactive=True, step=1)
|
47 |
+
|
48 |
+
# Кнопка генерации изображения справа
|
49 |
+
generate_button = gr.Button("Сгенерировать")
|
50 |
+
|
51 |
+
# Логика для автогенерации при изменении параметров
|
52 |
+
seed_input.change(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
|
53 |
+
num_images.change(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
|
54 |
+
|
55 |
+
# Логика для кнопки генерации
|
56 |
+
generate_button.click(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
|
57 |
+
|
58 |
+
# Автогенерация при старте
|
59 |
+
interface.load(generate_image, inputs=[num_images, seed_input], outputs=generated_image)
|
60 |
+
|
61 |
+
# Стилизация блока изображения (увеличение размера блока)
|
62 |
+
interface.css = """
|
63 |
+
#generated_image {
|
64 |
+
width: 400px;
|
65 |
+
height: 400px;
|
66 |
+
margin-top: 20px;
|
67 |
+
}
|
68 |
+
"""
|
69 |
+
|
70 |
+
# Запуск интерфейса
|
71 |
+
interface.launch()
|
pretrained_output/images/1.png
ADDED
![]() |
pretrained_output/images/10.png
ADDED
![]() |
pretrained_output/images/100.png
ADDED
![]() |
pretrained_output/images/11.png
ADDED
![]() |
pretrained_output/images/12.png
ADDED
![]() |
pretrained_output/images/13.png
ADDED
![]() |
pretrained_output/images/14.png
ADDED
![]() |
pretrained_output/images/15.png
ADDED
![]() |
pretrained_output/images/16.png
ADDED
![]() |
pretrained_output/images/17.png
ADDED
![]() |
pretrained_output/images/18.png
ADDED
![]() |
pretrained_output/images/19.png
ADDED
![]() |
pretrained_output/images/2.png
ADDED
![]() |
pretrained_output/images/20.png
ADDED
![]() |
pretrained_output/images/21.png
ADDED
![]() |
pretrained_output/images/22.png
ADDED
![]() |
pretrained_output/images/23.png
ADDED
![]() |
pretrained_output/images/24.png
ADDED
![]() |
pretrained_output/images/25.png
ADDED
![]() |
pretrained_output/images/26.png
ADDED
![]() |
pretrained_output/images/27.png
ADDED
![]() |
pretrained_output/images/28.png
ADDED
![]() |
pretrained_output/images/29.png
ADDED
![]() |
pretrained_output/images/3.png
ADDED
![]() |
pretrained_output/images/30.png
ADDED
![]() |
pretrained_output/images/31.png
ADDED
![]() |
pretrained_output/images/32.png
ADDED
![]() |
pretrained_output/images/33.png
ADDED
![]() |
pretrained_output/images/34.png
ADDED
![]() |
pretrained_output/images/35.png
ADDED
![]() |
pretrained_output/images/36.png
ADDED
![]() |
pretrained_output/images/37.png
ADDED
![]() |
pretrained_output/images/38.png
ADDED
![]() |
pretrained_output/images/39.png
ADDED
![]() |
pretrained_output/images/4.png
ADDED
![]() |
pretrained_output/images/40.png
ADDED
![]() |
pretrained_output/images/41.png
ADDED
![]() |
pretrained_output/images/42.png
ADDED
![]() |